Enzyme main
Loading...
Searching...
No Matches
EnzymeBatchDiffPass.h
Go to the documentation of this file.
1#ifndef ENZYME_BATCH_DIFF_PASS_H
2#define ENZYME_BATCH_DIFF_PASS_H
3
4#include "Dialect/Ops.h"
6#include "Interfaces/Utils.h"
7#include "PassDetails.h"
8#include "Passes/Passes.h"
9
10#include "mlir/Analysis/AliasAnalysis.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/Interfaces/FunctionInterfaces.h"
14#include <cstdint>
15
16namespace mlir {
17namespace enzyme {
18namespace batchutils {
19
21 FunctionOpInterface function;
22 SmallVector<mlir::Value> inputs;
23 SmallVector<enzyme::Activity> inActivity;
24 SmallVector<enzyme::Activity> retActivity;
25 Block *blk;
26
27 // for use in std::map:
28 bool operator<(const BatchDiffCacheKey &other) const {
29 auto lhs_name = const_cast<FunctionOpInterface &>(function).getName();
30 auto rhs_name = const_cast<FunctionOpInterface &>(other.function).getName();
31
32 if (lhs_name < rhs_name)
33 return true;
34 if (rhs_name < lhs_name)
35 return false;
36 if (inputs.size() < other.inputs.size())
37 return true;
38 if (other.inputs.size() < inputs.size())
39 return false;
40
41 // Sizes are equal, so compare elements
42 for (size_t i = 0; i < inputs.size(); ++i) {
43 auto lhs_ptr = inputs[i].getAsOpaquePointer();
44 auto rhs_ptr = other.inputs[i].getAsOpaquePointer();
45 if (lhs_ptr < rhs_ptr)
46 return true;
47 if (rhs_ptr < lhs_ptr)
48 return false;
49 }
50
51 if (inActivity < other.inActivity)
52 return true;
53 if (other.inActivity < inActivity)
54 return false;
55 if (retActivity < other.retActivity)
56 return true;
57 if (other.retActivity < retActivity)
58 return false;
59
60 return blk < other.blk;
61 }
62};
63
64template <typename SourceOp>
65BatchDiffCacheKey createDiffCacheKey(SourceOp uop, FunctionOpInterface fn) {
66 // extract in_activity, ret_activity, in_args
67 SmallVector<Activity> inActivity;
68 SmallVector<Activity> retActivity;
69 SmallVector<Value> in_args;
70
71 auto in_idx = 0;
72
73 for (auto [idx, act] : llvm::enumerate(uop.getActivity())) {
74 auto iattr = cast<ActivityAttr>(act);
75 auto val = iattr.getValue();
76 inActivity.push_back(val);
77
78 in_args.push_back(uop.getInputs()[in_idx]);
79 ++in_idx;
80
81 if (val == Activity::enzyme_dup || val == Activity::enzyme_dupnoneed) {
82 ++in_idx;
83 }
84 }
85
86 for (auto [idx, ract] : llvm::enumerate(uop.getRetActivity())) {
87 auto iattr = cast<ActivityAttr>(ract);
88 auto val = iattr.getValue();
89 retActivity.push_back(val);
90 }
91
92 batchutils::BatchDiffCacheKey key{fn, in_args, inActivity, retActivity,
93 uop->getBlock()};
94 return key;
95}
96
97template <typename SourceOp,
98 std::enable_if_t<
99 llvm::is_one_of<SourceOp, ForwardDiffOp, AutoDiffOp>::value,
100 bool> = true>
101SmallVector<MemoryEffects::EffectInstance> findCallerEffects(
102 SourceOp callerOp, FunctionOpInterface innerFnOp,
103 const SmallVector<MemoryEffects::EffectInstance> &innerEffects) {
104 SmallVector<MemoryEffects::EffectInstance> outerEffects;
105 for (auto &eff : innerEffects) {
106
107 Value effVal = eff.getValue();
108 if (!effVal) {
109 // unknown effect which isn't tied to a value, just add to result
110 outerEffects.push_back(eff);
111 continue;
112 }
113
114 // Find primal argument corresponding to effect value
115 size_t primalArgPos = 0;
116 bool foundPrimal = false;
117 if (auto effBA = dyn_cast<BlockArgument>(effVal)) {
118 if (llvm::is_contained(innerFnOp.getArguments(), effBA)) {
119 foundPrimal = true;
120 primalArgPos = effBA.getArgNumber();
121 }
122 }
123
124 if (!foundPrimal) {
125 // TODO: Handle this either as a global value, or a value which
126 // is inside of the MLIR function(for inter-proc alias analysis) -
127 // Just skip for now, since we don't have interprocedural alias-analysis
128 // implemented yet.
129 continue;
130 }
131
132 // Add primal effects to caller effect map for all ops
133 Value primalVal = callerOp.getPrimalInputs()[primalArgPos];
134 outerEffects.push_back(
135 oputils::getEffectOfVal(primalVal, eff.getEffect(), eff.getResource()));
136
137 // Add derivative effects(only if primal arg is dup)
138 // read(primal) -> read(derivative)
139 // write(primal) -> write(derivative)
140
141 // find position of dup arg for primal
142 bool primalIsDup =
143 (cast<ActivityAttr>(callerOp.getActivity()[primalArgPos]).getValue() ==
144 Activity::enzyme_dup) ||
145 (cast<ActivityAttr>(callerOp.getActivity()[primalArgPos]).getValue() ==
146 Activity::enzyme_dupnoneed);
147
148 if (primalIsDup) {
149 auto gradArgPos = 0;
150 for (auto [idx, act] : llvm::enumerate(callerOp.getActivity())) {
151 auto iattr = cast<ActivityAttr>(act);
152 auto act_val = iattr.getValue();
153 ++gradArgPos;
154
155 if (idx == primalArgPos)
156 break;
157
158 if (act_val == Activity::enzyme_dup ||
159 act_val == Activity::enzyme_dupnoneed) {
160 ++gradArgPos;
161 }
162 }
163
164 Value dVal = callerOp.getInputs()[gradArgPos];
165 // specialze effects based on callerOp type
166 if constexpr (std::is_same_v<SourceOp, ForwardDiffOp>) {
167 outerEffects.push_back(
168 oputils::getEffectOfVal(dVal, eff.getEffect(), eff.getResource()));
169 } else {
170 outerEffects.push_back(oputils::getEffectOfVal(
171 dVal, MemoryEffects::Write::get(), eff.getResource()));
172 outerEffects.push_back(oputils::getEffectOfVal(
173 dVal, MemoryEffects::Read::get(), eff.getResource()));
174 }
175 }
176 }
177 return outerEffects;
178}
179
180template <typename SourceOp,
181 std::enable_if_t<
182 llvm::is_one_of<SourceOp, ForwardDiffOp, AutoDiffOp>::value,
183 bool> = true>
184llvm::SmallVector<SourceOp, 2> pruneGradDefs(BatchDiffCacheKey &key,
185 SmallVector<SourceOp> &allDiffs) {
186 SmallVector<SourceOp, 2> prunedSources;
187
188 // We first prune and check that all derivative arguments are defined before
189 // the first diff in the same block. (ops in allDiffs are guaranteed to belong
190 // to the same basic block)
191 auto firstDiffOp = allDiffs[0];
192 for (auto uop : allDiffs) {
193
194 auto diffArgs = uop.getShadows();
195 if constexpr (std::is_same_v<SourceOp, AutoDiffOp>) {
196 auto diffeRet = uop.getDifferentialReturns();
197 diffArgs.append(diffeRet.begin(), diffeRet.end());
198 }
199
200 bool definedBeforeFirst = true;
201
202 for (auto diffVal : diffArgs) {
203 if (auto diffValOR = dyn_cast<OpResult>(diffVal)) {
204 // check that defining op appears before the current op
205 auto parentOp = diffValOR.getOwner();
206 if (!parentOp->isBeforeInBlock(firstDiffOp.getOperation())) {
207 definedBeforeFirst = false;
208 break;
209 }
210 }
211 }
212
213 if (definedBeforeFirst) {
214 prunedSources.push_back(uop);
215 }
216 }
217
218 return prunedSources;
219}
220
221template <typename SourceOp,
222 std::enable_if_t<
223 llvm::is_one_of<SourceOp, ForwardDiffOp, AutoDiffOp>::value,
224 bool> = true>
225llvm::SmallVector<SourceOp> pruneMemoryEffects(
226 SymbolTableCollection &symbolTable, BatchDiffCacheKey &key,
227 SmallVector<SourceOp> &prunedSources,
228 DenseMap<SourceOp, SmallVector<MemoryEffects::EffectInstance>>
229 &callerEffectMap,
230 llvm::DenseMap<FunctionOpInterface,
231 SmallVector<MemoryEffects::EffectInstance>>
232 &innerEffectCache) {
233 // Find a mergeable subset of diff operations, which do not violate memory
234 // effects wrt reads and writes. Note that callerEffects only contains the
235 // aliased set of primal effects, so we have to first map these primal effects
236 // to corresponding derivative effects in `prunedSources`
237 // TODO: Also handle global values, and non-primal values inside callerEffects
238 // through inter-procedural alias analysis. Skip for now
239
240 if (callerEffectMap.empty()) {
241 // legal to merge since there is no effect overwrite in mergeable ops
242 return prunedSources;
243 }
244
245 SmallVector<SourceOp> legalMerge;
246 auto lastOp = prunedSources[0];
247
248 SmallVector<MemoryEffects::EffectInstance, 4> betweenEffects;
249 for (auto candidateOp : prunedSources) {
250 // Update betweenEffects to include memory effects from lastOp to
251 // candidateOp
252 for (Operation *curr = lastOp.getOperation();
253 curr != candidateOp.getOperation(); curr = curr->getNextNode()) {
254 auto currSourceOp = dyn_cast<SourceOp>(curr);
255 if (currSourceOp && callerEffectMap.contains(currSourceOp)) {
256 // curr is/was a mergeable candidate, and we would have already computed
257 // its memory effects in the effect map
258 betweenEffects.append(callerEffectMap[currSourceOp]);
259 } else if (auto currFwdOp = dyn_cast<ForwardDiffOp>(curr)) {
260 // curr is a previously un-encountered fwddiff op
261 auto fnOp = dyn_cast_or_null<FunctionOpInterface>(
262 symbolTable.lookupNearestSymbolFrom(currFwdOp,
263 currFwdOp.getFnAttr()));
264 if (!fnOp)
265 continue;
266
267 if (!innerEffectCache.contains(fnOp)) {
268 innerEffectCache[fnOp] = oputils::collectFnEffects(fnOp);
269 }
270
271 // map to outerEffects
272 betweenEffects.append(batchutils::findCallerEffects(
273 currFwdOp, fnOp, innerEffectCache[fnOp]));
274 } else if (auto currBwdOp = dyn_cast<AutoDiffOp>(curr)) {
275 // curr is a previously un-encountered revdiff op
276 auto fnOp = dyn_cast_or_null<FunctionOpInterface>(
277 symbolTable.lookupNearestSymbolFrom(currBwdOp,
278 currBwdOp.getFnAttr()));
279 if (!fnOp)
280 continue;
281
282 if (!innerEffectCache.contains(fnOp)) {
283 innerEffectCache[fnOp] = oputils::collectFnEffects(fnOp);
284 }
285
286 // map to outerEffects
287 betweenEffects.append(batchutils::findCallerEffects(
288 currBwdOp, fnOp, innerEffectCache[fnOp]));
289 } else {
290 // TODO: move forwarddiff and revdiff effect collection specialization
291 // from `findCallerEffects` into collectOpEffects(), accounting for
292 // inter-procedural alias analysis
293 SmallVector<MemoryEffects::EffectInstance> currOpEffects;
294 (void)oputils::collectOpEffects(curr, currOpEffects);
295 betweenEffects.append(currOpEffects);
296 }
297 }
298
299 // Check conflicts between betweenEffects and candidateOp. Since the batched
300 // version essentially "pushes up" the candidateOp, we ideally want to stop
301 // this if it violates the final order of writes to the candidate op owned
302 // value
303 bool foundConflict = false;
304 for (auto candidateEffect : callerEffectMap[candidateOp]) {
305 for (auto prevEffect : betweenEffects) {
306 // We will disable batching any candidiate operation which re-orders the
307 // relative order of writes to the primal and derivative arguments. For
308 // this, we alias the underlying effects in the preceding effects and
309 // the current candidate operation.
310 if ((isa<MemoryEffects::Write>(prevEffect.getEffect()) &&
311 isa<MemoryEffects::Read>(candidateEffect.getEffect())) ||
312 (isa<MemoryEffects::Read>(prevEffect.getEffect()) &&
313 isa<MemoryEffects::Write>(candidateEffect.getEffect())) ||
314 (isa<MemoryEffects::Write>(prevEffect.getEffect()) &&
315 isa<MemoryEffects::Write>(candidateEffect.getEffect()))) {
316
317 // if the effects alias each other, then this is not a candidate for
318 // merging
319 if (oputils::mayAlias(candidateEffect, prevEffect)) {
320 foundConflict = true;
321 break;
322 }
323 }
324 }
325 }
326
327 if (!foundConflict)
328 legalMerge.push_back(candidateOp);
329
330 // mark start of next range
331 lastOp = candidateOp;
332 }
333
334 return legalMerge;
335}
336
337} // namespace batchutils
338} // namespace enzyme
339} // namespace mlir
340
341#endif // ENZYME_BATCH_DIFF_PASS_H
BatchDiffCacheKey createDiffCacheKey(SourceOp uop, FunctionOpInterface fn)
SmallVector< MemoryEffects::EffectInstance > findCallerEffects(SourceOp callerOp, FunctionOpInterface innerFnOp, const SmallVector< MemoryEffects::EffectInstance > &innerEffects)
llvm::SmallVector< SourceOp, 2 > pruneGradDefs(BatchDiffCacheKey &key, SmallVector< SourceOp > &allDiffs)
llvm::SmallVector< SourceOp > pruneMemoryEffects(SymbolTableCollection &symbolTable, BatchDiffCacheKey &key, SmallVector< SourceOp > &prunedSources, DenseMap< SourceOp, SmallVector< MemoryEffects::EffectInstance > > &callerEffectMap, llvm::DenseMap< FunctionOpInterface, SmallVector< MemoryEffects::EffectInstance > > &innerEffectCache)
bool mayAlias(Value v1, Value v2)
Definition Utils.cpp:136
bool collectOpEffects(Operation *rootOp, SmallVector< MemoryEffects::EffectInstance > &effects)
Returns the side effects of an operation(similar to mlir::getEffectsRecursively).
Definition Utils.cpp:297
SmallVector< MemoryEffects::EffectInstance > collectFnEffects(FunctionOpInterface fnOp)
Definition Utils.cpp:409
MemoryEffects::EffectInstance getEffectOfVal(Value val, MemoryEffects::Effect *effect, SideEffects::Resource *resource)
Definition Utils.cpp:422
SmallVector< enzyme::Activity > retActivity
bool operator<(const BatchDiffCacheKey &other) const
SmallVector< enzyme::Activity > inActivity