Enzyme main
Loading...
Searching...
No Matches
TypeAnalysis.cpp
Go to the documentation of this file.
1//===- TypeAnalysis.cpp - Implementation of Type Analysis ------------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains the implementation of Type Analysis, a utility for
22// computing the underlying data type of LLVM values.
23//
24//===----------------------------------------------------------------------===//
25#include <cstdint>
26#include <deque>
27
28#include <llvm/Config/llvm-config.h>
29
30#include "llvm/Demangle/Demangle.h"
31#include "llvm/Demangle/ItaniumDemangle.h"
32
33#include "llvm/IR/Constants.h"
34#include "llvm/IR/Function.h"
35#include "llvm/IR/InstrTypes.h"
36#include "llvm/IR/Instructions.h"
37#include "llvm/IR/IntrinsicInst.h"
38#include "llvm/IR/ModuleSlotTracker.h"
39#include "llvm/IR/Type.h"
40#include "llvm/IR/Value.h"
41
42#include "llvm/IR/InstIterator.h"
43
44#include "llvm/Support/CommandLine.h"
45#include "llvm/Support/TimeProfiler.h"
46#include "llvm/Support/raw_ostream.h"
47
48#include "llvm/ADT/SmallSet.h"
49#include "llvm/ADT/StringMap.h"
50#include "llvm/ADT/StringSet.h"
51
52#include "llvm/IR/InlineAsm.h"
53
54#include "../EnzymeLogic.h"
55#include "../Utils.h"
56#include "TypeAnalysis.h"
57
58#include "../FunctionUtils.h"
59#include "../LibraryFuncs.h"
60
61#include "RustDebugInfo.h"
62#include "TBAA.h"
63
64#include <math.h>
65
66#if LLVM_VERSION_MAJOR >= 14
67#define getAttribute getAttributeAtIndex
68#define hasAttribute hasAttributeAtIndex
69#define addAttribute addAttributeAtIndex
70#endif
71
72using namespace llvm;
73
74extern "C" {
75/// Maximum offset for type trees to keep
76llvm::cl::opt<int> MaxIntOffset("enzyme-max-int-offset", cl::init(100),
77 cl::Hidden,
78 cl::desc("Maximum type tree offset"));
79
80llvm::cl::opt<unsigned> EnzymeMaxTypeDepth("enzyme-max-type-depth", cl::init(6),
81 cl::Hidden,
82 cl::desc("Maximum type tree depth"));
83
84llvm::cl::opt<bool> EnzymePrintType("enzyme-print-type", cl::init(false),
85 cl::Hidden,
86 cl::desc("Print type analysis algorithm"));
87
88llvm::cl::opt<bool> RustTypeRules("enzyme-rust-type", cl::init(false),
89 cl::Hidden,
90 cl::desc("Enable rust-specific type rules"));
91
92llvm::cl::opt<bool> EnzymeStrictAliasing(
93 "enzyme-strict-aliasing", cl::init(true), cl::Hidden,
94 cl::desc("Assume strict aliasing of types / type stability"));
95}
96
97const llvm::StringMap<llvm::Intrinsic::ID> LIBM_FUNCTIONS = {
98 {"sinc", Intrinsic::not_intrinsic},
99 {"sincn", Intrinsic::not_intrinsic},
100 {"cos", Intrinsic::cos},
101 {"sin", Intrinsic::sin},
102 {"tan", Intrinsic::not_intrinsic},
103 {"acos", Intrinsic::not_intrinsic},
104 {"__nv_frcp_rd", Intrinsic::not_intrinsic},
105 {"__nv_frcp_rn", Intrinsic::not_intrinsic},
106 {"__nv_frcp_ru", Intrinsic::not_intrinsic},
107 {"__nv_frcp_rz", Intrinsic::not_intrinsic},
108 {"__nv_drcp_rd", Intrinsic::not_intrinsic},
109 {"__nv_drcp_rn", Intrinsic::not_intrinsic},
110 {"__nv_drcp_ru", Intrinsic::not_intrinsic},
111 {"__nv_drcp_rz", Intrinsic::not_intrinsic},
112 {"asin", Intrinsic::not_intrinsic},
113 {"__nv_asin", Intrinsic::not_intrinsic},
114 {"atan", Intrinsic::not_intrinsic},
115 {"atan2", Intrinsic::not_intrinsic},
116 {"__nv_atan2", Intrinsic::not_intrinsic},
117#if LLVM_VERSION_MAJOR >= 19
118 {"cosh", Intrinsic::cosh},
119 {"sinh", Intrinsic::sinh},
120 {"tanh", Intrinsic::tanh},
121#else
122 {"cosh", Intrinsic::not_intrinsic},
123 {"sinh", Intrinsic::not_intrinsic},
124 {"tanh", Intrinsic::not_intrinsic},
125#endif
126 {"acosh", Intrinsic::not_intrinsic},
127 {"asinh", Intrinsic::not_intrinsic},
128 {"atanh", Intrinsic::not_intrinsic},
129 {"exp", Intrinsic::exp},
130 {"exp2", Intrinsic::exp2},
131 {"exp10", Intrinsic::not_intrinsic},
132 {"log", Intrinsic::log},
133 {"log10", Intrinsic::log10},
134 {"expm1", Intrinsic::not_intrinsic},
135 {"log1p", Intrinsic::not_intrinsic},
136 {"log2", Intrinsic::log2},
137 {"logb", Intrinsic::not_intrinsic},
138 {"pow", Intrinsic::pow},
139 {"sqrt", Intrinsic::sqrt},
140 {"cbrt", Intrinsic::not_intrinsic},
141 {"hypot", Intrinsic::not_intrinsic},
142
143 {"__mulsc3", Intrinsic::not_intrinsic},
144 {"__muldc3", Intrinsic::not_intrinsic},
145 {"__multc3", Intrinsic::not_intrinsic},
146 {"__mulxc3", Intrinsic::not_intrinsic},
147
148 {"__divsc3", Intrinsic::not_intrinsic},
149 {"__divdc3", Intrinsic::not_intrinsic},
150 {"__divtc3", Intrinsic::not_intrinsic},
151 {"__divxc3", Intrinsic::not_intrinsic},
152
153 {"Faddeeva_erf", Intrinsic::not_intrinsic},
154 {"Faddeeva_erfc", Intrinsic::not_intrinsic},
155 {"Faddeeva_erfcx", Intrinsic::not_intrinsic},
156 {"Faddeeva_erfi", Intrinsic::not_intrinsic},
157 {"Faddeeva_dawson", Intrinsic::not_intrinsic},
158 {"Faddeeva_erf_re", Intrinsic::not_intrinsic},
159 {"Faddeeva_erfc_re", Intrinsic::not_intrinsic},
160 {"Faddeeva_erfcx_re", Intrinsic::not_intrinsic},
161 {"Faddeeva_erfi_re", Intrinsic::not_intrinsic},
162 {"Faddeeva_dawson_re", Intrinsic::not_intrinsic},
163 {"erf", Intrinsic::not_intrinsic},
164 {"erfi", Intrinsic::not_intrinsic},
165 {"erfc", Intrinsic::not_intrinsic},
166 {"erfinv", Intrinsic::not_intrinsic},
167
168 {"__fd_sincos_1", Intrinsic::not_intrinsic},
169 {"sincospi", Intrinsic::not_intrinsic},
170 {"cmplx_inv", Intrinsic::not_intrinsic},
171
172 // bessel functions
173 {"j0", Intrinsic::not_intrinsic},
174 {"j1", Intrinsic::not_intrinsic},
175 {"jn", Intrinsic::not_intrinsic},
176 {"y0", Intrinsic::not_intrinsic},
177 {"y1", Intrinsic::not_intrinsic},
178 {"yn", Intrinsic::not_intrinsic},
179 {"tgamma", Intrinsic::not_intrinsic},
180 {"lgamma", Intrinsic::not_intrinsic},
181 {"logabsgamma", Intrinsic::not_intrinsic},
182 {"ceil", Intrinsic::ceil},
183 {"__nv_ceil", Intrinsic::ceil},
184 {"floor", Intrinsic::floor},
185 {"fmod", Intrinsic::not_intrinsic},
186 {"trunc", Intrinsic::trunc},
187 {"round", Intrinsic::round},
188 {"rint", Intrinsic::rint},
189 {"nearbyint", Intrinsic::nearbyint},
190 {"remainder", Intrinsic::not_intrinsic},
191 {"copysign", Intrinsic::copysign},
192 {"nextafter", Intrinsic::not_intrinsic},
193 {"nexttoward", Intrinsic::not_intrinsic},
194 {"fdim", Intrinsic::not_intrinsic},
195 {"fmax", Intrinsic::maxnum},
196 {"fmin", Intrinsic::minnum},
197 {"fabs", Intrinsic::fabs},
198 {"fma", Intrinsic::fma},
199 {"ilogb", Intrinsic::not_intrinsic},
200 {"scalbn", Intrinsic::not_intrinsic},
201 {"scalbln", Intrinsic::not_intrinsic},
202 {"powi", Intrinsic::powi},
203 {"cabs", Intrinsic::not_intrinsic},
204 {"ldexp", Intrinsic::not_intrinsic},
205 {"fmod", Intrinsic::not_intrinsic},
206 {"finite", Intrinsic::not_intrinsic},
207 {"isinf", Intrinsic::not_intrinsic},
208 {"isnan", Intrinsic::not_intrinsic},
209 {"lround", Intrinsic::lround},
210 {"llround", Intrinsic::llround},
211 {"lrint", Intrinsic::lrint},
212 {"llrint", Intrinsic::llrint}};
213
214static bool isItaniumEncoding(StringRef S) {
215 // Itanium encoding requires 1 or 3 leading underscores, followed by 'Z'.
216 return startsWith(S, "_Z") || startsWith(S, "___Z");
217}
218
219bool dontAnalyze(StringRef str) {
220 if (isItaniumEncoding(str)) {
221 if (str.empty())
222 return false;
223
224 ItaniumPartialDemangler Parser;
225 char *data = (char *)malloc(str.size() + 1);
226 memcpy(data, str.data(), str.size());
227 data[str.size()] = 0;
228 bool hasError = Parser.partialDemangle(data);
229 if (hasError) {
230 free(data);
231 return false;
232 }
233
234 // auto basename = Parser.getFunctionBaseName(0, 0);
235 // auto base = Parser.getFunctionDeclContextName(0, 0);
236 // auto fn = Parser.getFunctionName(0, 0);
237 // llvm::errs() << " err: " << base << " - " << basename << " fn - " << fn
238 // << "\n";
239 free(data);
240 }
241 return false;
242}
243
245 uint8_t direction)
246 : MST(EnzymePrintType ? new ModuleSlotTracker(fn.Function->getParent())
247 : nullptr),
248 notForAnalysis(getGuaranteedUnreachable(fn.Function)), intseen(),
249 fntypeinfo(fn), interprocedural(TA), direction(direction), Invalid(false),
250 PHIRecur(false),
251 TLI(TA.Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(*fn.Function)),
252 DT(TA.Logic.PPC.FAM.getResult<DominatorTreeAnalysis>(*fn.Function)),
253 PDT(TA.Logic.PPC.FAM.getResult<PostDominatorTreeAnalysis>(*fn.Function)),
254 LI(TA.Logic.PPC.FAM.getResult<LoopAnalysis>(*fn.Function)),
255 SE(TA.Logic.PPC.FAM.getResult<ScalarEvolutionAnalysis>(*fn.Function)) {
256
257 assert(fntypeinfo.KnownValues.size() ==
258 fntypeinfo.Function->getFunctionType()->getNumParams());
259
260 // Add all instructions in the function
261 for (BasicBlock &BB : *fntypeinfo.Function) {
262 if (notForAnalysis.count(&BB))
263 continue;
264 for (Instruction &I : BB) {
265 workList.insert(&I);
266 }
267 }
268 // Add all operands referenced in the function
269 // This is done to investigate any referenced globals/etc
270 for (BasicBlock &BB : *fntypeinfo.Function) {
271 for (Instruction &I : BB) {
272 for (auto &Op : I.operands()) {
273 addToWorkList(Op);
274 }
275 }
276 }
277}
278
280 const FnTypeInfo &fn, TypeAnalysis &TA,
281 const llvm::SmallPtrSetImpl<llvm::BasicBlock *> &notForAnalysis,
282 const TypeAnalyzer &Prev, uint8_t direction, bool PHIRecur)
283 : MST(Prev.MST),
284 notForAnalysis(notForAnalysis.begin(), notForAnalysis.end()), intseen(),
285 fntypeinfo(fn), interprocedural(TA), direction(direction), Invalid(false),
286 PHIRecur(PHIRecur), TLI(Prev.TLI), DT(Prev.DT), PDT(Prev.PDT),
287 LI(Prev.LI), SE(Prev.SE) {
288 assert(fntypeinfo.KnownValues.size() ==
289 fntypeinfo.Function->getFunctionType()->getNumParams());
290}
291
292static SmallPtrSet<BasicBlock *, 1>
293findLoopIndices(llvm::Value *val, LoopInfo &LI, DominatorTree &DT,
294 SmallPtrSet<PHINode *, 1> &seen) {
295 if (isa<Constant>(val))
296 return {};
297 if (auto CI = dyn_cast<CastInst>(val))
298 return findLoopIndices(CI->getOperand(0), LI, DT, seen);
299 if (auto CI = dyn_cast<UnaryOperator>(val))
300 return findLoopIndices(CI->getOperand(0), LI, DT, seen);
301 if (auto bo = dyn_cast<BinaryOperator>(val)) {
302 auto inset0 = findLoopIndices(bo->getOperand(0), LI, DT, seen);
303 auto inset1 = findLoopIndices(bo->getOperand(1), LI, DT, seen);
304 inset0.insert(inset1.begin(), inset1.end());
305 return inset0;
306 }
307 if (auto LDI = dyn_cast<LoadInst>(val)) {
308 if (auto AI = dyn_cast<AllocaInst>(LDI->getPointerOperand())) {
309 StoreInst *SI = nullptr;
310 bool failed = false;
311 for (auto u : AI->users()) {
312 if (auto SIu = dyn_cast<StoreInst>(u)) {
313 if (SI && SIu->getValueOperand() == AI) {
314 failed = true;
315 break;
316 }
317 SI = SIu;
318 } else if (!isa<LoadInst>(u)) {
319 if (!cast<Instruction>(u)->mayReadOrWriteMemory() &&
320 cast<Instruction>(u)->use_empty())
321 continue;
322 if (auto CI = dyn_cast<CallBase>(u)) {
323 if (auto F = CI->getCalledFunction()) {
324 auto funcName = F->getName();
325 if (funcName == "__kmpc_for_static_init_4" ||
326 funcName == "__kmpc_for_static_init_4u" ||
327 funcName == "__kmpc_for_static_init_8" ||
328 funcName == "__kmpc_for_static_init_8u") {
329 continue;
330 }
331 }
332 }
333 failed = true;
334 break;
335 }
336 }
337 if (SI && !failed && DT.dominates(SI, LDI)) {
338 return findLoopIndices(SI->getValueOperand(), LI, DT, seen);
339 }
340 }
341 }
342 if (auto pn = dyn_cast<PHINode>(val)) {
343 auto L = LI.getLoopFor(pn->getParent());
344 if (L && L->getHeader() == pn->getParent())
345 return {pn->getParent()};
346 if (seen.contains(pn))
347 return {};
348 SmallPtrSet<BasicBlock *, 1> ops;
349 seen.insert(pn);
350 for (unsigned i = 0; i < pn->getNumIncomingValues(); ++i) {
351 auto a = pn->getIncomingValue(i);
352 auto seti = findLoopIndices(a, LI, DT, seen);
353 ops.insert(seti.begin(), seti.end());
354 }
355 return ops;
356 }
357 return {};
358}
359
360std::set<int64_t>
361FnTypeInfo::knownIntegralValues(llvm::Value *val, const DominatorTree &DT,
362 std::map<Value *, std::set<int64_t>> &intseen,
363 ScalarEvolution &SE) const {
364 if (auto constant = dyn_cast<ConstantInt>(val)) {
365#if LLVM_VERSION_MAJOR > 14
366 if (constant->getValue().getSignificantBits() > 64)
367 return {};
368#else
369 if (constant->getValue().getMinSignedBits() > 64)
370 return {};
371#endif
372 return {constant->getSExtValue()};
373 }
374
375 if (isa<ConstantPointerNull>(val)) {
376 return {0};
377 }
378
379 assert(KnownValues.size() == Function->getFunctionType()->getNumParams());
380
381 if (auto arg = dyn_cast<llvm::Argument>(val)) {
382 auto found = KnownValues.find(arg);
383 if (found == KnownValues.end()) {
384 for (const auto &pair : KnownValues) {
385 llvm::errs() << " KnownValues[" << *pair.first << "] - "
386 << pair.first->getParent()->getName() << "\n";
387 }
388 llvm::errs() << " arg: " << *arg << " - " << arg->getParent()->getName()
389 << "\n";
390 }
391 assert(found != KnownValues.end());
392 return found->second;
393 }
394
395 if (intseen.find(val) != intseen.end())
396 return intseen[val];
397 intseen[val] = {};
398
399 if (auto ci = dyn_cast<CastInst>(val)) {
400 intseen[val] = knownIntegralValues(ci->getOperand(0), DT, intseen, SE);
401 }
402
403 auto insert = [&](int64_t v) {
404 if (intseen[val].size() == 0) {
405 intseen[val].insert(v);
406 } else {
407 if (intseen[val].size() == 1) {
408 if (abs(*intseen[val].begin()) > MaxIntOffset) {
409 if (abs(*intseen[val].begin()) > abs(v)) {
410 intseen[val].clear();
411 intseen[val].insert(v);
412 } else {
413 return;
414 }
415 } else {
416 if (abs(v) > MaxIntOffset) {
417 return;
418 } else {
419 intseen[val].insert(v);
420 }
421 }
422 } else {
423 if (abs(v) > MaxIntOffset) {
424 return;
425 } else {
426 intseen[val].insert(v);
427 }
428 }
429 }
430 };
431 if (auto II = dyn_cast<IntrinsicInst>(val)) {
432 switch (II->getIntrinsicID()) {
433#if LLVM_VERSION_MAJOR >= 12
434 case Intrinsic::abs:
435 for (auto val :
436 knownIntegralValues(II->getArgOperand(0), DT, intseen, SE))
437 insert(abs(val));
438 break;
439#endif
440 case Intrinsic::nvvm_read_ptx_sreg_tid_x:
441 case Intrinsic::nvvm_read_ptx_sreg_tid_y:
442 case Intrinsic::nvvm_read_ptx_sreg_tid_z:
443 case Intrinsic::nvvm_read_ptx_sreg_ctaid_x:
444 case Intrinsic::nvvm_read_ptx_sreg_ctaid_y:
445 case Intrinsic::nvvm_read_ptx_sreg_ctaid_z:
446 case Intrinsic::amdgcn_workitem_id_x:
447 case Intrinsic::amdgcn_workitem_id_y:
448 case Intrinsic::amdgcn_workitem_id_z:
449 insert(0);
450 break;
451 default:
452 break;
453 }
454 }
455 if (auto LI = dyn_cast<LoadInst>(val)) {
456 if (auto AI = dyn_cast<AllocaInst>(LI->getPointerOperand())) {
457 StoreInst *SI = nullptr;
458 bool failed = false;
459 for (auto u : AI->users()) {
460 if (auto SIu = dyn_cast<StoreInst>(u)) {
461 if (SI && SIu->getValueOperand() == AI) {
462 failed = true;
463 break;
464 }
465 SI = SIu;
466 } else if (!isa<LoadInst>(u)) {
467 if (!cast<Instruction>(u)->mayReadOrWriteMemory() &&
468 cast<Instruction>(u)->use_empty())
469 continue;
470 if (auto CI = dyn_cast<CallBase>(u)) {
471 if (auto F = CI->getCalledFunction()) {
472 auto funcName = F->getName();
473 if (funcName == "__kmpc_for_static_init_4" ||
474 funcName == "__kmpc_for_static_init_4u" ||
475 funcName == "__kmpc_for_static_init_8" ||
476 funcName == "__kmpc_for_static_init_8u") {
477 continue;
478 }
479 }
480 }
481 failed = true;
482 break;
483 }
484 }
485 if (SI && !failed && DT.dominates(SI, LI)) {
486 for (auto val :
487 knownIntegralValues(SI->getValueOperand(), DT, intseen, SE)) {
488 insert(val);
489 }
490 }
491 }
492 }
493 if (auto pn = dyn_cast<PHINode>(val)) {
494 if (SE.isSCEVable(pn->getType()))
495 if (auto S = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(pn))) {
496 if (auto StartC = dyn_cast<SCEVConstant>(S->getStart())) {
497 auto L = S->getLoop();
498 auto BE = SE.getBackedgeTakenCount(L);
499 if (BE != SE.getCouldNotCompute()) {
500 if (auto Iters = dyn_cast<SCEVConstant>(BE)) {
501 uint64_t ival = Iters->getAPInt().getZExtValue();
502 // If strict aliasing and the loop header does not dominate all
503 // blocks at low optimization levels the last "iteration" will
504 // actually exit leading to one extra backedge that would be wise
505 // to ignore.
507 bool rotated = false;
508 BasicBlock *Latch = L->getLoopLatch();
509 rotated = Latch && L->isLoopExiting(Latch);
510 if (!rotated) {
511 if (ival > 0)
512 ival--;
513 }
514 }
515
516 uint64_t istart = 0;
517
518 if (S->isAffine()) {
519 if (auto StepC = dyn_cast<SCEVConstant>(S->getOperand(1))) {
520 APInt StartI = StartC->getAPInt();
521 APInt A = StepC->getAPInt();
522
523 if (A.sle(-1)) {
524 A = -A;
525 StartI = -StartI;
526 }
527
528 if (A.sge(1)) {
529 if (StartI.sge(MaxIntOffset)) {
530 ival = std::min(ival, (uint64_t)0);
531 } else {
532 ival = std::min(
533 ival,
534 (MaxIntOffset - StartI + A).udiv(A).getZExtValue());
535 }
536
537 if (StartI.slt(-MaxIntOffset)) {
538 istart = std::max(
539 istart,
540 (-MaxIntOffset - StartI).udiv(A).getZExtValue());
541 }
542
543 } else {
544 ival = std::min(ival, (uint64_t)0);
545 }
546 } else {
547 ival = std::min(ival, (uint64_t)0);
548 }
549 }
550
551 for (uint64_t i = istart; i <= ival; i++) {
552 if (auto Val = dyn_cast<SCEVConstant>(S->evaluateAtIteration(
553 SE.getConstant(Iters->getType(), i, /*signed*/ false),
554 SE))) {
555 insert(Val->getAPInt().getSExtValue());
556 }
557 }
558 return intseen[val];
559 }
560 }
561 }
562 }
563
564 for (unsigned i = 0; i < pn->getNumIncomingValues(); ++i) {
565 auto a = pn->getIncomingValue(i);
566 auto b = pn->getIncomingBlock(i);
567
568 // do not consider loop incoming edges
569 if (pn->getParent() == b || DT.dominates(pn, b)) {
570 continue;
571 }
572
573 auto inset = knownIntegralValues(a, DT, intseen, SE);
574
575 // TODO this here is not fully justified yet
576 for (auto pval : inset) {
577 if (pval < 20 && pval > -20) {
578 insert(pval);
579 }
580 }
581
582 // if we are an iteration variable, suppose that it could be zero in that
583 // range
584 // TODO: could actually check the range intercepts 0
585 if (auto bo = dyn_cast<BinaryOperator>(a)) {
586 if (bo->getOperand(0) == pn || bo->getOperand(1) == pn) {
587 if (bo->getOpcode() == BinaryOperator::Add ||
588 bo->getOpcode() == BinaryOperator::Sub) {
589 insert(0);
590 }
591 }
592 }
593 }
594 return intseen[val];
595 }
596
597 if (auto bo = dyn_cast<BinaryOperator>(val)) {
598 auto inset0 = knownIntegralValues(bo->getOperand(0), DT, intseen, SE);
599 auto inset1 = knownIntegralValues(bo->getOperand(1), DT, intseen, SE);
600 if (bo->getOpcode() == BinaryOperator::Mul) {
601
602 if (inset0.size() == 1 || inset1.size() == 1) {
603 for (auto val0 : inset0) {
604 for (auto val1 : inset1) {
605
606 insert(val0 * val1);
607 }
608 }
609 }
610 if (inset0.count(0) || inset1.count(0)) {
611 intseen[val].insert(0);
612 }
613 }
614
615 if (bo->getOpcode() == BinaryOperator::Add) {
616 if (inset0.size() == 1 || inset1.size() == 1) {
617 for (auto val0 : inset0) {
618 for (auto val1 : inset1) {
619 insert(val0 + val1);
620 }
621 }
622 }
623 }
624 if (bo->getOpcode() == BinaryOperator::Sub) {
625 if (inset0.size() == 1 || inset1.size() == 1) {
626 for (auto val0 : inset0) {
627 for (auto val1 : inset1) {
628 insert(val0 - val1);
629 }
630 }
631 }
632 }
633
634 if (bo->getOpcode() == BinaryOperator::SDiv) {
635 if (inset0.size() == 1 || inset1.size() == 1) {
636 for (auto val0 : inset0) {
637 for (auto val1 : inset1) {
638 insert(val0 / val1);
639 }
640 }
641 }
642 }
643
644 if (bo->getOpcode() == BinaryOperator::Shl) {
645 if (inset0.size() == 1 || inset1.size() == 1) {
646 for (auto val0 : inset0) {
647 for (auto val1 : inset1) {
648 insert(val0 << val1);
649 }
650 }
651 }
652 }
653
654 // TODO note C++ doesnt guarantee behavior of >> being arithmetic or logical
655 // and should replace with llvm apint internal
656 if (bo->getOpcode() == BinaryOperator::AShr ||
657 bo->getOpcode() == BinaryOperator::LShr) {
658 if (inset0.size() == 1 || inset1.size() == 1) {
659 for (auto val0 : inset0) {
660 for (auto val1 : inset1) {
661 insert(val0 >> val1);
662 }
663 }
664 }
665 }
666 }
667
668 return intseen[val];
669}
670
671/// Given a constant value, deduce any type information applicable
673 std::map<llvm::Value *, TypeTree> &analysis) {
674 auto found = analysis.find(Val);
675 if (found != analysis.end())
676 return;
677
678 auto &DL = TA.fntypeinfo.Function->getParent()->getDataLayout();
679
680 // Undefined value is an anything everywhere
681 if (isa<UndefValue>(Val) || isa<ConstantAggregateZero>(Val)) {
682 analysis[Val].insert({-1}, BaseType::Anything);
683 return;
684 }
685
686 // Null pointer is a pointer to anything, everywhere
687 if (isa<ConstantPointerNull>(Val)) {
688 TypeTree &Result = analysis[Val];
689 Result.insert({-1}, BaseType::Pointer);
690 Result.insert({-1, -1}, BaseType::Anything);
691 return;
692 }
693
694 // Known pointers are pointers at offset 0
695 if (isa<Function>(Val) || isa<BlockAddress>(Val)) {
696 analysis[Val].insert({-1}, BaseType::Pointer);
697 return;
698 }
699
700 // Any constants == 0 are considered Anything
701 // other floats are assumed to be that type
702 if (auto FP = dyn_cast<ConstantFP>(Val)) {
703 if (FP->isExactlyValue(0.0)) {
704 analysis[Val].insert({-1}, BaseType::Anything);
705 return;
706 }
707 analysis[Val].insert({-1}, ConcreteType(FP->getType()->getScalarType()));
708 return;
709 }
710
711 if (auto ci = dyn_cast<ConstantInt>(Val)) {
712 // Constants in range [1, 4096] are assumed to be integral since
713 // any float or pointers they may represent are ill-formed
714 if (!ci->isNegative() && ci->getLimitedValue() >= 1 &&
715 ci->getLimitedValue() <= 4096) {
716 analysis[Val].insert({-1}, BaseType::Integer);
717 return;
718 }
719
720 // Constants explicitly marked as negative that aren't -1 are considered
721 // integral if >= -4096
722 if (ci->isNegative() && !ci->isMinusOne() && ci->getValue().sge(-4096)) {
723 analysis[Val].insert({-1}, BaseType::Integer);
724 return;
725 }
726
727 // Values of size < 16 (half size) are considered integral
728 // since they cannot possibly represent a float or pointer
729 if (cast<IntegerType>(ci->getType())->getBitWidth() < 16) {
730 analysis[Val].insert({-1}, BaseType::Integer);
731 return;
732 }
733 // All other constant-ints could be any type
734 analysis[Val].insert({-1}, BaseType::Anything);
735 return;
736 }
737
738 // Type of an aggregate is the aggregation of
739 // the subtypes
740 if (auto CA = dyn_cast<ConstantAggregate>(Val)) {
741 TypeTree &Result = analysis[Val];
742 for (unsigned i = 0, size = CA->getNumOperands(); i < size; ++i) {
743 assert(TA.fntypeinfo.Function);
744 auto Op = CA->getOperand(i);
745 // TODO check this for i1 constant aggregates packing/etc
746 auto ObjSize = (TA.fntypeinfo.Function->getParent()
747 ->getDataLayout()
748 .getTypeSizeInBits(Op->getType()) +
749 7) /
750 8;
751
752 Value *vec[2] = {
753 ConstantInt::get(Type::getInt64Ty(Val->getContext()), 0),
754 ConstantInt::get(Type::getInt32Ty(Val->getContext()), i),
755 };
756 auto g2 = GetElementPtrInst::Create(
757 Val->getType(), UndefValue::get(getUnqual(Val->getType())), vec);
758 APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
759 g2->accumulateConstantOffset(DL, ai);
760 // Using destructor rather than eraseFromParent
761 // as g2 has no parent
762 delete g2;
763
764 int Off = (int)ai.getLimitedValue();
765 if (auto VT = dyn_cast<VectorType>(Val->getType()))
766 if (VT->getElementType()->isIntegerTy(1))
767 Off = i / 8;
768
769 getConstantAnalysis(Op, TA, analysis);
770 auto mid = analysis[Op];
771 if (TA.fntypeinfo.Function->getParent()
772 ->getDataLayout()
773 .getTypeSizeInBits(CA->getType()) >= 16) {
774 mid.ReplaceIntWithAnything();
775 }
776
777 Result |= mid.ShiftIndices(DL, /*init offset*/ 0,
778 /*maxSize*/ ObjSize,
779 /*addOffset*/ Off);
780 }
781 Result.CanonicalizeInPlace(
782 (TA.fntypeinfo.Function->getParent()->getDataLayout().getTypeSizeInBits(
783 CA->getType()) +
784 7) /
785 8,
786 DL);
787 return;
788 }
789
790 // Type of an sequence is the aggregation of
791 // the subtypes
792 if (auto CD = dyn_cast<ConstantDataSequential>(Val)) {
793 TypeTree &Result = analysis[Val];
794 for (unsigned i = 0, size = CD->getNumElements(); i < size; ++i) {
795 assert(TA.fntypeinfo.Function);
796 auto Op = CD->getElementAsConstant(i);
797 // TODO check this for i1 constant aggregates packing/etc
798 auto ObjSize = (TA.fntypeinfo.Function->getParent()
799 ->getDataLayout()
800 .getTypeSizeInBits(Op->getType()) +
801 7) /
802 8;
803
804 Value *vec[2] = {
805 ConstantInt::get(Type::getInt64Ty(Val->getContext()), 0),
806 ConstantInt::get(Type::getInt32Ty(Val->getContext()), i),
807 };
808 auto g2 = GetElementPtrInst::Create(
809 Val->getType(), UndefValue::get(getUnqual(Val->getType())), vec);
810 APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
811 g2->accumulateConstantOffset(DL, ai);
812 // Using destructor rather than eraseFromParent
813 // as g2 has no parent
814 delete g2;
815
816 int Off = (int)ai.getLimitedValue();
817
818 getConstantAnalysis(Op, TA, analysis);
819 auto mid = analysis[Op];
820 if (TA.fntypeinfo.Function->getParent()
821 ->getDataLayout()
822 .getTypeSizeInBits(CD->getType()) >= 16) {
823 mid.ReplaceIntWithAnything();
824 }
825 Result |= mid.ShiftIndices(DL, /*init offset*/ 0,
826 /*maxSize*/ ObjSize,
827 /*addOffset*/ Off);
828
829 Result |= mid;
830 }
831 Result.CanonicalizeInPlace(
832 (TA.fntypeinfo.Function->getParent()->getDataLayout().getTypeSizeInBits(
833 CD->getType()) +
834 7) /
835 8,
836 DL);
837 return;
838 }
839
840 // ConstantExprs are handled by considering the
841 // equivalent instruction
842 if (auto CE = dyn_cast<ConstantExpr>(Val)) {
843 if (CE->isCast()) {
844 if (CE->getType()->isPointerTy() && isa<ConstantInt>(CE->getOperand(0))) {
845 analysis[Val] = TypeTree(BaseType::Anything).Only(-1, nullptr);
846 return;
847 }
848 getConstantAnalysis(CE->getOperand(0), TA, analysis);
849 analysis[Val] = analysis[CE->getOperand(0)];
850 return;
851 }
852 if (CE->getOpcode() == Instruction::GetElementPtr) {
853 TA.visitGEPOperator(*cast<GEPOperator>(CE));
854 return;
855 }
856
857 auto I = CE->getAsInstruction();
858 I->insertBefore(TA.fntypeinfo.Function->getEntryBlock().getTerminator());
859
860 // Just analyze this new "instruction" and none of the others
861 {
862 TypeAnalyzer tmpAnalysis(TA.fntypeinfo, TA.interprocedural,
863 TA.notForAnalysis, TA);
864 tmpAnalysis.visit(*I);
865 analysis[Val] = tmpAnalysis.getAnalysis(I);
866
867 if (tmpAnalysis.workList.remove(I)) {
868 TA.workList.insert(CE);
869 }
870 }
871
872 I->eraseFromParent();
873 return;
874 }
875
876 if (auto GV = dyn_cast<GlobalVariable>(Val)) {
877
878 if (GV->getName() == "__cxa_thread_atexit_impl") {
879 analysis[Val] = TypeTree(BaseType::Pointer).Only(-1, nullptr);
880 return;
881 }
882
883 // from julia code
884 if (GV->getName() == "small_typeof" || GV->getName() == "jl_small_typeof") {
885 TypeTree T;
886 T.insert({-1}, BaseType::Pointer);
887 T.insert({-1, -1}, BaseType::Pointer);
888 analysis[Val] = T;
889 return;
890 }
891
892 TypeTree &Result = analysis[Val];
894
895 // A fixed constant global is a pointer to its initializer
896 if (GV->isConstant() && GV->hasInitializer()) {
897 getConstantAnalysis(GV->getInitializer(), TA, analysis);
898 Result |= analysis[GV->getInitializer()].Only(-1, nullptr);
899 return;
900 }
901 if (!isa<StructType>(GV->getValueType()) ||
902 !cast<StructType>(GV->getValueType())->isOpaque()) {
903 auto globalSize = (DL.getTypeSizeInBits(GV->getValueType()) + 7) / 8;
904 // Since halfs are 16bit (2 byte) and pointers are >=32bit (4 byte) any
905 // Single byte object must be integral
906 if (globalSize == 1) {
907 Result.insert({-1, -1}, ConcreteType(BaseType::Integer));
908 return;
909 }
910 }
911
912 // Otherwise, we simply know that this is a pointer, and
913 // not what it is a pointer to
914 return;
915 }
916
917 // No other information can be ascertained
918 analysis[Val] = TypeTree();
919 return;
920}
921
923 // Integers with fewer than 16 bits (size of half)
924 // must be integral, since it cannot possibly represent a float or pointer
925 if (!isa<UndefValue>(Val) && Val->getType()->isIntegerTy() &&
926 cast<IntegerType>(Val->getType())->getBitWidth() < 16)
927 return TypeTree(BaseType::Integer).Only(-1, nullptr);
928 if (auto C = dyn_cast<Constant>(Val)) {
929 getConstantAnalysis(C, *this, analysis);
930 return analysis[Val];
931 }
932
933 // Check that this value is from the function being analyzed
934 if (auto I = dyn_cast<Instruction>(Val)) {
935 if (I->getParent()->getParent() != fntypeinfo.Function) {
936 llvm::errs() << " function: " << *fntypeinfo.Function << "\n";
937 llvm::errs() << " instParent: " << *I->getParent()->getParent() << "\n";
938 llvm::errs() << " inst: " << *I << "\n";
939 }
940 assert(I->getParent()->getParent() == fntypeinfo.Function);
941 }
942 if (auto Arg = dyn_cast<Argument>(Val)) {
943 if (Arg->getParent() != fntypeinfo.Function) {
944 llvm::errs() << " function: " << *fntypeinfo.Function << "\n";
945 llvm::errs() << " argParent: " << *Arg->getParent() << "\n";
946 llvm::errs() << " arg: " << *Arg << "\n";
947 }
948 assert(Arg->getParent() == fntypeinfo.Function);
949 }
950
951 // Return current results
952 if (isa<Argument>(Val) || isa<Instruction>(Val))
953 return analysis[Val];
954
955 // Unhandled/unknown Value
956 llvm::errs() << "Error Unknown Value: " << *Val << "\n";
957 assert(0 && "Error Unknown Value: ");
958 llvm_unreachable("Error Unknown Value: ");
959 // return TypeTree();
960}
961
962void TypeAnalyzer::updateAnalysis(Value *Val, ConcreteType Data,
963 Value *Origin) {
964 updateAnalysis(Val, TypeTree(Data), Origin);
965}
966
967void TypeAnalyzer::updateAnalysis(Value *Val, BaseType Data, Value *Origin) {
968 updateAnalysis(Val, TypeTree(ConcreteType(Data)), Origin);
969}
970
971void TypeAnalyzer::addToWorkList(Value *Val) {
972 // Only consider instructions/arguments
973 if (!isa<Instruction>(Val) && !isa<Argument>(Val) &&
974 !isa<ConstantExpr>(Val) && !isa<GlobalVariable>(Val))
975 return;
976
977 // Verify this value comes from the function being analyzed
978 if (auto I = dyn_cast<Instruction>(Val)) {
979 if (fntypeinfo.Function != I->getParent()->getParent())
980 return;
981 if (notForAnalysis.count(I->getParent()))
982 return;
983 if (fntypeinfo.Function != I->getParent()->getParent()) {
984 llvm::errs() << "function: " << *fntypeinfo.Function << "\n";
985 llvm::errs() << "instf: " << *I->getParent()->getParent() << "\n";
986 llvm::errs() << "inst: " << *I << "\n";
987 }
988 assert(fntypeinfo.Function == I->getParent()->getParent());
989 } else if (auto Arg = dyn_cast<Argument>(Val)) {
990 if (fntypeinfo.Function != Arg->getParent()) {
991 llvm::errs() << "fn: " << *fntypeinfo.Function << "\n";
992 llvm::errs() << "argparen: " << *Arg->getParent() << "\n";
993 llvm::errs() << "val: " << *Arg << "\n";
994 }
995 assert(fntypeinfo.Function == Arg->getParent());
996 }
997
998 // Add to workList
999 workList.insert(Val);
1000}
1001
1002void TypeAnalyzer::updateAnalysis(Value *Val, TypeTree Data, Value *Origin) {
1003 if (Val->getType()->isVoidTy())
1004 return;
1005 // ConstantData's and Functions don't have analysis updated
1006 // We don't do "Constant" as globals are "Constant" types
1007 if (isa<ConstantData>(Val) || isa<Function>(Val)) {
1008 return;
1009 }
1010
1011 if (auto GV = dyn_cast<GlobalVariable>(Val)) {
1012 if (hasMetadata(GV, "enzyme_ta_norecur"))
1013 return;
1014 }
1015
1016 if (auto CE = dyn_cast<ConstantExpr>(Val)) {
1017 if (CE->isCast() && isa<ConstantInt>(CE->getOperand(0))) {
1018 return;
1019 }
1020 if (CE->getOpcode() == Instruction::GetElementPtr &&
1021 isa<ConstantPointerNull>(CE->getOperand(0)))
1022 return;
1023 }
1024
1025 if (auto I = dyn_cast<Instruction>(Val)) {
1026 if (fntypeinfo.Function != I->getParent()->getParent()) {
1027 llvm::errs() << "function: " << *fntypeinfo.Function << "\n";
1028 llvm::errs() << "instf: " << *I->getParent()->getParent() << "\n";
1029 llvm::errs() << "inst: " << *I << "\n";
1030 }
1031 assert(fntypeinfo.Function == I->getParent()->getParent());
1032 assert(Origin);
1033 if (!EnzymeStrictAliasing) {
1034 if (auto OI = dyn_cast<Instruction>(Origin)) {
1035 if (OI->getParent() != I->getParent() &&
1036 !PDT.dominates(OI->getParent(), I->getParent())) {
1037 bool allocationWithAllUsersInBlock = false;
1038 if (auto AI = dyn_cast<AllocaInst>(I)) {
1039 allocationWithAllUsersInBlock = true;
1040 for (auto U : AI->users()) {
1041 auto P = cast<Instruction>(U)->getParent();
1042 if (P == OI->getParent())
1043 continue;
1044 if (PDT.dominates(OI->getParent(), P))
1045 continue;
1046 allocationWithAllUsersInBlock = false;
1047 break;
1048 }
1049 }
1050 if (!allocationWithAllUsersInBlock) {
1051 if (EnzymePrintType) {
1052 llvm::errs() << " skipping update into ";
1053 I->print(llvm::errs(), *MST);
1054 llvm::errs() << " of " << Data.str() << " from ";
1055 OI->print(llvm::errs(), *MST);
1056 llvm::errs() << "\n";
1057 }
1058 return;
1059 }
1060 }
1061 }
1062 }
1063 } else if (auto Arg = dyn_cast<Argument>(Val)) {
1064 assert(fntypeinfo.Function == Arg->getParent());
1066 if (auto OI = dyn_cast<Instruction>(Origin)) {
1067 auto I = &*fntypeinfo.Function->getEntryBlock().begin();
1068 if (OI->getParent() != I->getParent() &&
1069 !PDT.dominates(OI->getParent(), I->getParent())) {
1070 if (EnzymePrintType) {
1071 llvm::errs() << " skipping update into ";
1072 Arg->print(llvm::errs(), *MST);
1073 llvm::errs() << " of " << Data.str() << " from ";
1074 OI->print(llvm::errs(), *MST);
1075 llvm::errs() << "\n";
1076 }
1077 return;
1078 }
1079 }
1080 }
1081
1082 // Attempt to update the underlying analysis
1083 bool LegalOr = true;
1084 if (analysis.find(Val) == analysis.end() && isa<Constant>(Val)) {
1085 if (!isa<ConstantExpr>(Val) ||
1086 cast<ConstantExpr>(Val)->getOpcode() != Instruction::GetElementPtr)
1087 getConstantAnalysis(cast<Constant>(Val), *this, analysis);
1088 }
1089
1090 TypeTree prev = analysis[Val];
1091
1092 auto &DL = fntypeinfo.Function->getParent()->getDataLayout();
1093 auto RegSize = (DL.getTypeSizeInBits(Val->getType()) + 7) / 8;
1094 Data.CanonicalizeInPlace(RegSize, DL);
1095 bool Changed =
1096 analysis[Val].checkedOrIn(Data, /*PointerIntSame*/ false, LegalOr);
1097
1098 // Print the update being made, if requested
1099 if (EnzymePrintType) {
1100 llvm::errs() << "updating analysis of val: ";
1101 Val->print(llvm::errs(), *MST);
1102 llvm::errs() << " current: " << prev.str() << " new " << Data.str();
1103 if (Origin) {
1104 llvm::errs() << " from ";
1105 Origin->print(llvm::errs(), *MST);
1106 }
1107 llvm::errs() << " Changed=" << Changed << " legal=" << LegalOr << "\n";
1108 }
1109
1110 if (!LegalOr) {
1111 if (direction != BOTH) {
1112 Invalid = true;
1113 return;
1114 }
1115 std::string str;
1116 raw_string_ostream ss(str);
1117 if (!CustomErrorHandler) {
1118 llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
1119 llvm::errs() << *fntypeinfo.Function << "\n";
1120 dump(ss);
1121 }
1122 ss << "Illegal updateAnalysis prev:" << prev.str() << " new: " << Data.str()
1123 << "\n";
1124 ss << "val: " << *Val;
1125 if (Origin)
1126 ss << " origin=" << *Origin;
1127
1128 if (CustomErrorHandler) {
1130 (void *)this, wrap(Origin), nullptr);
1131 }
1132 if (auto I = dyn_cast<Instruction>(Val)) {
1133 EmitFailure("IllegalUpdateAnalysis", I->getDebugLoc(), I, ss.str());
1134 exit(1);
1135 } else if (auto I = dyn_cast_or_null<Instruction>(Origin)) {
1136 EmitFailure("IllegalUpdateAnalysis", I->getDebugLoc(), I, ss.str());
1137 exit(1);
1138 } else {
1139 llvm::errs() << ss.str() << "\n";
1140 }
1141 report_fatal_error("Performed illegal updateAnalysis");
1142 }
1143
1144 if (Changed) {
1145
1146 if (auto GV = dyn_cast<GlobalVariable>(Val)) {
1147 if (GV->getValueType()->isSized()) {
1148 auto Size = (DL.getTypeSizeInBits(GV->getValueType()) + 7) / 8;
1149 Data = analysis[Val].Lookup(Size, DL).Only(-1, nullptr);
1150 Data.insert({-1}, BaseType::Pointer);
1151 analysis[Val] = Data;
1152 Origin = Val;
1153 }
1154 }
1155 // Add val so it can explicitly propagate this new info, if able to
1156 if (Val != Origin)
1157 addToWorkList(Val);
1158
1159 // Add users and operands of the value so they can update from the new
1160 // operand/use
1161 for (User *U : Val->users()) {
1162 if (U != Origin) {
1163
1164 if (auto I = dyn_cast<Instruction>(U)) {
1165 if (fntypeinfo.Function != I->getParent()->getParent()) {
1166 continue;
1167 }
1168 }
1169
1170 addToWorkList(U);
1171
1172 // per the handling of phi's
1173 if (auto BO = dyn_cast<BinaryOperator>(U)) {
1174 for (User *U2 : BO->users()) {
1175 if (isa<PHINode>(U2) && U2 != Origin) {
1176 addToWorkList(U2);
1177 }
1178 }
1179 }
1180 }
1181 }
1182
1183 if (User *US = dyn_cast<User>(Val)) {
1184 for (Value *Op : US->operands()) {
1185 if (Op != Origin) {
1186 addToWorkList(Op);
1187 }
1188 }
1189 }
1190 }
1191}
1192
1193/// Analyze type info given by the arguments, possibly adding to work queue
1195 // Propagate input type information for arguments
1196 for (auto &pair : fntypeinfo.Arguments) {
1197 assert(pair.first->getParent() == fntypeinfo.Function);
1198 updateAnalysis(pair.first, pair.second, pair.first);
1199 }
1200
1201 // Get type and other information about argument
1202 // getAnalysis may add more information so this
1203 // is necessary/useful
1204 for (Argument &Arg : fntypeinfo.Function->args()) {
1205 updateAnalysis(&Arg, getAnalysis(&Arg), &Arg);
1206 }
1207
1208 // Propagate return value type information
1209 for (BasicBlock &BB : *fntypeinfo.Function) {
1210 for (Instruction &I : BB) {
1211 if (ReturnInst *RI = dyn_cast<ReturnInst>(&I)) {
1212 if (Value *RV = RI->getReturnValue()) {
1214 updateAnalysis(RV, getAnalysis(RV), RV);
1215 }
1216 }
1217 }
1218 }
1219}
1220
1221/// Analyze type info given by the TBAA, possibly adding to work queue
1223 auto &DL = fntypeinfo.Function->getParent()->getDataLayout();
1224
1225 for (BasicBlock &BB : *fntypeinfo.Function) {
1226 if (notForAnalysis.count(&BB))
1227 continue;
1228 for (Instruction &I : BB) {
1229 if (auto MD = I.getMetadata("enzyme_type")) {
1230 auto TT = TypeTree::fromMD(MD);
1231
1232 auto RegSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8;
1233 for (const auto &pair : TT.getMapping()) {
1234 if (pair.first[0] != -1) {
1235 if ((size_t)pair.first[0] >= RegSize) {
1236 llvm::errs() << " bad enzyme_type " << TT.str()
1237 << " RegSize=" << RegSize << " I:" << I << "\n";
1238 llvm::report_fatal_error("Canonicalization failed");
1239 }
1240 }
1241 }
1242 updateAnalysis(&I, TT, &I);
1243 }
1244
1245 if (CallBase *call = dyn_cast<CallBase>(&I)) {
1246#if LLVM_VERSION_MAJOR >= 14
1247 size_t num_args = call->arg_size();
1248#else
1249 size_t num_args = call->getNumArgOperands();
1250#endif
1251
1252 if (call->getAttributes().hasAttribute(AttributeList::ReturnIndex,
1253 "enzyme_type")) {
1254 auto attr = call->getAttributes().getAttribute(
1255 AttributeList::ReturnIndex, "enzyme_type");
1256 auto TT =
1257 TypeTree::parse(attr.getValueAsString(), call->getContext());
1258
1259 auto RegSize = I.getType()->isVoidTy()
1260 ? 0
1261 : (DL.getTypeSizeInBits(I.getType()) + 7) / 8;
1262 for (const auto &pair : TT.getMapping()) {
1263 if (pair.first[0] != -1) {
1264 if ((size_t)pair.first[0] >= RegSize) {
1265 llvm::errs() << " bad enzyme_type " << TT.str()
1266 << " RegSize=" << RegSize << " I:" << I << "\n";
1267 llvm::report_fatal_error("Canonicalization failed");
1268 }
1269 }
1270 }
1271 updateAnalysis(call, TT, call);
1272 }
1273 for (size_t i = 0; i < num_args; i++) {
1274 if (call->getAttributes().hasParamAttr(i, "enzyme_type")) {
1275 auto attr = call->getAttributes().getParamAttr(i, "enzyme_type");
1276 auto TT =
1277 TypeTree::parse(attr.getValueAsString(), call->getContext());
1278 auto RegSize = I.getType()->isVoidTy()
1279 ? 0
1280 : (DL.getTypeSizeInBits(I.getType()) + 7) / 8;
1281 for (const auto &pair : TT.getMapping()) {
1282 if (pair.first[0] != -1) {
1283 if ((size_t)pair.first[0] >= RegSize) {
1284 llvm::errs() << " bad enzyme_type " << TT.str()
1285 << " RegSize=" << RegSize << " I:" << I << "\n";
1286 llvm::report_fatal_error("Canonicalization failed");
1287 }
1288 }
1289 }
1290 updateAnalysis(call->getArgOperand(i), TT, call);
1291 }
1292 }
1293
1294 Function *F = call->getCalledFunction();
1295
1296 if (F) {
1297 if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex,
1298 "enzyme_type")) {
1299 auto attr = F->getAttributes().getAttribute(
1300 AttributeList::ReturnIndex, "enzyme_type");
1301 auto TT =
1302 TypeTree::parse(attr.getValueAsString(), call->getContext());
1303 auto RegSize = I.getType()->isVoidTy()
1304 ? 0
1305 : (DL.getTypeSizeInBits(I.getType()) + 7) / 8;
1306 for (const auto &pair : TT.getMapping()) {
1307 if (pair.first[0] != -1) {
1308 if ((size_t)pair.first[0] >= RegSize) {
1309 llvm::errs() << " bad enzyme_type " << TT.str()
1310 << " RegSize=" << RegSize << " I:" << I << "\n";
1311 llvm::report_fatal_error("Canonicalization failed");
1312 }
1313 }
1314 }
1315 updateAnalysis(call, TT, call);
1316 }
1317 size_t f_num_args = F->arg_size();
1318 for (size_t i = 0; i < f_num_args; i++) {
1319 if (F->getAttributes().hasParamAttr(i, "enzyme_type")) {
1320 auto attr = F->getAttributes().getParamAttr(i, "enzyme_type");
1321 auto TT =
1322 TypeTree::parse(attr.getValueAsString(), call->getContext());
1323 auto RegSize = I.getType()->isVoidTy()
1324 ? 0
1325 : (DL.getTypeSizeInBits(I.getType()) + 7) / 8;
1326 for (const auto &pair : TT.getMapping()) {
1327 if (pair.first[0] != -1) {
1328 if ((size_t)pair.first[0] >= RegSize) {
1329 llvm::errs()
1330 << " bad enzyme_type " << TT.str()
1331 << " RegSize=" << RegSize << " I:" << I << "\n";
1332 llvm::report_fatal_error("Canonicalization failed");
1333 }
1334 }
1335 }
1336 updateAnalysis(call->getArgOperand(i), TT, call);
1337 }
1338 }
1339 }
1340
1341 if (auto castinst = dyn_cast<ConstantExpr>(call->getCalledOperand())) {
1342 if (castinst->isCast())
1343 if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
1344 F = fn;
1345 }
1346 }
1347 if (F && F->getName().contains("__enzyme_float")) {
1348 assert(num_args == 1 || num_args == 2);
1349 assert(call->getArgOperand(0)->getType()->isPointerTy());
1350 TypeTree TT;
1351 ssize_t num = 1;
1352 if (num_args == 2) {
1353 assert(isa<ConstantInt>(call->getArgOperand(1)));
1354 auto CI = cast<ConstantInt>(call->getArgOperand(1));
1355 if (CI->isNegative())
1356 num = -1;
1357 else
1358 num = CI->getLimitedValue();
1359 }
1360 if (num == -1)
1361 TT.insert({(int)num}, Type::getFloatTy(call->getContext()));
1362 else
1363 for (size_t i = 0; i < (size_t)num; i += 4)
1364 TT.insert({(int)i}, Type::getFloatTy(call->getContext()));
1365 TT.insert({}, BaseType::Pointer);
1366 updateAnalysis(call->getOperand(0), TT.Only(-1, call), call);
1367 }
1368 if (F && F->getName().contains("__enzyme_double")) {
1369 assert(num_args == 1 || num_args == 2);
1370 assert(call->getArgOperand(0)->getType()->isPointerTy());
1371 TypeTree TT;
1372 size_t num = 1;
1373 if (num_args == 2) {
1374 assert(isa<ConstantInt>(call->getArgOperand(1)));
1375 num = cast<ConstantInt>(call->getArgOperand(1))->getLimitedValue();
1376 }
1377 for (size_t i = 0; i < num; i += 8)
1378 TT.insert({(int)i}, Type::getDoubleTy(call->getContext()));
1379 TT.insert({}, BaseType::Pointer);
1380 updateAnalysis(call->getOperand(0), TT.Only(-1, call), call);
1381 }
1382 if (F && F->getName().contains("__enzyme_integer")) {
1383 assert(num_args == 1 || num_args == 2);
1384 assert(call->getArgOperand(0)->getType()->isPointerTy());
1385 size_t num = 1;
1386 if (num_args == 2) {
1387 assert(isa<ConstantInt>(call->getArgOperand(1)));
1388 num = cast<ConstantInt>(call->getArgOperand(1))->getLimitedValue();
1389 }
1390 TypeTree TT;
1391 for (size_t i = 0; i < num; i++)
1392 TT.insert({(int)i}, BaseType::Integer);
1393 TT.insert({}, BaseType::Pointer);
1394 updateAnalysis(call->getOperand(0), TT.Only(-1, call), call);
1395 }
1396 if (F && F->getName().contains("__enzyme_pointer")) {
1397 assert(num_args == 1 || num_args == 2);
1398 assert(call->getArgOperand(0)->getType()->isPointerTy());
1399 TypeTree TT;
1400 size_t num = 1;
1401 if (num_args == 2) {
1402 assert(isa<ConstantInt>(call->getArgOperand(1)));
1403 num = cast<ConstantInt>(call->getArgOperand(1))->getLimitedValue();
1404 }
1405 for (size_t i = 0; i < num;
1406 i += ((DL.getPointerSizeInBits() + 7) / 8))
1407 TT.insert({(int)i}, BaseType::Pointer);
1408 TT.insert({}, BaseType::Pointer);
1409 updateAnalysis(call->getOperand(0), TT.Only(-1, call), call);
1410 }
1411 if (F) {
1412 StringSet<> JuliaKnownTypes = {"julia.gc_alloc_obj",
1413 "jl_alloc_array_1d",
1414 "jl_alloc_array_2d",
1415 "jl_alloc_array_3d",
1416 "ijl_alloc_array_1d",
1417 "ijl_alloc_array_2d",
1418 "ijl_alloc_array_3d",
1419 "jl_gc_alloc_typed",
1420 "ijl_gc_alloc_typed",
1421 "jl_alloc_genericmemory",
1422 "ijl_alloc_genericmemory",
1423 "jl_alloc_genericmemory_unchecked",
1424 "ijl_alloc_genericmemory_unchecked",
1425 "jl_new_array",
1426 "ijl_new_array"};
1427 if (JuliaKnownTypes.count(F->getName())) {
1428 visitCallBase(*call);
1429 continue;
1430 }
1431 }
1432 }
1433
1434 TypeTree vdptr = parseTBAA(I, DL, MST);
1435
1436 // If we don't have any useful information,
1437 // don't bother updating
1438 if (!vdptr.isKnownPastPointer())
1439 continue;
1440
1441 if (CallBase *call = dyn_cast<CallBase>(&I)) {
1442 if (call->getCalledFunction() &&
1443 (call->getCalledFunction()->getIntrinsicID() == Intrinsic::memcpy ||
1444 call->getCalledFunction()->getIntrinsicID() ==
1445 Intrinsic::memmove)) {
1446 int64_t copySize = 1;
1447 for (auto val : fntypeinfo.knownIntegralValues(call->getOperand(2),
1448 DT, intseen, SE)) {
1449 copySize = max(copySize, val);
1450 }
1451 TypeTree update =
1452 vdptr
1453 .ShiftIndices(DL, /*init offset*/ 0,
1454 /*max size*/ copySize, /*new offset*/ 0)
1455 .Only(-1, call);
1456
1457 updateAnalysis(call->getOperand(0), update, call);
1458 updateAnalysis(call->getOperand(1), update, call);
1459 continue;
1460 } else if (call->getCalledFunction() &&
1461 (call->getCalledFunction()->getIntrinsicID() ==
1462 Intrinsic::memset ||
1463 call->getCalledFunction()->getName() ==
1464 "memset_pattern16")) {
1465 int64_t copySize = 1;
1466 for (auto val : fntypeinfo.knownIntegralValues(call->getOperand(2),
1467 DT, intseen, SE)) {
1468 copySize = max(copySize, val);
1469 }
1470 TypeTree update =
1471 vdptr
1472 .ShiftIndices(DL, /*init offset*/ 0,
1473 /*max size*/ copySize, /*new offset*/ 0)
1474 .Only(-1, call);
1475
1476 updateAnalysis(call->getOperand(0), update, call);
1477 continue;
1478#if LLVM_VERSION_MAJOR >= 20
1479 } else if (call->getCalledFunction() &&
1480 (call->getCalledFunction()->getIntrinsicID() ==
1481 Intrinsic::experimental_memset_pattern)) {
1482 int64_t copySize = 1;
1483 for (auto val : fntypeinfo.knownIntegralValues(call->getOperand(2),
1484 DT, intseen, SE)) {
1485 copySize = max(copySize, val);
1486 }
1487 TypeTree update =
1488 vdptr
1489 .ShiftIndices(DL, /*init offset*/ 0,
1490 /*max size*/ copySize, /*new offset*/ 0)
1491 .Only(-1, call);
1492
1493 updateAnalysis(call->getOperand(0), update, call);
1494 continue;
1495#endif
1496 } else if (call->getCalledFunction() &&
1497 call->getCalledFunction()->getIntrinsicID() ==
1498 Intrinsic::masked_gather) {
1499 auto VT = cast<VectorType>(call->getType());
1500 auto LoadSize = (DL.getTypeSizeInBits(VT) + 7) / 8;
1501 TypeTree req = vdptr.Only(-1, call);
1502 updateAnalysis(call, req.Lookup(LoadSize, DL), call);
1503 // TODO use mask to propagate up to relevant pointer
1504 } else if (call->getCalledFunction() &&
1505 call->getCalledFunction()->getIntrinsicID() ==
1506 Intrinsic::masked_scatter) {
1507 // TODO use mask to propagate up to relevant pointer
1508 } else if (call->getCalledFunction() &&
1509 call->getCalledFunction()->getIntrinsicID() ==
1510 Intrinsic::masked_load) {
1511 auto VT = cast<VectorType>(call->getType());
1512 auto LoadSize = (DL.getTypeSizeInBits(VT) + 7) / 8;
1513 TypeTree req = vdptr.Only(-1, call);
1514 updateAnalysis(call, req.Lookup(LoadSize, DL), call);
1515 // TODO use mask to propagate up to relevant pointer
1516 } else if (call->getCalledFunction() &&
1517 call->getCalledFunction()->getIntrinsicID() ==
1518 Intrinsic::masked_store) {
1519 // TODO use mask to propagate up to relevant pointer
1520 } else if (call->getType()->isPointerTy()) {
1521 updateAnalysis(call, vdptr.Only(-1, call), call);
1522 } else {
1523 llvm::errs() << " unknown tbaa call instruction user inst: " << I
1524 << " vdptr: " << vdptr.str() << "\n";
1525 }
1526 } else if (auto SI = dyn_cast<StoreInst>(&I)) {
1527 auto StoreSize =
1528 (DL.getTypeSizeInBits(SI->getValueOperand()->getType()) + 7) / 8;
1529 updateAnalysis(SI->getPointerOperand(),
1530 vdptr
1531 // Don't propagate "Anything" into ptr
1532 .PurgeAnything()
1533 // Cut off any values outside of store
1534 .ShiftIndices(DL, /*init offset*/ 0,
1535 /*max size*/ StoreSize,
1536 /*new offset*/ 0)
1537 .Only(-1, SI),
1538 SI);
1539 TypeTree req = vdptr.Only(-1, SI);
1540 updateAnalysis(SI->getValueOperand(), req.Lookup(StoreSize, DL), SI);
1541 } else if (auto LI = dyn_cast<LoadInst>(&I)) {
1542 auto LoadSize = (DL.getTypeSizeInBits(LI->getType()) + 7) / 8;
1543 updateAnalysis(LI->getPointerOperand(),
1544 vdptr
1545 // Don't propagate "Anything" into ptr
1546 .PurgeAnything()
1547 // Cut off any values outside of load
1548 .ShiftIndices(DL, /*init offset*/ 0,
1549 /*max size*/ LoadSize,
1550 /*new offset*/ 0)
1551 .Only(-1, LI),
1552 LI);
1553 TypeTree req = vdptr.Only(-1, LI);
1554 updateAnalysis(LI, req.Lookup(LoadSize, DL), LI);
1555 } else {
1556 llvm::errs() << " inst: " << I << " vdptr: " << vdptr.str() << "\n";
1557 assert(0 && "unknown tbaa instruction user");
1558 llvm_unreachable("unknown tbaa instruction user");
1559 }
1560 }
1561 }
1562}
1563
1565 if (PHIRecur)
1566 return;
1567 bool Changed;
1568 do {
1569 Changed = false;
1570 for (BasicBlock &BB : *fntypeinfo.Function) {
1571 for (Instruction &inst : BB) {
1572 if (PHINode *phi = dyn_cast<PHINode>(&inst)) {
1573 if (direction & DOWN && phi->getType()->isIntOrIntVectorTy() &&
1574 !getAnalysis(phi).isKnown()) {
1575 // Assume that this is an integer, does that mean we can prove that
1576 // the incoming operands are integral
1577
1579 notForAnalysis, *this, DOWN,
1580 /*PHIRecur*/ true);
1581 tmpAnalysis.intseen = intseen;
1582 tmpAnalysis.analysis = analysis;
1583 tmpAnalysis.analysis[phi] =
1585 for (auto U : phi->users()) {
1586 if (auto I = dyn_cast<Instruction>(U)) {
1587 tmpAnalysis.visit(*I);
1588 }
1589 }
1590 tmpAnalysis.run();
1591 if (!tmpAnalysis.Invalid) {
1592 TypeTree Result = tmpAnalysis.getAnalysis(phi);
1593 for (auto &op : phi->incoming_values()) {
1594 Result &= tmpAnalysis.getAnalysis(op);
1595 }
1596 if (Result == TypeTree(BaseType::Integer).Only(-1, phi) ||
1597 Result == TypeTree(BaseType::Anything).Only(-1, phi)) {
1598 updateAnalysis(phi, Result, phi);
1599 for (auto &pair : tmpAnalysis.analysis) {
1600 updateAnalysis(pair.first, pair.second, phi);
1601 }
1602 Changed = true;
1603 }
1604 }
1605 }
1606
1607 if (direction & DOWN && phi->getType()->isFPOrFPVectorTy() &&
1608 !getAnalysis(phi).isKnown()) {
1609 // Assume that this is an integer, does that mean we can prove that
1610 // the incoming operands are integral
1612 notForAnalysis, *this, DOWN,
1613 /*PHIRecur*/ true);
1614 tmpAnalysis.intseen = intseen;
1615 tmpAnalysis.analysis = analysis;
1616 tmpAnalysis.analysis[phi] =
1617 TypeTree(phi->getType()->getScalarType()).Only(-1, phi);
1618 for (auto U : phi->users()) {
1619 if (auto I = dyn_cast<Instruction>(U)) {
1620 tmpAnalysis.visit(*I);
1621 }
1622 }
1623 tmpAnalysis.run();
1624 if (!tmpAnalysis.Invalid) {
1625 TypeTree Result = tmpAnalysis.getAnalysis(phi);
1626 for (auto &op : phi->incoming_values()) {
1627 Result &= tmpAnalysis.getAnalysis(op);
1628 }
1629 if (Result ==
1630 TypeTree(phi->getType()->getScalarType()).Only(-1, phi) ||
1631 Result == TypeTree(BaseType::Anything).Only(-1, phi)) {
1632 updateAnalysis(phi, Result, phi);
1633 for (auto &pair : tmpAnalysis.analysis) {
1634 updateAnalysis(pair.first, pair.second, phi);
1635 }
1636 Changed = true;
1637 }
1638 }
1639 }
1640 }
1641 }
1642 }
1643 } while (Changed);
1644 return;
1645}
1646
1648
1649 TimeTraceScope timeScope("Type Analysis", fntypeinfo.Function->getName());
1650
1651 // This function runs a full round of type analysis.
1652 // This works by doing two stages of analysis,
1653 // with a "deduced integer types for unused" values
1654 // sandwiched in-between. This is done because we only
1655 // perform that check for values without types.
1656 //
1657 // For performance reasons in each round of type analysis
1658 // only analyze any call instances after all other potential
1659 // updates have been done. This is to minimize the number
1660 // of expensive interprocedural analyses
1661 std::deque<CallBase *> pendingCalls;
1662
1663 do {
1664 while (!Invalid && workList.size()) {
1665 auto todo = *workList.begin();
1666 workList.erase(workList.begin());
1667 if (auto call = dyn_cast<CallBase>(todo)) {
1668 StringRef funcName = getFuncNameFromCall(call);
1669 auto ci = getFunctionFromCall(call);
1670 if (ci && !ci->empty()) {
1671 if (interprocedural.CustomRules.find(funcName) ==
1673 pendingCalls.push_back(call);
1674 continue;
1675 }
1676 }
1677 }
1678 visitValue(*todo);
1679 }
1680
1681 if (pendingCalls.size() > 0) {
1682 auto todo = pendingCalls.front();
1683 pendingCalls.pop_front();
1684 visitValue(*todo);
1685 continue;
1686 } else
1687 break;
1688
1689 } while (1);
1690
1692
1693 do {
1694
1695 while (!Invalid && workList.size()) {
1696 auto todo = *workList.begin();
1697 workList.erase(workList.begin());
1698 if (auto ci = dyn_cast<CallBase>(todo)) {
1699 pendingCalls.push_back(ci);
1700 continue;
1701 }
1702 visitValue(*todo);
1703 }
1704
1705 if (pendingCalls.size() > 0) {
1706 auto todo = pendingCalls.front();
1707 pendingCalls.pop_front();
1708 visitValue(*todo);
1709 continue;
1710 } else
1711 break;
1712
1713 } while (1);
1714}
1715
1717 if (auto CE = dyn_cast<ConstantExpr>(&val)) {
1718 visitConstantExpr(*CE);
1719 }
1720
1721 if (isa<Constant>(&val)) {
1722 return;
1723 }
1724
1725 if (!isa<Argument>(&val) && !isa<Instruction>(&val))
1726 return;
1727
1728 if (auto *FPMO = dyn_cast<FPMathOperator>(&val)) {
1729 if (FPMO->getOpcode() == Instruction::FNeg) {
1730 Value *op = FPMO->getOperand(0);
1731 auto ty = op->getType()->getScalarType();
1732 assert(ty->isFloatingPointTy());
1733 ConcreteType dt(ty);
1734 updateAnalysis(op, TypeTree(ty).Only(-1, nullptr),
1735 cast<Instruction>(&val));
1736 updateAnalysis(FPMO, TypeTree(ty).Only(-1, nullptr),
1737 cast<Instruction>(&val));
1738 return;
1739 }
1740 }
1741
1742 if (auto inst = dyn_cast<Instruction>(&val)) {
1743 visit(*inst);
1744 }
1745}
1746
1747void TypeAnalyzer::visitConstantExpr(ConstantExpr &CE) {
1748 if (CE.isCast()) {
1749 if (direction & DOWN)
1750 updateAnalysis(&CE, getAnalysis(CE.getOperand(0)), &CE);
1751 if (direction & UP)
1752 updateAnalysis(CE.getOperand(0), getAnalysis(&CE), &CE);
1753 return;
1754 }
1755 if (CE.getOpcode() == Instruction::GetElementPtr) {
1756 visitGEPOperator(*cast<GEPOperator>(&CE));
1757 return;
1758 }
1759 auto I = CE.getAsInstruction();
1760 I->insertBefore(fntypeinfo.Function->getEntryBlock().getTerminator());
1761 analysis[I] = analysis[&CE];
1762 visit(*I);
1763 updateAnalysis(&CE, analysis[I], &CE);
1764 analysis.erase(I);
1765 if (workList.remove(I)) {
1766 workList.insert(&CE);
1767 }
1768 I->eraseFromParent();
1769}
1770
1771void TypeAnalyzer::visitCmpInst(CmpInst &cmp) {
1772 // No directionality check needed as always true
1773 updateAnalysis(&cmp, TypeTree(BaseType::Integer).Only(-1, &cmp), &cmp);
1774 if (direction & UP) {
1776 cmp.getOperand(0),
1777 TypeTree(getAnalysis(cmp.getOperand(1)).Inner0().PurgeAnything())
1778 .Only(-1, &cmp),
1779 &cmp);
1781 cmp.getOperand(1),
1782 TypeTree(getAnalysis(cmp.getOperand(0)).Inner0().PurgeAnything())
1783 .Only(-1, &cmp),
1784 &cmp);
1785 }
1786}
1787
1789 // No directionality check needed as always true
1790 updateAnalysis(I.getArraySize(), TypeTree(BaseType::Integer).Only(-1, &I),
1791 &I);
1792
1793 auto ptr = TypeTree(BaseType::Pointer);
1794
1795 if (auto CI = dyn_cast<ConstantInt>(I.getArraySize())) {
1796 auto &DL = I.getParent()->getParent()->getParent()->getDataLayout();
1797 auto LoadSize = CI->getZExtValue() *
1798 (DL.getTypeSizeInBits(I.getAllocatedType()) + 7) / 8;
1799 // Only propagate mappings in range that aren't "Anything" into the pointer
1800 ptr |= getAnalysis(&I).Lookup(LoadSize, DL);
1801 }
1802 updateAnalysis(&I, ptr.Only(-1, &I), &I);
1803}
1804
1806 auto &DL = I.getParent()->getParent()->getParent()->getDataLayout();
1807 auto LoadSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8;
1808
1809 if (direction & UP) {
1810 // Only propagate mappings in range that aren't "Anything" into the pointer
1811 auto ptr = getAnalysis(&I).PurgeAnything().ShiftIndices(
1812 DL, /*start*/ 0, LoadSize, /*addOffset*/ 0);
1814 updateAnalysis(I.getOperand(0), ptr.Only(-1, &I), &I);
1815 }
1816 if (direction & DOWN)
1817 updateAnalysis(&I, getAnalysis(I.getOperand(0)).Lookup(LoadSize, DL), &I);
1818}
1819
1821 auto &DL = I.getParent()->getParent()->getParent()->getDataLayout();
1822 auto StoreSize =
1823 (DL.getTypeSizeInBits(I.getValueOperand()->getType()) + 7) / 8;
1824
1825 // Rust specific rule, if storing an integer equal to the alignment
1826 // of a store, assuming nothing (or assume it is a pointer)
1827 // https://doc.rust-lang.org/src/core/ptr/non_null.rs.html#70-78
1828 if (RustTypeRules)
1829 if (auto CI = dyn_cast<ConstantInt>(I.getValueOperand())) {
1830 auto alignment = I.getAlign().value();
1831
1832 if (CI->getLimitedValue() == alignment) {
1833 return;
1834 }
1835 }
1836
1837 // Only propagate mappings in range that aren't "Anything" into the pointer
1838 auto ptr = TypeTree(BaseType::Pointer);
1839 auto purged = getAnalysis(I.getValueOperand())
1840 .PurgeAnything()
1841 .ShiftIndices(DL, /*start*/ 0, StoreSize, /*addOffset*/ 0)
1842 .ReplaceMinus();
1843 ptr |= purged;
1844
1845 if (direction & UP) {
1846 updateAnalysis(I.getPointerOperand(), ptr.Only(-1, &I), &I);
1847
1848 // Note that we also must purge anything from ptr => value in case we store
1849 // to a nullptr which has type [-1, -1]: Anything. While storing to a
1850 // nullptr is obviously bad, this doesn't mean the value we're storing is an
1851 // Anything
1852 updateAnalysis(I.getValueOperand(),
1853 getAnalysis(I.getPointerOperand())
1854 .PurgeAnything()
1855 .Lookup(StoreSize, DL),
1856 &I);
1857 }
1858}
1859
1860// Give a list of sets representing the legal set of values at a given index
1861// return a set of all possible combinations of those values
1862template <typename T>
1863std::set<SmallVector<T, 4>> getSet(ArrayRef<std::set<T>> todo, size_t idx) {
1864 assert(idx < todo.size());
1865 std::set<SmallVector<T, 4>> out;
1866 if (idx == 0) {
1867 for (auto val : todo[0]) {
1868 out.insert({val});
1869 }
1870 return out;
1871 }
1872
1873 auto old = getSet(todo, idx - 1);
1874 for (const auto &oldv : old) {
1875 for (auto val : todo[idx]) {
1876 auto nex = oldv;
1877 nex.push_back(val);
1878 out.insert(nex);
1879 }
1880 }
1881 return out;
1882}
1883
1884void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) {
1885 visitGEPOperator(*cast<GEPOperator>(&gep));
1886}
1887
1888void TypeAnalyzer::visitGEPOperator(GEPOperator &gep) {
1889 auto inst = dyn_cast<Instruction>(&gep);
1890 if (isa<UndefValue>(gep.getPointerOperand())) {
1891 updateAnalysis(&gep, TypeTree(BaseType::Anything).Only(-1, inst), &gep);
1892 return;
1893 }
1894 if (isa<ConstantPointerNull>(gep.getPointerOperand())) {
1895 bool nonZero = false;
1896 bool legal = true;
1897 for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) {
1898 auto ind = I->get();
1899 if (auto CI = dyn_cast<ConstantInt>(ind)) {
1900 if (!CI->isZero()) {
1901 nonZero = true;
1902 continue;
1903 }
1904 }
1905 auto CT = getAnalysis(ind).Inner0();
1906 if (CT == BaseType::Integer) {
1907 continue;
1908 }
1909 legal = false;
1910 break;
1911 }
1912 if (legal && nonZero) {
1913 updateAnalysis(&gep, TypeTree(BaseType::Integer).Only(-1, inst), &gep);
1914 return;
1915 }
1916 }
1917 if (auto GV = dyn_cast<GlobalVariable>(gep.getPointerOperand())) {
1918 // from julia code, do not propagate int to operands
1919 if (GV->getName() == "small_typeof" || GV->getName() == "jl_small_typeof") {
1920 TypeTree T;
1921 T.insert({-1}, BaseType::Pointer);
1922 T.insert({-1, -1}, BaseType::Pointer);
1923 updateAnalysis(&gep, T, &gep);
1924 return;
1925 }
1926 }
1927
1928 if (gep.idx_begin() == gep.idx_end()) {
1929 if (direction & DOWN)
1930 updateAnalysis(&gep, getAnalysis(gep.getPointerOperand()), &gep);
1931 if (direction & UP)
1932 updateAnalysis(gep.getPointerOperand(), getAnalysis(&gep), &gep);
1933 return;
1934 }
1935
1936 auto &DL = fntypeinfo.Function->getParent()->getDataLayout();
1937
1938 auto pointerAnalysis = getAnalysis(gep.getPointerOperand());
1939
1940 // If we know that the pointer operand is indeed a pointer, then the indicies
1941 // must be integers Note that we can't do this if we don't know the pointer
1942 // operand is a pointer since doing 1[pointer] is legal
1943 // sadly this still may not work since (nullptr)[fn] => fn where fn is
1944 // pointer and not int (whereas nullptr is a pointer) However if we are
1945 // inbounds you are only allowed to have nullptr[0] or nullptr[nullptr],
1946 // making this valid
1947 // Assuming nullptr[nullptr] doesn't occur in practice, the following
1948 // is valid. We could make it always valid by checking the pointer
1949 // operand explicitly is a pointer.
1950 if (direction & UP) {
1951 bool has_non_const_idx = false;
1952 for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) {
1953 auto ind = I->get();
1954 if (!isa<ConstantInt>(ind)) {
1955 has_non_const_idx = true;
1956 break;
1957 }
1958 }
1959
1960 if (has_non_const_idx &&
1961 (gep.isInBounds() ||
1963 pointerAnalysis.Inner0() == BaseType::Pointer &&
1964 getAnalysis(&gep).Inner0() == BaseType::Pointer))) {
1965 for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) {
1966 auto ind = I->get();
1967 updateAnalysis(ind, TypeTree(BaseType::Integer).Only(-1, inst), &gep);
1968 }
1969 }
1970 }
1971
1972 // If one of these is known to be a pointer, propagate it if either in bounds
1973 // or all operands are integral/unknown
1974 bool pointerPropagate = gep.isInBounds();
1975 if (!pointerPropagate) {
1976 bool allIntegral = true;
1977 for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) {
1978 auto ind = I->get();
1979 auto CT = getAnalysis(ind).Inner0();
1980 if (CT != BaseType::Integer && CT != BaseType::Anything) {
1981 allIntegral = false;
1982 break;
1983 }
1984 }
1985 if (allIntegral)
1986 pointerPropagate = true;
1987 }
1988
1989 if (!pointerPropagate)
1990 return;
1991
1992 if (direction & DOWN) {
1993 bool legal = true;
1994 auto keepMinus = pointerAnalysis.KeepMinusOne(legal);
1995 if (!legal) {
1997 CustomErrorHandler("Could not keep minus one", wrap(&gep),
1998 ErrorType::IllegalTypeAnalysis, this, nullptr,
1999 nullptr);
2000 else {
2001 dump();
2002 llvm::errs() << " could not perform minus one for gep'd: " << gep
2003 << "\n";
2004 }
2005 }
2006 updateAnalysis(&gep, keepMinus, &gep);
2007 // Don't propagate pointer type when the input pointer is null
2008 if (!isa<ConstantPointerNull>(gep.getPointerOperand())) {
2009 updateAnalysis(&gep, TypeTree(pointerAnalysis.Inner0()).Only(-1, inst),
2010 &gep);
2011 }
2012 }
2013 if (direction & UP)
2014 updateAnalysis(gep.getPointerOperand(),
2015 TypeTree(getAnalysis(&gep).Inner0()).Only(-1, inst), &gep);
2016
2017 TypeTree upTree;
2018 TypeTree downTree;
2019
2020 TypeTree gepData0;
2021 TypeTree pointerData0;
2022 if (direction & UP)
2023 gepData0 = getAnalysis(&gep).Data0();
2024 if (direction & DOWN)
2025 pointerData0 = pointerAnalysis.Data0();
2026
2027 auto BitWidth = DL.getIndexSizeInBits(gep.getPointerAddressSpace());
2028
2029 APInt constOffset(BitWidth, 0);
2030
2031#if LLVM_VERSION_MAJOR >= 20
2032 SmallMapVector<Value *, APInt, 4> VariableOffsets;
2033#else
2034 MapVector<Value *, APInt> VariableOffsets;
2035#endif
2036 bool legalOffset =
2037 collectOffset(&gep, DL, BitWidth, VariableOffsets, constOffset);
2038 (void)legalOffset;
2039 assert(legalOffset);
2040
2041 SmallVector<std::set<int>, 4> idnext;
2042
2043 SmallPtrSet<BasicBlock *, 1> previousLoopInductionHeaders;
2044 {
2045 Value *ptr = gep.getPointerOperand();
2046 while (true) {
2047 if (auto gepop = dyn_cast<GEPOperator>(ptr)) {
2048 for (auto I = gepop->idx_begin(), E = gepop->idx_end(); I != E; I++) {
2049 SmallPtrSet<PHINode *, 1> seen;
2050 for (auto loopInd : findLoopIndices(*I, LI, DT, seen)) {
2051 previousLoopInductionHeaders.insert(loopInd);
2052 }
2053 }
2054 ptr = gepop->getPointerOperand();
2055 continue;
2056 }
2057 if (auto CI = dyn_cast<CastInst>(ptr)) {
2058 ptr = CI->getOperand(0);
2059 continue;
2060 }
2061 break;
2062 }
2063 }
2064
2065 for (auto &pair : VariableOffsets) {
2066 auto a = pair.first;
2067 auto iset = fntypeinfo.knownIntegralValues(a, DT, intseen, SE);
2068 std::set<int> vset;
2069 for (auto i : iset) {
2070 // Don't consider negative indices of gep
2071 if (i < 0)
2072 continue;
2073 vset.insert(i);
2074 }
2075 if (vset.size() == 0)
2076 return;
2077
2078 // If seen the same variable before with > 1 option, we will accidentally
2079 // do an offset for [option1, option2] * oldOffset + [option1, option2] *
2080 // newOffset
2081 // instead of [option1, option2] * (oldOffset + newOffset).
2082 // In this case abort
2083 // TODO, in the future, mutually compute the offset together.
2084 if (vset.size() != 1) {
2085 SmallPtrSet<PHINode *, 1> seen;
2086 for (auto loopInd : findLoopIndices(pair.first, LI, DT, seen))
2087 if (previousLoopInductionHeaders.count(loopInd))
2088 return;
2089 }
2090 idnext.push_back(vset);
2091 }
2092
2093 // Stores pair ([whether first offset is zero], offset)
2094 std::vector<std::pair<bool, int>> offsets;
2095 Value *firstIdx = *gep.idx_begin();
2096 if (VariableOffsets.size() == 0) {
2097 bool firstIsZero = cast<ConstantInt>(firstIdx)->getLimitedValue() == 0;
2098 offsets.emplace_back(firstIsZero, (int)constOffset.getLimitedValue());
2099 } else {
2100 bool firstIsZero = false;
2101 if (auto CI = dyn_cast<ConstantInt>(firstIdx))
2102 firstIsZero = CI->getLimitedValue() == 0;
2103 for (auto vec : getSet<int>(idnext, idnext.size() - 1)) {
2104 APInt nextOffset = constOffset;
2105 for (auto [varpair, const_value] : llvm::zip(VariableOffsets, vec)) {
2106 nextOffset += varpair.second * const_value;
2107 if (varpair.first == firstIdx)
2108 firstIsZero = const_value == 0;
2109 }
2110 offsets.emplace_back(firstIsZero, (int)nextOffset.getLimitedValue());
2111 }
2112 }
2113
2114 bool seenIdx = false;
2115
2116 for (auto [firstIsZero, off] : offsets) {
2117 // TODO also allow negative offsets
2118 if (off < 0)
2119 continue;
2120
2121 int maxSize = -1;
2122 if (firstIsZero) {
2123 maxSize = DL.getTypeAllocSizeInBits(gep.getResultElementType()) / 8;
2124 }
2125
2126 if (direction & DOWN) {
2127 auto shft =
2128 pointerData0.ShiftIndices(DL, /*init offset*/ off,
2129 /*max size*/ maxSize, /*newoffset*/ 0);
2130 if (seenIdx)
2131 downTree &= shft;
2132 else
2133 downTree = shft;
2134 }
2135
2136 if (direction & UP) {
2137 auto shft = gepData0.ShiftIndices(DL, /*init offset*/ 0, /*max size*/ -1,
2138 /*new offset*/ off);
2139 if (seenIdx)
2140 upTree |= shft;
2141 else
2142 upTree = shft;
2143 }
2144 seenIdx = true;
2145 }
2146 if (direction & DOWN)
2147 updateAnalysis(&gep, downTree.Only(-1, inst), &gep);
2148 if (direction & UP)
2149 updateAnalysis(gep.getPointerOperand(), upTree.Only(-1, inst), &gep);
2150}
2151
2152void TypeAnalyzer::visitPHINode(PHINode &phi) {
2153 if (direction & UP) {
2154 TypeTree upVal = getAnalysis(&phi);
2155 // only propagate anything's up if there is one
2156 // incoming value
2157 Value *seen = phi.getIncomingValue(0);
2158 for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i) {
2159 if (seen != phi.getIncomingValue(i)) {
2160 seen = nullptr;
2161 break;
2162 }
2163 }
2164
2165 if (!seen) {
2166 upVal = upVal.PurgeAnything();
2167 }
2168
2169 if (EnzymeStrictAliasing || seen) {
2170 auto L = LI.getLoopFor(phi.getParent());
2171 bool isHeader = L && L->getHeader() == phi.getParent();
2172 for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i) {
2173 if (!isHeader || !L->contains(phi.getIncomingBlock(i))) {
2174 updateAnalysis(phi.getIncomingValue(i), upVal, &phi);
2175 }
2176 }
2177 } else {
2178 if (EnzymePrintType) {
2179 for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i) {
2180 llvm::errs() << " skipping update into ";
2181 phi.getIncomingValue(i)->print(llvm::errs(), *MST);
2182 llvm::errs() << " of " << upVal.str() << " from ";
2183 phi.print(llvm::errs(), *MST);
2184 llvm::errs() << "\n";
2185 }
2186 }
2187 }
2188 }
2189
2190 assert(phi.getNumIncomingValues() > 0);
2191
2192 // TODO generalize this (and for recursive, etc)
2193
2194 for (int i = 0; i < 2; i++) {
2195
2196 std::deque<Value *> vals;
2197 std::set<Value *> seen{&phi};
2198 for (auto &op : phi.incoming_values()) {
2199 vals.push_back(op);
2200 }
2201 SmallVector<BinaryOperator *, 4> bos;
2202
2203 // Unique values that propagate into this phi
2204 SmallVector<Value *, 4> UniqueValues;
2205
2206 while (vals.size()) {
2207 Value *todo = vals.front();
2208 vals.pop_front();
2209
2210 if (auto bo = dyn_cast<BinaryOperator>(todo)) {
2211 if (bo->getOpcode() == BinaryOperator::Add) {
2212 if (isa<Constant>(bo->getOperand(0))) {
2213 bos.push_back(bo);
2214 todo = bo->getOperand(1);
2215 }
2216 if (isa<Constant>(bo->getOperand(1))) {
2217 bos.push_back(bo);
2218 todo = bo->getOperand(0);
2219 }
2220 }
2221 }
2222
2223 if (seen.count(todo))
2224 continue;
2225 seen.insert(todo);
2226
2227 if (auto nphi = dyn_cast<PHINode>(todo)) {
2228 if (i == 0) {
2229 for (auto &op : nphi->incoming_values()) {
2230 vals.push_back(op);
2231 }
2232 continue;
2233 }
2234 }
2235 if (auto sel = dyn_cast<SelectInst>(todo)) {
2236 vals.push_back(sel->getOperand(1));
2237 vals.push_back(sel->getOperand(2));
2238 continue;
2239 }
2240 UniqueValues.push_back(todo);
2241 }
2242
2243 TypeTree PhiTypes;
2244 bool set = false;
2245
2246 for (size_t i = 0, size = UniqueValues.size(); i < size; ++i) {
2247 TypeTree newData = getAnalysis(UniqueValues[i]);
2248 if (UniqueValues.size() == 2) {
2249 if (auto BO = dyn_cast<BinaryOperator>(UniqueValues[i])) {
2250 if (BO->getOpcode() == BinaryOperator::Add ||
2251 BO->getOpcode() == BinaryOperator::Mul) {
2252 TypeTree otherData = getAnalysis(UniqueValues[1 - i]);
2253 // If we are adding/muling to a constant to derive this, we can
2254 // assume it to be an integer rather than Anything
2255 if (isa<Constant>(UniqueValues[1 - i])) {
2256 otherData = TypeTree(BaseType::Integer).Only(-1, &phi);
2257 }
2258 if (BO->getOperand(0) == &phi) {
2259 set = true;
2260 PhiTypes = otherData;
2261 bool Legal = true;
2262 PhiTypes.binopIn(Legal, getAnalysis(BO->getOperand(1)),
2263 BO->getOpcode());
2264 if (!Legal) {
2265 std::string str;
2266 raw_string_ostream ss(str);
2267 if (!CustomErrorHandler) {
2268 llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
2269 llvm::errs() << *fntypeinfo.Function << "\n";
2270 dump(ss);
2271 }
2272 ss << "Illegal updateBinop Analysis " << *BO << "\n";
2273 ss << "Illegal binopIn(0): " << *BO
2274 << " lhs: " << PhiTypes.str()
2275 << " rhs: " << getAnalysis(BO->getOperand(0)).str() << "\n";
2276 if (CustomErrorHandler) {
2277 CustomErrorHandler(str.c_str(), wrap(BO),
2279 (void *)this, wrap(BO), nullptr);
2280 }
2281 EmitFailure("IllegalUpdateAnalysis", BO->getDebugLoc(), BO,
2282 ss.str());
2283 report_fatal_error("Performed illegal updateAnalysis");
2284 }
2285 break;
2286 } else if (BO->getOperand(1) == &phi) {
2287 set = true;
2288 PhiTypes = getAnalysis(BO->getOperand(0));
2289 bool Legal = true;
2290 PhiTypes.binopIn(Legal, otherData, BO->getOpcode());
2291 if (!Legal) {
2292 std::string str;
2293 raw_string_ostream ss(str);
2294 if (!CustomErrorHandler) {
2295 llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
2296 llvm::errs() << *fntypeinfo.Function << "\n";
2297 dump(ss);
2298 }
2299 ss << "Illegal updateBinop Analysis " << *BO << "\n";
2300 ss << "Illegal binopIn(1): " << *BO
2301 << " lhs: " << PhiTypes.str() << " rhs: " << otherData.str()
2302 << "\n";
2303 if (CustomErrorHandler) {
2304 CustomErrorHandler(str.c_str(), wrap(BO),
2306 (void *)this, wrap(BO), nullptr);
2307 }
2308 EmitFailure("IllegalUpdateAnalysis", BO->getDebugLoc(), BO,
2309 ss.str());
2310 report_fatal_error("Performed illegal updateAnalysis");
2311 }
2312 break;
2313 }
2314 } else if (BO->getOpcode() == BinaryOperator::Sub) {
2315 // Repeated subtraction from a type X yields the type X back
2316 TypeTree otherData = getAnalysis(UniqueValues[1 - i]);
2317 // If we are subtracting from a constant to derive this, we can
2318 // assume it to be an integer rather than Anything
2319 if (isa<Constant>(UniqueValues[1 - i])) {
2320 otherData = TypeTree(BaseType::Integer).Only(-1, &phi);
2321 }
2322 if (BO->getOperand(0) == &phi) {
2323 set = true;
2324 PhiTypes = otherData;
2325 break;
2326 }
2327 }
2328 }
2329 }
2330 if (set) {
2331 PhiTypes &= newData;
2332 // TODO consider the or of anything (see selectinst)
2333 // however, this cannot be done yet for risk of turning
2334 // phi's that add floats into anything
2335 // PhiTypes |= newData.JustAnything();
2336 } else {
2337 set = true;
2338 PhiTypes = newData;
2339 }
2340 }
2341
2342 assert(set);
2343 // If we are only add / sub / etc to derive a value based off 0
2344 // we can start by assuming the type of 0 is integer rather
2345 // than assuming it could be anything (per null)
2346 if (bos.size() > 0 && UniqueValues.size() == 1 &&
2347 isa<ConstantInt>(UniqueValues[0]) &&
2348 (cast<ConstantInt>(UniqueValues[0])->isZero() ||
2349 cast<ConstantInt>(UniqueValues[0])->isOne())) {
2350 PhiTypes = TypeTree(BaseType::Integer).Only(-1, &phi);
2351 }
2352 for (BinaryOperator *bo : bos) {
2353 TypeTree vd1 = isa<Constant>(bo->getOperand(0))
2354 ? getAnalysis(bo->getOperand(0)).Data0()
2355 : PhiTypes.Data0();
2356 TypeTree vd2 = isa<Constant>(bo->getOperand(1))
2357 ? getAnalysis(bo->getOperand(1)).Data0()
2358 : PhiTypes.Data0();
2359 bool Legal = true;
2360 vd1.binopIn(Legal, vd2, bo->getOpcode());
2361 if (!Legal) {
2362 std::string str;
2363 raw_string_ostream ss(str);
2364 if (!CustomErrorHandler) {
2365 llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
2366 llvm::errs() << *fntypeinfo.Function << "\n";
2367 dump(ss);
2368 }
2369 ss << "Illegal updateBinop Analysis " << *bo << "\n";
2370 ss << "Illegal binopIn(consts): " << *bo << " lhs: " << vd1.str()
2371 << " rhs: " << vd2.str() << "\n";
2372 if (CustomErrorHandler) {
2373 CustomErrorHandler(str.c_str(), wrap(bo),
2374 ErrorType::IllegalTypeAnalysis, (void *)this,
2375 wrap(bo), nullptr);
2376 }
2377 EmitFailure("IllegalUpdateAnalysis", bo->getDebugLoc(), bo, ss.str());
2378 report_fatal_error("Performed illegal updateAnalysis");
2379 }
2380 PhiTypes &= vd1.Only(bo->getType()->isIntegerTy() ? -1 : 0, &phi);
2381 }
2382
2383 if (direction & DOWN) {
2384 if (phi.getType()->isIntOrIntVectorTy() &&
2385 PhiTypes.Inner0() == BaseType::Anything) {
2386 if (mustRemainInteger(&phi)) {
2387 PhiTypes = TypeTree(BaseType::Integer).Only(-1, &phi);
2388 }
2389 }
2390 updateAnalysis(&phi, PhiTypes, &phi);
2391 }
2392 }
2393}
2394
2396 auto &DL = fntypeinfo.Function->getParent()->getDataLayout();
2397 size_t inSize = (DL.getTypeSizeInBits(I.getOperand(0)->getType()) + 7) / 8;
2398 size_t outSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8;
2399 if (direction & DOWN)
2400 if (outSize != 1)
2401 updateAnalysis(&I,
2402 getAnalysis(I.getOperand(0))
2403 .ShiftIndices(DL, /*off*/ 0, inSize, /*addOffset*/ 0)
2404 .ShiftIndices(DL, /*off*/ 0, outSize, /*addOffset*/ 0),
2405 &I);
2406 // Don't propagate up a trunc float -> i8
2407 if (direction & UP)
2408 if (outSize != 1 || inSize == 1)
2410 I.getOperand(0),
2411 getAnalysis(&I).ShiftIndices(DL, /*off*/ 0, outSize, /*addOffset*/ 0),
2412 &I);
2413}
2414
2416 if (direction & DOWN) {
2417 TypeTree Result;
2418 if (cast<IntegerType>(I.getOperand(0)->getType()->getScalarType())
2419 ->getBitWidth() == 1) {
2420 Result = TypeTree(BaseType::Anything).Only(-1, &I);
2421 } else {
2422 Result = getAnalysis(I.getOperand(0));
2423 }
2424
2425 if (I.getType()->isIntOrIntVectorTy() &&
2426 Result.Inner0() == BaseType::Anything) {
2427 if (mustRemainInteger(&I)) {
2428 Result = TypeTree(BaseType::Integer).Only(-1, &I);
2429 }
2430 }
2431 updateAnalysis(&I, Result, &I);
2432 }
2433 if (direction & UP) {
2434 updateAnalysis(I.getOperand(0), getAnalysis(&I), &I);
2435 }
2436}
2437
2439 // This is only legal on integer types [not pointers per sign]
2440 // nor floatings points. Likewise, there's no direction check
2441 // necessary since this is always valid.
2442 updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I);
2443 updateAnalysis(I.getOperand(0), TypeTree(BaseType::Integer).Only(-1, &I), &I);
2444}
2445
2446void TypeAnalyzer::visitAddrSpaceCastInst(AddrSpaceCastInst &I) {
2447 if (direction & DOWN)
2448 updateAnalysis(&I, getAnalysis(I.getOperand(0)), &I);
2449 if (direction & UP)
2450 updateAnalysis(I.getOperand(0), getAnalysis(&I), &I);
2451}
2452
2454 // No direction check as always true
2456 &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I),
2457 &I);
2459 I.getOperand(0),
2460 TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType()))
2461 .Only(-1, &I),
2462 &I);
2463}
2464
2466 // No direction check as always true
2468 &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I),
2469 &I);
2471 I.getOperand(0),
2472 TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType()))
2473 .Only(-1, &I),
2474 &I);
2475}
2476
2478 // No direction check as always true
2479 updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I);
2481 I.getOperand(0),
2482 TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType()))
2483 .Only(-1, &I),
2484 &I);
2485}
2486
2488 // No direction check as always true
2489 updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I);
2491 I.getOperand(0),
2492 TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType()))
2493 .Only(-1, &I),
2494 &I);
2495}
2496
2498 // No direction check as always true
2499 updateAnalysis(I.getOperand(0), TypeTree(BaseType::Integer).Only(-1, &I), &I);
2501 &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I),
2502 &I);
2503}
2504
2506 // No direction check as always true
2507 updateAnalysis(I.getOperand(0), TypeTree(BaseType::Integer).Only(-1, &I), &I);
2509 &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I),
2510 &I);
2511}
2512
2513void TypeAnalyzer::visitPtrToIntInst(PtrToIntInst &I) {
2514 // Note it is illegal to assume here that either is a pointer or an int
2515 if (direction & DOWN)
2516 updateAnalysis(&I, getAnalysis(I.getOperand(0)), &I);
2517 if (direction & UP)
2518 updateAnalysis(I.getOperand(0), getAnalysis(&I), &I);
2519}
2520
2521void TypeAnalyzer::visitIntToPtrInst(IntToPtrInst &I) {
2522 // Note it is illegal to assume here that either is a pointer or an int
2523 if (direction & DOWN) {
2524 if (isa<ConstantInt>(I.getOperand(0))) {
2525 updateAnalysis(&I, TypeTree(BaseType::Anything).Only(-1, &I), &I);
2526 } else {
2527 updateAnalysis(&I, getAnalysis(I.getOperand(0)), &I);
2528 }
2529 }
2530 if (direction & UP)
2531 updateAnalysis(I.getOperand(0), getAnalysis(&I), &I);
2532}
2533
2534void TypeAnalyzer::visitFreezeInst(FreezeInst &I) {
2535 if (direction & DOWN)
2536 updateAnalysis(&I, getAnalysis(I.getOperand(0)), &I);
2537 if (direction & UP)
2538 updateAnalysis(I.getOperand(0), getAnalysis(&I), &I);
2539}
2540
2542 if (direction & DOWN)
2543 updateAnalysis(&I, getAnalysis(I.getOperand(0)), &I);
2544 if (direction & UP)
2545 updateAnalysis(I.getOperand(0), getAnalysis(&I), &I);
2546}
2547
2549 if (direction & UP) {
2550 auto Data = getAnalysis(&I).PurgeAnything();
2551 if (EnzymeStrictAliasing || (I.getTrueValue() == I.getFalseValue())) {
2552 updateAnalysis(I.getTrueValue(), Data, &I);
2553 updateAnalysis(I.getFalseValue(), Data, &I);
2554 } else {
2555 if (EnzymePrintType) {
2556 llvm::errs() << " skipping update into ";
2557 I.getTrueValue()->print(llvm::errs(), *MST);
2558 llvm::errs() << " of " << Data.str() << " from ";
2559 I.print(llvm::errs(), *MST);
2560 llvm::errs() << "\n";
2561 llvm::errs() << " skipping update into ";
2562 I.getFalseValue()->print(llvm::errs(), *MST);
2563 llvm::errs() << " of " << Data.str() << " from ";
2564 I.print(llvm::errs(), *MST);
2565 llvm::errs() << "\n";
2566 }
2567 }
2568 }
2569 if (direction & DOWN) {
2570 // special case for min/max result is still that operand [even if something
2571 // is 0]
2572 if (auto cmpI = dyn_cast<CmpInst>(I.getCondition())) {
2573 // is relational equiv to not is equality
2574 if (!cmpI->isEquality())
2575 if ((cmpI->getOperand(0) == I.getTrueValue() &&
2576 cmpI->getOperand(1) == I.getFalseValue()) ||
2577 (cmpI->getOperand(1) == I.getTrueValue() &&
2578 cmpI->getOperand(0) == I.getFalseValue())) {
2579 auto vd = getAnalysis(I.getTrueValue()).Inner0();
2580 vd &= getAnalysis(I.getFalseValue()).Inner0();
2581 if (vd.isKnown()) {
2582 updateAnalysis(&I, TypeTree(vd).Only(-1, &I), &I);
2583 return;
2584 }
2585 }
2586 }
2587 // If getTrueValue and getFalseValue are the same type (per the and)
2588 // it is safe to assume the result is as well
2589 TypeTree vd = getAnalysis(I.getTrueValue()).PurgeAnything();
2590 vd &= getAnalysis(I.getFalseValue()).PurgeAnything();
2591
2592 // A regular and operation, however is not sufficient. One of the operands
2593 // could be anything whereas the other is concrete, resulting in the
2594 // concrete type (e.g. select true, anything(0), integer(i64)) This is not
2595 // correct as the result of the select could always be anything (e.g. if it
2596 // is a pointer). As a result, explicitly or in any anything values
2597 // TODO this should be propagated elsewhere as well (specifically returns,
2598 // phi)
2599 TypeTree any = getAnalysis(I.getTrueValue()).JustAnything();
2600 any &= getAnalysis(I.getFalseValue()).JustAnything();
2601 vd |= any;
2602 updateAnalysis(&I, vd, &I);
2603 }
2604}
2605
2606void TypeAnalyzer::visitExtractElementInst(ExtractElementInst &I) {
2607 updateAnalysis(I.getIndexOperand(), BaseType::Integer, &I);
2608
2609 auto &dl = fntypeinfo.Function->getParent()->getDataLayout();
2610 VectorType *vecType = cast<VectorType>(I.getVectorOperand()->getType());
2611
2612 size_t bitsize = dl.getTypeSizeInBits(vecType->getElementType());
2613 size_t size = (bitsize + 7) / 8;
2614
2615 if (auto CI = dyn_cast<ConstantInt>(I.getIndexOperand())) {
2616 size_t off = (CI->getZExtValue() * bitsize) / 8;
2617
2618 if (direction & DOWN)
2619 updateAnalysis(&I,
2620 getAnalysis(I.getVectorOperand())
2621 .ShiftIndices(dl, off, size, /*addOffset*/ 0),
2622 &I);
2623
2624 if (direction & UP)
2625 updateAnalysis(I.getVectorOperand(),
2626 getAnalysis(&I).ShiftIndices(dl, 0, size, off), &I);
2627
2628 } else {
2629 if (direction & DOWN) {
2630 TypeTree vecAnalysis = getAnalysis(I.getVectorOperand());
2631 // TODO merge of anythings (see selectinst)
2632 TypeTree res = vecAnalysis.Lookup(size, dl);
2633 updateAnalysis(&I, res.Only(-1, &I), &I);
2634 }
2635 if (direction & UP) {
2636 // propagated upward to unknown location, no analysis
2637 // can be updated
2638 }
2639 }
2640}
2641
2642void TypeAnalyzer::visitInsertElementInst(InsertElementInst &I) {
2643 updateAnalysis(I.getOperand(2), TypeTree(BaseType::Integer).Only(-1, &I), &I);
2644
2645 auto &dl = fntypeinfo.Function->getParent()->getDataLayout();
2646 VectorType *vecType = cast<VectorType>(I.getOperand(0)->getType());
2647 if (vecType->getElementType()->isIntegerTy(1)) {
2648 if (direction & UP) {
2649 updateAnalysis(I.getOperand(0), TypeTree(BaseType::Integer).Only(-1, &I),
2650 &I);
2651 updateAnalysis(I.getOperand(1), TypeTree(BaseType::Integer).Only(-1, &I),
2652 &I);
2653 }
2654 if (direction & DOWN) {
2655 updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I);
2656 }
2657 return;
2658 }
2659#if LLVM_VERSION_MAJOR >= 12
2660 assert(!vecType->getElementCount().isScalable());
2661 size_t numElems = vecType->getElementCount().getKnownMinValue();
2662#else
2663 size_t numElems = vecType->getNumElements();
2664#endif
2665 size_t size = (dl.getTypeSizeInBits(vecType->getElementType()) + 7) / 8;
2666 size_t vecSize = (dl.getTypeSizeInBits(vecType) + 7) / 8;
2667
2668 if (auto CI = dyn_cast<ConstantInt>(I.getOperand(2))) {
2669 size_t off = CI->getZExtValue() * size;
2670
2671 if (direction & UP)
2672 updateAnalysis(I.getOperand(0),
2673 getAnalysis(&I).Clear(off, off + size, vecSize), &I);
2674
2675 if (direction & UP)
2676 updateAnalysis(I.getOperand(1),
2677 getAnalysis(&I).ShiftIndices(dl, off, size, 0), &I);
2678
2679 if (direction & DOWN) {
2680 auto new_res =
2681 getAnalysis(I.getOperand(0)).Clear(off, off + size, vecSize);
2682 auto shifted =
2683 getAnalysis(I.getOperand(1)).ShiftIndices(dl, 0, size, off);
2684 new_res |= shifted;
2685 updateAnalysis(&I, new_res, &I);
2686 }
2687 } else {
2688 if (direction & DOWN) {
2689 auto new_res = getAnalysis(I.getOperand(0));
2690 auto inserted = getAnalysis(I.getOperand(1));
2691 // TODO merge of anythings (see selectinst)
2692 for (size_t i = 0; i < numElems; ++i)
2693 new_res &= inserted.ShiftIndices(dl, 0, size, size * i);
2694 updateAnalysis(&I, new_res, &I);
2695 }
2696 }
2697}
2698
2699void TypeAnalyzer::visitShuffleVectorInst(ShuffleVectorInst &I) {
2700 // See selectinst type propagation rule for a description
2701 // of the ncessity and correctness of this rule.
2702 VectorType *resType = cast<VectorType>(I.getType());
2703
2704 auto &dl = fntypeinfo.Function->getParent()->getDataLayout();
2705
2706 const size_t lhs = 0;
2707 const size_t rhs = 1;
2708
2709#if LLVM_VERSION_MAJOR >= 12
2710 assert(!cast<VectorType>(I.getOperand(lhs)->getType())
2711 ->getElementCount()
2712 .isScalable());
2713 size_t numFirst = cast<VectorType>(I.getOperand(lhs)->getType())
2714 ->getElementCount()
2715 .getKnownMinValue();
2716#else
2717 size_t numFirst =
2718 cast<VectorType>(I.getOperand(lhs)->getType())->getNumElements();
2719#endif
2720 size_t size = (dl.getTypeSizeInBits(resType->getElementType()) + 7) / 8;
2721
2722 auto mask = I.getShuffleMask();
2723
2724 TypeTree result; // = getAnalysis(&I);
2725 for (size_t i = 0; i < mask.size(); ++i) {
2726 int newOff;
2727 {
2728 Value *vec[2] = {ConstantInt::get(Type::getInt64Ty(I.getContext()), 0),
2729 ConstantInt::get(Type::getInt64Ty(I.getContext()), i)};
2730 auto ud = UndefValue::get(getUnqual(I.getOperand(0)->getType()));
2731 auto g2 = GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec);
2732 APInt ai(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
2733 g2->accumulateConstantOffset(dl, ai);
2734 // Using destructor rather than eraseFromParent
2735 // as g2 has no parent
2736 delete g2;
2737 newOff = (int)ai.getLimitedValue();
2738 // there is a bug in LLVM, this is the correct offset
2739 if (cast<VectorType>(I.getOperand(lhs)->getType())
2740 ->getElementType()
2741 ->isIntegerTy(1)) {
2742 newOff = i / 8;
2743 }
2744 }
2745#if LLVM_VERSION_MAJOR > 16
2746 if (mask[i] == PoisonMaskElem)
2747#elif LLVM_VERSION_MAJOR >= 12
2748 if (mask[i] == UndefMaskElem)
2749#else
2750 if (mask[i] == -1)
2751#endif
2752 {
2753 if (direction & DOWN) {
2754 result |= TypeTree(BaseType::Anything)
2755 .Only(-1, &I)
2756 .ShiftIndices(dl, 0, size, newOff);
2757 }
2758 } else {
2759 if ((size_t)mask[i] < numFirst) {
2760 Value *vec[2] = {
2761 ConstantInt::get(Type::getInt64Ty(I.getContext()), 0),
2762 ConstantInt::get(Type::getInt64Ty(I.getContext()), mask[i])};
2763 auto ud = UndefValue::get(getUnqual(I.getOperand(0)->getType()));
2764 auto g2 =
2765 GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec);
2766 APInt ai(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
2767 g2->accumulateConstantOffset(dl, ai);
2768 // Using destructor rather than eraseFromParent
2769 // as g2 has no parent
2770 int oldOff = (int)ai.getLimitedValue();
2771 // there is a bug in LLVM, this is the correct offset
2772 if (cast<VectorType>(I.getOperand(lhs)->getType())
2773 ->getElementType()
2774 ->isIntegerTy(1)) {
2775 oldOff = mask[i] / 8;
2776 }
2777 delete g2;
2778 if (direction & UP) {
2779 updateAnalysis(I.getOperand(lhs),
2780 getAnalysis(&I).ShiftIndices(dl, newOff, size, oldOff),
2781 &I);
2782 }
2783 if (direction & DOWN) {
2784 result |= getAnalysis(I.getOperand(lhs))
2785 .ShiftIndices(dl, oldOff, size, newOff);
2786 }
2787 } else {
2788 Value *vec[2] = {ConstantInt::get(Type::getInt64Ty(I.getContext()), 0),
2789 ConstantInt::get(Type::getInt64Ty(I.getContext()),
2790 mask[i] - numFirst)};
2791 auto ud = UndefValue::get(getUnqual(I.getOperand(0)->getType()));
2792 auto g2 =
2793 GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec);
2794 APInt ai(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
2795 g2->accumulateConstantOffset(dl, ai);
2796 // Using destructor rather than eraseFromParent
2797 // as g2 has no parent
2798 int oldOff = (int)ai.getLimitedValue();
2799 // there is a bug in LLVM, this is the correct offset
2800 if (cast<VectorType>(I.getOperand(lhs)->getType())
2801 ->getElementType()
2802 ->isIntegerTy(1)) {
2803 oldOff = (mask[i] - numFirst) / 8;
2804 }
2805 delete g2;
2806 if (direction & UP) {
2807 updateAnalysis(I.getOperand(rhs),
2808 getAnalysis(&I).ShiftIndices(dl, newOff, size, oldOff),
2809 &I);
2810 }
2811 if (direction & DOWN) {
2812 result |= getAnalysis(I.getOperand(rhs))
2813 .ShiftIndices(dl, oldOff, size, newOff);
2814 }
2815 }
2816 }
2817 }
2818
2819 if (direction & DOWN) {
2820 updateAnalysis(&I, result, &I);
2821 }
2822}
2823
2824void TypeAnalyzer::visitExtractValueInst(ExtractValueInst &I) {
2825 auto &dl = fntypeinfo.Function->getParent()->getDataLayout();
2826 SmallVector<Value *, 4> vec;
2827 vec.push_back(ConstantInt::get(Type::getInt64Ty(I.getContext()), 0));
2828 for (auto ind : I.indices()) {
2829 vec.push_back(ConstantInt::get(Type::getInt32Ty(I.getContext()), ind));
2830 }
2831 auto ud = UndefValue::get(getUnqual(I.getOperand(0)->getType()));
2832 auto g2 = GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec);
2833 APInt ai(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
2834 g2->accumulateConstantOffset(dl, ai);
2835 // Using destructor rather than eraseFromParent
2836 // as g2 has no parent
2837 delete g2;
2838
2839 int off = (int)ai.getLimitedValue();
2840 int size = dl.getTypeSizeInBits(I.getType()) / 8;
2841
2842 if (direction & DOWN)
2843 updateAnalysis(&I,
2844 getAnalysis(I.getOperand(0))
2845 .ShiftIndices(dl, off, size, /*addOffset*/ 0),
2846 &I);
2847
2848 if (direction & UP)
2849 updateAnalysis(I.getOperand(0),
2850 getAnalysis(&I).ShiftIndices(dl, 0, size, off), &I);
2851}
2852
2853void TypeAnalyzer::visitInsertValueInst(InsertValueInst &I) {
2854 auto &dl = fntypeinfo.Function->getParent()->getDataLayout();
2855 SmallVector<Value *, 4> vec = {
2856 ConstantInt::get(Type::getInt64Ty(I.getContext()), 0)};
2857 for (auto ind : I.indices()) {
2858 vec.push_back(ConstantInt::get(Type::getInt32Ty(I.getContext()), ind));
2859 }
2860 auto ud = UndefValue::get(getUnqual(I.getOperand(0)->getType()));
2861 auto g2 = GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec);
2862 APInt ai(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
2863 g2->accumulateConstantOffset(dl, ai);
2864 delete g2;
2865 // Using destructor rather than eraseFromParent
2866 // as g2 has no parent
2867
2868 // Compute the offset at the next logical element [e.g. adding 1 to the last
2869 // index, carrying the value on overflow]
2870 for (ssize_t i = vec.size() - 1; i >= 0; i--) {
2871 auto CI = cast<ConstantInt>(vec[i]);
2872 auto val = CI->getZExtValue();
2873 if (i == 0) {
2874 vec[i] = ConstantInt::get(CI->getType(), val + 1);
2875 break;
2876 }
2877 auto subTy = GetElementPtrInst::getIndexedType(
2878 I.getOperand(0)->getType(), ArrayRef<Value *>(vec).slice(0, i));
2879 if (auto ST = dyn_cast<StructType>(subTy)) {
2880 if (val + 1 == ST->getNumElements()) {
2881 vec.erase(vec.begin() + i, vec.end());
2882 continue;
2883 }
2884 vec[i] = ConstantInt::get(CI->getType(), val + 1);
2885 break;
2886 } else {
2887 auto AT = cast<ArrayType>(subTy);
2888 if (val + 1 == AT->getNumElements()) {
2889 vec.erase(vec.begin() + i, vec.end());
2890 continue;
2891 }
2892 vec[i] = ConstantInt::get(CI->getType(), val + 1);
2893 break;
2894 }
2895 }
2896 g2 = GetElementPtrInst::Create(I.getOperand(0)->getType(), ud, vec);
2897 APInt aiend(dl.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
2898 g2->accumulateConstantOffset(dl, aiend);
2899 delete g2;
2900
2901 int off = (int)ai.getLimitedValue();
2902
2903 int agg_size = (dl.getTypeSizeInBits(I.getType()) + 7) / 8;
2904 int ins_size = (int)(aiend - ai).getLimitedValue();
2905 int ins2_size =
2906 (dl.getTypeSizeInBits(I.getInsertedValueOperand()->getType()) + 7) / 8;
2907
2908 if (direction & UP)
2909 updateAnalysis(I.getAggregateOperand(),
2910 getAnalysis(&I).Clear(off, off + ins_size, agg_size), &I);
2911 if (direction & UP)
2912 updateAnalysis(I.getInsertedValueOperand(),
2913 getAnalysis(&I).ShiftIndices(dl, off, ins2_size, 0), &I);
2914 auto new_res =
2915 getAnalysis(I.getAggregateOperand()).Clear(off, off + ins_size, agg_size);
2916 auto shifted = getAnalysis(I.getInsertedValueOperand())
2917 .ShiftIndices(dl, 0, ins_size, off);
2918 new_res |= shifted;
2919 if (direction & DOWN)
2920 updateAnalysis(&I, new_res, &I);
2921}
2922
2923void TypeAnalyzer::dump(llvm::raw_ostream &ss) {
2924 ss << "<analysis>\n";
2925 // We don't care about correct MD node numbering here.
2926 ModuleSlotTracker MST(fntypeinfo.Function->getParent(),
2927 /*ShouldInitializeAllMetadata*/ false);
2928 for (auto &pair : analysis) {
2929 if (auto F = dyn_cast<Function>(pair.first))
2930 ss << "@" << F->getName();
2931 else
2932 pair.first->print(ss, MST);
2933 ss << ": " << pair.second.str()
2934 << ", intvals: " << to_string(knownIntegralValues(pair.first)) << "\n";
2935 }
2936 ss << "</analysis>\n";
2937}
2938
2939void TypeAnalyzer::visitAtomicRMWInst(llvm::AtomicRMWInst &I) {
2940 Value *Args[2] = {nullptr, I.getOperand(1)};
2941 TypeTree Ret = getAnalysis(&I);
2942 auto &DL = I.getParent()->getParent()->getParent()->getDataLayout();
2943 auto LoadSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8;
2944 TypeTree LHS = getAnalysis(I.getOperand(0)).Lookup(LoadSize, DL);
2945 TypeTree RHS = getAnalysis(I.getOperand(1));
2946
2947 switch (I.getOperation()) {
2948 case AtomicRMWInst::Xchg: {
2949 auto tmp = LHS;
2950 LHS = RHS;
2951 RHS = tmp;
2952 bool Legal = true;
2953 LHS.checkedOrIn(Ret, /*PointerIntSame*/ false, Legal);
2954 if (!Legal) {
2955 dump();
2956 llvm::errs() << I << "\n";
2957 llvm::errs() << "Illegal orIn: " << LHS.str() << " right: " << Ret.str()
2958 << "\n";
2959 llvm::errs() << *I.getOperand(0) << " "
2960 << getAnalysis(I.getOperand(0)).str() << "\n";
2961 llvm::errs() << *I.getOperand(1) << " "
2962 << getAnalysis(I.getOperand(1)).str() << "\n";
2963 assert(0 && "Performed illegal visitAtomicRMWInst::orIn");
2964 llvm_unreachable("Performed illegal visitAtomicRMWInst::orIn");
2965 }
2966 Ret = tmp;
2967 break;
2968 }
2969 case AtomicRMWInst::Add:
2970 visitBinaryOperation(DL, I.getType(), BinaryOperator::Add, Args, Ret, LHS,
2971 RHS, &I);
2972 break;
2973 case AtomicRMWInst::Sub:
2974 visitBinaryOperation(DL, I.getType(), BinaryOperator::Sub, Args, Ret, LHS,
2975 RHS, &I);
2976 break;
2977 case AtomicRMWInst::And:
2978 visitBinaryOperation(DL, I.getType(), BinaryOperator::And, Args, Ret, LHS,
2979 RHS, &I);
2980 break;
2981 case AtomicRMWInst::Or:
2982 visitBinaryOperation(DL, I.getType(), BinaryOperator::Or, Args, Ret, LHS,
2983 RHS, &I);
2984 break;
2985 case AtomicRMWInst::Xor:
2986 visitBinaryOperation(DL, I.getType(), BinaryOperator::Xor, Args, Ret, LHS,
2987 RHS, &I);
2988 break;
2989 case AtomicRMWInst::FAdd:
2990 visitBinaryOperation(DL, I.getType(), BinaryOperator::FAdd, Args, Ret, LHS,
2991 RHS, &I);
2992 break;
2993 case AtomicRMWInst::FSub:
2994 visitBinaryOperation(DL, I.getType(), BinaryOperator::FSub, Args, Ret, LHS,
2995 RHS, &I);
2996 break;
2997 case AtomicRMWInst::Max:
2998 case AtomicRMWInst::Min:
2999 case AtomicRMWInst::UMax:
3000 case AtomicRMWInst::UMin:
3001 case AtomicRMWInst::Nand:
3002 default:
3003 break;
3004 }
3005
3006 if (direction & UP) {
3007 TypeTree ptr = LHS.PurgeAnything()
3008 .ShiftIndices(DL, /*start*/ 0, LoadSize, /*addOffset*/ 0)
3009 .Only(-1, &I);
3010 ptr.insert({-1}, BaseType::Pointer);
3011 updateAnalysis(I.getOperand(0), ptr, &I);
3012 updateAnalysis(I.getOperand(1), RHS, &I);
3013 }
3014
3015 if (direction & DOWN) {
3016 if (Ret[{-1}] == BaseType::Anything && LHS[{-1}] != BaseType::Anything)
3017 Ret = LHS;
3018 if (I.getType()->isIntOrIntVectorTy() && Ret[{-1}] == BaseType::Anything) {
3019 if (mustRemainInteger(&I)) {
3020 Ret = TypeTree(BaseType::Integer).Only(-1, &I);
3021 }
3022 }
3023 updateAnalysis(&I, Ret, &I);
3024 }
3025}
3026
3027void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T,
3028 llvm::Instruction::BinaryOps Opcode,
3029 Value *Args[2], TypeTree &Ret,
3030 TypeTree &LHS, TypeTree &RHS,
3031 Instruction *origin) {
3032 if (Opcode == BinaryOperator::FAdd || Opcode == BinaryOperator::FSub ||
3033 Opcode == BinaryOperator::FMul || Opcode == BinaryOperator::FDiv ||
3034 Opcode == BinaryOperator::FRem) {
3035 auto ty = T->getScalarType();
3036 assert(ty->isFloatingPointTy());
3037 ConcreteType dt(ty);
3038 if (direction & UP) {
3039 bool LegalOr = true;
3040 auto Data = TypeTree(dt).Only(-1, nullptr);
3041 LHS.checkedOrIn(Data, /*PointerIntSame*/ false, LegalOr);
3042 if (CustomErrorHandler && !LegalOr) {
3043 std::string str;
3044 raw_string_ostream ss(str);
3045 ss << "Illegal updateAnalysis prev:" << LHS.str()
3046 << " new: " << Data.str() << "\n";
3047 ss << "val: " << *Args[0];
3048 ss << "origin: " << *origin;
3049 CustomErrorHandler(str.c_str(), wrap(Args[0]),
3050 ErrorType::IllegalTypeAnalysis, (void *)this,
3051 wrap(origin), nullptr);
3052 }
3053 RHS.checkedOrIn(Data, /*PointerIntSame*/ false, LegalOr);
3054 if (CustomErrorHandler && !LegalOr) {
3055 std::string str;
3056 raw_string_ostream ss(str);
3057 ss << "Illegal updateAnalysis prev:" << RHS.str()
3058 << " new: " << Data.str() << "\n";
3059 ss << "val: " << *Args[1];
3060 ss << "origin: " << *origin;
3061 CustomErrorHandler(str.c_str(), wrap(Args[1]),
3062 ErrorType::IllegalTypeAnalysis, (void *)this,
3063 wrap(origin), nullptr);
3064 }
3065 }
3066 if (direction & DOWN)
3067 Ret |= TypeTree(dt).Only(-1, nullptr);
3068 } else {
3069 auto size = (dl.getTypeSizeInBits(T) + 7) / 8;
3070 auto AnalysisLHS = LHS.Data0();
3071 auto AnalysisRHS = RHS.Data0();
3072 auto AnalysisRet = Ret.Data0();
3073
3074 switch (Opcode) {
3075 case BinaryOperator::Sub:
3076 // ptr - ptr => int and int - int => int; thus int = a - b says only that
3077 // these are equal ptr - int => ptr and int - ptr => ptr; thus
3078 // howerver we do not want to propagate underlying ptr types since it's
3079 // legal to subtract unrelated pointer
3080 if (direction & UP) {
3081 if (AnalysisRet[{}] == BaseType::Integer) {
3082 LHS |= TypeTree(AnalysisRHS[{}]).PurgeAnything().Only(-1, nullptr);
3083 RHS |= TypeTree(AnalysisLHS[{}]).PurgeAnything().Only(-1, nullptr);
3084 }
3085 if (AnalysisRet[{}] == BaseType::Pointer) {
3086 if (AnalysisLHS[{}] == BaseType::Pointer) {
3087 RHS |= TypeTree(BaseType::Integer).Only(-1, nullptr);
3088 }
3089 if (AnalysisRHS[{}] == BaseType::Integer) {
3090 LHS |= TypeTree(BaseType::Pointer).Only(-1, nullptr);
3091 }
3092 }
3093 }
3094 break;
3095
3096 case BinaryOperator::Add:
3097 case BinaryOperator::Mul:
3098 // if a + b or a * b == int, then a and b must be ints
3099 if (direction & UP) {
3100 if (AnalysisRet[{}] == BaseType::Integer) {
3101 LHS.orIn({-1}, BaseType::Integer);
3102 RHS.orIn({-1}, BaseType::Integer);
3103 }
3104 }
3105 break;
3106
3107 case BinaryOperator::Xor:
3108 if (direction & UP)
3109 for (int i = 0; i < 2; ++i) {
3110 Type *FT = nullptr;
3111 if (!(FT = Ret.IsAllFloat(size, dl)))
3112 continue;
3113 // If ^ against 0b10000000000, the result is a float
3114 bool validXor = containsOnlyAtMostTopBit(Args[i], FT, dl);
3115 if (validXor) {
3116 bool Legal = true;
3117 ((i == 0) ? RHS : LHS)
3118 .checkedOrIn(TypeTree(FT).Only(-1, nullptr),
3119 /*pointerintsame*/ false, Legal);
3120
3121 if (!Legal) {
3122 std::string str;
3123 raw_string_ostream ss(str);
3124 if (!CustomErrorHandler) {
3125 llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
3126 llvm::errs() << *fntypeinfo.Function << "\n";
3127 dump(ss);
3128 }
3129 ss << "Illegal updateBinop (xor up) Analysis " << *origin << "\n";
3130 ss << " (i=" << i << ") " << (i == 0 ? "RHS" : "LHS") << " "
3131 << ((i == 0) ? RHS : LHS).str() << " FT from ret: " << *FT
3132 << "\n";
3133 if (CustomErrorHandler) {
3134 CustomErrorHandler(str.c_str(), wrap(origin),
3135 ErrorType::IllegalTypeAnalysis, (void *)this,
3136 wrap(origin), nullptr);
3137 }
3138 EmitFailure("IllegalUpdateAnalysis", origin->getDebugLoc(),
3139 origin, ss.str());
3140 report_fatal_error("Performed illegal updateAnalysis");
3141 }
3142 }
3143 }
3144 break;
3145 case BinaryOperator::Or:
3146 for (int i = 0; i < 2; ++i) {
3147 Type *FT = nullptr;
3148 if (!(FT = Ret.IsAllFloat(size, dl)))
3149 continue;
3150 // If | against a number only or'ing the exponent, the result is a float
3151 bool validXor = false;
3152 if (auto CIT = dyn_cast_or_null<ConstantInt>(Args[i])) {
3153 if (dl.getTypeSizeInBits(FT) != dl.getTypeSizeInBits(CIT->getType()))
3154 continue;
3155 auto CI = CIT->getValue();
3156#if LLVM_VERSION_MAJOR > 16
3157 if (CI.isZero())
3158#else
3159 if (CI.isNullValue())
3160#endif
3161 {
3162 validXor = true;
3163 } else if (
3164 !CI.isNegative() &&
3165 ((FT->isFloatTy()
3166#if LLVM_VERSION_MAJOR > 16
3167 && (CI & ~0b01111111100000000000000000000000ULL).isZero()
3168#else
3169 && (CI & ~0b01111111100000000000000000000000ULL).isNullValue()
3170#endif
3171 ) ||
3172 (FT->isDoubleTy()
3173#if LLVM_VERSION_MAJOR > 16
3174 &&
3175 (CI &
3176 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
3177 .isZero()
3178#else
3179 &&
3180 (CI &
3181 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
3182 .isNullValue()
3183#endif
3184 ))) {
3185 validXor = true;
3186 }
3187 } else if (auto CV = dyn_cast_or_null<ConstantVector>(Args[i])) {
3188 validXor = true;
3189 if (dl.getTypeSizeInBits(FT) !=
3190 dl.getTypeSizeInBits(CV->getOperand(i)->getType()))
3191 continue;
3192 for (size_t i = 0, end = CV->getNumOperands(); i < end; ++i) {
3193 auto CI = dyn_cast<ConstantInt>(CV->getOperand(i))->getValue();
3194
3195#if LLVM_VERSION_MAJOR > 16
3196 if (CI.isZero())
3197#else
3198 if (CI.isNullValue())
3199#endif
3200 {
3201 } else if (
3202 !CI.isNegative() &&
3203 ((FT->isFloatTy()
3204#if LLVM_VERSION_MAJOR > 16
3205 && (CI & ~0b01111111100000000000000000000000ULL).isZero()
3206#else
3207 && (CI & ~0b01111111100000000000000000000000ULL).isNullValue()
3208#endif
3209 ) ||
3210 (FT->isDoubleTy()
3211#if LLVM_VERSION_MAJOR > 16
3212 &&
3213 (CI &
3214 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
3215 .isZero()
3216#else
3217 &&
3218 (CI &
3219 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
3220 .isNullValue()
3221#endif
3222 ))) {
3223 } else
3224 validXor = false;
3225 }
3226 } else if (auto CV = dyn_cast_or_null<ConstantDataVector>(Args[i])) {
3227 validXor = true;
3228 if (dl.getTypeSizeInBits(FT) !=
3229 dl.getTypeSizeInBits(CV->getElementType()))
3230 continue;
3231 for (size_t i = 0, end = CV->getNumElements(); i < end; ++i) {
3232 auto CI = CV->getElementAsAPInt(i);
3233#if LLVM_VERSION_MAJOR > 16
3234 if (CI.isZero())
3235#else
3236 if (CI.isNullValue())
3237#endif
3238 {
3239 } else if (
3240 !CI.isNegative() &&
3241 ((FT->isFloatTy()
3242#if LLVM_VERSION_MAJOR > 16
3243 && (CI & ~0b01111111100000000000000000000000ULL).isZero()
3244#else
3245 && (CI & ~0b01111111100000000000000000000000ULL).isNullValue()
3246#endif
3247 ) ||
3248 (FT->isDoubleTy()
3249#if LLVM_VERSION_MAJOR > 16
3250 &&
3251 (CI &
3252 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
3253 .isZero()
3254#else
3255 &&
3256 (CI &
3257 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
3258 .isNullValue()
3259#endif
3260 ))) {
3261 } else
3262 validXor = false;
3263 }
3264 }
3265 if (validXor) {
3266 ((i == 0) ? RHS : LHS) |= TypeTree(FT).Only(-1, nullptr);
3267 }
3268 }
3269 break;
3270 default:
3271 break;
3272 }
3273
3274 if (direction & DOWN) {
3275 TypeTree Result = AnalysisLHS;
3276 bool Legal = true;
3277 Result.binopIn(Legal, AnalysisRHS, Opcode);
3278 if (!Legal) {
3279 std::string str;
3280 raw_string_ostream ss(str);
3281 if (!CustomErrorHandler) {
3282 llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
3283 llvm::errs() << *fntypeinfo.Function << "\n";
3284 dump(ss);
3285 }
3286 ss << "Illegal updateBinop Analysis " << *origin << "\n";
3287 ss << "Illegal binopIn(down): " << Opcode << " lhs: " << Result.str()
3288 << " rhs: " << AnalysisRHS.str() << "\n";
3289 if (CustomErrorHandler) {
3290 CustomErrorHandler(str.c_str(), wrap(origin),
3291 ErrorType::IllegalTypeAnalysis, (void *)this,
3292 wrap(origin), nullptr);
3293 }
3294 EmitFailure("IllegalUpdateAnalysis", origin->getDebugLoc(), origin,
3295 ss.str());
3296 report_fatal_error("Performed illegal updateAnalysis");
3297 }
3298 if (Opcode == BinaryOperator::And) {
3299 for (int i = 0; i < 2; ++i) {
3300 if (Args[i])
3301 for (auto andval :
3302 fntypeinfo.knownIntegralValues(Args[i], DT, intseen, SE)) {
3303 if (andval <= 16 && andval >= 0) {
3304 Result = TypeTree(BaseType::Integer);
3305 } else if (andval < 0 && andval >= -64) {
3306 // If a small negative number, this just masks off the lower
3307 // bits in this case we can say that this is the same as the
3308 // other operand
3309 Result = (i == 0 ? AnalysisRHS : AnalysisLHS);
3310 }
3311 }
3312 // If we and a constant against an integer, the result remains an
3313 // integer
3314 if (Args[i] && isa<ConstantInt>(Args[i]) &&
3315 (i == 0 ? AnalysisRHS : AnalysisLHS).Inner0() ==
3317 Result = TypeTree(BaseType::Integer);
3318 }
3319 }
3320 } else if (Opcode == BinaryOperator::Add ||
3321 Opcode == BinaryOperator::Sub) {
3322 for (int i = 0; i < 2; ++i) {
3323 if (i == 1 || Opcode == BinaryOperator::Add)
3324 if (auto CI = dyn_cast_or_null<ConstantInt>(Args[i])) {
3325 if (CI->isNegative() || CI->isZero() ||
3326 CI->getLimitedValue() <= 4096) {
3327 // If add/sub with zero, small, or negative number, the result
3328 // is equal to the type of the other operand (and we don't need
3329 // to assume this was an "anything")
3330 Result = (i == 0 ? AnalysisRHS : AnalysisLHS);
3331 }
3332 }
3333 }
3334 } else if (Opcode == BinaryOperator::Mul) {
3335 for (int i = 0; i < 2; ++i) {
3336 // If we mul a constant against an integer, the result remains an
3337 // integer
3338 if (Args[i] && isa<ConstantInt>(Args[i]) &&
3339 (i == 0 ? AnalysisRHS : AnalysisLHS)[{}] == BaseType::Integer) {
3340 Result = TypeTree(BaseType::Integer);
3341 }
3342 }
3343 } else if (Opcode == BinaryOperator::URem) {
3344 if (auto CI = dyn_cast_or_null<ConstantInt>(Args[1])) {
3345 // If rem with a small integer, the result is also a small integer
3346 if (CI->getLimitedValue() <= 4096) {
3347 Result = TypeTree(BaseType::Integer);
3348 }
3349 }
3350 } else if (Opcode == BinaryOperator::Xor) {
3351 for (int i = 0; i < 2; ++i) {
3352 Type *FT;
3353 if (!(FT = (i == 0 ? RHS : LHS).IsAllFloat(size, dl)))
3354 continue;
3355 // If ^ against 0b10000000000, the result is a float
3356 bool validXor = containsOnlyAtMostTopBit(Args[i], FT, dl);
3357 if (validXor) {
3358 Result = ConcreteType(FT);
3359 }
3360 }
3361 } else if (Opcode == BinaryOperator::Or) {
3362 for (int i = 0; i < 2; ++i) {
3363 Type *FT;
3364 if (!(FT = (i == 0 ? RHS : LHS).IsAllFloat(size, dl)))
3365 continue;
3366 // If & against 0b10000000000, the result is a float
3367 bool validXor = false;
3368 if (auto CIT = dyn_cast_or_null<ConstantInt>(Args[i])) {
3369 if (dl.getTypeSizeInBits(FT) !=
3370 dl.getTypeSizeInBits(CIT->getType()))
3371 continue;
3372 auto CI = CIT->getValue();
3373#if LLVM_VERSION_MAJOR > 16
3374 if (CI.isZero())
3375#else
3376 if (CI.isNullValue())
3377#endif
3378 {
3379 validXor = true;
3380 } else if (
3381 !CI.isNegative() &&
3382 ((FT->isFloatTy()
3383#if LLVM_VERSION_MAJOR > 16
3384 && (CI & ~0b01111111100000000000000000000000ULL).isZero()
3385#else
3386 && (CI & ~0b01111111100000000000000000000000ULL).isNullValue()
3387#endif
3388 ) ||
3389 (FT->isDoubleTy()
3390#if LLVM_VERSION_MAJOR > 16
3391 &&
3392 (CI &
3393 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
3394 .isZero()
3395#else
3396 &&
3397 (CI &
3398 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
3399 .isNullValue()
3400#endif
3401 ))) {
3402 validXor = true;
3403 }
3404 } else if (auto CV = dyn_cast_or_null<ConstantVector>(Args[i])) {
3405 validXor = true;
3406 if (dl.getTypeSizeInBits(FT) !=
3407 dl.getTypeSizeInBits(CV->getOperand(i)->getType()))
3408 continue;
3409 for (size_t i = 0, end = CV->getNumOperands(); i < end; ++i) {
3410 auto CI = dyn_cast<ConstantInt>(CV->getOperand(i))->getValue();
3411#if LLVM_VERSION_MAJOR > 16
3412 if (CI.isZero())
3413#else
3414 if (CI.isNullValue())
3415#endif
3416 {
3417 } else if (
3418 !CI.isNegative() &&
3419 ((FT->isFloatTy()
3420#if LLVM_VERSION_MAJOR > 16
3421 && (CI & ~0b01111111100000000000000000000000ULL).isZero()
3422#else
3423 &&
3424 (CI & ~0b01111111100000000000000000000000ULL).isNullValue()
3425#endif
3426 ) ||
3427 (FT->isDoubleTy()
3428#if LLVM_VERSION_MAJOR > 16
3429 &&
3430 (CI &
3431 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
3432 .isZero()
3433#else
3434 &&
3435 (CI &
3436 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
3437 .isNullValue()
3438#endif
3439 ))) {
3440 } else
3441 validXor = false;
3442 }
3443 } else if (auto CV = dyn_cast_or_null<ConstantDataVector>(Args[i])) {
3444 validXor = true;
3445 if (dl.getTypeSizeInBits(FT) !=
3446 dl.getTypeSizeInBits(CV->getElementType()))
3447 continue;
3448 for (size_t i = 0, end = CV->getNumElements(); i < end; ++i) {
3449 auto CI = CV->getElementAsAPInt(i);
3450#if LLVM_VERSION_MAJOR > 16
3451 if (CI.isZero())
3452#else
3453 if (CI.isNullValue())
3454#endif
3455 {
3456 } else if (
3457 !CI.isNegative() &&
3458 ((FT->isFloatTy()
3459#if LLVM_VERSION_MAJOR > 16
3460 && (CI & ~0b01111111100000000000000000000000ULL).isZero()
3461#else
3462 &&
3463 (CI & ~0b01111111100000000000000000000000ULL).isNullValue()
3464#endif
3465 ) ||
3466 (FT->isDoubleTy()
3467#if LLVM_VERSION_MAJOR > 16
3468 &&
3469 (CI &
3470 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
3471 .isZero()
3472#else
3473 &&
3474 (CI &
3475 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
3476 .isNullValue()
3477#endif
3478 ))) {
3479 } else
3480 validXor = false;
3481 }
3482 }
3483 if (validXor) {
3484 Result = ConcreteType(FT);
3485 }
3486 }
3487 }
3488
3489 Ret = Result.Only(-1, nullptr);
3490 }
3491 }
3492}
3493void TypeAnalyzer::visitBinaryOperator(BinaryOperator &I) {
3494 Value *Args[2] = {I.getOperand(0), I.getOperand(1)};
3495 TypeTree Ret = getAnalysis(&I);
3496 TypeTree LHS = getAnalysis(I.getOperand(0));
3497 TypeTree RHS = getAnalysis(I.getOperand(1));
3498 auto &DL = I.getParent()->getParent()->getParent()->getDataLayout();
3499 visitBinaryOperation(DL, I.getType(), I.getOpcode(), Args, Ret, LHS, RHS, &I);
3500
3501 if (direction & UP) {
3502 updateAnalysis(I.getOperand(0), LHS, &I);
3503 updateAnalysis(I.getOperand(1), RHS, &I);
3504 }
3505
3506 if (direction & DOWN) {
3507 if (I.getType()->isIntOrIntVectorTy() && Ret[{-1}] == BaseType::Anything) {
3508 if (mustRemainInteger(&I)) {
3509 Ret = TypeTree(BaseType::Integer).Only(-1, &I);
3510 }
3511 }
3512 updateAnalysis(&I, Ret, &I);
3513 }
3514}
3515
3516void TypeAnalyzer::visitMemTransferInst(llvm::MemTransferInst &MTI) {
3518}
3519
3520void TypeAnalyzer::visitMemTransferCommon(llvm::CallBase &MTI) {
3521 if (MTI.getType()->isIntegerTy()) {
3522 updateAnalysis(&MTI, TypeTree(BaseType::Integer).Only(-1, &MTI), &MTI);
3523 }
3524
3525 if (!(direction & UP))
3526 return;
3527
3528 // If memcpy / memmove of pointer, we can propagate type information from src
3529 // to dst up to the length and vice versa
3530 size_t sz = 1;
3531 for (auto val :
3532 fntypeinfo.knownIntegralValues(MTI.getArgOperand(2), DT, intseen, SE)) {
3533 if (val >= 0) {
3534 sz = max(sz, (size_t)val);
3535 }
3536 }
3537
3538 auto &dl = MTI.getParent()->getParent()->getParent()->getDataLayout();
3539 TypeTree res = getAnalysis(MTI.getArgOperand(0))
3540 .PurgeAnything()
3541 .Data0()
3542 .ShiftIndices(dl, 0, sz, 0);
3543 TypeTree res2 = getAnalysis(MTI.getArgOperand(1))
3544 .PurgeAnything()
3545 .Data0()
3546 .ShiftIndices(dl, 0, sz, 0);
3547
3548 bool Legal = true;
3549 res.checkedOrIn(res2, /*PointerIntSame*/ false, Legal);
3550 if (!Legal) {
3551 std::string str;
3552 raw_string_ostream ss(str);
3553 if (!CustomErrorHandler) {
3554 llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
3555 llvm::errs() << *fntypeinfo.Function << "\n";
3556 dump(ss);
3557 }
3558 ss << "Illegal updateMemTransfer Analysis " << MTI << "\n";
3559 ss << "Illegal orIn: " << res.str() << " right: " << res2.str() << "\n";
3560 ss << *MTI.getArgOperand(0) << " "
3561 << getAnalysis(MTI.getArgOperand(0)).str() << "\n";
3562 ss << *MTI.getArgOperand(1) << " "
3563 << getAnalysis(MTI.getArgOperand(1)).str() << "\n";
3564
3565 if (CustomErrorHandler) {
3566 CustomErrorHandler(str.c_str(), wrap(&MTI),
3567 ErrorType::IllegalTypeAnalysis, (void *)this,
3568 wrap(&MTI), nullptr);
3569 }
3570 EmitFailure("IllegalUpdateAnalysis", MTI.getDebugLoc(), &MTI, ss.str());
3571 report_fatal_error("Performed illegal updateAnalysis");
3572 }
3573 res.insert({}, BaseType::Pointer);
3574 res = res.Only(-1, &MTI);
3575 updateAnalysis(MTI.getArgOperand(0), res, &MTI);
3576 updateAnalysis(MTI.getArgOperand(1), res, &MTI);
3577#if LLVM_VERSION_MAJOR >= 14
3578 for (unsigned i = 2; i < MTI.arg_size(); ++i)
3579#else
3580 for (unsigned i = 2; i < MTI.getNumArgOperands(); ++i)
3581#endif
3582 {
3583 updateAnalysis(MTI.getArgOperand(i),
3584 TypeTree(BaseType::Integer).Only(-1, &MTI), &MTI);
3585 }
3586}
3587
3588void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
3589 switch (I.getIntrinsicID()) {
3590 case Intrinsic::ctpop:
3591 case Intrinsic::ctlz:
3592 case Intrinsic::cttz:
3593 case Intrinsic::nvvm_read_ptx_sreg_tid_x:
3594 case Intrinsic::nvvm_read_ptx_sreg_tid_y:
3595 case Intrinsic::nvvm_read_ptx_sreg_tid_z:
3596 case Intrinsic::nvvm_read_ptx_sreg_ntid_x:
3597 case Intrinsic::nvvm_read_ptx_sreg_ntid_y:
3598 case Intrinsic::nvvm_read_ptx_sreg_ntid_z:
3599 case Intrinsic::nvvm_read_ptx_sreg_ctaid_x:
3600 case Intrinsic::nvvm_read_ptx_sreg_ctaid_y:
3601 case Intrinsic::nvvm_read_ptx_sreg_ctaid_z:
3602 case Intrinsic::nvvm_read_ptx_sreg_nctaid_x:
3603 case Intrinsic::nvvm_read_ptx_sreg_nctaid_y:
3604 case Intrinsic::nvvm_read_ptx_sreg_nctaid_z:
3605 case Intrinsic::nvvm_read_ptx_sreg_warpsize:
3606 case Intrinsic::amdgcn_workitem_id_x:
3607 case Intrinsic::amdgcn_workitem_id_y:
3608 case Intrinsic::amdgcn_workitem_id_z:
3609 // No direction check as always valid
3610 updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I);
3611 return;
3612
3613#if LLVM_VERSION_MAJOR < 22
3614 case Intrinsic::nvvm_barrier0_popc:
3615 case Intrinsic::nvvm_barrier0_and:
3616 case Intrinsic::nvvm_barrier0_or:
3617#else
3618 case Intrinsic::nvvm_barrier_cta_red_and_aligned_all:
3619 case Intrinsic::nvvm_barrier_cta_red_and_aligned_count:
3620 case Intrinsic::nvvm_barrier_cta_red_or_aligned_all:
3621 case Intrinsic::nvvm_barrier_cta_red_or_aligned_count:
3622 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_all:
3623 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_count:
3624#endif
3625 // No direction check as always valid
3626 updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I);
3627 updateAnalysis(I.getOperand(0), TypeTree(BaseType::Integer).Only(-1, &I),
3628 &I);
3629 return;
3630
3631 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
3632 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
3633 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
3634 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
3635 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
3636 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
3637 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
3638 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
3639 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
3640 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
3641 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
3642 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: {
3643 TypeTree TT;
3644 TT.insert({-1}, BaseType::Pointer);
3645 TT.insert({-1, 0}, Type::getFloatTy(I.getContext()));
3646 updateAnalysis(I.getOperand(0), TT, &I);
3647 for (int i = 1; i <= 9; i++)
3649 I.getOperand(i),
3650 TypeTree(ConcreteType(Type::getFloatTy(I.getContext()))).Only(-1, &I),
3651 &I);
3652 return;
3653 }
3654
3655 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
3656 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
3657 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
3658 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
3659 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
3660 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
3661 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
3662 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
3663 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
3664 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
3665 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
3666 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
3667 TypeTree TT;
3668 TT.insert({-1}, BaseType::Pointer);
3669 TT.insert({-1, 0}, Type::getHalfTy(I.getContext()));
3670 updateAnalysis(I.getOperand(0), TT, &I);
3671 for (int i = 1; i <= 9; i++)
3673 I.getOperand(i),
3674 TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I),
3675 &I);
3676 return;
3677 }
3678
3679 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
3680 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
3681 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
3682 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
3683 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
3684 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
3685 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
3686 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
3687 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
3688 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
3689 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
3690 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: {
3691 TypeTree TT;
3692 TT.insert({-1}, BaseType::Pointer);
3693 TT.insert({-1, 0}, Type::getFloatTy(I.getContext()));
3694 updateAnalysis(I.getOperand(0), TT, &I);
3696 &I,
3697 TypeTree(ConcreteType(Type::getFloatTy(I.getContext()))).Only(-1, &I),
3698 &I);
3699 return;
3700 }
3701
3702 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
3703 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
3704 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
3705 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
3706 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
3707 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
3708 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
3709 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
3710 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
3711 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
3712 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
3713 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
3714 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
3715 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
3716 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
3717 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
3718 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
3719 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
3720 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
3721 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
3722 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
3723 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
3724 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
3725 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
3726 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
3727 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
3728 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
3729 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
3730 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
3731 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
3732 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
3733 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride:
3734 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
3735 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
3736 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
3737 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
3738 TypeTree TT;
3739 TT.insert({-1}, BaseType::Pointer);
3740 TT.insert({-1, 0}, Type::getHalfTy(I.getContext()));
3741 updateAnalysis(I.getOperand(0), TT, &I);
3743 &I,
3744 TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I),
3745 &I);
3746 return;
3747 }
3748
3749 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
3750 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
3751 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
3752 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
3753 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
3754 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
3755 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
3756 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
3757 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
3758 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
3759 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
3760 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
3761 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
3762 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
3763 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
3764 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
3765 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
3766 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
3767 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
3768 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
3769 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
3770 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
3771 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
3772 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
3773 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
3774 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
3775 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
3776 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
3777 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
3778 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
3779 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
3780 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
3781 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
3782 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
3783 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
3784 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
3785 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
3786 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
3787 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
3788 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
3789 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
3790 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
3791 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
3792 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
3793 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
3794 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
3795 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
3796 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
3797 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
3798 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
3799 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
3800 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
3801 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
3802 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
3803 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
3804 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
3805 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
3806 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
3807 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
3808 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride:
3809 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
3810 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
3811 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
3812 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
3813 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
3814 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
3815 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
3816 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
3817 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
3818 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
3819 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
3820 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
3821 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
3822 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
3823 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
3824 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
3825 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
3826 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
3827 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
3828 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride: {
3829 // TODO
3830 return;
3831 }
3832
3833 case Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_f16_f16:
3834 case Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_f16_f16:
3835 case Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_f16_f16:
3836 case Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16:
3837 case Intrinsic::nvvm_wmma_m32n8k16_mma_col_col_f16_f16:
3838 case Intrinsic::nvvm_wmma_m32n8k16_mma_col_row_f16_f16:
3839 case Intrinsic::nvvm_wmma_m32n8k16_mma_row_col_f16_f16:
3840 case Intrinsic::nvvm_wmma_m32n8k16_mma_row_row_f16_f16:
3841 case Intrinsic::nvvm_wmma_m8n32k16_mma_col_col_f16_f16:
3842 case Intrinsic::nvvm_wmma_m8n32k16_mma_col_row_f16_f16:
3843 case Intrinsic::nvvm_wmma_m8n32k16_mma_row_col_f16_f16:
3844 case Intrinsic::nvvm_wmma_m8n32k16_mma_row_row_f16_f16: {
3845 for (int i = 0; i < 16; i++)
3847 I.getOperand(i),
3848 TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I),
3849 &I);
3850 for (int i = 16; i < 16 + 8; i++)
3852 I.getOperand(i),
3853 TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I),
3854 &I);
3856 &I,
3857 TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I),
3858 &I);
3859 return;
3860 }
3861
3862 case Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_f16_f32:
3863 case Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_f16_f32:
3864 case Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_f16_f32:
3865 case Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f32:
3866 case Intrinsic::nvvm_wmma_m32n8k16_mma_col_col_f16_f32:
3867 case Intrinsic::nvvm_wmma_m32n8k16_mma_col_row_f16_f32:
3868 case Intrinsic::nvvm_wmma_m32n8k16_mma_row_col_f16_f32:
3869 case Intrinsic::nvvm_wmma_m32n8k16_mma_row_row_f16_f32:
3870 case Intrinsic::nvvm_wmma_m8n32k16_mma_col_col_f16_f32:
3871 case Intrinsic::nvvm_wmma_m8n32k16_mma_col_row_f16_f32:
3872 case Intrinsic::nvvm_wmma_m8n32k16_mma_row_col_f16_f32:
3873 case Intrinsic::nvvm_wmma_m8n32k16_mma_row_row_f16_f32: {
3874 for (int i = 0; i < 16; i++)
3876 I.getOperand(i),
3877 TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I),
3878 &I);
3879 for (int i = 16; i < 16 + 8; i++)
3881 I.getOperand(i),
3882 TypeTree(ConcreteType(Type::getFloatTy(I.getContext()))).Only(-1, &I),
3883 &I);
3885 &I,
3886 TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I),
3887 &I);
3888 return;
3889 }
3890
3891 case Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_f32_f16:
3892 case Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_f32_f16:
3893 case Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_f32_f16:
3894 case Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f16:
3895 case Intrinsic::nvvm_wmma_m32n8k16_mma_col_col_f32_f16:
3896 case Intrinsic::nvvm_wmma_m32n8k16_mma_col_row_f32_f16:
3897 case Intrinsic::nvvm_wmma_m32n8k16_mma_row_col_f32_f16:
3898 case Intrinsic::nvvm_wmma_m32n8k16_mma_row_row_f32_f16:
3899 case Intrinsic::nvvm_wmma_m8n32k16_mma_col_col_f32_f16:
3900 case Intrinsic::nvvm_wmma_m8n32k16_mma_col_row_f32_f16:
3901 case Intrinsic::nvvm_wmma_m8n32k16_mma_row_col_f32_f16:
3902 case Intrinsic::nvvm_wmma_m8n32k16_mma_row_row_f32_f16: {
3903 for (int i = 0; i < 16; i++)
3905 I.getOperand(i),
3906 TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I),
3907 &I);
3908 for (int i = 16; i < 16 + 8; i++)
3910 I.getOperand(i),
3911 TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I),
3912 &I);
3914 &I,
3915 TypeTree(ConcreteType(Type::getFloatTy(I.getContext()))).Only(-1, &I),
3916 &I);
3917 return;
3918 }
3919
3920 case Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_f32_f32:
3921 case Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_f32_f32:
3922 case Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_f32_f32:
3923 case Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32:
3924 case Intrinsic::nvvm_wmma_m32n8k16_mma_col_col_f32_f32:
3925 case Intrinsic::nvvm_wmma_m32n8k16_mma_col_row_f32_f32:
3926 case Intrinsic::nvvm_wmma_m32n8k16_mma_row_col_f32_f32:
3927 case Intrinsic::nvvm_wmma_m32n8k16_mma_row_row_f32_f32:
3928 case Intrinsic::nvvm_wmma_m8n32k16_mma_col_col_f32_f32:
3929 case Intrinsic::nvvm_wmma_m8n32k16_mma_col_row_f32_f32:
3930 case Intrinsic::nvvm_wmma_m8n32k16_mma_row_col_f32_f32:
3931 case Intrinsic::nvvm_wmma_m8n32k16_mma_row_row_f32_f32: {
3932 for (int i = 0; i < 16; i++)
3934 I.getOperand(i),
3935 TypeTree(ConcreteType(Type::getHalfTy(I.getContext()))).Only(-1, &I),
3936 &I);
3937 for (int i = 16; i < 16 + 8; i++)
3939 I.getOperand(i),
3940 TypeTree(ConcreteType(Type::getFloatTy(I.getContext()))).Only(-1, &I),
3941 &I);
3943 &I,
3944 TypeTree(ConcreteType(Type::getFloatTy(I.getContext()))).Only(-1, &I),
3945 &I);
3946 return;
3947 }
3948
3949#if LLVM_VERSION_MAJOR < 20
3950 case Intrinsic::nvvm_ldg_global_i:
3951 case Intrinsic::nvvm_ldg_global_p:
3952 case Intrinsic::nvvm_ldg_global_f:
3953#endif
3954 case Intrinsic::nvvm_ldu_global_i:
3955 case Intrinsic::nvvm_ldu_global_p:
3956 case Intrinsic::nvvm_ldu_global_f: {
3957 auto &DL = I.getParent()->getParent()->getParent()->getDataLayout();
3958 auto LoadSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8;
3959
3960 if (direction & UP) {
3963 DL, /*start*/ 0, LoadSize, /*addOffset*/ 0);
3964 updateAnalysis(I.getOperand(0), ptr.Only(-1, &I), &I);
3965 }
3966 if (direction & DOWN)
3967 updateAnalysis(&I, getAnalysis(I.getOperand(0)).Lookup(LoadSize, DL), &I);
3968 return;
3969 }
3970
3971 case Intrinsic::log:
3972 case Intrinsic::log2:
3973 case Intrinsic::log10:
3974 case Intrinsic::exp:
3975 case Intrinsic::exp2:
3976 case Intrinsic::sin:
3977 case Intrinsic::cos:
3978#if LLVM_VERSION_MAJOR >= 19
3979 case Intrinsic::sinh:
3980 case Intrinsic::cosh:
3981 case Intrinsic::tanh:
3982#endif
3983 case Intrinsic::floor:
3984 case Intrinsic::ceil:
3985 case Intrinsic::trunc:
3986 case Intrinsic::rint:
3987 case Intrinsic::nearbyint:
3988 case Intrinsic::round:
3989 case Intrinsic::sqrt:
3990#if LLVM_VERSION_MAJOR >= 21
3991 case Intrinsic::nvvm_fabs:
3992 case Intrinsic::nvvm_fabs_ftz:
3993#else
3994 case Intrinsic::nvvm_fabs_f:
3995 case Intrinsic::nvvm_fabs_d:
3996 case Intrinsic::nvvm_fabs_ftz_f:
3997#endif
3998 case Intrinsic::fabs:
3999 // No direction check as always valid
4001 &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I),
4002 &I);
4003 // No direction check as always valid
4005 I.getOperand(0),
4006 TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType()))
4007 .Only(-1, &I),
4008 &I);
4009 return;
4010
4011 case Intrinsic::fmuladd:
4012 case Intrinsic::fma:
4013 // No direction check as always valid
4015 &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I),
4016 &I);
4017 // No direction check as always valid
4019 I.getOperand(0),
4020 TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType()))
4021 .Only(-1, &I),
4022 &I);
4023 // No direction check as always valid
4025 I.getOperand(1),
4026 TypeTree(ConcreteType(I.getOperand(1)->getType()->getScalarType()))
4027 .Only(-1, &I),
4028 &I);
4029 // No direction check as always valid
4031 I.getOperand(2),
4032 TypeTree(ConcreteType(I.getOperand(2)->getType()->getScalarType()))
4033 .Only(-1, &I),
4034 &I);
4035 return;
4036
4037 case Intrinsic::powi:
4038 // No direction check as always valid
4040 &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I),
4041 &I);
4042 // No direction check as always valid
4044 I.getOperand(0),
4045 TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType()))
4046 .Only(-1, &I),
4047 &I);
4048 // No direction check as always valid
4049 updateAnalysis(I.getOperand(1), TypeTree(BaseType::Integer).Only(-1, &I),
4050 &I);
4051 return;
4052
4053#if LLVM_VERSION_MAJOR >= 12
4054 case Intrinsic::vector_reduce_fadd:
4055 case Intrinsic::vector_reduce_fmul:
4056#else
4057 case Intrinsic::experimental_vector_reduce_v2_fadd:
4058 case Intrinsic::experimental_vector_reduce_v2_fmul:
4059#endif
4060 case Intrinsic::copysign:
4061 case Intrinsic::maxnum:
4062 case Intrinsic::minnum:
4063#if LLVM_VERSION_MAJOR >= 15
4064 case Intrinsic::maximum:
4065 case Intrinsic::minimum:
4066#endif
4067 case Intrinsic::nvvm_fmax_f:
4068 case Intrinsic::nvvm_fmax_d:
4069 case Intrinsic::nvvm_fmax_ftz_f:
4070 case Intrinsic::nvvm_fmin_f:
4071 case Intrinsic::nvvm_fmin_d:
4072 case Intrinsic::nvvm_fmin_ftz_f:
4073 case Intrinsic::pow:
4074 // No direction check as always valid
4076 &I, TypeTree(ConcreteType(I.getType()->getScalarType())).Only(-1, &I),
4077 &I);
4078 // No direction check as always valid
4080 I.getOperand(0),
4081 TypeTree(ConcreteType(I.getOperand(0)->getType()->getScalarType()))
4082 .Only(-1, &I),
4083 &I);
4084 // No direction check as always valid
4086 I.getOperand(1),
4087 TypeTree(ConcreteType(I.getOperand(1)->getType()->getScalarType()))
4088 .Only(-1, &I),
4089 &I);
4090 return;
4091#if LLVM_VERSION_MAJOR >= 12
4092 case Intrinsic::smax:
4093 case Intrinsic::smin:
4094 case Intrinsic::umax:
4095 case Intrinsic::umin:
4096 if (direction & UP) {
4097 auto returnType = getAnalysis(&I)[{-1}];
4098 if (returnType == BaseType::Integer || returnType == BaseType::Pointer) {
4099 updateAnalysis(I.getOperand(0), TypeTree(returnType).Only(-1, &I), &I);
4100 updateAnalysis(I.getOperand(1), TypeTree(returnType).Only(-1, &I), &I);
4101 }
4102 }
4103 if (direction & DOWN) {
4104 auto opType0 = getAnalysis(I.getOperand(0))[{-1}];
4105 auto opType1 = getAnalysis(I.getOperand(1))[{-1}];
4106 if (opType0 == opType1 &&
4107 (opType0 == BaseType::Integer || opType0 == BaseType::Pointer)) {
4108 updateAnalysis(&I, TypeTree(opType0).Only(-1, &I), &I);
4109 } else if (opType0 == BaseType::Integer &&
4110 opType1 == BaseType::Anything) {
4111 updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I);
4112 } else if (opType1 == BaseType::Integer &&
4113 opType0 == BaseType::Anything) {
4114 updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I);
4115 }
4116 }
4117 return;
4118#endif
4119 case Intrinsic::umul_with_overflow:
4120 case Intrinsic::smul_with_overflow:
4121 case Intrinsic::ssub_with_overflow:
4122 case Intrinsic::usub_with_overflow:
4123 case Intrinsic::sadd_with_overflow:
4124 case Intrinsic::uadd_with_overflow: {
4125 // val, bool
4126 auto analysis = getAnalysis(&I).Data0();
4127
4128 BinaryOperator::BinaryOps opcode;
4129 // TODO update to use better rules in regular binop
4130 switch (I.getIntrinsicID()) {
4131 case Intrinsic::ssub_with_overflow:
4132 case Intrinsic::usub_with_overflow: {
4133 // TODO propagate this info
4134 // ptr - ptr => int and int - int => int; thus int = a - b says only that
4135 // these are equal ptr - int => ptr and int - ptr => ptr; thus
4137 opcode = BinaryOperator::Sub;
4138 break;
4139 }
4140
4141 case Intrinsic::smul_with_overflow:
4142 case Intrinsic::umul_with_overflow: {
4143 opcode = BinaryOperator::Mul;
4144 // if a + b or a * b == int, then a and b must be ints
4145 analysis = analysis.JustInt();
4146 break;
4147 }
4148 case Intrinsic::sadd_with_overflow:
4149 case Intrinsic::uadd_with_overflow: {
4150 opcode = BinaryOperator::Add;
4151 // if a + b or a * b == int, then a and b must be ints
4152 analysis = analysis.JustInt();
4153 break;
4154 }
4155 default:
4156 llvm_unreachable("unknown binary operator");
4157 }
4158
4159 // TODO update with newer binop protocol (see binop)
4160 if (direction & UP)
4161 updateAnalysis(I.getOperand(0), analysis.Only(-1, &I), &I);
4162 if (direction & UP)
4163 updateAnalysis(I.getOperand(1), analysis.Only(-1, &I), &I);
4164
4165 TypeTree vd = getAnalysis(I.getOperand(0)).Data0();
4166 bool Legal = true;
4167 vd.binopIn(Legal, getAnalysis(I.getOperand(1)).Data0(), opcode);
4168 if (!Legal) {
4169 std::string str;
4170 raw_string_ostream ss(str);
4171 if (!CustomErrorHandler) {
4172 llvm::errs() << *fntypeinfo.Function->getParent() << "\n";
4173 llvm::errs() << *fntypeinfo.Function << "\n";
4174 dump(ss);
4175 }
4176 ss << "Illegal updateBinopIntr Analysis " << I << "\n";
4177 ss << "Illegal binopIn(intr): " << I << " lhs: " << vd.str()
4178 << " rhs: " << getAnalysis(I.getOperand(1)).str() << "\n";
4179 if (CustomErrorHandler) {
4180 CustomErrorHandler(str.c_str(), wrap(&I),
4181 ErrorType::IllegalTypeAnalysis, (void *)this,
4182 wrap(&I), nullptr);
4183 }
4184 EmitFailure("IllegalUpdateAnalysis", I.getDebugLoc(), &I, ss.str());
4185 report_fatal_error("Performed illegal updateAnalysis");
4186 }
4187 auto &dl = I.getParent()->getParent()->getParent()->getDataLayout();
4188 int sz = (dl.getTypeSizeInBits(I.getOperand(0)->getType()) + 7) / 8;
4189 TypeTree overall = vd.Only(-1, &I).ShiftIndices(dl, 0, sz, 0);
4190
4191 int sz2 = (dl.getTypeSizeInBits(I.getType()) + 7) / 8;
4192 auto btree = TypeTree(BaseType::Integer)
4193 .Only(-1, &I)
4194 .ShiftIndices(dl, 0, sz2 - sz, sz);
4195 overall |= btree;
4196
4197 if (direction & DOWN)
4198 updateAnalysis(&I, overall, &I);
4199 return;
4200 }
4201 default:
4202 return;
4203 }
4204}
4205
4206/// This template class is defined to take the templated type T
4207/// update the analysis of the first argument (val) to be type T
4208/// As such, below we have several template specializations
4209/// to convert various c/c++ to TypeAnalysis types
4210template <typename T> struct TypeHandler {};
4211
4212template <> struct TypeHandler<double> {
4213 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4214 TA.updateAnalysis(
4215 val,
4216 TypeTree(ConcreteType(Type::getDoubleTy(call.getContext())))
4217 .Only(-1, &call),
4218 &call);
4219 }
4220};
4221
4222template <> struct TypeHandler<float> {
4223 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4224 TA.updateAnalysis(
4225 val,
4226 TypeTree(ConcreteType(Type::getFloatTy(call.getContext())))
4227 .Only(-1, &call),
4228 &call);
4229 }
4230};
4231
4232template <> struct TypeHandler<long double> {
4233 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4234 TA.updateAnalysis(
4235 val,
4236 TypeTree(ConcreteType(Type::getX86_FP80Ty(call.getContext())))
4237 .Only(-1, &call),
4238 &call);
4239 }
4240};
4241
4242#if defined(__FLOAT128__) || defined(__SIZEOF_FLOAT128__)
4243template <> struct TypeHandler<__float128> {
4244 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4245 TA.updateAnalysis(
4246 val,
4247 TypeTree(ConcreteType(Type::getFP128Ty(call.getContext())))
4248 .Only(-1, &call),
4249 &call);
4250 }
4251};
4252#endif
4253
4254template <> struct TypeHandler<double *> {
4255 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4256 TypeTree vd = TypeTree(Type::getDoubleTy(call.getContext())).Only(0, &call);
4258 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4259 }
4260};
4261
4262template <> struct TypeHandler<float *> {
4263 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4264 TypeTree vd = TypeTree(Type::getFloatTy(call.getContext())).Only(0, &call);
4266 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4267 }
4268};
4269
4270template <> struct TypeHandler<long double *> {
4271 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4272 TypeTree vd =
4273 TypeTree(Type::getX86_FP80Ty(call.getContext())).Only(0, &call);
4275 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4276 }
4277};
4278
4279#if defined(__FLOAT128__) || defined(__SIZEOF_FLOAT128__)
4280template <> struct TypeHandler<__float128 *> {
4281 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4282 TypeTree vd = TypeTree(Type::getFP128Ty(call.getContext())).Only(0, &call);
4284 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4285 }
4286};
4287#endif
4288
4289template <> struct TypeHandler<void> {
4290 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {}
4291};
4292
4293template <> struct TypeHandler<void *> {
4294 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4296 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4297 }
4298};
4299
4300template <> struct TypeHandler<int> {
4301 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4303 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4304 }
4305};
4306
4307template <> struct TypeHandler<int *> {
4308 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4309 TypeTree vd = TypeTree(BaseType::Integer).Only(0, &call);
4311 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4312 }
4313};
4314
4315template <> struct TypeHandler<unsigned int> {
4316 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4318 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4319 }
4320};
4321
4322template <> struct TypeHandler<unsigned int *> {
4323 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4324 TypeTree vd = TypeTree(BaseType::Integer).Only(0, &call);
4326 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4327 }
4328};
4329
4330template <> struct TypeHandler<long int> {
4331 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4333 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4334 }
4335};
4336
4337template <> struct TypeHandler<long int *> {
4338 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4339 TypeTree vd = TypeTree(BaseType::Integer).Only(0, &call);
4341 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4342 }
4343};
4344
4345template <> struct TypeHandler<long unsigned int> {
4346 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4348 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4349 }
4350};
4351
4352template <> struct TypeHandler<long unsigned int *> {
4353 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4354 TypeTree vd = TypeTree(BaseType::Integer).Only(0, &call);
4356 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4357 }
4358};
4359
4360template <> struct TypeHandler<long long int> {
4361 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4363 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4364 }
4365};
4366
4367template <> struct TypeHandler<long long int *> {
4368 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4369 TypeTree vd = TypeTree(BaseType::Integer).Only(0, &call);
4371 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4372 }
4373};
4374
4375template <> struct TypeHandler<long long unsigned int> {
4376 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4378 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4379 }
4380};
4381
4382template <> struct TypeHandler<long long unsigned int *> {
4383 static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA) {
4384 TypeTree vd = TypeTree(BaseType::Integer).Only(0, &call);
4386 TA.updateAnalysis(val, vd.Only(-1, &call), &call);
4387 }
4388};
4389
4390template <typename... Arg0> struct FunctionArgumentIterator {
4391 static void analyzeFuncTypesHelper(unsigned idx, CallBase &call,
4392 TypeAnalyzer &TA) {}
4393};
4394
4395template <typename Arg0, typename... Args>
4397 static void analyzeFuncTypesHelper(unsigned idx, CallBase &call,
4398 TypeAnalyzer &TA) {
4399 TypeHandler<Arg0>::analyzeType(call.getOperand(idx), call, TA);
4401 TA);
4402 }
4403};
4404
4405template <typename RT, typename... Args>
4410
4411template <typename RT, typename... Args>
4412void analyzeFuncTypes(RT (*fn)(Args...), CallBase &call, TypeAnalyzer &TA) {
4413 analyzeFuncTypesNoFn<RT, Args...>(call, TA);
4414}
4415
4416void analyzeIntelSubscriptIntrinsic(IntrinsicInst &II, TypeAnalyzer &TA) {
4417 assert(isIntelSubscriptIntrinsic(II));
4418#if LLVM_VERSION_MAJOR >= 14
4419 assert(II.arg_size() == 5);
4420#else
4421 assert(II.getNumArgOperands() == 5);
4422#endif
4423
4424 constexpr size_t idxArgsIndices[4] = {0, 1, 2, 4};
4425 constexpr size_t ptrArgIndex = 3;
4426
4427 // Update analysis of index parameters
4428
4429 if (TA.direction & TypeAnalyzer::UP) {
4430 for (auto i : idxArgsIndices) {
4431 auto idx = II.getOperand(i);
4432 TA.updateAnalysis(idx, TypeTree(BaseType::Integer).Only(-1, &II), &II);
4433 }
4434 }
4435
4436 // Update analysis of ptr parameter
4437
4438 auto &DL = TA.fntypeinfo.Function->getParent()->getDataLayout();
4439 auto pointerAnalysis = TA.getAnalysis(II.getOperand(ptrArgIndex));
4440
4441 if (TA.direction & TypeAnalyzer::DOWN) {
4442 bool legal = true;
4443 auto keepMinus = pointerAnalysis.KeepMinusOne(legal);
4444 if (!legal) {
4446 CustomErrorHandler("Could not keep minus one", wrap(&II),
4447 ErrorType::IllegalTypeAnalysis, &TA, nullptr,
4448 nullptr);
4449 else {
4450 TA.dump();
4451 llvm::errs()
4452 << " could not perform minus one for llvm.intel.subscript'd: " << II
4453 << "\n";
4454 }
4455 }
4456 TA.updateAnalysis(&II, keepMinus, &II);
4457 TA.updateAnalysis(&II, TypeTree(pointerAnalysis.Inner0()).Only(-1, &II),
4458 &II);
4459 }
4460
4461 if (TA.direction & TypeAnalyzer::UP) {
4462 TA.updateAnalysis(II.getOperand(ptrArgIndex),
4463 TypeTree(TA.getAnalysis(&II).Inner0()).Only(-1, &II),
4464 &II);
4465 }
4466
4467 SmallVector<std::set<int64_t>, 4> idnext;
4468 // The first operand is used to denote the axis of a multidimensional array,
4469 // but it is not used for address calculation, and so we skip it here.
4470 constexpr size_t offsetCalculationIndices[3] = {1, 2, 4};
4471 for (auto i : offsetCalculationIndices) {
4472 auto idx = II.getOperand(i);
4473 auto iset = TA.knownIntegralValues(idx);
4474 std::set<int64_t> vset;
4475 for (auto i : iset) {
4476 // Don't consider negative indices of llvm.intel.subscript
4477 if (i < 0)
4478 continue;
4479 vset.insert(i);
4480 }
4481 idnext.push_back(vset);
4482 if (idnext.back().size() == 0)
4483 return;
4484 }
4485 assert(idnext.size() != 0);
4486
4487 TypeTree upTree;
4488 TypeTree downTree;
4489
4490 TypeTree intrinsicData0;
4491 TypeTree pointerData0;
4492 if (TA.direction & TypeAnalyzer::UP)
4493 intrinsicData0 = TA.getAnalysis(&II).Data0();
4495 pointerData0 = pointerAnalysis.Data0();
4496
4497 bool firstLoop = true;
4498
4499 for (auto vec : getSet<int64_t>(idnext, idnext.size() - 1)) {
4500 auto baseIndex = vec[0];
4501 auto stride = vec[1];
4502 auto index = vec[2];
4503
4504 int offset = static_cast<int>(stride * (index - baseIndex));
4505 if (offset < 0) {
4506 continue; // The intrinsic doesn't handle negative offsets
4507 }
4508
4509 if (TA.direction & TypeAnalyzer::DOWN) {
4510 auto shft = pointerData0.ShiftIndices(DL, /*init offset*/ offset,
4511 /*max size*/ -1, /*newoffset*/ 0);
4512 if (firstLoop)
4513 downTree = shft;
4514 else
4515 downTree &= shft;
4516 }
4517
4518 if (TA.direction & TypeAnalyzer::UP) {
4519 auto shft =
4520 intrinsicData0.ShiftIndices(DL, /*init offset*/ 0, /*max size*/ -1,
4521 /*new offset*/ offset);
4522 if (firstLoop)
4523 upTree = shft;
4524 else
4525 upTree |= shft;
4526 }
4527 firstLoop = false;
4528 }
4530 TA.updateAnalysis(&II, downTree.Only(-1, &II), &II);
4531 if (TA.direction & TypeAnalyzer::UP)
4532 TA.updateAnalysis(II.getOperand(ptrArgIndex), upTree.Only(-1, &II), &II);
4533}
4534
4535void TypeAnalyzer::visitCallBase(CallBase &call) {
4536 assert(fntypeinfo.KnownValues.size() ==
4537 fntypeinfo.Function->getFunctionType()->getNumParams());
4538
4539 if (auto iasm = dyn_cast<InlineAsm>(call.getCalledOperand())) {
4540 // NO direction check as always valid
4541 if (StringRef(iasm->getAsmString()).contains("cpuid")) {
4542 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4543#if LLVM_VERSION_MAJOR >= 14
4544 for (auto &arg : call.args())
4545#else
4546 for (auto &arg : call.arg_operands())
4547#endif
4548 {
4549 updateAnalysis(arg, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4550 }
4551 }
4552 }
4553
4554 if (call.hasFnAttr("enzyme_ta_norecur"))
4555 return;
4556
4557 Function *ci = getFunctionFromCall(&call);
4558
4559 if (ci) {
4560 if (ci->getAttributes().hasAttribute(AttributeList::FunctionIndex,
4561 "enzyme_ta_norecur"))
4562 return;
4563
4564 StringRef funcName = getFuncNameFromCall(&call);
4565
4566 auto blasMetaData = extractBLAS(funcName);
4567 if (blasMetaData) {
4568 BlasInfo blas = *blasMetaData;
4569#include "BlasTA.inc"
4570 }
4571
4572 // clang-format off
4573 const char* NoTARecurStartsWith[] = {
4574 "std::__u::basic_ostream<wchar_t, std::__u::char_traits<wchar_t>>& std::__u::operator<<",
4575 };
4576 // clang-format on
4577 {
4578 std::string demangledName = llvm::demangle(funcName.str());
4579 // replace all '> >' with '>>'
4580 size_t start = 0;
4581 while ((start = demangledName.find("> >", start)) != std::string::npos) {
4582 demangledName.replace(start, 3, ">>");
4583 }
4584 for (auto Name : NoTARecurStartsWith)
4585 if (startsWith(demangledName, Name))
4586 return;
4587 }
4588
4589 // Manual TT specification is non-interprocedural and already handled once
4590 // at the start.
4591
4592 // When compiling Enzyme against standard LLVM, and not Intel's
4593 // modified version of LLVM, the intrinsic `llvm.intel.subscript` is
4594 // not fully understood by LLVM. One of the results of this is that the
4595 // visitor dispatches to visitCallBase, rather than visitIntrinsicInst, when
4596 // presented with the intrinsic - hence why we are handling it here.
4597 if (startsWith(funcName, "llvm.intel.subscript")) {
4598 assert(isa<IntrinsicInst>(call));
4599 analyzeIntelSubscriptIntrinsic(cast<IntrinsicInst>(call), *this);
4600 return;
4601 }
4602
4603#define CONSIDER(fn) \
4604 if (funcName == #fn) { \
4605 analyzeFuncTypes(::fn, call, *this); \
4606 return; \
4607 }
4608
4609#define CONSIDER2(fn, ...) \
4610 if (funcName == #fn) { \
4611 analyzeFuncTypesNoFn<__VA_ARGS__>(call, *this); \
4612 return; \
4613 }
4614
4615 auto customrule = interprocedural.CustomRules.find(funcName);
4616 if (customrule != interprocedural.CustomRules.end()) {
4617 auto returnAnalysis = getAnalysis(&call);
4618 SmallVector<TypeTree, 4> args;
4619 SmallVector<std::set<int64_t>, 4> knownValues;
4620#if LLVM_VERSION_MAJOR >= 14
4621 for (auto &arg : call.args())
4622#else
4623 for (auto &arg : call.arg_operands())
4624#endif
4625 {
4626 args.push_back(getAnalysis(arg));
4627 knownValues.push_back(
4628 fntypeinfo.knownIntegralValues((Value *)arg, DT, intseen, SE));
4629 }
4630
4631 bool err = customrule->second(direction, returnAnalysis, args,
4632 knownValues, &call, this);
4633 if (err) {
4634 Invalid = true;
4635 return;
4636 }
4637 updateAnalysis(&call, returnAnalysis, &call);
4638 size_t argnum = 0;
4639#if LLVM_VERSION_MAJOR >= 14
4640 for (auto &arg : call.args())
4641#else
4642 for (auto &arg : call.arg_operands())
4643#endif
4644 {
4645 updateAnalysis(arg, args[argnum], &call);
4646 argnum++;
4647 }
4648 return;
4649 }
4650
4651 // All these are always valid => no direction check
4652 // CONSIDER(malloc)
4653 // TODO consider handling other allocation functions integer inputs
4654 if (startsWith(funcName, "_ZN3std2io5stdio6_print") ||
4655 startsWith(funcName, "_ZN4core3fmt")) {
4656 return;
4657 }
4658 /// GEMM
4659 if (funcName == "dgemm_64" || funcName == "dgemm_64_" ||
4660 funcName == "dgemm" || funcName == "dgemm_") {
4661 TypeTree ptrint;
4662 ptrint.insert({-1}, BaseType::Pointer);
4663 ptrint.insert({-1, 0}, BaseType::Integer);
4664 // transa, transb, m, n, k, lda, ldb, ldc
4665 for (int i : {0, 1, 2, 3, 4, 7, 9, 12})
4666 updateAnalysis(call.getArgOperand(i), ptrint, &call);
4667
4668 TypeTree ptrdbl;
4669 ptrdbl.insert({-1}, BaseType::Pointer);
4670 ptrdbl.insert({-1, 0}, Type::getDoubleTy(call.getContext()));
4671
4672 // alpha, a, b, beta, c
4673 for (int i : {5, 6, 8, 10, 11})
4674 updateAnalysis(call.getArgOperand(i), ptrdbl, &call);
4675 return;
4676 }
4677
4678 if (funcName == "__kmpc_fork_call") {
4679 Function *fn = dyn_cast<Function>(call.getArgOperand(2));
4680
4681 if (auto castinst = dyn_cast<ConstantExpr>(call.getArgOperand(2)))
4682 if (castinst->isCast())
4683 fn = dyn_cast<Function>(castinst->getOperand(0));
4684
4685 if (fn) {
4686#if LLVM_VERSION_MAJOR >= 14
4687 if (call.arg_size() - 3 != fn->getFunctionType()->getNumParams() - 2)
4688 return;
4689#else
4690 if (call.getNumArgOperands() - 3 !=
4691 fn->getFunctionType()->getNumParams() - 2)
4692 return;
4693#endif
4694
4695 if (direction & UP) {
4696 FnTypeInfo typeInfo(fn);
4697
4698 TypeTree IntPtr;
4699 IntPtr.insert({-1, -1}, BaseType::Integer);
4700 IntPtr.insert({-1}, BaseType::Pointer);
4701
4702 int argnum = 0;
4703 for (auto &arg : fn->args()) {
4704 if (argnum <= 1) {
4705 typeInfo.Arguments.insert(
4706 std::pair<Argument *, TypeTree>(&arg, IntPtr));
4707 typeInfo.KnownValues.insert(
4708 std::pair<Argument *, std::set<int64_t>>(&arg, {0}));
4709 } else {
4710 typeInfo.Arguments.insert(std::pair<Argument *, TypeTree>(
4711 &arg, getAnalysis(call.getArgOperand(argnum - 2 + 3))));
4712 std::set<int64_t> bounded;
4713 for (auto v : fntypeinfo.knownIntegralValues(
4714 call.getArgOperand(argnum - 2 + 3), DT, intseen, SE)) {
4715 if (abs(v) > MaxIntOffset)
4716 continue;
4717 bounded.insert(v);
4718 }
4719 typeInfo.KnownValues.insert(
4720 std::pair<Argument *, std::set<int64_t>>(&arg, bounded));
4721 }
4722
4723 ++argnum;
4724 }
4725
4726 if (EnzymePrintType) {
4727 llvm::errs() << " starting omp IPO of ";
4728 call.print(llvm::errs(), *MST);
4729 llvm::errs() << "\n";
4730 }
4731
4732 auto a = fn->arg_begin();
4733 ++a;
4734 ++a;
4736#if LLVM_VERSION_MAJOR >= 14
4737 for (unsigned i = 3; i < call.arg_size(); ++i)
4738#else
4739 for (unsigned i = 3; i < call.getNumArgOperands(); ++i)
4740#endif
4741 {
4742 auto dt = STR.query(a);
4743 updateAnalysis(call.getArgOperand(i), dt, &call);
4744 ++a;
4745 }
4746 }
4747 }
4748 return;
4749 }
4750 if (funcName == "__kmpc_for_static_init_4" ||
4751 funcName == "__kmpc_for_static_init_4u" ||
4752 funcName == "__kmpc_for_static_init_8" ||
4753 funcName == "__kmpc_for_static_init_8u") {
4754 TypeTree ptrint;
4755 ptrint.insert({-1}, BaseType::Pointer);
4756 size_t numBytes = 4;
4757 if (funcName == "__kmpc_for_static_init_8" ||
4758 funcName == "__kmpc_for_static_init_8u")
4759 numBytes = 8;
4760 for (size_t i = 0; i < numBytes; i++)
4761 ptrint.insert({-1, (int)i}, BaseType::Integer);
4762 updateAnalysis(call.getArgOperand(3), ptrint, &call);
4763 updateAnalysis(call.getArgOperand(4), ptrint, &call);
4764 updateAnalysis(call.getArgOperand(5), ptrint, &call);
4765 updateAnalysis(call.getArgOperand(6), ptrint, &call);
4766 updateAnalysis(call.getArgOperand(7),
4767 TypeTree(BaseType::Integer).Only(-1, &call), &call);
4768 updateAnalysis(call.getArgOperand(8),
4769 TypeTree(BaseType::Integer).Only(-1, &call), &call);
4770 return;
4771 }
4772 if (funcName == "omp_get_max_threads" || funcName == "omp_get_thread_num" ||
4773 funcName == "omp_get_num_threads" ||
4774 funcName == "__kmpc_global_thread_num") {
4775 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4776 return;
4777 }
4778 if (funcName == "_ZNSt6localeC1Ev") {
4779 TypeTree ptrint;
4780 ptrint.insert({-1}, BaseType::Pointer);
4781 ptrint.insert({-1, 0}, BaseType::Integer);
4782 updateAnalysis(call.getOperand(0), ptrint, &call);
4783 return;
4784 }
4785
4786 if (startsWith(funcName, "_ZNKSt3__14hash")) {
4787 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4788 return;
4789 }
4790
4791 if (startsWith(funcName, "_ZNKSt3__112basic_string") ||
4792 startsWith(funcName, "_ZNSt3__112basic_string") ||
4793 startsWith(funcName, "_ZNSt3__112__hash_table") ||
4794 startsWith(funcName, "_ZNKSt3__115basic_stringbuf")) {
4795 return;
4796 }
4797
4798 if (funcName == "__dynamic_cast" ||
4799 funcName == "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base" ||
4800 funcName == "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base" ||
4801 funcName == "_ZSt18_Rb_tree_decrementPSt18_Rb_tree_node_base" ||
4802 funcName == "_ZSt18_Rb_tree_incrementPSt18_Rb_tree_node_base") {
4803 updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call);
4804 return;
4805 }
4806 if (funcName == "memcmp") {
4807 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4808 updateAnalysis(call.getOperand(0),
4809 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
4810 updateAnalysis(call.getOperand(1),
4811 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
4812 updateAnalysis(call.getOperand(2),
4813 TypeTree(BaseType::Integer).Only(-1, &call), &call);
4814 return;
4815 }
4816
4817 /// CUDA
4818 if (funcName == "cuDeviceGet") {
4819 // cuResult
4820 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4821 updateAnalysis(call.getOperand(0),
4822 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
4823 updateAnalysis(call.getOperand(1),
4824 TypeTree(BaseType::Integer).Only(-1, &call), &call);
4825 return;
4826 }
4827 if (funcName == "cuDeviceGetName") {
4828 // cuResult
4829 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4830 updateAnalysis(call.getOperand(0),
4831 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
4832 updateAnalysis(call.getOperand(1),
4833 TypeTree(BaseType::Integer).Only(-1, &call), &call);
4834 return;
4835 }
4836 if (funcName == "cudaRuntimeGetVersion" ||
4837 funcName == "cuDriverGetVersion" || funcName == "cuDeviceGetCount") {
4838 // cuResult
4839 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4840 TypeTree ptrint;
4841 ptrint.insert({-1}, BaseType::Pointer);
4842 ptrint.insert({-1, 0}, BaseType::Integer);
4843 updateAnalysis(call.getOperand(0), ptrint, &call);
4844 return;
4845 }
4846 if (funcName == "cuMemGetInfo_v2") {
4847 // cuResult
4848 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4849 TypeTree ptrint;
4850 ptrint.insert({-1}, BaseType::Pointer);
4851 ptrint.insert({-1, 0}, BaseType::Integer);
4852 updateAnalysis(call.getOperand(0), ptrint, &call);
4853 updateAnalysis(call.getOperand(1), ptrint, &call);
4854 return;
4855 }
4856 if (funcName == "cuDevicePrimaryCtxRetain" ||
4857 funcName == "cuCtxGetCurrent") {
4858 // cuResult
4859 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4860 updateAnalysis(call.getOperand(0),
4861 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
4862 return;
4863 }
4864 if (funcName == "cuStreamQuery") {
4865 // cuResult
4866 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4867 return;
4868 }
4869 if (funcName == "cuMemAllocAsync" || funcName == "cuMemAlloc" ||
4870 funcName == "cuMemAlloc_v2" || funcName == "cudaMalloc" ||
4871 funcName == "cudaMallocAsync" || funcName == "cudaMallocHost" ||
4872 funcName == "cudaMallocFromPoolAsync") {
4873 TypeTree ptrptr;
4874 ptrptr.insert({-1}, BaseType::Pointer);
4875 ptrptr.insert({-1, 0}, BaseType::Pointer);
4876 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4877 updateAnalysis(call.getOperand(0), ptrptr, &call);
4878 updateAnalysis(call.getOperand(1),
4879 TypeTree(BaseType::Integer).Only(-1, &call), &call);
4880 return;
4881 }
4882 if (funcName == "jl_hrtime" || funcName == "ijl_hrtime") {
4883 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4884 return;
4885 }
4886 if (funcName == "jl_get_task_tid" || funcName == "ijl_get_task_tid") {
4887 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4888 return;
4889 }
4890 if (funcName == "jl_get_binding_or_error" ||
4891 funcName == "ijl_get_binding_or_error") {
4892 updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call);
4893 return;
4894 }
4895 if (funcName == "julia.gc_loaded") {
4896 if (direction & UP)
4897 updateAnalysis(call.getArgOperand(1), getAnalysis(&call), &call);
4898 if (direction & DOWN)
4899 updateAnalysis(&call, getAnalysis(call.getArgOperand(1)), &call);
4900 return;
4901 }
4902 if (funcName == "julia.pointer_from_objref") {
4903 if (direction & UP)
4904 updateAnalysis(call.getArgOperand(0), getAnalysis(&call), &call);
4905 if (direction & DOWN)
4906 updateAnalysis(&call, getAnalysis(call.getArgOperand(0)), &call);
4907 return;
4908 }
4909 if (funcName == "_ZNSt6chrono3_V212steady_clock3nowEv") {
4910 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4911 return;
4912 }
4913
4914 /// MPI
4915 if (startsWith(funcName, "PMPI_"))
4916 funcName = funcName.substr(1);
4917 if (funcName == "MPI_Init") {
4918 TypeTree ptrint;
4919 ptrint.insert({-1}, BaseType::Pointer);
4920 ptrint.insert({-1, 0}, BaseType::Integer);
4921 updateAnalysis(call.getOperand(0), ptrint, &call);
4922 TypeTree ptrptrptr;
4923 ptrptrptr.insert({-1}, BaseType::Pointer);
4924 ptrptrptr.insert({-1, -1}, BaseType::Pointer);
4925 ptrptrptr.insert({-1, -1, 0}, BaseType::Pointer);
4926 updateAnalysis(call.getOperand(1), ptrptrptr, &call);
4927 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4928 return;
4929 }
4930 if (funcName == "MPI_Comm_size" || funcName == "MPI_Comm_rank" ||
4931 funcName == "MPI_Get_processor_name") {
4932 TypeTree ptrint;
4933 ptrint.insert({-1}, BaseType::Pointer);
4934 ptrint.insert({-1, 0}, BaseType::Integer);
4935 updateAnalysis(call.getOperand(1), ptrint, &call);
4936 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4937 return;
4938 }
4939 if (funcName == "MPI_Barrier" || funcName == "MPI_Finalize") {
4940 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4941 return;
4942 }
4943 if (funcName == "MPI_Send" || funcName == "MPI_Ssend" ||
4944 funcName == "MPI_Bsend" || funcName == "MPI_Recv" ||
4945 funcName == "MPI_Brecv" || funcName == "PMPI_Send" ||
4946 funcName == "PMPI_Ssend" || funcName == "PMPI_Bsend" ||
4947 funcName == "PMPI_Recv" || funcName == "PMPI_Brecv") {
4949
4950 if (Constant *C = dyn_cast<Constant>(call.getOperand(2))) {
4951 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
4952 C = CE->getOperand(0);
4953 }
4954 if (auto GV = dyn_cast<GlobalVariable>(C)) {
4955 if (GV->getName() == "ompi_mpi_double") {
4956 buf.insert({0}, Type::getDoubleTy(C->getContext()));
4957 } else if (GV->getName() == "ompi_mpi_float") {
4958 buf.insert({0}, Type::getFloatTy(C->getContext()));
4959 } else if (GV->getName() == "ompi_mpi_cxx_bool") {
4960 buf.insert({0}, BaseType::Integer);
4961 }
4962 } else if (auto CI = dyn_cast<ConstantInt>(C)) {
4963 // MPICH
4964 if (CI->getValue() == 1275070475) {
4965 buf.insert({0}, Type::getDoubleTy(C->getContext()));
4966 } else if (CI->getValue() == 1275069450) {
4967 buf.insert({0}, Type::getFloatTy(C->getContext()));
4968 }
4969 }
4970 }
4971 updateAnalysis(call.getOperand(0), buf.Only(-1, &call), &call);
4972 updateAnalysis(call.getOperand(1),
4973 TypeTree(BaseType::Integer).Only(-1, &call), &call);
4974 updateAnalysis(call.getOperand(3),
4975 TypeTree(BaseType::Integer).Only(-1, &call), &call);
4976 updateAnalysis(call.getOperand(4),
4977 TypeTree(BaseType::Integer).Only(-1, &call), &call);
4978 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
4979 return;
4980 }
4981 if (funcName == "MPI_Isend" || funcName == "MPI_Irecv" ||
4982 funcName == "PMPI_Isend" || funcName == "PMPI_Irecv") {
4984
4985 if (Constant *C = dyn_cast<Constant>(call.getOperand(2))) {
4986 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
4987 C = CE->getOperand(0);
4988 }
4989 if (auto GV = dyn_cast<GlobalVariable>(C)) {
4990 if (GV->getName() == "ompi_mpi_double") {
4991 buf.insert({0}, Type::getDoubleTy(C->getContext()));
4992 } else if (GV->getName() == "ompi_mpi_float") {
4993 buf.insert({0}, Type::getFloatTy(C->getContext()));
4994 } else if (GV->getName() == "ompi_mpi_cxx_bool") {
4995 buf.insert({0}, BaseType::Integer);
4996 }
4997 } else if (auto CI = dyn_cast<ConstantInt>(C)) {
4998 // MPICH
4999 if (CI->getValue() == 1275070475) {
5000 buf.insert({0}, Type::getDoubleTy(C->getContext()));
5001 } else if (CI->getValue() == 1275069450) {
5002 buf.insert({0}, Type::getFloatTy(C->getContext()));
5003 }
5004 }
5005 }
5006 updateAnalysis(call.getOperand(0), buf.Only(-1, &call), &call);
5007 updateAnalysis(call.getOperand(1),
5008 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5009 updateAnalysis(call.getOperand(3),
5010 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5011 updateAnalysis(call.getOperand(4),
5012 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5013 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5014 updateAnalysis(call.getOperand(6),
5015 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5016 return;
5017 }
5018 if (funcName == "MPI_Wait") {
5019 updateAnalysis(call.getOperand(0),
5020 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5021 updateAnalysis(call.getOperand(1),
5022 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5023 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5024 return;
5025 }
5026 if (funcName == "MPI_Waitany") {
5027 updateAnalysis(call.getOperand(0),
5028 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5029 updateAnalysis(call.getOperand(1),
5030 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5031 updateAnalysis(call.getOperand(2),
5032 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5033 updateAnalysis(call.getOperand(3),
5034 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5035 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5036 return;
5037 }
5038 if (funcName == "MPI_Waitall") {
5039 updateAnalysis(call.getOperand(0),
5040 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5041 updateAnalysis(call.getOperand(1),
5042 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5043 updateAnalysis(call.getOperand(2),
5044 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5045 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5046 return;
5047 }
5048 if (funcName == "MPI_Bcast") {
5049 updateAnalysis(call.getOperand(0),
5050 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5051 updateAnalysis(call.getOperand(1),
5052 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5053 updateAnalysis(call.getOperand(3),
5054 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5055 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5056 return;
5057 }
5058 if (funcName == "MPI_Reduce" || funcName == "PMPI_Reduce") {
5060
5061 if (Constant *C = dyn_cast<Constant>(call.getOperand(3))) {
5062 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
5063 C = CE->getOperand(0);
5064 }
5065 if (auto GV = dyn_cast<GlobalVariable>(C)) {
5066 if (GV->getName() == "ompi_mpi_double") {
5067 buf.insert({0}, Type::getDoubleTy(C->getContext()));
5068 } else if (GV->getName() == "ompi_mpi_float") {
5069 buf.insert({0}, Type::getFloatTy(C->getContext()));
5070 } else if (GV->getName() == "ompi_mpi_cxx_bool") {
5071 buf.insert({0}, BaseType::Integer);
5072 }
5073 } else if (auto CI = dyn_cast<ConstantInt>(C)) {
5074 // MPICH
5075 if (CI->getValue() == 1275070475) {
5076 buf.insert({0}, Type::getDoubleTy(C->getContext()));
5077 } else if (CI->getValue() == 1275069450) {
5078 buf.insert({0}, Type::getFloatTy(C->getContext()));
5079 }
5080 }
5081 }
5082 // int MPI_Reduce(const void *sendbuf, void *recvbuf, int count,
5083 // MPI_Datatype datatype,
5084 // MPI_Op op, int root, MPI_Comm comm)
5085 // sendbuf
5086 updateAnalysis(call.getOperand(0), buf.Only(-1, &call), &call);
5087 // recvbuf
5088 updateAnalysis(call.getOperand(1), buf.Only(-1, &call), &call);
5089 // count
5090 updateAnalysis(call.getOperand(2),
5091 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5092 // datatype
5093 // op
5094 // comm
5095 // result
5096 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5097 return;
5098 }
5099 if (funcName == "MPI_Allreduce" || funcName == "PMPI_Allreduce") {
5101
5102 if (Constant *C = dyn_cast<Constant>(call.getOperand(3))) {
5103 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
5104 C = CE->getOperand(0);
5105 }
5106 if (auto GV = dyn_cast<GlobalVariable>(C)) {
5107 if (GV->getName() == "ompi_mpi_double") {
5108 buf.insert({0}, Type::getDoubleTy(C->getContext()));
5109 } else if (GV->getName() == "ompi_mpi_float") {
5110 buf.insert({0}, Type::getFloatTy(C->getContext()));
5111 } else if (GV->getName() == "ompi_mpi_cxx_bool") {
5112 buf.insert({0}, BaseType::Integer);
5113 }
5114 } else if (auto CI = dyn_cast<ConstantInt>(C)) {
5115 // MPICH
5116 if (CI->getValue() == 1275070475) {
5117 buf.insert({0}, Type::getDoubleTy(C->getContext()));
5118 } else if (CI->getValue() == 1275069450) {
5119 buf.insert({0}, Type::getFloatTy(C->getContext()));
5120 }
5121 }
5122 }
5123 // int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
5124 // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
5125 // sendbuf
5126 updateAnalysis(call.getOperand(0), buf.Only(-1, &call), &call);
5127 // recvbuf
5128 updateAnalysis(call.getOperand(1), buf.Only(-1, &call), &call);
5129 // count
5130 updateAnalysis(call.getOperand(2),
5131 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5132 // datatype
5133 // op
5134 // comm
5135 // result
5136 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5137 return;
5138 }
5139 if (funcName == "MPI_Sendrecv_replace") {
5140 updateAnalysis(call.getOperand(0),
5141 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5142 updateAnalysis(call.getOperand(1),
5143 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5144 updateAnalysis(call.getOperand(3),
5145 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5146 updateAnalysis(call.getOperand(4),
5147 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5148 updateAnalysis(call.getOperand(5),
5149 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5150 updateAnalysis(call.getOperand(6),
5151 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5152 updateAnalysis(call.getOperand(8),
5153 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5154 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5155 return;
5156 }
5157 if (funcName == "MPI_Sendrecv") {
5158 updateAnalysis(call.getOperand(0),
5159 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5160 updateAnalysis(call.getOperand(1),
5161 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5162 updateAnalysis(call.getOperand(3),
5163 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5164 updateAnalysis(call.getOperand(4),
5165 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5166 updateAnalysis(call.getOperand(5),
5167 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5168 updateAnalysis(call.getOperand(6),
5169 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5170 updateAnalysis(call.getOperand(7),
5171 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5172 updateAnalysis(call.getOperand(8),
5173 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5174 updateAnalysis(call.getOperand(9),
5175 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5176 updateAnalysis(call.getOperand(11),
5177 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5178 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5179 return;
5180 }
5181 if (funcName == "MPI_Gather" || funcName == "MPI_Scatter") {
5182 updateAnalysis(call.getOperand(0),
5183 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5184 updateAnalysis(call.getOperand(1),
5185 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5186 updateAnalysis(call.getOperand(3),
5187 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5188 updateAnalysis(call.getOperand(4),
5189 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5190 updateAnalysis(call.getOperand(6),
5191 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5192 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5193 return;
5194 }
5195 if (funcName == "MPI_Allgather") {
5196 updateAnalysis(call.getOperand(0),
5197 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5198 updateAnalysis(call.getOperand(1),
5199 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5200 updateAnalysis(call.getOperand(3),
5201 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5202 updateAnalysis(call.getOperand(4),
5203 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5204 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5205 return;
5206 }
5207 /// END MPI
5208
5209 // Prob Prog
5210 if (ci->hasFnAttribute("enzyme_notypeanalysis")) {
5211 return;
5212 }
5213
5214 if (funcName == "memcpy" || funcName == "memmove") {
5215 // TODO have this call common mem transfer to copy data
5217 return;
5218 }
5219 if (funcName == "posix_memalign") {
5220 TypeTree ptrptr;
5221 ptrptr.insert({-1}, BaseType::Pointer);
5222 ptrptr.insert({-1, 0}, BaseType::Pointer);
5223 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5224 updateAnalysis(call.getOperand(0), ptrptr, &call);
5225 updateAnalysis(call.getOperand(1),
5226 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5227 updateAnalysis(call.getOperand(2),
5228 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5229 return;
5230 }
5231 if (funcName == "calloc") {
5232 updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5233 updateAnalysis(call.getOperand(0),
5234 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5235 updateAnalysis(call.getOperand(1),
5236 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5237 return;
5238 }
5239 if (auto opidx = getAllocationIndexFromCall(&call)) {
5240 auto ptr = TypeTree(BaseType::Pointer);
5241 unsigned index = (size_t)*opidx;
5242 if (auto CI = dyn_cast<ConstantInt>(call.getOperand(index))) {
5243 auto &DL = call.getParent()->getParent()->getParent()->getDataLayout();
5244 auto LoadSize = CI->getZExtValue();
5245 // Only propagate mappings in range that aren't "Anything" into the
5246 // pointer
5247 ptr |= getAnalysis(&call).Lookup(LoadSize, DL);
5248 }
5249 updateAnalysis(&call, ptr.Only(-1, &call), &call);
5250 updateAnalysis(call.getOperand(index),
5251 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5252 return;
5253 }
5254 if (funcName == "malloc") {
5255 auto ptr = TypeTree(BaseType::Pointer);
5256 if (auto CI = dyn_cast<ConstantInt>(call.getOperand(0))) {
5257 auto &DL = call.getParent()->getParent()->getParent()->getDataLayout();
5258 auto LoadSize = CI->getZExtValue();
5259 // Only propagate mappings in range that aren't "Anything" into the
5260 // pointer
5261 ptr |= getAnalysis(&call).Lookup(LoadSize, DL);
5262 }
5263 updateAnalysis(&call, ptr.Only(-1, &call), &call);
5264 updateAnalysis(call.getOperand(0),
5265 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5266 return;
5267 }
5268 if (funcName == "__size_returning_new_experiment") {
5269 auto ptr = TypeTree(BaseType::Pointer);
5270 auto &DL = call.getParent()->getParent()->getParent()->getDataLayout();
5271 if (auto CI = dyn_cast<ConstantInt>(call.getOperand(0))) {
5272 auto LoadSize = CI->getZExtValue();
5273 // Only propagate mappings in range that aren't "Anything" into the
5274 // pointer
5275 ptr |= getAnalysis(&call).Lookup(LoadSize, DL);
5276 }
5277 ptr = ptr.Only(0, &call);
5278 ptr |= TypeTree(BaseType::Integer).Only(DL.getPointerSize(), &call);
5279 updateAnalysis(&call, ptr.Only(0, &call), &call);
5280 updateAnalysis(call.getOperand(0),
5281 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5282 return;
5283 }
5284 if (funcName == "julia.gc_alloc_obj" || funcName == "jl_gc_alloc_typed" ||
5285 funcName == "ijl_gc_alloc_typed") {
5286 auto ptr = TypeTree(BaseType::Pointer);
5287 if (auto CI = dyn_cast<ConstantInt>(call.getOperand(1))) {
5288 auto &DL = call.getParent()->getParent()->getParent()->getDataLayout();
5289 auto LoadSize = CI->getZExtValue();
5290 // Only propagate mappings in range that aren't "Anything" into the
5291 // pointer
5292 ptr |= getAnalysis(&call).Lookup(LoadSize, DL);
5293 }
5294 updateAnalysis(&call, ptr.Only(-1, &call), &call);
5295 updateAnalysis(call.getOperand(1),
5296 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5297 return;
5298 }
5299 if (funcName == "julia.except_enter" || funcName == "ijl_excstack_state" ||
5300 funcName == "jl_excstack_state") {
5301 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5302 return;
5303 }
5304 if (funcName == "jl_array_copy" || funcName == "ijl_array_copy" ||
5305 funcName == "jl_inactive_inout" ||
5306 funcName == "jl_genericmemory_copy_slice" ||
5307 funcName == "ijl_genericmemory_copy_slice") {
5308 if (direction & DOWN)
5309 updateAnalysis(&call, getAnalysis(call.getOperand(0)), &call);
5310 if (direction & UP)
5311 updateAnalysis(call.getOperand(0), getAnalysis(&call), &call);
5312 return;
5313 }
5314
5315 if (isAllocationFunction(funcName, TLI)) {
5316 size_t Idx = 0;
5317 for (auto &Arg : ci->args()) {
5318 if (Arg.getType()->isIntegerTy()) {
5319 updateAnalysis(call.getOperand(Idx),
5320 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5321 }
5322 Idx++;
5323 }
5324 assert(ci->getReturnType()->isPointerTy());
5325 updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5326 return;
5327 }
5328 if (funcName == "malloc_usable_size" || funcName == "malloc_size" ||
5329 funcName == "_msize") {
5330 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5331 updateAnalysis(call.getOperand(0),
5332 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5333 return;
5334 }
5335 if (funcName == "realloc") {
5336 size_t sz = 1;
5337 for (auto val : fntypeinfo.knownIntegralValues(call.getArgOperand(1), DT,
5338 intseen, SE)) {
5339 if (val >= 0) {
5340 sz = max(sz, (size_t)val);
5341 }
5342 }
5343
5344 auto &dl = call.getParent()->getParent()->getParent()->getDataLayout();
5345 TypeTree res = getAnalysis(call.getArgOperand(0))
5346 .PurgeAnything()
5347 .Data0()
5348 .ShiftIndices(dl, 0, sz, 0);
5349 TypeTree res2 =
5350 getAnalysis(&call).PurgeAnything().Data0().ShiftIndices(dl, 0, sz, 0);
5351
5352 res.orIn(res2, /*PointerIntSame*/ false);
5353 res.insert({}, BaseType::Pointer);
5354 res = res.Only(-1, &call);
5355 if (direction & DOWN) {
5356 updateAnalysis(&call, res, &call);
5357 }
5358 if (direction & UP) {
5359 updateAnalysis(call.getOperand(0), res, &call);
5360 }
5361 return;
5362 }
5363 if (funcName == "sigaction") {
5364 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5365 updateAnalysis(call.getOperand(0),
5366 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5367 updateAnalysis(call.getOperand(1),
5368 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5369 updateAnalysis(call.getOperand(2),
5370 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5371 return;
5372 }
5373 if (funcName == "mmap") {
5374 updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5375 updateAnalysis(call.getOperand(0),
5376 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5377 updateAnalysis(call.getOperand(1),
5378 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5379 updateAnalysis(call.getOperand(2),
5380 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5381 updateAnalysis(call.getOperand(3),
5382 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5383 updateAnalysis(call.getOperand(4),
5384 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5385 updateAnalysis(call.getOperand(5),
5386 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5387 return;
5388 }
5389 if (funcName == "munmap") {
5390 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5391 updateAnalysis(call.getOperand(0),
5392 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5393 updateAnalysis(call.getOperand(1),
5394 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5395 return;
5396 }
5397 if (funcName == "pthread_mutex_lock" ||
5398 funcName == "pthread_mutex_trylock" ||
5399 funcName == "pthread_rwlock_rdlock" ||
5400 funcName == "pthread_rwlock_unlock" ||
5401 funcName == "pthread_attr_init" || funcName == "pthread_attr_destroy" ||
5402 funcName == "pthread_rwlock_unlock" ||
5403 funcName == "pthread_mutex_unlock") {
5404 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5405 updateAnalysis(call.getOperand(0),
5406 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5407 return;
5408 }
5409 if (isDeallocationFunction(funcName, TLI)) {
5410 size_t Idx = 0;
5411 for (auto &Arg : ci->args()) {
5412 if (Arg.getType()->isIntegerTy()) {
5413 updateAnalysis(call.getOperand(Idx),
5414 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5415 }
5416 if (Arg.getType()->isPointerTy()) {
5417 updateAnalysis(call.getOperand(Idx),
5418 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5419 }
5420 Idx++;
5421 }
5422 if (!ci->getReturnType()->isVoidTy()) {
5423 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call),
5424 &call);
5425 return;
5426 }
5427 assert(ci->getReturnType()->isVoidTy());
5428 return;
5429 }
5430 if (funcName == "memchr" || funcName == "memrchr") {
5431 updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5432 updateAnalysis(call.getOperand(0),
5433 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5434 updateAnalysis(call.getOperand(2),
5435 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5436 return;
5437 }
5438 if (funcName == "strlen") {
5439 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5440 updateAnalysis(call.getOperand(0),
5441 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5442 return;
5443 }
5444 if (funcName == "strcmp") {
5445 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5446 updateAnalysis(call.getOperand(0),
5447 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5448 updateAnalysis(call.getOperand(1),
5449 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5450 return;
5451 }
5452 if (funcName == "bcmp") {
5453 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5454 updateAnalysis(call.getOperand(0),
5455 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5456 updateAnalysis(call.getOperand(1),
5457 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5458 updateAnalysis(call.getOperand(2),
5459 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5460 return;
5461 }
5462 if (funcName == "getcwd") {
5463 updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5464 updateAnalysis(call.getOperand(0),
5465 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5466 updateAnalysis(call.getOperand(1),
5467 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5468 return;
5469 }
5470 if (funcName == "sysconf") {
5471 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5472 updateAnalysis(call.getOperand(0),
5473 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5474 return;
5475 }
5476 if (funcName == "dladdr") {
5477 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5478 updateAnalysis(call.getOperand(0),
5479 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5480 updateAnalysis(call.getOperand(1),
5481 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5482 return;
5483 }
5484 if (funcName == "__errno_location") {
5485 TypeTree ptrint;
5486 ptrint.insert({-1, -1}, BaseType::Integer);
5487 ptrint.insert({-1}, BaseType::Pointer);
5488 updateAnalysis(&call, ptrint, &call);
5489 return;
5490 }
5491 if (funcName == "getenv") {
5492 TypeTree ptrint;
5493 ptrint.insert({-1, -1}, BaseType::Integer);
5494 ptrint.insert({-1}, BaseType::Pointer);
5495 updateAnalysis(&call, ptrint, &call);
5496 updateAnalysis(call.getOperand(0),
5497 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5498 return;
5499 }
5500 if (funcName == "getcwd") {
5501 updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5502 updateAnalysis(call.getOperand(0),
5503 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5504 updateAnalysis(call.getOperand(1),
5505 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5506 return;
5507 }
5508 if (funcName == "mprotect") {
5509 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5510 updateAnalysis(call.getOperand(0),
5511 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5512 updateAnalysis(call.getOperand(1),
5513 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5514 updateAnalysis(call.getOperand(2),
5515 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5516 return;
5517 }
5518 if (funcName == "memcmp") {
5519 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5520 updateAnalysis(call.getOperand(0),
5521 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5522 updateAnalysis(call.getOperand(1),
5523 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5524 updateAnalysis(call.getOperand(2),
5525 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5526 return;
5527 }
5528 if (funcName == "signal") {
5529 updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5530 updateAnalysis(call.getOperand(0),
5531 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5532 updateAnalysis(call.getOperand(1),
5533 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5534 return;
5535 }
5536 if (funcName == "write" || funcName == "read" || funcName == "writev" ||
5537 funcName == "readv") {
5538 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5539 // FD type not going to be defined here
5540 // updateAnalysis(call.getOperand(0),
5541 // TypeTree(BaseType::Pointer).Only(-1),
5542 // &call);
5543 updateAnalysis(call.getOperand(1),
5544 TypeTree(BaseType::Pointer).Only(-1, &call), &call);
5545 updateAnalysis(call.getOperand(2),
5546 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5547 return;
5548 }
5549 if (funcName == "gsl_sf_legendre_array_e") {
5550 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5551 return;
5552 }
5553
5554 // CONSIDER(__lgamma_r_finite)
5555
5556 CONSIDER2(frexp, double, double, int *)
5557 CONSIDER(frexpf)
5558 CONSIDER(frexpl)
5559 CONSIDER2(ldexp, double, double, int)
5560 CONSIDER2(modf, double, double, double *)
5561 CONSIDER(modff)
5562 CONSIDER(modfl)
5563
5564 CONSIDER2(remquo, double, double, double, int *)
5565 CONSIDER(remquof)
5566 CONSIDER(remquol)
5567
5568 if (isMemFreeLibMFunction(funcName)) {
5569#if LLVM_VERSION_MAJOR >= 14
5570 for (size_t i = 0; i < call.arg_size(); ++i)
5571#else
5572 for (size_t i = 0; i < call.getNumArgOperands(); ++i)
5573#endif
5574 {
5575 Type *T = call.getArgOperand(i)->getType();
5576 if (T->isFloatingPointTy()) {
5578 call.getArgOperand(i),
5580 call.getArgOperand(i)->getType()->getScalarType()))
5581 .Only(-1, &call),
5582 &call);
5583 } else if (T->isIntegerTy()) {
5584 updateAnalysis(call.getArgOperand(i),
5585 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5586 } else if (auto ST = dyn_cast<StructType>(T)) {
5587 assert(ST->getNumElements() >= 1);
5588 for (size_t i = 1; i < ST->getNumElements(); ++i) {
5589 assert(ST->getTypeAtIndex((unsigned)0) == ST->getTypeAtIndex(i));
5590 }
5591 if (ST->getTypeAtIndex((unsigned)0)->isFloatingPointTy())
5593 call.getArgOperand(i),
5595 ST->getTypeAtIndex((unsigned)0)->getScalarType()))
5596 .Only(-1, &call),
5597 &call);
5598 else if (ST->getTypeAtIndex((unsigned)0)->isIntegerTy()) {
5599 updateAnalysis(call.getArgOperand(i),
5600 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5601 } else {
5602 llvm::errs() << *T << " - " << call << "\n";
5603 llvm_unreachable("Unknown type for libm");
5604 }
5605 } else if (auto AT = dyn_cast<ArrayType>(T)) {
5606 assert(AT->getNumElements() >= 1);
5607 if (AT->getElementType()->isFloatingPointTy())
5609 call.getArgOperand(i),
5610 TypeTree(ConcreteType(AT->getElementType()->getScalarType()))
5611 .Only(-1, &call),
5612 &call);
5613 else if (AT->getElementType()->isIntegerTy()) {
5614 updateAnalysis(call.getArgOperand(i),
5615 TypeTree(BaseType::Integer).Only(-1, &call), &call);
5616 } else {
5617 llvm::errs() << *T << " - " << call << "\n";
5618 llvm_unreachable("Unknown type for libm");
5619 }
5620 } else {
5621 llvm::errs() << *T << " - " << call << "\n";
5622 llvm_unreachable("Unknown type for libm");
5623 }
5624 }
5625 Type *T = call.getType();
5626 if (T->isFloatingPointTy()) {
5627 updateAnalysis(&call,
5628 TypeTree(ConcreteType(call.getType()->getScalarType()))
5629 .Only(-1, &call),
5630 &call);
5631 } else if (T->isIntegerTy()) {
5632 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call),
5633 &call);
5634 } else if (T->isVoidTy()) {
5635 } else if (auto ST = dyn_cast<StructType>(T)) {
5636 assert(ST->getNumElements() >= 1);
5637 TypeTree TT;
5638 auto &DL = call.getParent()->getParent()->getParent()->getDataLayout();
5639 for (size_t i = 0; i < ST->getNumElements(); ++i) {
5640 auto T = ST->getTypeAtIndex(i);
5642
5643 Value *vec[2] = {
5644 ConstantInt::get(Type::getInt64Ty(call.getContext()), 0),
5645 ConstantInt::get(Type::getInt32Ty(call.getContext()), i)};
5646 auto ud = UndefValue::get(getUnqual(ST));
5647 auto g2 = GetElementPtrInst::Create(ST, ud, vec);
5648 APInt ai(DL.getIndexSizeInBits(0), 0);
5649 g2->accumulateConstantOffset(DL, ai);
5650 delete g2;
5651 size_t Offset = ai.getZExtValue();
5652
5653 size_t nextOffset;
5654 if (i + 1 == ST->getNumElements())
5655 nextOffset = (DL.getTypeSizeInBits(ST) + 7) / 8;
5656 else {
5657 Value *vec[2] = {
5658 ConstantInt::get(Type::getInt64Ty(call.getContext()), 0),
5659 ConstantInt::get(Type::getInt32Ty(call.getContext()), i + 1)};
5660 auto ud = UndefValue::get(getUnqual(ST));
5661 auto g2 = GetElementPtrInst::Create(ST, ud, vec);
5662 APInt ai(DL.getIndexSizeInBits(0), 0);
5663 g2->accumulateConstantOffset(DL, ai);
5664 delete g2;
5665 nextOffset = ai.getZExtValue();
5666 }
5667
5668 if (T->isFloatingPointTy()) {
5669 CT = T;
5670 } else if (T->isIntegerTy()) {
5671 CT = BaseType::Integer;
5672 }
5673 if (CT != BaseType::Unknown) {
5674 TypeTree mid = TypeTree(CT).Only(-1, &call);
5675 TT |= mid.ShiftIndices(DL, /*init offset*/ 0,
5676 /*maxSize*/ nextOffset - Offset,
5677 /*addOffset*/ Offset);
5678 }
5679 }
5680 auto Size = (DL.getTypeSizeInBits(ST) + 7) / 8;
5681 TT.CanonicalizeInPlace(Size, DL);
5682 updateAnalysis(&call, TT, &call);
5683 } else if (auto AT = dyn_cast<ArrayType>(T)) {
5684 assert(AT->getNumElements() >= 1);
5685 if (AT->getElementType()->isFloatingPointTy())
5687 &call,
5688 TypeTree(ConcreteType(AT->getElementType()->getScalarType()))
5689 .Only(-1, &call),
5690 &call);
5691 else {
5692 llvm::errs() << *T << " - " << call << "\n";
5693 llvm_unreachable("Unknown type for libm");
5694 }
5695 } else {
5696 llvm::errs() << *T << " - " << call << "\n";
5697 llvm_unreachable("Unknown type for libm");
5698 }
5699 return;
5700 }
5701 if (funcName == "__lgamma_r_finite") {
5703 call.getArgOperand(0),
5704 TypeTree(ConcreteType(Type::getDoubleTy(call.getContext())))
5705 .Only(-1, &call),
5706 &call);
5707 updateAnalysis(call.getArgOperand(1),
5708 TypeTree(BaseType::Integer).Only(0, &call).Only(-1, &call),
5709 &call);
5711 &call,
5712 TypeTree(ConcreteType(Type::getDoubleTy(call.getContext())))
5713 .Only(-1, &call),
5714 &call);
5715 }
5716 if (funcName == "__fd_sincos_1" || funcName == "__fd_sincos_1f" ||
5717 funcName == "__fd_sincos_1l") {
5718 updateAnalysis(call.getArgOperand(0),
5719 TypeTree(ConcreteType(call.getArgOperand(0)->getType()))
5720 .Only(-1, &call),
5721 &call);
5722 updateAnalysis(&call,
5723 TypeTree(ConcreteType(call.getArgOperand(0)->getType()))
5724 .Only(-1, &call),
5725 &call);
5726 }
5727 if (funcName == "frexp" || funcName == "frexpf" || funcName == "frexpl") {
5728
5730 &call, TypeTree(ConcreteType(call.getType())).Only(-1, &call), &call);
5731 updateAnalysis(call.getOperand(0),
5732 TypeTree(ConcreteType(call.getType())).Only(-1, &call),
5733 &call);
5735 size_t objSize = 1;
5736
5737#if LLVM_VERSION_MAJOR < 17
5738 auto &DL = fntypeinfo.Function->getParent()->getDataLayout();
5739 objSize = DL.getTypeSizeInBits(
5740 call.getOperand(1)->getType()->getPointerElementType()) /
5741 8;
5742#endif
5743 for (size_t i = 0; i < objSize; ++i) {
5744 ival.insert({(int)i}, BaseType::Integer);
5745 }
5746 updateAnalysis(call.getOperand(1), ival.Only(-1, &call), &call);
5747 return;
5748 }
5749
5750 if (funcName == "__cxa_guard_acquire" || funcName == "printf" ||
5751 funcName == "vprintf" || funcName == "puts" || funcName == "fputc" ||
5752 funcName == "fprintf") {
5753 updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5754 }
5755
5756 if (dontAnalyze(funcName))
5757 return;
5758
5759 if (!ci->empty() && !hasMetadata(ci, "enzyme_gradient") &&
5760 !hasMetadata(ci, "enzyme_derivative")) {
5761 visitIPOCall(call, *ci);
5762 }
5763 }
5764}
5765
5767 bool set = false;
5768 TypeTree vd;
5769 for (BasicBlock &BB : *fntypeinfo.Function) {
5770 for (auto &inst : BB) {
5771 if (auto ri = dyn_cast<ReturnInst>(&inst)) {
5772 if (auto rv = ri->getReturnValue()) {
5773 if (set == false) {
5774 set = true;
5775 vd = getAnalysis(rv);
5776 continue;
5777 }
5778 vd &= getAnalysis(rv);
5779 // TODO insert the selectinst anything propagation here
5780 // however this needs to be done simultaneously with preventing
5781 // anything from propagating up through the return value (if there
5782 // are multiple possible returns)
5783 }
5784 }
5785 }
5786 }
5787 return vd;
5788}
5789
5790/// Helper function that calculates whether a given value must only be
5791/// an integer and cannot be cast/stored to be used as a ptr/integer
5792bool TypeAnalyzer::mustRemainInteger(Value *val, bool *returned) {
5793 std::map<Value *, std::pair<bool, bool>> &seen = mriseen;
5794 const DataLayout &DL = fntypeinfo.Function->getParent()->getDataLayout();
5795 if (seen.find(val) != seen.end()) {
5796 if (returned)
5797 *returned |= seen[val].second;
5798 return seen[val].first;
5799 }
5800 seen[val] = std::make_pair(true, false);
5801 for (auto u : val->users()) {
5802 if (auto SI = dyn_cast<StoreInst>(u)) {
5803 if (parseTBAA(*SI, DL, MST).Inner0().isIntegral())
5804 continue;
5805 seen[val].first = false;
5806 continue;
5807 }
5808 if (isa<CastInst>(u)) {
5809 if (!u->getType()->isIntOrIntVectorTy()) {
5810 seen[val].first = false;
5811 continue;
5812 } else if (!mustRemainInteger(u, returned)) {
5813 seen[val].first = false;
5814 seen[val].second |= seen[u].second;
5815 continue;
5816 } else
5817 continue;
5818 }
5819 if (isa<BinaryOperator>(u) || isa<IntrinsicInst>(u) || isa<PHINode>(u) ||
5820#if LLVM_VERSION_MAJOR <= 17
5821 isa<UDivOperator>(u) || isa<SDivOperator>(u) ||
5822#endif
5823 isa<LShrOperator>(u) || isa<AShrOperator>(u) || isa<AddOperator>(u) ||
5824 isa<MulOperator>(u) || isa<ShlOperator>(u)) {
5825 if (!mustRemainInteger(u, returned)) {
5826 seen[val].first = false;
5827 seen[val].second |= seen[u].second;
5828 }
5829 continue;
5830 }
5831 if (auto gep = dyn_cast<GetElementPtrInst>(u)) {
5832 if (gep->isInBounds() && gep->getPointerOperand() != val) {
5833 continue;
5834 }
5835 }
5836 if (returned && isa<ReturnInst>(u)) {
5837 *returned = true;
5838 seen[val].second = true;
5839 continue;
5840 }
5841 if (auto CI = dyn_cast<CallBase>(u)) {
5842 if (auto F = CI->getCalledFunction()) {
5843 if (!F->empty()) {
5844 int argnum = 0;
5845 bool subreturned = false;
5846 for (auto &arg : F->args()) {
5847 if (CI->getArgOperand(argnum) == val &&
5848 !mustRemainInteger(&arg, &subreturned)) {
5849 seen[val].first = false;
5850 seen[val].second |= seen[&arg].second;
5851 continue;
5852 }
5853 ++argnum;
5854 }
5855 if (subreturned && !mustRemainInteger(CI, returned)) {
5856 seen[val].first = false;
5857 seen[val].second |= seen[CI].second;
5858 continue;
5859 }
5860 continue;
5861 }
5862 }
5863 }
5864 if (isa<CmpInst>(u))
5865 continue;
5866 seen[val].first = false;
5867 seen[val].second = true;
5868 }
5869 if (returned && seen[val].second)
5870 *returned = true;
5871 return seen[val].first;
5872}
5873
5874FnTypeInfo TypeAnalyzer::getCallInfo(CallBase &call, Function &fn) {
5875 FnTypeInfo typeInfo(&fn);
5876
5877 size_t argnum = 0;
5878 for (auto &arg : fn.args()) {
5879 if (argnum >= call.arg_size()) {
5880 typeInfo.Arguments.insert(
5881 std::pair<Argument *, TypeTree>(&arg, TypeTree()));
5882 std::set<int64_t> bounded;
5883 typeInfo.KnownValues.insert(
5884 std::pair<Argument *, std::set<int64_t>>(&arg, bounded));
5885 ++argnum;
5886 continue;
5887 }
5888 auto dt = getAnalysis(call.getArgOperand(argnum));
5889 if (arg.getType()->isIntOrIntVectorTy() &&
5890 dt.Inner0() == BaseType::Anything) {
5891 if (mustRemainInteger(&arg)) {
5892 dt = TypeTree(BaseType::Integer).Only(-1, &call);
5893 }
5894 }
5895 typeInfo.Arguments.insert(std::pair<Argument *, TypeTree>(&arg, dt));
5896 std::set<int64_t> bounded;
5897 for (auto v : fntypeinfo.knownIntegralValues(call.getArgOperand(argnum), DT,
5898 intseen, SE)) {
5899 if (abs(v) > MaxIntOffset)
5900 continue;
5901 bounded.insert(v);
5902 }
5903 typeInfo.KnownValues.insert(
5904 std::pair<Argument *, std::set<int64_t>>(&arg, bounded));
5905 ++argnum;
5906 }
5907
5908 typeInfo.Return = getAnalysis(&call);
5909 return typeInfo;
5910}
5911
5912void TypeAnalyzer::visitIPOCall(CallBase &call, Function &fn) {
5913#if LLVM_VERSION_MAJOR >= 14
5914 if (call.arg_size() != fn.getFunctionType()->getNumParams())
5915 return;
5916#else
5917 if (call.getNumArgOperands() != fn.getFunctionType()->getNumParams())
5918 return;
5919#endif
5920
5921 assert(fntypeinfo.KnownValues.size() ==
5922 fntypeinfo.Function->getFunctionType()->getNumParams());
5923
5924 bool hasDown = direction & DOWN;
5925 bool hasUp = direction & UP;
5926
5927 if (hasDown) {
5928 if (call.getType()->isVoidTy())
5929 hasDown = false;
5930 else {
5931 if (getAnalysis(&call).IsFullyDetermined())
5932 hasDown = false;
5933 }
5934 }
5935 if (hasUp) {
5936 bool unknown = false;
5937#if LLVM_VERSION_MAJOR >= 14
5938 for (auto &arg : call.args())
5939#else
5940 for (auto &arg : call.arg_operands())
5941#endif
5942 {
5943 if (isa<ConstantData>(arg))
5944 continue;
5945 if (!getAnalysis(arg).IsFullyDetermined()) {
5946 unknown = true;
5947 break;
5948 }
5949 }
5950 if (!unknown)
5951 hasUp = false;
5952 }
5953
5954 // Fast path where all information has already been derived
5955 if (!hasUp && !hasDown)
5956 return;
5957
5958 FnTypeInfo typeInfo = getCallInfo(call, fn);
5959 typeInfo = preventTypeAnalysisLoops(typeInfo, call.getParent()->getParent());
5960
5961 if (EnzymePrintType) {
5962 llvm::errs() << " starting IPO of ";
5963 call.print(llvm::errs(), *MST);
5964 llvm::errs() << "\n";
5965 }
5966
5968
5969 if (EnzymePrintType) {
5970 llvm::errs() << " ending IPO of ";
5971 call.print(llvm::errs(), *MST);
5972 llvm::errs() << "\n";
5973 }
5974
5975 if (hasUp) {
5976 auto a = fn.arg_begin();
5977#if LLVM_VERSION_MAJOR >= 14
5978 for (auto &arg : call.args())
5979#else
5980 for (auto &arg : call.arg_operands())
5981#endif
5982 {
5983 auto dt = STR.query(a);
5984 if (EnzymePrintType) {
5985 llvm::errs() << " updating ";
5986 arg->print(llvm::errs(), *MST);
5987 llvm::errs() << " = " << dt.str() << " via IPO of ";
5988 call.print(llvm::errs(), *MST);
5989 llvm::errs() << " arg ";
5990 a->print(llvm::errs(), *MST);
5991 llvm::errs() << "\n";
5992 }
5993 updateAnalysis(arg, dt, &call);
5994 ++a;
5995 }
5996 }
5997
5998 if (hasDown) {
5999 TypeTree vd = STR.getReturnAnalysis();
6000 if (call.getType()->isIntOrIntVectorTy() &&
6001 vd.Inner0() == BaseType::Anything) {
6002 bool returned = false;
6003 if (mustRemainInteger(&call, &returned) && !returned) {
6004 vd = TypeTree(BaseType::Integer).Only(-1, &call);
6005 }
6006 }
6007 updateAnalysis(&call, vd, &call);
6008 }
6009}
6010
6012 assert(fn.KnownValues.size() ==
6013 fn.Function->getFunctionType()->getNumParams());
6014 assert(fn.Function);
6015 auto found = analyzedFunctions.find(fn);
6016 if (found != analyzedFunctions.end()) {
6017 auto &analysis = *found->second;
6018 if (analysis.fntypeinfo.Function != fn.Function) {
6019 llvm::errs() << " queryFunc: " << *fn.Function << "\n";
6020 llvm::errs() << " analysisFunc: " << *analysis.fntypeinfo.Function
6021 << "\n";
6022 }
6023 assert(analysis.fntypeinfo.Function == fn.Function);
6024
6025 return TypeResults(analysis);
6026 }
6027
6028 if (fn.Function->empty())
6029 return TypeResults(nullptr);
6030
6031 auto res = analyzedFunctions.emplace(fn, new TypeAnalyzer(fn, *this));
6032 auto &analysis = *res.first->second;
6033
6034 if (EnzymePrintType) {
6035 llvm::errs() << "analyzing function " << fn.Function->getName() << "\n";
6036 for (auto &pair : fn.Arguments) {
6037 llvm::errs() << " + knowndata: ";
6038 pair.first->print(llvm::errs(), *analysis.MST);
6039 llvm::errs() << " : " << pair.second.str();
6040 auto found = fn.KnownValues.find(pair.first);
6041 if (found != fn.KnownValues.end()) {
6042 llvm::errs() << " - " << to_string(found->second);
6043 }
6044 llvm::errs() << "\n";
6045 }
6046 llvm::errs() << " + retdata: " << fn.Return.str() << "\n";
6047 }
6048
6049 analysis.prepareArgs();
6050 if (RustTypeRules) {
6051 analysis.considerRustDebugInfo();
6052 }
6053 analysis.considerTBAA();
6054 analysis.run();
6055
6056 if (analysis.fntypeinfo.Function != fn.Function) {
6057 llvm::errs() << " queryFunc: " << *fn.Function << "\n";
6058 llvm::errs() << " analysisFunc: " << *analysis.fntypeinfo.Function << "\n";
6059 }
6060 assert(analysis.fntypeinfo.Function == fn.Function);
6061
6062 {
6063 auto &analysis = *analyzedFunctions.find(fn)->second;
6064 if (analysis.fntypeinfo.Function != fn.Function) {
6065 llvm::errs() << " queryFunc: " << *fn.Function << "\n";
6066 llvm::errs() << " analysisFunc: " << *analysis.fntypeinfo.Function
6067 << "\n";
6068 }
6069 assert(analysis.fntypeinfo.Function == fn.Function);
6070 }
6071
6072 // Store the steady state result (if changed) to avoid
6073 // a second analysis later.
6074 analyzedFunctions.emplace(TypeResults(analysis).getAnalyzedTypeInfo(),
6075 res.first->second);
6076
6077 return TypeResults(analysis);
6078}
6079
6080TypeResults::TypeResults(TypeAnalyzer &analyzer) : analyzer(&analyzer) {}
6081TypeResults::TypeResults(std::nullptr_t) : analyzer(nullptr) {}
6082
6085 for (auto &arg : analyzer->fntypeinfo.Function->args()) {
6086 res.Arguments.insert(std::pair<Argument *, TypeTree>(&arg, query(&arg)));
6087 }
6088 res.Return = getReturnAnalysis();
6090 return res;
6091}
6092
6093FnTypeInfo TypeResults::getCallInfo(CallBase &CI, Function &fn) const {
6094 return analyzer->getCallInfo(CI, fn);
6095}
6096
6097TypeTree TypeResults::query(Value *val) const {
6098#ifndef NDEBUG
6099 if (auto inst = dyn_cast<Instruction>(val)) {
6100 assert(inst->getParent()->getParent() == analyzer->fntypeinfo.Function);
6101 }
6102 if (auto arg = dyn_cast<Argument>(val)) {
6103 assert(arg->getParent() == analyzer->fntypeinfo.Function);
6104 }
6105#endif
6106 return analyzer->getAnalysis(val);
6107}
6108
6109// Returns last non-padding/alignment location of the corresponding subtype T.
6110size_t skippedBytes(SmallSet<size_t, 8> &offs, Type *T, const DataLayout &DL,
6111 size_t offset = 0) {
6112 auto ST = dyn_cast<StructType>(T);
6113 if (!ST)
6114 return (DL.getTypeSizeInBits(T) + 7) / 8;
6115
6116 auto SL = DL.getStructLayout(ST);
6117 size_t prevOff = 0;
6118 for (size_t idx = 0; idx < ST->getNumElements(); idx++) {
6119 auto off = SL->getElementOffset(idx);
6120 if (off > prevOff)
6121 for (size_t i = prevOff; i < off; i++)
6122 offs.insert(offset + i);
6123 size_t subSize = skippedBytes(offs, ST->getElementType(idx), DL, prevOff);
6124 prevOff = off + subSize;
6125 }
6126 return prevOff;
6127}
6128
6129bool TypeResults::allFloat(Value *val) const {
6130 assert(val);
6131 assert(val->getType());
6132 auto q = query(val);
6133 auto dt = q[{-1}];
6134 if (dt != BaseType::Anything && dt != BaseType::Unknown)
6135 return dt.isFloat();
6136
6137 if (val->getType()->isTokenTy() || val->getType()->isVoidTy())
6138 return false;
6139 auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout();
6140 SmallSet<size_t, 8> offs;
6141 size_t ObjSize = skippedBytes(offs, val->getType(), dl);
6142
6143 for (size_t i = 0; i < ObjSize;) {
6144 dt = q[{(int)i}];
6145 if (auto FT = dt.isFloat()) {
6146 i += (dl.getTypeSizeInBits(FT) + 7) / 8;
6147 continue;
6148 }
6149 if (offs.count(i)) {
6150 i++;
6151 continue;
6152 }
6153 return false;
6154 }
6155 return true;
6156}
6157
6158bool TypeResults::anyFloat(Value *val, bool anythingIsFloat) const {
6159 assert(val);
6160 assert(val->getType());
6161 auto q = query(val);
6162 auto dt = q[{-1}];
6163 if (!anythingIsFloat && dt == BaseType::Anything)
6164 return false;
6165 if (dt != BaseType::Anything && dt != BaseType::Unknown)
6166 return dt.isFloat();
6167
6168 if (val->getType()->isTokenTy() || val->getType()->isVoidTy())
6169 return false;
6170 auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout();
6171 SmallSet<size_t, 8> offs;
6172 size_t ObjSize = skippedBytes(offs, val->getType(), dl);
6173
6174 for (size_t i = 0; i < ObjSize;) {
6175 dt = q[{(int)i}];
6176 if (dt == BaseType::Integer) {
6177 i++;
6178 continue;
6179 }
6180 if (!anythingIsFloat && dt == BaseType::Integer) {
6181 i++;
6182 continue;
6183 }
6184 if (dt == BaseType::Pointer) {
6185 i += dl.getPointerSize(0);
6186 continue;
6187 }
6188 if (offs.count(i)) {
6189 i++;
6190 continue;
6191 }
6192 return true;
6193 }
6194 return false;
6195}
6196
6197bool TypeResults::anyPointer(Value *val) const {
6198 assert(val);
6199 assert(val->getType());
6200 auto q = query(val);
6201 auto dt = q[{-1}];
6202 if (dt != BaseType::Anything && dt != BaseType::Unknown)
6203 return dt == BaseType::Pointer;
6204 if (val->getType()->isTokenTy() || val->getType()->isVoidTy())
6205 return false;
6206
6207 auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout();
6208 SmallSet<size_t, 8> offs;
6209 size_t ObjSize = skippedBytes(offs, val->getType(), dl);
6210
6211 for (size_t i = 0; i < ObjSize;) {
6212 dt = q[{(int)i}];
6213 if (dt == BaseType::Integer) {
6214 i++;
6215 continue;
6216 }
6217 if (auto FT = dt.isFloat()) {
6218 i += (dl.getTypeSizeInBits(FT) + 7) / 8;
6219 continue;
6220 }
6221 if (offs.count(i)) {
6222 i++;
6223 continue;
6224 }
6225 return true;
6226 }
6227 return false;
6228}
6229
6230void TypeResults::dump(llvm::raw_ostream &ss) const { analyzer->dump(ss); }
6231
6232ConcreteType TypeResults::intType(size_t num, Value *val, bool errIfNotFound,
6233 bool pointerIntSame) const {
6234 assert(val);
6235 assert(val->getType());
6236 auto q = query(val);
6237 auto dt = q[{0}];
6238 /*
6239 size_t ObjSize = 1;
6240 if (val->getType()->isSized())
6241 ObjSize = (fn.Function->getParent()->getDataLayout().getTypeSizeInBits(
6242 val->getType()) +7) / 8;
6243 */
6244 dt.orIn(q[{-1}], pointerIntSame);
6245 for (size_t i = 1; i < num; ++i) {
6246 dt.orIn(q[{(int)i}], pointerIntSame);
6247 }
6248
6249 if (errIfNotFound && (!dt.isKnown() || dt == BaseType::Anything)) {
6250 if (auto inst = dyn_cast<Instruction>(val)) {
6251 llvm::errs() << *inst->getParent()->getParent()->getParent() << "\n";
6252 llvm::errs() << *inst->getParent()->getParent() << "\n";
6253 for (auto &pair : analyzer->analysis) {
6254 llvm::errs() << "val: " << *pair.first << " - " << pair.second.str()
6255 << "\n";
6256 }
6257 }
6258 llvm::errs() << "could not deduce type of integer " << *val << "\n";
6259 assert(0 && "could not deduce type of integer");
6260 }
6261 return dt;
6262}
6263
6264Type *TypeResults::addingType(size_t num, Value *val, size_t start) const {
6265 assert(val);
6266 assert(val->getType());
6267 auto q = query(val);
6268 Type *ty = q[{-1}].isFloat();
6269 for (size_t i = start; i < num; ++i) {
6270 auto ty2 = q[{(int)i}].isFloat();
6271 if (ty) {
6272 if (ty2)
6273 assert(ty == ty2);
6274 } else {
6275 ty = ty2;
6276 }
6277 }
6278 return ty;
6279}
6280
6281ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I,
6282 bool errIfNotFound,
6283 bool pointerIntSame) const {
6284 assert(val);
6285 assert(val->getType());
6286 auto q = query(val).Data0();
6287 if (!(val->getType()->isPointerTy() || q[{}] == BaseType::Pointer)) {
6288 llvm::errs() << *analyzer->fntypeinfo.Function << "\n";
6289 dump();
6290 llvm::errs() << "val: " << *val << "\n";
6291 }
6292 assert(val->getType()->isPointerTy() || q[{}] == BaseType::Pointer);
6293
6294 auto dt = q[{-1}];
6295 for (size_t i = 0; i < num; ++i) {
6296 bool Legal = true;
6297 dt.checkedOrIn(q[{(int)i}], pointerIntSame, Legal);
6298 if (!Legal) {
6299 std::string str;
6300 raw_string_ostream ss(str);
6301 ss << "Illegal firstPointer, num: " << num << " q: " << q.str() << "\n";
6302 ss << " at " << *val << " from " << *I << "\n";
6303 if (CustomErrorHandler) {
6305 &analyzer, nullptr, nullptr);
6306 }
6307 llvm::errs() << ss.str() << "\n";
6308 llvm_unreachable("Illegal firstPointer");
6309 }
6310 }
6311
6312 if (errIfNotFound && (!dt.isKnown() || dt == BaseType::Anything)) {
6313 auto &res = *analyzer;
6314 if (auto inst = dyn_cast<Instruction>(val)) {
6315 llvm::errs() << *inst->getParent()->getParent()->getParent() << "\n";
6316 llvm::errs() << *inst->getParent()->getParent() << "\n";
6317 for (auto &pair : res.analysis) {
6318 if (auto in = dyn_cast<Instruction>(pair.first)) {
6319 if (in->getParent()->getParent() != inst->getParent()->getParent()) {
6320 llvm::errs() << "inf: " << *in->getParent()->getParent() << "\n";
6321 llvm::errs() << "instf: " << *inst->getParent()->getParent()
6322 << "\n";
6323 llvm::errs() << "in: " << *in << "\n";
6324 llvm::errs() << "inst: " << *inst << "\n";
6325 }
6326 assert(in->getParent()->getParent() ==
6327 inst->getParent()->getParent());
6328 }
6329 llvm::errs() << "val: " << *pair.first << " - " << pair.second.str()
6330 << " int: " +
6331 to_string(res.knownIntegralValues(pair.first))
6332 << "\n";
6333 }
6334 }
6335 if (auto arg = dyn_cast<Argument>(val)) {
6336 llvm::errs() << *arg->getParent() << "\n";
6337 for (auto &pair : res.analysis) {
6338#ifndef NDEBUG
6339 if (auto in = dyn_cast<Instruction>(pair.first))
6340 assert(in->getParent()->getParent() == arg->getParent());
6341#endif
6342 llvm::errs() << "val: " << *pair.first << " - " << pair.second.str()
6343 << " int: " +
6344 to_string(res.knownIntegralValues(pair.first))
6345 << "\n";
6346 }
6347 }
6348 llvm::errs() << "fn: " << *analyzer->fntypeinfo.Function << "\n";
6349 dump();
6350 llvm::errs() << "could not deduce type of integer " << *val
6351 << " num:" << num << " q:" << q.str() << " \n";
6352
6353 llvm::DiagnosticLocation loc =
6354 analyzer->fntypeinfo.Function->getSubprogram();
6355 Instruction *codeLoc =
6356 &*analyzer->fntypeinfo.Function->getEntryBlock().begin();
6357 if (auto inst = dyn_cast<Instruction>(val)) {
6358 loc = inst->getDebugLoc();
6359 codeLoc = inst;
6360 }
6361 EmitFailure("CannotDeduceType", loc, codeLoc,
6362 "failed to deduce type of value ", *val);
6363
6364 assert(0 && "could not deduce type of integer");
6365 }
6366 return dt;
6367}
6368
6369/// Parse the debug info generated by rustc and retrieve useful type info if
6370/// possible
6372 DataLayout DL = fntypeinfo.Function->getParent()->getDataLayout();
6373 for (BasicBlock &BB : *fntypeinfo.Function) {
6374 for (Instruction &I : BB) {
6375 if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(&I)) {
6376 TypeTree TT = parseDIType(*DDI, DL);
6377 if (!TT.isKnown()) {
6378 continue;
6379 }
6381 updateAnalysis(DDI->getAddress(), TT.Only(-1, &I), DDI);
6382 }
6383 }
6384 }
6385}
6386
6387TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I,
6388 bool intIsPointer) {
6389 if (ET->isIntOrIntVectorTy()) {
6390 if (intIsPointer)
6391 return TypeTree(BaseType::Pointer).Only(-1, I);
6392 else
6393 return TypeTree(BaseType::Integer).Only(-1, I);
6394 }
6395 if (ET->isFPOrFPVectorTy()) {
6396 return TypeTree(ConcreteType(ET->getScalarType())).Only(-1, I);
6397 }
6398 if (ET->isPointerTy()) {
6399 return TypeTree(BaseType::Pointer).Only(-1, I);
6400 }
6401 if (auto ST = dyn_cast<StructType>(ET)) {
6402 auto &DL = I->getParent()->getParent()->getParent()->getDataLayout();
6403
6404 TypeTree Out;
6405
6406 for (size_t i = 0; i < ST->getNumElements(); i++) {
6407 auto SubT =
6408 defaultTypeTreeForLLVM(ST->getElementType(i), I, intIsPointer);
6409 Value *vec[2] = {
6410 ConstantInt::get(Type::getInt64Ty(I->getContext()), 0),
6411 ConstantInt::get(Type::getInt32Ty(I->getContext()), i),
6412 };
6413 auto g2 =
6414 GetElementPtrInst::Create(ST, UndefValue::get(getUnqual(ST)), vec);
6415 APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
6416 g2->accumulateConstantOffset(DL, ai);
6417 // Using destructor rather than eraseFromParent
6418 // as g2 has no parent
6419 delete g2;
6420
6421 auto size = (DL.getTypeSizeInBits(ST->getElementType(i)) + 7) / 8;
6422 int Off = (int)ai.getLimitedValue();
6423 Out |= SubT.ShiftIndices(DL, 0, size, Off);
6424 }
6425 return Out;
6426 }
6427 if (auto AT = dyn_cast<ArrayType>(ET)) {
6428 auto SubT = defaultTypeTreeForLLVM(AT->getElementType(), I, intIsPointer);
6429 auto &DL = I->getParent()->getParent()->getParent()->getDataLayout();
6430
6431 TypeTree Out;
6432 for (size_t i = 0; i < AT->getNumElements(); i++) {
6433 Value *vec[2] = {
6434 ConstantInt::get(Type::getInt64Ty(I->getContext()), 0),
6435 ConstantInt::get(Type::getInt32Ty(I->getContext()), i),
6436 };
6437 auto g2 =
6438 GetElementPtrInst::Create(AT, UndefValue::get(getUnqual(AT)), vec);
6439 APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
6440 g2->accumulateConstantOffset(DL, ai);
6441 // Using destructor rather than eraseFromParent
6442 // as g2 has no parent
6443 delete g2;
6444
6445 int Off = (int)ai.getLimitedValue();
6446 auto size = (DL.getTypeSizeInBits(AT->getElementType()) + 7) / 8;
6447 Out |= SubT.ShiftIndices(DL, 0, size, Off);
6448 }
6449 return Out;
6450 }
6451 if (auto AT = dyn_cast<VectorType>(ET)) {
6452#if LLVM_VERSION_MAJOR >= 12
6453 assert(!AT->getElementCount().isScalable());
6454 size_t numElems = AT->getElementCount().getKnownMinValue();
6455#else
6456 size_t numElems = AT->getNumElements();
6457#endif
6458 auto SubT = defaultTypeTreeForLLVM(AT->getElementType(), I, intIsPointer);
6459 auto &DL = I->getParent()->getParent()->getParent()->getDataLayout();
6460
6461 TypeTree Out;
6462 for (size_t i = 0; i < numElems; i++) {
6463 Value *vec[2] = {
6464 ConstantInt::get(Type::getInt64Ty(I->getContext()), 0),
6465 ConstantInt::get(Type::getInt32Ty(I->getContext()), i),
6466 };
6467 auto g2 =
6468 GetElementPtrInst::Create(AT, UndefValue::get(getUnqual(AT)), vec);
6469 APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
6470 g2->accumulateConstantOffset(DL, ai);
6471 // Using destructor rather than eraseFromParent
6472 // as g2 has no parent
6473 delete g2;
6474
6475 int Off = (int)ai.getLimitedValue();
6476 auto size = (DL.getTypeSizeInBits(AT->getElementType()) + 7) / 8;
6477 Out |= SubT.ShiftIndices(DL, 0, size, Off);
6478 }
6479 return Out;
6480 }
6481 // Unhandled/unknown Type
6482 llvm::errs() << "Error Unknown Type: " << *ET << "\n";
6483 assert(0 && "Error Unknown Type: ");
6484 llvm_unreachable("Error Unknown Type: ");
6485 // return TypeTree();
6486}
6487
6488Function *TypeResults::getFunction() const {
6490}
6491
6495
6496std::set<int64_t> TypeResults::knownIntegralValues(Value *val) const {
6497 return analyzer->knownIntegralValues(val);
6498}
6499
6500std::set<int64_t> TypeAnalyzer::knownIntegralValues(Value *val) {
6501 return fntypeinfo.knownIntegralValues(val, DT, intseen, SE);
6502}
6503
6505
6507 llvm::Function *todiff) {
6508 FnTypeInfo oldTypeInfo = oldTypeInfo_;
6509 for (auto &pair : oldTypeInfo.KnownValues) {
6510 if (pair.second.size() != 0) {
6511 bool recursiveUse = false;
6512 std::set<std::pair<Value *, Value *>> seen;
6513 SetVector<std::pair<Value *, Value *>> todo;
6514 for (auto user : pair.first->users())
6515 todo.insert(std::make_pair(user, pair.first));
6516 while (todo.size()) {
6517 auto spair = todo.pop_back_val();
6518 if (seen.count(spair))
6519 continue;
6520 seen.insert(spair);
6521 auto [v, prev] = spair;
6522 if (isa<BinaryOperator>(v) || isa<PHINode>(v) || isa<Argument>(v)) {
6523 for (auto user : v->users())
6524 todo.insert(std::make_pair(user, v));
6525 continue;
6526 }
6527 if (auto ci = dyn_cast<CallBase>(v)) {
6528 if (ci->getCalledFunction() == todiff &&
6529 ci->getArgOperand(pair.first->getArgNo()) == prev) {
6530 if (prev == pair.first)
6531 continue;
6532 recursiveUse = true;
6533 break;
6534 }
6535 }
6536 }
6537 if (recursiveUse) {
6538 pair.second.clear();
6539 }
6540 }
6541 }
6542 return oldTypeInfo;
6543}
BaseType
Categories of potential types.
Definition BaseType.h:32
static llvm::SmallPtrSet< llvm::BasicBlock *, 4 > getGuaranteedUnreachable(llvm::Function *F)
static bool isDeallocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory deallocation function For updating below one ...
static bool isAllocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory allocation function For updating below one sh...
static Operation * getFunctionFromCall(CallOpInterface iface)
constexpr const char * to_string(ActivityAnalyzer::UseActivity UA)
static std::string str(AugmentedStruct c)
Definition EnzymeLogic.h:62
TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL)
llvm::cl::opt< bool > EnzymePrintType
The following is not taken from LLVM.
static TypeTree parseTBAA(TBAAStructTypeNode AccessType, llvm::Instruction &I, const llvm::DataLayout &DL, std::shared_ptr< llvm::ModuleSlotTracker > MST)
Given a TBAA access node return the corresponding TypeTree This includes recursively parsing the acce...
Definition TBAA.h:439
TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I, bool intIsPointer)
llvm::cl::opt< int > MaxIntOffset("enzyme-max-int-offset", cl::init(100), cl::Hidden, cl::desc("Maximum type tree offset"))
Maximum offset for type trees to keep.
std::set< SmallVector< T, 4 > > getSet(ArrayRef< std::set< T > > todo, size_t idx)
#define CONSIDER(fn)
FnTypeInfo preventTypeAnalysisLoops(const FnTypeInfo &oldTypeInfo_, llvm::Function *todiff)
void getConstantAnalysis(Constant *Val, TypeAnalyzer &TA, std::map< llvm::Value *, TypeTree > &analysis)
Given a constant value, deduce any type information applicable.
#define CONSIDER2(fn,...)
llvm::cl::opt< bool > RustTypeRules("enzyme-rust-type", cl::init(false), cl::Hidden, cl::desc("Enable rust-specific type rules"))
void analyzeFuncTypesNoFn(CallBase &call, TypeAnalyzer &TA)
static SmallPtrSet< BasicBlock *, 1 > findLoopIndices(llvm::Value *val, LoopInfo &LI, DominatorTree &DT, SmallPtrSet< PHINode *, 1 > &seen)
llvm::cl::opt< bool > EnzymeStrictAliasing("enzyme-strict-aliasing", cl::init(true), cl::Hidden, cl::desc("Assume strict aliasing of types / type stability"))
size_t skippedBytes(SmallSet< size_t, 8 > &offs, Type *T, const DataLayout &DL, size_t offset=0)
static bool isItaniumEncoding(StringRef S)
const llvm::StringMap< llvm::Intrinsic::ID > LIBM_FUNCTIONS
void analyzeFuncTypes(RT(*fn)(Args...), CallBase &call, TypeAnalyzer &TA)
void analyzeIntelSubscriptIntrinsic(IntrinsicInst &II, TypeAnalyzer &TA)
bool dontAnalyze(StringRef str)
FnTypeInfo preventTypeAnalysisLoops(const FnTypeInfo &oldTypeInfo_, llvm::Function *todiff)
static bool isMemFreeLibMFunction(llvm::StringRef str, llvm::Intrinsic::ID *ID=nullptr)
llvm::cl::opt< unsigned > EnzymeMaxTypeDepth
llvm::Optional< BlasInfo > extractBLAS(llvm::StringRef in)
Definition Utils.cpp:3563
LLVMValueRef(* CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, const void *, LLVMValueRef, LLVMBuilderRef)
Definition Utils.cpp:62
bool collectOffset(GEPOperator *gep, const DataLayout &DL, unsigned BitWidth, MapVector< Value *, APInt > &VariableOffsets, APInt &ConstantOffset)
Definition Utils.cpp:4169
static bool isIntelSubscriptIntrinsic(const llvm::IntrinsicInst &II)
Definition Utils.h:1445
static T max(T a, T b)
Pick the maximum value.
Definition Utils.h:262
@ Args
Return is a struct of all args.
static bool startsWith(llvm::StringRef string, llvm::StringRef prefix)
Definition Utils.h:713
static llvm::PointerType * getUnqual(llvm::Type *T)
Definition Utils.h:1179
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
Definition Utils.h:203
static llvm::MDNode * hasMetadata(const llvm::GlobalObject *O, llvm::StringRef kind)
Check if a global has metadata.
Definition Utils.h:339
static bool containsOnlyAtMostTopBit(const llvm::Value *V, llvm::Type *FT, const llvm::DataLayout &dl, llvm::Type **vFT=nullptr)
Definition Utils.h:2047
@ IllegalFirstPointer
@ IllegalTypeAnalysis
static llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op)
Definition Utils.h:1269
static llvm::Optional< size_t > getAllocationIndexFromCall(const llvm::CallBase *op)
Definition Utils.h:1318
Concrete SubType of a given value.
ConcreteType PurgeAnything() const
Keep only mappings where the type is not an Anything
Full interprocedural TypeAnalysis.
llvm::StringMap< std::function< bool(int, TypeTree &, llvm::ArrayRef< TypeTree >, llvm::ArrayRef< std::set< int64_t > >, llvm::CallBase *, TypeAnalyzer *)> > CustomRules
Map of custom function call handlers.
std::map< FnTypeInfo, std::shared_ptr< TypeAnalyzer > > analyzedFunctions
Map of possible query states to TypeAnalyzer intermediate results.
void clear()
Clear existing analyses.
TypeResults analyzeFunction(const FnTypeInfo &fn)
Analyze a particular function, returning the results.
Helper class that computes the fixed-point type results of a given function.
void visitMemTransferInst(llvm::MemTransferInst &MTI)
void visitFPToSIInst(llvm::FPToSIInst &I)
llvm::PostDominatorTree & PDT
FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn)
void visitAllocaInst(llvm::AllocaInst &I)
const FnTypeInfo fntypeinfo
Calling context.
void visitExtractValueInst(llvm::ExtractValueInst &I)
void visitSIToFPInst(llvm::SIToFPInst &I)
const llvm::SmallPtrSet< llvm::BasicBlock *, 4 > notForAnalysis
void considerRustDebugInfo()
Parse the debug info generated by rustc and retrieve useful type info if possible.
llvm::DominatorTree & DT
std::set< int64_t > knownIntegralValues(llvm::Value *val)
void visitFPExtInst(llvm::FPExtInst &I)
void visitConstantExpr(llvm::ConstantExpr &CE)
TypeAnalysis & interprocedural
Calling TypeAnalysis to be used in the case of calls to other functions.
void visitGetElementPtrInst(llvm::GetElementPtrInst &gep)
void visitIntToPtrInst(llvm::IntToPtrInst &I)
void prepareArgs()
Analyze type info given by the arguments, possibly adding to work queue.
void visitShuffleVectorInst(llvm::ShuffleVectorInst &I)
llvm::TargetLibraryInfo & TLI
void visitUIToFPInst(llvm::UIToFPInst &I)
static constexpr uint8_t UP
void visitInsertValueInst(llvm::InsertValueInst &I)
void visitPHINode(llvm::PHINode &phi)
void visitValue(llvm::Value &val)
void visitSelectInst(llvm::SelectInst &I)
void visitIPOCall(llvm::CallBase &call, llvm::Function &fn)
std::shared_ptr< llvm::ModuleSlotTracker > MST
Cache of metadata indices, for faster printing.
llvm::SetVector< llvm::Value *, std::deque< llvm::Value * > > workList
List of value's which should be re-analyzed now with new information.
static constexpr uint8_t DOWN
void updateAnalysis(llvm::Value *val, BaseType data, llvm::Value *origin)
Add additional information to the Type info of val, readding it to the work queue as necessary.
void visitAddrSpaceCastInst(llvm::AddrSpaceCastInst &I)
bool Invalid
Whether an inconsistent update has been found This will only be set when direction !...
static constexpr uint8_t BOTH
void visitExtractElementInst(llvm::ExtractElementInst &I)
void visitLoadInst(llvm::LoadInst &I)
void visitCmpInst(llvm::CmpInst &I)
uint8_t direction
Directionality of checks.
std::map< llvm::Value *, TypeTree > analysis
Intermediate conservative, but correct Type analysis results.
void visitIntrinsicInst(llvm::IntrinsicInst &II)
void visitSExtInst(llvm::SExtInst &I)
void visitGEPOperator(llvm::GEPOperator &gep)
void considerTBAA()
Analyze type info given by the TBAA, possibly adding to work queue.
void visitCallBase(llvm::CallBase &call)
void visitFPTruncInst(llvm::FPTruncInst &I)
void visitStoreInst(llvm::StoreInst &I)
void visitAtomicRMWInst(llvm::AtomicRMWInst &I)
TypeAnalyzer(TypeAnalysis &TA)
void visitZExtInst(llvm::ZExtInst &I)
void runPHIHypotheses()
Hypothesize that undefined phi's are integers and try to prove that they are really integral.
void visitBinaryOperator(llvm::BinaryOperator &I)
void visitMemTransferCommon(llvm::CallBase &MTI)
TypeTree getAnalysis(llvm::Value *Val)
Get the current results for a given value.
void visitBinaryOperation(const llvm::DataLayout &DL, llvm::Type *T, llvm::Instruction::BinaryOps, llvm::Value *Args[2], TypeTree &Ret, TypeTree &LHS, TypeTree &RHS, llvm::Instruction *I)
void visitFPToUIInst(llvm::FPToUIInst &I)
void visitBitCastInst(llvm::BitCastInst &I)
llvm::LoopInfo & LI
void dump(llvm::raw_ostream &ss=llvm::errs())
void visitTruncInst(llvm::TruncInst &I)
void visitPtrToIntInst(llvm::PtrToIntInst &I)
llvm::ScalarEvolution & SE
void run()
Run the interprocedural type analysis starting from this function.
TypeTree getReturnAnalysis()
void visitInsertElementInst(llvm::InsertElementInst &I)
A holder class representing the results of running TypeAnalysis on a given function.
TypeResults(std::nullptr_t)
ConcreteType intType(size_t num, llvm::Value *val, bool errIfNotFound=true, bool pointerIntSame=false) const
llvm::Type * addingType(size_t num, llvm::Value *val, size_t start=0) const
void dump(llvm::raw_ostream &ss=llvm::errs()) const
Prints all known information.
bool anyFloat(llvm::Value *val, bool anythingIsFloat=true) const
Whether any part of the top level register can contain a float e.g.
TypeAnalyzer * analyzer
FnTypeInfo getAnalyzedTypeInfo() const
The TypeInfo calling convention.
bool anyPointer(llvm::Value *val) const
Whether any part of the top level register can contain a pointer e.g.
std::set< int64_t > knownIntegralValues(llvm::Value *val) const
The set of values val will take on during this program.
TypeTree getReturnAnalysis() const
The Type of the return.
TypeTree query(llvm::Value *val) const
The TypeTree of a particular Value.
bool allFloat(llvm::Value *val) const
Whether all of the top level register is known to contain float data.
llvm::Function * getFunction() const
FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn) const
ConcreteType firstPointer(size_t num, llvm::Value *val, llvm::Instruction *I, bool errIfNotFound=true, bool pointerIntSame=false) const
Returns whether in the first num bytes there is pointer, int, float, or none If pointerIntSame is set...
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
Definition TypeTree.h:72
TypeTree Only(int Off, llvm::Instruction *orig) const
Prepend an offset to all mappings.
Definition TypeTree.h:471
TypeTree Data0() const
Peel off the outermost index at offset 0.
Definition TypeTree.h:513
static TypeTree parse(llvm::StringRef str, llvm::LLVMContext &ctx)
Definition TypeTree.h:86
TypeTree ShiftIndices(const llvm::DataLayout &dl, const int offset, const int maxSize, size_t addOffset=0) const
Replace mappings in the range in [offset, offset+maxSize] with those in.
Definition TypeTree.h:840
TypeTree Lookup(size_t len, const llvm::DataLayout &dl) const
Select all submappings whose first index is in range [0, len) and remove the first index.
Definition TypeTree.h:593
ConcreteType Inner0() const
Optimized version of Data0()[{}].
Definition TypeTree.h:548
TypeTree ReplaceMinus() const
Replace -1 with 0.
Definition TypeTree.h:1059
llvm::Type * IsAllFloat(const size_t size, const llvm::DataLayout &dl) const
Definition TypeTree.h:814
static TypeTree fromMD(llvm::MDNode *md)
Definition TypeTree.h:1442
bool isKnownPastPointer() const
Whether this TypeTree knows any non-pointer information.
Definition TypeTree.h:443
TypeTree Clear(size_t start, size_t end, size_t len) const
Remove any mappings in the range [start, end) or [len, inf) This function has special handling for -1...
Definition TypeTree.h:556
TypeTree JustAnything() const
Keep only mappings where the type is an Anything
Definition TypeTree.h:1083
bool checkedOrIn(const std::vector< int > &Seq, ConcreteType RHS, bool PointerIntSame, bool &LegalOr)
Definition TypeTree.h:1108
void CanonicalizeInPlace(size_t len, const llvm::DataLayout &dl)
Given that this tree represents something of at most size len, canonicalize this, creating -1's where...
Definition TypeTree.h:676
std::string str() const
Returns a string representation of this TypeTree.
Definition TypeTree.h:1383
bool isKnown() const
Whether this TypeTree contains any information.
Definition TypeTree.h:431
bool orIn(const std::vector< int > &Seq, ConcreteType RHS, bool PointerIntSame=false)
Definition TypeTree.h:1232
bool insert(const std::vector< int > Seq, ConcreteType CT, bool PointerIntSame=false)
Return if changed.
Definition TypeTree.h:234
bool binopIn(bool &Legal, const TypeTree &RHS, llvm::BinaryOperator::BinaryOps Op)
Set this to the logical binop of itself and RHS, using the Binop Op, returning true if this was chang...
Definition TypeTree.h:1320
TypeTree PurgeAnything() const
Keep only mappings where the type is not an Anything
Definition TypeTree.h:1041
Struct containing all contextual type information for a particular function call.
std::map< llvm::Argument *, TypeTree > Arguments
Types of arguments.
llvm::Function * Function
Function being analyzed.
TypeTree Return
Type of return.
std::set< int64_t > knownIntegralValues(llvm::Value *val, const llvm::DominatorTree &DT, std::map< llvm::Value *, std::set< int64_t > > &intseen, llvm::ScalarEvolution &SE) const
The set of known values val will take.
std::map< llvm::Argument *, std::set< int64_t > > KnownValues
The specific constant(s) known to represented by an argument, if constant.
static void analyzeFuncTypesHelper(unsigned idx, CallBase &call, TypeAnalyzer &TA)
static void analyzeFuncTypesHelper(unsigned idx, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
static void analyzeType(Value *val, CallBase &call, TypeAnalyzer &TA)
This template class is defined to take the templated type T update the analysis of the first argument...