Enzyme main
Loading...
Searching...
No Matches
EnzymeBatchPass.cpp
Go to the documentation of this file.
1//===- EnzymeBatchPass.cpp - Replace calls with their batched versions
2//------------ //
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into
11// a generic parallel for representation
12//===----------------------------------------------------------------------===//
13
15#include "Dialect/Ops.h"
17#include "PassDetails.h"
18#include "Passes/Passes.h"
19
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/IR/Builders.h"
22#include "mlir/Interfaces/FunctionInterfaces.h"
23
24#define DEBUG_TYPE "enzyme-batch"
25
26using namespace mlir;
27using namespace mlir::enzyme;
28using namespace enzyme;
29
30namespace mlir {
31namespace enzyme {
32#define GEN_PASS_DEF_BATCHPASS
33#include "Passes/Passes.h.inc"
34} // namespace enzyme
35} // namespace mlir
36
37namespace mlir {
38namespace enzyme {
39namespace batchutils {
40
41mlir::TensorType applyBatchSizes(mlir::Type Ty,
42 llvm::ArrayRef<int64_t> batchSizes) {
43 auto T = dyn_cast<TensorType>(Ty);
44 if (!T) {
45 return RankedTensorType::get(batchSizes, Ty);
46 }
47
48 SmallVector<int64_t> shape(batchSizes.begin(), batchSizes.end());
49 shape.append(T.getShape().begin(), T.getShape().end());
50 auto T2 = T.clone(shape);
51 return T2;
52}
53
54LogicalResult handleCallOp(
55 func::CallOp callOp, OpBuilder &builder, IRMapping &mapper,
56 llvm::ArrayRef<int64_t> batchSizes,
57 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
58 // Get the called function
59 auto moduleOp = callOp->getParentOfType<ModuleOp>();
60 auto calledFunc =
61 dyn_cast<FunctionOpInterface>(moduleOp.lookupSymbol(callOp.getCallee()));
62 if (!calledFunc)
63 return failure();
64
65 // Create cache key for this function and batch size combination
66 BatchCacheKey key{calledFunc,
67 SmallVector<int64_t>(batchSizes.begin(), batchSizes.end())};
68
69 // Look up or create batched version of the called function
70 FunctionOpInterface batchedFunc;
71 auto it = batchedFunctionCache.find(key);
72 if (it != batchedFunctionCache.end()) {
73 batchedFunc = it->second;
74 } else {
75 std::string fnName = "batched_" + calledFunc.getName().str();
76 batchedFunc = batchCloneFunction(builder, calledFunc, fnName, batchSizes,
77 batchedFunctionCache);
78 if (!batchedFunc)
79 return failure();
80 batchedFunctionCache[key] = batchedFunc;
81 }
82
83 // Create new call operation to the batched function
84 SmallVector<Value> newOperands;
85 for (auto operand : callOp->getOperands())
86 newOperands.push_back(mapper.lookup(operand));
87
88 auto newCall =
89 func::CallOp::create(builder, callOp.getLoc(), batchedFunc.getName(),
90 batchedFunc.getResultTypes(), newOperands);
91
92 // Map the results
93 for (auto [oldResult, newResult] :
94 llvm::zip(callOp.getResults(), newCall.getResults()))
95 mapper.map(oldResult, newResult);
96
97 return success();
98}
99
101 OpBuilder &builder, Block *blk, IRMapping &mapper,
102 llvm::ArrayRef<int64_t> batchSizes,
103 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache,
104 bool withoutTerminator) {
105 for (auto &src : *blk) {
106 if (auto callOp = dyn_cast<func::CallOp>(&src)) {
107 if (succeeded(handleCallOp(callOp, builder, mapper, batchSizes,
108 batchedFunctionCache)))
109 continue;
110 }
111
112 if (auto ifaceOp = dyn_cast<BatchOpInterface>(&src)) {
113 auto res = ifaceOp.createBatch(builder, mapper, batchSizes);
114 if (res.succeeded())
115 continue;
116 }
117
118 SmallVector<Value, 8> operands;
119 SmallVector<Block *, 2> successors;
120
121 // Remap the operands.
122 operands.reserve(src.getNumOperands());
123 for (auto opValue : src.getOperands())
124 operands.push_back(mapper.lookup(opValue));
125
126 if (withoutTerminator && src.hasTrait<OpTrait::IsTerminator>()) {
127 // map the operands and the results
128 for (unsigned i = 0, e = src.getNumResults(); i != e; ++i)
129 mapper.map(src.getResult(i), operands[i]);
130 continue;
131 }
132
133 // Remap the successors.
134 successors.reserve(src.getNumSuccessors());
135 for (Block *successor : src.getSuccessors())
136 successors.push_back(mapper.lookup(successor));
137
138 SmallVector<Type> resultTypes(src.getResultTypes().begin(),
139 src.getResultTypes().end());
140 for (auto &Ty : resultTypes) {
141 Ty = applyBatchSizes(Ty, batchSizes);
142 }
143
144 Operation *newOp = Operation::create(
145 src.getLoc(), src.getName(), resultTypes, operands, src.getAttrs(),
146 mlir::PropertyRef(), successors, src.getNumRegions());
147
148 // Clone the regions.
149 for (auto &&[oldReg, newReg] :
150 llvm::zip(src.getRegions(), newOp->getRegions())) {
151 batchCloneRegion(builder, &oldReg, &newReg, mapper, batchSizes,
152 batchedFunctionCache);
153 }
154
155 // Remember the mapping of any results.
156 for (unsigned i = 0, e = src.getNumResults(); i != e; ++i)
157 mapper.map(src.getResult(i), newOp->getResult(i));
158
159 builder.insert(newOp);
160 }
161}
162
164 OpBuilder &builder, Region *src, Region *dest, IRMapping &mapper,
165 llvm::ArrayRef<int64_t> batchSizes,
166 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
167 // For each block in src, generate a corresponding block in the dest region.
168 for (auto &blk : *src) {
169 auto newBlk = new Block();
170 dest->push_back(newBlk);
171
172 mapper.map(&blk, newBlk);
173
174 for (auto arg : blk.getArguments()) {
175 Value newArg = newBlk->addArgument(
176 applyBatchSizes(arg.getType(), batchSizes), arg.getLoc());
177 mapper.map(arg, newArg);
178 }
179 }
180
181 for (auto &&[blk, newBlk] : llvm::zip(*src, *dest)) {
182 IRRewriter::InsertionGuard insertGuard(builder);
183 builder.setInsertionPointToEnd(&newBlk);
184 batchCloneBlock(builder, &blk, mapper, batchSizes, batchedFunctionCache,
185 false);
186 }
187}
188
189FunctionOpInterface batchCloneFunction(
190 OpBuilder &builder, FunctionOpInterface F, Twine name,
191 llvm::ArrayRef<int64_t> batchSizes,
192 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
193 assert(!F.getFunctionBody().empty());
194
195 auto FTy = cast<FunctionType>(F.getFunctionType());
196
197 llvm::SmallVector<mlir::Type> RetTypes;
198 RetTypes.reserve(FTy.getNumResults());
199
200 for (auto Ty : FTy.getResults()) {
201 RetTypes.push_back(applyBatchSizes(Ty, batchSizes));
202 }
203
204 SmallVector<mlir::Type, 4> ArgTypes;
205 ArgTypes.reserve(FTy.getNumInputs());
206
207 for (auto Ty : FTy.getInputs()) {
208 ArgTypes.push_back(applyBatchSizes(Ty, batchSizes));
209 }
210
211 FunctionType newFTy = builder.getFunctionType(ArgTypes, RetTypes);
212
213 auto NewF = cast<FunctionOpInterface>(F->cloneWithoutRegions());
214 SymbolTable::setSymbolName(NewF, name.str());
215 NewF.setType(newFTy);
216
217 Operation *parent = F->getParentWithTrait<OpTrait::SymbolTable>();
218 SymbolTable table(parent);
219 table.insert(NewF);
220 SymbolTable::setSymbolVisibility(NewF, SymbolTable::Visibility::Private);
221
222 // Add the function to the cache BEFORE processing its body to support
223 // recursion.
224 BatchCacheKey key{F,
225 SmallVector<int64_t>(batchSizes.begin(), batchSizes.end())};
226 batchedFunctionCache[key] = NewF;
227
228 auto &origReg = F.getFunctionBody();
229 auto &newReg = NewF.getFunctionBody();
230
231 IRMapping mapper;
232 batchCloneRegion(builder, &origReg, &newReg, mapper, batchSizes,
233 batchedFunctionCache);
234
235 return NewF;
236}
237
238template <typename T>
240 OpBuilder &builder, T CI, FunctionOpInterface fn,
241 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
243 fn, SmallVector<int64_t>(CI.getBatchShape().begin(),
244 CI.getBatchShape().end())};
245
246 // Check if we already have a batched version
247 auto it = batchedFunctionCache.find(key);
248 FunctionOpInterface newFunc;
249
250 if (it != batchedFunctionCache.end()) {
251 return it->second;
252 } else {
253 // Create new batched function and store in cache
254 std::string newFnName = "batched_" + fn.getName().str();
255 newFunc = batchCloneFunction(builder, fn, newFnName, CI.getBatchShape(),
256 batchedFunctionCache);
257 return newFunc;
258 }
259}
260
261} // namespace batchutils
262} // namespace enzyme
263} // namespace mlir
264
265namespace {
266
267struct BatchPass : public enzyme::impl::BatchPassBase<BatchPass> {
268 void runOnOperation() override;
269
270 // Cache mapping original function and batch sizes to batched function
271 std::map<enzyme::batchutils::BatchCacheKey, FunctionOpInterface>
272 batchedFunctionCache;
273
274 void lowerEnzymeBatchCalls(SymbolTableCollection &symbolTable,
275 FunctionOpInterface op) {
276 {
277 SmallVector<Operation *> toLower;
278 op->walk([&](enzyme::BatchOp dop) {
279 auto *symbolOp =
280 symbolTable.lookupNearestSymbolFrom(dop, dop.getFnAttr());
281 auto callableOp = cast<FunctionOpInterface>(symbolOp);
282
283 lowerEnzymeBatchCalls(symbolTable, callableOp);
284 toLower.push_back(dop);
285 });
286
287 OpBuilder builder(op);
288
289 for (auto T : toLower) {
290 if (auto F = dyn_cast<enzyme::BatchOp>(T)) {
291 auto res = enzyme::batchutils::batchOperation(symbolTable, builder, F,
292 batchedFunctionCache);
293 if (!res.succeeded()) {
294 signalPassFailure();
295 return;
296 }
297 } else {
298 llvm_unreachable("Illegal type");
299 }
300 }
301 };
302 };
303};
304
305} // end anonymous namespace
306
307void BatchPass::runOnOperation() {
308 SymbolTableCollection symbolTable;
309 symbolTable.getSymbolTable(getOperation());
310 getOperation()->walk(
311 [&](FunctionOpInterface op) { lowerEnzymeBatchCalls(symbolTable, op); });
312}
FunctionOpInterface batchOperationWithoutInsertingCallOp(OpBuilder &builder, T CI, FunctionOpInterface fn, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
LogicalResult batchOperation(SymbolTableCollection &symbolTable, OpBuilder &builder, T CI, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
mlir::TensorType applyBatchSizes(mlir::Type Ty, llvm::ArrayRef< int64_t > batchSizes)
FunctionOpInterface batchCloneFunction(OpBuilder &builder, FunctionOpInterface F, Twine name, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
void batchCloneBlock(OpBuilder &builder, Block *blk, IRMapping &mapper, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache, bool withoutTerminator)
LogicalResult handleCallOp(func::CallOp callOp, OpBuilder &builder, IRMapping &mapper, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
void batchCloneRegion(OpBuilder &builder, Region *src, Region *dest, IRMapping &mapper, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)