Enzyme main
Loading...
Searching...
No Matches
EnzymeToMemRef.cpp
Go to the documentation of this file.
1//===- EnzymeToMemRef.cpp - Lower custom Enzyme operations ------------------ //
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to lower custom ops generated by the Enzyme AD
10// procedure to the MemRef dialect.
11//===----------------------------------------------------------------------===//
12
13#include "Dialect/Dialect.h"
14#include "Dialect/Ops.h"
15#include "PassDetails.h"
16#include "Passes/Passes.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/SCF/IR/SCF.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/Support/raw_ostream.h"
24
25namespace mlir {
26namespace enzyme {
27#define GEN_PASS_DEF_ENZYMEOPSTOMEMREFPASS
28#include "Passes/Passes.h.inc"
29} // namespace enzyme
30} // namespace mlir
31
32using namespace mlir;
33using llvm::errs;
34namespace {
35struct LoweredCache {
36 Value elements, size, capacity;
37 Type elementType;
38
39 void allocateCache(Location loc, OpBuilder &b) {}
40
41 FlatSymbolRefAttr getOrInsertPushFunction(Location loc, ModuleOp moduleOp,
42 OpBuilder &b) const {
43 MLIRContext *context = b.getContext();
44 std::string funcName = "__enzyme_push_";
45 llvm::raw_string_ostream funcStream{funcName};
46 funcStream << elementType;
47 if (moduleOp.lookupSymbol<func::FuncOp>(funcName)) {
48 return SymbolRefAttr::get(context, funcName);
49 }
50
51 OpBuilder::InsertionGuard insertionGuard(b);
52 b.setInsertionPointToStart(moduleOp.getBody());
53
54 auto pushFnType = FunctionType::get(
55 context, /*inputs=*/
56 {elements.getType(), size.getType(), capacity.getType(), elementType},
57 /*outputs=*/{});
58 auto pushFn = func::FuncOp::create(b, loc, funcName, pushFnType);
59 pushFn.setPrivate();
60 Block *entryBlock = pushFn.addEntryBlock();
61 b.setInsertionPointToStart(entryBlock);
62 BlockArgument elementsField = pushFn.getArgument(0);
63 BlockArgument sizeField = pushFn.getArgument(1);
64 BlockArgument capacityField = pushFn.getArgument(2);
65 BlockArgument value = pushFn.getArgument(3);
66
67 Value sizeVal = memref::LoadOp::create(b, loc, sizeField);
68 Value capacityVal = memref::LoadOp::create(b, loc, capacityField);
69
70 Value predicate = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
71 sizeVal, capacityVal);
72 scf::IfOp::create(
73 b, loc, predicate, [&](OpBuilder &thenBuilder, Location loc) {
74 Value two = arith::ConstantIndexOp::create(thenBuilder, loc, 2);
75 Value newCapacity =
76 arith::MulIOp::create(thenBuilder, loc, capacityVal, two);
77 Value oldElements =
78 memref::LoadOp::create(thenBuilder, loc, elementsField);
79 Value newElements = memref::AllocOp::create(
80 thenBuilder, loc, cast<MemRefType>(oldElements.getType()),
81 newCapacity);
82 memref::CopyOp::create(thenBuilder, loc, oldElements, newElements);
83 memref::DeallocOp::create(thenBuilder, loc, oldElements);
84 memref::StoreOp::create(thenBuilder, loc, newElements, elementsField);
85 memref::StoreOp::create(thenBuilder, loc, newCapacity, capacityField);
86 scf::YieldOp::create(thenBuilder, loc);
87 });
88
89 Value elementsVal = memref::LoadOp::create(b, loc, elementsField);
90 memref::StoreOp::create(b, loc, value, elementsVal,
91 /*indices=*/sizeVal);
92
93 Value one = arith::ConstantIndexOp::create(b, loc, 1);
94 Value newSize = arith::AddIOp::create(b, loc, sizeVal, one);
95
96 memref::StoreOp::create(b, loc, newSize, sizeField);
97 func::ReturnOp::create(b, loc);
98
99 return SymbolRefAttr::get(context, funcName);
100 }
101
102 FlatSymbolRefAttr getOrInsertPopFunction(Location loc, ModuleOp moduleOp,
103 OpBuilder &b) {
104 MLIRContext *context = b.getContext();
105 std::string funcName = "__enzyme_pop_";
106 llvm::raw_string_ostream funcStream{funcName};
107 funcStream << elementType;
108 if (moduleOp.lookupSymbol<func::FuncOp>(funcName)) {
109 return SymbolRefAttr::get(context, funcName);
110 }
111
112 OpBuilder::InsertionGuard insertionGuard(b);
113 b.setInsertionPointToStart(moduleOp.getBody());
114 auto popFnType = FunctionType::get(
115 context,
116 /*inputs=*/{elements.getType(), size.getType(), capacity.getType()},
117 /*outputs=*/elementType);
118 auto popFn = func::FuncOp::create(b, loc, funcName, popFnType);
119 popFn.setPrivate();
120 Block *entryBlock = popFn.addEntryBlock();
121 b.setInsertionPointToStart(entryBlock);
122 BlockArgument elementsField = popFn.getArgument(0);
123 BlockArgument sizeField = popFn.getArgument(1);
124
125 Value elementsVal = memref::LoadOp::create(b, loc, elementsField);
126 Value sizeVal = memref::LoadOp::create(b, loc, sizeField);
127 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
128 Value pred =
129 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, sizeVal, zero);
130 cf::AssertOp::create(b, loc, pred, "pop on empty cache");
131
132 Value one = arith::ConstantIndexOp::create(b, loc, 1);
133 Value newSize = arith::SubIOp::create(b, loc, sizeVal, one);
134 memref::StoreOp::create(b, loc, newSize, sizeField);
135
136 Value result = memref::LoadOp::create(b, loc, elementsVal, newSize);
137 func::ReturnOp::create(b, loc, result);
138 return SymbolRefAttr::get(context, funcName);
139 }
140
141 void emitPush(Location loc, Value value, OpBuilder &b,
142 FlatSymbolRefAttr pushFn) const {
143 func::CallOp::create(
144 b, loc, pushFn, /*results=*/TypeRange{},
145 /*operands=*/ValueRange{elements, size, capacity, value});
146 }
147
148 Value emitPop(Location loc, OpBuilder &b, FlatSymbolRefAttr popFn) const {
149 return func::CallOp::create(
150 b, loc, popFn, /*results=*/
151 cast<ShapedType>(
152 cast<ShapedType>(elements.getType()).getElementType())
153 .getElementType(),
154 ValueRange{elements, size, capacity})
155 .getResult(0);
156 }
157
158 FlatSymbolRefAttr getOrInsertGetFunction(Location loc, ModuleOp moduleOp,
159 OpBuilder &b) {
160 MLIRContext *context = b.getContext();
161 std::string funcName = "__enzyme_get_";
162 llvm::raw_string_ostream funcStream{funcName};
163 funcStream << elementType;
164 if (moduleOp.lookupSymbol<func::FuncOp>(funcName)) {
165 return SymbolRefAttr::get(context, funcName);
166 }
167
168 OpBuilder::InsertionGuard insertionGuard(b);
169 b.setInsertionPointToStart(moduleOp.getBody());
170 auto popFnType = FunctionType::get(
171 context,
172 /*inputs=*/{elements.getType(), size.getType(), capacity.getType()},
173 /*outputs=*/elementType);
174 auto popFn = func::FuncOp::create(b, loc, funcName, popFnType);
175 popFn.setPrivate();
176 Block *entryBlock = popFn.addEntryBlock();
177 b.setInsertionPointToStart(entryBlock);
178 BlockArgument elementsField = popFn.getArgument(0);
179 BlockArgument sizeField = popFn.getArgument(1);
180
181 Value elementsVal = memref::LoadOp::create(b, loc, elementsField);
182 Value sizeVal = memref::LoadOp::create(b, loc, sizeField);
183 Value zero = arith::ConstantIndexOp::create(b, loc, 0);
184 Value one = arith::ConstantIndexOp::create(b, loc, 1);
185 Value pred =
186 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, sizeVal, zero);
187 cf::AssertOp::create(b, loc, pred, "get on empty cache");
188
189 Value lastIndex = arith::SubIOp::create(b, loc, sizeVal, one);
190 Value result = memref::LoadOp::create(b, loc, elementsVal, lastIndex);
191 func::ReturnOp::create(b, loc, result);
192 return SymbolRefAttr::get(context, funcName);
193 }
194
195 Value emitGet(Location loc, OpBuilder &b, FlatSymbolRefAttr getFn) const {
196 return func::CallOp::create(
197 b, loc, getFn, /*results=*/
198 cast<ShapedType>(
199 cast<ShapedType>(elements.getType()).getElementType())
200 .getElementType(),
201 ValueRange{elements, size, capacity})
202 .getResult(0);
203 }
204 static std::optional<LoweredCache>
205 getFromEnzymeCache(Location loc, const TypeConverter *typeConverter,
206 Value enzymeCache, OpBuilder &b) {
207 assert(isa<enzyme::CacheType>(enzymeCache.getType()));
208 auto cacheType = cast<enzyme::CacheType>(enzymeCache.getType());
209 SmallVector<Type> resultTypes;
210 if (failed(typeConverter->convertType(cacheType, resultTypes))) {
211 return {};
212 }
213 auto unpackedCache =
214 UnrealizedConversionCastOp::create(b, loc, resultTypes, enzymeCache);
215 return LoweredCache{.elements = unpackedCache.getResult(0),
216 .size = unpackedCache.getResult(1),
217 .capacity = unpackedCache.getResult(2),
218 .elementType = cacheType.getType()};
219 }
220};
221
222struct InitOpConversion : public OpConversionPattern<enzyme::InitOp> {
223 using OpConversionPattern::OpConversionPattern;
224
225 LogicalResult
226 matchAndRewrite(enzyme::InitOp op, OpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter) const override {
228 // `enzyme.init` is overloaded to initialize both gradients and caches.
229 // Gradients lower to single element MemRefs, while caches lower to
230 // variable-sized MemRefs.
231
232 if (isa<enzyme::GradientType>(op.getType())) {
233 auto memrefType =
234 cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
235 Value buffer = memref::AllocOp::create(rewriter, op.getLoc(), memrefType);
236 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(op, op.getType(),
237 buffer);
238 return success();
239 }
240
241 if (isa<enzyme::CacheType>(op.getType())) {
242 SmallVector<Type> resultTypes;
243 if (failed(getTypeConverter()->convertType(op.getType(), resultTypes))) {
244 op.emitError() << "Failed to convert type " << op.getType();
245 return failure();
246 }
247
248 Value capacity = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 1);
249 Value initialSize =
250 arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
251 auto dataType = cast<MemRefType>(resultTypes[0]);
252 auto sizeType = cast<MemRefType>(resultTypes[1]);
253 auto capacityType = cast<MemRefType>(resultTypes[2]);
254 Value buffer = memref::AllocOp::create(
255 rewriter, op.getLoc(), cast<MemRefType>(dataType.getElementType()),
256 /*dynamicSize=*/capacity);
257 Value bufferField =
258 memref::AllocaOp::create(rewriter, op.getLoc(), dataType);
259 Value sizeField =
260 memref::AllocaOp::create(rewriter, op.getLoc(), sizeType);
261 Value capacityField =
262 memref::AllocaOp::create(rewriter, op.getLoc(), capacityType);
263 memref::StoreOp::create(rewriter, op.getLoc(), buffer, bufferField);
264 memref::StoreOp::create(rewriter, op.getLoc(), initialSize, sizeField);
265 memref::StoreOp::create(rewriter, op.getLoc(), capacity, capacityField);
266 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
267 op, op.getType(), ValueRange{bufferField, sizeField, capacityField});
268 return success();
269 }
270
271 // TODO: Add verification to the init op to verify the valid types,
272 // or break out init gradient semantics from init cache semantics
273 op.emitError() << "Expected cache or gradient type but got: "
274 << op.getType();
275 return failure();
276 }
277};
278
279struct PushOpConversion : public OpConversionPattern<enzyme::PushOp> {
280 using OpConversionPattern::OpConversionPattern;
281
282 LogicalResult
283 matchAndRewrite(enzyme::PushOp op, OneToNOpAdaptor adaptor,
284 ConversionPatternRewriter &rewriter) const override {
285 Location loc = op.getLoc();
286 auto loweredCache = LoweredCache::getFromEnzymeCache(
287 loc, getTypeConverter(), op.getCache(), rewriter);
288 if (!loweredCache.has_value()) {
289 return failure();
290 }
291
292 FlatSymbolRefAttr pushFn = loweredCache.value().getOrInsertPushFunction(
293 loc, op->getParentOfType<ModuleOp>(), rewriter);
294 loweredCache.value().emitPush(loc, op.getValue(), rewriter, pushFn);
295 rewriter.eraseOp(op);
296 return success();
297 }
298};
299
300struct PopOpConversion : public OpConversionPattern<enzyme::PopOp> {
301 using OpConversionPattern::OpConversionPattern;
302
303 LogicalResult
304 matchAndRewrite(enzyme::PopOp op, OneToNOpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter) const override {
306 Location loc = op.getLoc();
307 auto loweredCache = LoweredCache::getFromEnzymeCache(
308 loc, getTypeConverter(), op.getCache(), rewriter);
309 if (!loweredCache.has_value()) {
310 return failure();
311 }
312
313 FlatSymbolRefAttr popFn = loweredCache.value().getOrInsertPopFunction(
314 loc, op->getParentOfType<ModuleOp>(), rewriter);
315 rewriter.replaceOp(op, loweredCache.value().emitPop(loc, rewriter, popFn));
316 return success();
317 }
318};
319
320struct SetOpConversion : public OpConversionPattern<enzyme::SetOp> {
321 using OpConversionPattern::OpConversionPattern;
322
323 LogicalResult
324 matchAndRewrite(enzyme::SetOp op, OpAdaptor adaptor,
325 ConversionPatternRewriter &rewriter) const override {
326 if (auto type = dyn_cast<enzyme::CacheType>(op.getGradient().getType())) {
327 op.emitError() << "set for CacheType not implemented";
328 return failure();
329 } else if (auto type =
330 dyn_cast<enzyme::GradientType>(op.getGradient().getType())) {
331 auto memrefType = cast<MemRefType>(getTypeConverter()->convertType(type));
332 auto castedGradient = UnrealizedConversionCastOp::create(
333 rewriter, op.getLoc(), memrefType, op.getGradient());
334 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, op.getValue(),
335 castedGradient.getResult(0));
336 }
337 return success();
338 }
339};
340
341struct GetOpConversion : public OpConversionPattern<enzyme::GetOp> {
342 using OpConversionPattern<enzyme::GetOp>::OpConversionPattern;
343
344 LogicalResult
345 matchAndRewrite(enzyme::GetOp op, OpAdaptor adaptor,
346 ConversionPatternRewriter &rewriter) const override {
347
348 if (auto type = dyn_cast<enzyme::CacheType>(op.getGradient().getType())) {
349 Location loc = op.getLoc();
350 auto loweredCache = LoweredCache::getFromEnzymeCache(
351 loc, getTypeConverter(), op.getGradient(), rewriter);
352 if (!loweredCache.has_value()) {
353 return failure();
354 }
355
356 FlatSymbolRefAttr getFn = loweredCache.value().getOrInsertGetFunction(
357 loc, op->getParentOfType<ModuleOp>(), rewriter);
358 rewriter.replaceOp(op,
359 loweredCache.value().emitGet(loc, rewriter, getFn));
360 } else if (auto type =
361 dyn_cast<enzyme::GradientType>(op.getGradient().getType())) {
362 auto memrefType = cast<MemRefType>(getTypeConverter()->convertType(type));
363 auto castedGradient = UnrealizedConversionCastOp::create(
364 rewriter, op.getLoc(), memrefType, op.getGradient());
365 rewriter.replaceOpWithNewOp<memref::LoadOp>(op,
366 castedGradient.getResult(0));
367 }
368 return success();
369 }
370};
371
372struct EnzymeToMemRefPass
373 : public enzyme::impl::EnzymeOpsToMemRefPassBase<EnzymeToMemRefPass> {
374 void runOnOperation() override {
375 MLIRContext *context = &getContext();
376 RewritePatternSet patterns(context);
377 TypeConverter typeConverter;
378 typeConverter.addConversion([](Type type) -> std::optional<Type> {
379 if (type.isIntOrIndexOrFloat() || isa<MemRefType>(type))
380 return type;
381 return {};
382 });
383 typeConverter.addConversion(
384 [](enzyme::GradientType type) -> std::optional<Type> {
385 return MemRefType::get({}, type.getBasetype());
386 });
387 typeConverter.addConversion(
388 [](enzyme::CacheType type, SmallVectorImpl<Type> &resultTypes) {
389 // Data
390 resultTypes.push_back(MemRefType::get(
391 {}, MemRefType::get({ShapedType::kDynamic}, type.getType())));
392 auto indexMemRefType =
393 MemRefType::get({}, IndexType::get(type.getContext()));
394 // Size
395 resultTypes.push_back(indexMemRefType);
396 // Capacity
397 resultTypes.push_back(indexMemRefType);
398 return success();
399 });
400
401 patterns.add<InitOpConversion>(typeConverter, context);
402 patterns.add<PushOpConversion>(typeConverter, context);
403 patterns.add<PopOpConversion>(typeConverter, context);
404 patterns.add<SetOpConversion>(typeConverter, context);
405 patterns.add<GetOpConversion>(typeConverter, context);
406
407 ConversionTarget target(*context);
408 target.addLegalDialect<memref::MemRefDialect>();
409 target.addLegalDialect<arith::ArithDialect>();
410 target.addLegalDialect<scf::SCFDialect>();
411 target.addLegalDialect<cf::ControlFlowDialect>();
412 target.addLegalDialect<func::FuncDialect>();
413 target.addLegalOp<UnrealizedConversionCastOp>();
414 target.addIllegalDialect<enzyme::EnzymeDialect>();
415
416 if (failed(applyPartialConversion(getOperation(), target,
417 std::move(patterns))))
418 signalPassFailure();
419 };
420};
421} // end anonymous namespace