Enzyme main
Loading...
Searching...
No Matches
ConcreteType.h
Go to the documentation of this file.
1//===- ConcreteType.h - Underlying SubType used in Type Analysis
2//------------===//
3//
4// Enzyme Project
5//
6// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
7// Exceptions. See https://llvm.org/LICENSE.txt for license information.
8// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9//
10// If using this code in an academic setting, please cite the following:
11// @incollection{enzymeNeurips,
12// title = {Instead of Rewriting Foreign Code for Machine Learning,
13// Automatically Synthesize Fast Gradients},
14// author = {Moses, William S. and Churavy, Valentin},
15// booktitle = {Advances in Neural Information Processing Systems 33},
16// year = {2020},
17// note = {To appear in},
18// }
19//
20//===----------------------------------------------------------------------===//
21//
22// This file contains the implementation of an a class representing all
23// potential end SubTypes used in Type Analysis. This ``ConcreteType`` contains
24// an the SubType category ``BaseType`` as well as the SubType of float, if
25// relevant. This also contains several helper utility functions.
26//
27//===----------------------------------------------------------------------===//
28#ifndef ENZYME_TYPE_ANALYSIS_CONCRETE_TYPE_H
29#define ENZYME_TYPE_ANALYSIS_CONCRETE_TYPE_H 1
30
31#include <string>
32
33#include "llvm/IR/InstrTypes.h"
34#include "llvm/IR/Type.h"
35#include "llvm/Support/ErrorHandling.h"
36
37#include "BaseType.h"
38
39/// Concrete SubType of a given value. Consists of a category `BaseType` and the
40/// particular floating point value, if relevant.
42public:
43 /// Category of underlying type
45 /// Floating point type, if relevant, otherwise nullptr
46 llvm::Type *SubType;
47
48 /// Construct a ConcreteType from an existing FloatingPoint Type
49 ConcreteType(llvm::Type *SubType)
51 assert(SubType != nullptr);
52 assert(!llvm::isa<llvm::VectorType>(SubType));
53 if (!SubType->isFloatingPointTy()) {
54 llvm::errs() << " passing in non FP SubType: " << *SubType << "\n";
55 }
56 assert(SubType->isFloatingPointTy());
57 }
58
59 /// Construct a non-floating Concrete type from a BaseType
64
65 /// Construct a ConcreteType from a string
66 /// A Concrete Type's string representation is given by the string of the
67 /// enum If it is a floating point it is given by Float@<specific_type>
68 ConcreteType(llvm::StringRef Str, llvm::LLVMContext &C) {
69 auto Sep = Str.find('@');
70 if (Sep != llvm::StringRef::npos) {
72 assert(Str.substr(0, Sep) == "Float");
73 auto SubName = Str.substr(Sep + 1);
74 if (SubName == "half") {
75 SubType = llvm::Type::getHalfTy(C);
76 } else if (SubName == "float") {
77 SubType = llvm::Type::getFloatTy(C);
78 } else if (SubName == "double") {
79 SubType = llvm::Type::getDoubleTy(C);
80 } else if (SubName == "fp80") {
81 SubType = llvm::Type::getX86_FP80Ty(C);
82 } else if (SubName == "bf16") {
83 SubType = llvm::Type::getBFloatTy(C);
84 } else if (SubName == "fp128") {
85 SubType = llvm::Type::getFP128Ty(C);
86 } else if (SubName == "ppc128") {
87 SubType = llvm::Type::getPPC_FP128Ty(C);
88 } else {
89 llvm_unreachable("unknown data SubType");
90 }
91 } else {
92 SubType = nullptr;
94 }
95 }
96
97 /// Convert the ConcreteType to a string
98 std::string str() const {
99 std::string Result = to_string(SubTypeEnum);
101 if (SubType->isHalfTy()) {
102 Result += "@half";
103 } else if (SubType->isFloatTy()) {
104 Result += "@float";
105 } else if (SubType->isDoubleTy()) {
106 Result += "@double";
107 } else if (SubType->isX86_FP80Ty()) {
108 Result += "@fp80";
109 } else if (SubType->isBFloatTy()) {
110 Result += "@bf16";
111 } else if (SubType->isFP128Ty()) {
112 Result += "@fp128";
113 } else if (SubType->isPPC_FP128Ty()) {
114 Result += "@ppc128";
115 } else {
116 llvm_unreachable("unknown data SubType");
117 }
118 }
119 return Result;
120 }
121
122 /// Whether this ConcreteType has information (is not unknown)
123 bool isKnown() const { return SubTypeEnum != BaseType::Unknown; }
124
125 /// Whether this ConcreteType must an integer
126 bool isIntegral() const { return SubTypeEnum == BaseType::Integer; }
127
128 /// Whether this ConcreteType could be a pointer (SubTypeEnum is unknown or a
129 /// pointer)
135
136 /// Whether this ConcreteType could be a float (SubTypeEnum is unknown or a
137 /// float)
143
144 /// Return the floating point type, if this is a float
145 llvm::Type *isFloat() const { return SubType; }
146
147 /// Return if this is known to be the BaseType BT
148 /// This cannot be called with BaseType::Float as it lacks information
149 bool operator==(const BaseType BT) const {
150 if (BT == BaseType::Float) {
151 assert(0 &&
152 "Cannot do comparision between ConcreteType and BaseType::Float");
153 llvm_unreachable(
154 "Cannot do comparision between ConcreteType and BaseType::Float");
155 }
156 return SubTypeEnum == BT;
157 }
158
159 /// Return if this is known not to be the BaseType BT
160 /// This cannot be called with BaseType::Float as it lacks information
161 bool operator!=(const BaseType BT) const {
162 if (BT == BaseType::Float) {
163 assert(0 &&
164 "Cannot do comparision between ConcreteType and BaseType::Float");
165 llvm_unreachable(
166 "Cannot do comparision between ConcreteType and BaseType::Float");
167 }
168 return SubTypeEnum != BT;
169 }
170
171 /// Return if this is known to be the ConcreteType CT
172 bool operator==(const ConcreteType CT) const {
173 return SubType == CT.SubType && SubTypeEnum == CT.SubTypeEnum;
174 }
175
176 /// Return if this is known not to be the ConcreteType CT
177 bool operator!=(const ConcreteType CT) const { return !(*this == CT); }
178
179 /// Set this to the given ConcreteType, returning true if
180 /// this ConcreteType has changed
181 bool operator=(const ConcreteType CT) {
182 bool changed = false;
183 if (SubTypeEnum != CT.SubTypeEnum)
184 changed = true;
186 if (SubType != CT.SubType)
187 changed = true;
188 SubType = CT.SubType;
189 return changed;
190 }
191
192 /// Set this to the given BaseType, returning true if
193 /// this ConcreteType has changed
194 bool operator=(const BaseType BT) {
195 assert(BT != BaseType::Float);
197 }
198
199 /// Set this to the logical or of itself and CT, returning whether this value
200 /// changed Setting `PointerIntSame` considers pointers and integers as
201 /// equivalent If this is an illegal operation, `LegalOr` will be set to false
202 bool checkedOrIn(const ConcreteType CT, bool PointerIntSame, bool &LegalOr) {
204 return false;
205 }
207 return *this = CT;
208 }
210 return *this = CT;
211 }
212 if (CT.SubTypeEnum == BaseType::Unknown) {
213 return false;
214 }
215 if (CT.SubTypeEnum != SubTypeEnum) {
216 if (PointerIntSame) {
221 return false;
222 }
223 }
224 LegalOr = false;
225 return false;
226 }
227 assert(CT.SubTypeEnum == SubTypeEnum);
228 if (CT.SubType != SubType) {
229 LegalOr = false;
230 return false;
231 }
232 assert(CT.SubType == SubType);
233 return false;
234 }
235
236 /// Set this to the logical or of itself and CT, returning whether this value
237 /// changed Setting `PointerIntSame` considers pointers and integers as
238 /// equivalent This function will error if doing an illegal Operation
239 bool orIn(const ConcreteType CT, bool PointerIntSame) {
240 bool Legal = true;
241 bool Result = checkedOrIn(CT, PointerIntSame, Legal);
242 if (!Legal) {
243 llvm::errs() << "Illegal orIn: " << str() << " right: " << CT.str()
244 << " PointerIntSame=" << PointerIntSame << "\n";
245 assert(0 && "Performed illegal ConcreteType::orIn");
246 llvm_unreachable("Performed illegal ConcreteType::orIn");
247 }
248 return Result;
249 }
250
251 /// Set this to the logical or of itself and CT, returning whether this value
252 /// changed This assumes that pointers and integers are distinct This function
253 /// will error if doing an illegal Operation
254 bool operator|=(const ConcreteType CT) {
255 return orIn(CT, /*pointerIntSame*/ false);
256 }
257
258 /// Set this to the logical and of itself and CT, returning whether this value
259 /// changed If this and CT are incompatible, the result will be
260 /// BaseType::Unknown
261 bool andIn(const ConcreteType CT) {
263 return *this = CT;
264 }
266 return false;
267 }
269 return false;
270 }
271 if (CT.SubTypeEnum == BaseType::Unknown) {
272 return *this = CT;
273 }
274
275 if (CT.SubTypeEnum != SubTypeEnum) {
276 return *this = BaseType::Unknown;
277 }
278 if (CT.SubType != SubType) {
279 return *this = BaseType::Unknown;
280 }
281 return false;
282 }
283
284 /// Set this to the logical and of itself and CT, returning whether this value
285 /// changed If this and CT are incompatible, the result will be
286 /// BaseType::Unknown
287 bool operator&=(const ConcreteType CT) { return andIn(CT); }
288
289 /// Keep only mappings where the type is not an `Anything`
292 return BaseType::Unknown;
293 return *this;
294 }
295
296 /// Set this to the logical `binop` of itself and RHS, using the Binop Op,
297 /// returning true if this was changed.
298 /// This function will error on an invalid type combination
299 bool binopIn(bool &Legal, const ConcreteType RHS,
300 llvm::BinaryOperator::BinaryOps Op) {
301 bool Changed = false;
302 using namespace llvm;
303
304 // Anything op Anything => Anything
307 return Changed;
308 }
309
310 // [?] op float => Unknown
314 RHS.isFloat()) ||
315 (isFloat() && (RHS.SubTypeEnum == BaseType::Anything ||
317 RHS.SubTypeEnum == BaseType::Unknown)))) {
319 SubType = nullptr;
320 Changed = true;
321 return Changed;
322 }
323
324 // Unknown op Anything => Unknown
331 Changed = true;
332 }
333 return Changed;
334 }
335
336 // Integer op Integer => Integer
339 return Changed;
340 }
341
342 // Integer op Anything => {Anything, Integer}
347
348 switch (Op) {
349 // The result of these operands mix data between LHS/RHS
350 // Therefore there is some "anything" data in the result
351 case BinaryOperator::Add:
352 case BinaryOperator::Sub:
353 case BinaryOperator::Mul:
354 case BinaryOperator::And:
355 case BinaryOperator::Or:
356 case BinaryOperator::Xor:
359 Changed = true;
360 }
361 break;
362
363 // The result of these operands only use data from LHS
364 case BinaryOperator::UDiv:
365 case BinaryOperator::SDiv:
366 case BinaryOperator::URem:
367 case BinaryOperator::SRem:
368 case BinaryOperator::Shl:
369 case BinaryOperator::AShr:
370 case BinaryOperator::LShr:
371 // No change since we retain data from LHS
372 break;
373 default:
374 Legal = false;
375 return Changed;
376 }
377 return Changed;
378 }
379
380 // Integer op Unknown => Unknown
381 // e.g. pointer + int = pointer and int + int = int
388 Changed = true;
389 }
390 return Changed;
391 }
392
393 // Pointer op Pointer => {Integer, Illegal}
396 switch (Op) {
397 case BinaryOperator::Sub:
399 Changed = true;
400 break;
401 case BinaryOperator::Add:
402 case BinaryOperator::Mul:
403 case BinaryOperator::UDiv:
404 case BinaryOperator::SDiv:
405 case BinaryOperator::URem:
406 case BinaryOperator::SRem:
407 case BinaryOperator::And:
408 case BinaryOperator::Or:
409 case BinaryOperator::Xor:
410 case BinaryOperator::Shl:
411 case BinaryOperator::AShr:
412 case BinaryOperator::LShr:
413 default:
414 Legal = false;
415 return Changed;
416 }
417 return Changed;
418 }
419
420 // Pointer - Unknown => Unknown
421 // This is because Pointer - Pointer => Integer
422 // and Pointer - Integer => Pointer
423 if (Op == BinaryOperator::Sub && SubTypeEnum == BaseType::Pointer &&
426 Changed = true;
427 return Changed;
428 }
429
430 // Pointer op ? => {Pointer, Unknown}
445
446 switch (Op) {
447 case BinaryOperator::Sub:
452 Changed = true;
453 }
454 break;
455 }
456 if (RHS.SubTypeEnum == BaseType::Pointer) {
459 Changed = true;
460 }
461 break;
462 }
463 [[fallthrough]];
464 case BinaryOperator::Add:
465 case BinaryOperator::Mul:
468 Changed = true;
469 }
470 break;
471 case BinaryOperator::UDiv:
472 case BinaryOperator::SDiv:
473 case BinaryOperator::URem:
474 case BinaryOperator::SRem:
475 if (RHS.SubTypeEnum == BaseType::Pointer) {
476 Legal = false;
477 return Changed;
478 } else if (SubTypeEnum != BaseType::Unknown) {
480 Changed = true;
481 }
482 break;
483 case BinaryOperator::And:
484 case BinaryOperator::Or:
485 case BinaryOperator::Xor:
486 case BinaryOperator::Shl:
487 case BinaryOperator::AShr:
488 case BinaryOperator::LShr:
491 Changed = true;
492 }
493 break;
494 default:
495 Legal = false;
496 return Changed;
497 }
498 return Changed;
499 }
500
501 Legal = false;
502 return Changed;
503 }
504
505 /// Compare concrete types for use in map's
506 bool operator<(const ConcreteType dt) const {
507 if (SubTypeEnum == dt.SubTypeEnum) {
508 return SubType < dt.SubType;
509 } else {
510 return SubTypeEnum < dt.SubTypeEnum;
511 }
512 }
513};
514
515// Convert ConcreteType to string
516static inline std::string to_string(const ConcreteType dt) { return dt.str(); }
517
518#endif
static BaseType parseBaseType(llvm::StringRef str)
Convert string to BaseType.
Definition BaseType.h:64
BaseType
Categories of potential types.
Definition BaseType.h:32
static std::string to_string(const ConcreteType dt)
Concrete SubType of a given value.
bool operator|=(const ConcreteType CT)
Set this to the logical or of itself and CT, returning whether this value changed This assumes that p...
bool isIntegral() const
Whether this ConcreteType must an integer.
bool operator==(const BaseType BT) const
Return if this is known to be the BaseType BT This cannot be called with BaseType::Float as it lacks ...
bool isPossiblePointer() const
Whether this ConcreteType could be a pointer (SubTypeEnum is unknown or a pointer)
bool binopIn(bool &Legal, const ConcreteType RHS, llvm::BinaryOperator::BinaryOps Op)
Set this to the logical binop of itself and RHS, using the Binop Op, returning true if this was chang...
ConcreteType(BaseType SubTypeEnum)
Construct a non-floating Concrete type from a BaseType.
bool operator=(const ConcreteType CT)
Set this to the given ConcreteType, returning true if this ConcreteType has changed.
bool isKnown() const
Whether this ConcreteType has information (is not unknown)
llvm::Type * SubType
Floating point type, if relevant, otherwise nullptr.
bool isPossibleFloat() const
Whether this ConcreteType could be a float (SubTypeEnum is unknown or a float)
bool operator==(const ConcreteType CT) const
Return if this is known to be the ConcreteType CT.
bool operator&=(const ConcreteType CT)
Set this to the logical and of itself and CT, returning whether this value changed If this and CT are...
ConcreteType(llvm::Type *SubType)
Construct a ConcreteType from an existing FloatingPoint Type.
bool operator!=(const ConcreteType CT) const
Return if this is known not to be the ConcreteType CT.
std::string str() const
Convert the ConcreteType to a string.
ConcreteType PurgeAnything() const
Keep only mappings where the type is not an Anything
bool operator!=(const BaseType BT) const
Return if this is known not to be the BaseType BT This cannot be called with BaseType::Float as it la...
BaseType SubTypeEnum
Category of underlying type.
bool checkedOrIn(const ConcreteType CT, bool PointerIntSame, bool &LegalOr)
Set this to the logical or of itself and CT, returning whether this value changed Setting PointerIntS...
ConcreteType(llvm::StringRef Str, llvm::LLVMContext &C)
Construct a ConcreteType from a string A Concrete Type's string representation is given by the string...
bool andIn(const ConcreteType CT)
Set this to the logical and of itself and CT, returning whether this value changed If this and CT are...
bool operator<(const ConcreteType dt) const
Compare concrete types for use in map's.
bool orIn(const ConcreteType CT, bool PointerIntSame)
Set this to the logical or of itself and CT, returning whether this value changed Setting PointerIntS...
llvm::Type * isFloat() const
Return the floating point type, if this is a float.
bool operator=(const BaseType BT)
Set this to the given BaseType, returning true if this ConcreteType has changed.