18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/IR/DialectRegistry.h"
22#include "mlir/Support/LogicalResult.h"
25#include "mlir/IR/TypeSupport.h"
32struct ArithConstantOpBatchInterface
33 :
public BatchOpInterface::ExternalModel<ArithConstantOpBatchInterface,
36 mlir::LogicalResult createBatch(Operation *src, OpBuilder &builder,
38 ArrayRef<int64_t> batchSizes)
const {
40 SmallVector<Type> resultTypes(src->getResultTypes().begin(),
41 src->getResultTypes().end());
42 for (
auto &Ty : resultTypes) {
43 auto T = cast<TensorType>(Ty);
44 SmallVector<int64_t> shape(batchSizes.begin(), batchSizes.end());
45 shape.append(T.getShape().begin(), T.getShape().end());
48 mlir::NamedAttrList attrs;
49 for (
auto attr : src->getAttrs()) {
50 auto eattr = cast<DenseElementsAttr>(attr.getValue());
51 attr.setValue(eattr.resizeSplat(cast<ShapedType>(resultTypes[0])));
54 auto cop = mlir::Operation::create(
55 src->getLoc(), src->getName(), resultTypes, {}, std::move(attrs),
56 mlir::PropertyRef(), mlir::BlockRange(), 0);
58 mapper.map(src->getResult(0), cop->getResult(0));
63struct ArithAddFSimplifyMathInterface
64 :
public MathSimplifyInterface::ExternalModel<
65 ArithAddFSimplifyMathInterface, arith::AddFOp> {
66 mlir::LogicalResult simplifyMath(Operation *src,
67 PatternRewriter &rewriter)
const {
68 auto op = cast<arith::AddFOp>(src);
70 if (matchPattern(op.getLhs(), m_AnyZeroFloat())) {
71 rewriter.replaceOp(op, op.getRhs());
75 if (matchPattern(op.getRhs(), m_AnyZeroFloat())) {
76 rewriter.replaceOp(op, op.getLhs());
84struct ArithSubFSimplifyMathInterface
85 :
public MathSimplifyInterface::ExternalModel<
86 ArithSubFSimplifyMathInterface, arith::SubFOp> {
87 mlir::LogicalResult simplifyMath(Operation *src,
88 PatternRewriter &rewriter)
const {
89 auto op = cast<arith::SubFOp>(src);
91 if (matchPattern(op.getRhs(), m_AnyZeroFloat())) {
92 rewriter.replaceOp(op, op.getLhs());
96 if (matchPattern(op.getLhs(), m_AnyZeroFloat())) {
97 rewriter.replaceOpWithNewOp<arith::NegFOp>(op, op.getRhs());
105#include "Implementations/ArithDerivatives.inc"
109 DialectRegistry ®istry) {
110 registry.addExtension(+[](MLIRContext *context, arith::ArithDialect *) {
111 registerInterfaces(context);
112 arith::ConstantOp::attachInterface<ArithConstantOpBatchInterface>(*context);
113 arith::AddFOp::attachInterface<ArithAddFSimplifyMathInterface>(*context);
114 arith::SubFOp::attachInterface<ArithSubFSimplifyMathInterface>(*context);
119 DialectRegistry ®istry) {
120 registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) {
121 registerInterfaces(context);
void registerArithDialectAutoDiffInterface(DialectRegistry ®istry)
void registerTensorDialectAutoDiffInterface(DialectRegistry ®istry)