Enzyme main
Loading...
Searching...
No Matches
HoistEnzymeRegions.cpp
Go to the documentation of this file.
1//===- HoistEnzymeRegions.cpp - Invariant code motion ------------===//
2//===- within enzyme.autodiff_region ----------=== //
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements passes to hoist computations within autodiff_region ops
11// to the caller
12//
13//===----------------------------------------------------------------------===//
14
15#include "Dialect/Ops.h"
17#include "Interfaces/Utils.h"
18#include "Passes/Passes.h"
19
20#include "mlir/Analysis/TopologicalSortUtils.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23#include "mlir/IR/Dominance.h"
24#include "mlir/Interfaces/FunctionInterfaces.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include "mlir/Transforms/RegionUtils.h"
27#include "mlir/Transforms/WalkPatternRewriteDriver.h"
28#include "llvm/ADT/TypeSwitch.h"
29
30using namespace mlir;
31using namespace enzyme;
32namespace mlir {
33namespace enzyme {
34#define GEN_PASS_DEF_HOISTENZYMEFROMREGIONPASS
35#include "Passes/Passes.h.inc"
36} // namespace enzyme
37} // namespace mlir
38
39#define DEBUG_TYPE "enzyme-hoist"
40#define ENZYME_DBGS llvm::dbgs() << "[" << DEBUG_TYPE << "]"
41
42namespace {
43
44template <typename SourceOp>
45static bool
46checkRangeDominance(IRMapping &btop, DominanceInfo &dom, SourceOp &rootOp,
47 SetVector<Operation *> &specialOps, ValueRange values) {
48 SmallVector<Value> blockArgs(rootOp.getBody().getArguments());
49 for (auto value : values) {
50 if (dom.properlyDominates(value, rootOp))
51 continue;
52
53 // Block arguments within autodiff_region are not supported
54 //
55 if (isa<BlockArgument>(value)) {
56 // check if it's a block argument of type enzyme_const
57 if (btop.contains(value))
58 continue;
59 else
60 return false;
61 }
62
63 if (!llvm::is_contained(specialOps, value.getDefiningOp())) {
64 return false;
65 }
66 }
67
68 // if we reach this point, it means that these current set of values are safe
69 // to hoist
70 return true;
71}
72
73template <typename SourceOp>
74struct HoistEnzymeAutoDiff : public OpRewritePattern<SourceOp> {
76 LogicalResult matchAndRewrite(SourceOp rootOp,
77 PatternRewriter &rewriter) const override {
78 DominanceInfo dom(rootOp);
79 PostDominanceInfo pdom(rootOp);
80 Region &autodiffRegion = rootOp.getBody();
81 SmallVector<Value> primalArgs = rootOp.getPrimalInputs();
82 SmallVector<Value> blockArgs(autodiffRegion.getArguments());
83
84 if (primalArgs.size() != blockArgs.size())
85 return failure();
86
87 // map for block arg -> primal arg iff activity is enzyme_const
88 IRMapping btop;
89 for (auto [pval, bval, act] :
90 llvm::zip(primalArgs, blockArgs,
91 rootOp.getActivity().template getAsRange<ActivityAttr>())) {
92 auto act_val = act.getValue();
93 if (act_val == Activity::enzyme_const) {
94 btop.map(bval, pval);
95 }
96 }
97
98 // rename all uses of primal
99 llvm::SetVector<Value> freeValues;
100 getUsedValuesDefinedAbove(autodiffRegion, freeValues);
101 for (Value value : freeValues) {
102 for (auto [pval, bval] : llvm::zip(primalArgs, blockArgs)) {
103 if (value == pval) {
104 for (OpOperand &use : llvm::make_early_inc_range(value.getUses())) {
105 if (rootOp->isProperAncestor(use.getOwner()))
106 use.assign(bval);
107 }
108 }
109 }
110 }
111
112 llvm::SetVector<Operation *> liftOps;
113 llvm::SetVector<Operation *> stationaryOps;
114 llvm::SmallVector<MemoryEffects::EffectInstance> stationaryEffects;
115 for (Block &blk : autodiffRegion.getBlocks()) {
116 // If bodyOp is in a block which does not post-dominate the entry
117 // block to the regionOp, then we disable lifting it
118 if (pdom.postDominates(&blk, &autodiffRegion.front())) {
119 for (Operation &bodyOp : blk.without_terminator()) {
120 bool canLift = true;
121 llvm::SmallVector<MemoryEffects::EffectInstance> bodyOpEffects;
122
123 bool couldCollectEffects =
124 enzyme::oputils::collectOpEffects(&bodyOp, bodyOpEffects);
125
126 if (!couldCollectEffects)
127 canLift = false;
128
129 canLift = checkRangeDominance(btop, dom, rootOp, liftOps,
130 bodyOp.getOperands());
131
132 llvm::SetVector<Value> inside_values;
133 if (bodyOp.getNumRegions()) {
134 canLift = false;
135 getUsedValuesDefinedAbove(bodyOp.getRegions(), inside_values);
136 canLift = checkRangeDominance(btop, dom, rootOp, liftOps,
137 inside_values.getArrayRef());
138 }
139
140 // Check for memory conflicts with current set of stationary ops
141 for (auto stationaryEffect : stationaryEffects) {
142 for (auto bodyOpEffect : bodyOpEffects) {
143 if ((isa<MemoryEffects::Write>(stationaryEffect.getEffect()) &&
144 isa<MemoryEffects::Read>(bodyOpEffect.getEffect())) ||
145 (isa<MemoryEffects::Read>(stationaryEffect.getEffect()) &&
146 isa<MemoryEffects::Write>(bodyOpEffect.getEffect())) ||
147 (isa<MemoryEffects::Write>(stationaryEffect.getEffect()) &&
148 isa<MemoryEffects::Write>(bodyOpEffect.getEffect()))) {
149
150 if (enzyme::oputils::mayAlias(bodyOpEffect, stationaryEffect)) {
151 canLift = false;
152 break;
153 }
154 }
155 }
156 }
157
158 if (canLift) {
159 // replace all instances of enzyme_const block args with the
160 // equivalent primal args for both inside_values and
161 // bodyOp.getOperands()
162
163 for (Value inner : inside_values) {
164 if (btop.contains(inner)) {
165 auto pval = btop.lookup(inner);
166 for (auto &region : bodyOp.getRegions()) {
167 replaceAllUsesInRegionWith(inner, pval, region);
168 }
169 }
170 }
171
172 for (OpOperand &inner : bodyOp.getOpOperands()) {
173 inner.assign(btop.lookupOrDefault(inner.get()));
174 }
175
176 liftOps.insert(&bodyOp);
177 } else {
178 stationaryOps.insert(&bodyOp);
179 stationaryEffects.append(bodyOpEffects.begin(),
180 bodyOpEffects.end());
181 }
182 }
183 }
184 }
185
186 // Lift operations
187 for (Operation *op : llvm::make_early_inc_range(liftOps)) {
188 rewriter.moveOpBefore(op, rootOp);
189 }
190
191 return success();
192 }
193};
194
195struct HoistEnzymeFromRegion
196 : public enzyme::impl::HoistEnzymeFromRegionPassBase<
197 HoistEnzymeFromRegion> {
198 void runOnOperation() override {
199 RewritePatternSet patterns(&getContext());
200 patterns.add<HoistEnzymeAutoDiff<AutoDiffRegionOp>,
201 HoistEnzymeAutoDiff<ForwardDiffRegionOp>>(&getContext());
202 GreedyRewriteConfig config;
203 (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
204 }
205};
206} // namespace
bool mayAlias(Value v1, Value v2)
Definition Utils.cpp:136
bool collectOpEffects(Operation *rootOp, SmallVector< MemoryEffects::EffectInstance > &effects)
Returns the side effects of an operation(similar to mlir::getEffectsRecursively).
Definition Utils.cpp:297