Enzyme main
Loading...
Searching...
No Matches
EnzymeLogic.h
Go to the documentation of this file.
1#pragma once
2
3#include "mlir/IR/IRMapping.h"
4#include "mlir/Interfaces/FunctionInterfaces.h"
5
7#include "../../Utils.h"
8#include <functional>
9
10namespace mlir {
11namespace enzyme {
12
13typedef void(buildReturnFunction)(OpBuilder &, mlir::Block *);
14
16
18public:
19 inline bool operator<(const MFnTypeInfo &rhs) const { return false; }
20};
21
23public:
24 MFnTypeInfo getAnalyzedTypeInfo(FunctionOpInterface op) const {
25 return MFnTypeInfo();
26 }
27};
28
30public:
31 // TODO
33 TypeTree query(Value) const { return TypeTree(); }
34 ConcreteType intType(size_t num, Value val, bool errIfNotFound = true,
35 bool pointerIntSame = false) const {
36 if (isa<IntegerType, IndexType>(val.getType())) {
37 return BaseType::Integer;
38 }
39 if (errIfNotFound) {
40 llvm_unreachable("something happened");
41 }
42 return BaseType::Unknown;
43 }
44};
45
47public:
49 FunctionOpInterface todiff;
50 const std::vector<DIFFE_TYPE> retType;
51 const std::vector<DIFFE_TYPE> constant_args;
52 // std::map<llvm::Argument *, bool> uncacheable_args;
53 std::vector<bool> returnUsed;
55 unsigned width;
56 mlir::Type additionalType;
58 bool omp;
61 inline bool operator<(const MForwardCacheKey &rhs) const {
62 if (todiff < rhs.todiff)
63 return true;
64 if (rhs.todiff < todiff)
65 return false;
66
67 if (retType < rhs.retType)
68 return true;
69 if (rhs.retType < retType)
70 return false;
71
72 if (std::lexicographical_compare(
73 constant_args.begin(), constant_args.end(),
74 rhs.constant_args.begin(), rhs.constant_args.end()))
75 return true;
76 if (std::lexicographical_compare(
77 rhs.constant_args.begin(), rhs.constant_args.end(),
78 constant_args.begin(), constant_args.end()))
79 return false;
81 if (returnUsed < rhs.returnUsed)
82 return true;
83 if (rhs.returnUsed < returnUsed)
84 return false;
85
86 if (mode < rhs.mode)
87 return true;
88 if (rhs.mode < mode)
89 return false;
90
91 if (width < rhs.width)
92 return true;
93 if (rhs.width < width)
94 return false;
96 if (additionalType.getImpl() < rhs.additionalType.getImpl())
97 return true;
98 if (rhs.additionalType.getImpl() < additionalType.getImpl())
99 return false;
101 if (typeInfo < rhs.typeInfo)
102 return true;
103 if (rhs.typeInfo < typeInfo)
104 return false;
105
106 if (omp < rhs.omp)
107 return true;
108 if (rhs.omp < omp)
109 return false;
110
111 if (strongZero < rhs.strongZero)
112 return true;
114 return false;
115
116 // equal
117 return false;
119 };
122 FunctionOpInterface todiff;
123 const std::vector<DIFFE_TYPE> retActivity;
124 const std::vector<DIFFE_TYPE> argActivity;
125 const std::vector<bool> returnPrimals;
126 const std::vector<bool> returnShadows;
129 unsigned width;
130 mlir::Type additionalType;
132 const std::vector<bool> volatileArgs;
133 bool omp;
135
136 inline bool operator<(const MReverseCacheKey &rhs) const {
137 if (todiff < rhs.todiff)
138 return true;
139 if (rhs.todiff < todiff)
140 return false;
141
142 if (std::lexicographical_compare(retActivity.begin(), retActivity.end(),
143 rhs.retActivity.begin(),
144 rhs.retActivity.end()))
145 return true;
146 if (std::lexicographical_compare(rhs.retActivity.begin(),
147 rhs.retActivity.end(),
148 retActivity.begin(), retActivity.end()))
149 return false;
150
151 if (std::lexicographical_compare(argActivity.begin(), argActivity.end(),
152 rhs.argActivity.begin(),
153 rhs.argActivity.end()))
154 return true;
155 if (std::lexicographical_compare(rhs.argActivity.begin(),
156 rhs.argActivity.end(),
157 argActivity.begin(), argActivity.end()))
158 return false;
159
161 return true;
163 return false;
166 return true;
168 return false;
170 if (mode < rhs.mode)
171 return true;
172 if (rhs.mode < mode)
173 return false;
176 return true;
177 if (rhs.freeMemory < freeMemory)
178 return false;
179
180 if (width < rhs.width)
181 return true;
182 if (rhs.width < width)
183 return false;
184
185 if (additionalType.getImpl() < rhs.additionalType.getImpl())
186 return true;
187 if (rhs.additionalType.getImpl() < additionalType.getImpl())
188 return false;
189
190 if (typeInfo < rhs.typeInfo)
191 return true;
192 if (rhs.typeInfo < typeInfo)
193 return false;
194
195 if (volatileArgs < rhs.volatileArgs)
196 return true;
197 if (rhs.volatileArgs < volatileArgs)
198 return false;
199
200 if (omp < rhs.omp)
201 return true;
202 if (rhs.omp < omp)
203 return false;
204
205 if (strongZero < rhs.strongZero)
206 return true;
207 if (rhs.strongZero < strongZero)
208 return false;
209
210 // equal
211 return false;
213 };
214
215 std::map<MForwardCacheKey, FunctionOpInterface> ForwardCachedFunctions;
216 std::map<MReverseCacheKey, FunctionOpInterface> ReverseCachedFunctions;
217
218 FunctionOpInterface
219 CreateForwardDiff(FunctionOpInterface fn, std::vector<DIFFE_TYPE> retType,
220 std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA,
221 std::vector<bool> returnPrimals, DerivativeMode mode,
222 bool freeMemory, size_t width, mlir::Type addedType,
223 MFnTypeInfo type_args, std::vector<bool> volatile_args,
224 void *augmented, bool omp, llvm::StringRef postpasses,
225 bool verifyPostPasses, bool strongZero);
226
227 FunctionOpInterface
228 CreateReverseDiff(FunctionOpInterface fn, std::vector<DIFFE_TYPE> retType,
229 std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA,
230 std::vector<bool> returnPrimals,
231 std::vector<bool> returnShadows, DerivativeMode mode,
232 bool freeMemory, size_t width, mlir::Type addedType,
233 MFnTypeInfo type_args, std::vector<bool> volatile_args,
234 void *augmented, bool omp, llvm::StringRef postpasses,
235 bool verifyPostPasses, bool strongZero);
236
237 void
238 initializeShadowValues(SmallVector<mlir::Block *> &dominatorToposortBlocks,
239 MGradientUtilsReverse *gutils);
240 void
241 handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB,
242 MGradientUtilsReverse *gutils,
243 llvm::function_ref<buildReturnFunction> buildReturnOp);
244 LogicalResult visitChildren(Block *oBB, Block *reverseBB,
245 MGradientUtilsReverse *gutils);
246 LogicalResult visitChild(Operation *op, OpBuilder &builder,
247 MGradientUtilsReverse *gutils);
248 void mapInvertArguments(Block *oBB, Block *reverseBB,
249 MGradientUtilsReverse *gutils);
250 LogicalResult
251 differentiate(MGradientUtilsReverse *gutils, Region &oldRegion,
252 Region &newRegion,
253 llvm::function_ref<buildReturnFunction> buildFuncRetrunOp,
254 std::function<std::pair<Value, Value>(Type)> cacheCreator);
255};
256
257} // Namespace enzyme
258} // Namespace mlir
DerivativeMode
Definition Utils.h:390
Concrete SubType of a given value.
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
Definition TypeTree.h:72
void initializeShadowValues(SmallVector< mlir::Block * > &dominatorToposortBlocks, MGradientUtilsReverse *gutils)
void mapInvertArguments(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils)
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, std::vector< bool > returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
LogicalResult visitChildren(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils)
FunctionOpInterface CreateForwardDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
LogicalResult visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils)
std::map< MReverseCacheKey, FunctionOpInterface > ReverseCachedFunctions
void handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, MGradientUtilsReverse *gutils, llvm::function_ref< buildReturnFunction > buildReturnOp)
std::map< MForwardCacheKey, FunctionOpInterface > ForwardCachedFunctions
LogicalResult differentiate(MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion, llvm::function_ref< buildReturnFunction > buildFuncRetrunOp, std::function< std::pair< Value, Value >(Type)> cacheCreator)
bool operator<(const MFnTypeInfo &rhs) const
Definition EnzymeLogic.h:19
MFnTypeInfo getAnalyzedTypeInfo(FunctionOpInterface op) const
Definition EnzymeLogic.h:24
TypeTree query(Value) const
Definition EnzymeLogic.h:33
ConcreteType intType(size_t num, Value val, bool errIfNotFound=true, bool pointerIntSame=false) const
Definition EnzymeLogic.h:34
void buildReturnFunction(OpBuilder &, mlir::Block *)
Definition EnzymeLogic.h:13
const std::vector< DIFFE_TYPE > retType
Definition EnzymeLogic.h:50
const std::vector< DIFFE_TYPE > constant_args
Definition EnzymeLogic.h:51
bool operator<(const MForwardCacheKey &rhs) const
Definition EnzymeLogic.h:61
const std::vector< DIFFE_TYPE > argActivity
const std::vector< DIFFE_TYPE > retActivity
bool operator<(const MReverseCacheKey &rhs) const