Enzyme main
Loading...
Searching...
No Matches
GradientUtils.h
Go to the documentation of this file.
1//===- GradientUtils.h - Utilities for gradient interfaces -------* C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#pragma once
9
12
14#include "mlir/IR/IRMapping.h"
15#include "mlir/Interfaces/FunctionInterfaces.h"
16
17namespace mlir {
18namespace enzyme {
19
21public:
22 // From CacheUtility
23 FunctionOpInterface newFunc;
24
28 FunctionOpInterface oldFunc;
30 IRMapping originalToNewFn;
31 std::map<Operation *, Operation *> originalToNewFnOps;
32
33 SmallPtrSet<Block *, 4> blocksNotForAnalysis;
34 DenseMap<Operation *, bool> readOnlyCache;
35 std::unique_ptr<enzyme::ActivityAnalyzer> activityAnalyzer;
36
39 bool omp;
41 llvm::StringRef postpasses;
43 const llvm::ArrayRef<bool> returnPrimals;
44 const llvm::ArrayRef<bool> returnShadows;
45
46 unsigned width;
47 ArrayRef<DIFFE_TYPE> ArgDiffeTypes;
48 ArrayRef<DIFFE_TYPE> RetDiffeTypes;
49
50 SmallVector<mlir::Value, 1> getNewFromOriginal(ValueRange originst) const;
51 mlir::Value getNewFromOriginal(const mlir::Value originst) const;
52 mlir::Block *getNewFromOriginal(mlir::Block *originst) const;
53 Operation *getNewFromOriginal(Operation *originst) const;
54
55 MGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
56 FunctionOpInterface oldFunc_, MTypeAnalysis &TA_,
57 MTypeResults TR_, IRMapping &invertedPointers_,
58 const llvm::ArrayRef<bool> returnPrimals,
59 const llvm::ArrayRef<bool> returnShadows,
60 const SmallPtrSetImpl<mlir::Value> &constantvalues_,
61 const SmallPtrSetImpl<mlir::Value> &activevals_,
62 ArrayRef<DIFFE_TYPE> ReturnActivities,
63 ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
64 IRMapping &originalToNewFn_,
65 std::map<Operation *, Operation *> &originalToNewFnOps_,
66 DerivativeMode mode, unsigned width, bool omp,
67 llvm::StringRef postpasses, bool verifyPostPasses,
68 bool strongZero);
69 void erase(Operation *op) { op->erase(); }
70 void replaceOrigOpWith(Operation *op, ValueRange vals) {
71 for (auto &&[res, rep] : llvm::zip(op->getResults(), vals)) {
72 originalToNewFn.map(res, rep);
73 }
74 auto newOp = getNewFromOriginal(op);
75 newOp->replaceAllUsesWith(vals);
76 originalToNewFnOps.erase(op);
77 }
78 void eraseIfUnused(Operation *op, bool erase = true, bool check = true) {
79 // TODO
80 }
81 bool isConstantInstruction(mlir::Operation *v) const;
82 bool isConstantValue(mlir::Value v) const;
83 mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2);
85
86 Operation *cloneWithNewOperands(OpBuilder &B, Operation *op);
87
88 LogicalResult visitChild(Operation *op);
89
90 void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder);
91 void setInvertedPointer(mlir::Value origv, mlir::Value newv);
92
93 mlir::Type getShadowType(mlir::Type T) {
94 auto iface = cast<AutoDiffTypeInterface>(T);
95 return iface.getShadowType(width);
96 }
97
98 static llvm::SmallVector<mlir::Value, 1>
99 reindex_arguments(llvm::ArrayRef<mlir::Value> vals,
100 mlir::OperandRange range) {
101 llvm::SmallVector<mlir::Value, 1> results;
102 for (size_t i = 0; i < range.size(); i++) {
103 results.push_back(vals[range.getBeginOperandIndex() + i]);
105 return results;
106 }
107};
110protected:
111 IRMapping differentials;
114 SmallVector<std::function<Value(Location, Type)>> gradientCreatorHook;
115
116public:
117 void registerGradientCreatorHook(std::function<Value(Location, Type)> hook);
118 void deregisterGradientCreatorHook(std::function<Value(Location, Type)> hook);
119 Value getNewGradient(Location loc, Type t);
120
121 mlir::Value getDifferential(mlir::Value origv);
122
123 void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder);
124
125 void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder);
126
127 mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder);
129 MDiffeGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
130 FunctionOpInterface oldFunc_, MTypeAnalysis &TA,
131 MTypeResults TR, IRMapping &invertedPointers_,
132 const llvm::ArrayRef<bool> returnPrimals,
133 const llvm::ArrayRef<bool> returnShadows,
134 const SmallPtrSetImpl<mlir::Value> &constantvalues_,
135 const SmallPtrSetImpl<mlir::Value> &activevals_,
136 ArrayRef<DIFFE_TYPE> RetActivity,
137 ArrayRef<DIFFE_TYPE> ArgActivity, IRMapping &origToNew_,
138 std::map<Operation *, Operation *> &origToNewOps_,
139 DerivativeMode mode, unsigned width, bool omp,
140 llvm::StringRef postpasses, bool verifyPostPasses,
142 : MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_,
143 returnPrimals, returnShadows, constantvalues_,
144 activevals_, RetActivity, ArgActivity, origToNew_,
145 origToNewOps_, mode, width, omp, postpasses,
147 initializationBlock(&*(newFunc.getFunctionBody().begin())) {}
149 // Technically diffe constructor
152 FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
153 const llvm::ArrayRef<bool> returnPrimals,
154 const llvm::ArrayRef<bool> returnShadows,
155 ArrayRef<DIFFE_TYPE> RetActivity, ArrayRef<DIFFE_TYPE> ArgActivity,
156 mlir::Type additionalArg, bool omp, llvm::StringRef postpasses,
158 std::string prefix;
159
160 switch (mode) {
164 prefix = "fwddiffe";
165 break;
168 prefix = "diffe";
169 break;
171 llvm_unreachable("invalid DerivativeMode: ReverseModePrimal\n");
173
174 if (width > 1)
175 prefix += std::to_string(width);
177 IRMapping originalToNew;
178 std::map<Operation *, Operation *> originalToNewOps;
180 SmallPtrSet<mlir::Value, 1> returnvals;
181 SmallPtrSet<mlir::Value, 1> constant_values;
182 SmallPtrSet<mlir::Value, 1> nonconstant_values;
183 IRMapping invertedPointers;
184 FunctionOpInterface newFunc = CloneFunctionWithReturns(
185 mode, width, todiff, invertedPointers, ArgActivity, constant_values,
186 nonconstant_values, returnvals, returnPrimals, returnShadows,
187 RetActivity, prefix + todiff.getName(), originalToNew, originalToNewOps,
188 additionalArg);
189 MTypeResults TR; // TODO
190 return new MDiffeGradientUtils(
192 returnShadows, constant_values, nonconstant_values, RetActivity,
193 ArgActivity, originalToNew, originalToNewOps, mode, width, omp,
196};
197
198}; // namespace enzyme
199}; // namespace mlir
FunctionOpInterface CloneFunctionWithReturns(DerivativeMode mode, unsigned width, FunctionOpInterface F, IRMapping &ptrInputs, ArrayRef< DIFFE_TYPE > ArgActivity, SmallPtrSetImpl< mlir::Value > &constants, SmallPtrSetImpl< mlir::Value > &nonconstants, SmallPtrSetImpl< mlir::Value > &returnvals, const std::vector< bool > &returnPrimals, const std::vector< bool > &returnShadows, ArrayRef< DIFFE_TYPE > RetActivity, Twine name, IRMapping &VMap, std::map< Operation *, Operation * > &OpMap, mlir::Type additionalArg)
DerivativeMode
Definition Utils.h:390
SmallVector< std::function< Value(Location, Type)> > gradientCreatorHook
void registerGradientCreatorHook(std::function< Value(Location, Type)> hook)
mlir::Value getDifferential(mlir::Value origv)
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
void deregisterGradientCreatorHook(std::function< Value(Location, Type)> hook)
static MDiffeGradientUtils * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, ArrayRef< DIFFE_TYPE > RetActivity, ArrayRef< DIFFE_TYPE > ArgActivity, mlir::Type additionalArg, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
Value getNewGradient(Location loc, Type t)
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder)
MDiffeGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA, MTypeResults TR, IRMapping &invertedPointers_, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, const SmallPtrSetImpl< mlir::Value > &constantvalues_, const SmallPtrSetImpl< mlir::Value > &activevals_, ArrayRef< DIFFE_TYPE > RetActivity, ArrayRef< DIFFE_TYPE > ArgActivity, IRMapping &origToNew_, std::map< Operation *, Operation * > &origToNewOps_, DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
Operation * cloneWithNewOperands(OpBuilder &B, Operation *op)
ArrayRef< DIFFE_TYPE > RetDiffeTypes
static llvm::SmallVector< mlir::Value, 1 > reindex_arguments(llvm::ArrayRef< mlir::Value > vals, mlir::OperandRange range)
const llvm::ArrayRef< bool > returnShadows
mlir::Type getShadowType(mlir::Type T)
const llvm::ArrayRef< bool > returnPrimals
FunctionOpInterface oldFunc
void replaceOrigOpWith(Operation *op, ValueRange vals)
std::map< Operation *, Operation * > originalToNewFnOps
DenseMap< Operation *, bool > readOnlyCache
ArrayRef< DIFFE_TYPE > ArgDiffeTypes
LogicalResult visitChild(Operation *op)
void erase(Operation *op)
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
void eraseIfUnused(Operation *op, bool erase=true, bool check=true)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
MGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, MTypeResults TR_, IRMapping &invertedPointers_, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, const SmallPtrSetImpl< mlir::Value > &constantvalues_, const SmallPtrSetImpl< mlir::Value > &activevals_, ArrayRef< DIFFE_TYPE > ReturnActivities, ArrayRef< DIFFE_TYPE > ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map< Operation *, Operation * > &originalToNewFnOps_, DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
std::unique_ptr< enzyme::ActivityAnalyzer > activityAnalyzer
FunctionOpInterface newFunc
bool isConstantInstruction(mlir::Operation *v) const
void setInvertedPointer(mlir::Value origv, mlir::Value newv)
bool isConstantValue(mlir::Value v) const
SmallPtrSet< Block *, 4 > blocksNotForAnalysis