Enzyme main
Loading...
Searching...
No Matches
AffineAutoDiffOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- AffineAutoDiffOpInterfaceImpl.cpp - Interface external model -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the external model implementation of the automatic
10// differentiation op interfaces for the upstream MLIR Affine dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "Dialect/Ops.h"
18#include "Passes/RemovalUtils.h"
19#include "Passes/Utils.h"
20#include "mlir/Dialect/Affine/IR/AffineOps.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/IR/IntegerSet.h"
23#include "llvm/ADT/ScopeExit.h"
24
25using namespace mlir;
26using namespace mlir::enzyme;
27using namespace mlir::affine;
28
29namespace {
30affine::AffineForOp
31createAffineForWithShadows(Operation *op, OpBuilder &builder,
32 MGradientUtils *gutils, Operation *original,
33 ValueRange remappedOperands, TypeRange rettys) {
34 affine::AffineForOpAdaptor adaptor(remappedOperands,
35 cast<affine::AffineForOp>(original));
36 auto repFor = affine::AffineForOp::create(
37 builder, original->getLoc(), adaptor.getLowerBoundOperands(),
38 adaptor.getLowerBoundMap(), adaptor.getUpperBoundOperands(),
39 adaptor.getUpperBoundMap(), adaptor.getStep().getZExtValue(),
40 // This dance is necessary because the adaptor accessors are based on the
41 // internal attribute containing the number of operands associated with
42 // each named operand group. This attribute is carried over from the
43 // original operation and does not account for the shadow-related iter
44 // args. Instead, assume lower/upper bound operands must not have shadows
45 // since they are integer-typed and take the result of operands as iter
46 // args.
47 remappedOperands.drop_front(adaptor.getLowerBoundOperands().size() +
48 adaptor.getUpperBoundOperands().size()));
49 return repFor;
50}
51
52affine::AffineIfOp createAffineIfWithShadows(Operation *op, OpBuilder &builder,
53 MGradientUtils *gutils,
54 affine::AffineIfOp original,
55 ValueRange remappedOperands,
56 TypeRange rettys) {
57 affine::AffineIfOpAdaptor adaptor(remappedOperands, original);
58 return affine::AffineIfOp::create(
59 builder, original->getLoc(), rettys, original.getIntegerSet(),
60 adaptor.getOperands(), !original.getElseRegion().empty());
61}
62
63struct AffineForOpInterfaceReverse
64 : public ReverseAutoDiffOpInterface::ExternalModel<
65 AffineForOpInterfaceReverse, affine::AffineForOp> {
66 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
68 SmallVector<Value> caches) const {
69 auto forOp = cast<affine::AffineForOp>(op);
70
71 affine::AffineBound lb = forOp.getLowerBound();
72 affine::AffineBound ub = forOp.getUpperBound();
73
74 if (lb.getMap().getNumResults() != 1 || ub.getMap().getNumResults() != 1) {
75 op->emitError() << "cannot differentiate loop with minmax bounds yet";
76 return failure();
77 }
78
79 SmallVector<bool> operandsActive;
80 for (auto [operand, result] : llvm::zip_equal(
81 op->getOperands().slice(forOp.getNumControlOperands(),
82 forOp->getNumOperands() -
83 forOp.getNumControlOperands()),
84 op->getResults())) {
85 operandsActive.push_back(!gutils->isConstantValue(operand) ||
86 !gutils->isConstantValue(result));
87 }
88
89 SmallVector<Value> revLBOperands, revUBOperands, incomingGradients;
90
91 for (int i = 0, e = lb.getNumOperands(); i < e; ++i) {
92 revLBOperands.push_back(gutils->popCache(caches[i], builder));
93 }
94
95 for (int i = lb.getNumOperands(), e = forOp.getNumControlOperands(); i < e;
96 ++i) {
97 revUBOperands.push_back(gutils->popCache(caches[i], builder));
98 }
99
100 for (auto &&[active, res] :
101 llvm::zip_equal(operandsActive, op->getResults())) {
102 if (active) {
103 incomingGradients.push_back(gutils->diffe(res, builder));
104 if (!gutils->isConstantValue(res))
105 gutils->zeroDiffe(res, builder);
106 }
107 }
108
109 auto revFor = affine::AffineForOp::create(
110 builder, op->getLoc(), revLBOperands, lb.getMap(), revUBOperands,
111 ub.getMap(), forOp.getStepAsInt(), incomingGradients);
112
113 bool valid = true;
114 for (auto &&[oldReg, newReg] :
115 llvm::zip(op->getRegions(), revFor->getRegions())) {
116 for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
117 OpBuilder bodyBuilder(&revBB, revBB.end());
118
119 // Create implicit terminator if not present (when num results > 0)
120 if (revBB.empty()) {
121 affine::AffineYieldOp::create(bodyBuilder, revFor->getLoc());
122 }
123 bodyBuilder.setInsertionPoint(revBB.getTerminator());
124
125 // All values defined in the body should have no use outside this block
126 // therefore we can set their diffe to zero upon entering the reverse
127 // block to simplify the work of the remove-unnecessary-enzyme-ops pass.
128 for (auto operand : oBB.getArguments().slice(1)) {
129 if (!gutils->isConstantValue(operand)) {
130 gutils->zeroDiffe(operand, bodyBuilder);
131 }
132 }
133
134 for (auto &it : oBB.getOperations()) {
135 for (auto res : it.getResults()) {
136 if (!gutils->isConstantValue(res)) {
137 auto iface = dyn_cast<AutoDiffTypeInterface>(res.getType());
138 if (iface && !iface.isMutable())
139 gutils->zeroDiffe(res, bodyBuilder);
140 }
141 }
142 }
143
144 auto term = oBB.getTerminator();
145
146 unsigned argIdx = 1; // Skip over the reversed IV
147 for (auto &&[active, operand] :
148 llvm::zip_equal(operandsActive, term->getOperands())) {
149 if (active) {
150 // Set diffe here, not add because it should not accumulate across
151 // iterations. Instead the new gradient for this operand is passed
152 // in the return of the reverse for body.
153 gutils->setDiffe(operand, revBB.getArgument(argIdx), bodyBuilder);
154 argIdx++;
155 }
156 }
157
158 auto first = oBB.rbegin();
159 first++; // skip terminator
160
161 auto last = oBB.rend();
162
163 for (auto it = first; it != last; ++it) {
164 Operation *op = &*it;
165 valid &=
166 gutils->Logic.visitChild(op, bodyBuilder, gutils).succeeded();
167 }
168
169 SmallVector<Value> newResults;
170 newResults.reserve(incomingGradients.size());
171
172 for (auto &&[active, arg] :
173 llvm::zip_equal(operandsActive, oBB.getArguments().slice(1))) {
174 if (active) {
175 newResults.push_back(gutils->diffe(arg, bodyBuilder));
176 if (!gutils->isConstantValue(arg))
177 gutils->zeroDiffe(arg, bodyBuilder);
178 }
179 }
180
181 // yield new gradient values
182 revBB.getTerminator()->setOperands(newResults);
183 }
184 }
185
186 unsigned resIdx = 0;
187 for (auto &&[active, arg] :
188 llvm::zip_equal(operandsActive, forOp.getInits())) {
189 if (active) {
190 if (!gutils->isConstantValue(arg)) {
191 gutils->addToDiffe(arg, revFor.getResult(resIdx), builder);
192 resIdx++;
193 }
194 }
195 }
196
197 return success(valid);
198 }
199
200 SmallVector<Value> cacheValues(Operation *op,
201 MGradientUtilsReverse *gutils) const {
202 auto forOp = cast<affine::AffineForOp>(op);
203
204 SmallVector<Value> caches;
205 OpBuilder cacheBuilder(gutils->getNewFromOriginal(op));
206 for (auto operand : forOp.getControlOperands()) {
207 caches.push_back(gutils->initAndPushCache(
208 gutils->getNewFromOriginal(operand), cacheBuilder));
209 }
210
211 return caches;
212 }
213
214 void createShadowValues(Operation *op, OpBuilder &builder,
215 MGradientUtilsReverse *gutils) const {}
216};
217
218struct AffineParallelOpInterfaceReverse
219 : public ReverseAutoDiffOpInterface::ExternalModel<
220 AffineParallelOpInterfaceReverse, affine::AffineParallelOp> {
221 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
222 MGradientUtilsReverse *gutils,
223 SmallVector<Value> caches) const {
224 auto parOp = cast<affine::AffineParallelOp>(op);
225 if (!parOp.getReductions().empty()) {
226 return parOp.emitError() << "parallel reductions not yet implemented";
227 }
228 if (parOp.hasMinMaxBounds()) {
229 return parOp.emitError() << "minmax bounds not yet supported";
230 }
231
232 SmallVector<Value> bounds = llvm::map_to_vector(
233 caches, [&](Value cache) { return gutils->popCache(cache, builder); });
234 auto revPar = affine::AffineParallelOp::create(
235 builder, op->getLoc(), parOp.getResultTypes(), parOp.getReductions(),
236 parOp.getLowerBoundsMap(), parOp.getLowerBoundsGroups(),
237 parOp.getUpperBoundsMap(), parOp.getUpperBoundsGroups(),
238 parOp.getSteps(), bounds);
239
240 // Create the body block and terminator
241 OpBuilder::InsertionGuard guard(builder);
242 SmallVector<Type> ivTypes(parOp.getIVs().size(), builder.getIndexType());
243 SmallVector<Location> ivLocs(parOp.getIVs().size(), parOp.getLoc());
244 builder.createBlock(&revPar.getBodyRegion(), revPar.getBodyRegion().begin(),
245 ivTypes, ivLocs);
246 affine::AffineYieldOp::create(builder, parOp.getLoc());
247
248 bool valid = true;
249 bool wasAtomic = gutils->AtomicAdd;
250 gutils->AtomicAdd = true;
251
252 {
253 Block *oBB = parOp.getBody();
254 Block *rBB = revPar.getBody();
255
256 OpBuilder bodyBuilder = revPar.getBodyBuilder();
257
258 bodyBuilder.setInsertionPointToStart(revPar.getBody());
259 mlir::enzyme::localizeGradients(bodyBuilder, gutils, oBB);
260
261 bodyBuilder.setInsertionPoint(rBB->getTerminator());
262
263 auto first = oBB->rbegin();
264 first++; // skip terminator
265
266 auto last = oBB->rend();
267
268 for (auto it = first; it != last; ++it) {
269 Operation *op = &*it;
270 valid &= gutils->Logic.visitChild(op, bodyBuilder, gutils).succeeded();
271 }
272 }
273
274 gutils->AtomicAdd = wasAtomic;
275 return success(valid);
276 }
277
278 SmallVector<Value> cacheValues(Operation *op,
279 MGradientUtilsReverse *gutils) const {
280 auto parOp = cast<affine::AffineParallelOp>(op);
281
282 SmallVector<Value> caches;
283 OpBuilder cacheBuilder(gutils->getNewFromOriginal(op));
284 for (auto operand : parOp.getMapOperands()) {
285 caches.push_back(gutils->initAndPushCache(
286 gutils->getNewFromOriginal(operand), cacheBuilder));
287 }
288 return caches;
289 }
290
291 void createShadowValues(Operation *op, OpBuilder &builder,
292 MGradientUtilsReverse *gutils) const {}
293};
294
295struct AffineParallelOpEnzymeOpsRemover
296 : public ForLikeEnzymeOpsRemover<AffineParallelOpEnzymeOpsRemover,
297 affine::AffineParallelOp> {
298 static SmallVector<IntOrValue, 1>
299 getDimensionBounds(OpBuilder &builder, affine::AffineParallelOp parOp) {
300 SmallVector<IntOrValue, 1> bounds;
301 auto ranges = parOp.getConstantRanges();
302 if (ranges) {
303 for (auto &&[r, step] : llvm::zip(*ranges, parOp.getSteps())) {
304 bounds.push_back(r / step);
305 }
306 } else {
307 for (auto &&[dim, step] : llvm::enumerate(parOp.getSteps())) {
308 auto lb = AffineApplyOp::create(builder, parOp.getLoc(),
309 parOp.getLowerBoundMap(dim),
310 parOp.getLowerBoundsOperands());
311 auto ub = AffineApplyOp::create(builder, parOp.getLoc(),
312 parOp.getUpperBoundMap(dim),
313 parOp.getUpperBoundsOperands());
314 Value diff = arith::SubIOp::create(builder, parOp.getLoc(), ub, lb);
315 if (step != 1) {
316 Value stepVal =
317 arith::ConstantIndexOp::create(builder, parOp.getLoc(), step);
318 diff = arith::DivUIOp::create(builder, parOp.getLoc(), diff, stepVal);
319 }
320 bounds.push_back(diff);
321 }
322 }
323 return bounds;
324 }
325
326 static SmallVector<Value> computeReversedIndices(
327 PatternRewriter &rewriter, affine::AffineParallelOp parOp,
328 ArrayRef<Value> otherInductionVariable, ArrayRef<IntOrValue> bounds) {
329 return SmallVector<Value>(otherInductionVariable);
330 }
331
332 static SmallVector<Value>
333 getCanonicalLoopIVs(OpBuilder &builder, affine::AffineParallelOp parOp) {
334 SmallVector<Value> ivs(parOp.getIVs());
335 for (auto &&[dim, step] : llvm::enumerate(parOp.getSteps())) {
336 Value iv = ivs[dim];
337 auto lbMap = parOp.getLowerBoundMap(dim);
338 if (!(lbMap.isSingleConstant() && lbMap.getSingleConstantResult() == 0)) {
339 auto lb = AffineApplyOp::create(builder, parOp.getLoc(), lbMap,
340 parOp.getLowerBoundsOperands());
341 iv = arith::SubIOp::create(builder, parOp.getLoc(), iv, lb);
342 }
343
344 if (step != 1) {
345 auto stepVal =
346 arith::ConstantIndexOp::create(builder, parOp.getLoc(), step);
347 iv = arith::DivUIOp::create(builder, parOp.getLoc(), iv, stepVal);
348 }
349
350 ivs[dim] = iv;
351 }
352 return ivs;
353 }
354
355 static IRMapping createArgumentMap(PatternRewriter &rewriter,
356 affine::AffineParallelOp parOp,
357 ArrayRef<Value> indPar,
358 affine::AffineParallelOp otherParOp,
359 ArrayRef<Value> indOther) {
360 IRMapping map;
361 for (auto &&[f, o] : llvm::zip_equal(indPar, indOther))
362 map.map(f, o);
363
364 for (auto &&[fiv, oiv] :
365 llvm::zip_equal(parOp.getIVs(), otherParOp.getIVs())) {
366 if (!map.contains(fiv)) {
367 assert(parOp.getLowerBoundsMap() == otherParOp.getLowerBoundsMap());
368 for (auto &&[f, o] :
369 llvm::zip_equal(parOp.getLowerBoundsOperands(),
370 otherParOp.getLowerBoundsOperands())) {
371 (void)f;
372 (void)o;
373 assert(Equivalent(f, o));
374 }
375 for (auto [fstep, ostep] :
376 llvm::zip_equal(parOp.getSteps(), otherParOp.getSteps())) {
377 (void)fstep;
378 (void)ostep;
379 assert(fstep == ostep);
380 }
381 map.map(fiv, oiv);
382 }
383 }
384 return map;
385 }
386
387 static affine::AffineParallelOp
388 replaceWithNewOperands(PatternRewriter &rewriter,
389 affine::AffineParallelOp otherParOp,
390 ArrayRef<Value> operands) {
391 SmallVector<mlir::Attribute> reductionKinds(
392 otherParOp.getReductions().begin(), otherParOp.getReductions().end());
393
394 for (unsigned i = otherParOp->getNumOperands(); i < operands.size(); i++) {
395 reductionKinds.push_back(arith::AtomicRMWKindAttr::get(
396 otherParOp.getContext(), arith::AtomicRMWKind::addf));
397 }
398
399 ValueRange operands_(operands);
400 auto newOtherParOp = affine::AffineParallelOp::create(
401 rewriter, otherParOp.getLoc(), operands_.getTypes(),
402 ArrayAttr::get(otherParOp.getContext(), reductionKinds),
403 otherParOp.getLowerBoundsMap(), otherParOp.getLowerBoundsGroups(),
404 otherParOp.getUpperBoundsMap(), otherParOp.getUpperBoundsGroups(),
405 otherParOp.getSteps(), otherParOp.getMapOperands());
406
407 newOtherParOp.getRegion().takeBody(otherParOp.getRegion());
408 rewriter.replaceOp(otherParOp, newOtherParOp->getResults().slice(
409 0, otherParOp->getNumResults()));
410 return newOtherParOp;
411 }
412
413 static ValueRange getInits(affine::AffineParallelOp parOp) {
414 return parOp.getInits();
415 }
416
417 static bool mustPostAdd(affine::AffineParallelOp forOp) { return true; }
418
419 static Value initialValueInBlock(OpBuilder &builder, Block *body,
420 Value grad) {
421 OpBuilder::InsertionGuard guard(builder);
422 builder.setInsertionPointToStart(body);
423 return cast<AutoDiffTypeInterface>(
424 cast<enzyme::GradientType>(grad.getType()).getBasetype())
425 .createNullValue(builder, grad.getLoc());
426 }
427};
428
429struct AffineLoadOpInterfaceReverse
430 : public ReverseAutoDiffOpInterface::ExternalModel<
431 AffineLoadOpInterfaceReverse, affine::AffineLoadOp> {
432 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
433 MGradientUtilsReverse *gutils,
434 SmallVector<Value> caches) const {
435 auto loadOp = cast<affine::AffineLoadOp>(op);
436 Value memref = loadOp.getMemref();
437
438 if (auto iface = dyn_cast<AutoDiffTypeInterface>(loadOp.getType())) {
439 if (!gutils->isConstantValue(loadOp) &&
440 !gutils->isConstantValue(memref)) {
441 Value gradient = gutils->diffe(loadOp, builder);
442 Value memrefGradient = gutils->popCache(caches.front(), builder);
443
444 SmallVector<Value> retrievedArguments;
445 for (Value cache : ValueRange(caches).drop_front(1)) {
446 Value retrievedValue = gutils->popCache(cache, builder);
447 retrievedArguments.push_back(retrievedValue);
448 }
449
450 if (!gutils->AtomicAdd) {
451 bool hasIndex = loadOp.getAffineMap().getNumDims() > 0;
452 // if index had to be cached, the pop is not necessarily a valid index
453 if (hasIndex) {
454 SmallVector<Value> indices;
455 computeAffineIndices(builder, loadOp.getLoc(),
456 loadOp.getAffineMap(), retrievedArguments,
457 indices);
458
459 Value loadedGradient = memref::LoadOp::create(
460 builder, loadOp.getLoc(), memrefGradient, indices);
461 Value addedGradient = iface.createAddOp(builder, loadOp.getLoc(),
462 loadedGradient, gradient);
463 memref::StoreOp::create(builder, loadOp.getLoc(), addedGradient,
464 memrefGradient, indices);
465 } else {
466 Value loadedGradient = affine::AffineLoadOp::create(
467 builder, loadOp.getLoc(), memrefGradient, loadOp.getAffineMap(),
468 ArrayRef<Value>(retrievedArguments));
469 Value addedGradient = iface.createAddOp(builder, loadOp.getLoc(),
470 loadedGradient, gradient);
471 affine::AffineStoreOp::create(
472 builder, loadOp.getLoc(), addedGradient, memrefGradient,
473 loadOp.getAffineMap(), ArrayRef<Value>(retrievedArguments));
474 }
475 } else {
476 bool hasIndex = loadOp.getAffineMap().getNumDims() > 0;
477 // if index had to be cached, the pop is not necessarily a valid index
478 if (hasIndex) {
479 SmallVector<Value> indices;
480 computeAffineIndices(builder, loadOp.getLoc(),
481 loadOp.getAffineMap(), retrievedArguments,
482 indices);
483 memref::AtomicRMWOp::create(builder, loadOp.getLoc(),
484 arith::AtomicRMWKind::addf, gradient,
485 memrefGradient, indices);
486 } else {
487 enzyme::AffineAtomicRMWOp::create(
488 builder, loadOp.getLoc(), gradient.getType(),
489 arith::AtomicRMWKind::addf, gradient, memrefGradient,
490 retrievedArguments, loadOp.getAffineMap());
491 }
492 }
493 }
494 }
495 return success();
496 }
497
498 SmallVector<Value> cacheValues(Operation *op,
499 MGradientUtilsReverse *gutils) const {
500 auto loadOp = cast<affine::AffineLoadOp>(op);
501 Value memref = loadOp.getMemref();
502 ValueRange indices = loadOp.getIndices();
503 if (auto iface = dyn_cast<AutoDiffTypeInterface>(loadOp.getType())) {
504 if (!gutils->isConstantValue(loadOp) &&
505 !gutils->isConstantValue(memref)) {
506 OpBuilder cacheBuilder(gutils->getNewFromOriginal(op));
507 SmallVector<Value> caches;
508 caches.push_back(gutils->initAndPushCache(
509 gutils->invertPointerM(memref, cacheBuilder), cacheBuilder));
510 for (Value v : indices) {
511 caches.push_back(gutils->initAndPushCache(
512 gutils->getNewFromOriginal(v), cacheBuilder));
513 }
514 return caches;
515 }
516 }
517 return SmallVector<Value>();
518 }
519
520 void createShadowValues(Operation *op, OpBuilder &builder,
521 MGradientUtilsReverse *gutils) const {
522 // auto loadOp = cast<memref::LoadOp>(op);
523 // Value memref = loadOp.getMemref();
524 // Value shadow = gutils->getShadowValue(memref);
525 // Do nothing yet. In the future support memref<memref<...>>
526 }
527};
528
529struct AffineStoreOpInterfaceReverse
530 : public ReverseAutoDiffOpInterface::ExternalModel<
531 AffineStoreOpInterfaceReverse, affine::AffineStoreOp> {
532 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
533 MGradientUtilsReverse *gutils,
534 SmallVector<Value> caches) const {
535 auto storeOp = cast<affine::AffineStoreOp>(op);
536 Value val = storeOp.getValue();
537 Value memref = storeOp.getMemref();
538 // ValueRange indices = storeOp.getIndices();
539
540 auto iface = cast<AutoDiffTypeInterface>(val.getType());
541
542 if (!gutils->isConstantValue(memref)) {
543 OpBuilder cacheBuilder(gutils->getNewFromOriginal(op));
544
545 Value memrefGradient = gutils->popCache(caches.front(), builder);
546
547 SmallVector<Value> retrievedArguments;
548 for (Value cache : ValueRange(caches).drop_front(1)) {
549 Value retrievedValue = gutils->popCache(cache, builder);
550 retrievedArguments.push_back(retrievedValue);
551 }
552
553 bool hasIndex = storeOp.getAffineMap().getNumDims() > 0;
554
555 if (!iface.isMutable()) {
556 if (!gutils->isConstantValue(val)) {
557 Value loadedGradient;
558 if (hasIndex) {
559 SmallVector<Value> indices;
560 computeAffineIndices(builder, storeOp.getLoc(),
561 storeOp.getAffineMap(), retrievedArguments,
562 indices);
563 loadedGradient = memref::LoadOp::create(builder, storeOp.getLoc(),
564 memrefGradient, indices);
565 } else {
566 loadedGradient = affine::AffineLoadOp::create(
567 builder, storeOp.getLoc(), memrefGradient,
568 storeOp.getAffineMap(), ArrayRef<Value>(retrievedArguments));
569 }
570 gutils->addToDiffe(val, loadedGradient, builder);
571 }
572
573 auto zero =
574 cast<AutoDiffTypeInterface>(gutils->getShadowType(val.getType()))
575 .createNullValue(builder, op->getLoc());
576
577 // if index had to be cached, the pop is not necessarily a valid index
578 if (hasIndex) {
579 SmallVector<Value> indices;
580 computeAffineIndices(builder, storeOp.getLoc(),
581 storeOp.getAffineMap(), retrievedArguments,
582 indices);
583 memref::StoreOp::create(builder, storeOp.getLoc(), zero,
584 memrefGradient, indices);
585 } else {
586 affine::AffineStoreOp::create(builder, storeOp.getLoc(), zero,
587 memrefGradient, storeOp.getAffineMap(),
588 ArrayRef<Value>(retrievedArguments));
589 }
590 }
591 }
592 return success();
593 }
594
595 SmallVector<Value> cacheValues(Operation *op,
596 MGradientUtilsReverse *gutils) const {
597 auto storeOp = cast<affine::AffineStoreOp>(op);
598 Value memref = storeOp.getMemref();
599 ValueRange indices = storeOp.getIndices();
600 Value val = storeOp.getValue();
601 if (auto iface = dyn_cast<AutoDiffTypeInterface>(val.getType())) {
602 if (!gutils->isConstantValue(memref)) {
603 OpBuilder cacheBuilder(gutils->getNewFromOriginal(op));
604 SmallVector<Value> caches;
605 caches.push_back(gutils->initAndPushCache(
606 gutils->invertPointerM(memref, cacheBuilder), cacheBuilder));
607 for (Value v : indices) {
608 caches.push_back(gutils->initAndPushCache(
609 gutils->getNewFromOriginal(v), cacheBuilder));
610 }
611 return caches;
612 }
613 }
614 return SmallVector<Value>();
615 }
616
617 void createShadowValues(Operation *op, OpBuilder &builder,
618 MGradientUtilsReverse *gutils) const {
619 // auto storeOp = cast<memref::StoreOp>(op);
620 // Value memref = storeOp.getMemref();
621 // Value shadow = gutils->getShadowValue(memref);
622 // Do nothing yet. In the future support memref<memref<...>>
623 }
624};
625
626struct AffineForOpADDataFlow
627 : public ADDataFlowOpInterface::ExternalModel<AffineForOpADDataFlow,
628 affine::AffineForOp> {
629 SmallVector<Value> getPotentialIncomingValuesRes(Operation *op,
630 OpResult res) const {
631 auto forOp = cast<affine::AffineForOp>(op);
632 return {
633 forOp.getInits()[res.getResultNumber()],
634 forOp.getBody()->getTerminator()->getOperand(res.getResultNumber())};
635 }
636 SmallVector<Value> getPotentialIncomingValuesArg(Operation *op,
637 BlockArgument arg) const {
638 auto forOp = cast<affine::AffineForOp>(op);
639 if (arg.getArgNumber() < 1) {
640 return {};
641 }
642 auto idx = arg.getArgNumber() - 1;
643 return {forOp.getInits()[idx],
644 forOp.getBody()->getTerminator()->getOperand(idx)};
645 }
646 SmallVector<Value> getPotentialTerminatorUsers(Operation *op, Operation *term,
647 Value val) const {
648 auto forOp = cast<affine::AffineForOp>(op);
649 SmallVector<Value> sv;
650
651 for (auto &&[res, arg, barg] :
652 llvm::zip_equal(forOp->getResults(), term->getOperands(),
653 forOp.getRegionIterArgs())) {
654 if (arg == val) {
655 sv.push_back(res);
656 sv.push_back(barg);
657 }
658 }
659
660 return sv;
661 }
662};
663
664struct AffineForOpEnzymeOpsRemover
665 : public ForLikeEnzymeOpsRemover<AffineForOpEnzymeOpsRemover,
666 affine::AffineForOp> {
667public:
668 // TODO: support non constant number of iteration by using unknown dimensions
669 static std::optional<int64_t>
670 getConstantNumberOfIterations(affine::AffineForOp forOp) {
671 if (!forOp.hasConstantLowerBound())
672 return std::nullopt;
673 if (!forOp.hasConstantUpperBound())
674 return std::nullopt;
675 return (forOp.getConstantUpperBound() - forOp.getConstantLowerBound()) /
676 forOp.getStepAsInt();
677 }
678
679 static SmallVector<IntOrValue, 1>
680 getDimensionBounds(OpBuilder &builder, affine::AffineForOp forOp) {
681 auto iters = getConstantNumberOfIterations(forOp);
682 if (iters) {
683 return {IntOrValue(*iters)};
684 } else {
685 auto lb = AffineApplyOp::create(builder, forOp.getLoc(),
686 forOp.getLowerBoundMap(),
687 forOp.getLowerBoundOperands());
688 auto ub = AffineApplyOp::create(builder, forOp.getLoc(),
689 forOp.getUpperBoundMap(),
690 forOp.getUpperBoundOperands());
691
692 Value diff = arith::SubIOp::create(builder, forOp->getLoc(), ub, lb);
693 if (forOp.getStepAsInt() != 1) {
694 auto step = arith::ConstantIntOp::create(
695 builder, forOp->getLoc(), diff.getType(), forOp.getStepAsInt());
696 diff = arith::DivUIOp::create(builder, forOp->getLoc(), diff, step);
697 }
698 return {IntOrValue(diff)};
699 }
700 }
701
702 static SmallVector<Value> getCanonicalLoopIVs(OpBuilder &builder,
703 affine::AffineForOp forOp) {
704 Value val = forOp.getBody()->getArgument(0);
705 if (!forOp.hasConstantLowerBound() || forOp.getConstantLowerBound() != 0) {
706 auto lb = AffineApplyOp::create(builder, forOp.getLoc(),
707 forOp.getLowerBoundMap(),
708 forOp.getLowerBoundOperands());
709 val = arith::SubIOp::create(builder, forOp->getLoc(), val, lb);
710 }
711
712 if (forOp.getStepAsInt() != 1) {
713 auto step = arith::ConstantIntOp::create(
714 builder, forOp->getLoc(), val.getType(), forOp.getStepAsInt());
715 val = arith::DivUIOp::create(builder, forOp->getLoc(), val, step);
716 }
717 return {val};
718 }
719
720 static IRMapping createArgumentMap(PatternRewriter &rewriter,
721 affine::AffineForOp forOp,
722 ArrayRef<Value> indFor,
723 affine::AffineForOp otherForOp,
724 ArrayRef<Value> indOther) {
725 IRMapping map;
726 for (auto &&[f, o] : llvm::zip_equal(indFor, indOther))
727 map.map(f, o);
728
729 Value canIdx = forOp.getBody()->getArgument(0);
730 if (!map.contains(canIdx)) {
731 assert(forOp.getLowerBoundMap() == otherForOp.getLowerBoundMap());
732 for (auto &&[f, o] :
733 llvm::zip_equal(forOp.getLowerBoundOperands(),
734 otherForOp.getLowerBoundOperands())) {
735 (void)f;
736 (void)o;
737 assert(Equivalent(f, o));
738 }
739 assert(forOp.getStep() == otherForOp.getStep());
740 map.map(forOp.getBody()->getArgument(0),
741 otherForOp.getBody()->getArgument(0));
742 }
743 return map;
744 }
745
746 static affine::AffineForOp
747 replaceWithNewOperands(PatternRewriter &rewriter,
748 affine::AffineForOp otherForOp,
749 ArrayRef<Value> operands) {
750 auto newOtherForOp = affine::AffineForOp::create(
751 rewriter, otherForOp->getLoc(), otherForOp.getLowerBoundOperands(),
752 otherForOp.getLowerBoundMap(), otherForOp.getUpperBoundOperands(),
753 otherForOp.getUpperBoundMap(), otherForOp.getStepAsInt(), operands);
754
755 newOtherForOp.getRegion().takeBody(otherForOp.getRegion());
756 rewriter.replaceOp(otherForOp, newOtherForOp->getResults().slice(
757 0, otherForOp->getNumResults()));
758 return newOtherForOp;
759 }
760
761 static ValueRange getInits(affine::AffineForOp forOp) {
762 return forOp.getInits();
763 }
764
765 static bool mustPostAdd(affine::AffineForOp forOp) { return false; }
766
767 static Value initialValueInBlock(OpBuilder &builder, Block *body,
768 Value grad) {
769 auto Ty = cast<enzyme::GradientType>(grad.getType()).getBasetype();
770 return body->addArgument(Ty, grad.getLoc());
771 }
772};
773
774#include "Implementations/AffineDerivatives.inc"
775} // namespace
776
778 DialectRegistry &registry) {
779 registry.addExtension(+[](MLIRContext *context, affine::AffineDialect *) {
780 registerInterfaces(context);
781 affine::AffineLoadOp::attachInterface<AffineLoadOpInterfaceReverse>(
782 *context);
783 affine::AffineStoreOp::attachInterface<AffineStoreOpInterfaceReverse>(
784 *context);
785 affine::AffineForOp::attachInterface<AffineForOpInterfaceReverse>(*context);
786 affine::AffineForOp::attachInterface<AffineForOpEnzymeOpsRemover>(*context);
787 affine::AffineForOp::attachInterface<AffineForOpADDataFlow>(*context);
788 affine::AffineParallelOp::attachInterface<AffineParallelOpInterfaceReverse>(
789 *context);
790 affine::AffineParallelOp::attachInterface<AffineParallelOpEnzymeOpsRemover>(
791 *context);
792 });
793}
static std::optional< SmallVector< Value > > getPotentialTerminatorUsers(Operation *op, Value parent)
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder)
LogicalResult visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils)
Value popCache(Value cache, OpBuilder &builder)
Value initAndPushCache(Value v, OpBuilder &builder)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
mlir::Type getShadowType(mlir::Type T)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
bool isConstantValue(mlir::Value v) const
void computeAffineIndices(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands, SmallVectorImpl< Value > &indices)
Definition Utils.cpp:171
void localizeGradients(OpBuilder &builder, MGradientUtilsReverse *gutils, Block *fwd)
void registerAffineDialectAutoDiffInterface(DialectRegistry &registry)