5#include "mlir/Dialect/Arith/IR/Arith.h"
6#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
7#include "mlir/Dialect/Func/IR/FuncOps.h"
8#include "mlir/Dialect/MemRef/IR/MemRef.h"
9#include "mlir/Dialect/SCF/IR/SCF.h"
10#include "mlir/Transforms/DialectConversion.h"
12#include "mlir/Dialect/Linalg/IR/Linalg.h"
13#include "mlir/Rewrite/PatternApplicator.h"
15#include "mlir/IR/Dominance.h"
16#include "llvm/Support/raw_ostream.h"
20#define GEN_PASS_DEF_SIMPLIFYMEMREFCACHEPASS
21#include "Passes/Passes.h.inc"
26using namespace enzyme;
30struct SimplifyMemrefCachePass
31 :
public enzyme::impl::SimplifyMemrefCachePassBase<
32 SimplifyMemrefCachePass> {
34 void handlePushOp(enzyme::PushOp pushOp, Type newType, enzyme::CacheType c2) {
35 auto v = pushOp.getValue();
36 auto definingOp = v.getDefiningOp();
37 auto allocOp = dyn_cast<memref::AllocOp>(definingOp);
39 llvm_unreachable(
"Unknown user of memref<CacheType>");
41 OpBuilder allocBuilder(allocOp);
42 auto newAllocOp = memref::AllocOp::create(
43 allocBuilder, allocOp.getLoc(), dyn_cast<MemRefType>(newType),
44 allocOp.getDynamicSizes(), allocOp.getSymbolOperands(),
45 allocOp.getAlignmentAttr());
47 for (
auto user : allocOp->getUsers()) {
48 auto linalgOp = dyn_cast<linalg::GenericOp>(user);
49 if (isa<enzyme::PushOp>(user)) {
51 }
else if (!linalgOp) {
52 llvm_unreachable(
"Unknown user of memref<CacheType>");
55 for (
auto &&output : llvm::enumerate(linalgOp.getOutputs())) {
56 if (output.value() != allocOp) {
59 unsigned outputIndex =
60 linalgOp.getNumDpsInputs() + (unsigned)output.index();
63 assert(linalgOp.getRegion().getArgument(outputIndex).use_empty());
65 linalgOp.getRegion().eraseArgument(outputIndex);
66 linalgOp.getRegion().insertArgument(outputIndex, c2.getType(),
67 output.value().getLoc());
69 Value cache = linalgOp.getRegion().front().getTerminator()->getOperand(
70 (
unsigned)output.index());
71 for (
auto user : cache.getUsers()) {
72 if (isa<enzyme::PopOp>(user)) {
73 llvm_unreachable(
"PopOp should not be used in forward pass");
75 auto pushOp = dyn_cast<enzyme::PushOp>(user);
79 linalgOp.getRegion().front().getTerminator()->setOperand(
80 (
unsigned)output.index(), pushOp.getValue());
83 cache.getDefiningOp()->erase();
87 allocOp.replaceAllUsesWith((Value)newAllocOp);
91 void handlePopOp(enzyme::PopOp popOp, Type newType, enzyme::CacheType c2) {
92 OpBuilder popBuilder(popOp);
93 auto newPopOp = enzyme::PopOp::create(popBuilder, popOp.getLoc(), newType,
97 for (
auto user : popOp->getUsers()) {
98 auto subviewOp = dyn_cast<memref::SubViewOp>(user);
102 for (
auto user : subviewOp->getUsers()) {
103 auto linalgOp = dyn_cast<enzyme::GenericAdjointOp>(user);
107 for (
auto &&input : llvm::enumerate(linalgOp.getInputs())) {
108 if (input.value() != subviewOp) {
111 unsigned inputIndex = (unsigned)input.index();
112 Value inputCacheSSA = linalgOp.getRegion().insertArgument(
113 inputIndex, c2.getType(), input.value().getLoc());
114 Value oldArg = linalgOp.getRegion().getArgument(inputIndex + 1);
115 for (
auto user : oldArg.getUsers()) {
116 auto popOp = dyn_cast<enzyme::PopOp>(user);
118 llvm_unreachable(
"Unknown user");
120 popOp.replaceAllUsesWith(inputCacheSSA);
125 linalgOp.getRegion().eraseArgument(inputIndex + 1);
129 OpBuilder subviewBuilder(subviewOp);
130 auto newSubviewOp = memref::SubViewOp::create(
131 subviewBuilder, subviewOp.getLoc(), newPopOp, subviewOp.getOffsets(),
132 subviewOp.getSizes(), subviewOp.getStrides());
133 subviewOp.replaceAllUsesWith((Value)newSubviewOp);
136 popOp.replaceAllUsesWith((Value)newPopOp);
140 void runOnOperation()
override {
141 MLIRContext *context = &getContext();
143 getOperation()->walk([&](Operation *op) {
144 auto initOp = dyn_cast<enzyme::InitOp>(op);
148 auto c1 = dyn_cast<enzyme::CacheType>(initOp.getType());
152 auto memref = dyn_cast<MemRefType>(c1.getType());
156 auto c2 = dyn_cast<enzyme::CacheType>(memref.getElementType());
160 mlir::MemRefType::Builder memrefTypeBuilder(memref);
161 memrefTypeBuilder.setElementType(c2.getType());
162 Type newType = memrefTypeBuilder;
163 Type newCacheType = enzyme::CacheType::get(context, newType);
164 for (
auto user : op->getUsers()) {
165 if (
auto pushOp = dyn_cast<enzyme::PushOp>(user)) {
166 handlePushOp(pushOp, newType, c2);
167 }
else if (
auto popOp = dyn_cast<enzyme::PopOp>(user)) {
168 handlePopOp(popOp, newType, c2);
170 llvm_unreachable(
"Unknown user of InitOp");
174 OpBuilder builder(op);
176 enzyme::InitOp::create(builder, op->getLoc(), newCacheType);
177 op->replaceAllUsesWith(newInit);