20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/IR/DialectRegistry.h"
24#include "mlir/IR/Types.h"
25#include "mlir/Support/LogicalResult.h"
26#include "mlir/Transforms/RegionUtils.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/ScopeExit.h"
35#include "Implementations/SCFDerivatives.inc"
37struct ForOpEnzymeOpsRemover
41 static std::optional<int64_t>
42 getConstantNumberOfIterations(scf::ForOp forOp) {
43 auto lb = forOp.getLowerBound();
44 auto ub = forOp.getUpperBound();
45 auto step = forOp.getStep();
47 IntegerAttr lbAttr, ubAttr, stepAttr;
48 if (!matchPattern(lb, m_Constant(&lbAttr)))
50 if (!matchPattern(ub, m_Constant(&ubAttr)))
52 if (!matchPattern(step, m_Constant(&stepAttr)))
55 int64_t lbI = lbAttr.getInt(), ubI = ubAttr.getInt(),
56 stepI = stepAttr.getInt();
58 return (ubI - lbI) / stepI;
61 static SmallVector<IntOrValue, 1> getDimensionBounds(OpBuilder &builder,
63 auto iters = getConstantNumberOfIterations(forOp);
65 return {IntOrValue(*iters)};
67 Value lb = forOp.getLowerBound(), ub = forOp.getUpperBound(),
68 step = forOp.getStep();
69 Value diff = arith::SubIOp::create(builder, forOp->getLoc(), ub, lb);
71 arith::DivUIOp::create(builder, forOp->getLoc(), diff, step);
72 return {IntOrValue(nSteps)};
76 static SmallVector<Value> getCanonicalLoopIVs(OpBuilder &builder,
79 Value val = forOp.getBody()->getArgument(0);
80 if (!matchPattern(forOp.getLowerBound(), m_Zero())) {
81 val = arith::SubIOp::create(builder, forOp->getLoc(), val,
82 forOp.getLowerBound());
85 if (!matchPattern(forOp.getStep(), m_One())) {
86 val = arith::DivUIOp::create(builder, forOp->getLoc(), val,
92 static IRMapping createArgumentMap(PatternRewriter &rewriter,
93 scf::ForOp forOp, ArrayRef<Value> indFor,
94 scf::ForOp otherForOp,
95 ArrayRef<Value> reversedOther) {
97 for (
auto &&[f, o] : llvm::zip_equal(indFor, reversedOther)) {
101 Value canIdx = forOp.getBody()->getArgument(0);
102 if (!map.contains(canIdx)) {
103 assert(Equivalent(forOp.getLowerBound(), otherForOp.getLowerBound()));
104 assert(Equivalent(forOp.getStep(), otherForOp.getStep()));
105 map.map(forOp.getBody()->getArgument(0),
106 otherForOp.getBody()->getArgument(0));
111 static scf::ForOp replaceWithNewOperands(PatternRewriter &rewriter,
112 scf::ForOp otherForOp,
113 ArrayRef<Value> operands) {
114 auto newOtherForOp = scf::ForOp::create(
115 rewriter, otherForOp->getLoc(), otherForOp.getLowerBound(),
116 otherForOp.getUpperBound(), otherForOp.getStep(), operands);
118 newOtherForOp.getRegion().takeBody(otherForOp.getRegion());
119 rewriter.replaceOp(otherForOp, newOtherForOp->getResults().slice(
120 0, otherForOp->getNumResults()));
121 return newOtherForOp;
124 static ValueRange getInits(scf::ForOp forOp) {
return forOp.getInitArgs(); }
126 static bool mustPostAdd(scf::ForOp forOp) {
return false; }
128 static Value initialValueInBlock(OpBuilder &builder, Block *body,
130 auto Ty = cast<enzyme::GradientType>(grad.getType()).getBasetype();
131 return body->addArgument(Ty, grad.getLoc());
135struct ForOpInterfaceReverse
136 :
public ReverseAutoDiffOpInterface::ExternalModel<ForOpInterfaceReverse,
139 static Value makeIntConstant(Location loc, OpBuilder builder, int64_t val,
141 return arith::ConstantOp::create(builder, loc, IntegerAttr::get(ty, val))
145 static void preserveAttributesButCheckpointing(Operation *newOp,
147 for (
auto attr : oldOp->getDiscardableAttrs()) {
148 if (attr.getName() !=
"enzyme.enable_checkpointing")
149 newOp->setAttr(attr.getName(), attr.getValue());
153 static bool needsCheckpointing(scf::ForOp forOp) {
154 return forOp->hasAttrOfType<BoolAttr>(
"enzyme.enable_checkpointing") &&
155 forOp->getAttrOfType<BoolAttr>(
"enzyme.enable_checkpointing")
157 ForOpEnzymeOpsRemover::getConstantNumberOfIterations(forOp)
162 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
164 SmallVector<Value> caches)
const {
169 auto forOp = cast<scf::ForOp>(op);
171 SmallVector<bool> operandsActive(forOp.getNumOperands() - 3,
false);
172 for (
int i = 0, e = operandsActive.size(); i < e; ++i) {
177 SmallVector<Value> incomingGradients;
178 for (
auto &&[active, res] :
179 llvm::zip_equal(operandsActive, op->getResults())) {
181 incomingGradients.push_back(gutils->
diffe(res, builder));
187 if (needsCheckpointing(forOp)) {
189 ForOpEnzymeOpsRemover::getConstantNumberOfIterations(forOp).value();
190 int64_t nInner = std::sqrt(numIters), nOuter = nInner;
191 int64_t trailingIters = numIters - nInner * nOuter;
193 bool hasTrailing = trailingIters > 0;
195 auto numIterArgs = forOp.getNumRegionIterArgs();
197 SetVector<Value> outsideRefs;
198 getUsedValuesDefinedAbove(op->getRegions(), outsideRefs);
200 SmallVector<Value> immutableRefs;
201 SmallVector<Value> mutableRefs;
203 for (
auto ref : outsideRefs) {
204 if (isa<ClonableTypeInterface>(ref.getType()))
205 mutableRefs.push_back(ref);
207 immutableRefs.push_back(ref);
212 assert(outsideRefs.size() == caches.size() - numIterArgs);
214 for (
auto [i, ref] : llvm::enumerate(immutableRefs)) {
216 caches[numIterArgs + mutableRefs.size() + i], builder);
217 mapping.map(ref, refVal);
220 auto ivTy = forOp.getLowerBound().getType();
221 Value outerUB = makeIntConstant(forOp.getLowerBound().getLoc(), builder,
222 nOuter + hasTrailing, ivTy);
223 auto revOuter = scf::ForOp::create(
224 builder, op->getLoc(),
225 makeIntConstant(forOp.getLowerBound().getLoc(), builder, 0, ivTy),
227 makeIntConstant(forOp.getLowerBound().getLoc(), builder, 1, ivTy),
229 preserveAttributesButCheckpointing(revOuter, forOp);
231 OpBuilder::InsertionGuard guard(builder);
232 builder.setInsertionPointToEnd(revOuter.getBody());
234 SmallVector<Value> cachedOutsideRefs;
235 for (
auto [i, ref] : llvm::enumerate(mutableRefs)) {
236 Value refVal = gutils->
popCache(caches[numIterArgs + i], builder);
237 cachedOutsideRefs.push_back(refVal);
238 mapping.map(ref, refVal);
241 Location loc = forOp.getInductionVar().getLoc();
242 Value currentOuterStep = arith::SubIOp::create(
243 builder, loc, makeIntConstant(loc, builder, nOuter, ivTy),
244 revOuter.getInductionVar());
246 SmallVector<Value> initArgs(numIterArgs,
nullptr);
247 for (
size_t i = 0; i < numIterArgs; ++i) {
248 initArgs[i] = gutils->
popCache(caches[i], builder);
251 auto nInnerCst = makeIntConstant(forOp.getLowerBound().getLoc(), builder,
253 Value zero = makeIntConstant(forOp.getLowerBound().getLoc(), builder, 0,
255 one = makeIntConstant(forOp.getLowerBound().getLoc(), builder, 1,
258 Value nInnerUB = nInnerCst;
259 if (trailingIters > 0) {
261 Location loc = forOp.getUpperBound().getLoc();
262 nInnerUB = arith::SelectOp::create(
264 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
265 revOuter.getInductionVar(), zero),
266 makeIntConstant(loc, builder, trailingIters, ivTy), nInnerCst);
269 auto revInner = scf::ForOp::create(builder, forOp.getLoc(), zero,
270 nInnerUB, one, initArgs);
271 preserveAttributesButCheckpointing(revInner, forOp);
274 if (!matchPattern(forOp.getStep(), m_ConstantInt(&stepI))) {
275 op->emitError() <<
"step size is not known constant\n";
280 if (!matchPattern(forOp.getLowerBound(), m_ConstantInt(&startI))) {
281 op->emitError() <<
"lower bound is not known constant\n";
285 builder.setInsertionPointToEnd(revInner.getBody());
287 Value currentIV = arith::AddIOp::create(
289 arith::MulIOp::create(
291 arith::AddIOp::create(builder, loc,
292 arith::MulIOp::create(builder, loc,
295 revInner.getInductionVar()),
296 arith::ConstantOp::create(builder, loc,
297 IntegerAttr::get(ivTy, stepI))),
298 arith::ConstantOp::create(builder, loc,
299 IntegerAttr::get(ivTy, startI)));
301 for (
auto [oldArg, newArg] :
302 llvm::zip_equal(forOp.getBody()->getArguments(),
303 revInner.getBody()->getArguments()))
304 mapping.map(oldArg, newArg);
305 mapping.map(forOp.getInductionVar(), currentIV);
307 for (
auto &it : *forOp.getBody()) {
308 auto newOp = builder.clone(it, mapping);
312 builder.setInsertionPointToEnd(revOuter.getBody());
314 for (
auto outsideRef : cachedOutsideRefs) {
316 dyn_cast<ClonableTypeInterface>(outsideRef.getType())) {
317 cachableT.freeClonedValue(builder, outsideRef);
322 scf::ForOp::create(builder, forOp.getLoc(), zero, nInnerUB, one,
323 revOuter.getBody()->getArguments().drop_front());
324 preserveAttributesButCheckpointing(revLoop, forOp);
326 Block *revLoopBody = revLoop.getBody();
327 builder.setInsertionPointToEnd(revLoopBody);
330 for (
auto &&[active, operand] :
331 llvm::zip_equal(operandsActive,
332 forOp.getBody()->getTerminator()->getOperands())) {
334 gutils->
addToDiffe(operand, revLoopBody->getArgument(revIdx),
340 Block *origBody = forOp.getBody();
344 auto first = origBody->rbegin();
347 auto last = origBody->rend();
349 for (
auto it = first; it != last; ++it) {
350 Operation *op = &*it;
354 SmallVector<Value> newResults;
355 for (
auto &&[active, arg] : llvm::zip_equal(
356 operandsActive, origBody->getArguments().drop_front())) {
358 newResults.push_back(gutils->
diffe(arg, builder));
364 builder.setInsertionPointToEnd(revLoopBody);
365 scf::YieldOp::create(builder, forOp.getBody()->getTerminator()->getLoc(),
368 builder.setInsertionPointToEnd(revOuter.getBody());
369 scf::YieldOp::create(builder, forOp.getBody()->getTerminator()->getLoc(),
370 revLoop.getResults());
372 builder.setInsertionPointAfter(revOuter);
375 for (
auto &&[active, arg] : llvm::zip_equal(
377 op->getOperands().slice(3, op->getNumOperands() - 3))) {
380 gutils->
addToDiffe(arg, revOuter->getResult(revIdx), builder);
386 return success(valid);
389 auto start = gutils->
popCache(caches[0], builder);
390 auto end = gutils->
popCache(caches[1], builder);
391 auto step = gutils->
popCache(caches[2], builder);
393 auto repFor = scf::ForOp::create(builder, forOp.getLoc(), start, end, step,
395 preserveAttributesButCheckpointing(repFor, forOp);
398 for (
auto &&[oldReg, newReg] :
399 llvm::zip(op->getRegions(), repFor->getRegions())) {
400 for (
auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
401 OpBuilder bodyBuilder(&revBB, revBB.end());
405 scf::YieldOp::create(bodyBuilder, repFor->getLoc());
408 bodyBuilder.setInsertionPointToStart(&revBB);
411 bodyBuilder.setInsertionPoint(revBB.getTerminator());
413 auto term = oBB.getTerminator();
416 for (
auto &&[active, operand] :
417 llvm::zip_equal(operandsActive, term->getOperands())) {
422 gutils->
setDiffe(operand, revBB.getArgument(argIdx), bodyBuilder);
427 auto first = oBB.rbegin();
430 auto last = oBB.rend();
432 for (
auto it = first; it != last; ++it) {
433 Operation *op = &*it;
438 SmallVector<Value> newResults;
439 newResults.reserve(incomingGradients.size());
441 for (
auto &&[active, arg] :
442 llvm::zip_equal(operandsActive, oBB.getArguments().slice(1))) {
444 newResults.push_back(gutils->
diffe(arg, bodyBuilder));
451 revBB.getTerminator()->setOperands(newResults);
456 for (
auto &&[active, arg] :
457 llvm::zip_equal(operandsActive, forOp.getInitArgs())) {
460 gutils->
addToDiffe(arg, repFor.getResult(resIdx), builder);
466 return success(valid);
469 SmallVector<Value> cacheValues(Operation *op,
471 auto forOp = cast<scf::ForOp>(op);
473 OpBuilder cacheBuilder(newOp);
475 if (needsCheckpointing(forOp)) {
477 ForOpEnzymeOpsRemover::getConstantNumberOfIterations(forOp).value();
478 int64_t nInner = std::sqrt(numIters), nOuter = nInner;
479 int64_t trailingIters = numIters - nInner * nOuter;
480 bool hasTrailing = trailingIters > 0;
482 SetVector<Value> outsideRefs;
483 getUsedValuesDefinedAbove(op->getRegions(), outsideRefs);
485 SmallVector<Value> immutableRefs;
486 SmallVector<Value> mutableRefs;
488 for (
auto ref : outsideRefs) {
489 if (isa<ClonableTypeInterface>(ref.getType()))
490 mutableRefs.push_back(ref);
492 immutableRefs.push_back(ref);
495 SmallVector<Value> caches;
499 Type ty = forOp.getLowerBound().getType();
500 auto outerFwd = scf::ForOp::create(
501 cacheBuilder, op->getLoc(),
502 makeIntConstant(forOp.getLowerBound().getLoc(), cacheBuilder, 0, ty),
503 makeIntConstant(forOp.getUpperBound().getLoc(), cacheBuilder,
504 nInner * (nOuter + hasTrailing), ty),
505 makeIntConstant(forOp.getStep().getLoc(), cacheBuilder, nInner, ty),
506 newForOp.getInitArgs());
507 preserveAttributesButCheckpointing(outerFwd, forOp);
509 cacheBuilder.setInsertionPointToStart(outerFwd.getBody());
510 auto nInnerCst = makeIntConstant(forOp.getUpperBound().getLoc(),
511 cacheBuilder, nInner, ty);
513 Value nInnerUB = nInnerCst;
514 if (trailingIters > 0) {
517 Location loc = forOp.getUpperBound().getLoc();
518 nInnerUB = arith::SelectOp::create(
520 arith::CmpIOp::create(
521 cacheBuilder, loc, arith::CmpIPredicate::eq,
522 outerFwd.getInductionVar(),
523 makeIntConstant(loc, cacheBuilder, nInner * nOuter, ty)),
524 makeIntConstant(loc, cacheBuilder, trailingIters, ty), nInnerCst);
529 SmallVector<Value> mutableRefsCaches;
530 for (
auto ref : mutableRefs) {
531 auto iface = cast<ClonableTypeInterface>(ref.getType());
533 iface.cloneValue(cacheBuilder, mapping.lookupOrDefault(ref));
534 mutableRefsCaches.push_back(
538 auto innerFwd = scf::ForOp::create(
539 cacheBuilder, op->getLoc(),
540 makeIntConstant(forOp.getLowerBound().getLoc(), cacheBuilder, 0, ty),
542 makeIntConstant(forOp.getStep().getLoc(), cacheBuilder, 1, ty),
543 outerFwd.getBody()->getArguments().drop_front());
544 preserveAttributesButCheckpointing(innerFwd, forOp);
546 cacheBuilder.setInsertionPointToEnd(innerFwd.getBody());
548 Location loc = forOp.getInductionVar().getLoc();
549 auto currentIV = arith::MulIOp::create(
551 arith::AddIOp::create(
553 arith::MulIOp::create(cacheBuilder, loc,
554 outerFwd.getInductionVar(), nInnerCst),
555 innerFwd.getInductionVar()),
558 for (
auto [oldArg, newArg] :
559 llvm::zip_equal(forOp.getBody()->getArguments(),
560 innerFwd.getBody()->getArguments()))
561 mapping.map(oldArg, newArg);
562 mapping.map(forOp.getInductionVar(), currentIV);
564 for (
auto &it : *forOp.getBody())
565 cacheBuilder.clone(it, mapping);
567 cacheBuilder.setInsertionPointToEnd(outerFwd.getBody());
568 for (
auto initArg : innerFwd.getInitArgs())
571 scf::YieldOp::create(cacheBuilder,
572 forOp.getBody()->getTerminator()->getLoc(),
573 innerFwd->getResults());
575 cacheBuilder.setInsertionPointAfter(outerFwd);
577 caches.append(mutableRefsCaches);
579 for (
auto ref : immutableRefs)
584 gutils->
erase(newForOp);
602 SmallVector<Value> caches;
606 caches.push_back(cacheLB);
610 caches.push_back(cacheUB);
614 caches.push_back(cacheStep);
619 void createShadowValues(Operation *op, OpBuilder &builder,
625struct ParallelOpEnzymeOpsRemover
628 static std::optional<int64_t>
629 getConstantNumberOfIterations(Value lb, Value ub, Value step) {
630 IntegerAttr lbAttr, ubAttr, stepAttr;
631 if (!matchPattern(lb, m_Constant(&lbAttr)))
633 if (!matchPattern(ub, m_Constant(&ubAttr)))
635 if (!matchPattern(step, m_Constant(&stepAttr)))
638 int64_t lbI = lbAttr.getInt(), ubI = ubAttr.getInt(),
639 stepI = stepAttr.getInt();
640 return (ubI - lbI) / stepI;
643 static SmallVector<IntOrValue, 1> getDimensionBounds(OpBuilder &builder,
644 scf::ParallelOp parOp) {
645 SmallVector<IntOrValue, 1> bounds;
646 bounds.reserve(parOp.getNumLoops());
647 for (
auto &&[lb, ub, step] : llvm::zip_equal(
648 parOp.getLowerBound(), parOp.getUpperBound(), parOp.getStep())) {
649 auto iters = getConstantNumberOfIterations(lb, ub, step);
651 bounds.push_back(IntOrValue(*iters));
653 Value diff = arith::SubIOp::create(builder, parOp.getLoc(), ub, lb);
655 arith::DivUIOp::create(builder, parOp.getLoc(), diff, step);
656 bounds.push_back(IntOrValue(nSteps));
662 static SmallVector<Value>
663 computeReversedIndices(PatternRewriter &rewriter, scf::ParallelOp parOp,
664 ArrayRef<Value> otherInductionVariable,
665 ArrayRef<IntOrValue> bounds) {
666 return SmallVector<Value>(otherInductionVariable);
669 static SmallVector<Value> getCanonicalLoopIVs(OpBuilder &builder,
670 scf::ParallelOp parOp) {
671 SmallVector<Value> canonicalIVs;
672 canonicalIVs.reserve(parOp.getNumLoops());
673 for (
auto &&[iv, lb, step] :
674 llvm::zip_equal(parOp.getInductionVars(), parOp.getLowerBound(),
677 if (!matchPattern(lb, m_Zero())) {
678 val = arith::SubIOp::create(builder, parOp.getLoc(), val, lb);
681 if (!matchPattern(step, m_One())) {
682 val = arith::DivUIOp::create(builder, parOp.getLoc(), val, step);
684 canonicalIVs.push_back(val);
689 static IRMapping createArgumentMap(PatternRewriter &rewriter,
690 scf::ParallelOp parOp,
691 ArrayRef<Value> indPar,
692 scf::ParallelOp otherParOp,
693 ArrayRef<Value> indOther) {
695 for (
auto &&[f, o] : llvm::zip_equal(indPar, indOther))
698 for (
auto &&[iv, oiv, lb, olb, step, ostep] : llvm::zip_equal(
699 parOp.getInductionVars(), otherParOp.getInductionVars(),
700 parOp.getLowerBound(), otherParOp.getLowerBound(), parOp.getStep(),
701 otherParOp.getStep())) {
702 if (!map.contains(iv)) {
703 assert(Equivalent(lb, olb));
704 assert(Equivalent(step, ostep));
711 static scf::ParallelOp replaceWithNewOperands(PatternRewriter &rewriter,
712 scf::ParallelOp otherParallelOp,
713 ArrayRef<Value> operands) {
714 auto newOtherParOp = scf::ParallelOp::create(
715 rewriter, otherParallelOp.getLoc(), otherParallelOp.getLowerBound(),
716 otherParallelOp.getUpperBound(), otherParallelOp.getStep(), operands);
718 newOtherParOp.getRegion().takeBody(otherParallelOp.getRegion());
721 newOtherParOp.getResults().slice(0, otherParallelOp.getNumResults()));
723 if (operands.size() >= 1) {
724 OpBuilder::InsertionGuard guard(rewriter);
725 Operation *oldTerm = newOtherParOp.getBody()->getTerminator();
726 rewriter.setInsertionPointToEnd(newOtherParOp.getBody());
727 auto term = scf::ReduceOp::create(rewriter, newOtherParOp.getLoc(),
728 oldTerm->getOperands());
730 for (
auto [reg, operand] :
731 llvm::zip_equal(term->getRegions(), operands)) {
732 Block *b = ®.front();
733 rewriter.setInsertionPointToEnd(b);
735 auto Ty = cast<AutoDiffTypeInterface>(operand.getType());
736 Value reduced = Ty.createAddOp(rewriter, operand.getLoc(),
737 b->getArgument(0), b->getArgument(1));
738 scf::ReduceReturnOp::create(rewriter, reduced.getLoc(), reduced);
744 return newOtherParOp;
747 static ValueRange getInits(scf::ParallelOp parallelOp) {
748 return parallelOp.getInitVals();
751 static bool mustPostAdd(scf::ParallelOp forOp) {
return false; }
753 static Value initialValueInBlock(OpBuilder &builder, Block *body,
755 OpBuilder::InsertionGuard guard(builder);
756 builder.setInsertionPointToStart(body);
757 return cast<AutoDiffTypeInterface>(
758 cast<enzyme::GradientType>(grad.getType()).getBasetype())
759 .createNullValue(builder, grad.getLoc());
763struct ParallelOpInterfaceReverse
764 :
public ReverseAutoDiffOpInterface::ExternalModel<
765 ParallelOpInterfaceReverse, scf::ParallelOp> {
766 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
768 SmallVector<Value> caches)
const {
769 auto parallelOp = cast<scf::ParallelOp>(op);
770 if (parallelOp.getNumReductions() != 0) {
771 return parallelOp.emitError()
772 <<
"parallel reductions not yet implemented\n";
775 unsigned loopCount = parallelOp.getNumLoops();
776 SmallVector<Value> bounds = llvm::map_to_vector(
777 caches, [&](Value cache) {
return gutils->
popCache(cache, builder); });
779 auto revPar = scf::ParallelOp::create(
780 builder, op->getLoc(),
781 ValueRange(bounds).slice(0, loopCount),
782 ValueRange(bounds).slice(loopCount, loopCount),
783 ValueRange(bounds).slice(loopCount * 2, loopCount));
790 Block *oBB = parallelOp.getBody();
791 Block *revBB = revPar.getBody();
793 OpBuilder bodyBuilder(revBB, revBB->end());
795 bodyBuilder.setInsertionPointToStart(revBB);
798 bodyBuilder.setInsertionPoint(revBB->getTerminator());
800 auto first = oBB->rbegin();
803 auto last = oBB->rend();
805 for (
auto it = first; it != last; ++it) {
806 Operation *op = &*it;
807 valid &= gutils->
Logic.
visitChild(op, bodyBuilder, gutils).succeeded();
812 return success(valid);
815 SmallVector<Value> cacheValues(Operation *op,
817 auto parallelOp = cast<scf::ParallelOp>(op);
819 OpBuilder cacheBuilder(newOp);
820 SmallVector<Value> caches;
821 for (Value lb : parallelOp.getLowerBound())
824 for (Value ub : parallelOp.getUpperBound())
827 for (Value step : parallelOp.getStep())
834 void createShadowValues(Operation *op, OpBuilder &builder,
838struct IfOpEnzymeOpsRemover
840 static Block *getThenBlock(scf::IfOp ifOp, OpBuilder &builder) {
841 return ifOp.thenBlock();
844 static Block *getElseBlock(scf::IfOp ifOp, OpBuilder &builder) {
846 if (ifOp.getElseRegion().empty()) {
847 OpBuilder::InsertionGuard guard(builder);
848 Block &newBlock = ifOp.getElseRegion().emplaceBlock();
849 builder.setInsertionPointToStart(&newBlock);
850 scf::YieldOp::create(builder, ifOp.getLoc());
853 return ifOp.elseBlock();
856 static Value getDummyValue(OpBuilder &builder, Location loc, Type dummyType) {
857 return cast<AutoDiffTypeInterface>(dummyType).createNullValue(builder, loc);
860 static scf::IfOp replace(PatternRewriter &rewriter, scf::IfOp otherIfOp,
861 TypeRange resultTypes) {
862 auto newIf = scf::IfOp::create(rewriter, otherIfOp->getLoc(), resultTypes,
863 otherIfOp.getCondition());
865 newIf.getThenRegion().takeBody(otherIfOp.getThenRegion());
866 newIf.getElseRegion().takeBody(otherIfOp.getElseRegion());
868 rewriter.replaceAllUsesWith(
869 otherIfOp->getResults(),
870 newIf->getResults().slice(0, otherIfOp->getNumResults()));
871 rewriter.eraseOp(otherIfOp);
876struct IfOpInterfaceReverse
877 :
public ReverseAutoDiffOpInterface::ExternalModel<IfOpInterfaceReverse,
879 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
881 SmallVector<Value> caches)
const {
882 auto ifOp = cast<scf::IfOp>(op);
883 bool hasElse = ifOp.elseBlock() !=
nullptr;
884 Value cond = gutils->
popCache(caches[0], builder);
886 SmallVector<bool> resultsActive(ifOp.getNumResults(),
false);
887 for (
int i = 0, e = resultsActive.size(); i < e; ++i) {
891 SmallVector<Value> incomingGradients;
892 for (
auto &&[active, res] :
893 llvm::zip_equal(resultsActive, ifOp.getResults())) {
895 incomingGradients.push_back(gutils->
diffe(res, builder));
902 scf::IfOp::create(builder, ifOp.getLoc(), TypeRange{}, cond, hasElse);
904 for (
auto &&[oldReg, newReg] :
905 llvm::zip(op->getRegions(), revIf->getRegions())) {
906 for (
auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
907 OpBuilder bodyBuilder(&revBB, revBB.end());
908 bodyBuilder.setInsertionPoint(revBB.getTerminator());
914 for (
auto &it : oBB.getOperations()) {
915 for (
auto res : it.getResults()) {
917 auto iface = dyn_cast<AutoDiffTypeInterface>(res.getType());
918 if (iface && !iface.isMutable())
924 auto term = oBB.getTerminator();
926 SmallVector<Value> activeTermOperands;
927 activeTermOperands.reserve(incomingGradients.size());
928 for (
auto &&[resultActive, operand] :
929 llvm::zip_equal(resultsActive, term->getOperands())) {
931 activeTermOperands.push_back(operand);
934 for (
auto &&[arg, operand] :
935 llvm::zip_equal(incomingGradients, activeTermOperands)) {
941 gutils->
addToDiffe(operand, arg, bodyBuilder);
945 auto first = oBB.rbegin();
948 auto last = oBB.rend();
950 for (
auto it = first; it != last; ++it) {
951 Operation *op = &*it;
957 return success(valid);
960 SmallVector<Value> cacheValues(Operation *op,
962 auto ifOp = cast<scf::IfOp>(op);
965 OpBuilder cacheBuilder(newOp);
968 return SmallVector<Value>{cacheCond};
971 void createShadowValues(Operation *op, OpBuilder &builder,
975struct ForOpADDataFlow
976 :
public ADDataFlowOpInterface::ExternalModel<ForOpADDataFlow, scf::ForOp> {
977 SmallVector<Value> getPotentialIncomingValuesRes(Operation *op,
978 OpResult res)
const {
979 auto forOp = cast<scf::ForOp>(op);
981 forOp->getOperand(res.getResultNumber() + 3),
982 forOp.getBody()->getTerminator()->getOperand(res.getResultNumber())};
984 SmallVector<Value> getPotentialIncomingValuesArg(Operation *op,
985 BlockArgument arg)
const {
986 auto forOp = cast<scf::ForOp>(op);
987 if (arg.getArgNumber() < forOp.getNumInductionVars())
989 auto idx = arg.getArgNumber() - forOp.getNumInductionVars();
990 return {forOp->getOperand(idx + 3),
991 forOp.getBody()->getTerminator()->getOperand(idx)};
995 auto forOp = cast<scf::ForOp>(op);
996 SmallVector<Value> sv;
998 for (
auto &&[res, arg, barg] :
999 llvm::zip_equal(forOp->getResults(), term->getOperands(),
1000 forOp.getRegionIterArgs())) {
1011struct ParallelOpADDataFlow
1012 :
public ADDataFlowOpInterface::ExternalModel<ParallelOpADDataFlow,
1014 SmallVector<Value> getPotentialIncomingValuesRes(Operation *op,
1015 OpResult res)
const {
1016 auto parOp = cast<scf::ParallelOp>(op);
1017 const size_t num_lower = parOp.getLowerBound().size();
1018 const size_t num_upper = parOp.getUpperBound().size();
1019 const size_t num_step = parOp.getStep().size();
1020 const size_t init_vals_offset = num_lower + num_upper + num_step;
1021 return {parOp->getOperand(res.getResultNumber() + init_vals_offset),
1024 ->getRegion(res.getResultNumber())
1029 SmallVector<Value> getPotentialIncomingValuesArg(Operation *op,
1030 BlockArgument arg)
const {
1033 return SmallVector<Value>();
1037 SmallVector<Value> sv;
1039 for (
auto [idx, arg] : llvm::enumerate(term->getOperands())) {
1041 sv.push_back(term->getRegion(idx).front().getArgument(0));
1049struct ReduceOpADDataFlow
1050 :
public ADDataFlowOpInterface::ExternalModel<ReduceOpADDataFlow,
1052 SmallVector<Value> getPotentialIncomingValuesRes(Operation *op,
1053 OpResult res)
const {
1055 return SmallVector<Value>();
1057 SmallVector<Value> getPotentialIncomingValuesArg(Operation *op,
1058 BlockArgument arg)
const {
1064 auto redOp = cast<scf::ReduceOp>(op);
1065 mlir::Block *ownerBlock = arg.getOwner();
1066 auto num_args = ownerBlock->getNumArguments();
1067 auto arg_idx = arg.getArgNumber();
1068 auto region_idx = ownerBlock->getParent()->getRegionNumber();
1069 if (arg_idx == num_args - 1) {
1070 auto parOp = cast<scf::ParallelOp>(redOp->getParentOp());
1071 auto num_lb = parOp.getLowerBound().size();
1072 auto num_ub = parOp.getUpperBound().size();
1073 auto num_st = parOp.getStep().size();
1074 return {parOp->getOperand(num_lb + num_ub + num_st + region_idx),
1075 ownerBlock->getTerminator()->getOperand(0)};
1077 return {redOp->getOperand(region_idx)};
1082 auto redOp = cast<scf::ReduceOp>(op);
1083 auto parOp = cast<scf::ParallelOp>(redOp->getParentOp());
1084 mlir::Block *ownerBlock = term->getBlock();
1085 auto region_idx = ownerBlock->getParent()->getRegionNumber();
1087 return {parOp->getResult(region_idx), ownerBlock->getArgument(1)};
1091class SCFReduceAutoDiffOpInterface
1092 :
public AutoDiffOpInterface::ExternalModel<SCFReduceAutoDiffOpInterface,
1095 LogicalResult createForwardModeTangent(Operation *origTerminator,
1098 auto parentOp = origTerminator->getParentOp();
1099 if (!isa<scf::ParallelOp>(parentOp)) {
1100 origTerminator->emitError()
1101 <<
" createForwardModeTangent called with invalid parent" << *parentOp
1108 assert(parentOp->getNumResults() == origTerminator->getNumOperands());
1109 llvm::SmallDenseSet<unsigned> operandsToShadow;
1110 for (
auto res : parentOp->getResults()) {
1112 operandsToShadow.insert(res.getResultNumber());
1115 SmallVector<Value> newOperands;
1116 newOperands.reserve(origTerminator->getNumOperands() +
1117 operandsToShadow.size());
1118 for (OpOperand &operand : origTerminator->getOpOperands()) {
1120 if (operandsToShadow.contains(operand.getOperandNumber()))
1121 newOperands.push_back(gutils->
invertPointerM(operand.get(), builder));
1127 replTerminator->setOperands(newOperands);
1130 for (
auto &origRegion : origTerminator->getRegions()) {
1131 for (
auto &origBlock : origRegion) {
1132 for (Operation &o : origBlock) {
1134 replTerminator->emitError() <<
" Differentiating reducer block "
1135 << *replTerminator <<
" failed!\n";
1146 for (
auto ®ion : replTerminator->getRegions()) {
1147 for (
auto &block : region) {
1148 std::map<Operation *, bool> used;
1149 std::vector<Operation *> op_list;
1152 for (Operation &o : block) {
1154 op_list.push_back(&o);
1159 auto mark_used = [&used](
const auto &self, Operation *op) ->
void {
1160 if (op !=
nullptr) {
1161 assert(used.find(op) != used.end());
1163 for (
auto v : op->getOperands())
1164 self(self, v.getDefiningOp());
1167 mark_used(mark_used, block.getTerminator());
1171 for (
auto it = op_list.rbegin(); it != op_list.rend(); ++it) {
1180 for (
int i = block.getNumArguments() - 2; i >= 0; i -= 2) {
1181 block.eraseArgument(i);
1190 mlir::OpBuilder term_builder(replTerminator);
1191 mlir::IRMapping mapper;
1192 OperationState state(replTerminator->getLoc(),
1193 scf::ReduceOp::getOperationName());
1194 state.addOperands(newOperands);
1195 size_t num_regions = origTerminator->getNumRegions();
1196 for (
size_t i = 0; i < num_regions; ++i) {
1197 Region *new_orig_region = state.addRegion();
1198 Region *new_diff_region = state.addRegion();
1199 origTerminator->getRegion(i).cloneInto(new_orig_region, mapper);
1200 new_diff_region->takeBody(replTerminator->getRegion(i));
1202 Operation *new_terminator_op = term_builder.create(state);
1203 gutils->
erase(replTerminator);
1210class SCFReduceReturnAutoDiffOpInterface
1211 :
public AutoDiffOpInterface::ExternalModel<
1212 SCFReduceReturnAutoDiffOpInterface, scf::ReduceReturnOp> {
1214 LogicalResult createForwardModeTangent(Operation *origTerminator,
1217 auto parentOp = origTerminator->getParentOp();
1218 if (!isa<scf::ReduceOp>(parentOp)) {
1219 origTerminator->emitError()
1220 <<
" createForwardModeTangent called with invalid parent" << *parentOp
1229 auto reducer_index =
1230 origTerminator->getBlock()->getParent()->getRegionNumber();
1231 assert(reducer_index < parentOp->getParentOp()->getNumResults());
1232 assert(origTerminator->getNumOperands() == 1);
1233 llvm::SmallDenseSet<unsigned> operandsToShadow;
1235 parentOp->getParentOp()->getResult(reducer_index)))
1236 operandsToShadow.insert(0);
1241 SmallVector<Value> newOperands;
1242 newOperands.reserve(operandsToShadow.size());
1243 for (OpOperand &operand : origTerminator->getOpOperands()) {
1244 if (operandsToShadow.contains(operand.getOperandNumber()))
1245 newOperands.push_back(gutils->
invertPointerM(operand.get(), builder));
1256 replTerminator->setOperands(newOperands);
1265 DialectRegistry ®istry) {
1266 registry.addExtension(+[](MLIRContext *context, scf::SCFDialect *) {
1267 registerInterfaces(context);
1268 scf::IfOp::attachInterface<IfOpInterfaceReverse>(*context);
1269 scf::IfOp::attachInterface<IfOpEnzymeOpsRemover>(*context);
1270 scf::ParallelOp::attachInterface<ParallelOpInterfaceReverse>(*context);
1271 scf::ParallelOp::attachInterface<ParallelOpEnzymeOpsRemover>(*context);
1272 scf::ParallelOp::attachInterface<ParallelOpADDataFlow>(*context);
1273 scf::ReduceOp::attachInterface<ReduceOpADDataFlow>(*context);
1274 scf::ReduceOp::attachInterface<SCFReduceAutoDiffOpInterface>(*context);
1275 scf::ReduceReturnOp::attachInterface<SCFReduceReturnAutoDiffOpInterface>(
1277 scf::ForOp::attachInterface<ForOpInterfaceReverse>(*context);
1278 scf::ForOp::attachInterface<ForOpEnzymeOpsRemover>(*context);
1279 scf::ForOp::attachInterface<ForOpADDataFlow>(*context);
Operation * clone(Operation *src, IRMapping &mapper, Operation::CloneOptions options, std::map< Operation *, Operation * > &opMap)
static std::optional< SmallVector< Value > > getPotentialTerminatorUsers(Operation *op, Value parent)
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder)
LogicalResult visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils)
Value popCache(Value cache, OpBuilder &builder)
Value initAndPushCache(Value v, OpBuilder &builder)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
IRMapping originalToNewFn
void replaceOrigOpWith(Operation *op, ValueRange vals)
std::map< Operation *, Operation * > originalToNewFnOps
LogicalResult visitChild(Operation *op)
void erase(Operation *op)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
bool isConstantValue(mlir::Value v) const
void localizeGradients(OpBuilder &builder, MGradientUtilsReverse *gutils, Block *fwd)
void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry)