156 PatternRewriter &rewriter)
const {
157 auto forOp = cast<OpName>(op);
160 auto loopLike = dyn_cast<LoopLikeOpInterface>(op);
163 mlir::moveLoopInvariantCode(loopLike);
167 OpName otherForOp =
nullptr;
175 llvm::SetVector<Value> updatedGradients;
177 llvm::MapVector<Value, CacheInfo> cachesMap;
178 SmallVector<CacheInfo> toDelete;
180 Block *body = forOp.getBody();
182 body->walk([&](Operation *op) {
183 if (
auto setOp = dyn_cast<enzyme::SetOp>(op)) {
184 if (!body->getParent()->isAncestor(
185 setOp.getGradient().getDefiningOp()->getParentRegion()))
186 updatedGradients.insert(setOp.getGradient());
189 if (
auto pushOp = dyn_cast<enzyme::PushOp>(op)) {
193 if (cachesMap.contains(pushedValue)) {
194 info = info.
merge(cachesMap.lookup(pushedValue), rewriter);
197 if (info.
pushOp->getBlock() == body && info.
popOp->getBlock() == body &&
199 toDelete.push_back(info);
202 cachesMap[pushedValue] = info;
204 if (isa<OpName>(info.
popOp->getParentOp())) {
205 otherForOp = cast<OpName>(info.
popOp->getParentOp());
210 while (!toDelete.empty()) {
211 CacheInfo info = toDelete.pop_back_val();
212 rewriter.replaceAllUsesWith(info.
popOp.getResult(),
214 rewriter.eraseOp(info.
pushOp);
215 rewriter.eraseOp(info.
popOp);
216 rewriter.eraseOp(info.
initOp);
219 SmallVector<CacheInfo> caches0 =
220 llvm::map_to_vector(cachesMap, [](
auto p) {
return std::get<1>(p); });
222 SmallVector<CacheInfo> caches = caches0;
225 if (updatedGradients.empty() && caches.empty())
228 DenseMap<Value, llvm::SmallVector<Operation *>> updatedGradientUsers;
230 for (
auto &it : llvm::make_early_inc_range(*body)) {
233 auto getOp = dyn_cast<enzyme::GetOp>(op);
235 if (
getOp && updatedGradients.contains(
getOp.getGradient())) {
236 updatedGradientUsers[
getOp.getGradient()].push_back(
getOp);
237 }
else if (
auto setOp = dyn_cast<enzyme::SetOp>(op)) {
238 updatedGradientUsers[setOp.getGradient()].push_back(setOp);
241 if (!
getOp || updatedGradients.contains(
getOp.getGradient()))
244 auto outerGet = enzyme::GetOp::create(rewriter,
getOp->getLoc(),
245 getOp.getResult().getType(),
246 getOp.getGradient());
248 rewriter.replaceAllUsesWith(
getOp.getResult(), outerGet.getResult());
249 rewriter.eraseOp(
getOp);
254 bool postAdd = FinalClass::mustPostAdd(forOp);
256 auto term = body->getTerminator();
258 SmallVector<Value> newOperands(FinalClass::getInits(forOp));
259 for (
auto grad : updatedGradients) {
260 auto Ty = cast<enzyme::GradientType>(grad.getType()).getBasetype();
265 newInit = enzyme::GetOp::create(rewriter, grad.getLoc(), Ty, grad);
267 newInit = cast<AutoDiffTypeInterface>(Ty).createNullValue(
268 rewriter, grad.getLoc());
271 newOperands.push_back(newInit);
276 Value val = FinalClass::initialValueInBlock(rewriter, body, grad);
277 for (
auto user : updatedGradientUsers[grad]) {
278 if (
auto getOp = dyn_cast<enzyme::GetOp>(user)) {
279 rewriter.replaceOp(
getOp, val);
281 auto setOp = cast<enzyme::SetOp>(user);
282 val = setOp.getValue();
283 rewriter.eraseOp(setOp);
287 term->insertOperands(term->getNumOperands(), ValueRange(val));
294 SmallVector<Value> inductionVariable;
295 SmallVector<Value> otherInductionVariable;
296 SmallVector<Value> reversedIndex;
298 SmallVector<IntOrValue> revNumIters;
299 SmallVector<IntOrValue> fwdNumIters;
301 if (!fwdNumIters.size()) {
302 OpBuilder::InsertionGuard guard(rewriter);
303 rewriter.setInsertionPoint(forOp);
304 fwdNumIters = FinalClass::getDimensionBounds(rewriter, forOp);
307 Operation *lastFwd =
nullptr;
309 rewriter.setInsertionPointToStart(forOp.getBody());
310 inductionVariable = FinalClass::getCanonicalLoopIVs(rewriter, forOp);
311 if (rewriter.getInsertionPoint() != forOp.getBody()->begin()) {
312 lastFwd = rewriter.getInsertionPoint()->getPrevNode();
315 rewriter.setInsertionPointToStart(otherForOp.getBody());
316 otherInductionVariable =
317 FinalClass::getCanonicalLoopIVs(rewriter, otherForOp);
321 if (!revNumIters.size()) {
322 OpBuilder::InsertionGuard guard(rewriter);
323 rewriter.setInsertionPoint(otherForOp);
324 revNumIters = FinalClass::getDimensionBounds(rewriter, otherForOp);
325 for (
auto &&[rev, fwd] : llvm::zip_equal(revNumIters, fwdNumIters)) {
326 if (!fwd.vval && rev.vval) {
333 reversedIndex = FinalClass::computeReversedIndices(
334 rewriter, otherForOp, otherInductionVariable, revNumIters);
335 fwdrevmap = FinalClass::createArgumentMap(
336 rewriter, forOp, inductionVariable, otherForOp, reversedIndex);
337 for (
auto v : inductionVariable) {
338 if (
auto op = v.getDefiningOp()) {
339 op->setAttr(
"enzyme.no_erase", rewriter.getUnitAttr());
342 for (
auto v : otherInductionVariable) {
343 if (
auto op = v.getDefiningOp()) {
344 op->setAttr(
"enzyme.no_erase", rewriter.getUnitAttr());
352 rewriter, fwdrevmap, lastFwd);
354 for (
auto v : inductionVariable) {
355 if (
auto op = v.getDefiningOp()) {
356 op->removeAttr(
"enzyme.no_erase");
359 for (
auto v : otherInductionVariable) {
360 if (
auto op = v.getDefiningOp()) {
361 op->removeAttr(
"enzyme.no_erase");
364 auto revIP = rewriter.saveInsertionPoint();
366 SmallVector<Value> newPushValues;
368 unsigned numNewValuePushes = 0;
371 rewriter.setInsertionPointAfter(lastFwd);
373 rewriter.setInsertionPointToStart(forOp.getBody());
374 for (
auto &info : caches) {
376 Value pushedValue = info.pushedValue();
378 assert(forOp.getRegion().isAncestor(pushedValue.getParentRegion()));
381 if (!inductionVariable.size()) {
382 Value zero = arith::ConstantOp::create(rewriter, forOp->getLoc(),
383 rewriter.getIndexAttr(0));
384 newOperands.push_back(zero);
386 inductionVariable = {
387 body->addArgument(zero.getType(), forOp->getLoc())};
389 OpBuilder::InsertionGuard guard(rewriter);
390 rewriter.setInsertionPoint(term);
392 auto one = arith::ConstantOp::create(rewriter, forOp->getLoc(),
393 rewriter.getIndexAttr(1));
394 auto newInductionVar = arith::AddIOp::create(
395 rewriter, forOp->getLoc(), inductionVariable[0], one);
396 term->insertOperands(term->getNumOperands(),
397 ValueRange(newInductionVar));
401 SmallVector<int64_t> newShape;
402 SmallVector<Value> dynamicDims;
403 for (
const auto &dim : fwdNumIters) {
405 newShape.push_back(mlir::ShapedType::kDynamic);
406 dynamicDims.push_back(dim.vval);
408 newShape.push_back(dim.ival);
412 auto ET = info.cachedType();
415 bool multiDim =
false;
416 if (
auto ST = dyn_cast<ShapedType>(ET)) {
417 auto allocOp = pushedValue.getDefiningOp<memref::AllocOp>();
419 allocOp.getSymbolOperands().empty() &&
420 llvm::all_of(allocOp.getDynamicSizes(), [&](Value dynSize) {
421 return !forOp.getRegion().isAncestor(dynSize.getParentRegion());
425 dynamicDims.append(allocOp.getDynamicSizes().begin(),
426 allocOp.getDynamicSizes().end());
428 }
else if (llvm::all_of(ST.getShape(), [](int64_t dim) {
429 return dim != ShapedType::kDynamic;
435 newShape.append(ST.getShape().begin(), ST.getShape().end());
436 ET = ST.getElementType();
441 ? cast<ShapedType>(RankedTensorType::get(newShape, ET))
442 : cast<ShapedType>(MemRefType::get(newShape, ET));
446 OpBuilder::InsertionGuard guard(rewriter);
447 rewriter.setInsertionPoint(forOp);
448 Value initValue = tensor::EmptyOp::create(
449 rewriter, info.initOp->getLoc(), newType, dynamicDims);
451 newOperands.push_back(initValue);
454 auto cacheValue = body->addArgument(newType, info.pushOp->getLoc());
457 OpBuilder::InsertionGuard guard(rewriter);
458 rewriter.setInsertionPoint(info.pushOp);
461 if (
auto TT = dyn_cast<TensorType>(info.cachedType())) {
462 auto shape = TT.getShape();
464 SmallVector<int64_t> offsets(shape.size() + 1, 0);
465 offsets[0] = ShapedType::kDynamic;
467 SmallVector<int64_t> sizes;
468 sizes.reserve(shape.size() + 1);
470 sizes.append(shape.begin(), shape.end());
472 SmallVector<int64_t> strides(shape.size() + 1, 1);
474 newCacheValue = tensor::InsertSliceOp::create(
475 rewriter, info.pushOp->getLoc(), info.pushOp.getValue(),
476 cacheValue, inductionVariable, ValueRange(), ValueRange(),
477 rewriter.getDenseI64ArrayAttr(offsets),
478 rewriter.getDenseI64ArrayAttr(sizes),
479 rewriter.getDenseI64ArrayAttr(strides));
481 newCacheValue = tensor::InsertOp::create(
482 rewriter, info.pushOp->getLoc(), info.pushOp.getValue(),
483 cacheValue, inductionVariable);
486 term->insertOperands(term->getNumOperands(),
487 ValueRange(newCacheValue));
494 OpBuilder::InsertionGuard guard(rewriter);
495 rewriter.setInsertionPoint(forOp);
497 memref::AllocOp::create(rewriter, info.initOp->getLoc(),
498 cast<MemRefType>(newType), dynamicDims);
499 newPushValues.push_back(initValue);
503 OpBuilder::InsertionGuard guard(rewriter);
504 rewriter.setInsertionPoint(info.pushOp);
506 auto MT = dyn_cast<MemRefType>(info.cachedType());
507 if (multiDim && MT) {
508 auto memref = info.pushOp.getValue();
509 auto shape = MT.getShape();
511 SmallVector<int64_t> offsets(newShape.size(), 0);
512 SmallVector<int64_t> sizes;
513 for (
auto [i, _] : llvm::enumerate(inductionVariable)) {
514 offsets[i] = ShapedType::kDynamic;
518 SmallVector<Value> dynSizes;
519 for (
size_t i = inductionVariable.size(); i < dynamicDims.size();
521 dynSizes.push_back(dynamicDims[i]);
524 sizes.append(shape.begin(), shape.end());
526 SmallVector<int64_t> strides(newShape.size(), 1);
528 auto RT = memref::SubViewOp::inferRankReducedResultType(
529 MT.getShape(), cast<MemRefType>(initValue.getType()), offsets,
532 rewriter.setInsertionPoint(memref.getDefiningOp());
533 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
534 memref.getDefiningOp(), RT, initValue,
538 rewriter.getDenseI64ArrayAttr(offsets),
539 rewriter.getDenseI64ArrayAttr(sizes),
540 rewriter.getDenseI64ArrayAttr(strides));
543 memref::StoreOp::create(rewriter, info.pushOp->getLoc(),
544 info.pushOp.getValue(), initValue,
551 auto numInitArgs = FinalClass::getInits(forOp).size();
552 rewriter.setInsertionPoint(forOp);
554 forOp = FinalClass::replaceWithNewOperands(rewriter, forOp, newOperands);
556 for (
size_t i = 0; i < numNewValuePushes; ++i)
557 newPushValues.push_back(
558 forOp->getResult(forOp->getNumResults() - numNewValuePushes + i));
561 rewriter.setInsertionPointAfter(forOp);
563 unsigned resultIdx = numInitArgs;
564 for (
auto grad : updatedGradients) {
567 Value incoming = forOp->getResult(resultIdx);
572 auto T = cast<AutoDiffTypeInterface>(incoming.getType());
573 Value current = enzyme::GetOp::create(rewriter, grad.getLoc(), T, grad);
574 outgoing = T.createAddOp(rewriter, grad.getLoc(), incoming, current);
576 enzyme::SetOp::create(rewriter, grad.getLoc(), grad, outgoing);
580 int pushedValueIdx = 0;
583 if (otherInductionVariable.size()) {
584 rewriter.restoreInsertionPoint(revIP);
586 rewriter.setInsertionPointToStart(otherForOp.getBody());
588 for (
auto &info : caches) {
591 forOp.getRegion().isAncestor(info.pushedValue().getParentRegion()));
593 Value cache = info.initOp.getResult();
597 if (!revNumIters.size()) {
598 OpBuilder::InsertionGuard guard(rewriter);
599 rewriter.setInsertionPoint(otherForOp);
600 revNumIters = FinalClass::getDimensionBounds(rewriter, otherForOp);
601 for (
auto &&[rev, fwd] : llvm::zip_equal(revNumIters, fwdNumIters)) {
602 if (!fwd.vval && rev.vval) {
610 if (otherInductionVariable.size() && !reversedIndex.size()) {
611 reversedIndex = FinalClass::computeReversedIndices(
612 rewriter, otherForOp, otherInductionVariable, revNumIters);
616 if (!otherInductionVariable.size()) {
617 Value zero = arith::ConstantOp::create(rewriter, otherForOp->getLoc(),
618 rewriter.getIndexAttr(0));
619 SmallVector<Value> newOperands =
620 llvm::to_vector(FinalClass::getInits(otherForOp));
621 newOperands.push_back(zero);
623 otherInductionVariable = {
624 body->addArgument(zero.getType(), otherForOp->getLoc())};
626 OpBuilder::InsertionGuard guard(rewriter);
627 rewriter.setInsertionPoint(term);
629 auto one = arith::ConstantOp::create(rewriter, forOp->getLoc(),
630 rewriter.getIndexAttr(1));
631 auto newInductionVar = arith::AddIOp::create(
632 rewriter, forOp->getLoc(), otherInductionVariable[0], one);
633 term->insertOperands(term->getNumOperands(),
634 ValueRange(newInductionVar));
638 OpBuilder::InsertionGuard guard(rewriter);
639 rewriter.setInsertionPoint(otherForOp);
640 otherForOp = FinalClass::replaceWithNewOperands(rewriter, otherForOp,
644 reversedIndex = FinalClass::computeReversedIndices(
645 rewriter, otherForOp, otherInductionVariable, revNumIters);
648 SmallVector<int64_t> newShape;
649 for (
const auto &dim : revNumIters) {
651 newShape.push_back(mlir::ShapedType::kDynamic);
653 newShape.push_back(dim.ival);
657 auto ET = info.cachedType();
660 bool multiDim =
false;
661 if (
auto ST = dyn_cast<ShapedType>(ET)) {
662 auto svOp = info.pushedValue().getDefiningOp<memref::SubViewOp>();
665 }
else if (llvm::all_of(ST.getShape(), [](int64_t dim) {
666 return dim != ShapedType::kDynamic;
672 newShape.append(ST.getShape().begin(), ST.getShape().end());
673 ET = ST.getElementType();
678 ? cast<ShapedType>(RankedTensorType::get(newShape, ET))
679 : cast<ShapedType>(MemRefType::get(newShape, ET));
680 enzyme::InitOp newInit = ({
681 OpBuilder::InsertionGuard guard(rewriter);
682 rewriter.setInsertionPoint(info.initOp);
684 enzyme::InitOp::create(
685 rewriter, info.initOp->getLoc(),
686 enzyme::CacheType::get(cache.getContext(), newType));
689 OpBuilder::InsertionGuard guard(rewriter);
690 rewriter.setInsertionPointAfter(forOp);
691 auto newPush = enzyme::PushOp::create(rewriter, cache.getLoc(),
693 newPushValues[pushedValueIdx]);
694 rewriter.eraseOp(info.pushOp);
700 OpBuilder::InsertionGuard guard(rewriter);
702 rewriter.setInsertionPoint(otherForOp);
704 auto popNewValue = enzyme::PopOp::create(rewriter, info.popOp->getLoc(),
705 newType, newInit.getResult());
707 rewriter.setInsertionPoint(info.popOp);
711 if (
auto TT = dyn_cast<TensorType>(info.cachedType())) {
712 auto shape = TT.getShape();
713 SmallVector<int64_t> offsets(shape.size() + 1, 0);
714 offsets[0] = ShapedType::kDynamic;
716 SmallVector<int64_t> sizes;
717 sizes.reserve(shape.size() + 1);
719 sizes.append(shape.begin(), shape.end());
721 SmallVector<int64_t> strides(shape.size() + 1, 1);
723 popValue = tensor::ExtractSliceOp::create(
724 rewriter, info.popOp->getLoc(), TT, popNewValue,
725 reversedIndex, ValueRange(), ValueRange(),
726 rewriter.getDenseI64ArrayAttr(offsets),
727 rewriter.getDenseI64ArrayAttr(sizes),
728 rewriter.getDenseI64ArrayAttr(strides))
731 popValue = tensor::ExtractOp::create(rewriter, info.popOp->getLoc(),
732 popNewValue, reversedIndex)
737 auto MT = dyn_cast<MemRefType>(info.cachedType());
738 if (multiDim && MT) {
739 auto shape = MT.getShape();
741 SmallVector<int64_t> offsets(newShape.size(), 0);
742 SmallVector<int64_t> sizes;
743 for (
auto [i, _] : llvm::enumerate(reversedIndex)) {
744 offsets[i] = ShapedType::kDynamic;
748 sizes.append(shape.begin(), shape.end());
750 SmallVector<Value> dynSizes;
751 for (
size_t i = reversedIndex.size(); i < newShape.size(); ++i) {
755 if (newShape[i] == ShapedType::kDynamic)
756 dynSizes.push_back(memref::DimOp::create(
757 rewriter, popNewValue.getLoc(), popNewValue,
758 arith::ConstantIndexOp::create(rewriter, popNewValue.getLoc(),
762 SmallVector<int64_t> strides(shape.size() + 1, 1);
764 auto RT = memref::SubViewOp::inferRankReducedResultType(
765 MT.getShape(), cast<MemRefType>(popNewValue.getType()), offsets,
768 popValue = memref::SubViewOp::create(
769 rewriter, info.popOp->getLoc(), RT, popNewValue,
773 rewriter.getDenseI64ArrayAttr(offsets),
774 rewriter.getDenseI64ArrayAttr(sizes),
775 rewriter.getDenseI64ArrayAttr(strides));
778 llvm::make_early_inc_range(info.popOp.getResult().getUsers())) {
779 if (isa<memref::DeallocOp>(user))
780 rewriter.eraseOp(user);
783 popValue = memref::LoadOp::create(rewriter, info.popOp->getLoc(),
784 popNewValue, reversedIndex);
788 rewriter.setInsertionPointAfter(otherForOp);
789 memref::DeallocOp::create(rewriter, info.initOp->getLoc(), popNewValue);
792 rewriter.replaceAllUsesWith(info.popOp.getResult(), popValue);
793 rewriter.eraseOp(info.popOp);