7#include "mlir/IR/PatternMatch.h"
8#include "mlir/Transforms/DialectConversion.h"
10#define DEBUG_TYPE "enzyme-batch-to-tensor"
11#define ENZYME_DBGS llvm::dbgs() << "[" << DEBUG_TYPE << "]"
15using namespace enzyme;
19#define GEN_PASS_DEF_ENZYMEBATCHTOTENSORPASS
20#include "Passes/Passes.h.inc"
26struct ExtractOpConversion :
public OpConversionPattern<enzyme::ExtractOp> {
27 using OpConversionPattern<enzyme::ExtractOp>::OpConversionPattern;
30 matchAndRewrite(enzyme::ExtractOp op, OpAdaptor adaptor,
31 ConversionPatternRewriter &rewriter)
const override {
33 auto outTy = op.getOutput().getType();
34 if (
auto outTensorTy = dyn_cast<TensorType>(outTy)) {
35 auto outRankTy = dyn_cast<RankedTensorType>(outTy);
36 auto rank = outRankTy.getRank();
41 SmallVector<OpFoldResult> offset = {op.getIndexAttr()},
42 sizes = {rewriter.getI64IntegerAttr(1)},
44 rewriter.getI64IntegerAttr(1));
46 SmallVector<OpFoldResult>(rank, rewriter.getI64IntegerAttr(0)));
48 for (
auto dim : outRankTy.getShape()) {
49 sizes.push_back(rewriter.getI64IntegerAttr(dim));
52 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
53 op, outRankTy, op.getInput(), offset, sizes, strides);
56 }
else if (outTy.isIntOrIndexOrFloat()) {
59 arith::ConstantIndexOp::create(rewriter, op->getLoc(), op.getIndex());
60 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, op->getResultTypes(),
61 op.getInput(), indexOp);
72struct ConcatOpConversion :
public OpConversionPattern<enzyme::ConcatOp> {
73 using OpConversionPattern<enzyme::ConcatOp>::OpConversionPattern;
76 matchAndRewrite(enzyme::ConcatOp op, OpAdaptor adaptor,
77 ConversionPatternRewriter &rewriter)
const override {
79 SmallVector<Value> inputs = op.getInputs();
83 auto firstInTy = inputs.front().getType();
85 if (
auto firstRankTy = dyn_cast<RankedTensorType>(firstInTy)) {
88 auto rank = firstRankTy.getRank();
91 SmallVector<Attribute> reassociationMap;
94 reassociationMap.push_back(rewriter.getI64ArrayAttr({0, 1}));
97 for (
auto i = 1; i < rank; ++i) {
99 reassociationMap.push_back(rewriter.getI64ArrayAttr({i + 1}));
102 ArrayAttr reassociationAttr =
103 ArrayAttr::get(rewriter.getContext(), reassociationMap);
106 SmallVector<Value> expandedInputs;
108 for (Value in : inputs) {
109 auto inRankTy = cast<RankedTensorType>(in.getType());
110 auto inShape = inRankTy.getShape();
111 SmallVector<Value> outDynamicDims;
113 SmallVector<int64_t> newInShape = {1};
114 newInShape.append(inShape.begin(), inShape.end());
116 for (
auto i = 0; i < rank; ++i) {
117 if (inRankTy.isDynamicDim(i)) {
120 arith::ConstantIndexOp::create(rewriter, op->getLoc(), i);
122 tensor::DimOp::create(rewriter, op->getLoc(), in, dynIdx);
123 outDynamicDims.push_back(dynVal);
127 auto newInTy = inRankTy.clone(newInShape);
128 auto outStaticDimAttr =
129 rewriter.getDenseI64ArrayAttr(newInTy.getShape());
131 Value newInput = tensor::ExpandShapeOp::create(
132 rewriter, op->getLoc(), newInTy, in, reassociationAttr,
133 outDynamicDims, outStaticDimAttr);
135 expandedInputs.push_back(newInput);
138 rewriter.replaceOpWithNewOp<tensor::ConcatOp>(op, op->getResultTypes(),
141 }
else if (firstInTy.isIntOrIndexOrFloat()) {
142 rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(
143 op, op->getResultTypes(), inputs);
153struct EnzymeBatchToTensorPass
154 :
public enzyme::impl::EnzymeBatchToTensorPassBase<
155 EnzymeBatchToTensorPass> {
156 void runOnOperation()
override {
157 MLIRContext *context = &getContext();
158 RewritePatternSet patterns(context);
162 patterns.add<ConcatOpConversion, ExtractOpConversion>(context);
164 ConversionTarget target(*context);
165 target.addLegalDialect<arith::ArithDialect>();
166 target.addLegalDialect<tensor::TensorDialect>();
167 target.addLegalDialect<enzyme::EnzymeDialect>();
168 target.addIllegalOp<enzyme::ConcatOp, enzyme::ExtractOp>();
170 if (failed(applyPartialConversion(getOperation(), target,
171 std::move(patterns)))) {