18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/IR/DialectRegistry.h"
20#include "mlir/Support/LogicalResult.h"
23#include "mlir/IR/TypeSupport.h"
29#include "Implementations/FuncDerivatives.inc"
33 :
public AutoDiffOpInterface::ExternalModel<AutoDiffCallFwd, func::CallOp> {
39 auto callOp = cast<func::CallOp>(orig);
40 SymbolTable symbolTable = SymbolTable::getNearestSymbolTable(orig);
42 Operation *callee = symbolTable.lookup(callOp.getCallee());
43 auto fn = cast<FunctionOpInterface>(callee);
45 auto narg = orig->getNumOperands();
46 auto nret = orig->getNumResults();
48 std::vector<DIFFE_TYPE> RetActivity;
49 RetActivity.reserve(nret);
50 for (
auto res : callOp.getResults()) {
55 std::vector<DIFFE_TYPE> ArgActivity;
56 ArgActivity.reserve(narg);
57 for (
auto arg : callOp.getOperands()) {
62 std::vector<bool> returnPrimal(nret,
true);
63 std::vector<bool> returnShadow(nret,
false);
67 bool freeMemory =
true;
68 size_t width = gutils->
width;
70 std::vector<bool> volatile_args(narg,
false);
73 fn, RetActivity, ArgActivity, gutils->
TA, returnPrimal, mode,
75 nullptr, type_args, volatile_args,
79 SmallVector<Value> fwdArguments;
81 for (
auto &&[arg, act] :
82 llvm::zip_equal(callOp.getOperands(), ArgActivity)) {
89 auto fwdCallOp = func::CallOp::create(
90 builder, orig->getLoc(), cast<func::FuncOp>(forwardFn), fwdArguments);
92 SmallVector<Value> primals;
93 primals.reserve(nret);
96 for (
auto &&[ret, act] :
97 llvm::zip_equal(callOp.getResults(), RetActivity)) {
98 auto fwdRet = fwdCallOp.getResult(fwdIndex);
99 primals.push_back(fwdRet);
104 gutils->
setDiffe(ret, fwdCallOp.getResult(fwdIndex), builder);
111 gutils->
erase(newOp);
118 :
public ReverseAutoDiffOpInterface::ExternalModel<AutoDiffCallRev,
123 SmallVector<Value> caches)
const {
126 SymbolTable symbolTable = SymbolTable::getNearestSymbolTable(orig);
128 func::CallOp callOp = cast<func::CallOp>(orig);
130 Operation *callee = symbolTable.lookup(callOp.getCallee());
131 auto fn = cast<FunctionOpInterface>(callee);
133 auto narg = orig->getNumOperands();
134 auto nret = orig->getNumResults();
136 std::vector<DIFFE_TYPE> RetActivity;
137 for (
auto res : callOp.getResults()) {
138 RetActivity.push_back(
140 : cast<AutoDiffTypeInterface>(res.getType()).isMutable()
145 std::vector<DIFFE_TYPE> ArgActivity;
146 for (
auto arg : callOp.getOperands()) {
147 ArgActivity.push_back(
149 : cast<AutoDiffTypeInterface>(arg.getType()).isMutable()
154 if (llvm::any_of(RetActivity,
157 <<
"could not emit adjoint with mutable return types in: " << *orig
162 std::vector<bool> volatile_args(narg,
true);
163 std::vector<bool> returnShadow(nret,
false);
164 std::vector<bool> returnPrimal(nret,
false);
168 bool freeMemory =
true;
169 size_t width = gutils->
width;
172 fn, RetActivity, ArgActivity, gutils->
TA, returnPrimal, returnShadow,
173 mode, freeMemory, width,
nullptr, type_args,
177 SmallVector<Value> revArguments;
179 for (
auto [arg, act, cache] :
180 llvm::zip_equal(callOp.getOperands(), ArgActivity, caches)) {
181 revArguments.push_back(gutils->
popCache(cache, builder));
186 for (
auto result : callOp.getResults()) {
189 revArguments.push_back(gutils->
diffe(result, builder));
192 auto revCallOp = func::CallOp::create(
193 builder, orig->getLoc(), cast<func::FuncOp>(revFn), revArguments);
195 int revIndex = 0, fwdIndex = 0;
196 for (
auto [arg, act] : llvm::zip_equal(callOp.getOperands(), ArgActivity)) {
203 cast<ClonableTypeInterface>(arg.getType())
204 .freeClonedValue(builder, revArguments[fwdIndex - 1]);
207 auto diffe = revCallOp.getResult(revIndex);
218 SmallVector<Value> cachedArguments;
221 OpBuilder cacheBuilder(newOp);
223 for (
auto arg : orig->getOperands()) {
225 if (
auto iface = dyn_cast<ClonableTypeInterface>(arg.getType())) {
226 toCache = iface.cloneValue(cacheBuilder, toCache);
229 cachedArguments.push_back(cache);
232 return cachedArguments;
240 :
public AutoDiffFunctionInterface::ExternalModel<
241 AutoDiffFuncFuncFunctionInterface, func::FuncOp> {
244 SmallVectorImpl<Type> &types)
const {}
246 Operation *
createCall(Operation *self, OpBuilder &builder, Location loc,
247 ValueRange args)
const {
248 return func::CallOp::create(builder, loc, cast<func::FuncOp>(self), args);
251 Operation *
createReturn(Operation *self, OpBuilder &builder, Location loc,
252 ValueRange args)
const {
253 return func::ReturnOp::create(builder, loc, args);
258 DialectRegistry ®istry) {
259 registry.addExtension(+[](MLIRContext *context, func::FuncDialect *) {
260 registerInterfaces(context);
261 func::CallOp::attachInterface<AutoDiffCallFwd>(*context);
262 func::CallOp::attachInterface<AutoDiffCallRev>(*context);
263 func::FuncOp::attachInterface<AutoDiffFuncFuncFunctionInterface>(*context);
LogicalResult createForwardModeTangent(Operation *orig, OpBuilder &builder, MGradientUtils *gutils) const
SmallVector< Value > cacheValues(Operation *orig, MGradientUtilsReverse *gutils) const
void createShadowValues(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector< Value > caches) const
Operation * createCall(Operation *self, OpBuilder &builder, Location loc, ValueRange args) const
Operation * createReturn(Operation *self, OpBuilder &builder, Location loc, ValueRange args) const
void transformResultTypes(Operation *self, SmallVectorImpl< Type > &types) const
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, std::vector< bool > returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
FunctionOpInterface CreateForwardDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
Value popCache(Value cache, OpBuilder &builder)
Value initAndPushCache(Value v, OpBuilder &builder)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
llvm::StringRef postpasses
void replaceOrigOpWith(Operation *op, ValueRange vals)
void erase(Operation *op)
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
bool isConstantValue(mlir::Value v) const
MFnTypeInfo getAnalyzedTypeInfo(FunctionOpInterface op) const
void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry)