Enzyme main
Loading...
Searching...
No Matches
FuncAutoDiffOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- FuncAutoDiffOpInterfaceImpl.cpp - Interface external model --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the external model implementation of the automatic
10// differentiation op interfaces for the upstream MLIR arithmetic dialect.
11//
12//===----------------------------------------------------------------------===//
13
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/IR/DialectRegistry.h"
20#include "mlir/Support/LogicalResult.h"
21
22#include "Dialect/Ops.h"
23#include "mlir/IR/TypeSupport.h"
24
25using namespace mlir;
26using namespace mlir::enzyme;
27
28namespace {
29#include "Implementations/FuncDerivatives.inc"
30} // namespace
31
33 : public AutoDiffOpInterface::ExternalModel<AutoDiffCallFwd, func::CallOp> {
34public:
35 LogicalResult createForwardModeTangent(Operation *orig, OpBuilder &builder,
36 MGradientUtils *gutils) const {
38
39 auto callOp = cast<func::CallOp>(orig);
40 SymbolTable symbolTable = SymbolTable::getNearestSymbolTable(orig);
41
42 Operation *callee = symbolTable.lookup(callOp.getCallee());
43 auto fn = cast<FunctionOpInterface>(callee);
44
45 auto narg = orig->getNumOperands();
46 auto nret = orig->getNumResults();
47
48 std::vector<DIFFE_TYPE> RetActivity;
49 RetActivity.reserve(nret);
50 for (auto res : callOp.getResults()) {
51 RetActivity.push_back(gutils->isConstantValue(res) ? DIFFE_TYPE::CONSTANT
53 }
54
55 std::vector<DIFFE_TYPE> ArgActivity;
56 ArgActivity.reserve(narg);
57 for (auto arg : callOp.getOperands()) {
58 ArgActivity.push_back(gutils->isConstantValue(arg) ? DIFFE_TYPE::CONSTANT
60 }
61
62 std::vector<bool> returnPrimal(nret, true);
63 std::vector<bool> returnShadow(nret, false);
64
65 auto type_args = gutils->TA.getAnalyzedTypeInfo(fn);
66
67 bool freeMemory = true;
68 size_t width = gutils->width;
69
70 std::vector<bool> volatile_args(narg, false);
71
72 auto forwardFn = gutils->Logic.CreateForwardDiff(
73 fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, mode,
74 freeMemory, width,
75 /* addedType */ nullptr, type_args, volatile_args,
76 /* augmented */ nullptr, gutils->omp, gutils->postpasses,
77 gutils->verifyPostPasses, gutils->strongZero);
78
79 SmallVector<Value> fwdArguments;
80
81 for (auto &&[arg, act] :
82 llvm::zip_equal(callOp.getOperands(), ArgActivity)) {
83
84 fwdArguments.push_back(gutils->getNewFromOriginal(arg));
85 if (act == DIFFE_TYPE::DUP_ARG)
86 fwdArguments.push_back(gutils->invertPointerM(arg, builder));
87 }
88
89 auto fwdCallOp = func::CallOp::create(
90 builder, orig->getLoc(), cast<func::FuncOp>(forwardFn), fwdArguments);
91
92 SmallVector<Value> primals;
93 primals.reserve(nret);
94
95 int fwdIndex = 0;
96 for (auto &&[ret, act] :
97 llvm::zip_equal(callOp.getResults(), RetActivity)) {
98 auto fwdRet = fwdCallOp.getResult(fwdIndex);
99 primals.push_back(fwdRet);
100
101 fwdIndex++;
102
103 if (act == DIFFE_TYPE::DUP_ARG) {
104 gutils->setDiffe(ret, fwdCallOp.getResult(fwdIndex), builder);
105 fwdIndex++;
106 }
107 }
108
109 auto newOp = gutils->getNewFromOriginal(orig);
110 gutils->replaceOrigOpWith(orig, primals);
111 gutils->erase(newOp);
112
113 return success();
114 }
115};
116
118 : public ReverseAutoDiffOpInterface::ExternalModel<AutoDiffCallRev,
119 func::CallOp> {
120public:
121 LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
122 MGradientUtilsReverse *gutils,
123 SmallVector<Value> caches) const {
125
126 SymbolTable symbolTable = SymbolTable::getNearestSymbolTable(orig);
127
128 func::CallOp callOp = cast<func::CallOp>(orig);
129
130 Operation *callee = symbolTable.lookup(callOp.getCallee());
131 auto fn = cast<FunctionOpInterface>(callee);
132
133 auto narg = orig->getNumOperands();
134 auto nret = orig->getNumResults();
135
136 std::vector<DIFFE_TYPE> RetActivity;
137 for (auto res : callOp.getResults()) {
138 RetActivity.push_back(
140 : cast<AutoDiffTypeInterface>(res.getType()).isMutable()
143 }
144
145 std::vector<DIFFE_TYPE> ArgActivity;
146 for (auto arg : callOp.getOperands()) {
147 ArgActivity.push_back(
149 : cast<AutoDiffTypeInterface>(arg.getType()).isMutable()
152 }
153
154 if (llvm::any_of(RetActivity,
155 [&](auto act) { return act == DIFFE_TYPE::DUP_ARG; })) {
156 orig->emitError()
157 << "could not emit adjoint with mutable return types in: " << *orig
158 << "\n";
159 return failure();
160 }
161
162 std::vector<bool> volatile_args(narg, true);
163 std::vector<bool> returnShadow(nret, false);
164 std::vector<bool> returnPrimal(nret, false);
165
166 auto type_args = gutils->TA.getAnalyzedTypeInfo(fn);
167
168 bool freeMemory = true;
169 size_t width = gutils->width;
170
171 auto revFn = gutils->Logic.CreateReverseDiff(
172 fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, returnShadow,
173 mode, freeMemory, width, /*addedType*/ nullptr, type_args,
174 volatile_args, /*augmented*/ nullptr, gutils->omp, gutils->postpasses,
175 gutils->verifyPostPasses, gutils->strongZero);
176
177 SmallVector<Value> revArguments;
178
179 for (auto [arg, act, cache] :
180 llvm::zip_equal(callOp.getOperands(), ArgActivity, caches)) {
181 revArguments.push_back(gutils->popCache(cache, builder));
182 if (act == DIFFE_TYPE::DUP_ARG)
183 revArguments.push_back(gutils->invertPointerM(arg, builder));
184 }
185
186 for (auto result : callOp.getResults()) {
187 if (gutils->isConstantValue(result))
188 continue;
189 revArguments.push_back(gutils->diffe(result, builder));
190 }
191
192 auto revCallOp = func::CallOp::create(
193 builder, orig->getLoc(), cast<func::FuncOp>(revFn), revArguments);
194
195 int revIndex = 0, fwdIndex = 0;
196 for (auto [arg, act] : llvm::zip_equal(callOp.getOperands(), ArgActivity)) {
197 fwdIndex++;
198
199 if (gutils->isConstantValue(arg))
200 continue;
201
202 if (act == DIFFE_TYPE::DUP_ARG) {
203 cast<ClonableTypeInterface>(arg.getType())
204 .freeClonedValue(builder, revArguments[fwdIndex - 1]);
205 fwdIndex++;
206 } else {
207 auto diffe = revCallOp.getResult(revIndex);
208 gutils->addToDiffe(arg, diffe, builder);
209 revIndex++;
210 }
211 }
212
213 return success();
214 }
215
216 SmallVector<Value> cacheValues(Operation *orig,
217 MGradientUtilsReverse *gutils) const {
218 SmallVector<Value> cachedArguments;
219
220 Operation *newOp = gutils->getNewFromOriginal(orig);
221 OpBuilder cacheBuilder(newOp);
222
223 for (auto arg : orig->getOperands()) {
224 Value toCache = gutils->getNewFromOriginal(arg);
225 if (auto iface = dyn_cast<ClonableTypeInterface>(arg.getType())) {
226 toCache = iface.cloneValue(cacheBuilder, toCache);
227 }
228 Value cache = gutils->initAndPushCache(toCache, cacheBuilder);
229 cachedArguments.push_back(cache);
230 }
231
232 return cachedArguments;
233 }
234
235 void createShadowValues(Operation *op, OpBuilder &builder,
236 MGradientUtilsReverse *gutils) const {}
237};
238
240 : public AutoDiffFunctionInterface::ExternalModel<
241 AutoDiffFuncFuncFunctionInterface, func::FuncOp> {
242public:
243 void transformResultTypes(Operation *self,
244 SmallVectorImpl<Type> &types) const {}
245
246 Operation *createCall(Operation *self, OpBuilder &builder, Location loc,
247 ValueRange args) const {
248 return func::CallOp::create(builder, loc, cast<func::FuncOp>(self), args);
249 }
250
251 Operation *createReturn(Operation *self, OpBuilder &builder, Location loc,
252 ValueRange args) const {
253 return func::ReturnOp::create(builder, loc, args);
254 }
255};
256
258 DialectRegistry &registry) {
259 registry.addExtension(+[](MLIRContext *context, func::FuncDialect *) {
260 registerInterfaces(context);
261 func::CallOp::attachInterface<AutoDiffCallFwd>(*context);
262 func::CallOp::attachInterface<AutoDiffCallRev>(*context);
263 func::FuncOp::attachInterface<AutoDiffFuncFuncFunctionInterface>(*context);
264 });
265}
DerivativeMode
Definition Utils.h:390
LogicalResult createForwardModeTangent(Operation *orig, OpBuilder &builder, MGradientUtils *gutils) const
SmallVector< Value > cacheValues(Operation *orig, MGradientUtilsReverse *gutils) const
void createShadowValues(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector< Value > caches) const
Operation * createCall(Operation *self, OpBuilder &builder, Location loc, ValueRange args) const
Operation * createReturn(Operation *self, OpBuilder &builder, Location loc, ValueRange args) const
void transformResultTypes(Operation *self, SmallVectorImpl< Type > &types) const
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, std::vector< bool > returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
FunctionOpInterface CreateForwardDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
Value popCache(Value cache, OpBuilder &builder)
Value initAndPushCache(Value v, OpBuilder &builder)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
void replaceOrigOpWith(Operation *op, ValueRange vals)
void erase(Operation *op)
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
bool isConstantValue(mlir::Value v) const
MFnTypeInfo getAnalyzedTypeInfo(FunctionOpInterface op) const
Definition EnzymeLogic.h:24
void registerFuncDialectAutoDiffInterface(DialectRegistry &registry)