20#include "mlir/Analysis/TopologicalSortUtils.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23#include "mlir/IR/Dominance.h"
24#include "mlir/Interfaces/FunctionInterfaces.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include "mlir/Transforms/RegionUtils.h"
27#include "mlir/Transforms/WalkPatternRewriteDriver.h"
28#include "llvm/ADT/TypeSwitch.h"
31using namespace enzyme;
34#define GEN_PASS_DEF_HOISTENZYMEFROMREGIONPASS
35#include "Passes/Passes.h.inc"
39#define DEBUG_TYPE "enzyme-hoist"
40#define ENZYME_DBGS llvm::dbgs() << "[" << DEBUG_TYPE << "]"
44template <
typename SourceOp>
46checkRangeDominance(IRMapping &btop, DominanceInfo &dom, SourceOp &rootOp,
47 SetVector<Operation *> &specialOps, ValueRange values) {
48 SmallVector<Value> blockArgs(rootOp.getBody().getArguments());
49 for (
auto value : values) {
50 if (dom.properlyDominates(value, rootOp))
55 if (isa<BlockArgument>(value)) {
57 if (btop.contains(value))
63 if (!llvm::is_contained(specialOps, value.getDefiningOp())) {
73template <
typename SourceOp>
76 LogicalResult matchAndRewrite(SourceOp rootOp,
77 PatternRewriter &rewriter)
const override {
78 DominanceInfo dom(rootOp);
79 PostDominanceInfo pdom(rootOp);
80 Region &autodiffRegion = rootOp.getBody();
81 SmallVector<Value> primalArgs = rootOp.getPrimalInputs();
82 SmallVector<Value> blockArgs(autodiffRegion.getArguments());
84 if (primalArgs.size() != blockArgs.size())
89 for (
auto [pval, bval, act] :
90 llvm::zip(primalArgs, blockArgs,
91 rootOp.getActivity().template getAsRange<ActivityAttr>())) {
92 auto act_val = act.getValue();
93 if (act_val == Activity::enzyme_const) {
99 llvm::SetVector<Value> freeValues;
100 getUsedValuesDefinedAbove(autodiffRegion, freeValues);
101 for (Value value : freeValues) {
102 for (
auto [pval, bval] : llvm::zip(primalArgs, blockArgs)) {
104 for (OpOperand &use : llvm::make_early_inc_range(value.getUses())) {
105 if (rootOp->isProperAncestor(use.getOwner()))
112 llvm::SetVector<Operation *> liftOps;
113 llvm::SetVector<Operation *> stationaryOps;
114 llvm::SmallVector<MemoryEffects::EffectInstance> stationaryEffects;
115 for (Block &blk : autodiffRegion.getBlocks()) {
118 if (pdom.postDominates(&blk, &autodiffRegion.front())) {
119 for (Operation &bodyOp : blk.without_terminator()) {
121 llvm::SmallVector<MemoryEffects::EffectInstance> bodyOpEffects;
123 bool couldCollectEffects =
126 if (!couldCollectEffects)
129 canLift = checkRangeDominance(btop, dom, rootOp, liftOps,
130 bodyOp.getOperands());
132 llvm::SetVector<Value> inside_values;
133 if (bodyOp.getNumRegions()) {
135 getUsedValuesDefinedAbove(bodyOp.getRegions(), inside_values);
136 canLift = checkRangeDominance(btop, dom, rootOp, liftOps,
137 inside_values.getArrayRef());
141 for (
auto stationaryEffect : stationaryEffects) {
142 for (
auto bodyOpEffect : bodyOpEffects) {
143 if ((isa<MemoryEffects::Write>(stationaryEffect.getEffect()) &&
144 isa<MemoryEffects::Read>(bodyOpEffect.getEffect())) ||
145 (isa<MemoryEffects::Read>(stationaryEffect.getEffect()) &&
146 isa<MemoryEffects::Write>(bodyOpEffect.getEffect())) ||
147 (isa<MemoryEffects::Write>(stationaryEffect.getEffect()) &&
148 isa<MemoryEffects::Write>(bodyOpEffect.getEffect()))) {
163 for (Value inner : inside_values) {
164 if (btop.contains(inner)) {
165 auto pval = btop.lookup(inner);
166 for (
auto ®ion : bodyOp.getRegions()) {
167 replaceAllUsesInRegionWith(inner, pval, region);
172 for (OpOperand &inner : bodyOp.getOpOperands()) {
173 inner.assign(btop.lookupOrDefault(inner.get()));
176 liftOps.insert(&bodyOp);
178 stationaryOps.insert(&bodyOp);
179 stationaryEffects.append(bodyOpEffects.begin(),
180 bodyOpEffects.end());
187 for (Operation *op : llvm::make_early_inc_range(liftOps)) {
188 rewriter.moveOpBefore(op, rootOp);
195struct HoistEnzymeFromRegion
196 :
public enzyme::impl::HoistEnzymeFromRegionPassBase<
197 HoistEnzymeFromRegion> {
198 void runOnOperation()
override {
199 RewritePatternSet patterns(&getContext());
200 patterns.add<HoistEnzymeAutoDiff<AutoDiffRegionOp>,
201 HoistEnzymeAutoDiff<ForwardDiffRegionOp>>(&getContext());
202 GreedyRewriteConfig config;
203 (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
bool mayAlias(Value v1, Value v2)
bool collectOpEffects(Operation *rootOp, SmallVector< MemoryEffects::EffectInstance > &effects)
Returns the side effects of an operation(similar to mlir::getEffectsRecursively).