14#include "mlir/IR/Matchers.h"
15#include "mlir/IR/SymbolTable.h"
16#include "mlir/Interfaces/FunctionInterfaces.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/IR/Dominance.h"
25#include "llvm/ADT/BreadthFirstIterator.h"
33 IRMapping invertedPointers_,
const llvm::ArrayRef<bool> returnPrimals,
34 const llvm::ArrayRef<bool> returnShadows,
35 const SmallPtrSetImpl<mlir::Value> &constantvalues_,
36 const SmallPtrSetImpl<mlir::Value> &activevals_,
37 ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
38 IRMapping &originalToNewFn_,
39 std::map<Operation *, Operation *> &originalToNewFnOps_,
40 DerivativeMode mode_,
unsigned width,
bool omp, StringRef postpasses,
41 bool verifyPostPasses,
bool strongZero)
43 invertedPointers_, returnPrimals, returnShadows,
44 constantvalues_, activevals_, ReturnActivity,
45 ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_,
46 mode_, width, omp, postpasses, verifyPostPasses,
50 Type indexType = getIndexType();
55 return mlir::IntegerType::get(initializationBlock->begin()->getContext(), 32);
59 OpBuilder builder(initializationBlock, initializationBlock->begin());
60 return enzyme::InitOp::create(builder,
61 (initializationBlock->rbegin())->getLoc(), t);
67 CacheType::get(initializationBlock->begin()->getContext(), t);
72 std::function<std::pair<Value, Value>(Type)> hook) {
73 if (hook !=
nullptr) {
79 std::function<std::pair<Value, Value>(Type)> hook) {
80 if (hook !=
nullptr) {
88 return {cache, cache};
97 enzyme::PushOp::create(builder, v.getLoc(), pushCache, v);
102 return enzyme::PopOp::create(
103 builder, cache.getLoc(),
104 cast<enzyme::CacheType>(cache.getType()).getType(), cache);
111 for (
auto operand : op->getOperands())
112 map.map(operand, getNewFromOriginal(operand));
113 return B.clone(*op, map);
118 OpBuilder &builder) {
119 assert(!isConstantValue(oldGradient));
120 Value operandGradient = diffe(oldGradient, builder);
121 auto iface = cast<AutoDiffTypeInterface>(addedGradient.getType());
122 auto added = iface.createAddOp(builder, oldGradient.getLoc(), operandGradient,
124 setDiffe(oldGradient, added, builder);
129 for (
auto it =
oldFunc.getBlocks().rbegin(); it !=
oldFunc.getBlocks().rend();
132 Block *reverseBlock =
new Block();
141 const ArrayRef<bool> returnPrimals,
const ArrayRef<bool> returnShadows,
142 ArrayRef<DIFFE_TYPE> retType, ArrayRef<DIFFE_TYPE> constant_args,
143 mlir::Type additionalArg,
bool omp, llvm::StringRef postpasses,
144 bool verifyPostPasses,
bool strongZero) {
158 llvm_unreachable(
"invalid DerivativeMode: ReverseModePrimal\n");
162 prefix += std::to_string(
width);
164 IRMapping originalToNew;
165 std::map<Operation *, Operation *> originalToNewOps;
167 SmallPtrSet<mlir::Value, 1> returnvals;
168 SmallPtrSet<mlir::Value, 1> constant_values;
169 SmallPtrSet<mlir::Value, 1> nonconstant_values;
174 prefix + todiff.getName(), originalToNew, originalToNewOps,
180 constant_args, originalToNew, originalToNewOps, mode_,
width,
omp,
FunctionOpInterface CloneFunctionWithReturns(DerivativeMode mode, unsigned width, FunctionOpInterface F, IRMapping &ptrInputs, ArrayRef< DIFFE_TYPE > ArgActivity, SmallPtrSetImpl< mlir::Value > &constants, SmallPtrSetImpl< mlir::Value > &nonconstants, SmallPtrSetImpl< mlir::Value > &returnvals, const std::vector< bool > &returnPrimals, const std::vector< bool > &returnShadows, ArrayRef< DIFFE_TYPE > RetActivity, Twine name, IRMapping &VMap, std::map< Operation *, Operation * > &OpMap, mlir::Type additionalArg)
Value popCache(Value cache, OpBuilder &builder)
SmallVector< std::function< std::pair< Value, Value >(Type)> > cacheCreatorHook
Value initAndPushCache(Value v, OpBuilder &builder)
MGradientUtilsReverse(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, IRMapping invertedPointers_, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, const SmallPtrSetImpl< mlir::Value > &constantvalues_, const SmallPtrSetImpl< mlir::Value > &activevals_, ArrayRef< DIFFE_TYPE > ReturnActivity, ArrayRef< DIFFE_TYPE > ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map< Operation *, Operation * > &originalToNewFnOps_, DerivativeMode mode_, unsigned width, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
void createReverseModeBlocks(Region &oldFunc, Region &newFunc)
void deregisterCacheCreatorHook(std::function< std::pair< Value, Value >(Type)> hook)
void registerCacheCreatorHook(std::function< std::pair< Value, Value >(Type)> hook)
Type getCacheType(Type t)
IRMapping mapReverseModeBlocks
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
static MGradientUtilsReverse * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const ArrayRef< bool > returnPrimals, const ArrayRef< bool > returnShadows, llvm::ArrayRef< DIFFE_TYPE > retType, llvm::ArrayRef< DIFFE_TYPE > constant_args, mlir::Type additionalArg, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
std::pair< Value, Value > getNewCache(Type t)
Operation * cloneWithNewOperands(OpBuilder &B, Operation *op)
llvm::StringRef postpasses
const llvm::ArrayRef< bool > returnShadows
const llvm::ArrayRef< bool > returnPrimals
FunctionOpInterface oldFunc
IRMapping invertedPointers
FunctionOpInterface newFunc
static LoopCacheType getCacheType(Operation *op)