18#include "mlir/IR/Builders.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/Interfaces/ControlFlowInterfaces.h"
21#include "mlir/Interfaces/FunctionInterfaces.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
25#define DEBUG_TYPE "enzyme"
29#define GEN_PASS_DEF_DIFFERENTIATEWRAPPERPASS
30#include "Passes/Passes.h.inc"
36using namespace enzyme;
41 std::vector<DIFFE_TYPE> ArgActivity;
42 SmallVector<StringRef, 1> split;
43 StringRef(inp.data(), inp.size()).split(split,
',');
44 for (
auto &
str : split) {
45 if (
str ==
"enzyme_dup")
47 else if (
str ==
"enzyme_const")
49 else if (
str ==
"enzyme_dupnoneed")
51 else if (
str ==
"enzyme_active")
54 llvm::errs() <<
"unknown activity to parse, found: '" <<
str <<
"'\n";
55 assert(0 &&
" unknown constant");
62struct DifferentiateWrapperPass
63 :
public enzyme::impl::DifferentiateWrapperPassBase<
64 DifferentiateWrapperPass> {
65 using DifferentiateWrapperPassBase::DifferentiateWrapperPassBase;
67 void runOnOperation()
override {
69 SymbolTableCollection symbolTable;
70 symbolTable.getSymbolTable(getOperation());
72 Operation *symbolOp =
nullptr;
74 symbolOp = symbolTable.lookupSymbolIn<Operation *>(
75 getOperation(), StringAttr::get(getOperation()->getContext(), infn));
77 for (
auto &op : getOperation()->getRegion(0).front()) {
78 auto fn = dyn_cast<FunctionOpInterface>(symbolOp);
81 assert(symbolOp ==
nullptr);
86 llvm::errs() <<
" Could not find function '" << infn
87 <<
"' to differentiate\n";
91 auto fn = cast<FunctionOpInterface>(symbolOp);
93 std::string postpasses =
"";
94 bool verifyPostPasses =
true;
95 bool strongZero =
false;
97 std::vector<DIFFE_TYPE> ArgActivity =
100 if (ArgActivity.size() != fn.getNumArguments()) {
102 <<
"Incorrect number of arg activity states for function, found "
103 << ArgActivity.size() <<
" expected "
104 << fn.getFunctionBody().front().getNumArguments();
108 std::vector<DIFFE_TYPE> RetActivity =
110 if (RetActivity.size() != fn.getNumResults()) {
112 <<
"Incorrect number of ret activity states for function, found "
113 << RetActivity.size() <<
" expected " << fn.getNumResults();
116 std::vector<bool> returnPrimal;
117 std::vector<bool> returnShadow;
118 for (
auto act : RetActivity) {
120 returnShadow.push_back(
false);
126 bool freeMemory =
true;
129 std::vector<bool> volatile_args;
130 for (
auto &a : fn.getFunctionBody().getArguments()) {
135 FunctionOpInterface newFunc;
138 fn, RetActivity, ArgActivity, TA, returnPrimal, mode, freeMemory,
140 nullptr, type_args, volatile_args,
141 nullptr, omp, postpasses, verifyPostPasses, strongZero);
144 fn, RetActivity, ArgActivity, TA, returnPrimal, returnShadow, mode,
146 nullptr, type_args, volatile_args,
147 nullptr, omp, postpasses, verifyPostPasses, strongZero);
155 SymbolTable::setSymbolVisibility(newFunc,
156 SymbolTable::Visibility::Public);
157 SymbolTable::setSymbolName(cast<FunctionOpInterface>(newFunc),
160 SymbolTable::setSymbolName(cast<FunctionOpInterface>(newFunc),
std::vector< DIFFE_TYPE > parseActivityString(StringRef inp)
static std::string str(AugmentedStruct c)
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, std::vector< bool > returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
FunctionOpInterface CreateForwardDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
MFnTypeInfo getAnalyzedTypeInfo(FunctionOpInterface op) const