Enzyme main
Loading...
Searching...
No Matches
RemoveUnusedEnzymeOps.cpp
Go to the documentation of this file.
1//===- RemoveUnusedEnzymeOps.cpp - Remove unnecessary or unused gradient and
2// cache ops
3//------------------ //
4//
5// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6// See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9//===----------------------------------------------------------------------===//
10//
11//===----------------------------------------------------------------------===//
12
13#include "Dialect/Dialect.h"
14#include "Dialect/Ops.h"
16#include "PassDetails.h"
17#include "Passes/Passes.h"
18#include "Passes/RemovalUtils.h"
19#include "mlir/Dialect/Arith/IR/Arith.h"
20#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/Dialect/Tensor/IR/Tensor.h"
25#include "mlir/Transforms/DialectConversion.h"
26
27#include "mlir/Rewrite/PatternApplicator.h"
28#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29
30#include "mlir/IR/Dominance.h"
31#include "llvm/Support/raw_ostream.h"
32
33namespace mlir {
34namespace enzyme {
35#define GEN_PASS_DEF_REMOVEUNUSEDENZYMEOPSPASS
36#include "Passes/Passes.h.inc"
37} // namespace enzyme
38} // namespace mlir
39
40using namespace mlir;
41using namespace enzyme;
42namespace {
43
44// Starting at the beginning of blk, is there a path that can execute
45// check before end.
46bool mayExecuteBefore(Block *blk, Operation *check, Operation *end) {
47 auto reg = blk->getParent();
48 (void)reg;
49 assert(reg->isAncestor(end->getParentRegion()));
50
51 DenseSet<Block *> visitedBlocks;
52
53 SmallVector<Block *> blocksToVisit;
54 for (auto succ : blk->getSuccessors()) {
55 blocksToVisit.push_back(succ);
56 }
57
58 while (!blocksToVisit.empty()) {
59 Block *cur = blocksToVisit.pop_back_val();
60
61 if (visitedBlocks.contains(cur))
62 continue;
63
64 visitedBlocks.insert(cur);
65
66 bool seenEnd = false;
67 for (auto &op : *cur) {
68
69 // If we've seen the thing to check with, it may execute before
70 if (op.isAncestor(check)) {
71 // The sole exception to this is if they are in the same sub region,
72 // which is known to execute only once. TODO this later
73 /*
74 if (op.isAncestor(end)) {
75
76 for (auto reg2 : op.getRegions()) {
77
78 }
79 }
80 */
81
82 return true;
83 }
84
85 // Otherwise if we've seen the end op, this path is over as the route we
86 // found here didn't first find a check.
87 if (op.isAncestor(end)) {
88 seenEnd = true;
89 break;
90 }
91 }
92
93 if (seenEnd)
94 continue;
95
96 // If we didn't find the end, try all successors
97 for (auto succ : cur->getSuccessors()) {
98 blocksToVisit.push_back(succ);
99 }
100 }
101
102 return false;
103}
104
105bool mayExecuteBetween(Operation *start, Operation *check, Operation *end) {
106 Block *blk = start->getBlock();
107 auto checkAnc = check;
108 while (checkAnc && checkAnc->getBlock() != blk)
109 checkAnc = checkAnc->getParentOp();
110
111 auto endAnc = end;
112 while (endAnc && endAnc->getBlock() != blk)
113 endAnc = endAnc->getParentOp();
114
115 // Both operations are either the same, or in a joint parent (which could
116 // rerun)
117 if (checkAnc && checkAnc == endAnc)
118 return true;
119
120 if (checkAnc) {
121 if (checkAnc->isBeforeInBlock(start))
122 checkAnc = nullptr;
123 }
124 if (endAnc) {
125 if (endAnc->isBeforeInBlock(start))
126 endAnc = nullptr;
127 }
128
129 if (checkAnc && endAnc) {
130 if (checkAnc->isBeforeInBlock(endAnc))
131 return true;
132 else
133 return false;
134 }
135
136 if (checkAnc)
137 return true;
138 if (endAnc)
139 return false;
140
141 auto reg = blk->getParent();
142 if (reg->isAncestor(end->getParentRegion())) {
143 return mayExecuteBefore(blk, check, end);
144 }
145
146 // If the check is in the parent op, but the end is not, assume
147 // we may execute that parent op part before going to any later ops
148 if (reg->isAncestor(check->getParentRegion())) {
149 return true;
150 }
151
152 return mayExecuteBetween(start->getParentOp(), check, end);
153}
154
155// TODO this isn't necessarily correct. This is because there could be a
156// non dominating use bewteen the dominating one and the op, causing
157// correctness issues when not seen. In interim, be conservative and only
158// succeed if these have the same parent block, and no other ops in path
159template <class T, class T2 = T>
160T findNearestDominatingOpByUse(Operation *op, Value v) {
161 DominanceInfo dInfo;
162 PostDominanceInfo pdInfo;
163
164 SmallVector<T, 1> options;
165 SmallVector<Operation *, 1> conflicts;
166 for (Operation *userSet : v.getUsers()) {
167 if (auto setOp = dyn_cast<T>(userSet)) {
168 options.push_back(setOp);
169 conflicts.push_back(setOp);
170 continue;
171 }
172 if (auto setOp = dyn_cast<T2>(userSet)) {
173 conflicts.push_back(setOp);
174 continue;
175 }
176 }
177
178 for (auto opt : options) {
179 if (!dInfo.dominates(opt, op))
180 continue;
181 bool conflict = false;
182 for (auto opt2 : conflicts) {
183 if (opt == opt2)
184 continue;
185 if (opt2 == op)
186 continue;
187
188 if (!mayExecuteBetween(opt, opt2, op)) {
189 continue;
190 }
191
192 conflict = true;
193 }
194 if (!conflict) {
195 return opt;
196 }
197 }
198
199 return nullptr;
200}
201
202struct PopSimplify : public OpRewritePattern<enzyme::PopOp> {
203 using OpRewritePattern<enzyme::PopOp>::OpRewritePattern;
204
205 LogicalResult matchAndRewrite(enzyme::PopOp pop,
206 PatternRewriter &rewriter) const final {
207
208 auto init = pop.getCache().getDefiningOp<enzyme::InitOp>();
209 if (!init)
210 return failure();
211
212 SmallVector<enzyme::PopOp, 1> pops;
213 SmallVector<enzyme::PushOp, 1> pushes;
214 for (Operation *userSet : init.getResult().getUsers()) {
215 if (auto push = dyn_cast<enzyme::PushOp>(userSet)) {
216 pushes.push_back(push);
217 continue;
218 }
219 if (auto pop = dyn_cast<enzyme::PopOp>(userSet)) {
220 pops.push_back(pop);
221 continue;
222 }
223 return failure();
224 }
225
226 if (auto push = findNearestDominatingOpByUse<enzyme::PushOp, enzyme::PopOp>(
227 pop, init)) {
228 // Do the block check to conservatively avoid multi execute push/pop
229 if (pop->getBlock() == push->getBlock()) {
230 rewriter.replaceOp(pop, push.getValue());
231 rewriter.eraseOp(push);
232 return success();
233 }
234 }
235
236 return failure();
237 }
238};
239
240struct GetSimplify : public OpRewritePattern<enzyme::GetOp> {
241 using OpRewritePattern<enzyme::GetOp>::OpRewritePattern;
242
243 LogicalResult matchAndRewrite(enzyme::GetOp get,
244 PatternRewriter &rewriter) const final {
245
246 auto init = get.getGradient().getDefiningOp<enzyme::InitOp>();
247 if (!init)
248 return failure();
249
250 for (Operation *userSet : init.getResult().getUsers()) {
251 if (isa<enzyme::GetOp>(userSet))
252 continue;
253 if (isa<enzyme::SetOp>(userSet))
254 continue;
255 return failure();
256 }
257
258 if (auto set = findNearestDominatingOpByUse<enzyme::SetOp>(get, init)) {
259 rewriter.replaceOp(get, set.getValue());
260 return success();
261 }
262 return failure();
263 }
264};
265
266struct SetSimplify : public OpRewritePattern<enzyme::SetOp> {
267 using OpRewritePattern<enzyme::SetOp>::OpRewritePattern;
268
269 LogicalResult matchAndRewrite(enzyme::SetOp get,
270 PatternRewriter &rewriter) const final {
271
272 auto init = get.getGradient().getDefiningOp<enzyme::InitOp>();
273 if (!init)
274 return failure();
275
276 for (Operation *userSet : init.getResult().getUsers()) {
277 if (isa<enzyme::SetOp>(userSet))
278 continue;
279 return failure();
280 }
281
282 rewriter.eraseOp(get);
283 return success();
284 }
285};
286
287struct PushSimplify : public OpRewritePattern<enzyme::PushOp> {
288 using OpRewritePattern<enzyme::PushOp>::OpRewritePattern;
289
290 LogicalResult matchAndRewrite(enzyme::PushOp get,
291 PatternRewriter &rewriter) const final {
292
293 auto init = get.getCache().getDefiningOp<enzyme::InitOp>();
294 if (!init)
295 return failure();
296
297 for (Operation *userSet : init.getResult().getUsers()) {
298 if (isa<enzyme::PushOp>(userSet))
299 continue;
300 return failure();
301 }
302
303 rewriter.eraseOp(get);
304 return success();
305 }
306};
307
308struct InitSimplify : public OpRewritePattern<enzyme::InitOp> {
309 using OpRewritePattern<enzyme::InitOp>::OpRewritePattern;
310
311 LogicalResult matchAndRewrite(enzyme::InitOp init,
312 PatternRewriter &rewriter) const final {
313 if (init.use_empty()) {
314 rewriter.eraseOp(init);
315 return success();
316 }
317 return failure();
318 }
319};
320
321struct IgnoreDerivativesSimplifyPattern
322 : public OpRewritePattern<enzyme::IgnoreDerivativesOp> {
323 using OpRewritePattern<enzyme::IgnoreDerivativesOp>::OpRewritePattern;
324
325 LogicalResult matchAndRewrite(enzyme::IgnoreDerivativesOp op,
326 PatternRewriter &rewriter) const override {
327 rewriter.replaceOp(op, op.getOperand());
328 return success();
329 }
330};
331
332static void applyPatterns(Operation *op) {
333 RewritePatternSet patterns(op->getContext());
334 patterns.insert<PopSimplify, GetSimplify, PushSimplify, SetSimplify,
335 InitSimplify, IgnoreDerivativesSimplifyPattern>(
336 op->getContext());
337
338 GreedyRewriteConfig config;
339 config.enableFolding();
340 (void)applyPatternsGreedily(op, std::move(patterns), config);
341}
342
343static void annotateRegionOpsInLoops(Operation *op) {
344 // When we have non-looping region branch ops (e.g. scf.if) inside of a loop,
345 // we want the pushes/pops to be removed by the outer loop remover, not the
346 // inner op remover. This helps mincut reduce the overall caching overhead.
347 op->walk([](LoopLikeOpInterface loop) {
348 loop->walk([](RegionBranchOpInterface regionBranch) {
349 if (!regionBranch.hasLoop()) {
350 regionBranch->setAttr(kPreserveCacheAttrName,
351 UnitAttr::get(regionBranch.getContext()));
352 }
353 });
354 });
355}
356
357// A worklist that supports removing operations
358// original implementation is from
359// https://github.com/llvm/llvm-project/blob/9d8d538e40ef040cb53e8db7a32f3024865187f3/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp#L198
360class Worklist {
361public:
362 Worklist() { list.reserve(8); }
363
364 bool empty();
365 void push(Operation *op);
366 void remove(Operation *op);
367 Operation *pop();
368 void reverse();
369
370private:
371 std::vector<Operation *> list;
372 llvm::DenseMap<Operation *, unsigned> map;
373};
374
375bool Worklist::empty() {
376 // Skip all nullptr.
377 return !llvm::any_of(list,
378 [](Operation *op) { return static_cast<bool>(op); });
379}
380
381void Worklist::push(Operation *op) {
382 assert(op && "cannot push nullptr to worklist");
383 // Check to see if the worklist already contains this op.
384 if (!map.insert({op, list.size()}).second)
385 return;
386 list.push_back(op);
387}
388
389void Worklist::reverse() {
390 std::reverse(list.begin(), list.end());
391 for (size_t i = 0, e = list.size(); i < e; ++i)
392 map[list[i]] = i;
393}
394
395Operation *Worklist::pop() {
396 // Skip and remove all trailing nullptr.
397 while (!list.back())
398 list.pop_back();
399 Operation *op = list.back();
400 list.pop_back();
401 map.erase(op);
402 // Cleanup: Remove all trailing nullptr.
403 while (!list.empty() && !list.back())
404 list.pop_back();
405 return op;
406}
407
408void Worklist::remove(Operation *op) {
409 assert(op && "cannot remove nullptr from worklist");
410 auto it = map.find(op);
411 if (it != map.end()) {
412 assert(list[it->second] == op && "malformed worklist data structure");
413 list[it->second] = nullptr;
414 map.erase(it);
415 }
416}
417
418// Drives Enzyme ops removal with the following goals:
419// * Each EnzymeOpsRemoverOpInterface should be processed once.
420// * Inserted ops next in the post order should still be run.
421class PostOrderWalkDriver : public RewriterBase::Listener {
422public:
423 PostOrderWalkDriver(Operation *root_) : root(root_) {}
424
425 void initializeWorklist();
426 LogicalResult processWorklist();
427
428protected:
429 void notifyOperationInserted(Operation *op,
430 OpBuilder::InsertPoint previous) override;
431 void notifyOperationErased(Operation *op) override;
432
433 void
434 notifyMatchFailure(Location loc,
435 function_ref<void(Diagnostic &)> reasonCallback) override;
436
437private:
438 void addToWorklist(Operation *op);
439
440 Worklist worklist;
441
442 Operation *current = nullptr;
443 Operation *root;
444};
445
446void PostOrderWalkDriver::addToWorklist(Operation *op) {
447 // This driver only processes EnzymeOpsRemoverOpInterface ops.
448 if (!isa<EnzymeOpsRemoverOpInterface>(op))
449 return;
450 worklist.push(op);
451}
452
453void PostOrderWalkDriver::notifyOperationInserted(
454 Operation *op, OpBuilder::InsertPoint previous) {
455 if (!isa<EnzymeOpsRemoverOpInterface>(op))
456 return;
457
458 if (!current) {
459 addToWorklist(op);
460 return;
461 }
462
463 // Check if the inserted op would be next in the post order or not compared to
464 // the current operation.
465 bool shouldInsert = false;
466 (void)root->walk([&](EnzymeOpsRemoverOpInterface iface) {
467 if ((Operation *)iface == current) {
468 shouldInsert = true;
469 return WalkResult::interrupt();
470 }
471
472 if ((Operation *)iface == op) {
473 shouldInsert = false;
474 return WalkResult::interrupt();
475 }
476
477 return WalkResult::advance();
478 });
479
480 if (shouldInsert)
481 addToWorklist(op);
482}
483
484void PostOrderWalkDriver::notifyOperationErased(Operation *op) {
485 if (op == current) {
486 current = nullptr;
487 }
488 if (!isa<EnzymeOpsRemoverOpInterface>(op))
489 return;
490 worklist.remove(op);
491}
492
493void PostOrderWalkDriver::notifyMatchFailure(
494 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
495 auto diag = mlir::emitError(loc);
496 reasonCallback(*diag.getUnderlyingDiagnostic());
497}
498
499void PostOrderWalkDriver::initializeWorklist() {
500 root->walk<WalkOrder::PreOrder>(
501 [this](EnzymeOpsRemoverOpInterface iface) { addToWorklist(iface); });
502}
503
504LogicalResult PostOrderWalkDriver::processWorklist() {
505 PatternRewriter rewriter(root->getContext());
506 rewriter.setListener(this);
507
508 bool result = true;
509 while (!worklist.empty()) {
510 auto op = worklist.pop();
511 auto iface = cast<EnzymeOpsRemoverOpInterface>(op);
512 current = op;
513 rewriter.setInsertionPoint(current);
514 result &= iface.removeEnzymeOps(rewriter).succeeded();
515 current = nullptr;
516 }
517
518 return LogicalResult::success(result);
519}
520
521struct RemoveUnusedEnzymeOpsPass
522 : public enzyme::impl::RemoveUnusedEnzymeOpsPassBase<
523 RemoveUnusedEnzymeOpsPass> {
524 void runOnOperation() override {
525 auto op = getOperation();
526
527 applyPatterns(op);
528
529 annotateRegionOpsInLoops(op);
530 bool failed = false;
531 op->walk([&](FunctionOpInterface func) {
532 PostOrderWalkDriver driver(func);
533 driver.initializeWorklist();
534 failed |= driver.processWorklist().failed();
535 });
536
537 if (failed) {
538 signalPassFailure();
539 return;
540 }
541
542 applyPatterns(op);
543 }
544};
545
546} // end anonymous namespace
static constexpr llvm::StringLiteral kPreserveCacheAttrName