Enzyme main
Loading...
Searching...
No Matches
LinalgAutoDiffOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- LinalgAutoDiffOpInterfaceImpl.cpp - Interface external model -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the external model implementation of the automatic
10// differentiation op interfaces for the upstream MLIR linalg dialect.
11//
12//===----------------------------------------------------------------------===//
13
19
20#include "mlir/Dialect/Affine/IR/AffineOps.h"
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Linalg/IR/Linalg.h"
23#include "mlir/Dialect/MemRef/IR/MemRef.h"
24
25#include "mlir/IR/DialectRegistry.h"
26#include "mlir/Support/LogicalResult.h"
27#include <functional>
28
29#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
30#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
31#include "mlir/IR/PatternMatch.h"
32#include "mlir/Transforms/DialectConversion.h"
33
34using namespace mlir;
35using namespace mlir::enzyme;
36
37namespace {
38
39Value invertMemref(Value inp, OpBuilder &builder, Location loc) {
40 MemRefType iType = cast<MemRefType>(inp.getType());
41 SmallVector<Value> dims;
42 SmallVector<Value> dimSubOnes;
43 SmallVector<Value> strides;
44 Value negOne = arith::ConstantIndexOp::create(builder, loc, -1);
45 int shapeDim = iType.getShape().size();
46 for (int i = 0; i < shapeDim; i++) {
47 Value dim = memref::DimOp::create(builder, loc, inp, i);
48 dims.push_back(dim);
49 auto dimSubOne = arith::AddIOp::create(builder, loc, dim, negOne);
50 dimSubOnes.push_back(dimSubOne);
51 strides.push_back(negOne);
52 }
53 Value view =
54 memref::SubViewOp::create(builder, loc, inp, ValueRange(dimSubOnes),
55 ValueRange(dims), ValueRange(strides));
56 return view;
57}
58
59SmallVector<AffineMap> getIndexingMapsArray(enzyme::GenericAdjointOp &op) {
60 auto attr = op.getIndexingMapsAttr();
61 SmallVector<AffineMap> indexingMaps;
62 for (auto map : attr.getValue()) {
63 indexingMaps.push_back(cast<AffineMapAttr>(map).getValue());
64 }
65 return indexingMaps;
66}
67
68template <typename T_>
69struct GenericOpInterfaceReverse
70 : public ReverseAutoDiffOpInterface::ExternalModel<
71 GenericOpInterfaceReverse<T_>, T_> {
72 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
74 SmallVector<Value> caches) const {
75 auto linalgOp = cast<linalg::LinalgOp>(op);
76 if (!linalgOp.hasPureBufferSemantics()) {
77 llvm::errs() << "Linalg op with tensor semantics not yet supported\n";
78 return failure();
79 }
80
81 linalg::LinalgOp newOp =
82 cast<linalg::LinalgOp>(gutils->getNewFromOriginal(linalgOp));
83
84 // Replace the op by a linalg.generic op if necessary
85 IRRewriter rewriter(builder.getContext(), builder.getListener());
86 auto failiureOrLinalgOp = generalizeNamedOp(rewriter, newOp);
87 if (!failed(failiureOrLinalgOp)) {
88 linalg::GenericOp replacement = failiureOrLinalgOp.value();
89 auto scope = OpBuilder::InsertionGuard(builder);
90 builder.setInsertionPointAfter(newOp);
91 builder.insert(replacement);
92 newOp.erase();
93 newOp = replacement;
94 }
95
96 auto cacheBuilder = OpBuilder(newOp, builder.getListener());
97
98 // Calculate the iteration domain
99 AffineMap aMap = newOp.getShapesToLoopsMap();
100 SmallVector<Value> dims;
101 for (OpOperand *input : newOp.getDpsInputOperands()) {
102 auto shape = cast<MemRefType>(input->get().getType()).getShape();
103 for (unsigned i = 0; i < shape.size(); i++) {
104 auto dimI =
105 arith::ConstantIndexOp::create(cacheBuilder, op->getLoc(), i);
106 auto dim = memref::DimOp::create(cacheBuilder, op->getLoc(),
107 input->get(), dimI);
108 dims.push_back(dim);
109 }
110 }
111 for (Value output : newOp.getDpsInits()) {
112 auto shape = cast<MemRefType>(output.getType()).getShape();
113 for (unsigned i = 0; i < shape.size(); i++) {
114 auto dimI =
115 arith::ConstantIndexOp::create(cacheBuilder, op->getLoc(), i);
116 auto dim =
117 memref::DimOp::create(cacheBuilder, op->getLoc(), output, dimI);
118 dims.push_back(dim);
119 }
120 }
121
122 SmallVector<Value> iterationDomains;
123 SmallVector<int64_t> shapes;
124 for (unsigned int i = 0; i < aMap.getNumResults(); i++) {
125 AffineMap subMap = aMap.getSubMap({i});
126 Value domain = affine::AffineApplyOp::create(cacheBuilder, op->getLoc(),
127 subMap, ValueRange(dims));
128 iterationDomains.push_back(domain);
129 shapes.push_back(ShapedType::kDynamic);
130 }
131 //
132
133 SmallVector<Value> inputs, outputs;
134 SmallVector<AffineMap> indexingMaps;
135 SmallVector<utils::IteratorType> iteratorTypes{
136 linalgOp.getNumLoops(), utils::IteratorType::parallel};
137
138 for (OpOperand &output : linalgOp.getDpsInitsMutable()) {
139 if (gutils->isConstantValue(output.get())) {
140 continue;
141 }
142 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&output));
143 Value out = gutils->invertPointerM(output.get(), builder);
144 Value view = invertMemref(out, builder, op->getLoc());
145 outputs.push_back(view);
146 }
147
148 for (OpOperand *input : linalgOp.getDpsInputOperands()) {
149 if (gutils->isConstantValue(input->get())) {
150 continue;
151 }
152 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(input));
153 Value inp = gutils->invertPointerM(input->get(), builder);
154 Value view = invertMemref(inp, builder, op->getLoc());
155 inputs.push_back(view);
156 }
157
158 ArrayAttr indexingMapsArrayAttr =
159 builder.getAffineMapArrayAttr(indexingMaps);
160 ArrayAttr iteratorTypesArrayAttr =
161 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
162 iteratorTypes, [&](utils::IteratorType iter) -> mlir::Attribute {
163 return linalg::IteratorTypeAttr::get(builder.getContext(), iter);
164 })));
165 auto adjoint = enzyme::GenericAdjointOp::create(
166 builder, op->getLoc(), TypeRange(), ValueRange(outputs),
167 ValueRange(inputs), indexingMapsArrayAttr, iteratorTypesArrayAttr,
168 StringAttr(), StringAttr());
169
170 int numInputs = inputs.size();
171 auto buildFuncReturnOp = [&gutils, numInputs](OpBuilder &builder,
172 Block *oBB) {
173 auto loc = oBB->rbegin()->getLoc();
174 SmallVector<Value> retargs;
175 for (auto arg : oBB->getArguments()) {
176 retargs.push_back(gutils->invertPointerM(arg, builder));
177 }
178 enzyme::AddToOp::create(builder, loc,
179 ValueRange{retargs}.take_front(numInputs));
180 return;
181 };
182
183 Region *newOpRegion = newOp.getBlock()->getParent();
184 Region *adjointRegion = &adjoint.getRegion();
185 int numInputsAdjoint = adjoint.getInputs().size();
186 Location loc = op->getLoc();
187 int numCaches = 0;
188 SmallVector<Value> pushCaches;
189
190 auto hook = [newOpRegion, adjointRegion, loc, &numCaches = numCaches,
191 numInputsAdjoint, &pushCaches = pushCaches](Type t) {
192 OpBuilder builder(newOpRegion);
193 Value pushCache = enzyme::InitOp::create(builder, loc, t);
194 pushCaches.push_back(pushCache);
195 newOpRegion->addArgument(t, loc);
196
197 Value popCache =
198 adjointRegion->insertArgument(numInputsAdjoint + numCaches, t, loc);
199 numCaches++;
200 return std::make_pair(pushCache, popCache);
201 };
202
203 auto sub = gutils->Logic.differentiate(
204 gutils, *linalgOp.getBlock()->getParent(), adjoint.getRegion(),
205 buildFuncReturnOp, hook);
206 if (!sub.succeeded())
207 return sub;
208
209 auto newOpYield = cast<linalg::YieldOp>(
210 cast<linalg::GenericOp>(newOp).getBodyRegion().front().getTerminator());
211 for (Value pc : pushCaches) {
212 newOpYield.getValuesMutable().append(pc);
213 }
214
215 Block *body = &(adjoint.getRegion().front());
216 auto yieldOp = cast<enzyme::AddToOp>(body->getTerminator());
217 for (auto opOperand : yieldOp.getOperands()) {
218 body->addArgument(opOperand.getType(), opOperand.getLoc());
219 }
220
221 OpBuilder builderAdd(yieldOp);
222
223 auto newIndexingMaps = newOp.getIndexingMapsArray();
224 auto indexingMapsAdjoint = getIndexingMapsArray(adjoint);
225 for (int i = 0; i < numCaches; i++) {
226 Value cacheArg = body->getArgument(outputs.size() + i);
227
228 Type ct = cacheArg.getType();
229 Type type = MemRefType::get(shapes, ct);
230 auto alloc = memref::AllocOp::create(cacheBuilder, op->getLoc(), type,
231 ValueRange(iterationDomains));
232 Value cache = gutils->initAndPushCache(alloc, cacheBuilder);
233 // TODO use higher level API
234 alloc->setAttr(alloc.getOperandSegmentSizesAttrName(),
235 cacheBuilder.getDenseI32ArrayAttr(
236 {static_cast<int32_t>(iterationDomains.size()), 0}));
237
238 cast<linalg::GenericOp>(newOp).getOutputsMutable().append(
239 ValueRange({alloc}));
240 newIndexingMaps.push_back(AffineMap::getMultiDimIdentityMap(
241 iterationDomains.size(), cacheBuilder.getContext()));
242
243 builderAdd.setInsertionPoint(adjoint);
244 Value retrievedValue = gutils->popCache(cache, builderAdd);
245 retrievedValue = invertMemref(retrievedValue, builderAdd, op->getLoc());
246 adjoint.getInputsMutable().append(ValueRange({retrievedValue}));
247 indexingMapsAdjoint.insert(
248 indexingMapsAdjoint.begin() + numInputsAdjoint + i,
249 AffineMap::getMultiDimIdentityMap(iterationDomains.size(),
250 builderAdd.getContext()));
251 }
252 SmallVector<Attribute> indexingMapsAttr;
253 SmallVector<Attribute> indexingMapsAttrAdjoint;
254 for (auto &map : newIndexingMaps) {
255 indexingMapsAttr.push_back(AffineMapAttr::get(map));
256 }
257 for (auto &map : indexingMapsAdjoint) {
258 indexingMapsAttrAdjoint.push_back(AffineMapAttr::get(map));
259 }
260 cast<linalg::GenericOp>(newOp).setIndexingMapsAttr(
261 cacheBuilder.getArrayAttr(indexingMapsAttr));
262 adjoint->setAttr(adjoint.getIndexingMapsAttrName(),
263 builder.getArrayAttr(indexingMapsAttrAdjoint));
264 return success();
265 }
266
267 SmallVector<Value> cacheValues(Operation *op,
268 MGradientUtilsReverse *gutils) const {
269 return SmallVector<Value>();
270 }
271
272 void createShadowValues(Operation *op, OpBuilder &builder,
273 MGradientUtilsReverse *gutils) const {}
274};
275
276class GenericFwd
277 : public AutoDiffOpInterface::ExternalModel<GenericFwd, linalg::GenericOp> {
278public:
279 LogicalResult createForwardModeTangent(Operation *orig, OpBuilder &builder,
280 MGradientUtils *gutils) const {
281
282 auto op = cast<linalg::GenericOp>(orig);
283
284 // For all active results, add shadow types.
285 // For now, assuming all results are relevant.
286 Operation *newOp = gutils->getNewFromOriginal(op);
287 SmallVector<Type> newOpResultTypes;
288 newOpResultTypes.reserve(op->getNumResults() * 2);
289 for (auto &&[result, init] :
290 llvm::zip_equal(op->getResults(), op.getOutputs())) {
291 newOpResultTypes.push_back(result.getType());
292 if (gutils->isConstantValue(result) && gutils->isConstantValue(init)) {
293 continue;
294 }
295 auto typeIface = dyn_cast<AutoDiffTypeInterface>(result.getType());
296 if (!typeIface) {
297 op->emitError() << " AutoDiffTypeInterface not implemented for "
298 << result.getType() << "\n";
299 return failure();
300 }
301 newOpResultTypes.push_back(typeIface.getShadowType(gutils->width));
302 }
303
304 SmallVector<Value> newInputs;
305 SmallVector<Value> newOutputs;
306 SmallVector<AffineMap> indexingMaps;
307 {
308 size_t idx = 0;
309 for (Value operand : op.getInputs()) {
310 newInputs.push_back(gutils->getNewFromOriginal(operand));
311 indexingMaps.push_back(op.getIndexingMapsArray()[idx]);
312 if (!gutils->isConstantValue(operand)) {
313 newInputs.push_back(gutils->invertPointerM(operand, builder));
314 indexingMaps.push_back(op.getIndexingMapsArray()[idx]);
315 }
316 idx++;
317 }
318 for (auto &&[operand, res, oarg] :
319 llvm::zip_equal(op.getOutputs(), op->getResults(),
320 op.getRegion().front().getArguments().slice(
321 op.getInputs().size()))) {
322 newOutputs.push_back(gutils->getNewFromOriginal(operand));
323 indexingMaps.push_back(op.getIndexingMapsArray()[idx]);
324 bool shadow = false;
325 if (!gutils->isConstantValue(operand)) {
326 shadow = true;
327 newOutputs.push_back(gutils->invertPointerM(operand, builder));
328 indexingMaps.push_back(op.getIndexingMapsArray()[idx]);
329 } else if (!gutils->isConstantValue(res)) {
330 auto typeIface = dyn_cast<AutoDiffTypeInterface>(operand.getType());
331 shadow = true;
332 newOutputs.push_back(
333 typeIface.createNullValue(builder, operand.getLoc()));
334 indexingMaps.push_back(op.getIndexingMapsArray()[idx]);
335 }
336
337 if (shadow && gutils->isConstantValue(oarg)) {
338 auto typeIface = dyn_cast<AutoDiffTypeInterface>(oarg.getType());
339 auto newBA = cast<BlockArgument>(gutils->getNewFromOriginal(oarg));
340 newBA.getOwner()->insertArgument(newBA.getArgNumber() + 1,
341 typeIface.getShadowType(),
342 newBA.getLoc());
343 }
344
345 idx++;
346 }
347 }
348 // We are assuming the op can forward additional operands, listed
349 // immediately after the original operands, to the same regions.
350 // ^^
351 // Our interface guarantees this.
352 // We also assume that the region-holding op returns all of the values
353 // yielded by terminators, and only those values.
354
355 auto replacement = linalg::GenericOp::create(
356 builder, op.getLoc(), newOpResultTypes, newInputs, newOutputs,
357 indexingMaps, op.getIteratorTypesArray(),
358 /*doc*/ "",
359 /*libraryCall*/ "");
360
361 assert(replacement->getNumResults() == newOpResultTypes.size());
362 for (auto &&[region, replacementRegion] :
363 llvm::zip(newOp->getRegions(), replacement->getRegions())) {
364 replacementRegion.takeBody(region);
365 }
366
367 // Inject the mapping for the new results into GradientUtil's shadow
368 // table.
369 SmallVector<Value> reps;
370 size_t idx = 0;
371 for (OpResult r : op->getResults()) {
372 // TODO only if used
373 reps.push_back(replacement->getResult(idx));
374 idx++;
375 if (!gutils->isConstantValue(r)) {
376 auto inverted = gutils->invertedPointers.lookupOrNull(r);
377 assert(inverted);
378 gutils->invertedPointers.map(r, replacement->getResult(idx));
379 inverted.replaceAllUsesWith(replacement->getResult(idx));
380 gutils->erase(inverted.getDefiningOp());
381 idx++;
382 }
383 }
384
385 // Differentiate body.
386 for (auto &origRegion : op->getRegions()) {
387 for (auto &origBlock : origRegion) {
388 for (Operation &o : origBlock) {
389 if (failed(gutils->visitChild(&o))) {
390 return failure();
391 }
392 }
393 }
394 }
395
396 // Replace all uses of original results
397 gutils->replaceOrigOpWith(op, reps);
398 gutils->erase(newOp);
399 gutils->originalToNewFnOps[op] = replacement;
400
401 return success();
402 }
403};
404
405#include "Implementations/LinalgDerivatives.inc"
406} // namespace
407
408template <typename... Ts> void attachAllInterfaces(MLIRContext *context) {
409 (Ts::template attachInterface<GenericOpInterfaceReverse<Ts>>(*context), ...);
410}
411
413 DialectRegistry &registry) {
414 registry.addExtension(+[](MLIRContext *context, linalg::LinalgDialect *) {
415 registerInterfaces(context);
416 linalg::GenericOp::attachInterface<GenericFwd>(*context);
418#define GET_OP_LIST
419#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
420 >(context);
421 });
422}
void attachAllInterfaces(MLIRContext *context)
LogicalResult differentiate(MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion, llvm::function_ref< buildReturnFunction > buildFuncRetrunOp, std::function< std::pair< Value, Value >(Type)> cacheCreator)
Value popCache(Value cache, OpBuilder &builder)
Value initAndPushCache(Value v, OpBuilder &builder)
void replaceOrigOpWith(Operation *op, ValueRange vals)
std::map< Operation *, Operation * > originalToNewFnOps
LogicalResult visitChild(Operation *op)
void erase(Operation *op)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
bool isConstantValue(mlir::Value v) const
void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry)