17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Complex/IR/Complex.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#define GEN_PASS_DEF_MATHEMATICSIMPLIFICATIONPASS
31#include "Passes/Passes.h.inc"
38struct ApplySimplificationPattern
39 :
public OpInterfaceRewritePattern<enzyme::MathSimplifyInterface> {
40 using OpInterfaceRewritePattern<
41 enzyme::MathSimplifyInterface>::OpInterfaceRewritePattern;
43 LogicalResult matchAndRewrite(enzyme::MathSimplifyInterface op,
44 PatternRewriter &rewriter)
const override {
45 return op.simplifyMath(rewriter);
49struct MathematicSimplification
50 :
public enzyme::impl::MathematicSimplificationPassBase<
51 MathematicSimplification> {
52 void runOnOperation()
override {
54 RewritePatternSet patterns(&getContext());
55 patterns.insert<ApplySimplificationPattern>(&getContext());
57 GreedyRewriteConfig config;
58 config.enableFolding();
59 (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);