17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/SCF/IR/SCF.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/Support/raw_ostream.h"
27#define GEN_PASS_DEF_ENZYMEOPSTOMEMREFPASS
28#include "Passes/Passes.h.inc"
36 Value elements, size, capacity;
39 void allocateCache(Location loc, OpBuilder &b) {}
41 FlatSymbolRefAttr getOrInsertPushFunction(Location loc, ModuleOp moduleOp,
43 MLIRContext *context = b.getContext();
44 std::string funcName =
"__enzyme_push_";
45 llvm::raw_string_ostream funcStream{funcName};
46 funcStream << elementType;
47 if (moduleOp.lookupSymbol<func::FuncOp>(funcName)) {
48 return SymbolRefAttr::get(context, funcName);
51 OpBuilder::InsertionGuard insertionGuard(b);
52 b.setInsertionPointToStart(moduleOp.getBody());
54 auto pushFnType = FunctionType::get(
56 {elements.getType(), size.getType(), capacity.getType(), elementType},
58 auto pushFn = func::FuncOp::create(b, loc, funcName, pushFnType);
60 Block *entryBlock = pushFn.addEntryBlock();
61 b.setInsertionPointToStart(entryBlock);
62 BlockArgument elementsField = pushFn.getArgument(0);
63 BlockArgument sizeField = pushFn.getArgument(1);
64 BlockArgument capacityField = pushFn.getArgument(2);
65 BlockArgument value = pushFn.getArgument(3);
67 Value sizeVal = memref::LoadOp::create(b, loc, sizeField);
68 Value capacityVal = memref::LoadOp::create(b, loc, capacityField);
70 Value predicate = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
71 sizeVal, capacityVal);
73 b, loc, predicate, [&](OpBuilder &thenBuilder, Location loc) {
74 Value two = arith::ConstantIndexOp::create(thenBuilder, loc, 2);
76 arith::MulIOp::create(thenBuilder, loc, capacityVal, two);
78 memref::LoadOp::create(thenBuilder, loc, elementsField);
79 Value newElements = memref::AllocOp::create(
80 thenBuilder, loc, cast<MemRefType>(oldElements.getType()),
82 memref::CopyOp::create(thenBuilder, loc, oldElements, newElements);
83 memref::DeallocOp::create(thenBuilder, loc, oldElements);
84 memref::StoreOp::create(thenBuilder, loc, newElements, elementsField);
85 memref::StoreOp::create(thenBuilder, loc, newCapacity, capacityField);
86 scf::YieldOp::create(thenBuilder, loc);
89 Value elementsVal = memref::LoadOp::create(b, loc, elementsField);
90 memref::StoreOp::create(b, loc, value, elementsVal,
93 Value one = arith::ConstantIndexOp::create(b, loc, 1);
94 Value newSize = arith::AddIOp::create(b, loc, sizeVal, one);
96 memref::StoreOp::create(b, loc, newSize, sizeField);
97 func::ReturnOp::create(b, loc);
99 return SymbolRefAttr::get(context, funcName);
102 FlatSymbolRefAttr getOrInsertPopFunction(Location loc, ModuleOp moduleOp,
104 MLIRContext *context = b.getContext();
105 std::string funcName =
"__enzyme_pop_";
106 llvm::raw_string_ostream funcStream{funcName};
107 funcStream << elementType;
108 if (moduleOp.lookupSymbol<func::FuncOp>(funcName)) {
109 return SymbolRefAttr::get(context, funcName);
112 OpBuilder::InsertionGuard insertionGuard(b);
113 b.setInsertionPointToStart(moduleOp.getBody());
114 auto popFnType = FunctionType::get(
116 {elements.getType(), size.getType(), capacity.getType()},
118 auto popFn = func::FuncOp::create(b, loc, funcName, popFnType);
120 Block *entryBlock = popFn.addEntryBlock();
121 b.setInsertionPointToStart(entryBlock);
122 BlockArgument elementsField = popFn.getArgument(0);
123 BlockArgument sizeField = popFn.getArgument(1);
125 Value elementsVal = memref::LoadOp::create(b, loc, elementsField);
126 Value sizeVal = memref::LoadOp::create(b, loc, sizeField);
127 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
129 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, sizeVal, zero);
130 cf::AssertOp::create(b, loc, pred,
"pop on empty cache");
132 Value one = arith::ConstantIndexOp::create(b, loc, 1);
133 Value newSize = arith::SubIOp::create(b, loc, sizeVal, one);
134 memref::StoreOp::create(b, loc, newSize, sizeField);
136 Value result = memref::LoadOp::create(b, loc, elementsVal, newSize);
137 func::ReturnOp::create(b, loc, result);
138 return SymbolRefAttr::get(context, funcName);
141 void emitPush(Location loc, Value value, OpBuilder &b,
142 FlatSymbolRefAttr pushFn)
const {
143 func::CallOp::create(
144 b, loc, pushFn, TypeRange{},
145 ValueRange{elements, size, capacity, value});
148 Value emitPop(Location loc, OpBuilder &b, FlatSymbolRefAttr popFn)
const {
149 return func::CallOp::create(
152 cast<ShapedType>(elements.getType()).getElementType())
154 ValueRange{elements, size, capacity})
158 FlatSymbolRefAttr getOrInsertGetFunction(Location loc, ModuleOp moduleOp,
160 MLIRContext *context = b.getContext();
161 std::string funcName =
"__enzyme_get_";
162 llvm::raw_string_ostream funcStream{funcName};
163 funcStream << elementType;
164 if (moduleOp.lookupSymbol<func::FuncOp>(funcName)) {
165 return SymbolRefAttr::get(context, funcName);
168 OpBuilder::InsertionGuard insertionGuard(b);
169 b.setInsertionPointToStart(moduleOp.getBody());
170 auto popFnType = FunctionType::get(
172 {elements.getType(), size.getType(), capacity.getType()},
174 auto popFn = func::FuncOp::create(b, loc, funcName, popFnType);
176 Block *entryBlock = popFn.addEntryBlock();
177 b.setInsertionPointToStart(entryBlock);
178 BlockArgument elementsField = popFn.getArgument(0);
179 BlockArgument sizeField = popFn.getArgument(1);
181 Value elementsVal = memref::LoadOp::create(b, loc, elementsField);
182 Value sizeVal = memref::LoadOp::create(b, loc, sizeField);
183 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
184 Value one = arith::ConstantIndexOp::create(b, loc, 1);
186 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, sizeVal, zero);
187 cf::AssertOp::create(b, loc, pred,
"get on empty cache");
189 Value lastIndex = arith::SubIOp::create(b, loc, sizeVal, one);
190 Value result = memref::LoadOp::create(b, loc, elementsVal, lastIndex);
191 func::ReturnOp::create(b, loc, result);
192 return SymbolRefAttr::get(context, funcName);
195 Value emitGet(Location loc, OpBuilder &b, FlatSymbolRefAttr getFn)
const {
196 return func::CallOp::create(
199 cast<ShapedType>(elements.getType()).getElementType())
201 ValueRange{elements, size, capacity})
204 static std::optional<LoweredCache>
205 getFromEnzymeCache(Location loc,
const TypeConverter *typeConverter,
206 Value enzymeCache, OpBuilder &b) {
207 assert(isa<enzyme::CacheType>(enzymeCache.getType()));
208 auto cacheType = cast<enzyme::CacheType>(enzymeCache.getType());
209 SmallVector<Type> resultTypes;
210 if (failed(typeConverter->convertType(cacheType, resultTypes))) {
214 UnrealizedConversionCastOp::create(b, loc, resultTypes, enzymeCache);
215 return LoweredCache{.elements = unpackedCache.getResult(0),
216 .size = unpackedCache.getResult(1),
217 .capacity = unpackedCache.getResult(2),
218 .elementType = cacheType.getType()};
222struct InitOpConversion :
public OpConversionPattern<enzyme::InitOp> {
223 using OpConversionPattern::OpConversionPattern;
226 matchAndRewrite(enzyme::InitOp op, OpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter)
const override {
232 if (isa<enzyme::GradientType>(op.getType())) {
234 cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
235 Value buffer = memref::AllocOp::create(rewriter, op.getLoc(), memrefType);
236 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(op, op.getType(),
241 if (isa<enzyme::CacheType>(op.getType())) {
242 SmallVector<Type> resultTypes;
243 if (failed(getTypeConverter()->convertType(op.getType(), resultTypes))) {
244 op.emitError() <<
"Failed to convert type " << op.getType();
248 Value capacity = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 1);
250 arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
251 auto dataType = cast<MemRefType>(resultTypes[0]);
252 auto sizeType = cast<MemRefType>(resultTypes[1]);
253 auto capacityType = cast<MemRefType>(resultTypes[2]);
254 Value buffer = memref::AllocOp::create(
255 rewriter, op.getLoc(), cast<MemRefType>(dataType.getElementType()),
258 memref::AllocaOp::create(rewriter, op.getLoc(), dataType);
260 memref::AllocaOp::create(rewriter, op.getLoc(), sizeType);
261 Value capacityField =
262 memref::AllocaOp::create(rewriter, op.getLoc(), capacityType);
263 memref::StoreOp::create(rewriter, op.getLoc(), buffer, bufferField);
264 memref::StoreOp::create(rewriter, op.getLoc(), initialSize, sizeField);
265 memref::StoreOp::create(rewriter, op.getLoc(), capacity, capacityField);
266 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
267 op, op.getType(), ValueRange{bufferField, sizeField, capacityField});
273 op.emitError() <<
"Expected cache or gradient type but got: "
279struct PushOpConversion :
public OpConversionPattern<enzyme::PushOp> {
280 using OpConversionPattern::OpConversionPattern;
283 matchAndRewrite(enzyme::PushOp op, OneToNOpAdaptor adaptor,
284 ConversionPatternRewriter &rewriter)
const override {
285 Location loc = op.getLoc();
286 auto loweredCache = LoweredCache::getFromEnzymeCache(
287 loc, getTypeConverter(), op.getCache(), rewriter);
288 if (!loweredCache.has_value()) {
292 FlatSymbolRefAttr pushFn = loweredCache.value().getOrInsertPushFunction(
293 loc, op->getParentOfType<ModuleOp>(), rewriter);
294 loweredCache.value().emitPush(loc, op.getValue(), rewriter, pushFn);
295 rewriter.eraseOp(op);
300struct PopOpConversion :
public OpConversionPattern<enzyme::PopOp> {
301 using OpConversionPattern::OpConversionPattern;
304 matchAndRewrite(enzyme::PopOp op, OneToNOpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter)
const override {
306 Location loc = op.getLoc();
307 auto loweredCache = LoweredCache::getFromEnzymeCache(
308 loc, getTypeConverter(), op.getCache(), rewriter);
309 if (!loweredCache.has_value()) {
313 FlatSymbolRefAttr popFn = loweredCache.value().getOrInsertPopFunction(
314 loc, op->getParentOfType<ModuleOp>(), rewriter);
315 rewriter.replaceOp(op, loweredCache.value().emitPop(loc, rewriter, popFn));
320struct SetOpConversion :
public OpConversionPattern<enzyme::SetOp> {
321 using OpConversionPattern::OpConversionPattern;
324 matchAndRewrite(enzyme::SetOp op, OpAdaptor adaptor,
325 ConversionPatternRewriter &rewriter)
const override {
326 if (
auto type = dyn_cast<enzyme::CacheType>(op.getGradient().getType())) {
327 op.emitError() <<
"set for CacheType not implemented";
329 }
else if (
auto type =
330 dyn_cast<enzyme::GradientType>(op.getGradient().getType())) {
331 auto memrefType = cast<MemRefType>(getTypeConverter()->convertType(type));
332 auto castedGradient = UnrealizedConversionCastOp::create(
333 rewriter, op.getLoc(), memrefType, op.getGradient());
334 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, op.getValue(),
335 castedGradient.getResult(0));
341struct GetOpConversion :
public OpConversionPattern<enzyme::GetOp> {
342 using OpConversionPattern<enzyme::GetOp>::OpConversionPattern;
345 matchAndRewrite(enzyme::GetOp op, OpAdaptor adaptor,
346 ConversionPatternRewriter &rewriter)
const override {
348 if (
auto type = dyn_cast<enzyme::CacheType>(op.getGradient().getType())) {
349 Location loc = op.getLoc();
350 auto loweredCache = LoweredCache::getFromEnzymeCache(
351 loc, getTypeConverter(), op.getGradient(), rewriter);
352 if (!loweredCache.has_value()) {
356 FlatSymbolRefAttr getFn = loweredCache.value().getOrInsertGetFunction(
357 loc, op->getParentOfType<ModuleOp>(), rewriter);
358 rewriter.replaceOp(op,
359 loweredCache.value().emitGet(loc, rewriter, getFn));
360 }
else if (
auto type =
361 dyn_cast<enzyme::GradientType>(op.getGradient().getType())) {
362 auto memrefType = cast<MemRefType>(getTypeConverter()->convertType(type));
363 auto castedGradient = UnrealizedConversionCastOp::create(
364 rewriter, op.getLoc(), memrefType, op.getGradient());
365 rewriter.replaceOpWithNewOp<memref::LoadOp>(op,
366 castedGradient.getResult(0));
372struct EnzymeToMemRefPass
373 :
public enzyme::impl::EnzymeOpsToMemRefPassBase<EnzymeToMemRefPass> {
374 void runOnOperation()
override {
375 MLIRContext *context = &getContext();
376 RewritePatternSet patterns(context);
377 TypeConverter typeConverter;
378 typeConverter.addConversion([](Type type) -> std::optional<Type> {
379 if (type.isIntOrIndexOrFloat() || isa<MemRefType>(type))
383 typeConverter.addConversion(
384 [](enzyme::GradientType type) -> std::optional<Type> {
385 return MemRefType::get({}, type.getBasetype());
387 typeConverter.addConversion(
388 [](enzyme::CacheType type, SmallVectorImpl<Type> &resultTypes) {
390 resultTypes.push_back(MemRefType::get(
391 {}, MemRefType::get({ShapedType::kDynamic}, type.getType())));
392 auto indexMemRefType =
393 MemRefType::get({}, IndexType::get(type.getContext()));
395 resultTypes.push_back(indexMemRefType);
397 resultTypes.push_back(indexMemRefType);
401 patterns.add<InitOpConversion>(typeConverter, context);
402 patterns.add<PushOpConversion>(typeConverter, context);
403 patterns.add<PopOpConversion>(typeConverter, context);
404 patterns.add<SetOpConversion>(typeConverter, context);
405 patterns.add<GetOpConversion>(typeConverter, context);
407 ConversionTarget target(*context);
408 target.addLegalDialect<memref::MemRefDialect>();
409 target.addLegalDialect<arith::ArithDialect>();
410 target.addLegalDialect<scf::SCFDialect>();
411 target.addLegalDialect<cf::ControlFlowDialect>();
412 target.addLegalDialect<func::FuncDialect>();
413 target.addLegalOp<UnrealizedConversionCastOp>();
414 target.addIllegalDialect<enzyme::EnzymeDialect>();
416 if (failed(applyPartialConversion(getOperation(), target,
417 std::move(patterns))))