Enzyme main
Loading...
Searching...
No Matches
SimplifyMemrefCache.cpp
Go to the documentation of this file.
1#include "Dialect/Dialect.h"
2#include "Dialect/Ops.h"
3#include "PassDetails.h"
4#include "Passes/Passes.h"
5#include "mlir/Dialect/Arith/IR/Arith.h"
6#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
7#include "mlir/Dialect/Func/IR/FuncOps.h"
8#include "mlir/Dialect/MemRef/IR/MemRef.h"
9#include "mlir/Dialect/SCF/IR/SCF.h"
10#include "mlir/Transforms/DialectConversion.h"
11
12#include "mlir/Dialect/Linalg/IR/Linalg.h"
13#include "mlir/Rewrite/PatternApplicator.h"
14
15#include "mlir/IR/Dominance.h"
16#include "llvm/Support/raw_ostream.h"
17
18namespace mlir {
19namespace enzyme {
20#define GEN_PASS_DEF_SIMPLIFYMEMREFCACHEPASS
21#include "Passes/Passes.h.inc"
22} // namespace enzyme
23} // namespace mlir
24
25using namespace mlir;
26using namespace enzyme;
27using llvm::errs;
28namespace {
29
30struct SimplifyMemrefCachePass
31 : public enzyme::impl::SimplifyMemrefCachePassBase<
32 SimplifyMemrefCachePass> {
33
34 void handlePushOp(enzyme::PushOp pushOp, Type newType, enzyme::CacheType c2) {
35 auto v = pushOp.getValue();
36 auto definingOp = v.getDefiningOp();
37 auto allocOp = dyn_cast<memref::AllocOp>(definingOp);
38 if (!allocOp) {
39 llvm_unreachable("Unknown user of memref<CacheType>");
40 }
41 OpBuilder allocBuilder(allocOp);
42 auto newAllocOp = memref::AllocOp::create(
43 allocBuilder, allocOp.getLoc(), dyn_cast<MemRefType>(newType),
44 allocOp.getDynamicSizes(), allocOp.getSymbolOperands(),
45 allocOp.getAlignmentAttr());
46
47 for (auto user : allocOp->getUsers()) {
48 auto linalgOp = dyn_cast<linalg::GenericOp>(user);
49 if (isa<enzyme::PushOp>(user)) {
50 continue;
51 } else if (!linalgOp) {
52 llvm_unreachable("Unknown user of memref<CacheType>");
53 }
54
55 for (auto &&output : llvm::enumerate(linalgOp.getOutputs())) {
56 if (output.value() != allocOp) {
57 continue;
58 }
59 unsigned outputIndex =
60 linalgOp.getNumDpsInputs() + (unsigned)output.index();
61
62 // We should never actually use the value of the output!
63 assert(linalgOp.getRegion().getArgument(outputIndex).use_empty());
64
65 linalgOp.getRegion().eraseArgument(outputIndex);
66 linalgOp.getRegion().insertArgument(outputIndex, c2.getType(),
67 output.value().getLoc());
68
69 Value cache = linalgOp.getRegion().front().getTerminator()->getOperand(
70 (unsigned)output.index());
71 for (auto user : cache.getUsers()) {
72 if (isa<enzyme::PopOp>(user)) {
73 llvm_unreachable("PopOp should not be used in forward pass");
74 }
75 auto pushOp = dyn_cast<enzyme::PushOp>(user);
76 if (!pushOp) {
77 continue;
78 }
79 linalgOp.getRegion().front().getTerminator()->setOperand(
80 (unsigned)output.index(), pushOp.getValue());
81 pushOp.erase();
82 }
83 cache.getDefiningOp()->erase();
84 }
85 }
86
87 allocOp.replaceAllUsesWith((Value)newAllocOp);
88 allocOp.erase();
89 }
90
91 void handlePopOp(enzyme::PopOp popOp, Type newType, enzyme::CacheType c2) {
92 OpBuilder popBuilder(popOp);
93 auto newPopOp = enzyme::PopOp::create(popBuilder, popOp.getLoc(), newType,
94 popOp.getCache());
95
96 // TODO: handle all the stuff inside linalg.generic
97 for (auto user : popOp->getUsers()) {
98 auto subviewOp = dyn_cast<memref::SubViewOp>(user);
99 if (!subviewOp) {
100 continue;
101 }
102 for (auto user : subviewOp->getUsers()) {
103 auto linalgOp = dyn_cast<enzyme::GenericAdjointOp>(user);
104 if (!linalgOp) {
105 continue;
106 }
107 for (auto &&input : llvm::enumerate(linalgOp.getInputs())) {
108 if (input.value() != subviewOp) {
109 continue;
110 }
111 unsigned inputIndex = (unsigned)input.index();
112 Value inputCacheSSA = linalgOp.getRegion().insertArgument(
113 inputIndex, c2.getType(), input.value().getLoc());
114 Value oldArg = linalgOp.getRegion().getArgument(inputIndex + 1);
115 for (auto user : oldArg.getUsers()) {
116 auto popOp = dyn_cast<enzyme::PopOp>(user);
117 if (!popOp) {
118 llvm_unreachable("Unknown user");
119 }
120 popOp.replaceAllUsesWith(inputCacheSSA);
121 popOp.erase();
122 }
123
124 // +1 because we inserted an argument above
125 linalgOp.getRegion().eraseArgument(inputIndex + 1);
126 }
127 }
128 // Replace Subview Op
129 OpBuilder subviewBuilder(subviewOp);
130 auto newSubviewOp = memref::SubViewOp::create(
131 subviewBuilder, subviewOp.getLoc(), newPopOp, subviewOp.getOffsets(),
132 subviewOp.getSizes(), subviewOp.getStrides());
133 subviewOp.replaceAllUsesWith((Value)newSubviewOp);
134 subviewOp.erase();
135 }
136 popOp.replaceAllUsesWith((Value)newPopOp);
137 popOp.erase();
138 }
139
140 void runOnOperation() override {
141 MLIRContext *context = &getContext();
142
143 getOperation()->walk([&](Operation *op) {
144 auto initOp = dyn_cast<enzyme::InitOp>(op);
145 if (!initOp) {
146 return;
147 }
148 auto c1 = dyn_cast<enzyme::CacheType>(initOp.getType());
149 if (!c1) {
150 return;
151 }
152 auto memref = dyn_cast<MemRefType>(c1.getType());
153 if (!memref) {
154 return;
155 }
156 auto c2 = dyn_cast<enzyme::CacheType>(memref.getElementType());
157 if (!c2) {
158 return;
159 }
160 mlir::MemRefType::Builder memrefTypeBuilder(memref);
161 memrefTypeBuilder.setElementType(c2.getType());
162 Type newType = memrefTypeBuilder;
163 Type newCacheType = enzyme::CacheType::get(context, newType);
164 for (auto user : op->getUsers()) {
165 if (auto pushOp = dyn_cast<enzyme::PushOp>(user)) {
166 handlePushOp(pushOp, newType, c2);
167 } else if (auto popOp = dyn_cast<enzyme::PopOp>(user)) {
168 handlePopOp(popOp, newType, c2);
169 } else {
170 llvm_unreachable("Unknown user of InitOp");
171 }
172 }
173
174 OpBuilder builder(op);
175 auto newInit =
176 enzyme::InitOp::create(builder, op->getLoc(), newCacheType);
177 op->replaceAllUsesWith(newInit);
178
179 op->erase();
180 });
181 };
182};
183} // namespace