27 if (oBB->getNumSuccessors() == 0) {
28 Operation *returnStatement = newBB->getTerminator();
29 gutils->
erase(returnStatement);
31 OpBuilder forwardToBackwardBuilder(newBB, newBB->end());
33 Operation *newBranchOp = cf::BranchOp::create(
34 forwardToBackwardBuilder, oBB->getTerminator()->getLoc(), reverseBB);
55 if ((op->getBlock()->getTerminator() != op) &&
isFullyInactive(op, gutils)) {
58 if (
auto ifaceOp = dyn_cast<ReverseAutoDiffOpInterface>(op)) {
59 SmallVector<Value> caches = ifaceOp.cacheValues(gutils);
61 ifaceOp.createShadowValues(augmentBuilder, gutils);
62 return ifaceOp.createReverseModeAdjoint(builder, gutils, caches);
64 op->emitError() <<
"could not compute the adjoint for this operation " << *op;
85 llvm::function_ref<buildReturnFunction> buildReturnOp) {
86 OpBuilder revBuilder(reverseBB, reverseBB->end());
87 if (oBB->hasNoPredecessors()) {
88 buildReturnOp(revBuilder, oBB);
90 Location loc = oBB->rbegin()->getLoc();
96 enzyme::PopOp::create(revBuilder, loc, gutils->
getIndexType(), cache);
98 Block *defaultBlock =
nullptr;
100 SmallVector<Block *> blocks;
101 SmallVector<APInt> indices;
103 OpBuilder newBuilder(newBB, newBB->begin());
105 SmallVector<Value, 1> diffes;
106 for (
auto arg : oBB->getArguments()) {
108 !cast<AutoDiffTypeInterface>(arg.getType()).isMutable()) {
109 diffes.push_back(gutils->
diffe(arg, revBuilder));
113 diffes.push_back(
nullptr);
116 for (
auto [idx, pred] : llvm::enumerate(oBB->getPredecessors())) {
121 OpBuilder predecessorBuilder(newPred->getTerminator());
124 arith::ConstantIntOp::create(predecessorBuilder, loc, idx - 1, 32);
125 enzyme::PushOp::create(predecessorBuilder, loc, cache, pred_idx_c);
128 defaultBlock = reversePred;
131 indices.push_back(APInt(32, idx - 1));
132 blocks.push_back(reversePred);
135 auto term = pred->getTerminator();
136 if (
auto iface = dyn_cast<BranchOpInterface>(term)) {
137 for (
auto &op : term->getOpOperands())
139 iface.getSuccessorBlockArgument(op.getOperandNumber()))
141 (*blk_idx).getOwner() == oBB) {
142 auto idx = (*blk_idx).getArgNumber();
146 arith::ConstantIntOp::create(revBuilder, loc, idx - 1, 32);
148 auto to_prop = arith::SelectOp::create(
150 arith::CmpIOp::create(revBuilder, loc,
151 arith::CmpIPredicate::eq, flag,
154 cast<AutoDiffTypeInterface>(diffes[idx].getType())
155 .createNullValue(revBuilder, loc));
157 gutils->
addToDiffe(op.get(), to_prop, revBuilder);
161 assert(0 &&
"predecessor did not implement branch op interface");
165 cf::SwitchOp::create(revBuilder, loc, flag, defaultBlock, ArrayRef<Value>(),
166 ArrayRef<APInt>(indices), ArrayRef<Block *>(blocks),
167 SmallVector<ValueRange>(indices.size(), ValueRange()));
173 llvm::function_ref<buildReturnFunction> buildFuncReturnOp,
174 std::function<std::pair<Value, Value>(Type)> cacheCreator) {
176 auto scope = llvm::make_scope_exit(
182 for (
auto &oBB : oldRegion) {
189 return success(valid);
193 FunctionOpInterface fn, std::vector<DIFFE_TYPE> retType,
195 std::vector<bool> returnPrimals, std::vector<bool> returnShadows,
196 DerivativeMode mode,
bool freeMemory,
size_t width, mlir::Type addedType,
197 MFnTypeInfo type_args, std::vector<bool> volatile_args,
void *augmented,
198 bool omp, llvm::StringRef postpasses,
bool verifyPostPasses,
201 if (fn.getFunctionBody().empty()) {
202 llvm::errs() << fn <<
"\n";
203 llvm_unreachable(
"Differentiating empty function");
213 static_cast<unsigned>(width),
222 return cachedFn->second;
225 SmallVector<bool> returnPrimalsP(returnPrimals.begin(), returnPrimals.end());
226 SmallVector<bool> returnShadowsP(returnShadows.begin(), returnShadows.end());
229 *
this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
230 retType, constants, addedType, omp, postpasses, verifyPostPasses,
235 Region &oldRegion = gutils->
oldFunc.getFunctionBody();
236 Region &newRegion = gutils->
newFunc.getFunctionBody();
238 auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) {
239 SmallVector<mlir::Value> retargs;
240 for (
auto [arg, returnPrimal] :
241 llvm::zip(oBB->getTerminator()->getOperands(), returnPrimals)) {
246 for (
auto [arg, cv] : llvm::zip(oBB->getArguments(), constants)) {
248 retargs.push_back(gutils->
diffe(arg, builder));
252 Location loc = oBB->rbegin()->getLoc();
253 if (
auto iface = dyn_cast<enzyme::AutoDiffFunctionInterface>(*fn))
254 iface.createReturn(builder, loc, retargs);
256 fn->emitError() <<
"this function operation does not implement "
257 "AutoDiffFunctionInterface";
264 differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp,
nullptr);
270 if (!res.succeeded())
273 if (postpasses !=
"") {
274 mlir::PassManager pm(nf->getContext());
275 pm.enableVerifier(verifyPostPasses);
276 std::string error_message;
278 mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
279 if (mlir::failed(result)) {
283 if (!mlir::succeeded(pm.run(nf))) {
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, std::vector< bool > returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
static MGradientUtilsReverse * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const ArrayRef< bool > returnPrimals, const ArrayRef< bool > returnShadows, llvm::ArrayRef< DIFFE_TYPE > retType, llvm::ArrayRef< DIFFE_TYPE > constant_args, mlir::Type additionalArg, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)