Enzyme main
Loading...
Searching...
No Matches
RemovalUtils.cpp
Go to the documentation of this file.
1//===- RemovalUtils.cpp - Utilities to remove Enzyme ops -------* C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "RemovalUtils.h"
13#include "Utils.h"
14#include "mlir/Analysis/TopologicalSortUtils.h"
15#include "mlir/IR/PatternMatch.h"
16#include <cassert>
17#include <deque>
18
19#include "llvm/ADT/MapVector.h"
20
21using namespace mlir;
22using namespace mlir::enzyme;
23
24#define DEBUG_TYPE "enzyme-mincut"
25
26static llvm::cl::opt<bool>
27 DebugGraphviz("mincut-print-graphviz", llvm::cl::init(false),
28 llvm::cl::Hidden,
29 llvm::cl::desc("Use with DEBUG_TYPE 'enzyme-mincut' to print "
30 "the mincut graphs in GraphViz"));
31
32void mlir::enzyme::localizeGradients(OpBuilder &builder,
34 Block *fwd) {
35 Operation *parent = fwd->getParentOp();
36
37 auto localizeGradientValue = [&](Value val) {
38 if (gutils->isConstantValue(val))
39 return;
40 auto iface = dyn_cast<AutoDiffTypeInterface>(val.getType());
41 if (iface && !iface.isMutable()) {
42 auto grad = gutils->getDifferential(val);
43
44 enzyme::SetOp initialSet = nullptr;
45 for (auto user : grad.getUsers()) {
46 if (!parent->isProperAncestor(user)) {
47 assert(!initialSet);
48 initialSet = dyn_cast<enzyme::SetOp>(user);
49 assert(initialSet);
50 }
51 }
52
53 auto initOp = grad.getDefiningOp<enzyme::InitOp>();
54
55 {
56 OpBuilder::InsertionGuard g(builder);
57 Value zero =
58 iface.createNullValue(builder, initialSet.getValue().getLoc());
59 builder.setInsertionPointAfter(zero.getDefiningOp());
60 enzyme::SetOp::create(builder, initialSet.getLoc(), grad, zero);
61 initialSet->erase();
62 }
63
64 builder.setInsertionPointToStart(builder.getBlock());
65 initOp->remove();
66 builder.insert(initOp);
67 }
68 };
69
70 for (auto operand : fwd->getArguments()) {
71 localizeGradientValue(operand);
72 }
73
74 for (auto &it : fwd->getOperations()) {
75 for (auto res : it.getResults()) {
76 localizeGradientValue(res);
77 }
78 }
79}
80
82 Block *block, IRMapping &mapping, PatternRewriter &rewriter,
83 llvm::SetVector<Value> &gradients,
84 llvm::MapVector<Value, CacheInfo> &caches) {
85 for (auto it = block->begin(), e = block->end(); it != e;) {
86 Operation *op = &*it;
87
88 if (auto setOp = dyn_cast<enzyme::SetOp>(op)) {
89 auto grad = setOp.getGradient();
90 auto value = setOp.getValue();
91 mapping.map(grad, value);
92 gradients.insert(grad);
93 }
94
95 if (auto getOp = dyn_cast<enzyme::GetOp>(op)) {
96 auto grad = getOp.getGradient();
97 Value value = mapping.lookupOrNull(getOp.getGradient());
98 if (!value) {
99 value = enzyme::GetOp::create(rewriter, getOp->getLoc(),
100 getOp.getResult().getType(), grad);
101 mapping.map(grad, value);
102 }
103 rewriter.replaceAllUsesWith(getOp.getResult(), value);
104 }
105
106 if (auto pushOp = dyn_cast<enzyme::PushOp>(op)) {
107 CacheInfo info(pushOp.getCache());
108
109 Value pushedValue = info.pushedValue();
110
111 // Then we can push the value before the if, if it is defined before the
112 // if
113 if (pushedValue.getParentBlock() != block) {
114 enzyme::PushOp::create(rewriter, pushOp->getLoc(), pushOp.getCache(),
115 pushedValue);
116
117 ++it; // Increment iterator to allow in place deletion
118 rewriter.eraseOp(pushOp);
119
120 // Move the pop before the other if
121 OpBuilder::InsertionGuard guard(rewriter);
122 rewriter.setInsertionPoint(info.popOp->getParentOp());
123
124 auto newPop =
125 enzyme::PopOp::create(rewriter, info.popOp->getLoc(),
126 pushedValue.getType(), info.popOp.getCache());
127 rewriter.replaceAllUsesWith(info.popOp.getResult(), newPop);
128 rewriter.eraseOp(info.popOp);
129
130 continue;
131 }
132
133 if (caches.contains(pushedValue)) {
134 info = info.merge(caches.lookup(pushedValue), rewriter);
135 }
136 caches[pushedValue] = info;
137 }
138
139 ++it;
140 }
141}
142
143typedef llvm::PointerUnion<Operation *, Value> Node;
144
145void dump(const Node &n) {
146 if (isa<Value>(n))
147 llvm::errs() << "[" << cast<Value>(n) << ", "
148 << "Value"
149 << "]\n";
150 else if (isa<Operation *>(n))
151 llvm::errs() << "[" << *cast<Operation *>(n) << ", "
152 << "Operation"
153 << "]\n";
154 else
155 llvm::errs() << "["
156 << "NULL"
157 << ", "
158 << "None"
159 << "]\n";
160}
161
162struct Graph : public llvm::MapVector<Node, SmallPtrSet<Node, 2>> {
163 const SmallPtrSet<Node, 2> &at(const Node &n) {
164 auto found = find(n);
165 assert(found != end());
166 return found->second;
167 }
168};
169
170static void dumpGraphviz(Graph &G) {
171 auto serialize = [&](Node n) -> std::string {
172 std::string s;
173 llvm::raw_string_ostream ss(s);
174 if (isa<Value>(n)) {
175 auto v = cast<Value>(n);
176 if (isa<OpResult>(v)) {
177 auto res = cast<OpResult>(v);
178 ss << "[val](" << res.getResultNumber() << ")";
179 if (res.getOwner()->hasAttr("dbg")) {
180 auto dbg = res.getOwner()->getAttrOfType<StringAttr>("dbg");
181 ss << dbg.getValue();
182 } else {
183 ss << res.getOwner()->getName().getStringRef();
184 }
185 } else {
186 ss << "[val]" << v;
187 }
188 } else if (isa<Operation *>(n)) {
189 auto op = cast<Operation *>(n);
190 ss << "[op]";
191 if (op->hasAttr("dbg")) {
192 auto dbg = op->getAttrOfType<StringAttr>("dbg");
193 ss << dbg.getValue();
194 } else {
195 ss << op->getName().getStringRef();
196 }
197 } else {
198 ss << "none";
199 }
200 return s;
201 };
202
203 using llvm::errs;
204 errs() << "digraph G {\n";
205 for (auto &pair : G) {
206 for (const auto &N : pair.second) {
207 errs() << " \"" << serialize(pair.first) << "\" -> \"" << serialize(N)
208 << "\";\n";
209 }
210 }
211
212 errs() << "}\n";
213}
214
215static void dump(Graph &G) {
216 if (DebugGraphviz) {
217 dumpGraphviz(G);
218 } else {
219 for (auto &pair : G) {
220 dump(pair.first);
221 for (const auto &N : pair.second) {
222 llvm::errs() << "\t";
223 dump(N);
224 }
225 }
226 }
227}
228
229// A node in the compute graph.
230// Operation nodes have outgoing edges to value nodes that they produce and
231// incoming nodes from values they take as operands.
232
233// parent is populated with a path from each connected leaf node of G to one
234// of the Value in Source.
235static inline void bfs(const Graph &G, const llvm::SetVector<Value> &Sources,
236 DenseMap<Node, Node> &parent) {
237 std::deque<Node> q;
238 for (const auto &V : Sources) {
239 Node N(V);
240 parent.try_emplace(N, Node());
241 q.push_back(N);
242 }
243
244 // Standard BFS Loop
245
246 SmallPtrSet<Node, 2> done;
247
248 while (!q.empty()) {
249 auto u = q.front();
250 q.pop_front();
251 auto found = G.find(u);
252 if (found == G.end())
253 continue;
254
255 if (!done.insert(u).second)
256 continue;
257
258 for (const auto &v : found->second) {
259 if (parent.try_emplace(v, u).second) {
260 q.push_back(v);
261 }
262 }
263 }
264}
265
266// Whether or not an operation can be moved from the forward region to the
267// reverse region or vice-versa.
268static inline bool isMovable(Operation *op) {
269 return op->getNumRegions() == 0 && op->getBlock()->getTerminator() != op &&
270 mlir::isPure(op);
271}
272
273// Given a graph `G`, construct a new graph `G2`, where all paths must terminate
274// in a node in the set `Required` and start at `Root`.
275template <typename T>
276static Graph filterGraph(const Graph &Orig, const SetVector<Value> &Roots,
277 const SetVector<T> &Required) {
278 Graph inverted;
279
280 // Compute the graph with inverted edges by a floodfill, stopping at the first
281 // `required`. This is required in the case of a root -> required -> required
282 // edge. We do not want to contain the required->required subgraph.
283 if (false) {
284 std::deque<Node> worklist;
285 for (auto val : Roots) {
286 worklist.push_back(val);
287 }
288
289 SmallPtrSet<Node, 2> done;
290 for (auto src : Required) {
291 done.insert(src);
292 }
293
294 while (!worklist.empty()) {
295 Node N = worklist.front();
296 worklist.pop_front();
297
298 if (!done.insert(N).second)
299 continue;
300
301 auto pair = Orig.find(N);
302 if (pair == Orig.end()) {
303 continue;
304 }
305
306 for (const auto &NN : pair->second) {
307
308 inverted[NN].insert(N);
309 if (!done.contains(NN)) {
310 worklist.push_back(NN);
311 }
312 }
313 }
314
315 } else {
316 for (auto &pair : Orig) {
317 for (auto N : pair.second) {
318 inverted[N].insert(pair.first);
319 }
320 }
321 }
322
323 std::deque<Node> worklist;
324 for (auto snk : Required) {
325 worklist.emplace_back(snk);
326 }
327
328 SmallPtrSet<Node, 2> done;
329 for (auto src : Roots) {
330 done.insert(src);
331 }
332
333 Graph G;
334
335 while (!worklist.empty()) {
336 Node N = worklist.front();
337 worklist.pop_front();
338
339 if (!done.insert(N).second)
340 continue;
341
342 auto pair = inverted.find(N);
343 if (pair == inverted.end()) {
344 continue;
345 }
346
347 for (const auto &NN : pair->second) {
348
349 G[NN].insert(N);
350 if (!done.contains(NN)) {
351 worklist.push_back(NN);
352 }
353 }
354 }
355
356 return G;
357}
358
359static int64_t computeSizeOfType(Value val) {
360 auto T = dyn_cast<AutoDiffTypeInterface>(val.getType());
361 return T ? T.getApproxSize() : INT64_MAX;
362};
363
364static int64_t computeRankOfType(Value val) {
365 auto TT = dyn_cast<RankedTensorType>(val.getType());
366 return TT ? TT.getRank() : 0;
367}
368
369/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
370/// op is a top-level module op (which is expected to be isolated from above),
371/// return that op.
372static Operation *findCommonAncestor(ArrayRef<Operation *> ops) {
373 // Check if there is a top-level operation within `ops`. If so, return that
374 // op.
375 for (Operation *op : ops) {
376 if (!op->getParentOp()) {
377#ifndef NDEBUG
378 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
379 "expected top-level op to be isolated from above");
380 for (Operation *other : ops)
381 assert(op->isAncestor(other) &&
382 "expected ops to have a common ancestor");
383#endif // NDEBUG
384 return op;
385 }
386 }
387
388 // No top-level op. Find a common ancestor.
389 Operation *commonAncestor =
390 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
391 for (Operation *op : ops.drop_front()) {
392 while (!commonAncestor->isProperAncestor(op)) {
393 commonAncestor =
394 commonAncestor->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
395 assert(commonAncestor &&
396 "expected to find a common isolated from above ancestor");
397 }
398 }
399
400 return commonAncestor;
401}
402
403// Annotate operations with a debug attribute. This makes the GraphViz printing
404// nicer.
405static void annotate_ops(Block *forward, Block *reverse) {
406 unsigned counter = 0;
407 forward->walk([&](Operation *op) {
408 auto debugName =
409 StringAttr::get(op->getContext(),
410 op->getName().stripDialect() + llvm::Twine(counter++));
411 op->setAttr("dbg", debugName);
412 });
413 reverse->walk([&](Operation *op) {
414 auto debugName =
415 StringAttr::get(op->getContext(),
416 op->getName().stripDialect() + llvm::Twine(counter++));
417 op->setAttr("dbg", debugName);
418 });
419}
420
421// Given the full forward/backward compute graph, the push/pop can be seen
422// as a special cut of this graph. This function tries to modifies the
423// boundary of the push/pop to minimize the amount of memory that is live
424// across different loops.
425// The insertion point of rewriter must be in the reverse block, after any
426// fwdrevmap settings have been created.
427void mlir::enzyme::minCutCache(Block *forward, Block *reverse,
428 SmallVector<CacheInfo> &caches0,
429 PatternRewriter &rewriter,
430 const IRMapping &fwdrevmap, Operation *lastFwd) {
431 assert(rewriter.getInsertionBlock() == reverse);
432 assert(rewriter.getInsertionPoint()->getBlock() == reverse);
433 if (caches0.empty())
434 return;
435
436 LLVM_DEBUG(if (DebugGraphviz) annotate_ops(forward, reverse));
437
438 // where to build the new inits
439 Operation *entry = caches0[0].initOp;
440
441 IRMapping mapping = fwdrevmap;
442 SmallVector<CacheInfo> caches;
443 // Hoist out pushes of values that are defined outside of the block
444 for (auto &info : caches0) {
445 auto todo = info.pushedValue();
446 bool isDefinedOutside =
447 !forward->getParent()->isAncestor(todo.getParentRegion());
448 if (isDefinedOutside) {
449 rewriter.modifyOpInPlace(info.pushOp, [&]() {
450 if (&*rewriter.getInsertionPoint() == info.pushOp)
451 rewriter.setInsertionPoint(info.pushOp->getNextNode());
452
453 info.pushOp->moveBefore(forward->getParentOp());
454 });
455 rewriter.modifyOpInPlace(info.popOp, [&]() {
456 if (&*rewriter.getInsertionPoint() == info.popOp)
457 rewriter.setInsertionPoint(info.popOp->getNextNode());
458 info.popOp->moveBefore(reverse->getParentOp());
459 });
460 mapping.map(info.pushedValue(), info.popOp);
461 continue;
462 }
463 caches.push_back(info);
464 }
465 assert(rewriter.getInsertionPoint()->getBlock() == reverse);
466
467 if (caches.empty()) {
468 caches0.clear();
469 return;
470 }
471
472 // Maintain a mapping of forward to reverse blocks. We later use this to place
473 // the new cache pops and cloned ops to the correct blocks.
474 DenseMap<Block *, OpBuilder::InsertPoint> insertionPointMap;
475 for (const auto &info : caches) {
476 Block *fwdBlock = info.pushOp->getBlock();
477 Block *revBlock = info.popOp->getBlock();
478 // For the top-level reverse block, we use the provided rewriter's insertion
479 // point (to skip over things like IV calculations). New operations in inner
480 // blocks should be inserted at the beginning of those blocks.
481 if (revBlock == reverse) {
482 insertionPointMap[fwdBlock] = rewriter.saveInsertionPoint();
483 } else {
484 insertionPointMap[fwdBlock] =
485 OpBuilder::InsertPoint(revBlock, revBlock->begin());
486 }
487 }
488
489 Graph G;
490
491 LLVM_DEBUG(llvm::dbgs() << "trying min/cut\n");
492 LLVM_DEBUG(
493 findCommonAncestor({forward->getParentOp(), reverse->getParentOp()})
494 ->dump());
495
496 LLVM_DEBUG(llvm::dbgs() << "forward: " << *forward << "\n";);
497 LLVM_DEBUG(llvm::dbgs() << "reverse: " << *reverse << "\n";);
498
499 SmallVector<Value> worklist;
500 for (auto &cache : caches) {
501 worklist.push_back(cache.pushedValue());
502 }
503
504 // nodes that cannot be recomputed
505 SetVector<Value> roots;
506
507 // Walk Backward
508 //
509 // Roots (sources) are either block arguments or values which are defined
510 // outside of forward.
511 while (!worklist.empty()) {
512 Value todo = worklist.pop_back_val();
513
514 bool isDefinedOutside =
515 !forward->getParent()->isAncestor(todo.getParentRegion());
516 if (isDefinedOutside || fwdrevmap.contains(todo)) {
517 continue;
518 }
519
520 Operation *owner = todo.getDefiningOp();
521 if (!owner || !isMovable(owner)) {
522 roots.insert(todo);
523 continue;
524 }
525
526 auto &&[_, inserted] = G[Node(owner)].insert(Node(todo));
527 if (inserted) {
528 for (Value operand : owner->getOperands()) {
529 G[Node(operand)].insert(Node(owner));
530 worklist.push_back(operand);
531 }
532 }
533 }
534
535 worklist.clear();
536
537 // The operation whose use of a value forces a value to be available
538 // in the reverse pass
539 SetVector<Operation *> Required;
540
541 {
542 for (auto &info : caches) {
543 Value poped = info.popOp.getResult();
544
545 bool isRequired = false;
546 for (auto user : poped.getUsers()) {
547 if (user->getBlock() != reverse || !isMovable(user)) {
548 G[info.pushedValue()].insert(Node(user));
549 Required.insert(user);
550 isRequired = true;
551 break;
552 }
553 }
554 if (!isRequired)
555 for (auto user : poped.getUsers()) {
556 G[Node(info.pushedValue())].insert(user);
557 for (Value res : user->getResults()) {
558 G[Node(user)].insert(res);
559 worklist.push_back(res);
560 }
561 }
562 }
563
564 // Walk Forward
565 while (!worklist.empty()) {
566 Value todo = worklist.pop_back_val();
567
568 bool isRequired = false;
569 for (auto user : todo.getUsers()) {
570 if (user->getBlock() != reverse || !isMovable(user)) {
571 G[todo].insert(Node(user));
572 Required.insert(user);
573 isRequired = true;
574 break;
575 }
576 }
577 if (isRequired)
578 continue;
579
580 for (auto user : todo.getUsers()) {
581 Node N(user);
582 auto &&[_, inserted] = G[Node(todo)].insert(N);
583 if (inserted) {
584 for (Value res : user->getResults()) {
585 G[N].insert(Node(res));
586 worklist.push_back(res);
587 }
588 }
589 }
590 }
591
592 for (auto N : G) {
593 if (!isa<Operation *>(N.first))
594 continue;
595 auto op = cast<Operation *>(N.first);
596 if (op->getBlock() != reverse)
597 continue;
598 for (auto v : op->getOperands()) {
599 if (v.getParentBlock() != reverse) {
600 continue;
601 }
602 if (G.contains(Node(v))) {
603 continue;
604 }
605 Required.insert(op);
606 break;
607 }
608 }
609 assert(rewriter.getInsertionPoint()->getBlock() == reverse);
610
611 LLVM_DEBUG(llvm::dbgs() << "Required: \n";);
612 LLVM_DEBUG(for (auto R : Required) llvm::dbgs() << " + " << *R << "\n";);
613
614 LLVM_DEBUG(llvm::dbgs() << "Roots: \n";);
615 LLVM_DEBUG(for (auto R : roots) llvm::dbgs() << " + " << R << "\n";);
616 }
617
618 LLVM_DEBUG(llvm::dbgs() << "pre filter graph: \n";);
619 LLVM_DEBUG(dump(G));
620 G = filterGraph(G, roots, Required);
621 LLVM_DEBUG(llvm::dbgs() << "post filter graph: \n";);
622 LLVM_DEBUG(dump(G));
623
624 Graph Orig = G;
625
626 // Augment the flow while there is a path from source to sink
627 while (1) {
628 DenseMap<Node, Node> parent;
629 bfs(G, roots, parent);
630 Node end;
631 for (auto req : Required) {
632 if (parent.find(Node(req)) != parent.end()) {
633 end = Node(req);
634 break;
635 }
636 }
637 if (end.isNull())
638 break;
639 // update residual capacities of the edges and reverse edges
640 // along the path
641 Node v = end;
642 while (1) {
643 assert(parent.find(v) != parent.end());
644 Node u = parent.find(v)->second;
645 assert(!u.isNull());
646 assert(G[u].count(v) == 1);
647 assert(G[v].count(u) == 0);
648 G[u].erase(v);
649 G[v].insert(u);
650 if (isa<Value>(u) && roots.contains(cast<Value>(u)))
651 break;
652 v = u;
653 }
654 }
655 assert(rewriter.getInsertionPoint()->getBlock() == reverse);
656 // Flow is maximum now, find vertices reachable from s
657
658 DenseMap<Node, Node> parent;
659 bfs(G, roots, parent);
660
661 LLVM_DEBUG(llvm::dbgs() << "residual graph: \n";);
662 LLVM_DEBUG(dump(G));
663
664 // Those are the new values to cache
665 SetVector<Value> newCaches;
666
667 // All edges that are from a reachable vertex to non-reachable vertex in
668 // the original graph are edges for the minimum cut. The set of values to
669 // cache are the values transported along those edges (either. Value ->
670 // Operation or Operation -> Value).
671 //
672 // Note: we could use more heuristics here to select the actual cached
673 // value
674 // based on sizes, existing caches, number of users in the fwd as to
675 // not duplicate work, etc...
676 for (auto &pair : Orig) {
677 if (parent.find(pair.first) != parent.end()) {
678 for (auto N : pair.second) {
679 if (parent.find(N) == parent.end()) {
680 Value newCache;
681 if (isa<Value>(pair.first)) {
682 assert(isa<Operation *>(N));
683 newCache = cast<Value>(pair.first);
684 } else {
685 assert(isa<Operation *>(pair.first));
686 assert(isa<Value>(N));
687 newCache = cast<Value>(N);
688 }
689 newCaches.insert(newCache);
690 }
691 }
692 }
693 }
694
695 // compute path from new caches to required
696 parent.clear();
697 bfs(Orig, newCaches, parent);
698
699 LLVM_DEBUG({
700 llvm::dbgs() << "initial new caches: \n";
701 for (Value v : newCaches) {
702 v.dump();
703 }
704 });
705
706 // The cachegraph is a sub graph of Orig with only pathes new caches
707 // to Required nodes.
708 Graph cacheGraph = filterGraph(Orig, newCaches, Required);
709
710 LLVM_DEBUG(llvm::dbgs() << "cacheGraph:\n");
711 LLVM_DEBUG(dump(cacheGraph));
712
713 SmallVector<CacheInfo> newCacheInfos;
714
715 // We guard here so then the IP after this is immediately before the new pop's
716 Operation *firstClone = nullptr;
717
718 // Refine cached values based on some heuristics
719 if (newCaches.size()) {
720
721 // sort caches to provide determinism.
722 // llvm::sort(newCaches.getArrayRef().begin(),
723 // newCaches.getArrayRef().end(), mlir::enzyme::valueCmp);
724
725 SmallVector<Value> todo(newCaches.begin(), newCaches.end());
726 while (todo.size()) {
727 auto cur = todo.pop_back_val();
728
729 auto &next = cacheGraph.at(Node(cur));
730
731 if (next.size() > 1)
732 continue;
733
734 auto nextF = *next.begin();
735 assert(isa<Operation *>(nextF));
736 auto opNext = cast<Operation *>(nextF);
737
738 if (Required.count(opNext))
739 continue;
740
741 if (opNext->getNumResults() != 1)
742 continue;
743
744 Value candidate = opNext->getResult(0);
745
746 int64_t curSize = computeSizeOfType(cur),
747 curRank = computeRankOfType(cur);
748
749 int64_t newSize = computeSizeOfType(candidate),
750 newRank = computeRankOfType(candidate);
751
752 if (newRank < curRank || (newRank == curRank && newSize < curSize)) {
753 newCaches.remove(cur);
754 newCaches.insert(candidate);
755 todo.push_back(candidate);
756 cacheGraph.erase(cur);
757 cacheGraph.erase(opNext);
758 }
759 }
760
761 LLVM_DEBUG(llvm::dbgs() << "refined cacheGraph:\n");
762 LLVM_DEBUG(dump(cacheGraph));
763 LLVM_DEBUG({
764 llvm::dbgs() << "refined new caches: \n";
765 for (Value v : newCaches) {
766 v.dump();
767 }
768 });
769
770 SetVector<Value> reverseCaches;
771 for (Value newCache : newCaches) {
772 if (!forward->getParent()->isAncestor(newCache.getParentRegion())) {
773 reverseCaches.insert(newCache);
774 continue;
775 }
776 assert(rewriter.getInsertionBlock() == reverse);
777
778 enzyme::InitOp initOp = ({
779 OpBuilder::InsertionGuard guard(rewriter);
780 rewriter.setInsertionPoint(entry);
781 enzyme::InitOp::create(
782 rewriter, newCache.getLoc(),
783 enzyme::CacheType::get(newCache.getContext(), newCache.getType()));
784 });
785
786 enzyme::PushOp pushOp = ({
787 OpBuilder::InsertionGuard guard(rewriter);
788 if (lastFwd && isa<BlockArgument>(newCache)) {
789 rewriter.setInsertionPointAfter(lastFwd);
790 } else {
791 rewriter.setInsertionPointAfterValue(newCache);
792 }
793 enzyme::PushOp::create(rewriter, newCache.getLoc(), initOp.getResult(),
794 newCache);
795 });
796
797 OpBuilder::InsertionGuard guard(rewriter);
798 rewriter.restoreInsertionPoint(
799 insertionPointMap.lookup(newCache.getParentBlock()));
800 enzyme::PopOp popOp = enzyme::PopOp::create(
801 rewriter, newCache.getLoc(), newCache.getType(), initOp.getResult());
802 insertionPointMap[newCache.getParentBlock()] =
803 rewriter.saveInsertionPoint();
804 if (!firstClone)
805 firstClone = popOp;
806 mapping.map(newCache, popOp.getResult());
807
808 CacheInfo info;
809 info.initOp = initOp;
810 info.pushOp = pushOp;
811 info.popOp = popOp;
812 newCacheInfos.push_back(info);
813 }
814
815 if (reverseCaches.size()) {
816 Graph fwdGraph = filterGraph(Orig, roots, newCaches);
817
818 IRMapping fwdmap;
819 for (auto &info : caches) {
820 fwdmap.map(info.popOp->getResult(0), info.pushedValue());
821 }
822
823 SmallVector<Operation *> toErase;
824 for (auto &op : llvm::make_early_inc_range(*reverse)) {
825 if (!fwdGraph.contains(Node(&op)))
826 continue;
827
828 Operation *newO = ({
829 OpBuilder::InsertionGuard guard(rewriter);
830 rewriter.setInsertionPoint(forward->getTerminator());
831 rewriter.clone(op, fwdmap);
832 });
833
834 bool hasUse = false;
835 for (auto &&[res, newRes] :
836 llvm::zip_equal(op.getResults(), newO->getResults())) {
837 if (newCaches.contains(res)) {
838 enzyme::InitOp initOp = ({
839 OpBuilder::InsertionGuard guard(rewriter);
840 rewriter.setInsertionPoint(entry);
841 enzyme::InitOp::create(rewriter, newRes.getLoc(),
842 enzyme::CacheType::get(newRes.getContext(),
843 newRes.getType()));
844 });
845
846 enzyme::PushOp pushOp = ({
847 OpBuilder::InsertionGuard guard(rewriter);
848 rewriter.setInsertionPoint(forward->getTerminator());
849 enzyme::PushOp::create(rewriter, newRes.getLoc(),
850 initOp.getResult(), newRes);
851 });
852
853 enzyme::PopOp popOp = ({
854 OpBuilder::InsertionGuard guard(rewriter);
855 rewriter.setInsertionPoint(&op);
856 enzyme::PopOp::create(rewriter, newRes.getLoc(), newRes.getType(),
857 initOp.getResult());
858 });
859
860 rewriter.replaceAllUsesWith(res, popOp->getResult(0));
861
862 CacheInfo info;
863 info.initOp = initOp;
864 info.pushOp = pushOp;
865 info.popOp = popOp;
866 newCacheInfos.push_back(info);
867 }
868 if (!hasUse) {
869 for (auto user : res.getUsers()) {
870 if (!fwdGraph.contains(Node(user))) {
871 hasUse = true;
872 break;
873 }
874 }
875 }
876 }
877
878 if (!hasUse && !op.hasAttr("enzyme.no_erase")) {
879 toErase.push_back(&op);
880 }
881 }
882 for (auto op : llvm::reverse(toErase)) {
883 rewriter.eraseOp(op);
884 }
885 }
886 }
887
888 forward->walk([&](Operation *op) {
889 if (!cacheGraph.contains(Node(op)))
890 return;
891 bool hasUse = false;
892 for (auto res : op->getResults()) {
893 if (newCaches.contains(res)) {
894 continue;
895 }
896 hasUse = true;
897 }
898 if (!hasUse)
899 return;
900 for (auto v : op->getOperands()) {
901 if (mapping.contains(v))
902 continue;
903 if (forward->getParent()->isAncestor(v.getParentRegion()))
904 continue;
905
906 enzyme::InitOp initOp = ({
907 OpBuilder::InsertionGuard guard(rewriter);
908 rewriter.setInsertionPoint(entry);
909 enzyme::InitOp::create(
910 rewriter, v.getLoc(),
911 enzyme::CacheType::get(v.getContext(), v.getType()));
912 });
913
914 {
915 OpBuilder::InsertionGuard guard(rewriter);
916 rewriter.setInsertionPoint(forward->getParentOp());
917 enzyme::PushOp::create(rewriter, v.getLoc(), initOp.getResult(), v);
918 };
919
920 enzyme::PopOp popOp = ({
921 OpBuilder::InsertionGuard guard(rewriter);
922 rewriter.setInsertionPoint(reverse->getParentOp());
923 enzyme::PopOp::create(rewriter, v.getLoc(), v.getType(),
924 initOp.getResult());
925 });
926 mapping.map(v, popOp->getResult(0));
927 }
928 OpBuilder::InsertionGuard guard(rewriter);
929 rewriter.restoreInsertionPoint(insertionPointMap.lookup(op->getBlock()));
930 auto cop = rewriter.clone(*op, mapping);
931 insertionPointMap[op->getBlock()] = rewriter.saveInsertionPoint();
932 if (!firstClone)
933 firstClone = cop;
934 });
935
936 if (firstClone)
937 rewriter.setInsertionPoint(firstClone);
938
939 // Remove old caches
940 for (auto &info : caches) {
941 if (mapping.contains(info.pushedValue())) {
942 rewriter.replaceOp(info.popOp, mapping.lookup(info.pushedValue()));
943 } else {
944 rewriter.eraseOp(info.popOp);
945 }
946 rewriter.eraseOp(info.pushOp);
947 rewriter.eraseOp(info.initOp);
948 }
949
950 LLVM_DEBUG(llvm::dbgs() << "post min/cut\n");
951 LLVM_DEBUG(
952 findCommonAncestor({forward->getParentOp(), reverse->getParentOp()})
953 ->dump());
954
955 // Set new caches
956 caches0 = std::move(newCacheInfos);
957}
958
961 mlir::PatternRewriter &rewriter) {
962 assert(other.pushOp->getBlock() == pushOp->getBlock());
963 assert(other.popOp->getBlock() == popOp->getBlock());
964
965 enzyme::InitOp newInitOp;
966 if (other.initOp->isBeforeInBlock(initOp)) {
967 newInitOp = other.initOp;
968 rewriter.replaceAllUsesWith(initOp.getResult(), newInitOp.getResult());
969 rewriter.eraseOp(initOp);
970 } else {
971 newInitOp = initOp;
972 rewriter.replaceAllUsesWith(other.initOp.getResult(),
973 newInitOp.getResult());
974 rewriter.eraseOp(other.initOp);
975 }
976
977 rewriter.eraseOp(other.pushOp);
978
979 enzyme::PopOp newPopOp;
980 if (other.popOp->isBeforeInBlock(popOp)) {
981 newPopOp = other.popOp;
982 rewriter.replaceAllUsesWith(popOp.getResult(), newPopOp.getResult());
983 rewriter.eraseOp(popOp);
984 } else {
985 newPopOp = popOp;
986 rewriter.replaceAllUsesWith(other.popOp.getResult(), newPopOp.getResult());
987 rewriter.eraseOp(other.popOp);
988 }
989
990 CacheInfo newInfo{newInitOp};
991 return newInfo;
992}
#define getOp(vtmp)
void dump(const Node &n)
static int64_t computeSizeOfType(Value val)
static int64_t computeRankOfType(Value val)
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static Graph filterGraph(const Graph &Orig, const SetVector< Value > &Roots, const SetVector< T > &Required)
static void bfs(const Graph &G, const llvm::SetVector< Value > &Sources, DenseMap< Node, Node > &parent)
llvm::PointerUnion< Operation *, Value > Node
static llvm::cl::opt< bool > DebugGraphviz("mincut-print-graphviz", llvm::cl::init(false), llvm::cl::Hidden, llvm::cl::desc("Use with DEBUG_TYPE 'enzyme-mincut' to print " "the mincut graphs in GraphViz"))
static bool isMovable(Operation *op)
static void dumpGraphviz(Graph &G)
static void annotate_ops(Block *forward, Block *reverse)
mlir::Value getDifferential(mlir::Value origv)
bool isConstantValue(mlir::Value v) const
void removalBlockExplore(Block *block, IRMapping &mapping, PatternRewriter &rewriter, llvm::SetVector< Value > &gradients, llvm::MapVector< Value, CacheInfo > &caches)
void localizeGradients(OpBuilder &builder, MGradientUtilsReverse *gutils, Block *fwd)
void minCutCache(Block *forward, Block *reverse, SmallVector< CacheInfo > &caches, PatternRewriter &rewriter, const IRMapping &fwdrevmap, Operation *lastFwd=nullptr)
const SmallPtrSet< Node, 2 > & at(const Node &n)
Information about a cache, each cache init should have one corresponding push and pop.
CacheInfo merge(CacheInfo other, PatternRewriter &rewriter)
enzyme::InitOp initOp
enzyme::PushOp pushOp