Enzyme main
Loading...
Searching...
No Matches
EnzymeBatchToTensorPass.cpp
Go to the documentation of this file.
1#include "Dialect/Ops.h"
3#include "Interfaces/Utils.h"
4#include "PassDetails.h"
5#include "Passes/Passes.h"
6
7#include "mlir/IR/PatternMatch.h"
8#include "mlir/Transforms/DialectConversion.h"
9
10#define DEBUG_TYPE "enzyme-batch-to-tensor"
11#define ENZYME_DBGS llvm::dbgs() << "[" << DEBUG_TYPE << "]"
12
13using namespace mlir;
14using namespace mlir::enzyme;
15using namespace enzyme;
16
17namespace mlir {
18namespace enzyme {
19#define GEN_PASS_DEF_ENZYMEBATCHTOTENSORPASS
20#include "Passes/Passes.h.inc"
21} // namespace enzyme
22} // namespace mlir
23
24namespace {
25
26struct ExtractOpConversion : public OpConversionPattern<enzyme::ExtractOp> {
27 using OpConversionPattern<enzyme::ExtractOp>::OpConversionPattern;
28
29 LogicalResult
30 matchAndRewrite(enzyme::ExtractOp op, OpAdaptor adaptor,
31 ConversionPatternRewriter &rewriter) const override {
32 // filter based on out type
33 auto outTy = op.getOutput().getType();
34 if (auto outTensorTy = dyn_cast<TensorType>(outTy)) {
35 auto outRankTy = dyn_cast<RankedTensorType>(outTy);
36 auto rank = outRankTy.getRank();
37
38 // Offsets : [index, 0, 0 ...]
39 // Sizes : [1, out_dim1, out_dim2 ...]
40 // Strides : [1,1,1,....]
41 SmallVector<OpFoldResult> offset = {op.getIndexAttr()},
42 sizes = {rewriter.getI64IntegerAttr(1)},
43 strides(rank + 1,
44 rewriter.getI64IntegerAttr(1));
45 offset.append(
46 SmallVector<OpFoldResult>(rank, rewriter.getI64IntegerAttr(0)));
47
48 for (auto dim : outRankTy.getShape()) {
49 sizes.push_back(rewriter.getI64IntegerAttr(dim));
50 }
51
52 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
53 op, outRankTy, op.getInput(), offset, sizes, strides);
54
55 return success();
56 } else if (outTy.isIntOrIndexOrFloat()) {
57 // ExtractOp expects only index type arg
58 Value indexOp =
59 arith::ConstantIndexOp::create(rewriter, op->getLoc(), op.getIndex());
60 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, op->getResultTypes(),
61 op.getInput(), indexOp);
62 return success();
63 } else {
64 // unsupported type
65 // TODO: handle memrefs
66
67 return failure();
68 }
69 }
70};
71
72struct ConcatOpConversion : public OpConversionPattern<enzyme::ConcatOp> {
73 using OpConversionPattern<enzyme::ConcatOp>::OpConversionPattern;
74
75 LogicalResult
76 matchAndRewrite(enzyme::ConcatOp op, OpAdaptor adaptor,
77 ConversionPatternRewriter &rewriter) const override {
78 // filter based on out type
79 SmallVector<Value> inputs = op.getInputs();
80 if (inputs.empty())
81 return failure();
82
83 auto firstInTy = inputs.front().getType();
84
85 if (auto firstRankTy = dyn_cast<RankedTensorType>(firstInTy)) {
86
87 // rank has to be the same for all inputs
88 auto rank = firstRankTy.getRank();
89
90 // Build the reassociation map attribute for expand_shape
91 SmallVector<Attribute> reassociationMap;
92
93 if (rank > 0) {
94 reassociationMap.push_back(rewriter.getI64ArrayAttr({0, 1}));
95 }
96
97 for (auto i = 1; i < rank; ++i) {
98 // src dim 'i' goes to dest dim 'i+1'
99 reassociationMap.push_back(rewriter.getI64ArrayAttr({i + 1}));
100 }
101
102 ArrayAttr reassociationAttr =
103 ArrayAttr::get(rewriter.getContext(), reassociationMap);
104
105 // tensor.expand_shape for every input argument
106 SmallVector<Value> expandedInputs;
107
108 for (Value in : inputs) {
109 auto inRankTy = cast<RankedTensorType>(in.getType());
110 auto inShape = inRankTy.getShape();
111 SmallVector<Value> outDynamicDims;
112
113 SmallVector<int64_t> newInShape = {1};
114 newInShape.append(inShape.begin(), inShape.end());
115
116 for (auto i = 0; i < rank; ++i) {
117 if (inRankTy.isDynamicDim(i)) {
118 // extract dynamic dim
119 Value dynIdx =
120 arith::ConstantIndexOp::create(rewriter, op->getLoc(), i);
121 Value dynVal =
122 tensor::DimOp::create(rewriter, op->getLoc(), in, dynIdx);
123 outDynamicDims.push_back(dynVal);
124 }
125 }
126
127 auto newInTy = inRankTy.clone(newInShape);
128 auto outStaticDimAttr =
129 rewriter.getDenseI64ArrayAttr(newInTy.getShape());
130
131 Value newInput = tensor::ExpandShapeOp::create(
132 rewriter, op->getLoc(), newInTy, in, reassociationAttr,
133 outDynamicDims, outStaticDimAttr);
134
135 expandedInputs.push_back(newInput);
136 }
137
138 rewriter.replaceOpWithNewOp<tensor::ConcatOp>(op, op->getResultTypes(),
139 /*dim*/ 0, expandedInputs);
140 return success();
141 } else if (firstInTy.isIntOrIndexOrFloat()) {
142 rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(
143 op, op->getResultTypes(), inputs);
144 return success();
145 } else {
146 // unsupported type
147 // TODO: handle memrefs
148 return failure();
149 }
150 }
151};
152
153struct EnzymeBatchToTensorPass
154 : public enzyme::impl::EnzymeBatchToTensorPassBase<
155 EnzymeBatchToTensorPass> {
156 void runOnOperation() override {
157 MLIRContext *context = &getContext();
158 RewritePatternSet patterns(context);
159
160 // NOTE: May need a typeConverter here when lowering batched memrefs.
161
162 patterns.add<ConcatOpConversion, ExtractOpConversion>(context);
163
164 ConversionTarget target(*context);
165 target.addLegalDialect<arith::ArithDialect>();
166 target.addLegalDialect<tensor::TensorDialect>();
167 target.addLegalDialect<enzyme::EnzymeDialect>();
168 target.addIllegalOp<enzyme::ConcatOp, enzyme::ExtractOp>();
169
170 if (failed(applyPartialConversion(getOperation(), target,
171 std::move(patterns)))) {
172 signalPassFailure();
173 }
174 };
175};
176
177} // namespace