21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
24#include "mlir/Interfaces/CallInterfaces.h"
25#include "mlir/Interfaces/FunctionInterfaces.h"
27#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Demangle/Demangle.h"
35#define GEN_PASS_DEF_PRINTACTIVITYANALYSISPASS
36#include "Passes/Passes.h.inc"
42enzyme::Activity getDefaultActivity(Type argType) {
43 if (argType.isIntOrIndex())
44 return enzyme::Activity::enzyme_const;
46 if (isa<FloatType, ComplexType>(argType))
47 return enzyme::Activity::enzyme_active;
49 if (
auto T = dyn_cast<TensorType>(argType))
50 return getDefaultActivity(T.getElementType());
52 if (isa<LLVM::LLVMPointerType, MemRefType>(argType))
53 return enzyme::Activity::enzyme_dup;
55 return enzyme::Activity::enzyme_const;
58struct PrintActivityAnalysisPass
59 :
public enzyme::impl::PrintActivityAnalysisPassBase<
60 PrintActivityAnalysisPass> {
61 using PrintActivityAnalysisPassBase::PrintActivityAnalysisPassBase;
65 void initializeArgAndResActivities(
66 FunctionOpInterface callee,
67 MutableArrayRef<enzyme::Activity> argActivities,
68 MutableArrayRef<enzyme::Activity> resActivities)
const {
69 for (
const auto &[idx, argType] :
70 llvm::enumerate(callee.getArgumentTypes())) {
71 if (callee.getArgAttr(idx,
"enzyme.const") || inactiveArgs)
72 argActivities[idx] = enzyme::Activity::enzyme_const;
74 argActivities[idx] = getDefaultActivity(argType);
77 for (
const auto &[idx, resType] :
78 llvm::enumerate(callee.getResultTypes())) {
80 resActivities[idx] = (enzyme::Activity::enzyme_dup);
82 resActivities[idx] = getDefaultActivity(resType);
86 void inferArgActivitiesFromEnzymeAutodiff(
87 FunctionOpInterface callee, CallOpInterface autodiff_call,
88 MutableArrayRef<enzyme::Activity> argActivities,
89 MutableArrayRef<enzyme::Activity> resultActivities) {
91 for (
const auto &[paramIdx, paramType] :
92 llvm::enumerate(callee.getArgumentTypes())) {
93 Value arg = autodiff_call.getArgOperands()[argIdx];
95 dyn_cast_if_present<LLVM::LoadOp>(arg.getDefiningOp())) {
96 if (
auto addressOf = dyn_cast_if_present<LLVM::AddressOfOp>(
97 loadOp.getAddr().getDefiningOp())) {
98 if (addressOf.getGlobalName() ==
"enzyme_const") {
99 argActivities[paramIdx] = enzyme::Activity::enzyme_const;
100 }
else if (addressOf.getGlobalName() ==
"enzyme_dup") {
101 argActivities[paramIdx] = enzyme::Activity::enzyme_dup;
104 }
else if (addressOf.getGlobalName() ==
"enzyme_dupnoneed") {
105 argActivities[paramIdx] = enzyme::Activity::enzyme_dupnoneed;
113 argActivities[paramIdx] =
114 llvm::TypeSwitch<Type, enzyme::Activity>(paramType)
115 .Case<FloatType, ComplexType>(
116 [](
auto type) {
return enzyme::Activity::enzyme_active; })
117 .Case<LLVM::LLVMPointerType, MemRefType>([&](
auto type) {
120 return enzyme::Activity::enzyme_dup;
123 [](Type type) {
return enzyme::Activity::enzyme_const; });
128 for (
const auto &[resIdx, resType] :
129 llvm::enumerate(callee.getResultTypes())) {
130 resultActivities[resIdx] =
131 llvm::TypeSwitch<Type, enzyme::Activity>(resType)
132 .Case<FloatType, ComplexType>(
133 [](
auto type) {
return enzyme::Activity::enzyme_active; })
135 [](Type type) {
return enzyme::Activity::enzyme_const; });
140 FunctionOpInterface callee,
141 ArrayRef<enzyme::Activity> argActivities,
142 ArrayRef<enzyme::Activity> resultActivities,
148 true, verbose, annotate);
151 SmallPtrSet<Block *, 4> blocksNotForAnalysis;
154 SmallPtrSet<mlir::Value, 1> constant_values;
155 SmallPtrSet<mlir::Value, 1> activevals_;
156 for (
auto &&[arg, act] :
157 llvm::zip(callee.getFunctionBody().getArguments(), argActivities)) {
158 if (act == enzyme::Activity::enzyme_const)
159 constant_values.insert(arg);
161 activevals_.insert(arg);
163 SmallVector<DIFFE_TYPE> ReturnActivity;
164 for (
auto act : resultActivities) {
165 if (act != enzyme::Activity::enzyme_const)
171 DenseMap<Operation *, bool> readOnlyCache;
173 readOnlyCache, constant_values,
174 activevals_, ReturnActivity);
176 callee.walk([&](Operation *op) {
179 MLIRContext *ctx = callee.getContext();
180 callee.walk([&](Operation *op) {
182 llvm::outs() <<
" Operation: " << *op <<
"\n";
183 for (
auto ® : op->getRegions()) {
184 for (
auto &blk : reg.getBlocks()) {
185 for (
auto &arg : blk.getArguments()) {
186 bool icv = activityAnalyzer.isConstantValue(TR, arg);
188 op->setAttr(
"enzyme.arg_icv" +
189 std::to_string(arg.getArgNumber()),
190 BoolAttr::get(ctx, icv));
192 llvm::outs() <<
" arg: " << arg <<
" icv=" << icv <<
"\n";
197 bool ici = activityAnalyzer.isConstantOperation(TR, op);
199 op->setAttr(
"enzyme.ici", BoolAttr::get(ctx, ici));
201 llvm::outs() <<
" op ici=" << ici <<
"\n";
203 for (
auto res : op->getResults()) {
204 bool icv = activityAnalyzer.isConstantValue(TR, res);
206 op->setAttr(
"enzyme.res_icv" +
207 std::to_string(res.getResultNumber()),
208 BoolAttr::get(ctx, icv));
210 llvm::outs() <<
" res: " << res <<
" icv=" << icv <<
"\n";
216 void runOnOperation()
override {
224 auto moduleOp = cast<ModuleOp>(getOperation());
226 if (inferFromAutodiff) {
228 Operation *autodiff_decl = moduleOp.lookupSymbol(
"__enzyme_autodiff");
229 if (!autodiff_decl) {
230 for (
auto &subOp : *moduleOp.getBody()) {
231 if (
auto func = dyn_cast<FunctionOpInterface>(&subOp)) {
232 if (func.getName().contains(
"__enzyme_autodiff")) {
233 autodiff_decl = &subOp;
239 if (!autodiff_decl) {
240 moduleOp.emitError(
"Failed to find __enzyme_autodiff symbol");
241 return signalPassFailure();
243 auto uses = SymbolTable::getSymbolUses(autodiff_decl, moduleOp);
244 assert(uses &&
"failed to find symbol uses of autodiff decl");
246 for (SymbolTable::SymbolUse use : *uses) {
247 auto autodiff_call = cast<CallOpInterface>(use.getUser());
248 FlatSymbolRefAttr calleeAttr =
249 cast<LLVM::AddressOfOp>(
250 autodiff_call.getArgOperands().front().getDefiningOp())
251 .getGlobalNameAttr();
253 cast<FunctionOpInterface>(moduleOp.lookupSymbol(calleeAttr));
255 SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
256 resultActivities{callee.getNumResults()};
259 inferArgActivitiesFromEnzymeAutodiff(callee, autodiff_call,
260 argActivities, resultActivities);
261 runActivityAnalysis(config, callee, argActivities, resultActivities);
266 if (funcsToAnalyze.empty()) {
267 moduleOp.walk([
this, config](FunctionOpInterface callee) {
268 if (callee.isExternal() || callee.isPrivate())
271 SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
272 resultActivities{callee.getNumResults()};
273 initializeArgAndResActivities(callee, argActivities, resultActivities);
275 runActivityAnalysis(config, callee, argActivities, resultActivities);
280 for (std::string funcName : funcsToAnalyze) {
281 Operation *op = moduleOp.lookupSymbol(funcName);
286 if (!isa<FunctionOpInterface>(op)) {
288 <<
"Operation " << funcName <<
" was not a FunctionOpInterface";
289 return signalPassFailure();
292 auto callee = cast<FunctionOpInterface>(op);
293 SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
294 resultActivities{callee.getNumResults()};
295 initializeArgAndResActivities(callee, argActivities, resultActivities);
297 runActivityAnalysis(config, callee, argActivities, resultActivities);
Helper class to analyze the differential activity.
bool annotate
Annotate the IR with activity information for every operation.
bool inferFromAutodiff
Infer the starting argument state from an __enzyme_autodiff call.
bool verbose
Output extra information for debugging.
bool dataflow
Whether to use the data-flow based algorithm or the classic activity analysis.
bool relative
Use function summaries.
void runDataFlowActivityAnalysis(FunctionOpInterface callee, ArrayRef< enzyme::Activity > argumentActivity, bool print=false, bool verbose=false, bool annotate=false)
void runActivityAnnotations(FunctionOpInterface callee, ArrayRef< enzyme::Activity > argActivities={}, const ActivityPrinterConfig &config=ActivityPrinterConfig())