18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/Interfaces/FunctionInterfaces.h"
22#include "mlir/Pass/PassManager.h"
24#define DEBUG_TYPE "enzyme"
28using namespace enzyme;
32#define GEN_PASS_DEF_DIFFERENTIATEPASS
33#include "Passes/Passes.h.inc"
38struct DifferentiatePass
39 :
public enzyme::impl::DifferentiatePassBase<DifferentiatePass> {
40 using DifferentiatePassBase::DifferentiatePassBase;
44 void runOnOperation()
override;
46 void getDependentDialects(DialectRegistry ®istry)
const override {
47 mlir::OpPassManager pm;
48 mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
49 if (!mlir::failed(result)) {
50 pm.getDependentDialects(registry);
53 registry.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
54 mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect,
55 mlir::enzyme::EnzymeDialect>();
58 static std::vector<DIFFE_TYPE> mode_from_fn(FunctionOpInterface fn,
60 std::vector<DIFFE_TYPE> retTypes;
61 for (
auto ty : fn.getResultTypes()) {
62 if (isa<IntegerType>(ty)) {
76 LogicalResult HandleAutoDiff(SymbolTableCollection &symbolTable, T CI) {
77 std::vector<DIFFE_TYPE> constants;
78 SmallVector<mlir::Value, 2> args;
81 auto activityAttr = CI.getActivity();
83 for (
unsigned i = 0; i < CI.getInputs().size(); ++i) {
84 mlir::Value res = CI.getInputs()[i];
86 auto mop = activityAttr[truei];
87 auto iattr = cast<mlir::enzyme::ActivityAttr>(mop);
90 switch (iattr.getValue()) {
91 case mlir::enzyme::Activity::enzyme_active:
94 case mlir::enzyme::Activity::enzyme_dup:
97 case mlir::enzyme::Activity::enzyme_const:
100 case mlir::enzyme::Activity::enzyme_dupnoneed:
103 case mlir::enzyme::Activity::enzyme_activenoneed:
105 assert(0 &&
"unsupported arg activenoneed");
107 case mlir::enzyme::Activity::enzyme_constnoneed:
109 assert(0 &&
"unsupported arg constnoneed");
113 constants.push_back(ty);
117 res = CI.getInputs()[i];
124 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr());
125 auto fn = cast<FunctionOpInterface>(symbolOp);
128 std::vector<DIFFE_TYPE> retType;
130 std::vector<bool> returnPrimals;
131 for (
auto act : CI.getRetActivity()) {
132 auto iattr = cast<mlir::enzyme::ActivityAttr>(act);
133 auto val = iattr.getValue();
135 bool primalNeeded =
true;
137 case mlir::enzyme::Activity::enzyme_active:
140 case mlir::enzyme::Activity::enzyme_dup:
143 case mlir::enzyme::Activity::enzyme_const:
146 case mlir::enzyme::Activity::enzyme_dupnoneed:
148 primalNeeded =
false;
150 case mlir::enzyme::Activity::enzyme_activenoneed:
152 primalNeeded =
false;
154 case mlir::enzyme::Activity::enzyme_constnoneed:
156 primalNeeded =
false;
159 retType.push_back(ty);
160 returnPrimals.push_back(primalNeeded);
165 bool freeMemory =
true;
167 size_t width = CI.getWidth();
169 std::vector<bool> volatile_args;
170 for (
auto &a : fn.getFunctionBody().getArguments()) {
176 fn, retType, constants, TA, returnPrimals, mode, freeMemory, width,
177 nullptr, type_args, volatile_args,
178 nullptr, omp, postpasses, verifyPostPasses,
183 OpBuilder builder(CI);
184 auto dCI = func::CallOp::create(builder, CI.getLoc(), newFunc.getName(),
185 newFunc.getResultTypes(), args);
186 if (dCI.getNumResults() != CI.getNumResults()) {
187 CI.emitError() <<
"Incorrect number of results for enzyme operation: "
188 << *CI <<
" expected " << *dCI;
191 CI.replaceAllUsesWith(dCI);
196 template <
typename T>
197 LogicalResult HandleAutoDiffReverse(SymbolTableCollection &symbolTable,
200 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr());
201 auto fn = cast<FunctionOpInterface>(symbolOp);
203 if (CI.getActivity().size() != fn.getNumArguments()) {
204 llvm::errs() <<
"Incorrect number of argument activities on autodiff op"
205 <<
"CI: " << CI <<
", expected " << fn.getNumArguments()
206 <<
" found " << CI.getActivity().size() <<
"\n";
209 if (CI.getRetActivity().size() != fn.getNumResults()) {
210 llvm::errs() <<
"Incorrect number of result activities on autodiff op"
211 <<
"CI: " << CI <<
", expected " << fn.getNumResults()
212 <<
" found " << CI.getRetActivity().size() <<
"\n";
216 std::vector<DIFFE_TYPE> arg_activities;
217 SmallVector<mlir::Value, 2> args;
221 for (
auto act : CI.getActivity()) {
222 if (call_idx >= CI.getInputs().size()) {
223 llvm::errs() <<
"Too few arguments to autodiff op"
224 <<
"CI: " << CI <<
"\n";
227 mlir::Value res = CI.getInputs()[call_idx];
230 auto iattr = cast<mlir::enzyme::ActivityAttr>(act);
231 auto val = iattr.getValue();
234 case mlir::enzyme::Activity::enzyme_active:
237 case mlir::enzyme::Activity::enzyme_dup:
240 case mlir::enzyme::Activity::enzyme_const:
243 case mlir::enzyme::Activity::enzyme_dupnoneed:
246 case mlir::enzyme::Activity::enzyme_activenoneed:
248 assert(0 &&
"unsupported arg activenoneed");
250 case mlir::enzyme::Activity::enzyme_constnoneed:
252 assert(0 &&
"unsupported arg constnoneed");
255 arg_activities.push_back(ty);
258 if (call_idx >= CI.getInputs().size()) {
259 llvm::errs() <<
"Too few arguments to autodiff op"
260 <<
"CI: " << CI <<
"\n";
263 res = CI.getInputs()[call_idx];
272 std::vector<DIFFE_TYPE> retType;
273 std::vector<bool> returnPrimals;
274 std::vector<bool> returnShadows;
277 for (
auto act : CI.getRetActivity()) {
278 auto iattr = cast<mlir::enzyme::ActivityAttr>(act);
279 auto val = iattr.getValue();
281 bool primalNeeded =
true;
283 case mlir::enzyme::Activity::enzyme_active:
286 case mlir::enzyme::Activity::enzyme_dup:
289 case mlir::enzyme::Activity::enzyme_const:
292 case mlir::enzyme::Activity::enzyme_dupnoneed:
294 primalNeeded =
false;
296 case mlir::enzyme::Activity::enzyme_activenoneed:
298 primalNeeded =
false;
300 case mlir::enzyme::Activity::enzyme_constnoneed:
302 primalNeeded =
false;
305 retType.push_back(ty);
306 returnPrimals.push_back(primalNeeded);
307 returnShadows.push_back(
false);
309 if (call_idx >= CI.getInputs().size()) {
310 llvm::errs() <<
"Too few arguments to autodiff op"
311 <<
"CI: " << CI <<
"\n";
314 mlir::Value res = CI.getInputs()[call_idx];
322 bool freeMemory =
true;
323 size_t width = CI.getWidth();
325 std::vector<bool> volatile_args;
326 for (
auto &a : fn.getFunctionBody().getArguments()) {
331 FunctionOpInterface newFunc =
333 returnShadows, mode, freeMemory, width,
334 nullptr, type_args, volatile_args,
335 nullptr, omp, postpasses,
336 verifyPostPasses, CI.getStrongZero());
340 OpBuilder builder(CI);
342 dyn_cast<AutoDiffFunctionInterface>(newFunc.getOperation())) {
343 auto dCI = iface.createCall(builder, CI.getLoc(), args);
344 CI.replaceAllUsesWith(dCI);
346 newFunc.getOperation()->emitError()
347 <<
"this function operation does not implement "
348 "AutoDiffFunctionInterface";
355 void lowerEnzymeCalls(SymbolTableCollection &symbolTable,
356 FunctionOpInterface op) {
358 SmallVector<Operation *> toLower;
359 op->walk([&](enzyme::ForwardDiffOp dop) {
361 symbolTable.lookupNearestSymbolFrom(dop, dop.getFnAttr());
362 auto callableOp = cast<FunctionOpInterface>(symbolOp);
364 lowerEnzymeCalls(symbolTable, callableOp);
365 toLower.push_back(dop);
368 for (
auto T : toLower) {
369 if (
auto F = dyn_cast<enzyme::ForwardDiffOp>(T)) {
370 auto res = HandleAutoDiff(symbolTable, F);
371 if (!res.succeeded()) {
376 llvm_unreachable(
"Illegal type");
382 SmallVector<Operation *> toLower;
383 op->walk([&](enzyme::AutoDiffOp dop) {
385 symbolTable.lookupNearestSymbolFrom(dop, dop.getFnAttr());
386 auto callableOp = cast<FunctionOpInterface>(symbolOp);
388 lowerEnzymeCalls(symbolTable, callableOp);
389 toLower.push_back(dop);
392 for (
auto T : toLower) {
393 if (
auto F = dyn_cast<enzyme::AutoDiffOp>(T)) {
394 auto res = HandleAutoDiffReverse(symbolTable, F);
395 if (!res.succeeded()) {
400 llvm_unreachable(
"Illegal type");
409void DifferentiatePass::runOnOperation() {
410 SymbolTableCollection symbolTable;
411 symbolTable.getSymbolTable(getOperation());
412 getOperation()->walk(
413 [&](FunctionOpInterface op) { lowerEnzymeCalls(symbolTable, op); });
DIFFE_TYPE
Potential differentiable argument classifications.
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, std::vector< bool > returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
FunctionOpInterface CreateForwardDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
MFnTypeInfo getAnalyzedTypeInfo(FunctionOpInterface op) const