14#include "mlir/IR/IRMapping.h"
15#include "mlir/Interfaces/FunctionInterfaces.h"
60 const SmallPtrSetImpl<mlir::Value> &constantvalues_,
61 const SmallPtrSetImpl<mlir::Value> &activevals_,
62 ArrayRef<DIFFE_TYPE> ReturnActivities,
63 ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
64 IRMapping &originalToNewFn_,
65 std::map<Operation *, Operation *> &originalToNewFnOps_,
69 void erase(Operation *op) { op->erase(); }
71 for (
auto &&[res, rep] : llvm::zip(op->getResults(), vals)) {
75 newOp->replaceAllUsesWith(vals);
90 void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder);
94 auto iface = cast<AutoDiffTypeInterface>(T);
95 return iface.getShadowType(
width);
98 static llvm::SmallVector<mlir::Value, 1>
100 mlir::OperandRange range) {
101 llvm::SmallVector<mlir::Value, 1> results;
102 for (
size_t i = 0; i < range.size(); i++) {
103 results.push_back(vals[range.getBeginOperandIndex() + i]);
123 void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder);
125 void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder);
127 mlir::Value
diffe(mlir::Value origv, mlir::OpBuilder &builder);
134 const SmallPtrSetImpl<mlir::Value> &constantvalues_,
135 const SmallPtrSetImpl<mlir::Value> &activevals_,
136 ArrayRef<DIFFE_TYPE> RetActivity,
137 ArrayRef<DIFFE_TYPE> ArgActivity, IRMapping &origToNew_,
138 std::map<Operation *, Operation *> &origToNewOps_,
144 activevals_, RetActivity, ArgActivity, origToNew_,
155 ArrayRef<DIFFE_TYPE> RetActivity, ArrayRef<DIFFE_TYPE> ArgActivity,
156 mlir::Type additionalArg,
bool omp, llvm::StringRef
postpasses,
171 llvm_unreachable(
"invalid DerivativeMode: ReverseModePrimal\n");
175 prefix += std::to_string(
width);
177 IRMapping originalToNew;
178 std::map<Operation *, Operation *> originalToNewOps;
180 SmallPtrSet<mlir::Value, 1> returnvals;
181 SmallPtrSet<mlir::Value, 1> constant_values;
182 SmallPtrSet<mlir::Value, 1> nonconstant_values;
187 RetActivity, prefix + todiff.getName(), originalToNew, originalToNewOps,
193 ArgActivity, originalToNew, originalToNewOps,
mode,
width,
omp,
FunctionOpInterface CloneFunctionWithReturns(DerivativeMode mode, unsigned width, FunctionOpInterface F, IRMapping &ptrInputs, ArrayRef< DIFFE_TYPE > ArgActivity, SmallPtrSetImpl< mlir::Value > &constants, SmallPtrSetImpl< mlir::Value > &nonconstants, SmallPtrSetImpl< mlir::Value > &returnvals, const std::vector< bool > &returnPrimals, const std::vector< bool > &returnShadows, ArrayRef< DIFFE_TYPE > RetActivity, Twine name, IRMapping &VMap, std::map< Operation *, Operation * > &OpMap, mlir::Type additionalArg)
SmallVector< std::function< Value(Location, Type)> > gradientCreatorHook
void registerGradientCreatorHook(std::function< Value(Location, Type)> hook)
mlir::Value getDifferential(mlir::Value origv)
Block * initializationBlock
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
void deregisterGradientCreatorHook(std::function< Value(Location, Type)> hook)
static MDiffeGradientUtils * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, ArrayRef< DIFFE_TYPE > RetActivity, ArrayRef< DIFFE_TYPE > ArgActivity, mlir::Type additionalArg, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
Value getNewGradient(Location loc, Type t)
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder)
MDiffeGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA, MTypeResults TR, IRMapping &invertedPointers_, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, const SmallPtrSetImpl< mlir::Value > &constantvalues_, const SmallPtrSetImpl< mlir::Value > &activevals_, ArrayRef< DIFFE_TYPE > RetActivity, ArrayRef< DIFFE_TYPE > ArgActivity, IRMapping &origToNew_, std::map< Operation *, Operation * > &origToNewOps_, DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
Operation * cloneWithNewOperands(OpBuilder &B, Operation *op)
IRMapping originalToNewFn
ArrayRef< DIFFE_TYPE > RetDiffeTypes
static llvm::SmallVector< mlir::Value, 1 > reindex_arguments(llvm::ArrayRef< mlir::Value > vals, mlir::OperandRange range)
llvm::StringRef postpasses
const llvm::ArrayRef< bool > returnShadows
mlir::Type getShadowType(mlir::Type T)
const llvm::ArrayRef< bool > returnPrimals
FunctionOpInterface oldFunc
IRMapping invertedPointers
void replaceOrigOpWith(Operation *op, ValueRange vals)
std::map< Operation *, Operation * > originalToNewFnOps
DenseMap< Operation *, bool > readOnlyCache
ArrayRef< DIFFE_TYPE > ArgDiffeTypes
LogicalResult visitChild(Operation *op)
void erase(Operation *op)
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
void eraseIfUnused(Operation *op, bool erase=true, bool check=true)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
void forceAugmentedReturns()
MGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, MTypeResults TR_, IRMapping &invertedPointers_, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, const SmallPtrSetImpl< mlir::Value > &constantvalues_, const SmallPtrSetImpl< mlir::Value > &activevals_, ArrayRef< DIFFE_TYPE > ReturnActivities, ArrayRef< DIFFE_TYPE > ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map< Operation *, Operation * > &originalToNewFnOps_, DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
std::unique_ptr< enzyme::ActivityAnalyzer > activityAnalyzer
FunctionOpInterface newFunc
bool isConstantInstruction(mlir::Operation *v) const
void setInvertedPointer(mlir::Value origv, mlir::Value newv)
bool isConstantValue(mlir::Value v) const
SmallPtrSet< Block *, 4 > blocksNotForAnalysis