25#ifndef ENZYME_TYPE_ANALYSIS_H
26#define ENZYME_TYPE_ANALYSIS_H 1
31#include <llvm/Config/llvm-config.h>
33#include "llvm/ADT/SetVector.h"
34#include "llvm/ADT/StringMap.h"
36#include "llvm/Analysis/TargetLibraryInfo.h"
38#if LLVM_VERSION_MAJOR >= 16
39#include "llvm/Analysis/ScalarEvolution.h"
40#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
42#include "SCEV/ScalarEvolution.h"
43#include "SCEV/ScalarEvolutionExpander.h"
46#include "llvm/IR/Constants.h"
47#include "llvm/IR/InstVisitor.h"
48#include "llvm/IR/ModuleSlotTracker.h"
49#include "llvm/IR/Type.h"
50#include "llvm/IR/Value.h"
52#include "llvm/Analysis/LoopInfo.h"
53#include "llvm/Analysis/PostDominators.h"
54#include "llvm/IR/Dominators.h"
62 llvm::Intrinsic::ID *ID =
nullptr) {
63 llvm::StringRef ogstr =
str;
65 if (
str ==
"llvm.enzyme.lifetime_start") {
66 *ID = llvm::Intrinsic::lifetime_start;
69 if (
str ==
"llvm.enzyme.lifetime_end") {
70 *ID = llvm::Intrinsic::lifetime_end;
132 std::map<llvm::Value *, std::set<int64_t>> &intseen,
133 llvm::ScalarEvolution &SE)
const;
148 for (
auto &arg : lhs.
Function->args()) {
150 auto foundLHS = lhs.
Arguments.find(&arg);
152 auto foundRHS = rhs.
Arguments.find(&arg);
154 if (foundLHS->second < foundRHS->second)
156 if (foundRHS->second < foundLHS->second)
165 if (foundLHS->second < foundRHS->second)
167 if (foundRHS->second < foundLHS->second)
189 bool pointerIntSame =
false)
const;
190 llvm::Type *
addingType(
size_t num, llvm::Value *val,
size_t start = 0)
const;
196 bool errIfNotFound =
true,
197 bool pointerIntSame =
false)
const;
207 bool anyFloat(llvm::Value *val,
bool anythingIsFloat =
true)
const;
210 bool allFloat(llvm::Value *val)
const;
224 void dump(llvm::raw_ostream &ss = llvm::errs())
const;
239 std::shared_ptr<llvm::ModuleSlotTracker>
MST;
242 llvm::SetVector<llvm::Value *, std::deque<llvm::Value *>>
workList;
248 void addToWorkList(llvm::Value *val);
251 std::map<llvm::Value *, std::set<int64_t>> intseen;
253 std::map<llvm::Value *, std::pair<bool, bool>> mriseen;
254 bool mustRemainInteger(llvm::Value *val,
bool *returned =
nullptr);
274 static constexpr uint8_t
UP = 1;
276 static constexpr uint8_t
DOWN = 2;
282 llvm::TargetLibraryInfo &
TLI;
283 llvm::DominatorTree &
DT;
284 llvm::PostDominatorTree &
PDT;
287 llvm::ScalarEvolution &
SE;
371#if LLVM_VERSION_MAJOR >= 10
372 void visitFreezeInst(llvm::FreezeInst &I);
391 llvm::Instruction::BinaryOps, llvm::Value *
Args[2],
393 llvm::Instruction *I);
395 void visitIPOCall(llvm::CallBase &call, llvm::Function &fn);
406 void dump(llvm::raw_ostream &ss = llvm::errs());
420 std::function<bool(
int ,
TypeTree & ,
421 llvm::ArrayRef<TypeTree> ,
422 llvm::ArrayRef<std::set<int64_t>> ,
437 bool intIsPointer =
true);
439 llvm::Function *todiff);
BaseType
Categories of potential types.
static std::string str(AugmentedStruct c)
FnTypeInfo preventTypeAnalysisLoops(const FnTypeInfo &oldTypeInfo_, llvm::Function *todiff)
TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I, bool intIsPointer=true)
const llvm::StringMap< llvm::Intrinsic::ID > LIBM_FUNCTIONS
static bool isMemFreeLibMFunction(llvm::StringRef str, llvm::Intrinsic::ID *ID=nullptr)
static bool operator<(const FnTypeInfo &lhs, const FnTypeInfo &rhs)
@ Args
Return is a struct of all args.
static bool startsWith(llvm::StringRef string, llvm::StringRef prefix)
static bool endsWith(llvm::StringRef string, llvm::StringRef suffix)
Concrete SubType of a given value.
Full interprocedural TypeAnalysis.
llvm::StringMap< std::function< bool(int, TypeTree &, llvm::ArrayRef< TypeTree >, llvm::ArrayRef< std::set< int64_t > >, llvm::CallBase *, TypeAnalyzer *)> > CustomRules
Map of custom function call handlers.
std::map< FnTypeInfo, std::shared_ptr< TypeAnalyzer > > analyzedFunctions
Map of possible query states to TypeAnalyzer intermediate results.
void clear()
Clear existing analyses.
TypeAnalysis(EnzymeLogic &Logic)
TypeResults analyzeFunction(const FnTypeInfo &fn)
Analyze a particular function, returning the results.
Helper class that computes the fixed-point type results of a given function.
void visitMemTransferInst(llvm::MemTransferInst &MTI)
void visitFPToSIInst(llvm::FPToSIInst &I)
llvm::PostDominatorTree & PDT
FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn)
void visitAllocaInst(llvm::AllocaInst &I)
const FnTypeInfo fntypeinfo
Calling context.
void visitExtractValueInst(llvm::ExtractValueInst &I)
void visitSIToFPInst(llvm::SIToFPInst &I)
const llvm::SmallPtrSet< llvm::BasicBlock *, 4 > notForAnalysis
void considerRustDebugInfo()
Parse the debug info generated by rustc and retrieve useful type info if possible.
std::set< int64_t > knownIntegralValues(llvm::Value *val)
void visitFPExtInst(llvm::FPExtInst &I)
void visitConstantExpr(llvm::ConstantExpr &CE)
TypeAnalysis & interprocedural
Calling TypeAnalysis to be used in the case of calls to other functions.
void visitGetElementPtrInst(llvm::GetElementPtrInst &gep)
void visitIntToPtrInst(llvm::IntToPtrInst &I)
void prepareArgs()
Analyze type info given by the arguments, possibly adding to work queue.
void visitShuffleVectorInst(llvm::ShuffleVectorInst &I)
llvm::TargetLibraryInfo & TLI
void visitUIToFPInst(llvm::UIToFPInst &I)
static constexpr uint8_t UP
void visitInsertValueInst(llvm::InsertValueInst &I)
void visitPHINode(llvm::PHINode &phi)
void visitValue(llvm::Value &val)
void visitSelectInst(llvm::SelectInst &I)
void visitIPOCall(llvm::CallBase &call, llvm::Function &fn)
std::shared_ptr< llvm::ModuleSlotTracker > MST
Cache of metadata indices, for faster printing.
llvm::SetVector< llvm::Value *, std::deque< llvm::Value * > > workList
List of value's which should be re-analyzed now with new information.
static constexpr uint8_t DOWN
void updateAnalysis(llvm::Value *val, BaseType data, llvm::Value *origin)
Add additional information to the Type info of val, readding it to the work queue as necessary.
void visitAddrSpaceCastInst(llvm::AddrSpaceCastInst &I)
bool Invalid
Whether an inconsistent update has been found This will only be set when direction !...
static constexpr uint8_t BOTH
void visitExtractElementInst(llvm::ExtractElementInst &I)
void visitLoadInst(llvm::LoadInst &I)
void visitCmpInst(llvm::CmpInst &I)
uint8_t direction
Directionality of checks.
std::map< llvm::Value *, TypeTree > analysis
Intermediate conservative, but correct Type analysis results.
void visitIntrinsicInst(llvm::IntrinsicInst &II)
void visitSExtInst(llvm::SExtInst &I)
void visitGEPOperator(llvm::GEPOperator &gep)
void updateAnalysis(llvm::Value *val, ConcreteType data, llvm::Value *origin)
void considerTBAA()
Analyze type info given by the TBAA, possibly adding to work queue.
void visitCallBase(llvm::CallBase &call)
void visitFPTruncInst(llvm::FPTruncInst &I)
void visitStoreInst(llvm::StoreInst &I)
void visitAtomicRMWInst(llvm::AtomicRMWInst &I)
TypeAnalyzer(TypeAnalysis &TA)
void visitZExtInst(llvm::ZExtInst &I)
void runPHIHypotheses()
Hypothesize that undefined phi's are integers and try to prove that they are really integral.
void visitBinaryOperator(llvm::BinaryOperator &I)
void visitMemTransferCommon(llvm::CallBase &MTI)
TypeTree getAnalysis(llvm::Value *Val)
Get the current results for a given value.
void visitBinaryOperation(const llvm::DataLayout &DL, llvm::Type *T, llvm::Instruction::BinaryOps, llvm::Value *Args[2], TypeTree &Ret, TypeTree &LHS, TypeTree &RHS, llvm::Instruction *I)
void visitFPToUIInst(llvm::FPToUIInst &I)
void visitBitCastInst(llvm::BitCastInst &I)
void updateAnalysis(llvm::Value *val, TypeTree data, llvm::Value *origin)
void dump(llvm::raw_ostream &ss=llvm::errs())
void visitTruncInst(llvm::TruncInst &I)
void visitPtrToIntInst(llvm::PtrToIntInst &I)
llvm::ScalarEvolution & SE
void run()
Run the interprocedural type analysis starting from this function.
TypeTree getReturnAnalysis()
void visitInsertElementInst(llvm::InsertElementInst &I)
A holder class representing the results of running TypeAnalysis on a given function.
TypeResults(std::nullptr_t)
ConcreteType intType(size_t num, llvm::Value *val, bool errIfNotFound=true, bool pointerIntSame=false) const
llvm::Type * addingType(size_t num, llvm::Value *val, size_t start=0) const
void dump(llvm::raw_ostream &ss=llvm::errs()) const
Prints all known information.
bool anyFloat(llvm::Value *val, bool anythingIsFloat=true) const
Whether any part of the top level register can contain a float e.g.
FnTypeInfo getAnalyzedTypeInfo() const
The TypeInfo calling convention.
bool anyPointer(llvm::Value *val) const
Whether any part of the top level register can contain a pointer e.g.
std::set< int64_t > knownIntegralValues(llvm::Value *val) const
The set of values val will take on during this program.
TypeTree getReturnAnalysis() const
The Type of the return.
TypeTree query(llvm::Value *val) const
The TypeTree of a particular Value.
bool allFloat(llvm::Value *val) const
Whether all of the top level register is known to contain float data.
llvm::Function * getFunction() const
FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn) const
ConcreteType firstPointer(size_t num, llvm::Value *val, llvm::Instruction *I, bool errIfNotFound=true, bool pointerIntSame=false) const
Returns whether in the first num bytes there is pointer, int, float, or none If pointerIntSame is set...
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
Struct containing all contextual type information for a particular function call.
FnTypeInfo(llvm::Function *fn)
std::map< llvm::Argument *, TypeTree > Arguments
Types of arguments.
llvm::Function * Function
Function being analyzed.
TypeTree Return
Type of return.
std::set< int64_t > knownIntegralValues(llvm::Value *val, const llvm::DominatorTree &DT, std::map< llvm::Value *, std::set< int64_t > > &intseen, llvm::ScalarEvolution &SE) const
The set of known values val will take.
FnTypeInfo & operator=(FnTypeInfo &)=default
FnTypeInfo(const FnTypeInfo &)=default
std::map< llvm::Argument *, std::set< int64_t > > KnownValues
The specific constant(s) known to represented by an argument, if constant.
FnTypeInfo & operator=(FnTypeInfo &&)=default