20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/IR/DialectRegistry.h"
22#include "mlir/Support/LogicalResult.h"
26#include "mlir/Dialect/Linalg/IR/Linalg.h"
32#include "Implementations/MemRefDerivatives.inc"
34struct LoadOpInterfaceReverse
35 :
public ReverseAutoDiffOpInterface::ExternalModel<LoadOpInterfaceReverse,
37 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
39 SmallVector<Value> caches)
const {
40 auto loadOp = cast<memref::LoadOp>(op);
41 Value memref = loadOp.getMemref();
43 if (
auto iface = dyn_cast<AutoDiffTypeInterface>(loadOp.getType())) {
46 Value gradient = gutils->
diffe(loadOp, builder);
47 Value memrefGradient = gutils->
popCache(caches.front(), builder);
49 SmallVector<Value> retrievedArguments;
50 for (Value cache : ValueRange(caches).drop_front(1)) {
51 Value retrievedValue = gutils->
popCache(cache, builder);
52 retrievedArguments.push_back(retrievedValue);
56 Value loadedGradient =
57 memref::LoadOp::create(builder, loadOp.getLoc(), memrefGradient,
58 ArrayRef<Value>(retrievedArguments));
59 Value addedGradient = iface.createAddOp(builder, loadOp.getLoc(),
60 loadedGradient, gradient);
61 memref::StoreOp::create(builder, loadOp.getLoc(), addedGradient,
63 ArrayRef<Value>(retrievedArguments));
65 memref::AtomicRMWOp::create(
66 builder, loadOp.getLoc(), arith::AtomicRMWKind::addf, gradient,
67 memrefGradient, ArrayRef<Value>(retrievedArguments));
74 SmallVector<Value> cacheValues(Operation *op,
76 auto loadOp = cast<memref::LoadOp>(op);
77 Value memref = loadOp.getMemref();
78 ValueRange indices = loadOp.getIndices();
79 if (
auto iface = dyn_cast<AutoDiffTypeInterface>(loadOp.getType())) {
83 SmallVector<Value> caches;
86 for (Value v : indices) {
93 return SmallVector<Value>();
96 void createShadowValues(Operation *op, OpBuilder &builder,
105struct StoreOpInterfaceReverse
106 :
public ReverseAutoDiffOpInterface::ExternalModel<StoreOpInterfaceReverse,
108 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
110 SmallVector<Value> caches)
const {
111 auto storeOp = cast<memref::StoreOp>(op);
112 Value val = storeOp.getValue();
113 Value memref = storeOp.getMemref();
116 auto iface = cast<AutoDiffTypeInterface>(val.getType());
119 Value memrefGradient = gutils->
popCache(caches.front(), builder);
121 SmallVector<Value> retrievedArguments;
122 for (Value cache : ValueRange(caches).drop_front(1)) {
123 Value retrievedValue = gutils->
popCache(cache, builder);
124 retrievedArguments.push_back(retrievedValue);
127 if (!iface.isMutable()) {
129 Value loadedGradient =
130 memref::LoadOp::create(builder, storeOp.getLoc(), memrefGradient,
131 ArrayRef<Value>(retrievedArguments));
132 gutils->
addToDiffe(val, loadedGradient, builder);
136 cast<AutoDiffTypeInterface>(gutils->
getShadowType(val.getType()))
137 .createNullValue(builder, op->getLoc());
139 memref::StoreOp::create(builder, storeOp.getLoc(), zero, memrefGradient,
140 ArrayRef<Value>(retrievedArguments));
146 SmallVector<Value> cacheValues(Operation *op,
148 auto storeOp = cast<memref::StoreOp>(op);
149 Value memref = storeOp.getMemref();
150 ValueRange indices = storeOp.getIndices();
151 Value val = storeOp.getValue();
152 if (
auto iface = dyn_cast<AutoDiffTypeInterface>(val.getType())) {
155 SmallVector<Value> caches;
158 for (Value v : indices) {
165 return SmallVector<Value>();
168 void createShadowValues(Operation *op, OpBuilder &builder,
177struct SubViewOpInterfaceReverse
178 :
public ReverseAutoDiffOpInterface::ExternalModel<
179 SubViewOpInterfaceReverse, memref::SubViewOp> {
180 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
182 SmallVector<Value> caches)
const {
186 SmallVector<Value> cacheValues(Operation *op,
188 return SmallVector<Value>();
191 void createShadowValues(Operation *op, OpBuilder &builder,
193 auto subviewOp = cast<memref::SubViewOp>(op);
196 Value shadow = memref::SubViewOp::create(
197 builder, op->getLoc(), newSubviewOp.getType(),
199 newSubviewOp.getMixedOffsets(), newSubviewOp.getMixedSizes(),
200 newSubviewOp.getMixedStrides());
206class MemRefClonableTypeInterface
207 :
public ClonableTypeInterface::ExternalModel<MemRefClonableTypeInterface,
211 mlir::Value cloneValue(mlir::Type self, OpBuilder &builder,
213 MemRefType MT = cast<MemRefType>(self);
214 SmallVector<Value> dynamicSizes;
216 for (
auto [i, s] : llvm::enumerate(MT.getShape())) {
217 if (s == ShapedType::kDynamic) {
218 Value dim = arith::ConstantIndexOp::create(builder, value.getLoc(), i);
219 dynamicSizes.push_back(
220 memref::DimOp::create(builder, value.getLoc(), value, dim));
225 memref::AllocOp::create(builder, value.getLoc(), self, dynamicSizes);
226 memref::CopyOp::create(builder, value.getLoc(), value,
clone);
231 void freeClonedValue(mlir::Type self, OpBuilder &builder, Value value)
const {
232 memref::DeallocOp::create(builder, value.getLoc(), value);
236class MemRefAutoDiffTypeInterface
237 :
public AutoDiffTypeInterface::ExternalModel<MemRefAutoDiffTypeInterface,
240 mlir::Attribute createNullAttr(mlir::Type self)
const {
241 llvm_unreachable(
"Cannot create null of memref (todo polygeist null)");
243 mlir::Value createNullValue(mlir::Type self, OpBuilder &builder,
244 Location loc)
const {
247 MemRefType MT = cast<MemRefType>(self);
248 unsigned numDynamicDims = MT.getNumDynamicDims();
249 SmallVector<mlir::Value> dynamicSizes(numDynamicDims);
250 for (
unsigned i = 0; i < numDynamicDims; ++i) {
251 dynamicSizes[i] = builder.create<mlir::arith::ConstantIndexOp>(loc, 0);
253 return mlir::memref::AllocOp::create(builder, loc, MT, dynamicSizes);
256 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
258 llvm_unreachable(
"TODO");
262 assert(width == 1 &&
"unsupported width != 1");
266 Value createConjOp(Type self, OpBuilder &builder, Location loc,
268 llvm_unreachable(
"TODO");
271 bool isMutable(Type self)
const {
return true; }
273 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
275 auto MT = cast<MemRefType>(self);
276 if (
auto iface = dyn_cast<AutoDiffTypeInterface>(MT.getElementType())) {
277 if (!iface.isMutable()) {
278 Value zero = iface.createNullValue(builder, loc);
279 linalg::FillOp::create(builder, loc, zero, val);
287 bool isZero(Type self, Value val)
const {
return false; }
288 bool isZeroAttr(Type self, Attribute val)
const {
return false; }
293 DialectRegistry ®istry) {
294 registry.addExtension(+[](MLIRContext *context, memref::MemRefDialect *) {
295 registerInterfaces(context);
296 MemRefType::attachInterface<MemRefAutoDiffTypeInterface>(*context);
297 MemRefType::attachInterface<MemRefClonableTypeInterface>(*context);
299 memref::LoadOp::attachInterface<LoadOpInterfaceReverse>(*context);
300 memref::StoreOp::attachInterface<StoreOpInterfaceReverse>(*context);
301 memref::SubViewOp::attachInterface<SubViewOpInterfaceReverse>(*context);
Type getShadowType(Type type, unsigned width)
Operation * clone(Operation *src, IRMapping &mapper, Operation::CloneOptions options, std::map< Operation *, Operation * > &opMap)
static bool isZero(llvm::Constant *cst)
static bool isMutable(Type type)
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
Value popCache(Value cache, OpBuilder &builder)
Value initAndPushCache(Value v, OpBuilder &builder)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
mlir::Type getShadowType(mlir::Type T)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
void setInvertedPointer(mlir::Value origv, mlir::Value newv)
bool isConstantValue(mlir::Value v) const
void registerMemRefDialectAutoDiffInterface(DialectRegistry ®istry)