20 mlir::Type additionalArg,
21 const std::vector<bool> &returnPrimals,
22 const std::vector<bool> &returnShadows,
23 llvm::ArrayRef<DIFFE_TYPE> ReturnActivity,
24 llvm::ArrayRef<DIFFE_TYPE> ArgActivity) {
25 static_assert(llvm::is_one_of<T, FunctionType, LLVM::LLVMFunctionType>::value,
26 "Expected FunctionType or LLVMFunctionType");
27 SmallVector<mlir::Type, 4> RetTypes;
28 ArrayRef<Type> origInputTypes, origResultTypes;
29 if constexpr (std::is_same<T, LLVM::LLVMFunctionType>::value) {
30 origInputTypes = FTy.getParams();
31 origResultTypes = FTy.getReturnTypes();
33 origInputTypes = FTy.getInputs();
34 origResultTypes = FTy.getResults();
37 for (
auto &&[Ty, returnPrimal, returnShadow, activity] : llvm::zip(
38 origResultTypes, returnPrimals, returnShadows, ReturnActivity)) {
40 RetTypes.push_back(Ty);
49 SmallVector<mlir::Type, 4> ArgTypes;
51 for (
auto &&[ITy, act] : llvm::zip(origInputTypes, ArgActivity)) {
52 ArgTypes.push_back(ITy);
60 for (
auto &&[Ty, activity] : llvm::zip(origResultTypes, ReturnActivity)) {
67 ArgTypes.push_back(additionalArg);
71 OpBuilder builder(FTy.getContext());
72 return builder.getFunctionType(ArgTypes, RetTypes);
75Operation *
clone(Operation *src, IRMapping &mapper,
76 Operation::CloneOptions options,
77 std::map<Operation *, Operation *> &opMap) {
78 SmallVector<Value, 8> operands;
79 SmallVector<Block *, 2> successors;
82 if (options.shouldCloneOperands()) {
83 operands.reserve(src->getNumOperands());
84 for (
auto opValue : src->getOperands())
85 operands.push_back(mapper.lookupOrDefault(opValue));
89 successors.reserve(src->getNumSuccessors());
90 for (Block *successor : src->getSuccessors())
91 successors.push_back(mapper.lookupOrDefault(successor));
94 Operation *newOp =
nullptr;
97 SmallVector<Type> resultTypes(src->getResultTypes().begin(),
98 src->getResultTypes().end());
99 newOp = Operation::create(src->getLoc(), src->getName(), resultTypes,
100 operands, src->getAttrs(), mlir::PropertyRef(),
101 successors, src->getNumRegions());
104 if (options.shouldCloneRegions()) {
105 for (
unsigned i = 0; i != src->getNumRegions(); ++i)
106 cloneInto(&src->getRegion(i), &newOp->getRegion(i), mapper, opMap);
111 for (
unsigned i = 0, e = src->getNumResults(); i != e; ++i)
112 mapper.map(src->getResult(i), newOp->getResult(i));
124void cloneInto(Region *src, Region *dest, Region::iterator destPos,
125 IRMapping &mapper, std::map<Operation *, Operation *> &opMap) {
127 assert(dest &&
"expected valid region to clone into");
128 assert(src != dest &&
"cannot clone region into itself");
146 for (Block &block : *src) {
147 Block *newBlock =
new Block();
148 mapper.map(&block, newBlock);
153 for (
auto arg : block.getArguments())
154 if (!mapper.contains(arg)) {
155 auto Ty = arg.getType();
156 mapper.map(arg, newBlock->addArgument(Ty, arg.getLoc()));
159 dest->getBlocks().insert(destPos, newBlock);
162 auto newBlocksRange =
163 llvm::make_range(Region::iterator(mapper.lookup(&src->front())), destPos);
172 Operation::CloneOptions::all().cloneRegions(
false).cloneOperands(
false);
173 for (
auto zippedBlocks : llvm::zip(*src, newBlocksRange)) {
174 Block &sourceBlock = std::get<0>(zippedBlocks);
175 Block &clonedBlock = std::get<1>(zippedBlocks);
177 for (Operation &op : sourceBlock) {
178 clonedBlock.push_back(
clone(&op, mapper, cloneOptions, opMap));
184 SmallVector<Value> operands;
185 for (
auto zippedBlocks : llvm::zip(*src, newBlocksRange)) {
187 llvm::zip(std::get<0>(zippedBlocks), std::get<1>(zippedBlocks))) {
188 Operation &source = std::get<0>(ops);
189 Operation &
clone = std::get<1>(ops);
191 operands.resize(source.getNumOperands());
193 source.getOperands(), operands.begin(),
194 [&](Value operand) { return mapper.lookupOrDefault(operand); });
195 clone.setOperands(operands);
197 for (
auto regions : llvm::zip(source.getRegions(),
clone.getRegions()))
198 cloneInto(&std::get<0>(regions), &std::get<1>(regions), mapper, opMap);
205 IRMapping &ptrInputs, ArrayRef<DIFFE_TYPE> ArgActivity,
206 SmallPtrSetImpl<mlir::Value> &constants,
207 SmallPtrSetImpl<mlir::Value> &nonconstants,
208 SmallPtrSetImpl<mlir::Value> &returnvals,
209 const std::vector<bool> &returnPrimals,
210 const std::vector<bool> &returnShadows, ArrayRef<DIFFE_TYPE> RetActivity,
211 Twine name, IRMapping &VMap, std::map<Operation *, Operation *> &OpMap,
212 mlir::Type additionalArg) {
213 assert(!F.getFunctionBody().empty());
217 if (
auto llFTy = dyn_cast<LLVM::LLVMFunctionType>(F.getFunctionType())) {
219 returnPrimals, returnShadows, RetActivity,
223 mode, width, additionalArg, returnPrimals,
224 returnShadows, RetActivity, ArgActivity);
239 auto NewF = cast<FunctionOpInterface>(F->cloneWithoutRegions());
240 SymbolTable::setSymbolName(NewF, name.str());
241 SmallVector<Type> resultTypes(FTy.getResults());
242 if (
auto iface = dyn_cast<AutoDiffFunctionInterface>(*NewF)) {
243 iface.transformResultTypes(resultTypes);
246 << F <<
"this function does not implement AutoDiffFunctionInterface";
249 NewF.setType(F.cloneTypeWith(FTy.getInputs(), resultTypes));
251 Operation *parent = F->getParentWithTrait<OpTrait::SymbolTable>();
252 SymbolTable table(parent);
254 SymbolTable::setSymbolVisibility(NewF, SymbolTable::Visibility::Private);
256 cloneInto(&F.getFunctionBody(), &NewF.getFunctionBody(), VMap, OpMap);
259 SmallVector<mlir::Attribute> allAttrs(F.getNumArguments(),
nullptr);
260 if (
auto allArgAttrs = F.getAllArgAttrs())
261 allAttrs.assign(allArgAttrs.getValue().begin(),
262 allArgAttrs.getValue().end());
264 auto &blk = NewF.getFunctionBody().front();
265 assert(F.getFunctionBody().front().getNumArguments() == ArgActivity.size());
266 for (ssize_t i = ArgActivity.size() - 1; i >= 0; i--) {
267 mlir::Value oval = F.getFunctionBody().front().getArgument(i);
269 constants.insert(oval);
271 nonconstants.insert(oval);
274 nonconstants.insert(oval);
275 mlir::Value val = blk.getArgument(i);
277 mlir::Attribute dupAttr =
nullptr;
278 if ((
size_t)i == ArgActivity.size() - 1) {
281 allAttrs.push_back(dupAttr);
283 dval = blk.insertArgument(blk.args_begin() + i + 1,
286 allAttrs.insert(allAttrs.begin() + i + 1, dupAttr);
288 ptrInputs.map(oval, dval);
291 auto retloc = blk.getTerminator()->getLoc();
292 ArrayRef<Type> resultTypes;
293 if (
auto llFTy = dyn_cast<LLVM::LLVMFunctionType>(F.getFunctionType()))
294 resultTypes = llFTy.getReturnTypes();
296 resultTypes = cast<mlir::FunctionType>(F.getFunctionType()).getResults();
298 for (
auto &&[Ty, activity] : llvm::zip(resultTypes, RetActivity)) {
301 allAttrs.push_back(
nullptr);
305 NewF.setAllArgAttrs(allAttrs);
308 std::string ToClone[] = {
309 "bufferization.writable",
314 "xla_framework.input_mapping",
315 "xla_framework.result_mapping",
318 size_t newxlacnt = 0;
320 SmallVector<mlir::Attribute> resultAttrs;
321 for (
size_t oldi = 0, end = F.getNumResults(); oldi < end; oldi++) {
322 if (returnPrimals[oldi]) {
323 resultAttrs.push_back(F.getResultAttrDict(oldi));
325 if (returnShadows[oldi]) {
326 resultAttrs.push_back(
nullptr);
329 for (
auto activity : ArgActivity) {
331 resultAttrs.push_back(
nullptr);
334 bool packedResults = resultAttrs.size() != NewF.getNumResults();
336 resultAttrs.assign(NewF.getNumResults(),
nullptr);
338 NewF.setAllResultAttrs(resultAttrs);
340 if (!packedResults) {
343 while (oldi < F.getNumResults()) {
344 if (returnPrimals[oldi]) {
345 for (
auto attrName : ToClone) {
346 auto attrNameS = StringAttr::get(F->getContext(), attrName);
347 if (
auto attr = F.getResultAttr(oldi, attrName)) {
348 if (attrName ==
"xla_framework.result_mapping") {
349 auto iattr = cast<IntegerAttr>(attr);
350 APSInt nc(iattr.getValue());
352 attr = IntegerAttr::get(F->getContext(), nc);
355 NewF.setResultAttr(newi, attrNameS, attr);
360 if (returnShadows[oldi]) {
361 for (
auto attrName : ToClone) {
362 auto attrNameS = StringAttr::get(F->getContext(), attrName);
363 if (
auto attr = F.getResultAttr(oldi, attrName)) {
364 if (attrName ==
"xla_framework.result_mapping") {
365 auto iattr = cast<IntegerAttr>(attr);
366 APSInt nc(iattr.getValue());
368 attr = IntegerAttr::get(F->getContext(), nc);
371 NewF.setResultAttr(newi, attrNameS, attr);
383 while (oldi < F.getNumArguments()) {
384 if (
auto attr = NewF.getArgAttr(newi,
"xla_framework.input_mapping")) {
385 auto iattr = cast<IntegerAttr>(attr);
386 APSInt nc(iattr.getValue());
388 attr = IntegerAttr::get(F->getContext(), nc);
390 NewF.setArgAttr(newi,
"xla_framework.input_mapping", attr);
396 for (
auto attrName : ToClone) {
397 if (
auto attr = NewF.getArgAttr(newi - 1, attrName)) {
398 if (attrName ==
"xla_framework.input_mapping") {
399 auto iattr = cast<IntegerAttr>(attr);
400 APSInt nc(iattr.getValue());
402 attr = IntegerAttr::get(F->getContext(), nc);
405 NewF.setArgAttr(newi, attrName, attr);