42 llvm::ArrayRef<int64_t> batchSizes) {
43 auto T = dyn_cast<TensorType>(Ty);
45 return RankedTensorType::get(batchSizes, Ty);
48 SmallVector<int64_t> shape(batchSizes.begin(), batchSizes.end());
49 shape.append(T.getShape().begin(), T.getShape().end());
50 auto T2 = T.clone(shape);
55 func::CallOp callOp, OpBuilder &builder, IRMapping &mapper,
56 llvm::ArrayRef<int64_t> batchSizes,
57 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
59 auto moduleOp = callOp->getParentOfType<ModuleOp>();
61 dyn_cast<FunctionOpInterface>(moduleOp.lookupSymbol(callOp.getCallee()));
67 SmallVector<int64_t>(batchSizes.begin(), batchSizes.end())};
70 FunctionOpInterface batchedFunc;
71 auto it = batchedFunctionCache.find(key);
72 if (it != batchedFunctionCache.end()) {
73 batchedFunc = it->second;
75 std::string fnName =
"batched_" + calledFunc.getName().str();
77 batchedFunctionCache);
80 batchedFunctionCache[key] = batchedFunc;
84 SmallVector<Value> newOperands;
85 for (
auto operand : callOp->getOperands())
86 newOperands.push_back(mapper.lookup(operand));
89 func::CallOp::create(builder, callOp.getLoc(), batchedFunc.getName(),
90 batchedFunc.getResultTypes(), newOperands);
93 for (
auto [oldResult, newResult] :
94 llvm::zip(callOp.getResults(), newCall.getResults()))
95 mapper.map(oldResult, newResult);
101 OpBuilder &builder, Block *blk, IRMapping &mapper,
102 llvm::ArrayRef<int64_t> batchSizes,
103 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache,
104 bool withoutTerminator) {
105 for (
auto &src : *blk) {
106 if (
auto callOp = dyn_cast<func::CallOp>(&src)) {
107 if (succeeded(
handleCallOp(callOp, builder, mapper, batchSizes,
108 batchedFunctionCache)))
112 if (
auto ifaceOp = dyn_cast<BatchOpInterface>(&src)) {
113 auto res = ifaceOp.createBatch(builder, mapper, batchSizes);
118 SmallVector<Value, 8> operands;
119 SmallVector<Block *, 2> successors;
122 operands.reserve(src.getNumOperands());
123 for (
auto opValue : src.getOperands())
124 operands.push_back(mapper.lookup(opValue));
126 if (withoutTerminator && src.hasTrait<OpTrait::IsTerminator>()) {
128 for (
unsigned i = 0, e = src.getNumResults(); i != e; ++i)
129 mapper.map(src.getResult(i), operands[i]);
134 successors.reserve(src.getNumSuccessors());
135 for (Block *successor : src.getSuccessors())
136 successors.push_back(mapper.lookup(successor));
138 SmallVector<Type> resultTypes(src.getResultTypes().begin(),
139 src.getResultTypes().end());
140 for (
auto &Ty : resultTypes) {
144 Operation *newOp = Operation::create(
145 src.getLoc(), src.getName(), resultTypes, operands, src.getAttrs(),
146 mlir::PropertyRef(), successors, src.getNumRegions());
149 for (
auto &&[oldReg, newReg] :
150 llvm::zip(src.getRegions(), newOp->getRegions())) {
152 batchedFunctionCache);
156 for (
unsigned i = 0, e = src.getNumResults(); i != e; ++i)
157 mapper.map(src.getResult(i), newOp->getResult(i));
159 builder.insert(newOp);
164 OpBuilder &builder, Region *src, Region *dest, IRMapping &mapper,
165 llvm::ArrayRef<int64_t> batchSizes,
166 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
168 for (
auto &blk : *src) {
169 auto newBlk =
new Block();
170 dest->push_back(newBlk);
172 mapper.map(&blk, newBlk);
174 for (
auto arg : blk.getArguments()) {
175 Value newArg = newBlk->addArgument(
177 mapper.map(arg, newArg);
181 for (
auto &&[blk, newBlk] : llvm::zip(*src, *dest)) {
182 IRRewriter::InsertionGuard insertGuard(builder);
183 builder.setInsertionPointToEnd(&newBlk);
184 batchCloneBlock(builder, &blk, mapper, batchSizes, batchedFunctionCache,
190 OpBuilder &builder, FunctionOpInterface F, Twine name,
191 llvm::ArrayRef<int64_t> batchSizes,
192 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
193 assert(!F.getFunctionBody().empty());
195 auto FTy = cast<FunctionType>(F.getFunctionType());
197 llvm::SmallVector<mlir::Type> RetTypes;
198 RetTypes.reserve(FTy.getNumResults());
200 for (
auto Ty : FTy.getResults()) {
204 SmallVector<mlir::Type, 4> ArgTypes;
205 ArgTypes.reserve(FTy.getNumInputs());
207 for (
auto Ty : FTy.getInputs()) {
211 FunctionType newFTy = builder.getFunctionType(ArgTypes, RetTypes);
213 auto NewF = cast<FunctionOpInterface>(F->cloneWithoutRegions());
214 SymbolTable::setSymbolName(NewF, name.str());
215 NewF.setType(newFTy);
217 Operation *parent = F->getParentWithTrait<OpTrait::SymbolTable>();
218 SymbolTable table(parent);
220 SymbolTable::setSymbolVisibility(NewF, SymbolTable::Visibility::Private);
225 SmallVector<int64_t>(batchSizes.begin(), batchSizes.end())};
226 batchedFunctionCache[key] = NewF;
228 auto &origReg = F.getFunctionBody();
229 auto &newReg = NewF.getFunctionBody();
233 batchedFunctionCache);
240 OpBuilder &builder, T CI, FunctionOpInterface fn,
241 std::map<BatchCacheKey, FunctionOpInterface> &batchedFunctionCache) {
243 fn, SmallVector<int64_t>(CI.getBatchShape().begin(),
244 CI.getBatchShape().end())};
247 auto it = batchedFunctionCache.find(key);
248 FunctionOpInterface newFunc;
250 if (it != batchedFunctionCache.end()) {
254 std::string newFnName =
"batched_" + fn.getName().str();
256 batchedFunctionCache);
FunctionOpInterface batchCloneFunction(OpBuilder &builder, FunctionOpInterface F, Twine name, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
void batchCloneBlock(OpBuilder &builder, Block *blk, IRMapping &mapper, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache, bool withoutTerminator)
LogicalResult handleCallOp(func::CallOp callOp, OpBuilder &builder, IRMapping &mapper, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)
void batchCloneRegion(OpBuilder &builder, Region *src, Region *dest, IRMapping &mapper, llvm::ArrayRef< int64_t > batchSizes, std::map< BatchCacheKey, FunctionOpInterface > &batchedFunctionCache)