Enzyme main
Loading...
Searching...
No Matches
ActivityAnalysisPrinter.cpp
Go to the documentation of this file.
1// ActivityAnalysisPrinter.cpp - Printer utility pass for Activity Analysis =//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains a utility LLVM pass for printing derived Activity Analysis
22// results of a given function.
23//
24//===----------------------------------------------------------------------===//
25#include <llvm/Config/llvm-config.h>
26
27#include "llvm/Analysis/ScalarEvolution.h"
28#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
29
30#include "llvm/ADT/SmallVector.h"
31
32#include "llvm/IR/BasicBlock.h"
33#include "llvm/IR/Constants.h"
34#include "llvm/IR/DebugInfoMetadata.h"
35#include "llvm/IR/Function.h"
36#include "llvm/IR/IRBuilder.h"
37#include "llvm/IR/InstrTypes.h"
38#include "llvm/IR/Instructions.h"
39#include "llvm/IR/MDBuilder.h"
40#include "llvm/IR/Metadata.h"
41
42#include "llvm/Support/Debug.h"
43#include "llvm/Transforms/Scalar.h"
44
45#include "llvm/Analysis/BasicAliasAnalysis.h"
46#include "llvm/Analysis/GlobalsModRef.h"
47#include "llvm/Analysis/ScalarEvolution.h"
48
49#include "llvm/Support/CommandLine.h"
50#include "llvm/Support/ErrorHandling.h"
51
52#include "ActivityAnalysis.h"
54#include "EnzymeLogic.h"
55#include "FunctionUtils.h"
57#include "Utils.h"
58
59using namespace llvm;
60#ifdef DEBUG_TYPE
61#undef DEBUG_TYPE
62#endif
63#define DEBUG_TYPE "activity-analysis-results"
64
65/// Function TypeAnalysis will be starting its run from
66static llvm::cl::opt<std::string>
67 FunctionToAnalyze("activity-analysis-func", cl::init(""), cl::Hidden,
68 cl::desc("Which function to analyze/print"));
69
70static llvm::cl::opt<bool>
71 InactiveArgs("activity-analysis-inactive-args", cl::init(false), cl::Hidden,
72 cl::desc("Whether all args are inactive"));
73
74static llvm::cl::opt<bool>
75 DuplicatedRet("activity-analysis-duplicated-ret", cl::init(false),
76 cl::Hidden, cl::desc("Whether the return is duplicated"));
77namespace {
78
79bool printActivityAnalysis(llvm::Function &F, TargetLibraryInfo &TLI) {
80 if (F.getName() != FunctionToAnalyze)
81 return /*changed*/ false;
82
83 FnTypeInfo type_args(&F);
84 for (auto &a : type_args.Function->args()) {
85 TypeTree dt;
86 if (a.getType()->isFPOrFPVectorTy()) {
87 dt = ConcreteType(a.getType()->getScalarType());
88 } else if (a.getType()->isPointerTy()) {
89#if LLVM_VERSION_MAJOR < 17
90 if (a.getContext().supportsTypedPointers()) {
91 auto et = a.getType()->getPointerElementType();
92 if (et->isFPOrFPVectorTy()) {
93 dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr);
94 } else if (et->isPointerTy()) {
95 dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr);
96 }
97 }
98#endif
99 } else if (a.getType()->isIntOrIntVectorTy()) {
101 }
102 type_args.Arguments.insert(
103 std::pair<Argument *, TypeTree>(&a, dt.Only(-1, nullptr)));
104 // TODO note that here we do NOT propagate constants in type info (and
105 // should consider whether we should)
106 type_args.KnownValues.insert(
107 std::pair<Argument *, std::set<int64_t>>(&a, {}));
108 }
109
110 TypeTree dt;
111 if (F.getReturnType()->isFPOrFPVectorTy()) {
112 dt = ConcreteType(F.getReturnType()->getScalarType());
113 } else if (F.getReturnType()->isPointerTy()) {
114#if LLVM_VERSION_MAJOR < 17
115 if (F.getContext().supportsTypedPointers()) {
116 auto et = F.getReturnType()->getPointerElementType();
117 if (et->isFPOrFPVectorTy()) {
118 dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr);
119 } else if (et->isPointerTy()) {
120 dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr);
121 }
122 }
123#endif
124 } else if (F.getReturnType()->isIntOrIntVectorTy()) {
126 }
127 type_args.Return = dt.Only(-1, nullptr);
128
129 EnzymeLogic Logic(false);
130 TypeAnalysis TA(Logic);
131 TypeResults TR = TA.analyzeFunction(type_args);
132
133 llvm::SmallPtrSet<llvm::Value *, 4> ConstantValues;
134 llvm::SmallPtrSet<llvm::Value *, 4> ActiveValues;
135 for (auto &a : type_args.Function->args()) {
136 if (InactiveArgs) {
137 ConstantValues.insert(&a);
138 } else if (a.getType()->isIntOrIntVectorTy()) {
139 ConstantValues.insert(&a);
140 } else {
141 ActiveValues.insert(&a);
142 }
143 }
144
145 DIFFE_TYPE ActiveReturns = F.getReturnType()->isFPOrFPVectorTy()
148 if (DuplicatedRet)
149 ActiveReturns = DIFFE_TYPE::DUP_ARG;
150 SmallPtrSet<BasicBlock *, 4> notForAnalysis(getGuaranteedUnreachable(&F));
151 ActivityAnalyzer ATA(Logic.PPC, Logic.PPC.FAM.getResult<AAManager>(F),
152 notForAnalysis, TLI, ConstantValues, ActiveValues,
153 ActiveReturns);
154
155 for (auto &a : F.args()) {
156 ATA.isConstantValue(TR, &a);
157 llvm::errs().flush();
158 }
159 for (auto &BB : F) {
160 for (auto &I : BB) {
161 ATA.isConstantInstruction(TR, &I);
162 ATA.isConstantValue(TR, &I);
163 llvm::errs().flush();
164 }
165 }
166
167 for (auto &a : F.args()) {
168 bool icv = ATA.isConstantValue(TR, &a);
169 llvm::errs().flush();
170 llvm::outs() << a << ": icv:" << icv << "\n";
171 llvm::outs().flush();
172 }
173 for (auto &BB : F) {
174 llvm::outs() << BB.getName() << "\n";
175 for (auto &I : BB) {
176 bool ici = ATA.isConstantInstruction(TR, &I);
177 bool icv = ATA.isConstantValue(TR, &I);
178 llvm::errs().flush();
179 llvm::outs() << I << ": icv:" << icv << " ici:" << ici << "\n";
180 llvm::outs().flush();
181 }
182 }
183 return /*changed*/ false;
184}
185
186class ActivityAnalysisPrinter final : public ModulePass {
187public:
188 static char ID;
189 ActivityAnalysisPrinter() : ModulePass(ID) {}
190
191 bool runOnModule(Module &M) override {
192 // Check if function name is specified
193 if (FunctionToAnalyze.empty()) {
194 EmitFailure("NoFunctionSpecified", M,
195 "No function specified for -activity-analysis-func");
196 return false;
197 }
198
199 // Check if the specified function exists
200 Function *TargetFunc = M.getFunction(FunctionToAnalyze);
201
202 if (!TargetFunc) {
203 EmitFailure("FunctionNotFound", M, "Function '", FunctionToAnalyze,
204 "' specified by -activity-analysis-func not found in module");
205 return false;
206 }
207
208 // Run analysis only on the target function
209 auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(*TargetFunc);
210 return printActivityAnalysis(*TargetFunc, TLI);
211 }
212
213 void getAnalysisUsage(AnalysisUsage &AU) const override {
214 AU.addRequired<TargetLibraryInfoWrapperPass>();
215 AU.setPreservesAll();
216 }
217};
218
219} // namespace
220
221char ActivityAnalysisPrinter::ID = 0;
222
223static RegisterPass<ActivityAnalysisPrinter>
224 X("print-activity-analysis", "Print Activity Analysis Results");
225
228 llvm::ModuleAnalysisManager &MAM) {
229 // Check if function name is specified
230 if (FunctionToAnalyze.empty()) {
231 EmitFailure("NoFunctionSpecified", M,
232 "No function specified for -activity-analysis-func");
233 return PreservedAnalyses::all();
234 }
235
236 // Check if the specified function exists
237 Function *TargetFunc = M.getFunction(FunctionToAnalyze);
238
239 if (!TargetFunc) {
240 EmitFailure("FunctionNotFound", M, "Function '", FunctionToAnalyze,
241 "' specified by -activity-analysis-func not found in module");
242 return PreservedAnalyses::all();
243 }
244
245 // Run analysis only on the target function
246 auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
247 bool changed = printActivityAnalysis(
248 *TargetFunc, FAM.getResult<TargetLibraryAnalysis>(*TargetFunc));
249 return changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
250}
251llvm::AnalysisKey ActivityAnalysisPrinterNewPM::Key;
static llvm::cl::opt< bool > DuplicatedRet("activity-analysis-duplicated-ret", cl::init(false), cl::Hidden, cl::desc("Whether the return is duplicated"))
static RegisterPass< ActivityAnalysisPrinter > X("print-activity-analysis", "Print Activity Analysis Results")
static llvm::cl::opt< std::string > FunctionToAnalyze("activity-analysis-func", cl::init(""), cl::Hidden, cl::desc("Which function to analyze/print"))
Function TypeAnalysis will be starting its run from.
static llvm::cl::opt< bool > InactiveArgs("activity-analysis-inactive-args", cl::init(false), cl::Hidden, cl::desc("Whether all args are inactive"))
static llvm::SmallPtrSet< llvm::BasicBlock *, 4 > getGuaranteedUnreachable(llvm::Function *F)
DIFFE_TYPE
Potential differentiable argument classifications.
Definition Utils.h:374
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
Definition Utils.h:203
Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM)
Helper class to analyze the differential activity.
Concrete SubType of a given value.
Full interprocedural TypeAnalysis.
A holder class representing the results of running TypeAnalysis on a given function.
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
Definition TypeTree.h:72
TypeTree Only(int Off, llvm::Instruction *orig) const
Prepend an offset to all mappings.
Definition TypeTree.h:471
bool insert(const std::vector< int > Seq, ConcreteType CT, bool PointerIntSame=false)
Return if changed.
Definition TypeTree.h:234
Struct containing all contextual type information for a particular function call.