Enzyme main
Loading...
Searching...
No Matches
ArithAutoDiffOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- ArithAutoDiffOpInterfaceImpl.cpp - Interface external model --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the external model implementation of the automatic
10// differentiation op interfaces for the upstream MLIR arithmetic dialect.
11//
12//===----------------------------------------------------------------------===//
13
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/IR/DialectRegistry.h"
22#include "mlir/Support/LogicalResult.h"
23
24#include "Dialect/Ops.h"
25#include "mlir/IR/TypeSupport.h"
26
27using namespace mlir;
28using namespace mlir::enzyme;
29
30namespace {
31
32struct ArithConstantOpBatchInterface
33 : public BatchOpInterface::ExternalModel<ArithConstantOpBatchInterface,
34 arith::ConstantOp> {
35
36 mlir::LogicalResult createBatch(Operation *src, OpBuilder &builder,
37 IRMapping &mapper,
38 ArrayRef<int64_t> batchSizes) const {
39
40 SmallVector<Type> resultTypes(src->getResultTypes().begin(),
41 src->getResultTypes().end());
42 for (auto &Ty : resultTypes) {
43 auto T = cast<TensorType>(Ty);
44 SmallVector<int64_t> shape(batchSizes.begin(), batchSizes.end());
45 shape.append(T.getShape().begin(), T.getShape().end());
46 Ty = T.clone(shape);
47 }
48 mlir::NamedAttrList attrs;
49 for (auto attr : src->getAttrs()) {
50 auto eattr = cast<DenseElementsAttr>(attr.getValue());
51 attr.setValue(eattr.resizeSplat(cast<ShapedType>(resultTypes[0])));
52 attrs.append(attr);
53 }
54 auto cop = mlir::Operation::create(
55 src->getLoc(), src->getName(), resultTypes, {}, std::move(attrs),
56 mlir::PropertyRef(), mlir::BlockRange(), 0);
57 builder.insert(cop);
58 mapper.map(src->getResult(0), cop->getResult(0));
59 return success();
60 }
61};
62
63struct ArithAddFSimplifyMathInterface
64 : public MathSimplifyInterface::ExternalModel<
65 ArithAddFSimplifyMathInterface, arith::AddFOp> {
66 mlir::LogicalResult simplifyMath(Operation *src,
67 PatternRewriter &rewriter) const {
68 auto op = cast<arith::AddFOp>(src);
69
70 if (matchPattern(op.getLhs(), m_AnyZeroFloat())) {
71 rewriter.replaceOp(op, op.getRhs());
72 return success();
73 }
74
75 if (matchPattern(op.getRhs(), m_AnyZeroFloat())) {
76 rewriter.replaceOp(op, op.getLhs());
77 return success();
78 }
79
80 return failure();
81 }
82};
83
84struct ArithSubFSimplifyMathInterface
85 : public MathSimplifyInterface::ExternalModel<
86 ArithSubFSimplifyMathInterface, arith::SubFOp> {
87 mlir::LogicalResult simplifyMath(Operation *src,
88 PatternRewriter &rewriter) const {
89 auto op = cast<arith::SubFOp>(src);
90
91 if (matchPattern(op.getRhs(), m_AnyZeroFloat())) {
92 rewriter.replaceOp(op, op.getLhs());
93 return success();
94 }
95
96 if (matchPattern(op.getLhs(), m_AnyZeroFloat())) {
97 rewriter.replaceOpWithNewOp<arith::NegFOp>(op, op.getRhs());
98 return success();
99 }
100
101 return failure();
102 }
103};
104
105#include "Implementations/ArithDerivatives.inc"
106} // namespace
107
109 DialectRegistry &registry) {
110 registry.addExtension(+[](MLIRContext *context, arith::ArithDialect *) {
111 registerInterfaces(context);
112 arith::ConstantOp::attachInterface<ArithConstantOpBatchInterface>(*context);
113 arith::AddFOp::attachInterface<ArithAddFSimplifyMathInterface>(*context);
114 arith::SubFOp::attachInterface<ArithSubFSimplifyMathInterface>(*context);
115 });
116}
117
119 DialectRegistry &registry) {
120 registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) {
121 registerInterfaces(context);
122 });
123}
void registerArithDialectAutoDiffInterface(DialectRegistry &registry)
void registerTensorDialectAutoDiffInterface(DialectRegistry &registry)