10#ifndef ENZYME_MLIR_INTERFACES_GRADIENT_UTILS_REVERSE_H
11#define ENZYME_MLIR_INTERFACES_GRADIENT_UTILS_REVERSE_H
13#include "mlir/IR/IRMapping.h"
14#include "mlir/Interfaces/FunctionInterfaces.h"
30 IRMapping invertedPointers_,
33 const SmallPtrSetImpl<mlir::Value> &constantvalues_,
34 const SmallPtrSetImpl<mlir::Value> &activevals_,
35 ArrayRef<DIFFE_TYPE> ReturnActivity,
36 ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
37 IRMapping &originalToNewFn_,
38 std::map<Operation *, Operation *> &originalToNewFnOps_,
45 void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient,
65 Value
popCache(Value cache, OpBuilder &builder);
73 llvm::ArrayRef<DIFFE_TYPE> retType,
74 llvm::ArrayRef<DIFFE_TYPE> constant_args, mlir::Type additionalArg,
Value popCache(Value cache, OpBuilder &builder)
SmallVector< std::function< std::pair< Value, Value >(Type)> > cacheCreatorHook
Value initAndPushCache(Value v, OpBuilder &builder)
MGradientUtilsReverse(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, IRMapping invertedPointers_, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, const SmallPtrSetImpl< mlir::Value > &constantvalues_, const SmallPtrSetImpl< mlir::Value > &activevals_, ArrayRef< DIFFE_TYPE > ReturnActivity, ArrayRef< DIFFE_TYPE > ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map< Operation *, Operation * > &originalToNewFnOps_, DerivativeMode mode_, unsigned width, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
void createReverseModeBlocks(Region &oldFunc, Region &newFunc)
void deregisterCacheCreatorHook(std::function< std::pair< Value, Value >(Type)> hook)
void registerCacheCreatorHook(std::function< std::pair< Value, Value >(Type)> hook)
Type getCacheType(Type t)
IRMapping mapReverseModeBlocks
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
static MGradientUtilsReverse * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const ArrayRef< bool > returnPrimals, const ArrayRef< bool > returnShadows, llvm::ArrayRef< DIFFE_TYPE > retType, llvm::ArrayRef< DIFFE_TYPE > constant_args, mlir::Type additionalArg, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
std::pair< Value, Value > getNewCache(Type t)
Operation * cloneWithNewOperands(OpBuilder &B, Operation *op)
llvm::StringRef postpasses
const llvm::ArrayRef< bool > returnShadows
const llvm::ArrayRef< bool > returnPrimals
FunctionOpInterface oldFunc
FunctionOpInterface newFunc