18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Rewrite/PatternApplicator.h"
26#include "llvm/Support/raw_ostream.h"
29#include "mlir/Dialect/Affine/IR/AffineOps.h"
30#include "mlir/Dialect/Linalg/IR/Linalg.h"
36#define GEN_PASS_DEF_ADDTOOPTOSPLITPASS
37#include "Passes/Passes.h.inc"
42using namespace enzyme;
46Operation *getAddToOp(linalg::GenericOp &adjoint) {
47 Operation *addToOp =
nullptr;
48 adjoint.walk([&](Operation *op) {
49 if (isa<enzyme::AddToOp>(op)) {
56bool isMemrefCacheType(Type type) {
57 if (
auto memrefType = dyn_cast<MemRefType>(type)) {
58 return isa<CacheType>(memrefType.getElementType());
63void processGenericDuplication(Operation *op, OpBuilder &builder, Location loc,
65 auto clonedAdjoint = builder.clone(*op);
66 linalg::GenericOp clonedAdjointGenericOp =
67 cast<linalg::GenericOp>(clonedAdjoint);
70 unsigned numInputs = clonedAdjointGenericOp.getInputsMutable().size();
71 SmallVector<mlir::Attribute> indexingMaps(
72 clonedAdjointGenericOp.getIndexingMaps().getValue());
74 if (clonedAdjointGenericOp.getOutputsMutable().size() - i - 1 > 0) {
76 unsigned len = clonedAdjointGenericOp.getOutputsMutable().size() - i - 1;
77 clonedAdjointGenericOp.getOutputsMutable().erase(idx, len);
78 clonedAdjointGenericOp.getRegion().front().eraseArguments(numInputs + idx,
80 indexingMaps.erase(indexingMaps.begin() + numInputs + idx,
81 indexingMaps.begin() + numInputs + idx + len);
85 clonedAdjointGenericOp.getOutputsMutable().erase(0, i);
86 clonedAdjointGenericOp.getRegion().front().eraseArguments(numInputs, i);
87 indexingMaps.erase(indexingMaps.begin() + numInputs,
88 indexingMaps.begin() + numInputs + i);
90 clonedAdjointGenericOp.setIndexingMapsAttr(
91 builder.getArrayAttr(indexingMaps));
93 auto clonedAddToOp = getAddToOp(clonedAdjointGenericOp);
94 auto scope = OpBuilder::InsertionGuard(builder);
96 builder.setInsertionPointAfter(clonedAddToOp);
97 auto terminator = linalg::YieldOp::create(builder, loc);
99 auto operand = clonedAddToOp->getOperand(i);
101 clonedAdjointGenericOp.getRegion().front().getArgument(numInputs + i);
102 auto operandType = cast<AutoDiffTypeInterface>(operand.getType());
104 builder.setInsertionPoint(terminator);
106 operandType.createAddOp(builder, loc, outputOperand, operand);
108 terminator->setOperands({returnValue});
109 clonedAddToOp->erase();
112struct AddToOpToSplitPass
113 :
public enzyme::impl::AddToOpToSplitPassBase<AddToOpToSplitPass> {
114 void runOnOperation()
override {
115 getOperation()->walk([&](Operation *op) {
116 auto enzymeAdjoint = dyn_cast<enzyme::GenericAdjointOp>(op);
117 auto loc = op->getLoc();
121 OpBuilder builder(enzymeAdjoint);
124 Operation *addToOp = getAddToOp(adjoint);
130 for (
auto input : adjoint.getInputs()) {
131 if (isMemrefCacheType(input.getType())) {
132 llvm::report_fatal_error(
133 "Cannot split AddToOp with memref<CacheType> inputs");
138 builder.setInsertionPoint(adjoint);
140 for (
size_t i = 0; i < addToOp->getNumOperands(); i++) {
141 processGenericDuplication(adjoint.getOperation(), builder, loc, i);
static mlir::linalg::GenericOp adjointToGeneric(enzyme::GenericAdjointOp &op, OpBuilder &builder, Location loc)