Enzyme main
Loading...
Searching...
No Matches
EnzymeBatchPass.h
Go to the documentation of this file.
1#ifndef ENZYME_BATCH_PASS_H
2#define ENZYME_BATCH_PASS_H
3
4#include "Dialect/Ops.h"
6#include "PassDetails.h"
7#include "Passes/Passes.h"
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10#include "mlir/IR/Builders.h"
11#include "mlir/Interfaces/FunctionInterfaces.h"
12
13namespace mlir {
14namespace enzyme {
15namespace batchutils {
16
18 FunctionOpInterface function;
19 SmallVector<int64_t> batchSizes;
20
21 // for use in std::map:
22 bool operator<(const BatchCacheKey &other) const {
23 if (const_cast<FunctionOpInterface &>(function).getName() !=
24 const_cast<FunctionOpInterface &>(other.function).getName())
25 return const_cast<FunctionOpInterface &>(function).getName() <
26 const_cast<FunctionOpInterface &>(other.function).getName();
27 return batchSizes < other.batchSizes;
28 }
29};
30
31mlir::TensorType applyBatchSizes(mlir::Type Ty,
32 llvm::ArrayRef<int64_t> batchSizes);
33
34FunctionOpInterface batchCloneFunction(
35 OpBuilder &builder, FunctionOpInterface F, Twine name,
36 llvm::ArrayRef<int64_t> batchSizes,
37 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache);
38
40 OpBuilder &builder, Block *blk, IRMapping &mapper,
41 llvm::ArrayRef<int64_t> batchSizes,
42 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache,
43 bool withoutTerminator);
44
46 OpBuilder &builder, Region *src, Region *dest, IRMapping &mapper,
47 llvm::ArrayRef<int64_t> batchSizes,
48 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache);
49
50LogicalResult handleCallOp(
51 func::CallOp callOp, OpBuilder &builder, IRMapping &mapper,
52 llvm::ArrayRef<int64_t> batchSizes,
53 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache);
54
55template <typename T>
56FunctionOpInterface batchOperationWithoutInsertingCallOp(
57 OpBuilder &builder, T CI, FunctionOpInterface fn,
58 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache);
59
60template <typename T>
61LogicalResult batchOperation(
62 SymbolTableCollection &symbolTable, OpBuilder &builder, T CI,
63 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
64
65 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr());
66 return batchOperation(builder, CI, cast<FunctionOpInterface>(symbolOp),
67 batchedFunctionCache);
68}
69
70template <typename T>
71LogicalResult batchOperation(
72 SymbolTableCollection &symbolTable, PatternRewriter &rewriter, T CI,
73 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
74 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr());
75 return batchOperation(rewriter, CI, cast<FunctionOpInterface>(symbolOp),
76 batchedFunctionCache);
77}
78
79template <typename T>
80LogicalResult batchOperation(
81 OpBuilder &builder, T CI, FunctionOpInterface fn,
82 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
83 auto newFunc = batchOperationWithoutInsertingCallOp(builder, CI, fn,
84 batchedFunctionCache);
85
86 if (!newFunc)
87 return failure();
88
89 {
90 IRRewriter::InsertionGuard insertGuard(builder);
91 builder.setInsertionPoint(CI);
92 auto dCI = func::CallOp::create(builder, CI.getLoc(), newFunc.getName(),
93 newFunc.getResultTypes(), CI.getInputs());
94 CI.replaceAllUsesWith(dCI);
95 CI->erase();
96 }
97 return success();
98}
99
100template <typename T>
101LogicalResult batchOperation(
102 PatternRewriter &rewriter, T CI, FunctionOpInterface fn,
103 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
104 auto newFunc = batchOperationWithoutInsertingCallOp(rewriter, CI, fn,
105 batchedFunctionCache);
106
107 if (!newFunc)
108 return failure();
109
110 {
111 IRRewriter::InsertionGuard insertGuard(rewriter);
112 rewriter.setInsertionPoint(CI);
113 rewriter.replaceOpWithNewOp<func::CallOp>(
114 CI, newFunc.getName(), newFunc.getResultTypes(), CI.getInputs());
115 }
116 return success();
117}
118
119// instead of inserting a call op, we will inline each operation directly
120// into the caller
121inline void batchOperationInline(PatternRewriter &rewriter,
122 enzyme::BatchOp batchOp,
123 FunctionOpInterface func) {
124 auto &origRegion = func.getFunctionBody();
125 auto &origBlock = origRegion.front();
126
127 IRMapping mapper;
128 for (int i = 0; i < batchOp->getNumOperands(); i++) {
129 mapper.map(origBlock.getArguments()[i], batchOp->getOperand(i));
130 }
131
132 rewriter.setInsertionPoint(batchOp);
133 std::map<BatchCacheKey, FunctionOpInterface> batchedFunctionCache;
134 batchCloneBlock(rewriter, &origBlock, mapper, batchOp.getBatchShape(),
135 batchedFunctionCache, true);
136
137 auto origTerm = origBlock.getTerminator();
138 for (auto [i, operand] : llvm::enumerate(origTerm->getOperands())) {
139 auto mappedOperand = mapper.lookup(operand);
140 rewriter.replaceAllUsesWith(batchOp->getResult(i), mappedOperand);
141 }
142 rewriter.eraseOp(batchOp);
143 rewriter.eraseOp(func);
144}
145
146} // namespace batchutils
147} // namespace enzyme
148} // namespace mlir
149
150#endif // ENZYME_BATCH_PASS_H
FunctionOpInterface batchOperationWithoutInsertingCallOp(OpBuilder &builder, T CI, FunctionOpInterface fn, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
LogicalResult batchOperation(SymbolTableCollection &symbolTable, OpBuilder &builder, T CI, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
mlir::TensorType applyBatchSizes(mlir::Type Ty, llvm::ArrayRef< int64_t > batchSizes)
void batchOperationInline(PatternRewriter &rewriter, enzyme::BatchOp batchOp, FunctionOpInterface func)
FunctionOpInterface batchCloneFunction(OpBuilder &builder, FunctionOpInterface F, Twine name, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
void batchCloneBlock(OpBuilder &builder, Block *blk, IRMapping &mapper, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache, bool withoutTerminator)
LogicalResult handleCallOp(func::CallOp callOp, OpBuilder &builder, IRMapping &mapper, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
void batchCloneRegion(OpBuilder &builder, Region *src, Region *dest, IRMapping &mapper, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
bool operator<(const BatchCacheKey &other) const