28 const ArrayRef<bool> returnPrimals,
29 const ArrayRef<bool> returnShadows) {
30 auto inst = oBB->getTerminator();
34 auto newInst = nBB->getTerminator();
36 OpBuilder nBuilder(inst);
37 nBuilder.setInsertionPointToEnd(nBB);
39 if (
auto binst = dyn_cast<BranchOpInterface>(inst)) {
45 if (!inst->hasTrait<OpTrait::ReturnLike>())
48 SmallVector<mlir::Value, 2> retargs;
50 for (
auto &&[ret, returnPrimal, returnShadow] :
51 llvm::zip(inst->getOperands(), returnPrimals, returnShadows)) {
59 Type retTy = cast<AutoDiffTypeInterface>(ret.getType()).getShadowType();
60 auto toret = cast<AutoDiffTypeInterface>(retTy).createNullValue(
61 nBuilder, ret.getLoc());
62 retargs.push_back(toret);
67 nBB->push_back(newInst->create(newInst->getLoc(), newInst->getName(),
68 TypeRange(), retargs, newInst->getAttrs(),
69 mlir::PropertyRef(), newInst->getSuccessors(),
70 newInst->getNumRegions()));
71 gutils->
erase(newInst);
79 FunctionOpInterface fn, std::vector<DIFFE_TYPE> RetActivity,
81 std::vector<bool> returnPrimals,
DerivativeMode mode,
bool freeMemory,
82 size_t width, mlir::Type addedType,
MFnTypeInfo type_args,
83 std::vector<bool> volatile_args,
void *augmented,
bool omp,
84 llvm::StringRef postpasses,
bool verifyPostPasses,
bool strongZero) {
85 if (fn.getFunctionBody().empty()) {
86 llvm::errs() << fn <<
"\n";
87 llvm_unreachable(
"Differentiating empty function");
89 assert(fn.getFunctionBody().front().getNumArguments() == ArgActivity.size());
90 assert(fn.getFunctionBody().front().getNumArguments() ==
91 volatile_args.size());
94 fn, RetActivity, ArgActivity,
97 returnPrimals, mode,
static_cast<unsigned>(width), addedType, type_args,
103 std::vector<bool> returnShadows;
104 for (
auto act : RetActivity) {
107 SmallVector<bool> returnPrimalsP(returnPrimals.begin(), returnPrimals.end());
108 SmallVector<bool> returnShadowsP(returnShadows.begin(), returnShadows.end());
110 *
this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
111 RetActivity, ArgActivity, addedType,
112 false, postpasses, verifyPostPasses, strongZero);
120 const SmallPtrSet<mlir::Block *, 4> guaranteedUnreachable;
124 gutils->forceAugmentedReturns();
150 for (Block &oBB : gutils->oldFunc.getFunctionBody().getBlocks()) {
152 if (guaranteedUnreachable.find(&oBB) != guaranteedUnreachable.end()) {
153 auto newBB = gutils->getNewFromOriginal(&oBB);
155 for (
auto &I : make_early_inc_range(reverse(oBB))) {
156 gutils->eraseIfUnused(&I,
true,
false);
159 OpBuilder builder(gutils->oldFunc.getContext());
160 builder.setInsertionPointToEnd(newBB);
161 LLVM::UnreachableOp::create(builder, gutils->oldFunc.getLoc());
165 assert(oBB.getTerminator());
167 auto first = oBB.begin();
168 auto last = oBB.empty() ? oBB.end() : std::prev(oBB.end());
169 for (
auto it = first; it != last; ++it) {
171 auto res = gutils->visitChild(&*it);
172 valid &= res.succeeded();
196 auto nf = gutils->newFunc;
202 if (postpasses !=
"") {
203 mlir::PassManager pm(nf->getContext());
204 pm.enableVerifier(verifyPostPasses);
205 std::string error_message;
207 mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
208 if (mlir::failed(result)) {
212 if (!mlir::succeeded(pm.run(nf))) {
static MDiffeGradientUtils * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, ArrayRef< DIFFE_TYPE > RetActivity, ArrayRef< DIFFE_TYPE > ArgActivity, mlir::Type additionalArg, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
FunctionOpInterface CreateForwardDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)