19#include "mlir/Dialect/Arith/IR/Arith.h"
20#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/Dialect/Tensor/IR/Tensor.h"
25#include "mlir/Transforms/DialectConversion.h"
27#include "mlir/Rewrite/PatternApplicator.h"
28#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "mlir/IR/Dominance.h"
31#include "llvm/Support/raw_ostream.h"
35#define GEN_PASS_DEF_REMOVEUNUSEDENZYMEOPSPASS
36#include "Passes/Passes.h.inc"
41using namespace enzyme;
46bool mayExecuteBefore(Block *blk, Operation *check, Operation *end) {
47 auto reg = blk->getParent();
49 assert(reg->isAncestor(end->getParentRegion()));
51 DenseSet<Block *> visitedBlocks;
53 SmallVector<Block *> blocksToVisit;
54 for (
auto succ : blk->getSuccessors()) {
55 blocksToVisit.push_back(succ);
58 while (!blocksToVisit.empty()) {
59 Block *cur = blocksToVisit.pop_back_val();
61 if (visitedBlocks.contains(cur))
64 visitedBlocks.insert(cur);
67 for (
auto &op : *cur) {
70 if (op.isAncestor(check)) {
87 if (op.isAncestor(end)) {
97 for (
auto succ : cur->getSuccessors()) {
98 blocksToVisit.push_back(succ);
105bool mayExecuteBetween(Operation *start, Operation *check, Operation *end) {
106 Block *blk = start->getBlock();
107 auto checkAnc = check;
108 while (checkAnc && checkAnc->getBlock() != blk)
109 checkAnc = checkAnc->getParentOp();
112 while (endAnc && endAnc->getBlock() != blk)
113 endAnc = endAnc->getParentOp();
117 if (checkAnc && checkAnc == endAnc)
121 if (checkAnc->isBeforeInBlock(start))
125 if (endAnc->isBeforeInBlock(start))
129 if (checkAnc && endAnc) {
130 if (checkAnc->isBeforeInBlock(endAnc))
141 auto reg = blk->getParent();
142 if (reg->isAncestor(end->getParentRegion())) {
143 return mayExecuteBefore(blk, check, end);
148 if (reg->isAncestor(check->getParentRegion())) {
152 return mayExecuteBetween(start->getParentOp(), check, end);
159template <
class T,
class T2 = T>
160T findNearestDominatingOpByUse(Operation *op, Value v) {
162 PostDominanceInfo pdInfo;
164 SmallVector<T, 1> options;
165 SmallVector<Operation *, 1> conflicts;
166 for (Operation *userSet : v.getUsers()) {
167 if (
auto setOp = dyn_cast<T>(userSet)) {
168 options.push_back(setOp);
169 conflicts.push_back(setOp);
172 if (
auto setOp = dyn_cast<T2>(userSet)) {
173 conflicts.push_back(setOp);
178 for (
auto opt : options) {
179 if (!dInfo.dominates(opt, op))
181 bool conflict =
false;
182 for (
auto opt2 : conflicts) {
188 if (!mayExecuteBetween(opt, opt2, op)) {
205 LogicalResult matchAndRewrite(enzyme::PopOp pop,
206 PatternRewriter &rewriter)
const final {
208 auto init = pop.getCache().getDefiningOp<enzyme::InitOp>();
212 SmallVector<enzyme::PopOp, 1> pops;
213 SmallVector<enzyme::PushOp, 1> pushes;
214 for (Operation *userSet : init.getResult().getUsers()) {
215 if (
auto push = dyn_cast<enzyme::PushOp>(userSet)) {
216 pushes.push_back(push);
219 if (
auto pop = dyn_cast<enzyme::PopOp>(userSet)) {
226 if (
auto push = findNearestDominatingOpByUse<enzyme::PushOp, enzyme::PopOp>(
229 if (pop->getBlock() == push->getBlock()) {
230 rewriter.replaceOp(pop, push.getValue());
231 rewriter.eraseOp(push);
243 LogicalResult matchAndRewrite(enzyme::GetOp get,
244 PatternRewriter &rewriter)
const final {
246 auto init = get.getGradient().getDefiningOp<enzyme::InitOp>();
250 for (Operation *userSet : init.getResult().getUsers()) {
251 if (isa<enzyme::GetOp>(userSet))
253 if (isa<enzyme::SetOp>(userSet))
258 if (
auto set = findNearestDominatingOpByUse<enzyme::SetOp>(get, init)) {
259 rewriter.replaceOp(get, set.getValue());
269 LogicalResult matchAndRewrite(enzyme::SetOp get,
270 PatternRewriter &rewriter)
const final {
272 auto init = get.getGradient().getDefiningOp<enzyme::InitOp>();
276 for (Operation *userSet : init.getResult().getUsers()) {
277 if (isa<enzyme::SetOp>(userSet))
282 rewriter.eraseOp(get);
290 LogicalResult matchAndRewrite(enzyme::PushOp get,
291 PatternRewriter &rewriter)
const final {
293 auto init = get.getCache().getDefiningOp<enzyme::InitOp>();
297 for (Operation *userSet : init.getResult().getUsers()) {
298 if (isa<enzyme::PushOp>(userSet))
303 rewriter.eraseOp(get);
311 LogicalResult matchAndRewrite(enzyme::InitOp init,
312 PatternRewriter &rewriter)
const final {
313 if (init.use_empty()) {
314 rewriter.eraseOp(init);
321struct IgnoreDerivativesSimplifyPattern
325 LogicalResult matchAndRewrite(enzyme::IgnoreDerivativesOp op,
326 PatternRewriter &rewriter)
const override {
327 rewriter.replaceOp(op, op.getOperand());
332static void applyPatterns(Operation *op) {
333 RewritePatternSet patterns(op->getContext());
334 patterns.insert<PopSimplify, GetSimplify, PushSimplify, SetSimplify,
335 InitSimplify, IgnoreDerivativesSimplifyPattern>(
338 GreedyRewriteConfig config;
339 config.enableFolding();
340 (void)applyPatternsGreedily(op, std::move(patterns), config);
343static void annotateRegionOpsInLoops(Operation *op) {
347 op->walk([](LoopLikeOpInterface loop) {
348 loop->walk([](RegionBranchOpInterface regionBranch) {
349 if (!regionBranch.hasLoop()) {
351 UnitAttr::get(regionBranch.getContext()));
362 Worklist() { list.reserve(8); }
365 void push(Operation *op);
366 void remove(Operation *op);
371 std::vector<Operation *> list;
372 llvm::DenseMap<Operation *, unsigned> map;
375bool Worklist::empty() {
377 return !llvm::any_of(list,
378 [](Operation *op) {
return static_cast<bool>(op); });
381void Worklist::push(Operation *op) {
382 assert(op &&
"cannot push nullptr to worklist");
384 if (!map.insert({op, list.size()}).second)
389void Worklist::reverse() {
390 std::reverse(list.begin(), list.end());
391 for (
size_t i = 0, e = list.size(); i < e; ++i)
395Operation *Worklist::pop() {
399 Operation *op = list.back();
403 while (!list.empty() && !list.back())
408void Worklist::remove(Operation *op) {
409 assert(op &&
"cannot remove nullptr from worklist");
410 auto it = map.find(op);
411 if (it != map.end()) {
412 assert(list[it->second] == op &&
"malformed worklist data structure");
413 list[it->second] =
nullptr;
421class PostOrderWalkDriver :
public RewriterBase::Listener {
423 PostOrderWalkDriver(Operation *root_) : root(root_) {}
425 void initializeWorklist();
426 LogicalResult processWorklist();
429 void notifyOperationInserted(Operation *op,
430 OpBuilder::InsertPoint previous)
override;
431 void notifyOperationErased(Operation *op)
override;
434 notifyMatchFailure(Location loc,
435 function_ref<
void(Diagnostic &)> reasonCallback)
override;
438 void addToWorklist(Operation *op);
442 Operation *current =
nullptr;
446void PostOrderWalkDriver::addToWorklist(Operation *op) {
448 if (!isa<EnzymeOpsRemoverOpInterface>(op))
453void PostOrderWalkDriver::notifyOperationInserted(
454 Operation *op, OpBuilder::InsertPoint previous) {
455 if (!isa<EnzymeOpsRemoverOpInterface>(op))
465 bool shouldInsert =
false;
466 (void)root->walk([&](EnzymeOpsRemoverOpInterface iface) {
467 if ((Operation *)iface == current) {
469 return WalkResult::interrupt();
472 if ((Operation *)iface == op) {
473 shouldInsert =
false;
474 return WalkResult::interrupt();
477 return WalkResult::advance();
484void PostOrderWalkDriver::notifyOperationErased(Operation *op) {
488 if (!isa<EnzymeOpsRemoverOpInterface>(op))
493void PostOrderWalkDriver::notifyMatchFailure(
494 Location loc, function_ref<
void(Diagnostic &)> reasonCallback) {
495 auto diag = mlir::emitError(loc);
496 reasonCallback(*diag.getUnderlyingDiagnostic());
499void PostOrderWalkDriver::initializeWorklist() {
500 root->walk<WalkOrder::PreOrder>(
501 [
this](EnzymeOpsRemoverOpInterface iface) { addToWorklist(iface); });
504LogicalResult PostOrderWalkDriver::processWorklist() {
505 PatternRewriter rewriter(root->getContext());
506 rewriter.setListener(
this);
509 while (!worklist.empty()) {
510 auto op = worklist.pop();
511 auto iface = cast<EnzymeOpsRemoverOpInterface>(op);
513 rewriter.setInsertionPoint(current);
514 result &= iface.removeEnzymeOps(rewriter).succeeded();
518 return LogicalResult::success(result);
521struct RemoveUnusedEnzymeOpsPass
522 :
public enzyme::impl::RemoveUnusedEnzymeOpsPassBase<
523 RemoveUnusedEnzymeOpsPass> {
524 void runOnOperation()
override {
525 auto op = getOperation();
529 annotateRegionOpsInLoops(op);
531 op->walk([&](FunctionOpInterface func) {
532 PostOrderWalkDriver driver(func);
533 driver.initializeWorklist();
534 failed |= driver.processWorklist().failed();
static constexpr llvm::StringLiteral kPreserveCacheAttrName