Enzyme main
Loading...
Searching...
No Matches
TypeAnalysisPrinter.cpp
Go to the documentation of this file.
1//===- TypeAnalysisPrinter.cpp - Printer utility pass for Type Analysis ---===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains a utility LLVM pass for printing derived Type Analysis
22// results of a given function.
23//
24//===----------------------------------------------------------------------===//
25#include <llvm/Config/llvm-config.h>
26
27#if LLVM_VERSION_MAJOR >= 16
28#include "llvm/Analysis/ScalarEvolution.h"
29#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
30#else
31#include "SCEV/ScalarEvolution.h"
32#include "SCEV/ScalarEvolutionExpander.h"
33#endif
34
35#include "llvm/ADT/SmallVector.h"
36
37#include "llvm/IR/BasicBlock.h"
38#include "llvm/IR/Constants.h"
39#include "llvm/IR/DebugInfoMetadata.h"
40#include "llvm/IR/Function.h"
41#include "llvm/IR/IRBuilder.h"
42#include "llvm/IR/InstrTypes.h"
43#include "llvm/IR/Instructions.h"
44#include "llvm/IR/MDBuilder.h"
45#include "llvm/IR/Metadata.h"
46
47#include "llvm/Support/Debug.h"
48#include "llvm/Transforms/Scalar.h"
49
50#include "llvm/Analysis/BasicAliasAnalysis.h"
51#include "llvm/Analysis/GlobalsModRef.h"
52#include "llvm/Analysis/ScalarEvolution.h"
53
54#include "llvm/Support/CommandLine.h"
55#include "llvm/Support/ErrorHandling.h"
56
57#include "../EnzymeLogic.h"
58#include "../FunctionUtils.h"
59#include "../Utils.h"
60#include "TypeAnalysis.h"
61#include "TypeAnalysisPrinter.h"
62
63using namespace llvm;
64#ifdef DEBUG_TYPE
65#undef DEBUG_TYPE
66#endif
67#define DEBUG_TYPE "type-analysis-results"
68
69extern "C" {
70/// Function ActivityAnalysis will be starting its run from
71llvm::cl::opt<std::string>
72 EnzymeFunctionToAnalyze("type-analysis-func", cl::init(""), cl::Hidden,
73 cl::desc("Which function to analyze/print"));
74}
75
76namespace {
77bool printTypeAnalyses(llvm::Function &F) {
78
79 if (F.getName() != EnzymeFunctionToAnalyze)
80 return /*changed*/ false;
81
82 FnTypeInfo type_args(&F);
83 for (auto &a : type_args.Function->args()) {
84 TypeTree dt;
85 if (a.getType()->isFPOrFPVectorTy()) {
86 dt = ConcreteType(a.getType()->getScalarType());
87 } else if (a.getType()->isPointerTy()) {
88#if LLVM_VERSION_MAJOR < 17
89#if LLVM_VERSION_MAJOR >= 15
90 if (F.getContext().supportsTypedPointers()) {
91#endif
92 auto et = cast<PointerType>(a.getType())->getPointerElementType();
93 if (et->isFPOrFPVectorTy()) {
94 dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr);
95 } else if (et->isPointerTy()) {
96 dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr);
97 }
98#if LLVM_VERSION_MAJOR >= 15
99 }
100#endif
101#endif
103 } else if (a.getType()->isIntOrIntVectorTy()) {
105 }
106 type_args.Arguments.insert(
107 std::pair<Argument *, TypeTree>(&a, dt.Only(-1, nullptr)));
108 // TODO note that here we do NOT propagate constants in type info (and
109 // should consider whether we should)
110 type_args.KnownValues.insert(
111 std::pair<Argument *, std::set<int64_t>>(&a, {}));
112 }
113
114 TypeTree dt;
115 if (F.getReturnType()->isFPOrFPVectorTy()) {
116 dt = ConcreteType(F.getReturnType()->getScalarType());
117 } else if (F.getReturnType()->isPointerTy()) {
118#if LLVM_VERSION_MAJOR < 17
119#if LLVM_VERSION_MAJOR >= 15
120 if (F.getContext().supportsTypedPointers()) {
121#endif
122 auto et = cast<PointerType>(F.getReturnType())->getPointerElementType();
123 if (et->isFPOrFPVectorTy()) {
124 dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr);
125 } else if (et->isPointerTy()) {
126 dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr);
127 }
128#if LLVM_VERSION_MAJOR >= 15
129 }
130#endif
131#endif
133 } else if (F.getReturnType()->isIntOrIntVectorTy()) {
135 }
136 type_args.Return = dt.Only(-1, nullptr);
137 EnzymeLogic Logic(false);
138 TypeAnalysis TA(Logic);
139 TA.analyzeFunction(type_args);
140 for (Function &f : *F.getParent()) {
141
142 for (auto &analysis : TA.analyzedFunctions) {
143 if (analysis.first.Function != &f)
144 continue;
145 auto &ta = *analysis.second;
146 llvm::outs() << f.getName() << " - " << analysis.first.Return.str()
147 << " |";
148
149 for (auto &a : f.args()) {
150 llvm::outs() << analysis.first.Arguments.find(&a)->second.str() << ":"
151 << to_string(analysis.first.KnownValues.find(&a)->second)
152 << " ";
153 }
154 llvm::outs() << "\n";
155
156 for (auto &a : f.args()) {
157 llvm::outs() << a << ": " << ta.getAnalysis(&a).str() << "\n";
158 }
159 for (auto &BB : f) {
160 llvm::outs() << BB.getName() << "\n";
161 for (auto &I : BB) {
162 llvm::outs() << I << ": " << ta.getAnalysis(&I).str() << "\n";
163 }
164 }
165 }
166 }
167 return /*changed*/ false;
168}
169
170class TypeAnalysisPrinter final : public ModulePass {
171public:
172 static char ID;
173 TypeAnalysisPrinter() : ModulePass(ID) {}
174
175 bool runOnModule(Module &M) override {
176 // Check if function name is specified
177 if (EnzymeFunctionToAnalyze.empty()) {
178 EmitFailure("NoFunctionSpecified", M,
179 "No function specified for -type-analysis-func");
180 return false;
181 }
182
183 // Check if the specified function exists
184 Function *TargetFunc = M.getFunction(EnzymeFunctionToAnalyze);
185
186 if (!TargetFunc) {
187 EmitFailure("FunctionNotFound", M, "Function '", EnzymeFunctionToAnalyze,
188 "' specified by -type-analysis-func not found in module");
189 return false;
190 }
191
192 // Run analysis only on the target function
193 return printTypeAnalyses(*TargetFunc);
194 }
195
196 void getAnalysisUsage(AnalysisUsage &AU) const override {
197 AU.setPreservesAll();
198 }
199};
200
201} // namespace
202
203char TypeAnalysisPrinter::ID = 0;
204
205static RegisterPass<TypeAnalysisPrinter> X("print-type-analysis",
206 "Print Type Analysis Results");
207
210 llvm::ModuleAnalysisManager &MAM) {
211 // Check if function name is specified
212 if (EnzymeFunctionToAnalyze.empty()) {
213 EmitFailure("NoFunctionSpecified", M,
214 "No function specified for -type-analysis-func");
215 return PreservedAnalyses::all();
216 }
217
218 // Check if the specified function exists
219 Function *TargetFunc = M.getFunction(EnzymeFunctionToAnalyze);
220
221 if (!TargetFunc) {
222 EmitFailure("FunctionNotFound", M, "Function '", EnzymeFunctionToAnalyze,
223 "' specified by -type-analysis-func not found in module");
224 return PreservedAnalyses::all();
225 }
226
227 // Run analysis only on the target function
228 bool changed = printTypeAnalyses(*TargetFunc);
229 return changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
230}
231llvm::AnalysisKey TypeAnalysisPrinterNewPM::Key;
constexpr const char * to_string(ActivityAnalyzer::UseActivity UA)
llvm::cl::opt< std::string > EnzymeFunctionToAnalyze("type-analysis-func", cl::init(""), cl::Hidden, cl::desc("Which function to analyze/print"))
Function ActivityAnalysis will be starting its run from.
static RegisterPass< TypeAnalysisPrinter > X("print-type-analysis", "Print Type Analysis Results")
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
Definition Utils.h:203
Concrete SubType of a given value.
llvm::PreservedAnalyses Result
Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM)
Full interprocedural TypeAnalysis.
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
Definition TypeTree.h:72
TypeTree Only(int Off, llvm::Instruction *orig) const
Prepend an offset to all mappings.
Definition TypeTree.h:471
bool insert(const std::vector< int > Seq, ConcreteType CT, bool PointerIntSame=false)
Return if changed.
Definition TypeTree.h:234
Struct containing all contextual type information for a particular function call.