Enzyme main
Loading...
Searching...
No Matches
TypeAnalysis.h
Go to the documentation of this file.
1//===- TypeAnalysis.h - Declaration of Type Analysis ------------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains the declaration of Type Analysis, a utility for
22// computing the underlying data type of LLVM values.
23//
24//===----------------------------------------------------------------------===//
25#ifndef ENZYME_TYPE_ANALYSIS_H
26#define ENZYME_TYPE_ANALYSIS_H 1
27
28#include <cstdint>
29#include <deque>
30
31#include <llvm/Config/llvm-config.h>
32
33#include "llvm/ADT/SetVector.h"
34#include "llvm/ADT/StringMap.h"
35
36#include "llvm/Analysis/TargetLibraryInfo.h"
37
38#if LLVM_VERSION_MAJOR >= 16
39#include "llvm/Analysis/ScalarEvolution.h"
40#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
41#else
42#include "SCEV/ScalarEvolution.h"
43#include "SCEV/ScalarEvolutionExpander.h"
44#endif
45
46#include "llvm/IR/Constants.h"
47#include "llvm/IR/InstVisitor.h"
48#include "llvm/IR/ModuleSlotTracker.h"
49#include "llvm/IR/Type.h"
50#include "llvm/IR/Value.h"
51
52#include "llvm/Analysis/LoopInfo.h"
53#include "llvm/Analysis/PostDominators.h"
54#include "llvm/IR/Dominators.h"
55
56#include "../Utils.h"
57#include "TypeTree.h"
58
59extern const llvm::StringMap<llvm::Intrinsic::ID> LIBM_FUNCTIONS;
60
61static inline bool isMemFreeLibMFunction(llvm::StringRef str,
62 llvm::Intrinsic::ID *ID = nullptr) {
63 llvm::StringRef ogstr = str;
64 if (ID) {
65 if (str == "llvm.enzyme.lifetime_start") {
66 *ID = llvm::Intrinsic::lifetime_start;
67 return false;
68 }
69 if (str == "llvm.enzyme.lifetime_end") {
70 *ID = llvm::Intrinsic::lifetime_end;
71 return false;
72 }
73 }
74 if (startsWith(str, "__") && endsWith(str, "_finite")) {
75 str = str.substr(2, str.size() - 2 - 7);
76 } else if (startsWith(str, "__fd_") && endsWith(str, "_1")) {
77 str = str.substr(5, str.size() - 5 - 2);
78 } else if (startsWith(str, "__nv_")) {
79 str = str.substr(5, str.size() - 5);
80 } else if (startsWith(str, "__ocml_")) {
81 str = str.substr(7, str.size() - 7);
82 }
83 if (LIBM_FUNCTIONS.find(str.str()) != LIBM_FUNCTIONS.end()) {
84 if (ID)
85 *ID = LIBM_FUNCTIONS.find(str.str())->second;
86 return true;
87 }
88 if (endsWith(str, "f") || endsWith(str, "l") ||
89 (startsWith(ogstr, "__nv_") && endsWith(str, "d"))) {
90 if (LIBM_FUNCTIONS.find(str.substr(0, str.size() - 1).str()) !=
91 LIBM_FUNCTIONS.end()) {
92 if (ID)
93 *ID = LIBM_FUNCTIONS.find(str.substr(0, str.size() - 1).str())->second;
94 return true;
95 }
96 }
97 if ((startsWith(ogstr, "__ocml_") &&
98 (endsWith(str, "_f64") || endsWith(str, "_f32")))) {
99 if (LIBM_FUNCTIONS.find(str.substr(0, str.size() - 4).str()) !=
100 LIBM_FUNCTIONS.end()) {
101 if (ID)
102 *ID = LIBM_FUNCTIONS.find(str.substr(0, str.size() - 4).str())->second;
103 return true;
104 }
105 }
106 return false;
107}
108
109/// Struct containing all contextual type information for a
110/// particular function call
112 /// Function being analyzed
113 llvm::Function *Function;
114
115 FnTypeInfo(llvm::Function *fn) : Function(fn) {}
116 FnTypeInfo(const FnTypeInfo &) = default;
119
120 /// Types of arguments
121 std::map<llvm::Argument *, TypeTree> Arguments;
122
123 /// Type of return
125
126 /// The specific constant(s) known to represented by an argument, if constant
127 std::map<llvm::Argument *, std::set<int64_t>> KnownValues;
128
129 /// The set of known values val will take
130 std::set<int64_t>
131 knownIntegralValues(llvm::Value *val, const llvm::DominatorTree &DT,
132 std::map<llvm::Value *, std::set<int64_t>> &intseen,
133 llvm::ScalarEvolution &SE) const;
134};
135
136static inline bool operator<(const FnTypeInfo &lhs, const FnTypeInfo &rhs) {
137
138 if (lhs.Function < rhs.Function)
139 return true;
140 if (rhs.Function < lhs.Function)
141 return false;
142
143 if (lhs.Return < rhs.Return)
144 return true;
145 if (rhs.Return < lhs.Return)
146 return false;
147
148 for (auto &arg : lhs.Function->args()) {
149 {
150 auto foundLHS = lhs.Arguments.find(&arg);
151 assert(foundLHS != lhs.Arguments.end());
152 auto foundRHS = rhs.Arguments.find(&arg);
153 assert(foundRHS != rhs.Arguments.end());
154 if (foundLHS->second < foundRHS->second)
155 return true;
156 if (foundRHS->second < foundLHS->second)
157 return false;
158 }
159
160 {
161 auto foundLHS = lhs.KnownValues.find(&arg);
162 assert(foundLHS != lhs.KnownValues.end());
163 auto foundRHS = rhs.KnownValues.find(&arg);
164 assert(foundRHS != rhs.KnownValues.end());
165 if (foundLHS->second < foundRHS->second)
166 return true;
167 if (foundRHS->second < foundLHS->second)
168 return false;
169 }
170 }
171 // equal;
172 return false;
173}
174
175class TypeAnalyzer;
176class TypeAnalysis;
177class EnzymeLogic;
178
179/// A holder class representing the results of running TypeAnalysis
180/// on a given function
182public:
184
185public:
186 TypeResults(std::nullptr_t);
188 ConcreteType intType(size_t num, llvm::Value *val, bool errIfNotFound = true,
189 bool pointerIntSame = false) const;
190 llvm::Type *addingType(size_t num, llvm::Value *val, size_t start = 0) const;
191
192 /// Returns whether in the first num bytes there is pointer, int, float, or
193 /// none If pointerIntSame is set to true, then consider either as the same
194 /// (and thus mergable)
195 ConcreteType firstPointer(size_t num, llvm::Value *val, llvm::Instruction *I,
196 bool errIfNotFound = true,
197 bool pointerIntSame = false) const;
198
199 /// The TypeTree of a particular Value
200 TypeTree query(llvm::Value *val) const;
201
202 /// Whether any part of the top level register can contain a float
203 /// e.g. { i64, float } can contain a float, but { i64, i8* } would not.
204 // Of course, here we compute with type analysis rather than llvm type
205 // The flag `anythingIsFloat` specifies whether an anything should
206 // be considered a float.
207 bool anyFloat(llvm::Value *val, bool anythingIsFloat = true) const;
208
209 /// Whether all of the top level register is known to contain float data
210 bool allFloat(llvm::Value *val) const;
211
212 /// Whether any part of the top level register can contain a pointer
213 /// e.g. { i64, i8* } can contain a pointer, but { i64, float } would not.
214 // Of course, here we compute with type analysis rather than llvm type
215 bool anyPointer(llvm::Value *val) const;
216
217 /// The TypeInfo calling convention
219
220 /// The Type of the return
222
223 /// Prints all known information
224 void dump(llvm::raw_ostream &ss = llvm::errs()) const;
225
226 /// The set of values val will take on during this program
227 std::set<int64_t> knownIntegralValues(llvm::Value *val) const;
228
229 FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn) const;
230
231 llvm::Function *getFunction() const;
232};
233
234/// Helper class that computes the fixed-point type results of a given function
235class TypeAnalyzer : public llvm::InstVisitor<TypeAnalyzer> {
236public:
237 /// Cache of metadata indices, for faster printing.
238 /// Only initialized if EnzymePrintType is true
239 std::shared_ptr<llvm::ModuleSlotTracker> MST;
240
241 /// List of value's which should be re-analyzed now with new information
242 llvm::SetVector<llvm::Value *, std::deque<llvm::Value *>> workList;
243
244 const llvm::SmallPtrSet<llvm::BasicBlock *, 4> notForAnalysis;
245
246private:
247 /// Tell TypeAnalyzer to reanalyze this value
248 void addToWorkList(llvm::Value *val);
249
250 /// Map of Value to known integer constants that it will take on
251 std::map<llvm::Value *, std::set<int64_t>> intseen;
252
253 std::map<llvm::Value *, std::pair<bool, bool>> mriseen;
254 bool mustRemainInteger(llvm::Value *val, bool *returned = nullptr);
255
256public:
257 /// Calling context
259
260 /// Calling TypeAnalysis to be used in the case of calls to other
261 /// functions
263
264 /// Directionality of checks
265 uint8_t direction;
266
267 /// Whether an inconsistent update has been found
268 /// This will only be set when direction != Both, erring otherwise
270
272
273 // propagate from instruction to operand
274 static constexpr uint8_t UP = 1;
275 // propagate from operand to instruction
276 static constexpr uint8_t DOWN = 2;
277 static constexpr uint8_t BOTH = UP | DOWN;
278
279 /// Intermediate conservative, but correct Type analysis results
280 std::map<llvm::Value *, TypeTree> analysis;
281
282 llvm::TargetLibraryInfo &TLI;
283 llvm::DominatorTree &DT;
284 llvm::PostDominatorTree &PDT;
285
286 llvm::LoopInfo &LI;
287 llvm::ScalarEvolution &SE;
288
289 FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn);
290
292
293 TypeAnalyzer(const FnTypeInfo &fn, TypeAnalysis &TA,
294 uint8_t direction = BOTH);
295
296 TypeAnalyzer(const FnTypeInfo &fn, TypeAnalysis &TA,
297 const llvm::SmallPtrSetImpl<llvm::BasicBlock *> &notForAnalysis,
298 const TypeAnalyzer &Prev, uint8_t direction = BOTH,
299 bool PHIRecur = false);
300
301 /// Get the current results for a given value
302 TypeTree getAnalysis(llvm::Value *Val);
303
304 /// Add additional information to the Type info of val, readding it to the
305 /// work queue as necessary
306 void updateAnalysis(llvm::Value *val, BaseType data, llvm::Value *origin);
307 void updateAnalysis(llvm::Value *val, ConcreteType data, llvm::Value *origin);
308 void updateAnalysis(llvm::Value *val, TypeTree data, llvm::Value *origin);
309
310 /// Analyze type info given by the arguments, possibly adding to work queue
311 void prepareArgs();
312
313 /// Analyze type info given by the TBAA, possibly adding to work queue
314 void considerTBAA();
315
316 /// Parse the debug info generated by rustc and retrieve useful type info if
317 /// possible
319
320 /// Run the interprocedural type analysis starting from this function
321 void run();
322
323 /// Hypothesize that undefined phi's are integers and try to prove
324 /// that they are really integral
325 void runPHIHypotheses();
326
327 void visitValue(llvm::Value &val);
328
329 void visitConstantExpr(llvm::ConstantExpr &CE);
330
331 void visitCmpInst(llvm::CmpInst &I);
332
333 void visitAllocaInst(llvm::AllocaInst &I);
334
335 void visitLoadInst(llvm::LoadInst &I);
336
337 void visitStoreInst(llvm::StoreInst &I);
338
339 void visitGetElementPtrInst(llvm::GetElementPtrInst &gep);
340
341 void visitGEPOperator(llvm::GEPOperator &gep);
342
343 void visitPHINode(llvm::PHINode &phi);
344
345 void visitTruncInst(llvm::TruncInst &I);
346
347 void visitZExtInst(llvm::ZExtInst &I);
348
349 void visitSExtInst(llvm::SExtInst &I);
350
351 void visitAddrSpaceCastInst(llvm::AddrSpaceCastInst &I);
352
353 void visitFPExtInst(llvm::FPExtInst &I);
354
355 void visitFPTruncInst(llvm::FPTruncInst &I);
356
357 void visitFPToUIInst(llvm::FPToUIInst &I);
358
359 void visitFPToSIInst(llvm::FPToSIInst &I);
360
361 void visitUIToFPInst(llvm::UIToFPInst &I);
362
363 void visitSIToFPInst(llvm::SIToFPInst &I);
364
365 void visitPtrToIntInst(llvm::PtrToIntInst &I);
366
367 void visitIntToPtrInst(llvm::IntToPtrInst &I);
368
369 void visitBitCastInst(llvm::BitCastInst &I);
370
371#if LLVM_VERSION_MAJOR >= 10
372 void visitFreezeInst(llvm::FreezeInst &I);
373#endif
374
375 void visitSelectInst(llvm::SelectInst &I);
376
377 void visitExtractElementInst(llvm::ExtractElementInst &I);
378
379 void visitInsertElementInst(llvm::InsertElementInst &I);
380
381 void visitShuffleVectorInst(llvm::ShuffleVectorInst &I);
382
383 void visitExtractValueInst(llvm::ExtractValueInst &I);
384
385 void visitInsertValueInst(llvm::InsertValueInst &I);
386
387 void visitAtomicRMWInst(llvm::AtomicRMWInst &I);
388
389 void visitBinaryOperator(llvm::BinaryOperator &I);
390 void visitBinaryOperation(const llvm::DataLayout &DL, llvm::Type *T,
391 llvm::Instruction::BinaryOps, llvm::Value *Args[2],
392 TypeTree &Ret, TypeTree &LHS, TypeTree &RHS,
393 llvm::Instruction *I);
394
395 void visitIPOCall(llvm::CallBase &call, llvm::Function &fn);
396
397 void visitCallBase(llvm::CallBase &call);
398
399 void visitMemTransferInst(llvm::MemTransferInst &MTI);
400 void visitMemTransferCommon(llvm::CallBase &MTI);
401
402 void visitIntrinsicInst(llvm::IntrinsicInst &II);
403
405
406 void dump(llvm::raw_ostream &ss = llvm::errs());
407
408 std::set<int64_t> knownIntegralValues(llvm::Value *val);
409
410 // TODO handle fneg on LLVM 10+
411};
412
413/// Full interprocedural TypeAnalysis
415public:
418 /// Map of custom function call handlers
419 llvm::StringMap<
420 std::function<bool(int /*direction*/, TypeTree & /*returnTree*/,
421 llvm::ArrayRef<TypeTree> /*argTrees*/,
422 llvm::ArrayRef<std::set<int64_t>> /*knownValues*/,
423 llvm::CallBase * /*call*/, TypeAnalyzer *)>>
425
426 /// Map of possible query states to TypeAnalyzer intermediate results
427 std::map<FnTypeInfo, std::shared_ptr<TypeAnalyzer>> analyzedFunctions;
428
429 /// Analyze a particular function, returning the results
431
432 /// Clear existing analyses
433 void clear();
434};
435
436TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I,
437 bool intIsPointer = true);
439 llvm::Function *todiff);
440#endif
BaseType
Categories of potential types.
Definition BaseType.h:32
static std::string str(AugmentedStruct c)
Definition EnzymeLogic.h:62
FnTypeInfo preventTypeAnalysisLoops(const FnTypeInfo &oldTypeInfo_, llvm::Function *todiff)
TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I, bool intIsPointer=true)
const llvm::StringMap< llvm::Intrinsic::ID > LIBM_FUNCTIONS
static bool isMemFreeLibMFunction(llvm::StringRef str, llvm::Intrinsic::ID *ID=nullptr)
static bool operator<(const FnTypeInfo &lhs, const FnTypeInfo &rhs)
@ Args
Return is a struct of all args.
static bool startsWith(llvm::StringRef string, llvm::StringRef prefix)
Definition Utils.h:713
static bool endsWith(llvm::StringRef string, llvm::StringRef suffix)
Definition Utils.h:721
Concrete SubType of a given value.
Full interprocedural TypeAnalysis.
llvm::StringMap< std::function< bool(int, TypeTree &, llvm::ArrayRef< TypeTree >, llvm::ArrayRef< std::set< int64_t > >, llvm::CallBase *, TypeAnalyzer *)> > CustomRules
Map of custom function call handlers.
std::map< FnTypeInfo, std::shared_ptr< TypeAnalyzer > > analyzedFunctions
Map of possible query states to TypeAnalyzer intermediate results.
void clear()
Clear existing analyses.
TypeAnalysis(EnzymeLogic &Logic)
TypeResults analyzeFunction(const FnTypeInfo &fn)
Analyze a particular function, returning the results.
EnzymeLogic & Logic
Helper class that computes the fixed-point type results of a given function.
void visitMemTransferInst(llvm::MemTransferInst &MTI)
void visitFPToSIInst(llvm::FPToSIInst &I)
llvm::PostDominatorTree & PDT
FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn)
void visitAllocaInst(llvm::AllocaInst &I)
const FnTypeInfo fntypeinfo
Calling context.
void visitExtractValueInst(llvm::ExtractValueInst &I)
void visitSIToFPInst(llvm::SIToFPInst &I)
const llvm::SmallPtrSet< llvm::BasicBlock *, 4 > notForAnalysis
void considerRustDebugInfo()
Parse the debug info generated by rustc and retrieve useful type info if possible.
llvm::DominatorTree & DT
std::set< int64_t > knownIntegralValues(llvm::Value *val)
void visitFPExtInst(llvm::FPExtInst &I)
void visitConstantExpr(llvm::ConstantExpr &CE)
TypeAnalysis & interprocedural
Calling TypeAnalysis to be used in the case of calls to other functions.
void visitGetElementPtrInst(llvm::GetElementPtrInst &gep)
void visitIntToPtrInst(llvm::IntToPtrInst &I)
void prepareArgs()
Analyze type info given by the arguments, possibly adding to work queue.
void visitShuffleVectorInst(llvm::ShuffleVectorInst &I)
llvm::TargetLibraryInfo & TLI
void visitUIToFPInst(llvm::UIToFPInst &I)
static constexpr uint8_t UP
void visitInsertValueInst(llvm::InsertValueInst &I)
void visitPHINode(llvm::PHINode &phi)
void visitValue(llvm::Value &val)
void visitSelectInst(llvm::SelectInst &I)
void visitIPOCall(llvm::CallBase &call, llvm::Function &fn)
std::shared_ptr< llvm::ModuleSlotTracker > MST
Cache of metadata indices, for faster printing.
llvm::SetVector< llvm::Value *, std::deque< llvm::Value * > > workList
List of value's which should be re-analyzed now with new information.
static constexpr uint8_t DOWN
void updateAnalysis(llvm::Value *val, BaseType data, llvm::Value *origin)
Add additional information to the Type info of val, readding it to the work queue as necessary.
void visitAddrSpaceCastInst(llvm::AddrSpaceCastInst &I)
bool Invalid
Whether an inconsistent update has been found This will only be set when direction !...
static constexpr uint8_t BOTH
void visitExtractElementInst(llvm::ExtractElementInst &I)
void visitLoadInst(llvm::LoadInst &I)
void visitCmpInst(llvm::CmpInst &I)
uint8_t direction
Directionality of checks.
std::map< llvm::Value *, TypeTree > analysis
Intermediate conservative, but correct Type analysis results.
void visitIntrinsicInst(llvm::IntrinsicInst &II)
void visitSExtInst(llvm::SExtInst &I)
void visitGEPOperator(llvm::GEPOperator &gep)
void updateAnalysis(llvm::Value *val, ConcreteType data, llvm::Value *origin)
void considerTBAA()
Analyze type info given by the TBAA, possibly adding to work queue.
void visitCallBase(llvm::CallBase &call)
void visitFPTruncInst(llvm::FPTruncInst &I)
void visitStoreInst(llvm::StoreInst &I)
void visitAtomicRMWInst(llvm::AtomicRMWInst &I)
TypeAnalyzer(TypeAnalysis &TA)
void visitZExtInst(llvm::ZExtInst &I)
void runPHIHypotheses()
Hypothesize that undefined phi's are integers and try to prove that they are really integral.
void visitBinaryOperator(llvm::BinaryOperator &I)
void visitMemTransferCommon(llvm::CallBase &MTI)
TypeTree getAnalysis(llvm::Value *Val)
Get the current results for a given value.
void visitBinaryOperation(const llvm::DataLayout &DL, llvm::Type *T, llvm::Instruction::BinaryOps, llvm::Value *Args[2], TypeTree &Ret, TypeTree &LHS, TypeTree &RHS, llvm::Instruction *I)
void visitFPToUIInst(llvm::FPToUIInst &I)
void visitBitCastInst(llvm::BitCastInst &I)
void updateAnalysis(llvm::Value *val, TypeTree data, llvm::Value *origin)
llvm::LoopInfo & LI
void dump(llvm::raw_ostream &ss=llvm::errs())
void visitTruncInst(llvm::TruncInst &I)
void visitPtrToIntInst(llvm::PtrToIntInst &I)
llvm::ScalarEvolution & SE
void run()
Run the interprocedural type analysis starting from this function.
TypeTree getReturnAnalysis()
void visitInsertElementInst(llvm::InsertElementInst &I)
A holder class representing the results of running TypeAnalysis on a given function.
TypeResults(std::nullptr_t)
ConcreteType intType(size_t num, llvm::Value *val, bool errIfNotFound=true, bool pointerIntSame=false) const
llvm::Type * addingType(size_t num, llvm::Value *val, size_t start=0) const
void dump(llvm::raw_ostream &ss=llvm::errs()) const
Prints all known information.
bool anyFloat(llvm::Value *val, bool anythingIsFloat=true) const
Whether any part of the top level register can contain a float e.g.
TypeAnalyzer * analyzer
FnTypeInfo getAnalyzedTypeInfo() const
The TypeInfo calling convention.
bool anyPointer(llvm::Value *val) const
Whether any part of the top level register can contain a pointer e.g.
std::set< int64_t > knownIntegralValues(llvm::Value *val) const
The set of values val will take on during this program.
TypeTree getReturnAnalysis() const
The Type of the return.
TypeTree query(llvm::Value *val) const
The TypeTree of a particular Value.
bool allFloat(llvm::Value *val) const
Whether all of the top level register is known to contain float data.
llvm::Function * getFunction() const
FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn) const
ConcreteType firstPointer(size_t num, llvm::Value *val, llvm::Instruction *I, bool errIfNotFound=true, bool pointerIntSame=false) const
Returns whether in the first num bytes there is pointer, int, float, or none If pointerIntSame is set...
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
Definition TypeTree.h:72
Struct containing all contextual type information for a particular function call.
FnTypeInfo(llvm::Function *fn)
std::map< llvm::Argument *, TypeTree > Arguments
Types of arguments.
llvm::Function * Function
Function being analyzed.
TypeTree Return
Type of return.
std::set< int64_t > knownIntegralValues(llvm::Value *val, const llvm::DominatorTree &DT, std::map< llvm::Value *, std::set< int64_t > > &intseen, llvm::ScalarEvolution &SE) const
The set of known values val will take.
FnTypeInfo & operator=(FnTypeInfo &)=default
FnTypeInfo(const FnTypeInfo &)=default
std::map< llvm::Argument *, std::set< int64_t > > KnownValues
The specific constant(s) known to represented by an argument, if constant.
FnTypeInfo & operator=(FnTypeInfo &&)=default