1#ifndef ENZYME_BATCH_PASS_H
2#define ENZYME_BATCH_PASS_H
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10#include "mlir/IR/Builders.h"
11#include "mlir/Interfaces/FunctionInterfaces.h"
23 if (
const_cast<FunctionOpInterface &
>(
function).getName() !=
24 const_cast<FunctionOpInterface &
>(other.
function).getName())
25 return const_cast<FunctionOpInterface &
>(
function).getName() <
26 const_cast<FunctionOpInterface &
>(other.
function).getName();
32 llvm::ArrayRef<int64_t> batchSizes);
35 OpBuilder &builder, FunctionOpInterface F, Twine name,
36 llvm::ArrayRef<int64_t> batchSizes,
37 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache);
40 OpBuilder &builder, Block *blk, IRMapping &mapper,
41 llvm::ArrayRef<int64_t> batchSizes,
42 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache,
43 bool withoutTerminator);
46 OpBuilder &builder, Region *src, Region *dest, IRMapping &mapper,
47 llvm::ArrayRef<int64_t> batchSizes,
48 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache);
51 func::CallOp callOp, OpBuilder &builder, IRMapping &mapper,
52 llvm::ArrayRef<int64_t> batchSizes,
53 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache);
57 OpBuilder &builder, T CI, FunctionOpInterface fn,
58 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache);
62 SymbolTableCollection &symbolTable, OpBuilder &builder, T CI,
63 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
65 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr());
66 return batchOperation(builder, CI, cast<FunctionOpInterface>(symbolOp),
67 batchedFunctionCache);
72 SymbolTableCollection &symbolTable, PatternRewriter &rewriter, T CI,
73 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
74 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr());
75 return batchOperation(rewriter, CI, cast<FunctionOpInterface>(symbolOp),
76 batchedFunctionCache);
81 OpBuilder &builder, T CI, FunctionOpInterface fn,
82 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
84 batchedFunctionCache);
90 IRRewriter::InsertionGuard insertGuard(builder);
91 builder.setInsertionPoint(CI);
92 auto dCI = func::CallOp::create(builder, CI.getLoc(), newFunc.getName(),
93 newFunc.getResultTypes(), CI.getInputs());
94 CI.replaceAllUsesWith(dCI);
102 PatternRewriter &rewriter, T CI, FunctionOpInterface fn,
103 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
105 batchedFunctionCache);
111 IRRewriter::InsertionGuard insertGuard(rewriter);
112 rewriter.setInsertionPoint(CI);
113 rewriter.replaceOpWithNewOp<func::CallOp>(
114 CI, newFunc.getName(), newFunc.getResultTypes(), CI.getInputs());
122 enzyme::BatchOp batchOp,
123 FunctionOpInterface func) {
124 auto &origRegion = func.getFunctionBody();
125 auto &origBlock = origRegion.front();
128 for (
int i = 0; i < batchOp->getNumOperands(); i++) {
129 mapper.map(origBlock.getArguments()[i], batchOp->getOperand(i));
132 rewriter.setInsertionPoint(batchOp);
133 std::map<BatchCacheKey, FunctionOpInterface> batchedFunctionCache;
135 batchedFunctionCache,
true);
137 auto origTerm = origBlock.getTerminator();
138 for (
auto [i, operand] : llvm::enumerate(origTerm->getOperands())) {
139 auto mappedOperand = mapper.lookup(operand);
140 rewriter.replaceAllUsesWith(batchOp->getResult(i), mappedOperand);
142 rewriter.eraseOp(batchOp);
143 rewriter.eraseOp(func);
FunctionOpInterface batchOperationWithoutInsertingCallOp(OpBuilder &builder, T CI, FunctionOpInterface fn, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
LogicalResult batchOperation(SymbolTableCollection &symbolTable, OpBuilder &builder, T CI, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
mlir::TensorType applyBatchSizes(mlir::Type Ty, llvm::ArrayRef< int64_t > batchSizes)
void batchOperationInline(PatternRewriter &rewriter, enzyme::BatchOp batchOp, FunctionOpInterface func)
FunctionOpInterface batchCloneFunction(OpBuilder &builder, FunctionOpInterface F, Twine name, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
void batchCloneBlock(OpBuilder &builder, Block *blk, IRMapping &mapper, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache, bool withoutTerminator)
LogicalResult handleCallOp(func::CallOp callOp, OpBuilder &builder, IRMapping &mapper, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
void batchCloneRegion(OpBuilder &builder, Region *src, Region *dest, IRMapping &mapper, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
FunctionOpInterface function
bool operator<(const BatchCacheKey &other) const
SmallVector< int64_t > batchSizes