Enzyme main
Loading...
Searching...
No Matches
CloneFunction.cpp
Go to the documentation of this file.
1#include "llvm/ADT/APSInt.h"
2
3#include "mlir/IR/BuiltinTypes.h"
4
5#include "CloneFunction.h"
6
7using namespace mlir;
8using namespace mlir::enzyme;
9
10Type getShadowType(Type type, unsigned width) {
11 if (auto iface = dyn_cast<AutoDiffTypeInterface>(type))
12 return iface.getShadowType(width);
13 llvm::errs() << " type does not have autodifftypeinterface: " << type << "\n";
14 exit(1);
15}
16
17template <typename T>
18mlir::FunctionType
19getFunctionTypeForClone(T FTy, DerivativeMode mode, unsigned width,
20 mlir::Type additionalArg,
21 const std::vector<bool> &returnPrimals,
22 const std::vector<bool> &returnShadows,
23 llvm::ArrayRef<DIFFE_TYPE> ReturnActivity,
24 llvm::ArrayRef<DIFFE_TYPE> ArgActivity) {
25 static_assert(llvm::is_one_of<T, FunctionType, LLVM::LLVMFunctionType>::value,
26 "Expected FunctionType or LLVMFunctionType");
27 SmallVector<mlir::Type, 4> RetTypes;
28 ArrayRef<Type> origInputTypes, origResultTypes;
29 if constexpr (std::is_same<T, LLVM::LLVMFunctionType>::value) {
30 origInputTypes = FTy.getParams();
31 origResultTypes = FTy.getReturnTypes();
32 } else {
33 origInputTypes = FTy.getInputs();
34 origResultTypes = FTy.getResults();
35 }
36
37 for (auto &&[Ty, returnPrimal, returnShadow, activity] : llvm::zip(
38 origResultTypes, returnPrimals, returnShadows, ReturnActivity)) {
39 if (returnPrimal) {
40 RetTypes.push_back(Ty);
41 }
42 if (returnShadow) {
43 assert(activity != DIFFE_TYPE::CONSTANT);
44 assert(activity != DIFFE_TYPE::OUT_DIFF);
45 RetTypes.push_back(getShadowType(Ty, width));
46 }
47 }
48
49 SmallVector<mlir::Type, 4> ArgTypes;
50
51 for (auto &&[ITy, act] : llvm::zip(origInputTypes, ArgActivity)) {
52 ArgTypes.push_back(ITy);
53 if (act == DIFFE_TYPE::DUP_ARG || act == DIFFE_TYPE::DUP_NONEED) {
54 ArgTypes.push_back(getShadowType(ITy, width));
55 } else if (act == DIFFE_TYPE::OUT_DIFF) {
56 RetTypes.push_back(getShadowType(ITy, width));
57 }
58 }
59
60 for (auto &&[Ty, activity] : llvm::zip(origResultTypes, ReturnActivity)) {
61 if (activity == DIFFE_TYPE::OUT_DIFF) {
62 ArgTypes.push_back(getShadowType(Ty, width));
63 }
64 }
65
66 if (additionalArg) {
67 ArgTypes.push_back(additionalArg);
68 }
69
70 // Create a new function type...
71 OpBuilder builder(FTy.getContext());
72 return builder.getFunctionType(ArgTypes, RetTypes);
73}
74
75Operation *clone(Operation *src, IRMapping &mapper,
76 Operation::CloneOptions options,
77 std::map<Operation *, Operation *> &opMap) {
78 SmallVector<Value, 8> operands;
79 SmallVector<Block *, 2> successors;
80
81 // Remap the operands.
82 if (options.shouldCloneOperands()) {
83 operands.reserve(src->getNumOperands());
84 for (auto opValue : src->getOperands())
85 operands.push_back(mapper.lookupOrDefault(opValue));
86 }
87
88 // Remap the successors.
89 successors.reserve(src->getNumSuccessors());
90 for (Block *successor : src->getSuccessors())
91 successors.push_back(mapper.lookupOrDefault(successor));
92
93 // Create the new operation.
94 Operation *newOp = nullptr;
95
96 if (!newOp) {
97 SmallVector<Type> resultTypes(src->getResultTypes().begin(),
98 src->getResultTypes().end());
99 newOp = Operation::create(src->getLoc(), src->getName(), resultTypes,
100 operands, src->getAttrs(), mlir::PropertyRef(),
101 successors, src->getNumRegions());
102
103 // Clone the regions.
104 if (options.shouldCloneRegions()) {
105 for (unsigned i = 0; i != src->getNumRegions(); ++i)
106 cloneInto(&src->getRegion(i), &newOp->getRegion(i), mapper, opMap);
107 }
108 }
109
110 // Remember the mapping of any results.
111 for (unsigned i = 0, e = src->getNumResults(); i != e; ++i)
112 mapper.map(src->getResult(i), newOp->getResult(i));
113
114 opMap[src] = newOp;
115 return newOp;
116}
117
118void cloneInto(Region *src, Region *dest, IRMapping &mapper,
119 std::map<Operation *, Operation *> &opMap) {
120 cloneInto(src, dest, dest->end(), mapper, opMap);
121}
122
123/// Clone this region into 'dest' before the given position in 'dest'.
124void cloneInto(Region *src, Region *dest, Region::iterator destPos,
125 IRMapping &mapper, std::map<Operation *, Operation *> &opMap) {
126 assert(src);
127 assert(dest && "expected valid region to clone into");
128 assert(src != dest && "cannot clone region into itself");
129
130 // If the list is empty there is nothing to clone.
131 if (src->empty())
132 return;
133
134 // The below clone implementation takes special care to be read only for the
135 // sake of multi threading. That essentially means not adding any uses to any
136 // of the blocks or operation results contained within this region as that
137 // would lead to a write in their use-def list. This is unavoidable for
138 // 'Value's from outside the region however, in which case it is not read
139 // only. Using the BlockAndValueMapper it is possible to remap such 'Value's
140 // to ones owned by the calling thread however, making it read only once
141 // again.
142
143 // First clone all the blocks and block arguments and map them, but don't yet
144 // clone the operations, as they may otherwise add a use to a block that has
145 // not yet been mapped
146 for (Block &block : *src) {
147 Block *newBlock = new Block();
148 mapper.map(&block, newBlock);
149
150 // Clone the block arguments. The user might be deleting arguments to the
151 // block by specifying them in the mapper. If so, we don't add the
152 // argument to the cloned block.
153 for (auto arg : block.getArguments())
154 if (!mapper.contains(arg)) {
155 auto Ty = arg.getType();
156 mapper.map(arg, newBlock->addArgument(Ty, arg.getLoc()));
157 }
158
159 dest->getBlocks().insert(destPos, newBlock);
160 }
161
162 auto newBlocksRange =
163 llvm::make_range(Region::iterator(mapper.lookup(&src->front())), destPos);
164
165 // Now follow up with creating the operations, but don't yet clone their
166 // regions, nor set their operands. Setting the successors is safe as all have
167 // already been mapped. We are essentially just creating the operation results
168 // to be able to map them.
169 // Cloning the operands and region as well would lead to uses of operations
170 // not yet mapped.
171 auto cloneOptions =
172 Operation::CloneOptions::all().cloneRegions(false).cloneOperands(false);
173 for (auto zippedBlocks : llvm::zip(*src, newBlocksRange)) {
174 Block &sourceBlock = std::get<0>(zippedBlocks);
175 Block &clonedBlock = std::get<1>(zippedBlocks);
176 // Clone and remap the operations within this block.
177 for (Operation &op : sourceBlock) {
178 clonedBlock.push_back(clone(&op, mapper, cloneOptions, opMap));
179 }
180 }
181
182 // Finally now that all operation results have been mapped, set the operands
183 // and clone the regions.
184 SmallVector<Value> operands;
185 for (auto zippedBlocks : llvm::zip(*src, newBlocksRange)) {
186 for (auto ops :
187 llvm::zip(std::get<0>(zippedBlocks), std::get<1>(zippedBlocks))) {
188 Operation &source = std::get<0>(ops);
189 Operation &clone = std::get<1>(ops);
190
191 operands.resize(source.getNumOperands());
192 llvm::transform(
193 source.getOperands(), operands.begin(),
194 [&](Value operand) { return mapper.lookupOrDefault(operand); });
195 clone.setOperands(operands);
196
197 for (auto regions : llvm::zip(source.getRegions(), clone.getRegions()))
198 cloneInto(&std::get<0>(regions), &std::get<1>(regions), mapper, opMap);
199 }
200 }
201}
202
203FunctionOpInterface CloneFunctionWithReturns(
204 DerivativeMode mode, unsigned width, FunctionOpInterface F,
205 IRMapping &ptrInputs, ArrayRef<DIFFE_TYPE> ArgActivity,
206 SmallPtrSetImpl<mlir::Value> &constants,
207 SmallPtrSetImpl<mlir::Value> &nonconstants,
208 SmallPtrSetImpl<mlir::Value> &returnvals,
209 const std::vector<bool> &returnPrimals,
210 const std::vector<bool> &returnShadows, ArrayRef<DIFFE_TYPE> RetActivity,
211 Twine name, IRMapping &VMap, std::map<Operation *, Operation *> &OpMap,
212 mlir::Type additionalArg) {
213 assert(!F.getFunctionBody().empty());
214 // F = preprocessForClone(F, mode);
215 // llvm::ValueToValueMapTy VMap;
216 FunctionType FTy;
217 if (auto llFTy = dyn_cast<LLVM::LLVMFunctionType>(F.getFunctionType())) {
218 FTy = getFunctionTypeForClone(llFTy, mode, width, additionalArg,
219 returnPrimals, returnShadows, RetActivity,
220 ArgActivity);
221 } else {
222 FTy = getFunctionTypeForClone(cast<mlir::FunctionType>(F.getFunctionType()),
223 mode, width, additionalArg, returnPrimals,
224 returnShadows, RetActivity, ArgActivity);
225 }
226
227 /*
228 for (Block &BB : F.getFunctionBody().getBlocks()) {
229 if (auto ri = dyn_cast<ReturnInst>(BB.getTerminator())) {
230 if (auto rv = ri->getReturnValue()) {
231 returnvals.insert(rv);
232 }
233 }
234 }
235 */
236
237 // Create the new function. This needs to go through the raw Operation API
238 // instead of a concrete builder for genericity.
239 auto NewF = cast<FunctionOpInterface>(F->cloneWithoutRegions());
240 SymbolTable::setSymbolName(NewF, name.str());
241 SmallVector<Type> resultTypes(FTy.getResults());
242 if (auto iface = dyn_cast<AutoDiffFunctionInterface>(*NewF)) {
243 iface.transformResultTypes(resultTypes);
244 } else {
245 llvm::errs()
246 << F << "this function does not implement AutoDiffFunctionInterface";
247 return nullptr;
248 }
249 NewF.setType(F.cloneTypeWith(FTy.getInputs(), resultTypes));
250
251 Operation *parent = F->getParentWithTrait<OpTrait::SymbolTable>();
252 SymbolTable table(parent);
253 table.insert(NewF);
254 SymbolTable::setSymbolVisibility(NewF, SymbolTable::Visibility::Private);
255
256 cloneInto(&F.getFunctionBody(), &NewF.getFunctionBody(), VMap, OpMap);
257
258 {
259 SmallVector<mlir::Attribute> allAttrs(F.getNumArguments(), nullptr);
260 if (auto allArgAttrs = F.getAllArgAttrs())
261 allAttrs.assign(allArgAttrs.getValue().begin(),
262 allArgAttrs.getValue().end());
263
264 auto &blk = NewF.getFunctionBody().front();
265 assert(F.getFunctionBody().front().getNumArguments() == ArgActivity.size());
266 for (ssize_t i = ArgActivity.size() - 1; i >= 0; i--) {
267 mlir::Value oval = F.getFunctionBody().front().getArgument(i);
268 if (ArgActivity[i] == DIFFE_TYPE::CONSTANT)
269 constants.insert(oval);
270 else if (ArgActivity[i] == DIFFE_TYPE::OUT_DIFF)
271 nonconstants.insert(oval);
272 else if (ArgActivity[i] == DIFFE_TYPE::DUP_ARG ||
273 ArgActivity[i] == DIFFE_TYPE::DUP_NONEED) {
274 nonconstants.insert(oval);
275 mlir::Value val = blk.getArgument(i);
276 mlir::Value dval;
277 mlir::Attribute dupAttr = nullptr;
278 if ((size_t)i == ArgActivity.size() - 1) {
279 dval = blk.addArgument(getShadowType(val.getType(), width),
280 val.getLoc());
281 allAttrs.push_back(dupAttr);
282 } else {
283 dval = blk.insertArgument(blk.args_begin() + i + 1,
284 getShadowType(val.getType(), width),
285 val.getLoc());
286 allAttrs.insert(allAttrs.begin() + i + 1, dupAttr);
287 }
288 ptrInputs.map(oval, dval);
289 }
290 }
291 auto retloc = blk.getTerminator()->getLoc();
292 ArrayRef<Type> resultTypes;
293 if (auto llFTy = dyn_cast<LLVM::LLVMFunctionType>(F.getFunctionType()))
294 resultTypes = llFTy.getReturnTypes();
295 else
296 resultTypes = cast<mlir::FunctionType>(F.getFunctionType()).getResults();
297
298 for (auto &&[Ty, activity] : llvm::zip(resultTypes, RetActivity)) {
299 if (activity == DIFFE_TYPE::OUT_DIFF) {
300 blk.addArgument(getShadowType(Ty, width), retloc);
301 allAttrs.push_back(nullptr);
302 }
303 }
304
305 NewF.setAllArgAttrs(allAttrs);
306 }
307
308 std::string ToClone[] = {
309 "bufferization.writable",
310 "mhlo.sharding",
311 "sdy.sharding",
312 "mhlo.layout_mode",
313 "tt.divisibility",
314 "xla_framework.input_mapping",
315 "xla_framework.result_mapping",
316 };
317
318 size_t newxlacnt = 0;
319 {
320 SmallVector<mlir::Attribute> resultAttrs;
321 for (size_t oldi = 0, end = F.getNumResults(); oldi < end; oldi++) {
322 if (returnPrimals[oldi]) {
323 resultAttrs.push_back(F.getResultAttrDict(oldi));
324 }
325 if (returnShadows[oldi]) {
326 resultAttrs.push_back(nullptr);
327 }
328 }
329 for (auto activity : ArgActivity) {
330 if (activity == DIFFE_TYPE::OUT_DIFF)
331 resultAttrs.push_back(nullptr);
332 }
333
334 bool packedResults = resultAttrs.size() != NewF.getNumResults();
335 if (packedResults)
336 resultAttrs.assign(NewF.getNumResults(), nullptr);
337
338 NewF.setAllResultAttrs(resultAttrs);
339
340 if (!packedResults) {
341 size_t oldi = 0;
342 size_t newi = 0;
343 while (oldi < F.getNumResults()) {
344 if (returnPrimals[oldi]) {
345 for (auto attrName : ToClone) {
346 auto attrNameS = StringAttr::get(F->getContext(), attrName);
347 if (auto attr = F.getResultAttr(oldi, attrName)) {
348 if (attrName == "xla_framework.result_mapping") {
349 auto iattr = cast<IntegerAttr>(attr);
350 APSInt nc(iattr.getValue());
351 nc = newxlacnt;
352 attr = IntegerAttr::get(F->getContext(), nc);
353 newxlacnt++;
354 }
355 NewF.setResultAttr(newi, attrNameS, attr);
356 }
357 }
358 newi++;
359 }
360 if (returnShadows[oldi]) {
361 for (auto attrName : ToClone) {
362 auto attrNameS = StringAttr::get(F->getContext(), attrName);
363 if (auto attr = F.getResultAttr(oldi, attrName)) {
364 if (attrName == "xla_framework.result_mapping") {
365 auto iattr = cast<IntegerAttr>(attr);
366 APSInt nc(iattr.getValue());
367 nc = newxlacnt;
368 attr = IntegerAttr::get(F->getContext(), nc);
369 newxlacnt++;
370 }
371 NewF.setResultAttr(newi, attrNameS, attr);
372 }
373 }
374 newi++;
375 }
376 oldi++;
377 }
378 }
379 }
380 {
381 size_t oldi = 0;
382 size_t newi = 0;
383 while (oldi < F.getNumArguments()) {
384 if (auto attr = NewF.getArgAttr(newi, "xla_framework.input_mapping")) {
385 auto iattr = cast<IntegerAttr>(attr);
386 APSInt nc(iattr.getValue());
387 nc = newxlacnt;
388 attr = IntegerAttr::get(F->getContext(), nc);
389 newxlacnt++;
390 NewF.setArgAttr(newi, "xla_framework.input_mapping", attr);
391 }
392
393 newi++;
394 if (ArgActivity[oldi] == DIFFE_TYPE::DUP_ARG ||
395 ArgActivity[oldi] == DIFFE_TYPE::DUP_NONEED) {
396 for (auto attrName : ToClone) {
397 if (auto attr = NewF.getArgAttr(newi - 1, attrName)) {
398 if (attrName == "xla_framework.input_mapping") {
399 auto iattr = cast<IntegerAttr>(attr);
400 APSInt nc(iattr.getValue());
401 nc = newxlacnt;
402 attr = IntegerAttr::get(F->getContext(), nc);
403 newxlacnt++;
404 }
405 NewF.setArgAttr(newi, attrName, attr);
406 }
407 }
408 newi++;
409 }
410 oldi++;
411 }
412 }
413
414 return NewF;
415}
Type getShadowType(Type type, unsigned width)
void cloneInto(Region *src, Region *dest, IRMapping &mapper, std::map< Operation *, Operation * > &opMap)
mlir::FunctionType getFunctionTypeForClone(T FTy, DerivativeMode mode, unsigned width, mlir::Type additionalArg, const std::vector< bool > &returnPrimals, const std::vector< bool > &returnShadows, llvm::ArrayRef< DIFFE_TYPE > ReturnActivity, llvm::ArrayRef< DIFFE_TYPE > ArgActivity)
FunctionOpInterface CloneFunctionWithReturns(DerivativeMode mode, unsigned width, FunctionOpInterface F, IRMapping &ptrInputs, ArrayRef< DIFFE_TYPE > ArgActivity, SmallPtrSetImpl< mlir::Value > &constants, SmallPtrSetImpl< mlir::Value > &nonconstants, SmallPtrSetImpl< mlir::Value > &returnvals, const std::vector< bool > &returnPrimals, const std::vector< bool > &returnShadows, ArrayRef< DIFFE_TYPE > RetActivity, Twine name, IRMapping &VMap, std::map< Operation *, Operation * > &OpMap, mlir::Type additionalArg)
Operation * clone(Operation *src, IRMapping &mapper, Operation::CloneOptions options, std::map< Operation *, Operation * > &opMap)
DerivativeMode
Definition Utils.h:390