20#include "mlir/Dialect/Complex/IR/Complex.h"
21#include "mlir/IR/DialectRegistry.h"
22#include "mlir/Support/LogicalResult.h"
28#include "Implementations/ComplexDerivatives.inc"
30struct ComplexAddSimplifyMathInterface
31 :
public MathSimplifyInterface::ExternalModel<
32 ComplexAddSimplifyMathInterface, complex::AddOp> {
33 mlir::LogicalResult simplifyMath(Operation *src,
34 PatternRewriter &rewriter)
const {
35 auto op = cast<complex::AddOp>(src);
37 auto ATy = cast<AutoDiffTypeInterface>(op.getLhs().getType());
39 if (ATy.isZero(op.getLhs())) {
40 rewriter.replaceOp(op, op.getRhs());
44 if (ATy.isZero(op.getRhs())) {
45 rewriter.replaceOp(op, op.getLhs());
53struct ComplexSubSimplifyMathInterface
54 :
public MathSimplifyInterface::ExternalModel<
55 ComplexSubSimplifyMathInterface, complex::SubOp> {
56 mlir::LogicalResult simplifyMath(Operation *src,
57 PatternRewriter &rewriter)
const {
58 auto op = cast<complex::SubOp>(src);
60 auto ATy = cast<AutoDiffTypeInterface>(op.getLhs().getType());
62 if (ATy.isZero(op.getRhs())) {
63 rewriter.replaceOp(op, op.getLhs());
67 if (ATy.isZero(op.getLhs())) {
68 rewriter.replaceOpWithNewOp<complex::NegOp>(op, op.getRhs());
79 DialectRegistry ®istry) {
80 registry.addExtension(+[](MLIRContext *context, complex::ComplexDialect *) {
81 complex::AddOp::attachInterface<ComplexAddSimplifyMathInterface>(*context);
82 complex::SubOp::attachInterface<ComplexSubSimplifyMathInterface>(*context);
83 registerInterfaces(context);
void registerComplexDialectAutoDiffInterface(DialectRegistry ®istry)