Enzyme main
Loading...
Searching...
No Matches
EnzymeMLIRPass.cpp
Go to the documentation of this file.
1//===- EnzymeMLIRPass.cpp - Replace calls with their derivatives ------------ //
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into
10// a generic parallel for representation
11//===----------------------------------------------------------------------===//
12
13#include "Dialect/Ops.h"
15#include "PassDetails.h"
16#include "Passes/Passes.h"
17
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/Interfaces/FunctionInterfaces.h"
22#include "mlir/Pass/PassManager.h"
23
24#define DEBUG_TYPE "enzyme"
25
26using namespace mlir;
27using namespace mlir::enzyme;
28using namespace enzyme;
29
30namespace mlir {
31namespace enzyme {
32#define GEN_PASS_DEF_DIFFERENTIATEPASS
33#include "Passes/Passes.h.inc"
34} // namespace enzyme
35} // namespace mlir
36
37namespace {
38struct DifferentiatePass
39 : public enzyme::impl::DifferentiatePassBase<DifferentiatePass> {
40 using DifferentiatePassBase::DifferentiatePassBase;
41
42 MEnzymeLogic Logic;
43
44 void runOnOperation() override;
45
46 void getDependentDialects(DialectRegistry &registry) const override {
47 mlir::OpPassManager pm;
48 mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
49 if (!mlir::failed(result)) {
50 pm.getDependentDialects(registry);
51 }
52
53 registry.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
54 mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect,
55 mlir::enzyme::EnzymeDialect>();
56 }
57
58 static std::vector<DIFFE_TYPE> mode_from_fn(FunctionOpInterface fn,
59 DerivativeMode mode) {
60 std::vector<DIFFE_TYPE> retTypes;
61 for (auto ty : fn.getResultTypes()) {
62 if (isa<IntegerType>(ty)) {
63 retTypes.push_back(DIFFE_TYPE::CONSTANT);
64 continue;
65 }
66
68 retTypes.push_back(DIFFE_TYPE::OUT_DIFF);
69 else
70 retTypes.push_back(DIFFE_TYPE::DUP_ARG);
71 }
72 return retTypes;
73 }
74
75 template <typename T>
76 LogicalResult HandleAutoDiff(SymbolTableCollection &symbolTable, T CI) {
77 std::vector<DIFFE_TYPE> constants;
78 SmallVector<mlir::Value, 2> args;
79
80 size_t truei = 0;
81 auto activityAttr = CI.getActivity();
82
83 for (unsigned i = 0; i < CI.getInputs().size(); ++i) {
84 mlir::Value res = CI.getInputs()[i];
85
86 auto mop = activityAttr[truei];
87 auto iattr = cast<mlir::enzyme::ActivityAttr>(mop);
88 DIFFE_TYPE ty;
89
90 switch (iattr.getValue()) {
91 case mlir::enzyme::Activity::enzyme_active:
93 break;
94 case mlir::enzyme::Activity::enzyme_dup:
96 break;
97 case mlir::enzyme::Activity::enzyme_const:
99 break;
100 case mlir::enzyme::Activity::enzyme_dupnoneed:
102 break;
103 case mlir::enzyme::Activity::enzyme_activenoneed:
105 assert(0 && "unsupported arg activenoneed");
106 break;
107 case mlir::enzyme::Activity::enzyme_constnoneed:
109 assert(0 && "unsupported arg constnoneed");
110 break;
111 }
112
113 constants.push_back(ty);
114 args.push_back(res);
115 if (ty == DIFFE_TYPE::DUP_ARG || ty == DIFFE_TYPE::DUP_NONEED) {
116 ++i;
117 res = CI.getInputs()[i];
118 args.push_back(res);
119 }
120
121 truei++;
122 }
123
124 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr());
125 auto fn = cast<FunctionOpInterface>(symbolOp);
126
127 auto mode = DerivativeMode::ForwardMode;
128 std::vector<DIFFE_TYPE> retType;
129
130 std::vector<bool> returnPrimals;
131 for (auto act : CI.getRetActivity()) {
132 auto iattr = cast<mlir::enzyme::ActivityAttr>(act);
133 auto val = iattr.getValue();
134 DIFFE_TYPE ty;
135 bool primalNeeded = true;
136 switch (val) {
137 case mlir::enzyme::Activity::enzyme_active:
139 break;
140 case mlir::enzyme::Activity::enzyme_dup:
142 break;
143 case mlir::enzyme::Activity::enzyme_const:
145 break;
146 case mlir::enzyme::Activity::enzyme_dupnoneed:
148 primalNeeded = false;
149 break;
150 case mlir::enzyme::Activity::enzyme_activenoneed:
152 primalNeeded = false;
153 break;
154 case mlir::enzyme::Activity::enzyme_constnoneed:
156 primalNeeded = false;
157 break;
158 }
159 retType.push_back(ty);
160 returnPrimals.push_back(primalNeeded);
161 }
162
163 MTypeAnalysis TA;
164 auto type_args = TA.getAnalyzedTypeInfo(fn);
165 bool freeMemory = true;
166 bool omp = false;
167 size_t width = CI.getWidth();
168
169 std::vector<bool> volatile_args;
170 for (auto &a : fn.getFunctionBody().getArguments()) {
171 (void)a;
172 volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined));
173 }
174
175 FunctionOpInterface newFunc = Logic.CreateForwardDiff(
176 fn, retType, constants, TA, returnPrimals, mode, freeMemory, width,
177 /*addedType*/ nullptr, type_args, volatile_args,
178 /*augmented*/ nullptr, omp, postpasses, verifyPostPasses,
179 CI.getStrongZero());
180 if (!newFunc)
181 return failure();
182
183 OpBuilder builder(CI);
184 auto dCI = func::CallOp::create(builder, CI.getLoc(), newFunc.getName(),
185 newFunc.getResultTypes(), args);
186 if (dCI.getNumResults() != CI.getNumResults()) {
187 CI.emitError() << "Incorrect number of results for enzyme operation: "
188 << *CI << " expected " << *dCI;
189 return failure();
190 }
191 CI.replaceAllUsesWith(dCI);
192 CI->erase();
193 return success();
194 }
195
196 template <typename T>
197 LogicalResult HandleAutoDiffReverse(SymbolTableCollection &symbolTable,
198 T CI) {
199
200 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr());
201 auto fn = cast<FunctionOpInterface>(symbolOp);
202 assert(fn);
203 if (CI.getActivity().size() != fn.getNumArguments()) {
204 llvm::errs() << "Incorrect number of argument activities on autodiff op"
205 << "CI: " << CI << ", expected " << fn.getNumArguments()
206 << " found " << CI.getActivity().size() << "\n";
207 return failure();
208 }
209 if (CI.getRetActivity().size() != fn.getNumResults()) {
210 llvm::errs() << "Incorrect number of result activities on autodiff op"
211 << "CI: " << CI << ", expected " << fn.getNumResults()
212 << " found " << CI.getRetActivity().size() << "\n";
213 return failure();
214 }
215
216 std::vector<DIFFE_TYPE> arg_activities;
217 SmallVector<mlir::Value, 2> args;
218
219 size_t call_idx = 0;
220 {
221 for (auto act : CI.getActivity()) {
222 if (call_idx >= CI.getInputs().size()) {
223 llvm::errs() << "Too few arguments to autodiff op"
224 << "CI: " << CI << "\n";
225 return failure();
226 }
227 mlir::Value res = CI.getInputs()[call_idx];
228 ++call_idx;
229
230 auto iattr = cast<mlir::enzyme::ActivityAttr>(act);
231 auto val = iattr.getValue();
232 DIFFE_TYPE ty;
233 switch (val) {
234 case mlir::enzyme::Activity::enzyme_active:
236 break;
237 case mlir::enzyme::Activity::enzyme_dup:
239 break;
240 case mlir::enzyme::Activity::enzyme_const:
242 break;
243 case mlir::enzyme::Activity::enzyme_dupnoneed:
245 break;
246 case mlir::enzyme::Activity::enzyme_activenoneed:
248 assert(0 && "unsupported arg activenoneed");
249 break;
250 case mlir::enzyme::Activity::enzyme_constnoneed:
252 assert(0 && "unsupported arg constnoneed");
253 break;
254 }
255 arg_activities.push_back(ty);
256 args.push_back(res);
257 if (ty == DIFFE_TYPE::DUP_ARG || ty == DIFFE_TYPE::DUP_NONEED) {
258 if (call_idx >= CI.getInputs().size()) {
259 llvm::errs() << "Too few arguments to autodiff op"
260 << "CI: " << CI << "\n";
261 return failure();
262 }
263 res = CI.getInputs()[call_idx];
264 ++call_idx;
265 args.push_back(res);
266 }
267 }
268 }
269
270 bool omp = false;
272 std::vector<DIFFE_TYPE> retType;
273 std::vector<bool> returnPrimals;
274 std::vector<bool> returnShadows;
275
276 // Add the return gradient
277 for (auto act : CI.getRetActivity()) {
278 auto iattr = cast<mlir::enzyme::ActivityAttr>(act);
279 auto val = iattr.getValue();
280 DIFFE_TYPE ty;
281 bool primalNeeded = true;
282 switch (val) {
283 case mlir::enzyme::Activity::enzyme_active:
285 break;
286 case mlir::enzyme::Activity::enzyme_dup:
288 break;
289 case mlir::enzyme::Activity::enzyme_const:
291 break;
292 case mlir::enzyme::Activity::enzyme_dupnoneed:
294 primalNeeded = false;
295 break;
296 case mlir::enzyme::Activity::enzyme_activenoneed:
298 primalNeeded = false;
299 break;
300 case mlir::enzyme::Activity::enzyme_constnoneed:
302 primalNeeded = false;
303 break;
304 }
305 retType.push_back(ty);
306 returnPrimals.push_back(primalNeeded);
307 returnShadows.push_back(false);
308 if (ty == DIFFE_TYPE::OUT_DIFF) {
309 if (call_idx >= CI.getInputs().size()) {
310 llvm::errs() << "Too few arguments to autodiff op"
311 << "CI: " << CI << "\n";
312 return failure();
313 }
314 mlir::Value res = CI.getInputs()[call_idx];
315 ++call_idx;
316 args.push_back(res);
317 }
318 }
319
320 MTypeAnalysis TA;
321 auto type_args = TA.getAnalyzedTypeInfo(fn);
322 bool freeMemory = true;
323 size_t width = CI.getWidth();
324
325 std::vector<bool> volatile_args;
326 for (auto &a : fn.getFunctionBody().getArguments()) {
327 (void)a;
328 volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined));
329 }
330
331 FunctionOpInterface newFunc =
332 Logic.CreateReverseDiff(fn, retType, arg_activities, TA, returnPrimals,
333 returnShadows, mode, freeMemory, width,
334 /*addedType*/ nullptr, type_args, volatile_args,
335 /*augmented*/ nullptr, omp, postpasses,
336 verifyPostPasses, CI.getStrongZero());
337 if (!newFunc)
338 return failure();
339
340 OpBuilder builder(CI);
341 if (auto iface =
342 dyn_cast<AutoDiffFunctionInterface>(newFunc.getOperation())) {
343 auto dCI = iface.createCall(builder, CI.getLoc(), args);
344 CI.replaceAllUsesWith(dCI);
345 } else {
346 newFunc.getOperation()->emitError()
347 << "this function operation does not implement "
348 "AutoDiffFunctionInterface";
349 return failure();
350 }
351 CI->erase();
352 return success();
353 }
354
355 void lowerEnzymeCalls(SymbolTableCollection &symbolTable,
356 FunctionOpInterface op) {
357 {
358 SmallVector<Operation *> toLower;
359 op->walk([&](enzyme::ForwardDiffOp dop) {
360 auto *symbolOp =
361 symbolTable.lookupNearestSymbolFrom(dop, dop.getFnAttr());
362 auto callableOp = cast<FunctionOpInterface>(symbolOp);
363
364 lowerEnzymeCalls(symbolTable, callableOp);
365 toLower.push_back(dop);
366 });
367
368 for (auto T : toLower) {
369 if (auto F = dyn_cast<enzyme::ForwardDiffOp>(T)) {
370 auto res = HandleAutoDiff(symbolTable, F);
371 if (!res.succeeded()) {
372 signalPassFailure();
373 return;
374 }
375 } else {
376 llvm_unreachable("Illegal type");
377 }
378 }
379 };
380
381 {
382 SmallVector<Operation *> toLower;
383 op->walk([&](enzyme::AutoDiffOp dop) {
384 auto *symbolOp =
385 symbolTable.lookupNearestSymbolFrom(dop, dop.getFnAttr());
386 auto callableOp = cast<FunctionOpInterface>(symbolOp);
387
388 lowerEnzymeCalls(symbolTable, callableOp);
389 toLower.push_back(dop);
390 });
391
392 for (auto T : toLower) {
393 if (auto F = dyn_cast<enzyme::AutoDiffOp>(T)) {
394 auto res = HandleAutoDiffReverse(symbolTable, F);
395 if (!res.succeeded()) {
396 signalPassFailure();
397 return;
398 }
399 } else {
400 llvm_unreachable("Illegal type");
401 }
402 }
403 }
404 };
405};
406
407} // end anonymous namespace
408
409void DifferentiatePass::runOnOperation() {
410 SymbolTableCollection symbolTable;
411 symbolTable.getSymbolTable(getOperation());
412 getOperation()->walk(
413 [&](FunctionOpInterface op) { lowerEnzymeCalls(symbolTable, op); });
414}
DIFFE_TYPE
Potential differentiable argument classifications.
Definition Utils.h:374
DerivativeMode
Definition Utils.h:390
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, std::vector< bool > returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
FunctionOpInterface CreateForwardDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
MFnTypeInfo getAnalyzedTypeInfo(FunctionOpInterface op) const
Definition EnzymeLogic.h:24