27#ifndef ENZYME_TYPE_ANALYSIS_TBAA_H
28#define ENZYME_TYPE_ANALYSIS_TBAA_H 1
37 if (N->getNumOperands() < 3)
40 if (!llvm::isa<llvm::MDNode>(N->getOperand(0)))
49 MDNodeTy *
Node =
nullptr;
67 if (
Node->getNumOperands() < 2)
69 MDNodeTy *P = llvm::dyn_cast_or_null<MDNodeTy>(
Node->getOperand(1));
80 if (
Node->getNumOperands() < 3)
82 llvm::ConstantInt *CI =
83 llvm::mdconst::dyn_extract<llvm::ConstantInt>(
Node->getOperand(2));
86 return CI->getValue()[0];
113 if (
Node->getNumOperands() < 4)
122 return llvm::dyn_cast_or_null<llvm::MDNode>(
Node->getOperand(0));
126 return llvm::dyn_cast_or_null<llvm::MDNode>(
Node->getOperand(1));
130 return llvm::mdconst::extract<llvm::ConstantInt>(
Node->getOperand(2))
137 return llvm::mdconst::extract<llvm::ConstantInt>(
Node->getOperand(3))
146 if (
Node->getNumOperands() < OpNo + 1)
148 llvm::ConstantInt *CI =
149 llvm::mdconst::dyn_extract<llvm::ConstantInt>(
Node->getOperand(OpNo));
152 return CI->getValue()[0];
168 const llvm::MDNode *
Node =
nullptr;
193 return (
getNode()->getNumOperands() - FirstFieldOpNo) / NumOpsPerField;
199 unsigned OpIndex = FirstFieldOpNo + FieldIndex * NumOpsPerField;
202 llvm::mdconst::extract<llvm::ConstantInt>(
Node->getOperand(OpIndex + 1))
210 unsigned OpIndex = FirstFieldOpNo + FieldIndex * NumOpsPerField;
211 auto *TypeNode = llvm::cast<llvm::MDNode>(
getNode()->getOperand(OpIndex));
221 if (
Node->getNumOperands() < 6)
225 if (
Node->getNumOperands() < 2)
230 if (
Node->getNumOperands() <= 3) {
232 Node->getNumOperands() == 2
234 : llvm::mdconst::extract<llvm::ConstantInt>(
Node->getOperand(2))
238 llvm::dyn_cast_or_null<llvm::MDNode>(
Node->getOperand(1));
247 unsigned FirstFieldOpNo = NewFormat ? 3 : 1;
248 unsigned NumOpsPerField = NewFormat ? 3 : 2;
250 for (
unsigned Idx = FirstFieldOpNo; Idx <
Node->getNumOperands();
251 Idx += NumOpsPerField) {
253 llvm::mdconst::extract<llvm::ConstantInt>(
Node->getOperand(Idx + 1))
256 assert(Idx >= FirstFieldOpNo + NumOpsPerField &&
257 "TBAAStructTypeNode::getField should have an offset match!");
258 TheIdx = Idx - NumOpsPerField;
264 TheIdx =
Node->getNumOperands() - NumOpsPerField;
266 llvm::mdconst::extract<llvm::ConstantInt>(
Node->getOperand(TheIdx + 1))
270 llvm::dyn_cast_or_null<llvm::MDNode>(
Node->getOperand(TheIdx));
283 return llvm::isa<llvm::MDNode>(MD->getOperand(0)) &&
284 MD->getNumOperands() >= 3;
287static inline const llvm::MDNode *
291 if (!AccessType || AccessType->getNumOperands() < 2)
294 llvm::Type *Int64 = llvm::IntegerType::get(AccessType->getContext(), 64);
296 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(Int64, 0));
301 uint64_t AccessSize = UINT64_MAX;
302 auto *SizeNode = llvm::ConstantAsMetadata::get(
303 llvm::ConstantInt::get(Int64, AccessSize));
304 llvm::Metadata *Ops[] = {
const_cast<llvm::MDNode *
>(AccessType),
305 const_cast<llvm::MDNode *
>(AccessType), OffsetNode,
307 return llvm::MDNode::get(AccessType->getContext(), Ops);
310 llvm::Metadata *Ops[] = {
const_cast<llvm::MDNode *
>(AccessType),
311 const_cast<llvm::MDNode *
>(AccessType), OffsetNode};
312 return llvm::MDNode::get(AccessType->getContext(), Ops);
317static inline std::string
319 const std::set<std::string> &legalnames) {
321 if (M->getNumOperands() < 1)
323 if (
const llvm::MDString *Tag1 =
324 llvm::dyn_cast<llvm::MDString>(M->getOperand(0))) {
325 return Tag1->getString().str();
339 if (
auto *Id = llvm::dyn_cast<llvm::MDString>(AccessType.
getId())) {
341 if (legalnames.count(Id->getString().str())) {
342 return Id->getString().str();
350 if (
auto *Id = llvm::dyn_cast<llvm::MDString>(AccessType.
getId())) {
352 return Id->getString().str();
357static inline std::string
359 const std::set<std::string> &legalnames) {
360 if (
const llvm::MDNode *M =
361 Inst->getMetadata(llvm::LLVMContext::MD_tbaa_struct)) {
362 for (
unsigned i = 2; i < M->getNumOperands(); i += 3) {
363 if (
const llvm::MDNode *M2 =
364 llvm::dyn_cast<llvm::MDNode>(M->getOperand(i))) {
371 if (
const llvm::MDNode *M = Inst->getMetadata(llvm::LLVMContext::MD_tbaa)) {
388 std::shared_ptr<llvm::ModuleSlotTracker> MST) {
389 if (TypeName ==
"long long" || TypeName ==
"long" || TypeName ==
"int" ||
390 TypeName ==
"bool" || TypeName ==
"jtbaa_arraysize" ||
391 TypeName ==
"jtbaa_arraylen") {
393 llvm::errs() <<
"known tbaa ";
395 I.print(llvm::errs(), *MST);
398 llvm::errs() <<
" " << TypeName <<
"\n";
401 }
else if (TypeName ==
"any pointer" || TypeName ==
"vtable pointer" ||
402 TypeName ==
"jtbaa_tag") {
404 llvm::errs() <<
"known tbaa ";
406 I.print(llvm::errs(), *MST);
409 llvm::errs() <<
" " << TypeName <<
"\n";
412 }
else if (TypeName ==
"float") {
414 llvm::errs() <<
"known tbaa ";
416 I.print(llvm::errs(), *MST);
419 llvm::errs() <<
" " << TypeName <<
"\n";
421 return llvm::Type::getFloatTy(I.getContext());
422 }
else if (TypeName ==
"double") {
424 llvm::errs() <<
"known tbaa ";
426 I.print(llvm::errs(), *MST);
429 llvm::errs() <<
" " << TypeName <<
"\n";
431 return llvm::Type::getDoubleTy(I.getContext());
440 llvm::Instruction &I,
441 const llvm::DataLayout &DL,
442 std::shared_ptr<llvm::ModuleSlotTracker> MST) {
444 if (
auto *Id = llvm::dyn_cast<llvm::MDString>(AccessType.
getId())) {
452 for (
unsigned i = 0, size = AccessType.
getNumFields(); i < size; ++i) {
455 auto SubResult =
parseTBAA(SubAccess, I, DL, MST);
466 const llvm::DataLayout &DL,
467 std::shared_ptr<llvm::ModuleSlotTracker> MST) {
469 if (M->getNumOperands() < 1)
471 if (
const llvm::MDString *Tag1 =
472 llvm::dyn_cast<llvm::MDString>(M->getOperand(0))) {
482 return parseTBAA(AccessType, I, DL, MST);
488 const llvm::DataLayout &DL,
489 std::shared_ptr<llvm::ModuleSlotTracker> MST) {
491 if (
const llvm::MDNode *M =
492 I.getMetadata(llvm::LLVMContext::MD_tbaa_struct)) {
493 for (
unsigned i = 0, size = M->getNumOperands(); i < size; i += 3) {
494 if (
const llvm::MDNode *M2 =
495 llvm::dyn_cast<llvm::MDNode>(M->getOperand(i + 2))) {
496 auto SubResult =
parseTBAA(M2, I, DL, MST);
497 auto Start = llvm::cast<llvm::ConstantInt>(
498 llvm::cast<llvm::ConstantAsMetadata>(M->getOperand(i))
502 llvm::cast<llvm::ConstantInt>(
503 llvm::cast<llvm::ConstantAsMetadata>(M->getOperand(i + 1))
512 if (
const llvm::MDNode *M = I.getMetadata(llvm::LLVMContext::MD_tbaa)) {
llvm::PointerUnion< Operation *, Value > Node
static std::string getAccessNameTBAA(const llvm::MDNode *M, const std::set< std::string > &legalnames)
static const llvm::MDNode * createAccessTag(const llvm::MDNode *AccessType)
static ConcreteType getTypeFromTBAAString(std::string TypeName, llvm::Instruction &I, std::shared_ptr< llvm::ModuleSlotTracker > MST)
Derive the ConcreteType corresponding to the string TypeName The llvm::Instruction I denotes the cont...
llvm::cl::opt< bool > EnzymePrintType
The following is not taken from LLVM.
static bool isStructPathTBAA(const llvm::MDNode *MD)
Check the first operand of the tbaa tag node, if it is a llvm::MDNode, we treat it as struct-path awa...
static bool isNewFormatTypeNode(const llvm::MDNode *N)
isNewFormatTypeNode - Return true iff the given type node is in the new size-aware format.
static TypeTree parseTBAA(TBAAStructTypeNode AccessType, llvm::Instruction &I, const llvm::DataLayout &DL, std::shared_ptr< llvm::ModuleSlotTracker > MST)
Given a TBAA access node return the corresponding TypeTree This includes recursively parsing the acce...
Concrete SubType of a given value.
This is a simple wrapper around an llvm::MDNode which provides a higher-level interface by hiding the...
bool isNewFormat() const
isNewFormat - Return true iff the wrapped type node is in the new size-aware format.
TBAANodeImpl< MDNodeTy > getParent() const
getParent - Get this TBAANode's Alias tree parent.
TBAANodeImpl(MDNodeTy *N)
MDNodeTy * getNode() const
getNode - Get the llvm::MDNode for this TBAANode.
bool isTypeImmutable() const
Test if this TBAANode represents a type for objects which are not modified (by any means) in the cont...
This is a simple wrapper around an llvm::MDNode which provides a higher-level interface by hiding the...
bool isNewFormat() const
isNewFormat - Return true iff the wrapped access tag is in the new size-aware format.
MDNodeTy * getBaseType() const
MDNodeTy * getNode() const
Get the llvm::MDNode for this TBAAStructTagNode.
TBAAStructTagNodeImpl(MDNodeTy *N)
MDNodeTy * getAccessType() const
uint64_t getOffset() const
bool isTypeImmutable() const
Test if this TBAAStructTagNode represents a type for objects which are not modified (by any means) in...
This is a simple wrapper around an llvm::MDNode which provides a higher-level interface by hiding the...
const llvm::MDNode * getNode() const
Get the llvm::MDNode for this TBAAStructTypeNode.
TBAAStructTypeNode getField(uint64_t &Offset) const
Get this TBAAStructTypeNode's field in the type DAG with given offset.
TBAAStructTypeNode getFieldType(unsigned FieldIndex) const
bool isNewFormat() const
isNewFormat - Return true iff the wrapped type node is in the new size-aware format.
TBAAStructTypeNode()=default
llvm::Metadata * getId() const
getId - Return type identifier.
TBAAStructTypeNode(const llvm::MDNode *N)
bool operator==(const TBAAStructTypeNode &Other) const
unsigned getNumFields() const
uint64_t getFieldOffset(unsigned FieldIndex) const
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
TypeTree Only(int Off, llvm::Instruction *orig) const
Prepend an offset to all mappings.
TypeTree ShiftIndices(const llvm::DataLayout &dl, const int offset, const int maxSize, size_t addOffset=0) const
Replace mappings in the range in [offset, offset+maxSize] with those in.