Enzyme main
Loading...
Searching...
No Matches
PrintActivityAnalysis.cpp
Go to the documentation of this file.
1//===- PrintActivityAnalysis.cpp - Pass to print activity analysis --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to print the results of running activity
10// analysis.
11//
12//===----------------------------------------------------------------------===//
16#include "Dialect/Ops.h"
18#include "Passes/PassDetails.h"
19#include "Passes/Passes.h"
20
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
24#include "mlir/Interfaces/CallInterfaces.h"
25#include "mlir/Interfaces/FunctionInterfaces.h"
26
27#include "llvm/ADT/TypeSwitch.h"
28
29#include "llvm/Demangle/Demangle.h"
30
31using namespace mlir;
32
33namespace mlir {
34namespace enzyme {
35#define GEN_PASS_DEF_PRINTACTIVITYANALYSISPASS
36#include "Passes/Passes.h.inc"
37} // namespace enzyme
38} // namespace mlir
39
40namespace {
41
42enzyme::Activity getDefaultActivity(Type argType) {
43 if (argType.isIntOrIndex())
44 return enzyme::Activity::enzyme_const;
45
46 if (isa<FloatType, ComplexType>(argType))
47 return enzyme::Activity::enzyme_active;
48
49 if (auto T = dyn_cast<TensorType>(argType))
50 return getDefaultActivity(T.getElementType());
51
52 if (isa<LLVM::LLVMPointerType, MemRefType>(argType))
53 return enzyme::Activity::enzyme_dup;
54
55 return enzyme::Activity::enzyme_const;
56}
57
58struct PrintActivityAnalysisPass
59 : public enzyme::impl::PrintActivityAnalysisPassBase<
60 PrintActivityAnalysisPass> {
61 using PrintActivityAnalysisPassBase::PrintActivityAnalysisPassBase;
62
63 /// Do the simplest possible inference of argument and result activities, or
64 /// take the user's explicit override if provided
65 void initializeArgAndResActivities(
66 FunctionOpInterface callee,
67 MutableArrayRef<enzyme::Activity> argActivities,
68 MutableArrayRef<enzyme::Activity> resActivities) const {
69 for (const auto &[idx, argType] :
70 llvm::enumerate(callee.getArgumentTypes())) {
71 if (callee.getArgAttr(idx, "enzyme.const") || inactiveArgs)
72 argActivities[idx] = enzyme::Activity::enzyme_const;
73 else
74 argActivities[idx] = getDefaultActivity(argType);
75 }
76
77 for (const auto &[idx, resType] :
78 llvm::enumerate(callee.getResultTypes())) {
79 if (duplicatedRet)
80 resActivities[idx] = (enzyme::Activity::enzyme_dup);
81 else
82 resActivities[idx] = getDefaultActivity(resType);
83 }
84 }
85
86 void inferArgActivitiesFromEnzymeAutodiff(
87 FunctionOpInterface callee, CallOpInterface autodiff_call,
88 MutableArrayRef<enzyme::Activity> argActivities,
89 MutableArrayRef<enzyme::Activity> resultActivities) {
90 unsigned argIdx = 1;
91 for (const auto &[paramIdx, paramType] :
92 llvm::enumerate(callee.getArgumentTypes())) {
93 Value arg = autodiff_call.getArgOperands()[argIdx];
94 if (auto loadOp =
95 dyn_cast_if_present<LLVM::LoadOp>(arg.getDefiningOp())) {
96 if (auto addressOf = dyn_cast_if_present<LLVM::AddressOfOp>(
97 loadOp.getAddr().getDefiningOp())) {
98 if (addressOf.getGlobalName() == "enzyme_const") {
99 argActivities[paramIdx] = enzyme::Activity::enzyme_const;
100 } else if (addressOf.getGlobalName() == "enzyme_dup") {
101 argActivities[paramIdx] = enzyme::Activity::enzyme_dup;
102 // Skip the shadow
103 argIdx++;
104 } else if (addressOf.getGlobalName() == "enzyme_dupnoneed") {
105 argActivities[paramIdx] = enzyme::Activity::enzyme_dupnoneed;
106 // Skip the shadow
107 argIdx++;
108 }
109 }
110 // Skip the enzyme_* annotation
111 argIdx++;
112 } else {
113 argActivities[paramIdx] =
114 llvm::TypeSwitch<Type, enzyme::Activity>(paramType)
115 .Case<FloatType, ComplexType>(
116 [](auto type) { return enzyme::Activity::enzyme_active; })
117 .Case<LLVM::LLVMPointerType, MemRefType>([&](auto type) {
118 // Skip the shadow
119 argIdx++;
120 return enzyme::Activity::enzyme_dup;
121 })
122 .Default(
123 [](Type type) { return enzyme::Activity::enzyme_const; });
124 }
125 argIdx++;
126 }
127
128 for (const auto &[resIdx, resType] :
129 llvm::enumerate(callee.getResultTypes())) {
130 resultActivities[resIdx] =
131 llvm::TypeSwitch<Type, enzyme::Activity>(resType)
132 .Case<FloatType, ComplexType>(
133 [](auto type) { return enzyme::Activity::enzyme_active; })
134 .Default(
135 [](Type type) { return enzyme::Activity::enzyme_const; });
136 }
137 }
138
139 void runActivityAnalysis(const enzyme::ActivityPrinterConfig &config,
140 FunctionOpInterface callee,
141 ArrayRef<enzyme::Activity> argActivities,
142 ArrayRef<enzyme::Activity> resultActivities,
143 bool print = true) {
144 if (config.relative) {
145 enzyme::runActivityAnnotations(callee, argActivities, config);
146 } else if (config.dataflow) {
147 enzyme::runDataFlowActivityAnalysis(callee, argActivities,
148 /*print=*/true, verbose, annotate);
149 } else {
150
151 SmallPtrSet<Block *, 4> blocksNotForAnalysis;
152
154 SmallPtrSet<mlir::Value, 1> constant_values;
155 SmallPtrSet<mlir::Value, 1> activevals_;
156 for (auto &&[arg, act] :
157 llvm::zip(callee.getFunctionBody().getArguments(), argActivities)) {
158 if (act == enzyme::Activity::enzyme_const)
159 constant_values.insert(arg);
160 else
161 activevals_.insert(arg);
162 }
163 SmallVector<DIFFE_TYPE> ReturnActivity;
164 for (auto act : resultActivities) {
165 if (act != enzyme::Activity::enzyme_const)
166 ReturnActivity.push_back(DIFFE_TYPE::DUP_ARG);
167 else
168 ReturnActivity.push_back(DIFFE_TYPE::CONSTANT);
169 }
170
171 DenseMap<Operation *, bool> readOnlyCache;
172 enzyme::ActivityAnalyzer activityAnalyzer(blocksNotForAnalysis,
173 readOnlyCache, constant_values,
174 activevals_, ReturnActivity);
175
176 callee.walk([&](Operation *op) {
177
178 });
179 MLIRContext *ctx = callee.getContext();
180 callee.walk([&](Operation *op) {
181 if (print)
182 llvm::outs() << " Operation: " << *op << "\n";
183 for (auto &reg : op->getRegions()) {
184 for (auto &blk : reg.getBlocks()) {
185 for (auto &arg : blk.getArguments()) {
186 bool icv = activityAnalyzer.isConstantValue(TR, arg);
187 if (annotate)
188 op->setAttr("enzyme.arg_icv" +
189 std::to_string(arg.getArgNumber()),
190 BoolAttr::get(ctx, icv));
191 if (print)
192 llvm::outs() << " arg: " << arg << " icv=" << icv << "\n";
193 }
194 }
195 }
196
197 bool ici = activityAnalyzer.isConstantOperation(TR, op);
198 if (annotate)
199 op->setAttr("enzyme.ici", BoolAttr::get(ctx, ici));
200 if (print)
201 llvm::outs() << " op ici=" << ici << "\n";
202
203 for (auto res : op->getResults()) {
204 bool icv = activityAnalyzer.isConstantValue(TR, res);
205 if (annotate)
206 op->setAttr("enzyme.res_icv" +
207 std::to_string(res.getResultNumber()),
208 BoolAttr::get(ctx, icv));
209 if (print)
210 llvm::outs() << " res: " << res << " icv=" << icv << "\n";
211 }
212 });
213 }
214 }
215
216 void runOnOperation() override {
218 config.dataflow = dataflow;
219 config.relative = relative;
220 config.annotate = annotate;
221 config.inferFromAutodiff = inferFromAutodiff;
222 config.verbose = verbose;
223
224 auto moduleOp = cast<ModuleOp>(getOperation());
225
226 if (inferFromAutodiff) {
227 // Infer the activity attributes from the __enzyme_autodiff call
228 Operation *autodiff_decl = moduleOp.lookupSymbol("__enzyme_autodiff");
229 if (!autodiff_decl) {
230 for (auto &subOp : *moduleOp.getBody()) {
231 if (auto func = dyn_cast<FunctionOpInterface>(&subOp)) {
232 if (func.getName().contains("__enzyme_autodiff")) {
233 autodiff_decl = &subOp;
234 break;
235 }
236 }
237 }
238 }
239 if (!autodiff_decl) {
240 moduleOp.emitError("Failed to find __enzyme_autodiff symbol");
241 return signalPassFailure();
242 }
243 auto uses = SymbolTable::getSymbolUses(autodiff_decl, moduleOp);
244 assert(uses && "failed to find symbol uses of autodiff decl");
245
246 for (SymbolTable::SymbolUse use : *uses) {
247 auto autodiff_call = cast<CallOpInterface>(use.getUser());
248 FlatSymbolRefAttr calleeAttr =
249 cast<LLVM::AddressOfOp>(
250 autodiff_call.getArgOperands().front().getDefiningOp())
251 .getGlobalNameAttr();
252 auto callee =
253 cast<FunctionOpInterface>(moduleOp.lookupSymbol(calleeAttr));
254
255 SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
256 resultActivities{callee.getNumResults()};
257 // Populate the argument activities based on either the type or the
258 // supplied annotation. First argument is the callee
259 inferArgActivitiesFromEnzymeAutodiff(callee, autodiff_call,
260 argActivities, resultActivities);
261 runActivityAnalysis(config, callee, argActivities, resultActivities);
262 }
263 return;
264 }
265
266 if (funcsToAnalyze.empty()) {
267 moduleOp.walk([this, config](FunctionOpInterface callee) {
268 if (callee.isExternal() || callee.isPrivate())
269 return;
270
271 SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
272 resultActivities{callee.getNumResults()};
273 initializeArgAndResActivities(callee, argActivities, resultActivities);
274
275 runActivityAnalysis(config, callee, argActivities, resultActivities);
276 });
277 return;
278 }
279
280 for (std::string funcName : funcsToAnalyze) {
281 Operation *op = moduleOp.lookupSymbol(funcName);
282 if (!op) {
283 continue;
284 }
285
286 if (!isa<FunctionOpInterface>(op)) {
287 moduleOp.emitError()
288 << "Operation " << funcName << " was not a FunctionOpInterface";
289 return signalPassFailure();
290 }
291
292 auto callee = cast<FunctionOpInterface>(op);
293 SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
294 resultActivities{callee.getNumResults()};
295 initializeArgAndResActivities(callee, argActivities, resultActivities);
296
297 runActivityAnalysis(config, callee, argActivities, resultActivities);
298 }
299 }
300};
301} // namespace
Helper class to analyze the differential activity.
bool annotate
Annotate the IR with activity information for every operation.
bool inferFromAutodiff
Infer the starting argument state from an __enzyme_autodiff call.
bool verbose
Output extra information for debugging.
bool dataflow
Whether to use the data-flow based algorithm or the classic activity analysis.
bool relative
Use function summaries.
void runDataFlowActivityAnalysis(FunctionOpInterface callee, ArrayRef< enzyme::Activity > argumentActivity, bool print=false, bool verbose=false, bool annotate=false)
void runActivityAnnotations(FunctionOpInterface callee, ArrayRef< enzyme::Activity > argActivities={}, const ActivityPrinterConfig &config=ActivityPrinterConfig())