428 SmallVector<CacheInfo> &caches0,
429 PatternRewriter &rewriter,
430 const IRMapping &fwdrevmap, Operation *lastFwd) {
431 assert(rewriter.getInsertionBlock() == reverse);
432 assert(rewriter.getInsertionPoint()->getBlock() == reverse);
439 Operation *entry = caches0[0].initOp;
441 IRMapping mapping = fwdrevmap;
442 SmallVector<CacheInfo> caches;
444 for (
auto &info : caches0) {
445 auto todo = info.pushedValue();
446 bool isDefinedOutside =
447 !forward->getParent()->isAncestor(todo.getParentRegion());
448 if (isDefinedOutside) {
449 rewriter.modifyOpInPlace(info.pushOp, [&]() {
450 if (&*rewriter.getInsertionPoint() == info.pushOp)
451 rewriter.setInsertionPoint(info.pushOp->getNextNode());
453 info.pushOp->moveBefore(forward->getParentOp());
455 rewriter.modifyOpInPlace(info.popOp, [&]() {
456 if (&*rewriter.getInsertionPoint() == info.popOp)
457 rewriter.setInsertionPoint(info.popOp->getNextNode());
458 info.popOp->moveBefore(reverse->getParentOp());
460 mapping.map(info.pushedValue(), info.popOp);
463 caches.push_back(info);
465 assert(rewriter.getInsertionPoint()->getBlock() == reverse);
467 if (caches.empty()) {
474 DenseMap<Block *, OpBuilder::InsertPoint> insertionPointMap;
475 for (
const auto &info : caches) {
476 Block *fwdBlock = info.pushOp->getBlock();
477 Block *revBlock = info.popOp->getBlock();
481 if (revBlock == reverse) {
482 insertionPointMap[fwdBlock] = rewriter.saveInsertionPoint();
484 insertionPointMap[fwdBlock] =
485 OpBuilder::InsertPoint(revBlock, revBlock->begin());
491 LLVM_DEBUG(llvm::dbgs() <<
"trying min/cut\n");
496 LLVM_DEBUG(llvm::dbgs() <<
"forward: " << *forward <<
"\n";);
497 LLVM_DEBUG(llvm::dbgs() <<
"reverse: " << *reverse <<
"\n";);
499 SmallVector<Value> worklist;
500 for (
auto &cache : caches) {
501 worklist.push_back(cache.pushedValue());
505 SetVector<Value> roots;
511 while (!worklist.empty()) {
512 Value todo = worklist.pop_back_val();
514 bool isDefinedOutside =
515 !forward->getParent()->isAncestor(todo.getParentRegion());
516 if (isDefinedOutside || fwdrevmap.contains(todo)) {
520 Operation *owner = todo.getDefiningOp();
526 auto &&[_, inserted] = G[
Node(owner)].insert(
Node(todo));
528 for (Value operand : owner->getOperands()) {
529 G[
Node(operand)].insert(
Node(owner));
530 worklist.push_back(operand);
539 SetVector<Operation *> Required;
542 for (
auto &info : caches) {
543 Value poped = info.popOp.getResult();
545 bool isRequired =
false;
546 for (
auto user : poped.getUsers()) {
547 if (user->getBlock() != reverse || !
isMovable(user)) {
548 G[info.pushedValue()].insert(
Node(user));
549 Required.insert(user);
555 for (
auto user : poped.getUsers()) {
556 G[
Node(info.pushedValue())].insert(user);
557 for (Value res : user->getResults()) {
558 G[
Node(user)].insert(res);
559 worklist.push_back(res);
565 while (!worklist.empty()) {
566 Value todo = worklist.pop_back_val();
568 bool isRequired =
false;
569 for (
auto user : todo.getUsers()) {
570 if (user->getBlock() != reverse || !
isMovable(user)) {
571 G[todo].insert(
Node(user));
572 Required.insert(user);
580 for (
auto user : todo.getUsers()) {
582 auto &&[_, inserted] = G[
Node(todo)].insert(N);
584 for (Value res : user->getResults()) {
585 G[N].insert(
Node(res));
586 worklist.push_back(res);
593 if (!isa<Operation *>(N.first))
595 auto op = cast<Operation *>(N.first);
596 if (op->getBlock() != reverse)
598 for (
auto v : op->getOperands()) {
599 if (v.getParentBlock() != reverse) {
602 if (G.contains(
Node(v))) {
609 assert(rewriter.getInsertionPoint()->getBlock() == reverse);
611 LLVM_DEBUG(llvm::dbgs() <<
"Required: \n";);
612 LLVM_DEBUG(
for (
auto R : Required) llvm::dbgs() <<
" + " << *R <<
"\n";);
614 LLVM_DEBUG(llvm::dbgs() <<
"Roots: \n";);
615 LLVM_DEBUG(
for (
auto R : roots) llvm::dbgs() <<
" + " << R <<
"\n";);
618 LLVM_DEBUG(llvm::dbgs() <<
"pre filter graph: \n";);
621 LLVM_DEBUG(llvm::dbgs() <<
"post filter graph: \n";);
628 DenseMap<Node, Node> parent;
629 bfs(G, roots, parent);
631 for (
auto req : Required) {
632 if (parent.find(
Node(req)) != parent.end()) {
643 assert(parent.find(v) != parent.end());
644 Node u = parent.find(v)->second;
646 assert(G[u].count(v) == 1);
647 assert(G[v].count(u) == 0);
650 if (isa<Value>(u) && roots.contains(cast<Value>(u)))
655 assert(rewriter.getInsertionPoint()->getBlock() == reverse);
658 DenseMap<Node, Node> parent;
659 bfs(G, roots, parent);
661 LLVM_DEBUG(llvm::dbgs() <<
"residual graph: \n";);
665 SetVector<Value> newCaches;
676 for (
auto &pair : Orig) {
677 if (parent.find(pair.first) != parent.end()) {
678 for (
auto N : pair.second) {
679 if (parent.find(N) == parent.end()) {
681 if (isa<Value>(pair.first)) {
682 assert(isa<Operation *>(N));
683 newCache = cast<Value>(pair.first);
685 assert(isa<Operation *>(pair.first));
686 assert(isa<Value>(N));
687 newCache = cast<Value>(N);
689 newCaches.insert(newCache);
697 bfs(Orig, newCaches, parent);
700 llvm::dbgs() <<
"initial new caches: \n";
701 for (Value v : newCaches) {
710 LLVM_DEBUG(llvm::dbgs() <<
"cacheGraph:\n");
711 LLVM_DEBUG(
dump(cacheGraph));
713 SmallVector<CacheInfo> newCacheInfos;
716 Operation *firstClone =
nullptr;
719 if (newCaches.size()) {
725 SmallVector<Value> todo(newCaches.begin(), newCaches.end());
726 while (todo.size()) {
727 auto cur = todo.pop_back_val();
729 auto &next = cacheGraph.at(
Node(cur));
734 auto nextF = *next.begin();
735 assert(isa<Operation *>(nextF));
736 auto opNext = cast<Operation *>(nextF);
738 if (Required.count(opNext))
741 if (opNext->getNumResults() != 1)
744 Value candidate = opNext->getResult(0);
752 if (newRank < curRank || (newRank == curRank && newSize < curSize)) {
753 newCaches.remove(cur);
754 newCaches.insert(candidate);
755 todo.push_back(candidate);
756 cacheGraph.erase(cur);
757 cacheGraph.erase(opNext);
761 LLVM_DEBUG(llvm::dbgs() <<
"refined cacheGraph:\n");
762 LLVM_DEBUG(
dump(cacheGraph));
764 llvm::dbgs() <<
"refined new caches: \n";
765 for (Value v : newCaches) {
770 SetVector<Value> reverseCaches;
771 for (Value newCache : newCaches) {
772 if (!forward->getParent()->isAncestor(newCache.getParentRegion())) {
773 reverseCaches.insert(newCache);
776 assert(rewriter.getInsertionBlock() == reverse);
778 enzyme::InitOp initOp = ({
779 OpBuilder::InsertionGuard guard(rewriter);
780 rewriter.setInsertionPoint(entry);
781 enzyme::InitOp::create(
782 rewriter, newCache.getLoc(),
783 enzyme::CacheType::get(newCache.getContext(), newCache.getType()));
786 enzyme::PushOp pushOp = ({
787 OpBuilder::InsertionGuard guard(rewriter);
788 if (lastFwd && isa<BlockArgument>(newCache)) {
789 rewriter.setInsertionPointAfter(lastFwd);
791 rewriter.setInsertionPointAfterValue(newCache);
793 enzyme::PushOp::create(rewriter, newCache.getLoc(), initOp.getResult(),
797 OpBuilder::InsertionGuard guard(rewriter);
798 rewriter.restoreInsertionPoint(
799 insertionPointMap.lookup(newCache.getParentBlock()));
800 enzyme::PopOp popOp = enzyme::PopOp::create(
801 rewriter, newCache.getLoc(), newCache.getType(), initOp.getResult());
802 insertionPointMap[newCache.getParentBlock()] =
803 rewriter.saveInsertionPoint();
806 mapping.map(newCache, popOp.getResult());
812 newCacheInfos.push_back(info);
815 if (reverseCaches.size()) {
819 for (
auto &info : caches) {
820 fwdmap.map(info.popOp->getResult(0), info.pushedValue());
823 SmallVector<Operation *> toErase;
824 for (
auto &op : llvm::make_early_inc_range(*reverse)) {
825 if (!fwdGraph.contains(
Node(&op)))
829 OpBuilder::InsertionGuard guard(rewriter);
830 rewriter.setInsertionPoint(forward->getTerminator());
831 rewriter.clone(op, fwdmap);
835 for (
auto &&[res, newRes] :
836 llvm::zip_equal(op.getResults(), newO->getResults())) {
837 if (newCaches.contains(res)) {
838 enzyme::InitOp initOp = ({
839 OpBuilder::InsertionGuard guard(rewriter);
840 rewriter.setInsertionPoint(entry);
841 enzyme::InitOp::create(rewriter, newRes.getLoc(),
842 enzyme::CacheType::get(newRes.getContext(),
846 enzyme::PushOp pushOp = ({
847 OpBuilder::InsertionGuard guard(rewriter);
848 rewriter.setInsertionPoint(forward->getTerminator());
849 enzyme::PushOp::create(rewriter, newRes.getLoc(),
850 initOp.getResult(), newRes);
853 enzyme::PopOp popOp = ({
854 OpBuilder::InsertionGuard guard(rewriter);
855 rewriter.setInsertionPoint(&op);
856 enzyme::PopOp::create(rewriter, newRes.getLoc(), newRes.getType(),
860 rewriter.replaceAllUsesWith(res, popOp->getResult(0));
866 newCacheInfos.push_back(info);
869 for (
auto user : res.getUsers()) {
870 if (!fwdGraph.contains(
Node(user))) {
878 if (!hasUse && !op.hasAttr(
"enzyme.no_erase")) {
879 toErase.push_back(&op);
882 for (
auto op : llvm::reverse(toErase)) {
883 rewriter.eraseOp(op);
888 forward->walk([&](Operation *op) {
889 if (!cacheGraph.contains(
Node(op)))
892 for (
auto res : op->getResults()) {
893 if (newCaches.contains(res)) {
900 for (
auto v : op->getOperands()) {
901 if (mapping.contains(v))
903 if (forward->getParent()->isAncestor(v.getParentRegion()))
906 enzyme::InitOp initOp = ({
907 OpBuilder::InsertionGuard guard(rewriter);
908 rewriter.setInsertionPoint(entry);
909 enzyme::InitOp::create(
910 rewriter, v.getLoc(),
911 enzyme::CacheType::get(v.getContext(), v.getType()));
915 OpBuilder::InsertionGuard guard(rewriter);
916 rewriter.setInsertionPoint(forward->getParentOp());
917 enzyme::PushOp::create(rewriter, v.getLoc(), initOp.getResult(), v);
920 enzyme::PopOp popOp = ({
921 OpBuilder::InsertionGuard guard(rewriter);
922 rewriter.setInsertionPoint(reverse->getParentOp());
923 enzyme::PopOp::create(rewriter, v.getLoc(), v.getType(),
926 mapping.map(v, popOp->getResult(0));
928 OpBuilder::InsertionGuard guard(rewriter);
929 rewriter.restoreInsertionPoint(insertionPointMap.lookup(op->getBlock()));
930 auto cop = rewriter.clone(*op, mapping);
931 insertionPointMap[op->getBlock()] = rewriter.saveInsertionPoint();
937 rewriter.setInsertionPoint(firstClone);
940 for (
auto &info : caches) {
941 if (mapping.contains(info.pushedValue())) {
942 rewriter.replaceOp(info.popOp, mapping.lookup(info.pushedValue()));
944 rewriter.eraseOp(info.popOp);
946 rewriter.eraseOp(info.pushOp);
947 rewriter.eraseOp(info.initOp);
950 LLVM_DEBUG(llvm::dbgs() <<
"post min/cut\n");
956 caches0 = std::move(newCacheInfos);