Enzyme main
Loading...
Searching...
No Matches
RemovalUtils.h
Go to the documentation of this file.
1//===- RemovalUtils.h - Utilities to remove Enzyme ops -------* C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#pragma once
9
10#include "Dialect/Ops.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/IRMapping.h"
18#include "mlir/Interfaces/LoopLikeInterface.h"
19#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
20
21#include "mlir/Interfaces/FunctionInterfaces.h"
22
23#include "mlir/IR/Matchers.h"
24
25namespace mlir {
26namespace enzyme {
27
28constexpr static llvm::StringLiteral kPreserveCacheAttrName = "preserve_cache";
29
30/// Information about a cache, each cache init should have one corresponding
31/// push and pop.
32struct CacheInfo {
33 enzyme::InitOp initOp;
34 enzyme::PushOp pushOp;
35 enzyme::PopOp popOp;
36
38 initOp = nullptr;
39 pushOp = nullptr;
40 popOp = nullptr;
41 }
42 CacheInfo(Value cache) {
43 initOp = cache.getDefiningOp<enzyme::InitOp>();
44 unsigned nusers = 0;
45 for (auto user : cache.getUsers()) {
46 nusers++;
47 if (!popOp)
48 popOp = dyn_cast<enzyme::PopOp>(user);
49 if (!pushOp)
50 pushOp = dyn_cast<enzyme::PushOp>(user);
51 }
52 (void)nusers;
53 assert(nusers == 2); // TODO: support more uses
54 }
55
56 Value pushedValue() { return pushOp.getValue(); }
57 Type cachedType() {
58 return cast<enzyme::CacheType>(initOp.getResult().getType()).getType();
59 }
60
61 // Pushed values must be the same
62 CacheInfo merge(CacheInfo other, PatternRewriter &rewriter);
63};
64
65// Tries to limit the amount of values cache from block `forward` to `reverse`
66// using a mincut algorithm and heuristics based on the size of values.
67// All pushes must go after `lastFwd`, if non null
68void minCutCache(Block *forward, Block *reverse, SmallVector<CacheInfo> &caches,
69 PatternRewriter &rewriter, const IRMapping &fwdrevmap,
70 Operation *lastFwd = nullptr);
71
72enum class LoopCacheType { TENSOR, MEMREF };
73
74static LoopCacheType getCacheType(Operation *op) {
76 if (op->hasAttr("enzyme.cache_use_tensor")) {
77 cacheType = LoopCacheType::TENSOR;
78 }
79 return cacheType;
80}
81
82static bool hasMinCut(Operation *op) {
83 return !op->hasAttr("enzyme.disable_mincut");
84}
85
86static bool hasLICM(Operation *op) {
87 return !op->hasAttr("enzyme.disable_licm");
88}
89
90template <typename FinalClass, typename OpName>
92 : public EnzymeOpsRemoverOpInterface::ExternalModel<FinalClass, OpName> {
93private:
94public:
95 struct IntOrValue {
96 size_t ival;
97 mlir::Value vval;
98 IntOrValue(mlir::Value vval) : ival(0), vval(vval) {}
99 IntOrValue(size_t ival) : ival(ival), vval(nullptr) {}
100 };
101
102 static bool Equivalent(Value lhs, Value rhs) {
103 if (lhs == rhs)
104 return true;
105 Attribute la, ra;
106 if (matchPattern(lhs, m_Constant(&la)) &&
107 matchPattern(rhs, m_Constant(&ra)))
108 return true;
109 auto pop = rhs.getDefiningOp<enzyme::PopOp>();
110 if (!pop)
111 return false;
112 auto init = pop.getOperand().getDefiningOp<enzyme::InitOp>();
113 if (!init)
114 return false;
115 for (auto u : init->getResult(0).getUsers()) {
116 if (u == pop)
117 continue;
118 auto push = dyn_cast<enzyme::PushOp>(u);
119 if (!push)
120 continue;
121 return push.getValue() == lhs;
122 }
123 return false;
124 }
125
126 static llvm::SmallVector<mlir::Value>
127 computeReversedIndices(PatternRewriter &rewriter, OpName op,
128 llvm::ArrayRef<mlir::Value> otherInductionVariable,
129 llvm::ArrayRef<IntOrValue> bounds) {
130 llvm::SmallVector<mlir::Value> results;
131 for (auto &&[bound, iv] : llvm::zip_equal(bounds, otherInductionVariable)) {
132 Value boundv;
133 if (bound.vval) {
134 Value c1;
135 if (iv.getType().isIndex())
136 c1 = arith::ConstantIndexOp::create(rewriter, op->getLoc(), 1);
137 else
138 c1 = arith::ConstantIntOp::create(rewriter, op->getLoc(),
139 iv.getType(), 1);
140 boundv = arith::SubIOp::create(rewriter, op->getLoc(), bound.vval, c1);
141 } else {
142 if (iv.getType().isIndex())
143 boundv = arith::ConstantIndexOp::create(rewriter, op->getLoc(),
144 bound.ival - 1);
145 else
146 boundv = arith::ConstantIntOp::create(rewriter, op->getLoc(),
147 iv.getType(), bound.ival - 1);
148 }
149 Value result = arith::SubIOp::create(rewriter, op->getLoc(), boundv, iv);
150 results.push_back(result);
151 }
152 return results;
153 }
154
155 LogicalResult removeEnzymeOps(Operation *op,
156 PatternRewriter &rewriter) const {
157 auto forOp = cast<OpName>(op);
158
159 if (hasLICM(op)) { // perform licm
160 auto loopLike = dyn_cast<LoopLikeOpInterface>(op);
161
162 if (loopLike) {
163 mlir::moveLoopInvariantCode(loopLike);
164 }
165 }
166
167 OpName otherForOp = nullptr; // where caches pops are
168
169 // There is support for two push/pop removal modes, one is using immutable
170 // tensors, the other uses memrefs. memref is the default, but tensor can be
171 // enabled with enzyme.cache_use_tensor
172 auto cacheType = getCacheType(op);
173
174 // Gradients whose values need to be passed as iteration variables.
175 llvm::SetVector<Value> updatedGradients;
176
177 llvm::MapVector<Value, CacheInfo> cachesMap;
178 SmallVector<CacheInfo> toDelete;
179
180 Block *body = forOp.getBody();
181
182 body->walk([&](Operation *op) {
183 if (auto setOp = dyn_cast<enzyme::SetOp>(op)) {
184 if (!body->getParent()->isAncestor(
185 setOp.getGradient().getDefiningOp()->getParentRegion()))
186 updatedGradients.insert(setOp.getGradient());
187 }
188
189 if (auto pushOp = dyn_cast<enzyme::PushOp>(op)) {
190 CacheInfo info(pushOp.getCache());
191
192 Value pushedValue = info.pushedValue();
193 if (cachesMap.contains(pushedValue)) {
194 info = info.merge(cachesMap.lookup(pushedValue), rewriter);
195 }
196
197 if (info.pushOp->getBlock() == body && info.popOp->getBlock() == body &&
198 info.pushOp->isBeforeInBlock(info.popOp)) {
199 toDelete.push_back(info);
200 return;
201 }
202 cachesMap[pushedValue] = info;
203
204 if (isa<OpName>(info.popOp->getParentOp())) {
205 otherForOp = cast<OpName>(info.popOp->getParentOp());
206 }
207 }
208 });
209
210 while (!toDelete.empty()) {
211 CacheInfo info = toDelete.pop_back_val();
212 rewriter.replaceAllUsesWith(info.popOp.getResult(),
213 info.pushOp.getValue());
214 rewriter.eraseOp(info.pushOp);
215 rewriter.eraseOp(info.popOp);
216 rewriter.eraseOp(info.initOp);
217 }
218
219 SmallVector<CacheInfo> caches0 =
220 llvm::map_to_vector(cachesMap, [](auto p) { return std::get<1>(p); });
221
222 SmallVector<CacheInfo> caches = caches0;
223
224 // nothing to do
225 if (updatedGradients.empty() && caches.empty())
226 return success();
227
228 DenseMap<Value, llvm::SmallVector<Operation *>> updatedGradientUsers;
229
230 for (auto &it : llvm::make_early_inc_range(*body)) {
231 Operation *op = &it;
232
233 auto getOp = dyn_cast<enzyme::GetOp>(op);
234
235 if (getOp && updatedGradients.contains(getOp.getGradient())) {
236 updatedGradientUsers[getOp.getGradient()].push_back(getOp);
237 } else if (auto setOp = dyn_cast<enzyme::SetOp>(op)) {
238 updatedGradientUsers[setOp.getGradient()].push_back(setOp);
239 }
240
241 if (!getOp || updatedGradients.contains(getOp.getGradient()))
242 continue;
243
244 auto outerGet = enzyme::GetOp::create(rewriter, getOp->getLoc(),
245 getOp.getResult().getType(),
246 getOp.getGradient());
247
248 rewriter.replaceAllUsesWith(getOp.getResult(), outerGet.getResult());
249 rewriter.eraseOp(getOp);
250 }
251
252 // postadd means that the loops init is zero and that the result
253 // is added with the previous grad after the loop.
254 bool postAdd = FinalClass::mustPostAdd(forOp);
255
256 auto term = body->getTerminator();
257
258 SmallVector<Value> newOperands(FinalClass::getInits(forOp));
259 for (auto grad : updatedGradients) {
260 auto Ty = cast<enzyme::GradientType>(grad.getType()).getBasetype();
261
262 Value newInit;
263
264 if (!postAdd) {
265 newInit = enzyme::GetOp::create(rewriter, grad.getLoc(), Ty, grad);
266 } else {
267 newInit = cast<AutoDiffTypeInterface>(Ty).createNullValue(
268 rewriter, grad.getLoc());
269 }
270
271 newOperands.push_back(newInit);
272
273 // here we do a primitive form of mem2reg within the loop. We have a
274 // sorted (by instruction number) list of all users of the
275 // instruction.
276 Value val = FinalClass::initialValueInBlock(rewriter, body, grad);
277 for (auto user : updatedGradientUsers[grad]) {
278 if (auto getOp = dyn_cast<enzyme::GetOp>(user)) {
279 rewriter.replaceOp(getOp, val);
280 } else {
281 auto setOp = cast<enzyme::SetOp>(user);
282 val = setOp.getValue();
283 rewriter.eraseOp(setOp);
284 }
285 }
286
287 term->insertOperands(term->getNumOperands(), ValueRange(val));
288 }
289
290 IRMapping fwdrevmap;
291 bool mincut = false;
292
293 // [0,..., N - 1] counter
294 SmallVector<Value> inductionVariable;
295 SmallVector<Value> otherInductionVariable;
296 SmallVector<Value> reversedIndex;
297
298 SmallVector<IntOrValue> revNumIters;
299 SmallVector<IntOrValue> fwdNumIters;
300
301 if (!fwdNumIters.size()) {
302 OpBuilder::InsertionGuard guard(rewriter);
303 rewriter.setInsertionPoint(forOp);
304 fwdNumIters = FinalClass::getDimensionBounds(rewriter, forOp);
305 }
306
307 Operation *lastFwd = nullptr;
308 if (caches.size()) {
309 rewriter.setInsertionPointToStart(forOp.getBody());
310 inductionVariable = FinalClass::getCanonicalLoopIVs(rewriter, forOp);
311 if (rewriter.getInsertionPoint() != forOp.getBody()->begin()) {
312 lastFwd = rewriter.getInsertionPoint()->getPrevNode();
313 }
314
315 rewriter.setInsertionPointToStart(otherForOp.getBody());
316 otherInductionVariable =
317 FinalClass::getCanonicalLoopIVs(rewriter, otherForOp);
318
319 // The reverse iteration count may not be known at this point, as it may
320 // be cached via a push/pop, use the fwd count in that case.
321 if (!revNumIters.size()) {
322 OpBuilder::InsertionGuard guard(rewriter);
323 rewriter.setInsertionPoint(otherForOp);
324 revNumIters = FinalClass::getDimensionBounds(rewriter, otherForOp);
325 for (auto &&[rev, fwd] : llvm::zip_equal(revNumIters, fwdNumIters)) {
326 if (!fwd.vval && rev.vval) {
327 rev.vval = nullptr;
328 rev.ival = fwd.ival;
329 }
330 }
331 }
332
333 reversedIndex = FinalClass::computeReversedIndices(
334 rewriter, otherForOp, otherInductionVariable, revNumIters);
335 fwdrevmap = FinalClass::createArgumentMap(
336 rewriter, forOp, inductionVariable, otherForOp, reversedIndex);
337 for (auto v : inductionVariable) {
338 if (auto op = v.getDefiningOp()) {
339 op->setAttr("enzyme.no_erase", rewriter.getUnitAttr());
340 }
341 }
342 for (auto v : otherInductionVariable) {
343 if (auto op = v.getDefiningOp()) {
344 op->setAttr("enzyme.no_erase", rewriter.getUnitAttr());
345 }
346 }
347 }
348
349 if (hasMinCut(forOp) && caches.size()) {
350 mincut = true;
351 mlir::enzyme::minCutCache(forOp.getBody(), otherForOp.getBody(), caches,
352 rewriter, fwdrevmap, lastFwd);
353 }
354 for (auto v : inductionVariable) {
355 if (auto op = v.getDefiningOp()) {
356 op->removeAttr("enzyme.no_erase");
357 }
358 }
359 for (auto v : otherInductionVariable) {
360 if (auto op = v.getDefiningOp()) {
361 op->removeAttr("enzyme.no_erase");
362 }
363 }
364 auto revIP = rewriter.saveInsertionPoint();
365
366 SmallVector<Value> newPushValues;
367
368 unsigned numNewValuePushes = 0;
369
370 if (lastFwd)
371 rewriter.setInsertionPointAfter(lastFwd);
372 else
373 rewriter.setInsertionPointToStart(forOp.getBody());
374 for (auto &info : caches) {
375
376 Value pushedValue = info.pushedValue();
377 if (mincut)
378 assert(forOp.getRegion().isAncestor(pushedValue.getParentRegion()));
379
380 // Otherwise, add a new variable to keep track.
381 if (!inductionVariable.size()) {
382 Value zero = arith::ConstantOp::create(rewriter, forOp->getLoc(),
383 rewriter.getIndexAttr(0));
384 newOperands.push_back(zero);
385
386 inductionVariable = {
387 body->addArgument(zero.getType(), forOp->getLoc())};
388 {
389 OpBuilder::InsertionGuard guard(rewriter);
390 rewriter.setInsertionPoint(term);
391
392 auto one = arith::ConstantOp::create(rewriter, forOp->getLoc(),
393 rewriter.getIndexAttr(1));
394 auto newInductionVar = arith::AddIOp::create(
395 rewriter, forOp->getLoc(), inductionVariable[0], one);
396 term->insertOperands(term->getNumOperands(),
397 ValueRange(newInductionVar));
398 }
399 }
400
401 SmallVector<int64_t> newShape;
402 SmallVector<Value> dynamicDims;
403 for (const auto &dim : fwdNumIters) {
404 if (dim.vval) {
405 newShape.push_back(mlir::ShapedType::kDynamic);
406 dynamicDims.push_back(dim.vval);
407 } else {
408 newShape.push_back(dim.ival);
409 }
410 }
411
412 auto ET = info.cachedType();
413 ShapedType NT;
414
415 bool multiDim = false;
416 if (auto ST = dyn_cast<ShapedType>(ET)) {
417 auto allocOp = pushedValue.getDefiningOp<memref::AllocOp>();
418 if (cacheType == LoopCacheType::MEMREF && allocOp &&
419 allocOp.getSymbolOperands().empty() &&
420 llvm::all_of(allocOp.getDynamicSizes(), [&](Value dynSize) {
421 return !forOp.getRegion().isAncestor(dynSize.getParentRegion());
422 })) {
423 multiDim = true;
424
425 dynamicDims.append(allocOp.getDynamicSizes().begin(),
426 allocOp.getDynamicSizes().end());
427
428 } else if (llvm::all_of(ST.getShape(), [](int64_t dim) {
429 return dim != ShapedType::kDynamic;
430 })) {
431 multiDim = true;
432 }
433
434 if (multiDim) {
435 newShape.append(ST.getShape().begin(), ST.getShape().end());
436 ET = ST.getElementType();
437 }
438 }
439
440 auto newType = cacheType == LoopCacheType::TENSOR
441 ? cast<ShapedType>(RankedTensorType::get(newShape, ET))
442 : cast<ShapedType>(MemRefType::get(newShape, ET));
443
444 if (cacheType == LoopCacheType::TENSOR) {
445 {
446 OpBuilder::InsertionGuard guard(rewriter);
447 rewriter.setInsertionPoint(forOp);
448 Value initValue = tensor::EmptyOp::create(
449 rewriter, info.initOp->getLoc(), newType, dynamicDims);
450
451 newOperands.push_back(initValue);
452 }
453
454 auto cacheValue = body->addArgument(newType, info.pushOp->getLoc());
455
456 {
457 OpBuilder::InsertionGuard guard(rewriter);
458 rewriter.setInsertionPoint(info.pushOp);
459
460 Value newCacheValue;
461 if (auto TT = dyn_cast<TensorType>(info.cachedType())) {
462 auto shape = TT.getShape();
463
464 SmallVector<int64_t> offsets(shape.size() + 1, 0);
465 offsets[0] = ShapedType::kDynamic;
466
467 SmallVector<int64_t> sizes;
468 sizes.reserve(shape.size() + 1);
469 sizes.push_back(1);
470 sizes.append(shape.begin(), shape.end());
471
472 SmallVector<int64_t> strides(shape.size() + 1, 1);
473
474 newCacheValue = tensor::InsertSliceOp::create(
475 rewriter, info.pushOp->getLoc(), info.pushOp.getValue(),
476 cacheValue, inductionVariable, ValueRange(), ValueRange(),
477 rewriter.getDenseI64ArrayAttr(offsets),
478 rewriter.getDenseI64ArrayAttr(sizes),
479 rewriter.getDenseI64ArrayAttr(strides));
480 } else {
481 newCacheValue = tensor::InsertOp::create(
482 rewriter, info.pushOp->getLoc(), info.pushOp.getValue(),
483 cacheValue, inductionVariable);
484 }
485
486 term->insertOperands(term->getNumOperands(),
487 ValueRange(newCacheValue));
488
489 numNewValuePushes++;
490 }
491 } else if (cacheType == LoopCacheType::MEMREF) {
492 Value initValue;
493 {
494 OpBuilder::InsertionGuard guard(rewriter);
495 rewriter.setInsertionPoint(forOp);
496 initValue =
497 memref::AllocOp::create(rewriter, info.initOp->getLoc(),
498 cast<MemRefType>(newType), dynamicDims);
499 newPushValues.push_back(initValue);
500 }
501
502 {
503 OpBuilder::InsertionGuard guard(rewriter);
504 rewriter.setInsertionPoint(info.pushOp);
505
506 auto MT = dyn_cast<MemRefType>(info.cachedType());
507 if (multiDim && MT) {
508 auto memref = info.pushOp.getValue();
509 auto shape = MT.getShape();
510
511 SmallVector<int64_t> offsets(newShape.size(), 0);
512 SmallVector<int64_t> sizes;
513 for (auto [i, _] : llvm::enumerate(inductionVariable)) {
514 offsets[i] = ShapedType::kDynamic;
515 sizes.push_back(1);
516 }
517
518 SmallVector<Value> dynSizes;
519 for (size_t i = inductionVariable.size(); i < dynamicDims.size();
520 ++i) {
521 dynSizes.push_back(dynamicDims[i]);
522 }
523
524 sizes.append(shape.begin(), shape.end());
525
526 SmallVector<int64_t> strides(newShape.size(), 1);
527
528 auto RT = memref::SubViewOp::inferRankReducedResultType(
529 MT.getShape(), cast<MemRefType>(initValue.getType()), offsets,
530 sizes, strides);
531
532 rewriter.setInsertionPoint(memref.getDefiningOp());
533 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
534 memref.getDefiningOp(), RT, initValue,
535 /*offsets*/ inductionVariable,
536 /*sizes*/ dynSizes,
537 /*strides*/ ValueRange(),
538 /*static_offsets*/ rewriter.getDenseI64ArrayAttr(offsets),
539 /*static_sizes*/ rewriter.getDenseI64ArrayAttr(sizes),
540 /*static_strides*/ rewriter.getDenseI64ArrayAttr(strides));
541
542 } else {
543 memref::StoreOp::create(rewriter, info.pushOp->getLoc(),
544 info.pushOp.getValue(), initValue,
545 inductionVariable);
546 }
547 }
548 }
549 }
550
551 auto numInitArgs = FinalClass::getInits(forOp).size();
552 rewriter.setInsertionPoint(forOp);
553
554 forOp = FinalClass::replaceWithNewOperands(rewriter, forOp, newOperands);
555 if (cacheType == LoopCacheType::TENSOR) {
556 for (size_t i = 0; i < numNewValuePushes; ++i)
557 newPushValues.push_back(
558 forOp->getResult(forOp->getNumResults() - numNewValuePushes + i));
559 }
560
561 rewriter.setInsertionPointAfter(forOp);
562
563 unsigned resultIdx = numInitArgs;
564 for (auto grad : updatedGradients) {
565 // set the updated gradient after the new for op.
566
567 Value incoming = forOp->getResult(resultIdx);
568 Value outgoing;
569 if (!postAdd) {
570 outgoing = incoming;
571 } else {
572 auto T = cast<AutoDiffTypeInterface>(incoming.getType());
573 Value current = enzyme::GetOp::create(rewriter, grad.getLoc(), T, grad);
574 outgoing = T.createAddOp(rewriter, grad.getLoc(), incoming, current);
575 }
576 enzyme::SetOp::create(rewriter, grad.getLoc(), grad, outgoing);
577 ++resultIdx;
578 }
579
580 int pushedValueIdx = 0;
581
582 if (caches.size()) {
583 if (otherInductionVariable.size()) {
584 rewriter.restoreInsertionPoint(revIP);
585 } else
586 rewriter.setInsertionPointToStart(otherForOp.getBody());
587 }
588 for (auto &info : caches) {
589 if (mincut)
590 assert(
591 forOp.getRegion().isAncestor(info.pushedValue().getParentRegion()));
592
593 Value cache = info.initOp.getResult();
594
595 // The reverse iteration count may not be known at this point, as it may
596 // be cached via a push/pop, use the fwd count in that case.
597 if (!revNumIters.size()) {
598 OpBuilder::InsertionGuard guard(rewriter);
599 rewriter.setInsertionPoint(otherForOp);
600 revNumIters = FinalClass::getDimensionBounds(rewriter, otherForOp);
601 for (auto &&[rev, fwd] : llvm::zip_equal(revNumIters, fwdNumIters)) {
602 if (!fwd.vval && rev.vval) {
603 rev.vval = nullptr;
604 rev.ival = fwd.ival;
605 }
606 }
607 }
608
609 // First, try to get canonical vars from looking up directly
610 if (otherInductionVariable.size() && !reversedIndex.size()) {
611 reversedIndex = FinalClass::computeReversedIndices(
612 rewriter, otherForOp, otherInductionVariable, revNumIters);
613 }
614
615 // Otherwise, add a new variable to keep track.
616 if (!otherInductionVariable.size()) {
617 Value zero = arith::ConstantOp::create(rewriter, otherForOp->getLoc(),
618 rewriter.getIndexAttr(0));
619 SmallVector<Value> newOperands =
620 llvm::to_vector(FinalClass::getInits(otherForOp));
621 newOperands.push_back(zero);
622
623 otherInductionVariable = {
624 body->addArgument(zero.getType(), otherForOp->getLoc())};
625 {
626 OpBuilder::InsertionGuard guard(rewriter);
627 rewriter.setInsertionPoint(term);
628
629 auto one = arith::ConstantOp::create(rewriter, forOp->getLoc(),
630 rewriter.getIndexAttr(1));
631 auto newInductionVar = arith::AddIOp::create(
632 rewriter, forOp->getLoc(), otherInductionVariable[0], one);
633 term->insertOperands(term->getNumOperands(),
634 ValueRange(newInductionVar));
635 }
636
637 {
638 OpBuilder::InsertionGuard guard(rewriter);
639 rewriter.setInsertionPoint(otherForOp);
640 otherForOp = FinalClass::replaceWithNewOperands(rewriter, otherForOp,
641 newOperands);
642 }
643
644 reversedIndex = FinalClass::computeReversedIndices(
645 rewriter, otherForOp, otherInductionVariable, revNumIters);
646 }
647
648 SmallVector<int64_t> newShape;
649 for (const auto &dim : revNumIters) {
650 if (dim.vval) {
651 newShape.push_back(mlir::ShapedType::kDynamic);
652 } else {
653 newShape.push_back(dim.ival);
654 }
655 }
656
657 auto ET = info.cachedType();
658 ShapedType NT;
659
660 bool multiDim = false;
661 if (auto ST = dyn_cast<ShapedType>(ET)) {
662 auto svOp = info.pushedValue().getDefiningOp<memref::SubViewOp>();
663 if (cacheType == LoopCacheType::MEMREF && svOp) {
664 multiDim = true;
665 } else if (llvm::all_of(ST.getShape(), [](int64_t dim) {
666 return dim != ShapedType::kDynamic;
667 })) {
668 multiDim = true;
669 }
670
671 if (multiDim) {
672 newShape.append(ST.getShape().begin(), ST.getShape().end());
673 ET = ST.getElementType();
674 }
675 }
676
677 auto newType = cacheType == LoopCacheType::TENSOR
678 ? cast<ShapedType>(RankedTensorType::get(newShape, ET))
679 : cast<ShapedType>(MemRefType::get(newShape, ET));
680 enzyme::InitOp newInit = ({
681 OpBuilder::InsertionGuard guard(rewriter);
682 rewriter.setInsertionPoint(info.initOp);
683
684 enzyme::InitOp::create(
685 rewriter, info.initOp->getLoc(),
686 enzyme::CacheType::get(cache.getContext(), newType));
687 });
688 info.pushOp = ({
689 OpBuilder::InsertionGuard guard(rewriter);
690 rewriter.setInsertionPointAfter(forOp);
691 auto newPush = enzyme::PushOp::create(rewriter, cache.getLoc(),
692 newInit.getResult(),
693 newPushValues[pushedValueIdx]);
694 rewriter.eraseOp(info.pushOp);
695 newPush;
696 });
697
698 pushedValueIdx++;
699
700 OpBuilder::InsertionGuard guard(rewriter);
701
702 rewriter.setInsertionPoint(otherForOp);
703
704 auto popNewValue = enzyme::PopOp::create(rewriter, info.popOp->getLoc(),
705 newType, newInit.getResult());
706
707 rewriter.setInsertionPoint(info.popOp);
708
709 Value popValue;
710 if (cacheType == LoopCacheType::TENSOR) {
711 if (auto TT = dyn_cast<TensorType>(info.cachedType())) {
712 auto shape = TT.getShape();
713 SmallVector<int64_t> offsets(shape.size() + 1, 0);
714 offsets[0] = ShapedType::kDynamic;
715
716 SmallVector<int64_t> sizes;
717 sizes.reserve(shape.size() + 1);
718 sizes.push_back(1);
719 sizes.append(shape.begin(), shape.end());
720
721 SmallVector<int64_t> strides(shape.size() + 1, 1);
722
723 popValue = tensor::ExtractSliceOp::create(
724 rewriter, info.popOp->getLoc(), TT, popNewValue,
725 reversedIndex, ValueRange(), ValueRange(),
726 rewriter.getDenseI64ArrayAttr(offsets),
727 rewriter.getDenseI64ArrayAttr(sizes),
728 rewriter.getDenseI64ArrayAttr(strides))
729 .getResult();
730 } else {
731 popValue = tensor::ExtractOp::create(rewriter, info.popOp->getLoc(),
732 popNewValue, reversedIndex)
733 .getResult();
734 }
735 } else if (cacheType == LoopCacheType::MEMREF) {
736
737 auto MT = dyn_cast<MemRefType>(info.cachedType());
738 if (multiDim && MT) {
739 auto shape = MT.getShape();
740
741 SmallVector<int64_t> offsets(newShape.size(), 0);
742 SmallVector<int64_t> sizes;
743 for (auto [i, _] : llvm::enumerate(reversedIndex)) {
744 offsets[i] = ShapedType::kDynamic;
745 sizes.push_back(1);
746 }
747
748 sizes.append(shape.begin(), shape.end());
749
750 SmallVector<Value> dynSizes;
751 for (size_t i = reversedIndex.size(); i < newShape.size(); ++i) {
752 // we use memref.dim here to know the size, hopefully further
753 // optimization/canonicalizations can just forward the right size
754 // here.
755 if (newShape[i] == ShapedType::kDynamic)
756 dynSizes.push_back(memref::DimOp::create(
757 rewriter, popNewValue.getLoc(), popNewValue,
758 arith::ConstantIndexOp::create(rewriter, popNewValue.getLoc(),
759 i)));
760 }
761
762 SmallVector<int64_t> strides(shape.size() + 1, 1);
763
764 auto RT = memref::SubViewOp::inferRankReducedResultType(
765 MT.getShape(), cast<MemRefType>(popNewValue.getType()), offsets,
766 sizes, strides);
767
768 popValue = memref::SubViewOp::create(
769 rewriter, info.popOp->getLoc(), RT, popNewValue,
770 /*offsets*/ reversedIndex,
771 /*sizes*/ dynSizes,
772 /*strides*/ ValueRange(),
773 /*static_offsets*/ rewriter.getDenseI64ArrayAttr(offsets),
774 /*static_sizes*/ rewriter.getDenseI64ArrayAttr(sizes),
775 /*static_strides*/ rewriter.getDenseI64ArrayAttr(strides));
776
777 for (auto user :
778 llvm::make_early_inc_range(info.popOp.getResult().getUsers())) {
779 if (isa<memref::DeallocOp>(user))
780 rewriter.eraseOp(user);
781 }
782 } else {
783 popValue = memref::LoadOp::create(rewriter, info.popOp->getLoc(),
784 popNewValue, reversedIndex);
785 }
786
787 // this memref was allocated on push, dealloc it
788 rewriter.setInsertionPointAfter(otherForOp);
789 memref::DeallocOp::create(rewriter, info.initOp->getLoc(), popNewValue);
790 }
791
792 rewriter.replaceAllUsesWith(info.popOp.getResult(), popValue);
793 rewriter.eraseOp(info.popOp);
794 }
795
796 return success();
797 }
798};
799
800// All values defined in fwd should have no use outside this block
801// therefore we can localize their differential to only the rev block in order
802// to simplify the work of the remove-unnecessary-enzyme-ops pass.
803//
804// The builder insertion point should be at the start of the corresponding rev
805// block.
806void localizeGradients(OpBuilder &builder, MGradientUtilsReverse *gutils,
807 Block *fwd);
808
809void removalBlockExplore(Block *block, IRMapping &mapping,
810 PatternRewriter &rewriter,
811 llvm::SetVector<Value> &gradients,
812 llvm::MapVector<Value, CacheInfo> &caches);
813
814template <typename FinalClass, typename OpName>
816 : public EnzymeOpsRemoverOpInterface::ExternalModel<FinalClass, OpName> {
817 LogicalResult removeEnzymeOps(Operation *op,
818 PatternRewriter &rewriter) const {
819 auto ifOp = cast<OpName>(op);
820 // Gradients:
821 //
822 // For each set in a branch, we instead set after the if by using the
823 // return value.
824 //
825 // if %pred {
826 // enzyme.set %grad, %2
827 // } else {
828 // }
829 //
830 // %0 = enzyme.get %grad
831 // %1 = if %pred {
832 // return %2
833 // } else {
834 // return %0
835 // }
836 // enzyme.set %grad, %1
837 //
838 // For each get in a branch, we get before and use that instead of the
839 // get.
840
841 // Caches:
842 //
843 // For each push, push after the if instead add a dummy value in the other
844 // branch.
845 //
846 // For each pop in the reverse if, pop before the if instead of inside a
847 // branch.
848
849 Block *trueBlock = FinalClass::getThenBlock(ifOp, rewriter),
850 *falseBlock = FinalClass::getElseBlock(ifOp, rewriter);
851
852 // Gradients whose value is set in either branches.
853 llvm::SetVector<Value> gradients;
854
855 // We assume pushes are exclusive.
856 llvm::MapVector<Value, CacheInfo> pushedCaches;
857
858 // Grad to value
859 IRMapping trueMapping, falseMapping;
860
861 removalBlockExplore(trueBlock, trueMapping, rewriter, gradients,
862 pushedCaches);
863 removalBlockExplore(falseBlock, falseMapping, rewriter, gradients,
864 pushedCaches);
865
866 if (gradients.empty() && pushedCaches.empty())
867 return success();
868 bool removeCaches = !op->hasAttr(kPreserveCacheAttrName);
869
870 Operation *trueTerm = trueBlock->getTerminator();
871 Operation *falseTerm = falseBlock->getTerminator();
872
873 for (auto grad : gradients) {
874 auto trueValue = trueMapping.lookupOrNull(grad);
875 if (!trueValue) {
876 trueValue = enzyme::GetOp::create(
877 rewriter, grad.getLoc(),
878 cast<enzyme::GradientType>(grad.getType()).getBasetype(), grad);
879 }
880 trueTerm->insertOperands(trueTerm->getNumOperands(),
881 ValueRange(trueValue));
882
883 auto falseValue = falseMapping.lookupOrNull(grad);
884 if (!falseValue) {
885 falseValue = enzyme::GetOp::create(
886 rewriter, grad.getLoc(),
887 cast<enzyme::GradientType>(grad.getType()).getBasetype(), grad);
888 }
889 falseTerm->insertOperands(falseTerm->getNumOperands(),
890 ValueRange(falseValue));
891 }
892
893 if (removeCaches) {
894 for (auto &[pushedValue, info] : pushedCaches) {
895 Value dummy = FinalClass::getDummyValue(rewriter, pushedValue.getLoc(),
896 pushedValue.getType());
897
898 Value trueValue =
899 pushedValue.getParentBlock() == trueBlock ? pushedValue : dummy;
900 Value falseValue =
901 pushedValue.getParentBlock() == falseBlock ? pushedValue : dummy;
902
903 trueTerm->insertOperands(trueTerm->getNumOperands(),
904 ValueRange(trueValue));
905 falseTerm->insertOperands(falseTerm->getNumOperands(),
906 ValueRange(falseValue));
907 }
908 }
909
910 size_t idx = ifOp->getNumResults();
911 ifOp = FinalClass::replace(rewriter, ifOp, trueTerm->getOperandTypes());
912
913 for (auto grad : gradients) {
914 enzyme::SetOp::create(rewriter, grad.getLoc(), grad,
915 ifOp->getResult(idx));
916 idx++;
917 }
918
919 if (removeCaches) {
920 for (auto &[pushedValue, info] : pushedCaches) {
921 enzyme::PushOp::create(rewriter, info.pushOp->getLoc(),
922 info.initOp.getResult(), ifOp->getResult(idx));
923 rewriter.eraseOp(info.pushOp);
924
925 OpBuilder::InsertionGuard guard(rewriter);
926 rewriter.setInsertionPoint(info.popOp->getParentOp());
927
928 auto newPop = enzyme::PopOp::create(rewriter, info.popOp->getLoc(),
929 info.popOp.getResult().getType(),
930 info.popOp.getCache());
931 rewriter.replaceAllUsesWith(info.popOp.getResult(), newPop);
932 rewriter.eraseOp(info.popOp);
933
934 idx++;
935 }
936 }
937
938 return success();
939 }
940};
941
942} // namespace enzyme
943} // namespace mlir
#define getOp(vtmp)
void removalBlockExplore(Block *block, IRMapping &mapping, PatternRewriter &rewriter, llvm::SetVector< Value > &gradients, llvm::MapVector< Value, CacheInfo > &caches)
void localizeGradients(OpBuilder &builder, MGradientUtilsReverse *gutils, Block *fwd)
static constexpr llvm::StringLiteral kPreserveCacheAttrName
void minCutCache(Block *forward, Block *reverse, SmallVector< CacheInfo > &caches, PatternRewriter &rewriter, const IRMapping &fwdrevmap, Operation *lastFwd=nullptr)
static bool hasLICM(Operation *op)
static bool hasMinCut(Operation *op)
static LoopCacheType getCacheType(Operation *op)
Information about a cache, each cache init should have one corresponding push and pop.
CacheInfo merge(CacheInfo other, PatternRewriter &rewriter)
enzyme::InitOp initOp
enzyme::PushOp pushOp
static bool Equivalent(Value lhs, Value rhs)
static llvm::SmallVector< mlir::Value > computeReversedIndices(PatternRewriter &rewriter, OpName op, llvm::ArrayRef< mlir::Value > otherInductionVariable, llvm::ArrayRef< IntOrValue > bounds)
LogicalResult removeEnzymeOps(Operation *op, PatternRewriter &rewriter) const
LogicalResult removeEnzymeOps(Operation *op, PatternRewriter &rewriter) const