Enzyme main
Loading...
Searching...
No Matches
ComplexAutoDiffOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- ComplexAutoDiffOpInterfaceImpl.cpp - Interface external model --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the external model implementation of the automatic
10// differentiation op interfaces for the upstream MLIR complex dialect.
11//
12//===----------------------------------------------------------------------===//
13
19
20#include "mlir/Dialect/Complex/IR/Complex.h"
21#include "mlir/IR/DialectRegistry.h"
22#include "mlir/Support/LogicalResult.h"
23
24using namespace mlir;
25using namespace mlir::enzyme;
26
27namespace {
28#include "Implementations/ComplexDerivatives.inc"
29
30struct ComplexAddSimplifyMathInterface
31 : public MathSimplifyInterface::ExternalModel<
32 ComplexAddSimplifyMathInterface, complex::AddOp> {
33 mlir::LogicalResult simplifyMath(Operation *src,
34 PatternRewriter &rewriter) const {
35 auto op = cast<complex::AddOp>(src);
36
37 auto ATy = cast<AutoDiffTypeInterface>(op.getLhs().getType());
38
39 if (ATy.isZero(op.getLhs())) {
40 rewriter.replaceOp(op, op.getRhs());
41 return success();
42 }
43
44 if (ATy.isZero(op.getRhs())) {
45 rewriter.replaceOp(op, op.getLhs());
46 return success();
47 }
48
49 return failure();
50 }
51};
52
53struct ComplexSubSimplifyMathInterface
54 : public MathSimplifyInterface::ExternalModel<
55 ComplexSubSimplifyMathInterface, complex::SubOp> {
56 mlir::LogicalResult simplifyMath(Operation *src,
57 PatternRewriter &rewriter) const {
58 auto op = cast<complex::SubOp>(src);
59
60 auto ATy = cast<AutoDiffTypeInterface>(op.getLhs().getType());
61
62 if (ATy.isZero(op.getRhs())) {
63 rewriter.replaceOp(op, op.getLhs());
64 return success();
65 }
66
67 if (ATy.isZero(op.getLhs())) {
68 rewriter.replaceOpWithNewOp<complex::NegOp>(op, op.getRhs());
69 return success();
70 }
71
72 return failure();
73 }
74};
75
76} // namespace
77
79 DialectRegistry &registry) {
80 registry.addExtension(+[](MLIRContext *context, complex::ComplexDialect *) {
81 complex::AddOp::attachInterface<ComplexAddSimplifyMathInterface>(*context);
82 complex::SubOp::attachInterface<ComplexSubSimplifyMathInterface>(*context);
83 registerInterfaces(context);
84 });
85}
void registerComplexDialectAutoDiffInterface(DialectRegistry &registry)