Enzyme main
Loading...
Searching...
No Matches
SCFAutoDiffOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- SCFAutoDiffOpInterfaceImpl.cpp - Interface external model ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the external model implementation of the automatic
10// differentiation op interfaces for the upstream MLIR SCF dialect.
11//
12//===----------------------------------------------------------------------===//
13
19#include "Passes/RemovalUtils.h"
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/IR/DialectRegistry.h"
24#include "mlir/IR/Types.h"
25#include "mlir/Support/LogicalResult.h"
26#include "mlir/Transforms/RegionUtils.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/ScopeExit.h"
29#include <functional>
30
31using namespace mlir;
32using namespace mlir::enzyme;
33
34namespace {
35#include "Implementations/SCFDerivatives.inc"
36
37struct ForOpEnzymeOpsRemover
38 : public ForLikeEnzymeOpsRemover<ForOpEnzymeOpsRemover, scf::ForOp> {
39public:
40 // TODO: support non constant number of iteration by using unknown dimensions
41 static std::optional<int64_t>
42 getConstantNumberOfIterations(scf::ForOp forOp) {
43 auto lb = forOp.getLowerBound();
44 auto ub = forOp.getUpperBound();
45 auto step = forOp.getStep();
46
47 IntegerAttr lbAttr, ubAttr, stepAttr;
48 if (!matchPattern(lb, m_Constant(&lbAttr)))
49 return std::nullopt;
50 if (!matchPattern(ub, m_Constant(&ubAttr)))
51 return std::nullopt;
52 if (!matchPattern(step, m_Constant(&stepAttr)))
53 return std::nullopt;
54
55 int64_t lbI = lbAttr.getInt(), ubI = ubAttr.getInt(),
56 stepI = stepAttr.getInt();
57
58 return (ubI - lbI) / stepI;
59 }
60
61 static SmallVector<IntOrValue, 1> getDimensionBounds(OpBuilder &builder,
62 scf::ForOp forOp) {
63 auto iters = getConstantNumberOfIterations(forOp);
64 if (iters) {
65 return {IntOrValue(*iters)};
66 } else {
67 Value lb = forOp.getLowerBound(), ub = forOp.getUpperBound(),
68 step = forOp.getStep();
69 Value diff = arith::SubIOp::create(builder, forOp->getLoc(), ub, lb);
70 Value nSteps =
71 arith::DivUIOp::create(builder, forOp->getLoc(), diff, step);
72 return {IntOrValue(nSteps)};
73 }
74 }
75
76 static SmallVector<Value> getCanonicalLoopIVs(OpBuilder &builder,
77 scf::ForOp forOp) {
78
79 Value val = forOp.getBody()->getArgument(0);
80 if (!matchPattern(forOp.getLowerBound(), m_Zero())) {
81 val = arith::SubIOp::create(builder, forOp->getLoc(), val,
82 forOp.getLowerBound());
83 }
84
85 if (!matchPattern(forOp.getStep(), m_One())) {
86 val = arith::DivUIOp::create(builder, forOp->getLoc(), val,
87 forOp.getStep());
88 }
89 return {val};
90 }
91
92 static IRMapping createArgumentMap(PatternRewriter &rewriter,
93 scf::ForOp forOp, ArrayRef<Value> indFor,
94 scf::ForOp otherForOp,
95 ArrayRef<Value> reversedOther) {
96 IRMapping map;
97 for (auto &&[f, o] : llvm::zip_equal(indFor, reversedOther)) {
98 map.map(f, o);
99 }
100
101 Value canIdx = forOp.getBody()->getArgument(0);
102 if (!map.contains(canIdx)) {
103 assert(Equivalent(forOp.getLowerBound(), otherForOp.getLowerBound()));
104 assert(Equivalent(forOp.getStep(), otherForOp.getStep()));
105 map.map(forOp.getBody()->getArgument(0),
106 otherForOp.getBody()->getArgument(0));
107 }
108 return map;
109 }
110
111 static scf::ForOp replaceWithNewOperands(PatternRewriter &rewriter,
112 scf::ForOp otherForOp,
113 ArrayRef<Value> operands) {
114 auto newOtherForOp = scf::ForOp::create(
115 rewriter, otherForOp->getLoc(), otherForOp.getLowerBound(),
116 otherForOp.getUpperBound(), otherForOp.getStep(), operands);
117
118 newOtherForOp.getRegion().takeBody(otherForOp.getRegion());
119 rewriter.replaceOp(otherForOp, newOtherForOp->getResults().slice(
120 0, otherForOp->getNumResults()));
121 return newOtherForOp;
122 }
123
124 static ValueRange getInits(scf::ForOp forOp) { return forOp.getInitArgs(); }
125
126 static bool mustPostAdd(scf::ForOp forOp) { return false; }
127
128 static Value initialValueInBlock(OpBuilder &builder, Block *body,
129 Value grad) {
130 auto Ty = cast<enzyme::GradientType>(grad.getType()).getBasetype();
131 return body->addArgument(Ty, grad.getLoc());
132 }
133};
134
135struct ForOpInterfaceReverse
136 : public ReverseAutoDiffOpInterface::ExternalModel<ForOpInterfaceReverse,
137 scf::ForOp> {
138private:
139 static Value makeIntConstant(Location loc, OpBuilder builder, int64_t val,
140 Type ty) {
141 return arith::ConstantOp::create(builder, loc, IntegerAttr::get(ty, val))
142 .getResult();
143 };
144
145 static void preserveAttributesButCheckpointing(Operation *newOp,
146 Operation *oldOp) {
147 for (auto attr : oldOp->getDiscardableAttrs()) {
148 if (attr.getName() != "enzyme.enable_checkpointing")
149 newOp->setAttr(attr.getName(), attr.getValue());
150 }
151 }
152
153 static bool needsCheckpointing(scf::ForOp forOp) {
154 return forOp->hasAttrOfType<BoolAttr>("enzyme.enable_checkpointing") &&
155 forOp->getAttrOfType<BoolAttr>("enzyme.enable_checkpointing")
156 .getValue() &&
157 ForOpEnzymeOpsRemover::getConstantNumberOfIterations(forOp)
158 .has_value();
159 }
160
161public:
162 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
163 MGradientUtilsReverse *gutils,
164 SmallVector<Value> caches) const {
165 // SCF ForOp has 3 more operands than results (lb, ub, step).
166 // Its body has 1 more argument than yielded values (the induction
167 // variable).
168
169 auto forOp = cast<scf::ForOp>(op);
170
171 SmallVector<bool> operandsActive(forOp.getNumOperands() - 3, false);
172 for (int i = 0, e = operandsActive.size(); i < e; ++i) {
173 operandsActive[i] = !gutils->isConstantValue(op->getOperand(i + 3)) ||
174 !gutils->isConstantValue(op->getResult(i));
175 }
176
177 SmallVector<Value> incomingGradients;
178 for (auto &&[active, res] :
179 llvm::zip_equal(operandsActive, op->getResults())) {
180 if (active) {
181 incomingGradients.push_back(gutils->diffe(res, builder));
182 if (!gutils->isConstantValue(res))
183 gutils->zeroDiffe(res, builder);
184 }
185 }
186
187 if (needsCheckpointing(forOp)) {
188 int64_t numIters =
189 ForOpEnzymeOpsRemover::getConstantNumberOfIterations(forOp).value();
190 int64_t nInner = std::sqrt(numIters), nOuter = nInner;
191 int64_t trailingIters = numIters - nInner * nOuter;
192
193 bool hasTrailing = trailingIters > 0;
194
195 auto numIterArgs = forOp.getNumRegionIterArgs();
196
197 SetVector<Value> outsideRefs;
198 getUsedValuesDefinedAbove(op->getRegions(), outsideRefs);
199
200 SmallVector<Value> immutableRefs;
201 SmallVector<Value> mutableRefs;
202
203 for (auto ref : outsideRefs) {
204 if (isa<ClonableTypeInterface>(ref.getType()))
205 mutableRefs.push_back(ref);
206 else
207 immutableRefs.push_back(ref);
208 }
209
210 IRMapping &mapping = gutils->originalToNewFn;
211
212 assert(outsideRefs.size() == caches.size() - numIterArgs);
213
214 for (auto [i, ref] : llvm::enumerate(immutableRefs)) {
215 Value refVal = gutils->popCache(
216 caches[numIterArgs + mutableRefs.size() + i], builder);
217 mapping.map(ref, refVal);
218 }
219
220 auto ivTy = forOp.getLowerBound().getType();
221 Value outerUB = makeIntConstant(forOp.getLowerBound().getLoc(), builder,
222 nOuter + hasTrailing, ivTy);
223 auto revOuter = scf::ForOp::create(
224 builder, op->getLoc(),
225 makeIntConstant(forOp.getLowerBound().getLoc(), builder, 0, ivTy),
226 outerUB,
227 makeIntConstant(forOp.getLowerBound().getLoc(), builder, 1, ivTy),
228 incomingGradients);
229 preserveAttributesButCheckpointing(revOuter, forOp);
230
231 OpBuilder::InsertionGuard guard(builder);
232 builder.setInsertionPointToEnd(revOuter.getBody());
233
234 SmallVector<Value> cachedOutsideRefs;
235 for (auto [i, ref] : llvm::enumerate(mutableRefs)) {
236 Value refVal = gutils->popCache(caches[numIterArgs + i], builder);
237 cachedOutsideRefs.push_back(refVal);
238 mapping.map(ref, refVal);
239 }
240
241 Location loc = forOp.getInductionVar().getLoc();
242 Value currentOuterStep = arith::SubIOp::create(
243 builder, loc, makeIntConstant(loc, builder, nOuter, ivTy),
244 revOuter.getInductionVar());
245
246 SmallVector<Value> initArgs(numIterArgs, nullptr);
247 for (size_t i = 0; i < numIterArgs; ++i) {
248 initArgs[i] = gutils->popCache(caches[i], builder);
249 }
250
251 auto nInnerCst = makeIntConstant(forOp.getLowerBound().getLoc(), builder,
252 nInner, ivTy);
253 Value zero = makeIntConstant(forOp.getLowerBound().getLoc(), builder, 0,
254 ivTy),
255 one = makeIntConstant(forOp.getLowerBound().getLoc(), builder, 1,
256 ivTy);
257
258 Value nInnerUB = nInnerCst;
259 if (trailingIters > 0) {
260 // this is the first reverse iteration
261 Location loc = forOp.getUpperBound().getLoc();
262 nInnerUB = arith::SelectOp::create(
263 builder, loc,
264 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
265 revOuter.getInductionVar(), zero),
266 makeIntConstant(loc, builder, trailingIters, ivTy), nInnerCst);
267 }
268
269 auto revInner = scf::ForOp::create(builder, forOp.getLoc(), zero,
270 nInnerUB, one, initArgs);
271 preserveAttributesButCheckpointing(revInner, forOp);
272
273 llvm::APInt stepI;
274 if (!matchPattern(forOp.getStep(), m_ConstantInt(&stepI))) {
275 op->emitError() << "step size is not known constant\n";
276 return failure();
277 }
278
279 llvm::APInt startI;
280 if (!matchPattern(forOp.getLowerBound(), m_ConstantInt(&startI))) {
281 op->emitError() << "lower bound is not known constant\n";
282 return failure();
283 }
284
285 builder.setInsertionPointToEnd(revInner.getBody());
286
287 Value currentIV = arith::AddIOp::create(
288 builder, loc,
289 arith::MulIOp::create(
290 builder, loc,
291 arith::AddIOp::create(builder, loc,
292 arith::MulIOp::create(builder, loc,
293 currentOuterStep,
294 nInnerCst),
295 revInner.getInductionVar()),
296 arith::ConstantOp::create(builder, loc,
297 IntegerAttr::get(ivTy, stepI))),
298 arith::ConstantOp::create(builder, loc,
299 IntegerAttr::get(ivTy, startI)));
300
301 for (auto [oldArg, newArg] :
302 llvm::zip_equal(forOp.getBody()->getArguments(),
303 revInner.getBody()->getArguments()))
304 mapping.map(oldArg, newArg);
305 mapping.map(forOp.getInductionVar(), currentIV);
306
307 for (auto &it : *forOp.getBody()) {
308 auto newOp = builder.clone(it, mapping);
309 gutils->originalToNewFnOps[&it] = newOp;
310 }
311
312 builder.setInsertionPointToEnd(revOuter.getBody());
313
314 for (auto outsideRef : cachedOutsideRefs) {
315 if (auto cachableT =
316 dyn_cast<ClonableTypeInterface>(outsideRef.getType())) {
317 cachableT.freeClonedValue(builder, outsideRef);
318 }
319 }
320
321 auto revLoop =
322 scf::ForOp::create(builder, forOp.getLoc(), zero, nInnerUB, one,
323 revOuter.getBody()->getArguments().drop_front());
324 preserveAttributesButCheckpointing(revLoop, forOp);
325
326 Block *revLoopBody = revLoop.getBody();
327 builder.setInsertionPointToEnd(revLoopBody);
328
329 int revIdx = 1;
330 for (auto &&[active, operand] :
331 llvm::zip_equal(operandsActive,
332 forOp.getBody()->getTerminator()->getOperands())) {
333 if (active) {
334 gutils->addToDiffe(operand, revLoopBody->getArgument(revIdx),
335 builder);
336 revIdx++;
337 }
338 }
339
340 Block *origBody = forOp.getBody();
341
342 bool valid = true;
343
344 auto first = origBody->rbegin();
345 first++; // skip terminator
346
347 auto last = origBody->rend();
348
349 for (auto it = first; it != last; ++it) {
350 Operation *op = &*it;
351 valid &= gutils->Logic.visitChild(op, builder, gutils).succeeded();
352 }
353
354 SmallVector<Value> newResults;
355 for (auto &&[active, arg] : llvm::zip_equal(
356 operandsActive, origBody->getArguments().drop_front())) {
357 if (active) {
358 newResults.push_back(gutils->diffe(arg, builder));
359 if (!gutils->isConstantValue(arg))
360 gutils->zeroDiffe(arg, builder);
361 }
362 }
363
364 builder.setInsertionPointToEnd(revLoopBody);
365 scf::YieldOp::create(builder, forOp.getBody()->getTerminator()->getLoc(),
366 newResults);
367
368 builder.setInsertionPointToEnd(revOuter.getBody());
369 scf::YieldOp::create(builder, forOp.getBody()->getTerminator()->getLoc(),
370 revLoop.getResults());
371
372 builder.setInsertionPointAfter(revOuter);
373
374 revIdx = 0;
375 for (auto &&[active, arg] : llvm::zip_equal(
376 operandsActive,
377 op->getOperands().slice(3, op->getNumOperands() - 3))) {
378 if (active) {
379 if (!gutils->isConstantValue(arg)) {
380 gutils->addToDiffe(arg, revOuter->getResult(revIdx), builder);
381 }
382 revIdx++;
383 }
384 }
385
386 return success(valid);
387 }
388
389 auto start = gutils->popCache(caches[0], builder);
390 auto end = gutils->popCache(caches[1], builder);
391 auto step = gutils->popCache(caches[2], builder);
392
393 auto repFor = scf::ForOp::create(builder, forOp.getLoc(), start, end, step,
394 incomingGradients);
395 preserveAttributesButCheckpointing(repFor, forOp);
396
397 bool valid = true;
398 for (auto &&[oldReg, newReg] :
399 llvm::zip(op->getRegions(), repFor->getRegions())) {
400 for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
401 OpBuilder bodyBuilder(&revBB, revBB.end());
402
403 // Create implicit terminator if not present (when num results > 0)
404 if (revBB.empty()) {
405 scf::YieldOp::create(bodyBuilder, repFor->getLoc());
406 }
407
408 bodyBuilder.setInsertionPointToStart(&revBB);
409 mlir::enzyme::localizeGradients(bodyBuilder, gutils, &oBB);
410
411 bodyBuilder.setInsertionPoint(revBB.getTerminator());
412
413 auto term = oBB.getTerminator();
414
415 unsigned argIdx = 1; // Skip over the reversed IV
416 for (auto &&[active, operand] :
417 llvm::zip_equal(operandsActive, term->getOperands())) {
418 if (active) {
419 // Set diffe here, not add because it should not accumulate across
420 // iterations. Instead the new gradient for this operand is passed
421 // in the return of the reverse for body.
422 gutils->setDiffe(operand, revBB.getArgument(argIdx), bodyBuilder);
423 argIdx++;
424 }
425 }
426
427 auto first = oBB.rbegin();
428 first++; // skip terminator
429
430 auto last = oBB.rend();
431
432 for (auto it = first; it != last; ++it) {
433 Operation *op = &*it;
434 valid &=
435 gutils->Logic.visitChild(op, bodyBuilder, gutils).succeeded();
436 }
437
438 SmallVector<Value> newResults;
439 newResults.reserve(incomingGradients.size());
440
441 for (auto &&[active, arg] :
442 llvm::zip_equal(operandsActive, oBB.getArguments().slice(1))) {
443 if (active) {
444 newResults.push_back(gutils->diffe(arg, bodyBuilder));
445 if (!gutils->isConstantValue(arg))
446 gutils->zeroDiffe(arg, bodyBuilder);
447 }
448 }
449
450 // yield new gradient values
451 revBB.getTerminator()->setOperands(newResults);
452 }
453 }
454
455 unsigned resIdx = 0;
456 for (auto &&[active, arg] :
457 llvm::zip_equal(operandsActive, forOp.getInitArgs())) {
458 if (active) {
459 if (!gutils->isConstantValue(arg)) {
460 gutils->addToDiffe(arg, repFor.getResult(resIdx), builder);
461 resIdx++;
462 }
463 }
464 }
465
466 return success(valid);
467 }
468
469 SmallVector<Value> cacheValues(Operation *op,
470 MGradientUtilsReverse *gutils) const {
471 auto forOp = cast<scf::ForOp>(op);
472 Operation *newOp = gutils->getNewFromOriginal(op);
473 OpBuilder cacheBuilder(newOp);
474
475 if (needsCheckpointing(forOp)) {
476 int64_t numIters =
477 ForOpEnzymeOpsRemover::getConstantNumberOfIterations(forOp).value();
478 int64_t nInner = std::sqrt(numIters), nOuter = nInner;
479 int64_t trailingIters = numIters - nInner * nOuter;
480 bool hasTrailing = trailingIters > 0;
481
482 SetVector<Value> outsideRefs;
483 getUsedValuesDefinedAbove(op->getRegions(), outsideRefs);
484
485 SmallVector<Value> immutableRefs;
486 SmallVector<Value> mutableRefs;
487
488 for (auto ref : outsideRefs) {
489 if (isa<ClonableTypeInterface>(ref.getType()))
490 mutableRefs.push_back(ref);
491 else
492 immutableRefs.push_back(ref);
493 }
494
495 SmallVector<Value> caches;
496
497 scf::ForOp newForOp = cast<scf::ForOp>(gutils->getNewFromOriginal(op));
498
499 Type ty = forOp.getLowerBound().getType();
500 auto outerFwd = scf::ForOp::create(
501 cacheBuilder, op->getLoc(),
502 makeIntConstant(forOp.getLowerBound().getLoc(), cacheBuilder, 0, ty),
503 makeIntConstant(forOp.getUpperBound().getLoc(), cacheBuilder,
504 nInner * (nOuter + hasTrailing), ty),
505 makeIntConstant(forOp.getStep().getLoc(), cacheBuilder, nInner, ty),
506 newForOp.getInitArgs());
507 preserveAttributesButCheckpointing(outerFwd, forOp);
508
509 cacheBuilder.setInsertionPointToStart(outerFwd.getBody());
510 auto nInnerCst = makeIntConstant(forOp.getUpperBound().getLoc(),
511 cacheBuilder, nInner, ty);
512
513 Value nInnerUB = nInnerCst;
514 if (trailingIters > 0) {
515 // if this is the last iteration, then the inner
516 // loop will only make trailingIters iterations
517 Location loc = forOp.getUpperBound().getLoc();
518 nInnerUB = arith::SelectOp::create(
519 cacheBuilder, loc,
520 arith::CmpIOp::create(
521 cacheBuilder, loc, arith::CmpIPredicate::eq,
522 outerFwd.getInductionVar(),
523 makeIntConstant(loc, cacheBuilder, nInner * nOuter, ty)),
524 makeIntConstant(loc, cacheBuilder, trailingIters, ty), nInnerCst);
525 }
526
527 IRMapping &mapping = gutils->originalToNewFn;
528
529 SmallVector<Value> mutableRefsCaches;
530 for (auto ref : mutableRefs) {
531 auto iface = cast<ClonableTypeInterface>(ref.getType());
532 auto clone =
533 iface.cloneValue(cacheBuilder, mapping.lookupOrDefault(ref));
534 mutableRefsCaches.push_back(
535 gutils->initAndPushCache(clone, cacheBuilder));
536 }
537
538 auto innerFwd = scf::ForOp::create(
539 cacheBuilder, op->getLoc(),
540 makeIntConstant(forOp.getLowerBound().getLoc(), cacheBuilder, 0, ty),
541 nInnerUB,
542 makeIntConstant(forOp.getStep().getLoc(), cacheBuilder, 1, ty),
543 outerFwd.getBody()->getArguments().drop_front());
544 preserveAttributesButCheckpointing(innerFwd, forOp);
545
546 cacheBuilder.setInsertionPointToEnd(innerFwd.getBody());
547
548 Location loc = forOp.getInductionVar().getLoc();
549 auto currentIV = arith::MulIOp::create(
550 cacheBuilder, loc,
551 arith::AddIOp::create(
552 cacheBuilder, loc,
553 arith::MulIOp::create(cacheBuilder, loc,
554 outerFwd.getInductionVar(), nInnerCst),
555 innerFwd.getInductionVar()),
556 newForOp.getStep());
557
558 for (auto [oldArg, newArg] :
559 llvm::zip_equal(forOp.getBody()->getArguments(),
560 innerFwd.getBody()->getArguments()))
561 mapping.map(oldArg, newArg);
562 mapping.map(forOp.getInductionVar(), currentIV);
563
564 for (auto &it : *forOp.getBody())
565 cacheBuilder.clone(it, mapping);
566
567 cacheBuilder.setInsertionPointToEnd(outerFwd.getBody());
568 for (auto initArg : innerFwd.getInitArgs())
569 caches.push_back(gutils->initAndPushCache(initArg, cacheBuilder));
570
571 scf::YieldOp::create(cacheBuilder,
572 forOp.getBody()->getTerminator()->getLoc(),
573 innerFwd->getResults());
574
575 cacheBuilder.setInsertionPointAfter(outerFwd);
576
577 caches.append(mutableRefsCaches);
578
579 for (auto ref : immutableRefs)
580 caches.push_back(gutils->initAndPushCache(mapping.lookupOrDefault(ref),
581 cacheBuilder));
582
583 gutils->replaceOrigOpWith(op, outerFwd.getResults());
584 gutils->erase(newForOp);
585 gutils->originalToNewFnOps[op] = outerFwd;
586
587 // caches is composed of:
588 // [
589 // <caches of iter args>...,
590 // <caches of mutable values>...,
591 // <caches of immutable values>...,
592 // ]
593 //
594 // TODO: we don't need to cache refs of arith.constants
595 // .... which we can "clone" just before the inner forward
596 // .... in the reverse pass.
597 // .... create an interface that mincut can also use?
598
599 return caches;
600 }
601
602 SmallVector<Value> caches;
603
604 Value cacheLB = gutils->initAndPushCache(
605 gutils->getNewFromOriginal(forOp.getLowerBound()), cacheBuilder);
606 caches.push_back(cacheLB);
607
608 Value cacheUB = gutils->initAndPushCache(
609 gutils->getNewFromOriginal(forOp.getUpperBound()), cacheBuilder);
610 caches.push_back(cacheUB);
611
612 Value cacheStep = gutils->initAndPushCache(
613 gutils->getNewFromOriginal(forOp.getStep()), cacheBuilder);
614 caches.push_back(cacheStep);
615
616 return caches;
617 }
618
619 void createShadowValues(Operation *op, OpBuilder &builder,
620 MGradientUtilsReverse *gutils) const {
621 // auto forOp = cast<scf::ForOp>(op);
622 }
623};
624
625struct ParallelOpEnzymeOpsRemover
626 : public ForLikeEnzymeOpsRemover<ParallelOpEnzymeOpsRemover,
627 scf::ParallelOp> {
628 static std::optional<int64_t>
629 getConstantNumberOfIterations(Value lb, Value ub, Value step) {
630 IntegerAttr lbAttr, ubAttr, stepAttr;
631 if (!matchPattern(lb, m_Constant(&lbAttr)))
632 return std::nullopt;
633 if (!matchPattern(ub, m_Constant(&ubAttr)))
634 return std::nullopt;
635 if (!matchPattern(step, m_Constant(&stepAttr)))
636 return std::nullopt;
637
638 int64_t lbI = lbAttr.getInt(), ubI = ubAttr.getInt(),
639 stepI = stepAttr.getInt();
640 return (ubI - lbI) / stepI;
641 }
642
643 static SmallVector<IntOrValue, 1> getDimensionBounds(OpBuilder &builder,
644 scf::ParallelOp parOp) {
645 SmallVector<IntOrValue, 1> bounds;
646 bounds.reserve(parOp.getNumLoops());
647 for (auto &&[lb, ub, step] : llvm::zip_equal(
648 parOp.getLowerBound(), parOp.getUpperBound(), parOp.getStep())) {
649 auto iters = getConstantNumberOfIterations(lb, ub, step);
650 if (iters) {
651 bounds.push_back(IntOrValue(*iters));
652 } else {
653 Value diff = arith::SubIOp::create(builder, parOp.getLoc(), ub, lb);
654 Value nSteps =
655 arith::DivUIOp::create(builder, parOp.getLoc(), diff, step);
656 bounds.push_back(IntOrValue(nSteps));
657 }
658 }
659 return bounds;
660 }
661
662 static SmallVector<Value>
663 computeReversedIndices(PatternRewriter &rewriter, scf::ParallelOp parOp,
664 ArrayRef<Value> otherInductionVariable,
665 ArrayRef<IntOrValue> bounds) {
666 return SmallVector<Value>(otherInductionVariable);
667 }
668
669 static SmallVector<Value> getCanonicalLoopIVs(OpBuilder &builder,
670 scf::ParallelOp parOp) {
671 SmallVector<Value> canonicalIVs;
672 canonicalIVs.reserve(parOp.getNumLoops());
673 for (auto &&[iv, lb, step] :
674 llvm::zip_equal(parOp.getInductionVars(), parOp.getLowerBound(),
675 parOp.getStep())) {
676 Value val = iv;
677 if (!matchPattern(lb, m_Zero())) {
678 val = arith::SubIOp::create(builder, parOp.getLoc(), val, lb);
679 }
680
681 if (!matchPattern(step, m_One())) {
682 val = arith::DivUIOp::create(builder, parOp.getLoc(), val, step);
683 }
684 canonicalIVs.push_back(val);
685 }
686 return canonicalIVs;
687 }
688
689 static IRMapping createArgumentMap(PatternRewriter &rewriter,
690 scf::ParallelOp parOp,
691 ArrayRef<Value> indPar,
692 scf::ParallelOp otherParOp,
693 ArrayRef<Value> indOther) {
694 IRMapping map;
695 for (auto &&[f, o] : llvm::zip_equal(indPar, indOther))
696 map.map(f, o);
697
698 for (auto &&[iv, oiv, lb, olb, step, ostep] : llvm::zip_equal(
699 parOp.getInductionVars(), otherParOp.getInductionVars(),
700 parOp.getLowerBound(), otherParOp.getLowerBound(), parOp.getStep(),
701 otherParOp.getStep())) {
702 if (!map.contains(iv)) {
703 assert(Equivalent(lb, olb));
704 assert(Equivalent(step, ostep));
705 map.map(iv, oiv);
706 }
707 }
708 return map;
709 }
710
711 static scf::ParallelOp replaceWithNewOperands(PatternRewriter &rewriter,
712 scf::ParallelOp otherParallelOp,
713 ArrayRef<Value> operands) {
714 auto newOtherParOp = scf::ParallelOp::create(
715 rewriter, otherParallelOp.getLoc(), otherParallelOp.getLowerBound(),
716 otherParallelOp.getUpperBound(), otherParallelOp.getStep(), operands);
717
718 newOtherParOp.getRegion().takeBody(otherParallelOp.getRegion());
719 rewriter.replaceOp(
720 otherParallelOp,
721 newOtherParOp.getResults().slice(0, otherParallelOp.getNumResults()));
722
723 if (operands.size() >= 1) {
724 OpBuilder::InsertionGuard guard(rewriter);
725 Operation *oldTerm = newOtherParOp.getBody()->getTerminator();
726 rewriter.setInsertionPointToEnd(newOtherParOp.getBody());
727 auto term = scf::ReduceOp::create(rewriter, newOtherParOp.getLoc(),
728 oldTerm->getOperands());
729
730 for (auto [reg, operand] :
731 llvm::zip_equal(term->getRegions(), operands)) {
732 Block *b = &reg.front();
733 rewriter.setInsertionPointToEnd(b);
734
735 auto Ty = cast<AutoDiffTypeInterface>(operand.getType());
736 Value reduced = Ty.createAddOp(rewriter, operand.getLoc(),
737 b->getArgument(0), b->getArgument(1));
738 scf::ReduceReturnOp::create(rewriter, reduced.getLoc(), reduced);
739 }
740
741 oldTerm->erase();
742 }
743
744 return newOtherParOp;
745 }
746
747 static ValueRange getInits(scf::ParallelOp parallelOp) {
748 return parallelOp.getInitVals();
749 }
750
751 static bool mustPostAdd(scf::ParallelOp forOp) { return false; }
752
753 static Value initialValueInBlock(OpBuilder &builder, Block *body,
754 Value grad) {
755 OpBuilder::InsertionGuard guard(builder);
756 builder.setInsertionPointToStart(body);
757 return cast<AutoDiffTypeInterface>(
758 cast<enzyme::GradientType>(grad.getType()).getBasetype())
759 .createNullValue(builder, grad.getLoc());
760 }
761};
762
763struct ParallelOpInterfaceReverse
764 : public ReverseAutoDiffOpInterface::ExternalModel<
765 ParallelOpInterfaceReverse, scf::ParallelOp> {
766 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
767 MGradientUtilsReverse *gutils,
768 SmallVector<Value> caches) const {
769 auto parallelOp = cast<scf::ParallelOp>(op);
770 if (parallelOp.getNumReductions() != 0) {
771 return parallelOp.emitError()
772 << "parallel reductions not yet implemented\n";
773 }
774
775 unsigned loopCount = parallelOp.getNumLoops();
776 SmallVector<Value> bounds = llvm::map_to_vector(
777 caches, [&](Value cache) { return gutils->popCache(cache, builder); });
778
779 auto revPar = scf::ParallelOp::create(
780 builder, op->getLoc(),
781 /*lowerBounds=*/ValueRange(bounds).slice(0, loopCount),
782 /*upperBounds=*/ValueRange(bounds).slice(loopCount, loopCount),
783 /*steps=*/ValueRange(bounds).slice(loopCount * 2, loopCount));
784
785 bool valid = true;
786 bool wasAtomic = gutils->AtomicAdd;
787 gutils->AtomicAdd = true;
788
789 {
790 Block *oBB = parallelOp.getBody();
791 Block *revBB = revPar.getBody();
792
793 OpBuilder bodyBuilder(revBB, revBB->end());
794
795 bodyBuilder.setInsertionPointToStart(revBB);
796 mlir::enzyme::localizeGradients(bodyBuilder, gutils, oBB);
797
798 bodyBuilder.setInsertionPoint(revBB->getTerminator());
799
800 auto first = oBB->rbegin();
801 first++; // skip terminator
802
803 auto last = oBB->rend();
804
805 for (auto it = first; it != last; ++it) {
806 Operation *op = &*it;
807 valid &= gutils->Logic.visitChild(op, bodyBuilder, gutils).succeeded();
808 }
809 }
810
811 gutils->AtomicAdd = wasAtomic;
812 return success(valid);
813 }
814
815 SmallVector<Value> cacheValues(Operation *op,
816 MGradientUtilsReverse *gutils) const {
817 auto parallelOp = cast<scf::ParallelOp>(op);
818 Operation *newOp = gutils->getNewFromOriginal(op);
819 OpBuilder cacheBuilder(newOp);
820 SmallVector<Value> caches;
821 for (Value lb : parallelOp.getLowerBound())
822 caches.push_back(gutils->initAndPushCache(gutils->getNewFromOriginal(lb),
823 cacheBuilder));
824 for (Value ub : parallelOp.getUpperBound())
825 caches.push_back(gutils->initAndPushCache(gutils->getNewFromOriginal(ub),
826 cacheBuilder));
827 for (Value step : parallelOp.getStep())
828 caches.push_back(gutils->initAndPushCache(
829 gutils->getNewFromOriginal(step), cacheBuilder));
830
831 return caches;
832 }
833
834 void createShadowValues(Operation *op, OpBuilder &builder,
835 MGradientUtilsReverse *gutils) const {}
836};
837
838struct IfOpEnzymeOpsRemover
839 : public IfLikeEnzymeOpsRemover<IfOpEnzymeOpsRemover, scf::IfOp> {
840 static Block *getThenBlock(scf::IfOp ifOp, OpBuilder &builder) {
841 return ifOp.thenBlock();
842 }
843
844 static Block *getElseBlock(scf::IfOp ifOp, OpBuilder &builder) {
845 // Ensure the if has an else block
846 if (ifOp.getElseRegion().empty()) {
847 OpBuilder::InsertionGuard guard(builder);
848 Block &newBlock = ifOp.getElseRegion().emplaceBlock();
849 builder.setInsertionPointToStart(&newBlock);
850 scf::YieldOp::create(builder, ifOp.getLoc());
851 }
852
853 return ifOp.elseBlock();
854 }
855
856 static Value getDummyValue(OpBuilder &builder, Location loc, Type dummyType) {
857 return cast<AutoDiffTypeInterface>(dummyType).createNullValue(builder, loc);
858 }
859
860 static scf::IfOp replace(PatternRewriter &rewriter, scf::IfOp otherIfOp,
861 TypeRange resultTypes) {
862 auto newIf = scf::IfOp::create(rewriter, otherIfOp->getLoc(), resultTypes,
863 otherIfOp.getCondition());
864
865 newIf.getThenRegion().takeBody(otherIfOp.getThenRegion());
866 newIf.getElseRegion().takeBody(otherIfOp.getElseRegion());
867
868 rewriter.replaceAllUsesWith(
869 otherIfOp->getResults(),
870 newIf->getResults().slice(0, otherIfOp->getNumResults()));
871 rewriter.eraseOp(otherIfOp);
872 return newIf;
873 }
874};
875
876struct IfOpInterfaceReverse
877 : public ReverseAutoDiffOpInterface::ExternalModel<IfOpInterfaceReverse,
878 scf::IfOp> {
879 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
880 MGradientUtilsReverse *gutils,
881 SmallVector<Value> caches) const {
882 auto ifOp = cast<scf::IfOp>(op);
883 bool hasElse = ifOp.elseBlock() != nullptr;
884 Value cond = gutils->popCache(caches[0], builder);
885
886 SmallVector<bool> resultsActive(ifOp.getNumResults(), false);
887 for (int i = 0, e = resultsActive.size(); i < e; ++i) {
888 resultsActive[i] = !gutils->isConstantValue(ifOp.getResult(i));
889 }
890
891 SmallVector<Value> incomingGradients;
892 for (auto &&[active, res] :
893 llvm::zip_equal(resultsActive, ifOp.getResults())) {
894 if (active) {
895 incomingGradients.push_back(gutils->diffe(res, builder));
896 if (!gutils->isConstantValue(res))
897 gutils->zeroDiffe(res, builder);
898 }
899 }
900
901 auto revIf =
902 scf::IfOp::create(builder, ifOp.getLoc(), TypeRange{}, cond, hasElse);
903 bool valid = true;
904 for (auto &&[oldReg, newReg] :
905 llvm::zip(op->getRegions(), revIf->getRegions())) {
906 for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
907 OpBuilder bodyBuilder(&revBB, revBB.end());
908 bodyBuilder.setInsertionPoint(revBB.getTerminator());
909
910 // All values defined in the body should have no use outside this
911 // block therefore we can set their diffe to zero upon entering the
912 // reverse block to simplify the work of the
913 // remove-unnecessary-enzyme-ops pass.
914 for (auto &it : oBB.getOperations()) {
915 for (auto res : it.getResults()) {
916 if (!gutils->isConstantValue(res)) {
917 auto iface = dyn_cast<AutoDiffTypeInterface>(res.getType());
918 if (iface && !iface.isMutable())
919 gutils->zeroDiffe(res, bodyBuilder);
920 }
921 }
922 }
923
924 auto term = oBB.getTerminator();
925 // Align incomingGradients with their corresponding yield operands.
926 SmallVector<Value> activeTermOperands;
927 activeTermOperands.reserve(incomingGradients.size());
928 for (auto &&[resultActive, operand] :
929 llvm::zip_equal(resultsActive, term->getOperands())) {
930 if (resultActive)
931 activeTermOperands.push_back(operand);
932 }
933
934 for (auto &&[arg, operand] :
935 llvm::zip_equal(incomingGradients, activeTermOperands)) {
936 // Check activity of the argument separately from the result. If
937 // some branches yield inactive values while others yield active
938 // values, the result will be active, but this operand may still be
939 // inactive (and we cannot addToDiffe)
940 if (!gutils->isConstantValue(operand)) {
941 gutils->addToDiffe(operand, arg, bodyBuilder);
942 }
943 }
944
945 auto first = oBB.rbegin();
946 first++; // skip terminator
947
948 auto last = oBB.rend();
949
950 for (auto it = first; it != last; ++it) {
951 Operation *op = &*it;
952 valid &=
953 gutils->Logic.visitChild(op, bodyBuilder, gutils).succeeded();
954 }
955 }
956 }
957 return success(valid);
958 }
959
960 SmallVector<Value> cacheValues(Operation *op,
961 MGradientUtilsReverse *gutils) const {
962 auto ifOp = cast<scf::IfOp>(op);
963
964 Operation *newOp = gutils->getNewFromOriginal(op);
965 OpBuilder cacheBuilder(newOp);
966 Value cacheCond = gutils->initAndPushCache(
967 gutils->getNewFromOriginal(ifOp.getCondition()), cacheBuilder);
968 return SmallVector<Value>{cacheCond};
969 }
970
971 void createShadowValues(Operation *op, OpBuilder &builder,
972 MGradientUtilsReverse *gutils) const {}
973};
974
975struct ForOpADDataFlow
976 : public ADDataFlowOpInterface::ExternalModel<ForOpADDataFlow, scf::ForOp> {
977 SmallVector<Value> getPotentialIncomingValuesRes(Operation *op,
978 OpResult res) const {
979 auto forOp = cast<scf::ForOp>(op);
980 return {
981 forOp->getOperand(res.getResultNumber() + 3),
982 forOp.getBody()->getTerminator()->getOperand(res.getResultNumber())};
983 }
984 SmallVector<Value> getPotentialIncomingValuesArg(Operation *op,
985 BlockArgument arg) const {
986 auto forOp = cast<scf::ForOp>(op);
987 if (arg.getArgNumber() < forOp.getNumInductionVars())
988 return {};
989 auto idx = arg.getArgNumber() - forOp.getNumInductionVars();
990 return {forOp->getOperand(idx + 3),
991 forOp.getBody()->getTerminator()->getOperand(idx)};
992 }
993 SmallVector<Value> getPotentialTerminatorUsers(Operation *op, Operation *term,
994 Value val) const {
995 auto forOp = cast<scf::ForOp>(op);
996 SmallVector<Value> sv;
997
998 for (auto &&[res, arg, barg] :
999 llvm::zip_equal(forOp->getResults(), term->getOperands(),
1000 forOp.getRegionIterArgs())) {
1001 if (arg == val) {
1002 sv.push_back(res);
1003 sv.push_back(barg);
1004 }
1005 }
1006
1007 return sv;
1008 }
1009};
1010
1011struct ParallelOpADDataFlow
1012 : public ADDataFlowOpInterface::ExternalModel<ParallelOpADDataFlow,
1013 scf::ParallelOp> {
1014 SmallVector<Value> getPotentialIncomingValuesRes(Operation *op,
1015 OpResult res) const {
1016 auto parOp = cast<scf::ParallelOp>(op);
1017 const size_t num_lower = parOp.getLowerBound().size();
1018 const size_t num_upper = parOp.getUpperBound().size();
1019 const size_t num_step = parOp.getStep().size();
1020 const size_t init_vals_offset = num_lower + num_upper + num_step;
1021 return {parOp->getOperand(res.getResultNumber() + init_vals_offset),
1022 parOp.getBody()
1023 ->getTerminator()
1024 ->getRegion(res.getResultNumber())
1025 .front()
1026 .getTerminator()
1027 ->getOperand(0)};
1028 }
1029 SmallVector<Value> getPotentialIncomingValuesArg(Operation *op,
1030 BlockArgument arg) const {
1031 // TO DO: do we need this?
1032 assert(0);
1033 return SmallVector<Value>();
1034 }
1035 SmallVector<Value> getPotentialTerminatorUsers(Operation *op, Operation *term,
1036 Value val) const {
1037 SmallVector<Value> sv;
1038
1039 for (auto [idx, arg] : llvm::enumerate(term->getOperands())) {
1040 if (arg == val) {
1041 sv.push_back(term->getRegion(idx).front().getArgument(0));
1042 }
1043 }
1044
1045 return sv;
1046 }
1047};
1048
1049struct ReduceOpADDataFlow
1050 : public ADDataFlowOpInterface::ExternalModel<ReduceOpADDataFlow,
1051 scf::ReduceOp> {
1052 SmallVector<Value> getPotentialIncomingValuesRes(Operation *op,
1053 OpResult res) const {
1054 // ReduceOp's have no results
1055 return SmallVector<Value>();
1056 }
1057 SmallVector<Value> getPotentialIncomingValuesArg(Operation *op,
1058 BlockArgument arg) const {
1059 // The op here is the parent of the block, which is a ReduceOp
1060 // All but the last block arguments match up with the corresponding operand
1061 // of the reduce op. The last matches up with terminator operand as well as
1062 // the initial value. If this is the ith block, it is the ith initial value
1063
1064 auto redOp = cast<scf::ReduceOp>(op);
1065 mlir::Block *ownerBlock = arg.getOwner();
1066 auto num_args = ownerBlock->getNumArguments();
1067 auto arg_idx = arg.getArgNumber();
1068 auto region_idx = ownerBlock->getParent()->getRegionNumber();
1069 if (arg_idx == num_args - 1) {
1070 auto parOp = cast<scf::ParallelOp>(redOp->getParentOp());
1071 auto num_lb = parOp.getLowerBound().size();
1072 auto num_ub = parOp.getUpperBound().size();
1073 auto num_st = parOp.getStep().size();
1074 return {parOp->getOperand(num_lb + num_ub + num_st + region_idx),
1075 ownerBlock->getTerminator()->getOperand(0)};
1076 } else {
1077 return {redOp->getOperand(region_idx)};
1078 }
1079 }
1080 SmallVector<Value> getPotentialTerminatorUsers(Operation *op, Operation *term,
1081 Value val) const {
1082 auto redOp = cast<scf::ReduceOp>(op);
1083 auto parOp = cast<scf::ParallelOp>(redOp->getParentOp());
1084 mlir::Block *ownerBlock = term->getBlock();
1085 auto region_idx = ownerBlock->getParent()->getRegionNumber();
1086
1087 return {parOp->getResult(region_idx), ownerBlock->getArgument(1)};
1088 }
1089};
1090
1091class SCFReduceAutoDiffOpInterface
1092 : public AutoDiffOpInterface::ExternalModel<SCFReduceAutoDiffOpInterface,
1093 scf::ReduceOp> {
1094public:
1095 LogicalResult createForwardModeTangent(Operation *origTerminator,
1096 OpBuilder &builder,
1097 MGradientUtils *gutils) const {
1098 auto parentOp = origTerminator->getParentOp();
1099 if (!isa<scf::ParallelOp>(parentOp)) {
1100 origTerminator->emitError()
1101 << " createForwardModeTangent called with invalid parent" << *parentOp
1102 << "\n";
1103 return failure();
1104 }
1105
1106 // Note, this works for scf::ReduceOp because it has the same number of
1107 // operands as the parent (scf::ParallelOp) has results
1108 assert(parentOp->getNumResults() == origTerminator->getNumOperands());
1109 llvm::SmallDenseSet<unsigned> operandsToShadow;
1110 for (auto res : parentOp->getResults()) {
1111 if (!gutils->isConstantValue(res))
1112 operandsToShadow.insert(res.getResultNumber());
1113 }
1114
1115 SmallVector<Value> newOperands;
1116 newOperands.reserve(origTerminator->getNumOperands() +
1117 operandsToShadow.size());
1118 for (OpOperand &operand : origTerminator->getOpOperands()) {
1119 newOperands.push_back(gutils->getNewFromOriginal(operand.get()));
1120 if (operandsToShadow.contains(operand.getOperandNumber()))
1121 newOperands.push_back(gutils->invertPointerM(operand.get(), builder));
1122 }
1123
1124 // Assuming shadows following the originals are fine.
1125 // TODO: consider extending to have a ShadowableTerminatorOpInterface
1126 Operation *replTerminator = gutils->getNewFromOriginal(origTerminator);
1127 replTerminator->setOperands(newOperands);
1128
1129 // Differentiate the body of the reducer
1130 for (auto &origRegion : origTerminator->getRegions()) {
1131 for (auto &origBlock : origRegion) {
1132 for (Operation &o : origBlock) {
1133 if (failed(gutils->visitChild(&o))) {
1134 replTerminator->emitError() << " Differentiating reducer block "
1135 << *replTerminator << " failed!\n";
1136 }
1137 }
1138 }
1139 }
1140
1141 // Delete the primal operations in each differentiated reducer block by
1142 // building a map of the operations that are ultimately used by starting
1143 // from the shadow operands of the terminator (scf::ReduceReturnOp). Then
1144 // erase all of the operations that aren't used. Note that from above, all
1145 // operands for the terminator are shadow operands.
1146 for (auto &region : replTerminator->getRegions()) {
1147 for (auto &block : region) {
1148 std::map<Operation *, bool> used;
1149 std::vector<Operation *> op_list;
1150
1151 // Initialize all operations as not used
1152 for (Operation &o : block) {
1153 used[&o] = false;
1154 op_list.push_back(&o);
1155 }
1156
1157 // Recursively mark operations that are used starting from the
1158 // terminator
1159 auto mark_used = [&used](const auto &self, Operation *op) -> void {
1160 if (op != nullptr) {
1161 assert(used.find(op) != used.end());
1162 used[op] = true;
1163 for (auto v : op->getOperands())
1164 self(self, v.getDefiningOp());
1165 }
1166 };
1167 mark_used(mark_used, block.getTerminator());
1168
1169 // Delete the unused operations squentially, starting from the last so
1170 // that all users of an operation are erased before the operation itself
1171 for (auto it = op_list.rbegin(); it != op_list.rend(); ++it) {
1172 if (!used[*it]) {
1173 (*it)->erase();
1174 }
1175 }
1176
1177 // Delete the primal arguments from the block. We have to go backwards
1178 // starting from the second-to-last as the args will shift forward after
1179 // erasing.
1180 for (int i = block.getNumArguments() - 2; i >= 0; i -= 2) {
1181 block.eraseArgument(i);
1182 }
1183 }
1184 }
1185
1186 // Create a new terminator combining the regions of differentiated and
1187 // original terminators. We clone the original region so that it still
1188 // exists for the undifferentiated reducer but we can take the region from
1189 // the originally differentiated one because we delete it later
1190 mlir::OpBuilder term_builder(replTerminator);
1191 mlir::IRMapping mapper;
1192 OperationState state(replTerminator->getLoc(),
1193 scf::ReduceOp::getOperationName());
1194 state.addOperands(newOperands);
1195 size_t num_regions = origTerminator->getNumRegions();
1196 for (size_t i = 0; i < num_regions; ++i) {
1197 Region *new_orig_region = state.addRegion();
1198 Region *new_diff_region = state.addRegion();
1199 origTerminator->getRegion(i).cloneInto(new_orig_region, mapper);
1200 new_diff_region->takeBody(replTerminator->getRegion(i));
1201 }
1202 Operation *new_terminator_op = term_builder.create(state);
1203 gutils->erase(replTerminator);
1204 gutils->originalToNewFnOps[origTerminator] = new_terminator_op;
1205
1206 return success();
1207 }
1208};
1209
1210class SCFReduceReturnAutoDiffOpInterface
1211 : public AutoDiffOpInterface::ExternalModel<
1212 SCFReduceReturnAutoDiffOpInterface, scf::ReduceReturnOp> {
1213public:
1214 LogicalResult createForwardModeTangent(Operation *origTerminator,
1215 OpBuilder &builder,
1216 MGradientUtils *gutils) const {
1217 auto parentOp = origTerminator->getParentOp();
1218 if (!isa<scf::ReduceOp>(parentOp)) {
1219 origTerminator->emitError()
1220 << " createForwardModeTangent called with invalid parent" << *parentOp
1221 << "\n";
1222 return failure();
1223 }
1224
1225 // ReduceOp has no direct results, instead the result of the ith reducer
1226 // block within the ReduceOp matches up with the ith result of the parent
1227 // ParallelOp of the ReduceOp. Therefore the terminator must have exactly 1
1228 // operand and we will shadow it
1229 auto reducer_index =
1230 origTerminator->getBlock()->getParent()->getRegionNumber();
1231 assert(reducer_index < parentOp->getParentOp()->getNumResults());
1232 assert(origTerminator->getNumOperands() == 1);
1233 llvm::SmallDenseSet<unsigned> operandsToShadow;
1234 if (!gutils->isConstantValue(
1235 parentOp->getParentOp()->getResult(reducer_index)))
1236 operandsToShadow.insert(0);
1237
1238 // For scf::ReduceReturnOp only add the
1239 // shadows as operands since the primal reducer will be in a different
1240 // region with its own scf::ReduceReturnOp
1241 SmallVector<Value> newOperands;
1242 newOperands.reserve(operandsToShadow.size());
1243 for (OpOperand &operand : origTerminator->getOpOperands()) {
1244 if (operandsToShadow.contains(operand.getOperandNumber()))
1245 newOperands.push_back(gutils->invertPointerM(operand.get(), builder));
1246 }
1247
1248 // Special handling for scf::ReduceOp where the assumption that shadows
1249 // follow originals is violated. Here the shadow operations need to be put
1250 // in a shadow region. It isn't clear how to do that directly, so instead
1251 // we will create the shadows as normal and then create a new scf::ReduceOp
1252 // terminator that combines the regions from the original and
1253 // differentiated. We then erase the primal operations from the derivative
1254 // reducer region(s).
1255 Operation *replTerminator = gutils->getNewFromOriginal(origTerminator);
1256 replTerminator->setOperands(newOperands);
1257
1258 return success();
1259 }
1260};
1261
1262} // namespace
1263
1265 DialectRegistry &registry) {
1266 registry.addExtension(+[](MLIRContext *context, scf::SCFDialect *) {
1267 registerInterfaces(context);
1268 scf::IfOp::attachInterface<IfOpInterfaceReverse>(*context);
1269 scf::IfOp::attachInterface<IfOpEnzymeOpsRemover>(*context);
1270 scf::ParallelOp::attachInterface<ParallelOpInterfaceReverse>(*context);
1271 scf::ParallelOp::attachInterface<ParallelOpEnzymeOpsRemover>(*context);
1272 scf::ParallelOp::attachInterface<ParallelOpADDataFlow>(*context);
1273 scf::ReduceOp::attachInterface<ReduceOpADDataFlow>(*context);
1274 scf::ReduceOp::attachInterface<SCFReduceAutoDiffOpInterface>(*context);
1275 scf::ReduceReturnOp::attachInterface<SCFReduceReturnAutoDiffOpInterface>(
1276 *context);
1277 scf::ForOp::attachInterface<ForOpInterfaceReverse>(*context);
1278 scf::ForOp::attachInterface<ForOpEnzymeOpsRemover>(*context);
1279 scf::ForOp::attachInterface<ForOpADDataFlow>(*context);
1280 });
1281}
Operation * clone(Operation *src, IRMapping &mapper, Operation::CloneOptions options, std::map< Operation *, Operation * > &opMap)
static std::optional< SmallVector< Value > > getPotentialTerminatorUsers(Operation *op, Value parent)
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder)
LogicalResult visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils)
Value popCache(Value cache, OpBuilder &builder)
Value initAndPushCache(Value v, OpBuilder &builder)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
void replaceOrigOpWith(Operation *op, ValueRange vals)
std::map< Operation *, Operation * > originalToNewFnOps
LogicalResult visitChild(Operation *op)
void erase(Operation *op)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
bool isConstantValue(mlir::Value v) const
void localizeGradients(OpBuilder &builder, MGradientUtilsReverse *gutils, Block *fwd)
void registerSCFDialectAutoDiffInterface(DialectRegistry &registry)