Enzyme main
Loading...
Searching...
No Matches
AddToOpToSplit.cpp
Go to the documentation of this file.
1//===- AddToOpToSplit.cpp - Lower Shadowed Gradient ops
2//------------------ //
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements a pass to lower custom ops generated by the Enzyme AD
11// procedure to the MemRef dialect.
12//===----------------------------------------------------------------------===//
13
14#include "Dialect/Dialect.h"
15#include "Dialect/Ops.h"
16#include "PassDetails.h"
17#include "Passes/Passes.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/Transforms/DialectConversion.h"
24
25#include "mlir/Rewrite/PatternApplicator.h"
26#include "llvm/Support/raw_ostream.h"
27
29#include "mlir/Dialect/Affine/IR/AffineOps.h"
30#include "mlir/Dialect/Linalg/IR/Linalg.h"
31
32#include "Utils.h"
33
34namespace mlir {
35namespace enzyme {
36#define GEN_PASS_DEF_ADDTOOPTOSPLITPASS
37#include "Passes/Passes.h.inc"
38} // namespace enzyme
39} // namespace mlir
40
41using namespace mlir;
42using namespace enzyme;
43using llvm::errs;
44namespace {
45
46Operation *getAddToOp(linalg::GenericOp &adjoint) {
47 Operation *addToOp = nullptr;
48 adjoint.walk([&](Operation *op) {
49 if (isa<enzyme::AddToOp>(op)) {
50 addToOp = op;
51 }
52 });
53 return addToOp;
54}
55
56bool isMemrefCacheType(Type type) {
57 if (auto memrefType = dyn_cast<MemRefType>(type)) {
58 return isa<CacheType>(memrefType.getElementType());
59 }
60 return false;
61}
62
63void processGenericDuplication(Operation *op, OpBuilder &builder, Location loc,
64 int i) {
65 auto clonedAdjoint = builder.clone(*op);
66 linalg::GenericOp clonedAdjointGenericOp =
67 cast<linalg::GenericOp>(clonedAdjoint);
68
69 // Delete all but the ith output
70 unsigned numInputs = clonedAdjointGenericOp.getInputsMutable().size();
71 SmallVector<mlir::Attribute> indexingMaps(
72 clonedAdjointGenericOp.getIndexingMaps().getValue());
73
74 if (clonedAdjointGenericOp.getOutputsMutable().size() - i - 1 > 0) {
75 unsigned idx = i + 1;
76 unsigned len = clonedAdjointGenericOp.getOutputsMutable().size() - i - 1;
77 clonedAdjointGenericOp.getOutputsMutable().erase(idx, len);
78 clonedAdjointGenericOp.getRegion().front().eraseArguments(numInputs + idx,
79 len);
80 indexingMaps.erase(indexingMaps.begin() + numInputs + idx,
81 indexingMaps.begin() + numInputs + idx + len);
82 }
83
84 if (i > 0) {
85 clonedAdjointGenericOp.getOutputsMutable().erase(0, i);
86 clonedAdjointGenericOp.getRegion().front().eraseArguments(numInputs, i);
87 indexingMaps.erase(indexingMaps.begin() + numInputs,
88 indexingMaps.begin() + numInputs + i);
89 }
90 clonedAdjointGenericOp.setIndexingMapsAttr(
91 builder.getArrayAttr(indexingMaps));
92
93 auto clonedAddToOp = getAddToOp(clonedAdjointGenericOp);
94 auto scope = OpBuilder::InsertionGuard(builder);
95
96 builder.setInsertionPointAfter(clonedAddToOp);
97 auto terminator = linalg::YieldOp::create(builder, loc);
98
99 auto operand = clonedAddToOp->getOperand(i);
100 auto outputOperand =
101 clonedAdjointGenericOp.getRegion().front().getArgument(numInputs + i);
102 auto operandType = cast<AutoDiffTypeInterface>(operand.getType());
103
104 builder.setInsertionPoint(terminator);
105 auto returnValue =
106 operandType.createAddOp(builder, loc, outputOperand, operand);
107
108 terminator->setOperands({returnValue});
109 clonedAddToOp->erase();
110}
111
112struct AddToOpToSplitPass
113 : public enzyme::impl::AddToOpToSplitPassBase<AddToOpToSplitPass> {
114 void runOnOperation() override {
115 getOperation()->walk([&](Operation *op) {
116 auto enzymeAdjoint = dyn_cast<enzyme::GenericAdjointOp>(op);
117 auto loc = op->getLoc();
118 if (!enzymeAdjoint)
119 return;
120
121 OpBuilder builder(enzymeAdjoint);
122 auto adjoint = Utils::adjointToGeneric(enzymeAdjoint, builder, loc);
123
124 Operation *addToOp = getAddToOp(adjoint);
125 if (!addToOp)
126 return;
127
128 // TODO duplicate memref<CacheType> inputs
129 // For now just error out
130 for (auto input : adjoint.getInputs()) {
131 if (isMemrefCacheType(input.getType())) {
132 llvm::report_fatal_error(
133 "Cannot split AddToOp with memref<CacheType> inputs");
134 return;
135 }
136 }
137
138 builder.setInsertionPoint(adjoint);
139
140 for (size_t i = 0; i < addToOp->getNumOperands(); i++) {
141 processGenericDuplication(adjoint.getOperation(), builder, loc, i);
142 }
143 adjoint->erase();
144 });
145 };
146};
147} // end anonymous namespace
static mlir::linalg::GenericOp adjointToGeneric(enzyme::GenericAdjointOp &op, OpBuilder &builder, Location loc)
Definition Utils.cpp:15