Enzyme main
Loading...
Searching...
No Matches
MemRefAutoDiffOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- MemRefAutoDiffOpInterfaceImpl.cpp - Interface external model -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the external model implementation of the automatic
10// differentiation op interfaces for the upstream MLIR memref dialect.
11//
12//===----------------------------------------------------------------------===//
13
19
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/IR/DialectRegistry.h"
22#include "mlir/Support/LogicalResult.h"
23
24// TODO: We need a way to zero out a memref (which linalg.fill does), but
25// ideally we wouldn't depend on the linalg dialect.
26#include "mlir/Dialect/Linalg/IR/Linalg.h"
27
28using namespace mlir;
29using namespace mlir::enzyme;
30
31namespace {
32#include "Implementations/MemRefDerivatives.inc"
33
34struct LoadOpInterfaceReverse
35 : public ReverseAutoDiffOpInterface::ExternalModel<LoadOpInterfaceReverse,
36 memref::LoadOp> {
37 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
39 SmallVector<Value> caches) const {
40 auto loadOp = cast<memref::LoadOp>(op);
41 Value memref = loadOp.getMemref();
42
43 if (auto iface = dyn_cast<AutoDiffTypeInterface>(loadOp.getType())) {
44 if (!gutils->isConstantValue(loadOp) &&
45 !gutils->isConstantValue(memref)) {
46 Value gradient = gutils->diffe(loadOp, builder);
47 Value memrefGradient = gutils->popCache(caches.front(), builder);
48
49 SmallVector<Value> retrievedArguments;
50 for (Value cache : ValueRange(caches).drop_front(1)) {
51 Value retrievedValue = gutils->popCache(cache, builder);
52 retrievedArguments.push_back(retrievedValue);
53 }
54
55 if (!gutils->AtomicAdd) {
56 Value loadedGradient =
57 memref::LoadOp::create(builder, loadOp.getLoc(), memrefGradient,
58 ArrayRef<Value>(retrievedArguments));
59 Value addedGradient = iface.createAddOp(builder, loadOp.getLoc(),
60 loadedGradient, gradient);
61 memref::StoreOp::create(builder, loadOp.getLoc(), addedGradient,
62 memrefGradient,
63 ArrayRef<Value>(retrievedArguments));
64 } else {
65 memref::AtomicRMWOp::create(
66 builder, loadOp.getLoc(), arith::AtomicRMWKind::addf, gradient,
67 memrefGradient, ArrayRef<Value>(retrievedArguments));
68 }
69 }
70 }
71 return success();
72 }
73
74 SmallVector<Value> cacheValues(Operation *op,
75 MGradientUtilsReverse *gutils) const {
76 auto loadOp = cast<memref::LoadOp>(op);
77 Value memref = loadOp.getMemref();
78 ValueRange indices = loadOp.getIndices();
79 if (auto iface = dyn_cast<AutoDiffTypeInterface>(loadOp.getType())) {
80 if (!gutils->isConstantValue(loadOp) &&
81 !gutils->isConstantValue(memref)) {
82 OpBuilder cacheBuilder(gutils->getNewFromOriginal(op));
83 SmallVector<Value> caches;
84 caches.push_back(gutils->initAndPushCache(
85 gutils->invertPointerM(memref, cacheBuilder), cacheBuilder));
86 for (Value v : indices) {
87 caches.push_back(gutils->initAndPushCache(
88 gutils->getNewFromOriginal(v), cacheBuilder));
89 }
90 return caches;
91 }
92 }
93 return SmallVector<Value>();
94 }
95
96 void createShadowValues(Operation *op, OpBuilder &builder,
97 MGradientUtilsReverse *gutils) const {
98 // auto loadOp = cast<memref::LoadOp>(op);
99 // Value memref = loadOp.getMemref();
100 // Value shadow = gutils->getShadowValue(memref);
101 // Do nothing yet. In the future support memref<memref<...>>
102 }
103};
104
105struct StoreOpInterfaceReverse
106 : public ReverseAutoDiffOpInterface::ExternalModel<StoreOpInterfaceReverse,
107 memref::StoreOp> {
108 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
109 MGradientUtilsReverse *gutils,
110 SmallVector<Value> caches) const {
111 auto storeOp = cast<memref::StoreOp>(op);
112 Value val = storeOp.getValue();
113 Value memref = storeOp.getMemref();
114 // ValueRange indices = storeOp.getIndices();
115
116 auto iface = cast<AutoDiffTypeInterface>(val.getType());
117
118 if (!gutils->isConstantValue(memref)) {
119 Value memrefGradient = gutils->popCache(caches.front(), builder);
120
121 SmallVector<Value> retrievedArguments;
122 for (Value cache : ValueRange(caches).drop_front(1)) {
123 Value retrievedValue = gutils->popCache(cache, builder);
124 retrievedArguments.push_back(retrievedValue);
125 }
126
127 if (!iface.isMutable()) {
128 if (!gutils->isConstantValue(val)) {
129 Value loadedGradient =
130 memref::LoadOp::create(builder, storeOp.getLoc(), memrefGradient,
131 ArrayRef<Value>(retrievedArguments));
132 gutils->addToDiffe(val, loadedGradient, builder);
133 }
134
135 auto zero =
136 cast<AutoDiffTypeInterface>(gutils->getShadowType(val.getType()))
137 .createNullValue(builder, op->getLoc());
138
139 memref::StoreOp::create(builder, storeOp.getLoc(), zero, memrefGradient,
140 ArrayRef<Value>(retrievedArguments));
141 }
142 }
143 return success();
144 }
145
146 SmallVector<Value> cacheValues(Operation *op,
147 MGradientUtilsReverse *gutils) const {
148 auto storeOp = cast<memref::StoreOp>(op);
149 Value memref = storeOp.getMemref();
150 ValueRange indices = storeOp.getIndices();
151 Value val = storeOp.getValue();
152 if (auto iface = dyn_cast<AutoDiffTypeInterface>(val.getType())) {
153 if (!gutils->isConstantValue(memref)) {
154 OpBuilder cacheBuilder(gutils->getNewFromOriginal(op));
155 SmallVector<Value> caches;
156 caches.push_back(gutils->initAndPushCache(
157 gutils->invertPointerM(memref, cacheBuilder), cacheBuilder));
158 for (Value v : indices) {
159 caches.push_back(gutils->initAndPushCache(
160 gutils->getNewFromOriginal(v), cacheBuilder));
161 }
162 return caches;
163 }
164 }
165 return SmallVector<Value>();
166 }
167
168 void createShadowValues(Operation *op, OpBuilder &builder,
169 MGradientUtilsReverse *gutils) const {
170 // auto storeOp = cast<memref::StoreOp>(op);
171 // Value memref = storeOp.getMemref();
172 // Value shadow = gutils->getShadowValue(memref);
173 // Do nothing yet. In the future support memref<memref<...>>
174 }
175};
176
177struct SubViewOpInterfaceReverse
178 : public ReverseAutoDiffOpInterface::ExternalModel<
179 SubViewOpInterfaceReverse, memref::SubViewOp> {
180 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
181 MGradientUtilsReverse *gutils,
182 SmallVector<Value> caches) const {
183 return success();
184 }
185
186 SmallVector<Value> cacheValues(Operation *op,
187 MGradientUtilsReverse *gutils) const {
188 return SmallVector<Value>();
189 }
190
191 void createShadowValues(Operation *op, OpBuilder &builder,
192 MGradientUtilsReverse *gutils) const {
193 auto subviewOp = cast<memref::SubViewOp>(op);
194 auto newSubviewOp = cast<memref::SubViewOp>(gutils->getNewFromOriginal(op));
195 if (!gutils->isConstantValue(subviewOp.getSource())) {
196 Value shadow = memref::SubViewOp::create(
197 builder, op->getLoc(), newSubviewOp.getType(),
198 gutils->invertPointerM(subviewOp.getSource(), builder),
199 newSubviewOp.getMixedOffsets(), newSubviewOp.getMixedSizes(),
200 newSubviewOp.getMixedStrides());
201 gutils->setInvertedPointer(subviewOp, shadow);
202 }
203 }
204};
205
206class MemRefClonableTypeInterface
207 : public ClonableTypeInterface::ExternalModel<MemRefClonableTypeInterface,
208 MemRefType> {
209
210public:
211 mlir::Value cloneValue(mlir::Type self, OpBuilder &builder,
212 Value value) const {
213 MemRefType MT = cast<MemRefType>(self);
214 SmallVector<Value> dynamicSizes;
215
216 for (auto [i, s] : llvm::enumerate(MT.getShape())) {
217 if (s == ShapedType::kDynamic) {
218 Value dim = arith::ConstantIndexOp::create(builder, value.getLoc(), i);
219 dynamicSizes.push_back(
220 memref::DimOp::create(builder, value.getLoc(), value, dim));
221 }
222 }
223
224 auto clone =
225 memref::AllocOp::create(builder, value.getLoc(), self, dynamicSizes);
226 memref::CopyOp::create(builder, value.getLoc(), value, clone);
227
228 return clone;
229 }
230
231 void freeClonedValue(mlir::Type self, OpBuilder &builder, Value value) const {
232 memref::DeallocOp::create(builder, value.getLoc(), value);
233 };
234};
235
236class MemRefAutoDiffTypeInterface
237 : public AutoDiffTypeInterface::ExternalModel<MemRefAutoDiffTypeInterface,
238 MemRefType> {
239public:
240 mlir::Attribute createNullAttr(mlir::Type self) const {
241 llvm_unreachable("Cannot create null of memref (todo polygeist null)");
242 }
243 mlir::Value createNullValue(mlir::Type self, OpBuilder &builder,
244 Location loc) const {
245 // Create a memref of the given type with the required number of
246 // dynamic dimensions, all set to 0
247 MemRefType MT = cast<MemRefType>(self);
248 unsigned numDynamicDims = MT.getNumDynamicDims();
249 SmallVector<mlir::Value> dynamicSizes(numDynamicDims);
250 for (unsigned i = 0; i < numDynamicDims; ++i) {
251 dynamicSizes[i] = builder.create<mlir::arith::ConstantIndexOp>(loc, 0);
252 }
253 return mlir::memref::AllocOp::create(builder, loc, MT, dynamicSizes);
254 }
255
256 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
257 Value b) const {
258 llvm_unreachable("TODO");
259 }
260
261 Type getShadowType(Type self, unsigned width) const {
262 assert(width == 1 && "unsupported width != 1");
263 return self;
264 }
265
266 Value createConjOp(Type self, OpBuilder &builder, Location loc,
267 Value a) const {
268 llvm_unreachable("TODO");
269 }
270
271 bool isMutable(Type self) const { return true; }
272
273 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
274 Value val) const {
275 auto MT = cast<MemRefType>(self);
276 if (auto iface = dyn_cast<AutoDiffTypeInterface>(MT.getElementType())) {
277 if (!iface.isMutable()) {
278 Value zero = iface.createNullValue(builder, loc);
279 linalg::FillOp::create(builder, loc, zero, val);
280 }
281 } else {
282 return failure();
283 }
284 return success();
285 }
286
287 bool isZero(Type self, Value val) const { return false; }
288 bool isZeroAttr(Type self, Attribute val) const { return false; }
289};
290} // namespace
291
293 DialectRegistry &registry) {
294 registry.addExtension(+[](MLIRContext *context, memref::MemRefDialect *) {
295 registerInterfaces(context);
296 MemRefType::attachInterface<MemRefAutoDiffTypeInterface>(*context);
297 MemRefType::attachInterface<MemRefClonableTypeInterface>(*context);
298
299 memref::LoadOp::attachInterface<LoadOpInterfaceReverse>(*context);
300 memref::StoreOp::attachInterface<StoreOpInterfaceReverse>(*context);
301 memref::SubViewOp::attachInterface<SubViewOpInterfaceReverse>(*context);
302 });
303}
Type getShadowType(Type type, unsigned width)
Operation * clone(Operation *src, IRMapping &mapper, Operation::CloneOptions options, std::map< Operation *, Operation * > &opMap)
static bool isZero(llvm::Constant *cst)
static bool isMutable(Type type)
Definition Ops.cpp:235
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
Value popCache(Value cache, OpBuilder &builder)
Value initAndPushCache(Value v, OpBuilder &builder)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
mlir::Type getShadowType(mlir::Type T)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
void setInvertedPointer(mlir::Value origv, mlir::Value newv)
bool isConstantValue(mlir::Value v) const
void registerMemRefDialectAutoDiffInterface(DialectRegistry &registry)