Enzyme main
Loading...
Searching...
No Matches
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Definition of miscellaneous utilities ------------------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file defines miscellaneous utilities that are used as part of the
22// AD process.
23//
24//===----------------------------------------------------------------------===//
25#include "Utils.h"
26#include "GradientUtils.h"
28
29#if LLVM_VERSION_MAJOR >= 16
30#include "llvm/Analysis/ScalarEvolution.h"
31#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
32#else
33#include "SCEV/ScalarEvolution.h"
34#include "SCEV/ScalarEvolutionExpander.h"
35#endif
36
37#include "TypeAnalysis/TBAA.h"
38#include "llvm/IR/BasicBlock.h"
39#include "llvm/IR/DerivedTypes.h"
40#include "llvm/IR/Function.h"
41#include "llvm/IR/GetElementPtrTypeIterator.h"
42#include "llvm/IR/IRBuilder.h"
43#include "llvm/IR/InlineAsm.h"
44#include "llvm/IR/Module.h"
45#include "llvm/IR/Type.h"
46#include "llvm/IR/Verifier.h"
47
48#if LLVM_VERSION_MAJOR >= 16
49#include "llvm/TargetParser/Triple.h"
50#else
51#include "llvm/ADT/Triple.h"
52#endif
53
54#include "llvm-c/Core.h"
55
56#include "BlasAttributor.inc"
57#include "LibraryFuncs.h"
58
59using namespace llvm;
60
61extern "C" {
62LLVMValueRef (*CustomErrorHandler)(const char *, LLVMValueRef, ErrorType,
63 const void *, LLVMValueRef,
64 LLVMBuilderRef) = nullptr;
65LLVMValueRef (*CustomAllocator)(LLVMBuilderRef, LLVMTypeRef,
66 /*Count*/ LLVMValueRef,
67 /*Align*/ LLVMValueRef, uint8_t,
68 LLVMValueRef *) = nullptr;
69void (*CustomZero)(LLVMBuilderRef, LLVMTypeRef,
70 /*Ptr*/ LLVMValueRef, uint8_t) = nullptr;
71LLVMValueRef (*CustomDeallocator)(LLVMBuilderRef, LLVMValueRef) = nullptr;
72void (*CustomRuntimeInactiveError)(LLVMBuilderRef, LLVMValueRef,
73 LLVMValueRef) = nullptr;
74LLVMValueRef *(*EnzymePostCacheStore)(LLVMValueRef, LLVMBuilderRef,
75 uint64_t *size) = nullptr;
76LLVMTypeRef (*EnzymeDefaultTapeType)(LLVMContextRef) = nullptr;
77LLVMValueRef (*EnzymeUndefinedValueForType)(LLVMModuleRef, LLVMTypeRef,
78 uint8_t) = nullptr;
79
80LLVMValueRef (*EnzymeSanitizeDerivatives)(LLVMValueRef, LLVMValueRef toset,
81 LLVMBuilderRef,
82 LLVMValueRef) = nullptr;
83
84extern llvm::cl::opt<bool> EnzymeZeroCache;
85
86// default to false because lacpy is slow
87llvm::cl::opt<bool>
88 EnzymeLapackCopy("enzyme-lapack-copy", cl::init(false), cl::Hidden,
89 cl::desc("Use blas copy calls to cache matrices"));
90llvm::cl::opt<bool>
91 EnzymeBlasCopy("enzyme-blas-copy", cl::init(true), cl::Hidden,
92 cl::desc("Use blas copy calls to cache vectors"));
93llvm::cl::opt<bool>
94 EnzymeFastMath("enzyme-fast-math", cl::init(true), cl::Hidden,
95 cl::desc("Use fast math on derivative compuation"));
96llvm::cl::opt<bool> EnzymeMemmoveWarning(
97 "enzyme-memmove-warning", cl::init(true), cl::Hidden,
98 cl::desc("Warn if using memmove implementation as a fallback for memmove"));
99llvm::cl::opt<bool> EnzymeRuntimeError(
100 "enzyme-runtime-error", cl::init(false), cl::Hidden,
101 cl::desc("Emit Runtime errors instead of compile time ones"));
102
103llvm::cl::opt<bool> EnzymeCheckDerivativeNaN(
104 "enzyme-check-nan", cl::init(false), cl::Hidden,
105 cl::desc("Add NaN checks to all derivative intermediate values"));
106
107llvm::cl::opt<bool> EnzymeNonPower2Cache(
108 "enzyme-non-power2-cache", cl::init(false), cl::Hidden,
109 cl::desc("Disable caching of integers which are not a power of 2"));
110}
111
112#define addAttribute addAttributeAtIndex
113#define getAttribute getAttributeAtIndex
114bool attributeKnownFunctions(llvm::Function &F) {
115 bool changed = false;
116 if (F.getName() == "fprintf") {
117 for (auto &arg : F.args()) {
118 if (arg.getType()->isPointerTy()) {
119 addFunctionNoCapture(&F, arg.getArgNo());
120 changed = true;
121 }
122 }
123 }
124 if (F.getName().contains("__enzyme_float") ||
125 F.getName().contains("__enzyme_double") ||
126 F.getName().contains("__enzyme_integer") ||
127 F.getName().contains("__enzyme_pointer") ||
128 F.getName().contains("__enzyme_todense") ||
129 F.getName().contains("__enzyme_ignore_derivatives") ||
130 F.getName().contains("__enzyme_iter") ||
131 F.getName().contains("__enzyme_virtualreverse")) {
132 changed = true;
133#if LLVM_VERSION_MAJOR >= 16
134 F.setOnlyReadsMemory();
135 F.setOnlyWritesMemory();
136#else
137 F.addFnAttr(Attribute::ReadNone);
138#endif
139 if (!(F.getName().contains("__enzyme_todense") ||
140 F.getName().contains("__enzyme_ignore_derivatives"))) {
141 for (auto &arg : F.args()) {
142 if (arg.getType()->isPointerTy()) {
143 arg.addAttr(Attribute::ReadNone);
144 addFunctionNoCapture(&F, arg.getArgNo());
145 }
146 }
147 }
148 }
149 if (F.getName() == "memcmp") {
150 changed = true;
151#if LLVM_VERSION_MAJOR >= 16
152 F.setOnlyAccessesArgMemory();
153 F.setOnlyReadsMemory();
154#else
155 F.addFnAttr(Attribute::ArgMemOnly);
156 F.addFnAttr(Attribute::ReadOnly);
157#endif
158 F.addFnAttr(Attribute::NoUnwind);
159 F.addFnAttr(Attribute::NoRecurse);
160 F.addFnAttr(Attribute::WillReturn);
161 F.addFnAttr(Attribute::NoFree);
162 F.addFnAttr(Attribute::NoSync);
163 for (int i = 0; i < 2; i++)
164 if (F.getFunctionType()->getParamType(i)->isPointerTy()) {
166 F.addParamAttr(i, Attribute::ReadOnly);
167 }
168 }
169
170 if (F.getName() ==
171 "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_createERmm") {
172 changed = true;
173 F.addFnAttr(Attribute::NoFree);
174 }
175 if (F.getName() == "MPI_Irecv" || F.getName() == "PMPI_Irecv") {
176 auto FT = F.getFunctionType();
177 bool PointerABI = true;
178 changed = true;
179 F.addFnAttr(Attribute::NoUnwind);
180 F.addFnAttr(Attribute::NoRecurse);
181 F.addFnAttr(Attribute::WillReturn);
182 F.addFnAttr(Attribute::NoFree);
183 F.addFnAttr(Attribute::NoSync);
184 if (FT->getParamType(0)->isPointerTy()) {
185 F.addParamAttr(0, Attribute::WriteOnly);
186 } else {
187 PointerABI = false;
188 }
189 // OpenMPI vs MPICH
190 if (FT->getParamType(2)->isPointerTy()) {
192 F.addParamAttr(2, Attribute::WriteOnly);
193 }
194 if (FT->getParamType(6)->isPointerTy()) {
195 F.addParamAttr(6, Attribute::WriteOnly);
196 } else {
197 PointerABI = false;
198 }
199 if (PointerABI) {
200#if LLVM_VERSION_MAJOR >= 16
201 F.setOnlyAccessesInaccessibleMemOrArgMem();
202#else
203 F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
204#endif
205 }
206 }
207 auto name = getFuncName(&F);
208 if (name == "MPI_Isend" || name == "PMPI_Isend") {
209 auto FT = F.getFunctionType();
210 bool PointerABI = true;
211 changed = true;
212 F.addFnAttr(Attribute::NoUnwind);
213 F.addFnAttr(Attribute::NoRecurse);
214 F.addFnAttr(Attribute::WillReturn);
215 F.addFnAttr(Attribute::NoFree);
216 F.addFnAttr(Attribute::NoSync);
217 if (FT->getParamType(0)->isPointerTy()) {
218 F.addParamAttr(0, Attribute::ReadOnly);
219 } else {
220 PointerABI = false;
221 }
222 // OpenMPI vs MPICH
223 if (FT->getParamType(2)->isPointerTy()) {
225 F.addParamAttr(2, Attribute::ReadOnly);
226 }
227 if (FT->getParamType(6)->isPointerTy()) {
228 F.addParamAttr(6, Attribute::WriteOnly);
229 } else {
230 PointerABI = false;
231 }
232 if (PointerABI) {
233#if LLVM_VERSION_MAJOR >= 16
234 F.setOnlyAccessesInaccessibleMemOrArgMem();
235#else
236 F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
237#endif
238 }
239 }
240 if (name == "MPI_Comm_rank" || name == "PMPI_Comm_rank" ||
241 name == "MPI_Comm_size" || name == "PMPI_Comm_size") {
242 auto FT = F.getFunctionType();
243 bool PointerABI = true;
244 changed = true;
245 F.addFnAttr(Attribute::NoUnwind);
246 F.addFnAttr(Attribute::NoRecurse);
247 F.addFnAttr(Attribute::WillReturn);
248 F.addFnAttr(Attribute::NoFree);
249 F.addFnAttr(Attribute::NoSync);
250
251 // OpenMPI vs MPICH
252 if (FT->getParamType(0)->isPointerTy()) {
254 F.addParamAttr(0, Attribute::ReadOnly);
255 }
256 if (FT->getParamType(1)->isPointerTy()) {
257 F.addParamAttr(1, Attribute::WriteOnly);
259 } else {
260 PointerABI = false;
261 }
262 if (PointerABI) {
263#if LLVM_VERSION_MAJOR >= 16
264 F.setOnlyAccessesInaccessibleMemOrArgMem();
265#else
266 F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
267#endif
268 }
269 }
270 if (name == "MPI_Wait" || name == "PMPI_Wait") {
271 changed = true;
272 F.addFnAttr(Attribute::NoUnwind);
273 F.addFnAttr(Attribute::NoRecurse);
274 F.addFnAttr(Attribute::WillReturn);
275 F.addFnAttr(Attribute::NoFree);
276 F.addFnAttr(Attribute::NoSync);
277 if (F.getFunctionType()->getParamType(0)->isPointerTy()) {
279 }
280 if (F.getFunctionType()->getParamType(1)->isPointerTy()) {
281 F.addParamAttr(1, Attribute::WriteOnly);
283 }
284 }
285 if (name == "MPI_Waitall" || name == "PMPI_Waitall") {
286 changed = true;
287 F.addFnAttr(Attribute::NoUnwind);
288 F.addFnAttr(Attribute::NoRecurse);
289 F.addFnAttr(Attribute::WillReturn);
290 F.addFnAttr(Attribute::NoFree);
291 F.addFnAttr(Attribute::NoSync);
292 if (F.getFunctionType()->getParamType(1)->isPointerTy()) {
294 }
295 if (F.getFunctionType()->getParamType(2)->isPointerTy()) {
296 F.addParamAttr(2, Attribute::WriteOnly);
298 }
299 }
300 // Map of MPI function name to the arg index of its type argument
301 std::map<std::string, int> MPI_TYPE_ARGS = {
302 {"MPI_Send", 2}, {"MPI_Ssend", 2}, {"MPI_Bsend", 2},
303 {"MPI_Recv", 2}, {"MPI_Brecv", 2}, {"PMPI_Send", 2},
304 {"PMPI_Ssend", 2}, {"PMPI_Bsend", 2}, {"PMPI_Recv", 2},
305 {"PMPI_Brecv", 2},
306
307 {"MPI_Isend", 2}, {"MPI_Irecv", 2}, {"PMPI_Isend", 2},
308 {"PMPI_Irecv", 2},
309
310 {"MPI_Reduce", 3}, {"PMPI_Reduce", 3},
311
312 {"MPI_Allreduce", 3}, {"PMPI_Allreduce", 3}};
313 {
314 auto found = MPI_TYPE_ARGS.find(name.str());
315 if (found != MPI_TYPE_ARGS.end()) {
316 for (auto user : F.users()) {
317 if (auto CI = dyn_cast<CallBase>(user))
318 if (CI->getCalledFunction() == &F) {
319 if (Constant *C =
320 dyn_cast<Constant>(CI->getArgOperand(found->second))) {
321 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
322 C = CE->getOperand(0);
323 }
324 if (auto GV = dyn_cast<GlobalVariable>(C)) {
325 if (GV->getName() == "ompi_mpi_cxx_bool") {
326 changed = true;
327 CI->addAttribute(
328 AttributeList::FunctionIndex,
329 Attribute::get(CI->getContext(), "enzyme_inactive"));
330 }
331 }
332 }
333 }
334 }
335 }
336 }
337
338 if (F.getName() == "omp_get_max_threads" ||
339 F.getName() == "omp_get_thread_num") {
340 changed = true;
341#if LLVM_VERSION_MAJOR >= 16
342 F.setOnlyAccessesInaccessibleMemory();
343 F.setOnlyReadsMemory();
344#else
345 F.addFnAttr(Attribute::InaccessibleMemOnly);
346 F.addFnAttr(Attribute::ReadOnly);
347#endif
348 }
349 if (F.getName() == "frexp" || F.getName() == "frexpf" ||
350 F.getName() == "frexpl") {
351 changed = true;
352#if LLVM_VERSION_MAJOR >= 16
353 F.setOnlyAccessesArgMemory();
354#else
355 F.addFnAttr(Attribute::ArgMemOnly);
356#endif
357 F.addParamAttr(1, Attribute::WriteOnly);
358 }
359 if (F.getName() == "__fd_sincos_1" || F.getName() == "__fd_cos_1" ||
360 F.getName() == "__mth_i_ipowi") {
361 changed = true;
362#if LLVM_VERSION_MAJOR >= 16
363 F.setOnlyReadsMemory();
364 F.setOnlyWritesMemory();
365#else
366 F.addFnAttr(Attribute::ReadNone);
367#endif
368 }
369
370 const char *NonEscapingFns[] = {
371 "julia.ptls_states",
372 "julia.get_pgcstack",
373 "lgamma_r",
374 "memcmp",
375 "_ZNSt6chrono3_V212steady_clock3nowEv",
376 "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_"
377 "createERmm",
378 "_ZNKSt8__detail20_Prime_rehash_policy14_M_need_rehashEmmm",
379 "fprintf",
380 "fwrite",
381 "fputc",
382 "strtol",
383 "getenv",
384 "memchr",
385 "cublasSetMathMode",
386 "cublasSetStream_v2",
387 "cuMemPoolTrimTo",
388 "cuDeviceGetMemPool",
389 "cuStreamSynchronize",
390 "cuStreamDestroy",
391 "cuStreamQuery",
392 "cuCtxGetCurrent",
393 "cuDeviceGet",
394 "cuDeviceGetName",
395 "cuDriverGetVersion",
396 "cudaRuntimeGetVersion",
397 "cuDeviceGetCount",
398 "cuMemPoolGetAttribute",
399 "cuMemGetInfo_v2",
400 "cuDeviceGetAttribute",
401 "cuDevicePrimaryCtxRetain",
402 };
403 for (auto fname : NonEscapingFns)
404 if (name == fname) {
405 changed = true;
406 F.addAttribute(
407 AttributeList::FunctionIndex,
408 Attribute::get(F.getContext(), "enzyme_no_escaping_allocation"));
409 }
410 changed |= attributeTablegen(F);
411 return changed;
412}
413
414void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj,
415 bool isTape) {
416 if (CustomZero) {
417 CustomZero(wrap(&Builder), wrap(T), wrap(obj), isTape);
418 } else {
419 Builder.CreateStore(Constant::getNullValue(T), obj);
420 }
421}
422
423llvm::SmallVector<llvm::Instruction *, 2> PostCacheStore(llvm::StoreInst *SI,
424 llvm::IRBuilder<> &B) {
425 SmallVector<llvm::Instruction *, 2> res;
427 uint64_t size = 0;
428 auto ptr = EnzymePostCacheStore(wrap(SI), wrap(&B), &size);
429 for (size_t i = 0; i < size; i++) {
430 res.push_back(cast<Instruction>(unwrap(ptr[i])));
431 }
432 free(ptr);
433 }
434 return res;
435}
436
437llvm::PointerType *getDefaultAnonymousTapeType(llvm::LLVMContext &C) {
439 return cast<PointerType>(unwrap(EnzymeDefaultTapeType(wrap(&C))));
440 return getInt8PtrTy(C);
441}
442
443Function *getOrInsertExponentialAllocator(Module &M, Function *newFunc,
444 bool ZeroInit, llvm::Type *RT) {
445 bool custom = true;
446 llvm::PointerType *allocType;
447 {
448 auto i64 = Type::getInt64Ty(newFunc->getContext());
449 BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", newFunc);
450 IRBuilder<> B(BB);
451 auto P = B.CreatePHI(i64, 1);
452 CallInst *malloccall;
453 Instruction *SubZero = nullptr;
454 CreateAllocation(B, RT, P, "tapemem", &malloccall, &SubZero)->getType();
455 if (auto F = getFunctionFromCall(malloccall)) {
456 custom = F->getName() != "malloc";
457 }
458 allocType = cast<PointerType>(malloccall->getType());
459 if (ZeroInit && !SubZero)
460 ZeroInit = false;
461 BB->eraseFromParent();
462 }
463
464 Type *types[] = {allocType, Type::getInt64Ty(M.getContext()),
465 Type::getInt64Ty(M.getContext())};
466 std::string name = "__enzyme_exponentialallocation";
467 if (ZeroInit)
468 name += "zero";
469 if (custom)
470 name += ".custom@" + std::to_string((size_t)RT);
471
472 FunctionType *FT = FunctionType::get(allocType, types, false);
473 AttributeList AL;
474 if (newFunc->hasFnAttribute("enzymejl_world")) {
475 AL = AL.addFnAttribute(newFunc->getContext(),
476 newFunc->getFnAttribute("enzymejl_world"));
477 }
478 Function *F = cast<Function>(M.getOrInsertFunction(name, FT, AL).getCallee());
479
480 if (!F->empty())
481 return F;
482
483 F->setLinkage(Function::LinkageTypes::InternalLinkage);
484 F->addFnAttr(Attribute::AlwaysInline);
485 F->addFnAttr(Attribute::NoUnwind);
486 BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
487 BasicBlock *grow = BasicBlock::Create(M.getContext(), "grow", F);
488 BasicBlock *ok = BasicBlock::Create(M.getContext(), "ok", F);
489
490 IRBuilder<> B(entry);
491
492 Argument *ptr = F->arg_begin();
493 ptr->setName("ptr");
494 Argument *size = ptr + 1;
495 size->setName("size");
496 Argument *tsize = size + 1;
497 tsize->setName("tsize");
498
499 Value *hasOne = B.CreateICmpNE(
500 B.CreateAnd(size, ConstantInt::get(size->getType(), 1, false)),
501 ConstantInt::get(size->getType(), 0, false));
502 auto popCnt = getIntrinsicDeclaration(&M, Intrinsic::ctpop, {types[1]});
503
504 B.CreateCondBr(
505 B.CreateAnd(B.CreateICmpULT(B.CreateCall(popCnt, {size}),
506 ConstantInt::get(types[1], 3, false)),
507 hasOne),
508 grow, ok);
509
510 B.SetInsertPoint(grow);
511
512 auto lz =
513 B.CreateCall(getIntrinsicDeclaration(&M, Intrinsic::ctlz, {types[1]}),
514 {size, ConstantInt::getTrue(M.getContext())});
515 Value *next =
516 B.CreateShl(tsize, B.CreateSub(ConstantInt::get(types[1], 64, false), lz,
517 "", true, true));
518
519 Value *gVal;
520
521 Value *prevSize =
522 B.CreateSelect(B.CreateICmpEQ(size, ConstantInt::get(size->getType(), 1)),
523 ConstantInt::get(next->getType(), 0),
524 B.CreateLShr(next, ConstantInt::get(next->getType(), 1)));
525
526 auto Arch = llvm::Triple(M.getTargetTriple()).getArch();
527 bool forceMalloc = Arch == Triple::nvptx || Arch == Triple::nvptx64;
528
529 if (!custom && !forceMalloc) {
530 auto reallocF = M.getOrInsertFunction("realloc", allocType, allocType,
531 Type::getInt64Ty(M.getContext()));
532
533 Value *args[] = {B.CreatePointerCast(ptr, allocType), next};
534 gVal = B.CreateCall(reallocF, args);
535 } else {
536 Value *tsize = ConstantInt::get(
537 next->getType(),
538 newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(RT) / 8);
539 auto elSize = B.CreateUDiv(next, tsize, "", /*isExact*/ true);
540 Instruction *SubZero = nullptr;
541 gVal = CreateAllocation(B, RT, elSize, "", nullptr, &SubZero);
542
543 Type *bTy =
544 PointerType::get(Type::getInt8Ty(gVal->getContext()),
545 cast<PointerType>(gVal->getType())->getAddressSpace());
546 gVal = B.CreatePointerCast(gVal, bTy);
547 auto pVal = B.CreatePointerCast(ptr, gVal->getType());
548
549 Value *margs[] = {gVal, pVal, prevSize,
550 ConstantInt::getFalse(M.getContext())};
551 Type *tys[] = {margs[0]->getType(), margs[1]->getType(),
552 margs[2]->getType()};
553 auto memsetF = getIntrinsicDeclaration(&M, Intrinsic::memcpy, tys);
554 B.CreateCall(memsetF, margs);
555 if (SubZero) {
556 ZeroInit = false;
557 IRBuilder<> BB(SubZero);
558 Value *zeroSize = BB.CreateSub(next, prevSize);
559 Value *tmp = SubZero->getOperand(0);
560 Type *tmpT = tmp->getType();
561 tmp = BB.CreatePointerCast(tmp, bTy);
562 tmp = BB.CreateInBoundsGEP(Type::getInt8Ty(tmp->getContext()), tmp,
563 prevSize);
564 tmp = BB.CreatePointerCast(tmp, tmpT);
565 SubZero->setOperand(0, tmp);
566 SubZero->setOperand(2, zeroSize);
567 }
568 }
569
570 if (ZeroInit) {
571 Value *zeroSize = B.CreateSub(next, prevSize);
572
573 Value *margs[] = {B.CreateInBoundsGEP(B.getInt8Ty(), gVal, prevSize),
574 B.getInt8(0), zeroSize, B.getFalse()};
575 Type *tys[] = {margs[0]->getType(), margs[2]->getType()};
576 auto memsetF = getIntrinsicDeclaration(&M, Intrinsic::memset, tys);
577 B.CreateCall(memsetF, margs);
578 }
579 gVal = B.CreatePointerCast(gVal, ptr->getType());
580
581 B.CreateBr(ok);
582 B.SetInsertPoint(ok);
583 auto phi = B.CreatePHI(ptr->getType(), 2);
584 phi->addIncoming(gVal, grow);
585 phi->addIncoming(ptr, entry);
586 B.CreateRet(phi);
587 return F;
588}
589
590llvm::Value *CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev,
591 llvm::Type *T, llvm::Value *OuterCount,
592 llvm::Value *InnerCount,
593 const llvm::Twine &Name,
594 llvm::CallInst **caller, bool ZeroMem) {
595 auto newFunc = B.GetInsertBlock()->getParent();
596
597 Value *tsize = ConstantInt::get(
598 InnerCount->getType(),
599 newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(T) / 8);
600
601 Value *idxs[] = {
602 /*ptr*/
603 prev,
604 /*incrementing value to increase when it goes past a power of two*/
605 OuterCount,
606 /*buffer size (element x subloops)*/
607 B.CreateMul(tsize, InnerCount, "", /*NUW*/ true,
608 /*NSW*/ true)};
609
610 auto realloccall =
611 B.CreateCall(getOrInsertExponentialAllocator(*newFunc->getParent(),
612 newFunc, ZeroMem, T),
613 idxs, Name);
614 if (caller)
615 *caller = realloccall;
616 return realloccall;
617}
618
619Value *CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count,
620 const Twine &Name, CallInst **caller,
621 Instruction **ZeroMem, bool isDefault) {
622 Value *res;
623 auto &M = *Builder.GetInsertBlock()->getParent()->getParent();
624 auto AlignI = M.getDataLayout().getTypeAllocSizeInBits(T) / 8;
625 auto Align = ConstantInt::get(Count->getType(), AlignI);
626 CallInst *malloccall = nullptr;
627 if (CustomAllocator) {
628 LLVMValueRef wzeromem = nullptr;
629 res = unwrap(CustomAllocator(wrap(&Builder), wrap(T), wrap(Count),
630 wrap(Align), isDefault,
631 ZeroMem ? &wzeromem : nullptr));
632 if (isa<UndefValue>(res))
633 return res;
634 if (isa<Constant>(res))
635 return res;
636 if (auto I = dyn_cast<Instruction>(res))
637 I->setName(Name);
638
639 malloccall = dyn_cast<CallInst>(res);
640 if (malloccall == nullptr) {
641 malloccall = cast<CallInst>(cast<Instruction>(res)->getOperand(0));
642 }
643 if (ZeroMem) {
644 *ZeroMem = cast_or_null<Instruction>(unwrap(wzeromem));
645 ZeroMem = nullptr;
646 }
647 } else {
648#if LLVM_VERSION_MAJOR > 17
649 res =
650 Builder.CreateMalloc(Count->getType(), T, Align, Count, nullptr, Name);
651#else
652 if (Builder.GetInsertPoint() == Builder.GetInsertBlock()->end()) {
653 res = CallInst::CreateMalloc(Builder.GetInsertBlock(), Count->getType(),
654 T, Align, Count, nullptr, Name);
655 Builder.SetInsertPoint(Builder.GetInsertBlock());
656 } else {
657 res = CallInst::CreateMalloc(&*Builder.GetInsertPoint(), Count->getType(),
658 T, Align, Count, nullptr, Name);
659 }
660 if (!cast<Instruction>(res)->getParent())
661 Builder.Insert(cast<Instruction>(res));
662#endif
663
664 malloccall = dyn_cast<CallInst>(res);
665 if (malloccall == nullptr) {
666 malloccall = cast<CallInst>(cast<Instruction>(res)->getOperand(0));
667 }
668
669 // Assert computation of size of array doesn't wrap
670 if (auto BI = dyn_cast<BinaryOperator>(malloccall->getArgOperand(0))) {
671 if (BI->getOpcode() == BinaryOperator::Mul) {
672 if ((BI->getOperand(0) == Align && BI->getOperand(1) == Count) ||
673 (BI->getOperand(1) == Align && BI->getOperand(0) == Count))
674 BI->setHasNoSignedWrap(true);
675 BI->setHasNoUnsignedWrap(true);
676 }
677 }
678
679 if (auto ci = dyn_cast<ConstantInt>(Count)) {
680#if LLVM_VERSION_MAJOR >= 14
681 malloccall->addDereferenceableRetAttr(ci->getLimitedValue() * AlignI);
682#if !defined(FLANG) && !defined(ROCM)
683 AttrBuilder B(ci->getContext());
684#else
685 AttrBuilder B;
686#endif
687 B.addDereferenceableOrNullAttr(ci->getLimitedValue() * AlignI);
688 malloccall->setAttributes(malloccall->getAttributes().addRetAttributes(
689 malloccall->getContext(), B));
690#else
691 malloccall->addDereferenceableAttr(llvm::AttributeList::ReturnIndex,
692 ci->getLimitedValue() * AlignI);
693 malloccall->addDereferenceableOrNullAttr(llvm::AttributeList::ReturnIndex,
694 ci->getLimitedValue() * AlignI);
695#endif
696 // malloccall->removeAttribute(llvm::AttributeList::ReturnIndex,
697 // Attribute::DereferenceableOrNull);
698 }
699#if LLVM_VERSION_MAJOR >= 14
700 malloccall->addAttributeAtIndex(AttributeList::ReturnIndex,
701 Attribute::NoAlias);
702 malloccall->addAttributeAtIndex(AttributeList::ReturnIndex,
703 Attribute::NonNull);
704#else
705 malloccall->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
706 malloccall->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
707#endif
708 }
709 if (caller) {
710 *caller = malloccall;
711 }
712 if (ZeroMem) {
713 auto PT = cast<PointerType>(malloccall->getType());
714 Value *tozero = malloccall;
715
716 bool needsCast = false;
717#if LLVM_VERSION_MAJOR < 17
718#if LLVM_VERSION_MAJOR >= 15
719 if (PT->getContext().supportsTypedPointers()) {
720#endif
721 needsCast = !PT->getPointerElementType()->isIntegerTy(8);
722#if LLVM_VERSION_MAJOR >= 15
723 }
724#endif
725#endif
726 if (needsCast)
727 tozero = Builder.CreatePointerCast(
728 tozero, PointerType::get(Type::getInt8Ty(PT->getContext()),
729 PT->getAddressSpace()));
730 Value *args[] = {
731 tozero, ConstantInt::get(Type::getInt8Ty(malloccall->getContext()), 0),
732 Builder.CreateMul(Align, Count, "", true, true),
733 ConstantInt::getFalse(malloccall->getContext())};
734 Type *tys[] = {args[0]->getType(), args[2]->getType()};
735
736 *ZeroMem = Builder.CreateCall(
737 getIntrinsicDeclaration(&M, Intrinsic::memset, tys), args);
738 }
739 return res;
740}
741
742CallInst *CreateDealloc(llvm::IRBuilder<> &Builder, llvm::Value *ToFree) {
743 CallInst *res = nullptr;
744
745 if (CustomDeallocator) {
746 res = dyn_cast_or_null<CallInst>(
747 unwrap(CustomDeallocator(wrap(&Builder), wrap(ToFree))));
748 } else {
749
750 ToFree =
751 Builder.CreatePointerCast(ToFree, getInt8PtrTy(ToFree->getContext()));
752#if LLVM_VERSION_MAJOR > 17
753 res = cast<CallInst>(Builder.CreateFree(ToFree));
754#else
755 if (Builder.GetInsertPoint() == Builder.GetInsertBlock()->end()) {
756 res = cast<CallInst>(
757 CallInst::CreateFree(ToFree, Builder.GetInsertBlock()));
758 Builder.SetInsertPoint(Builder.GetInsertBlock());
759 } else {
760 res = cast<CallInst>(
761 CallInst::CreateFree(ToFree, &*Builder.GetInsertPoint()));
762 }
763 if (!cast<Instruction>(res)->getParent())
764 Builder.Insert(cast<Instruction>(res));
765#endif
766#if LLVM_VERSION_MAJOR >= 14
767 res->addAttributeAtIndex(AttributeList::FirstArgIndex, Attribute::NonNull);
768#else
769 res->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull);
770#endif
771 }
772 return res;
773}
774
775EnzymeWarning::EnzymeWarning(const llvm::Twine &RemarkName,
776 const llvm::DiagnosticLocation &Loc,
777 const llvm::Instruction *CodeRegion)
778 : EnzymeWarning(RemarkName, Loc, CodeRegion->getParent()->getParent()) {}
779
780EnzymeWarning::EnzymeWarning(const llvm::Twine &RemarkName,
781 const llvm::DiagnosticLocation &Loc,
782 const llvm::Function *CodeRegion)
783 : DiagnosticInfoUnsupported(*CodeRegion, RemarkName, Loc, DS_Warning) {}
784
785EnzymeFailure::EnzymeFailure(const llvm::Twine &RemarkName,
786 const llvm::DiagnosticLocation &Loc,
787 const llvm::Instruction *CodeRegion)
788 : EnzymeFailure(RemarkName, Loc, CodeRegion->getParent()->getParent()) {}
789
790EnzymeFailure::EnzymeFailure(const llvm::Twine &RemarkName,
791 const llvm::DiagnosticLocation &Loc,
792 const llvm::Function *CodeRegion)
793 : DiagnosticInfoUnsupported(*CodeRegion, RemarkName, Loc) {}
794
795/// Convert a floating type to a string
796static inline std::string tofltstr(Type *T) {
797 if (auto VT = dyn_cast<VectorType>(T)) {
798#if LLVM_VERSION_MAJOR >= 12
799 auto len = VT->getElementCount().getFixedValue();
800#else
801 auto len = VT->getNumElements();
802#endif
803 return "vec" + std::to_string(len) + tofltstr(VT->getElementType());
804 }
805 switch (T->getTypeID()) {
806 case Type::HalfTyID:
807 return "half";
808 case Type::FloatTyID:
809 return "float";
810 case Type::DoubleTyID:
811 return "double";
812 case Type::X86_FP80TyID:
813 return "x87d";
814 case Type::BFloatTyID:
815 return "bf16";
816 case Type::FP128TyID:
817 return "quad";
818 case Type::PPC_FP128TyID:
819 return "ppcddouble";
820 default:
821 llvm_unreachable("Invalid floating type");
822 }
823}
824
825Constant *getString(Module &M, StringRef Str) {
826 llvm::Constant *s = llvm::ConstantDataArray::getString(M.getContext(), Str);
827 auto *gv = new llvm::GlobalVariable(
828 M, s->getType(), true, llvm::GlobalValue::PrivateLinkage, s, ".str");
829 gv->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
830 Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(M.getContext()), 0),
831 ConstantInt::get(Type::getInt32Ty(M.getContext()), 0)};
832 return ConstantExpr::getInBoundsGetElementPtr(s->getType(), gv, Idxs);
833}
834
835void emit_backtrace(llvm::Instruction *inst, llvm::raw_ostream &ss) {
836 SmallPtrSet<llvm::Instruction *, 8> visited;
837 while (true) {
838 if (visited.contains(inst))
839 break;
840 visited.insert(inst);
841
842 // Print debug info for this instruction
843 if (auto dbgLoc = inst->getDebugLoc()) {
844 auto *loc = dbgLoc.get();
845 while (loc) {
846 if (auto *scope = loc->getScope()) {
847 StringRef name = scope->getName();
848 // Remove trailing semicolons (Julia-style function name decoration)
849 while (!name.empty() && name.back() == ';')
850 name = name.drop_back();
851 if (auto *file = scope->getFile()) {
852 StringRef dir = file->getDirectory();
853 StringRef fn = file->getFilename();
854 ss << " in '" << name << "' at ";
855 if (!dir.empty())
856 ss << dir << "/";
857 ss << fn << ":" << loc->getLine() << "\n";
858 } else {
859 ss << " in '" << name << "' at unknown:" << loc->getLine() << "\n";
860 }
861 }
862 loc = loc->getInlinedAt();
863 }
864 }
865
866 // Move up the call chain
867 Function *f = inst->getParent()->getParent();
868
869 // Collect callers with debug info
870 SmallVector<CallInst *, 4> callersWithDbg;
871 for (auto *U : f->users()) {
872 auto *CI = dyn_cast<CallInst>(U);
873 if (!CI)
874 continue;
875 if (!CI->getDebugLoc())
876 continue;
877 callersWithDbg.push_back(CI);
878 }
879
880 if (callersWithDbg.empty())
881 break;
882
883 // Deduplicate by debug location MDNode
884 SmallVector<CallInst *, 4> uniqueCallSites;
885 SmallPtrSet<const MDNode *, 4> seenMD;
886 for (auto *CI : callersWithDbg) {
887 if (seenMD.insert(CI->getDebugLoc().getAsMDNode()).second)
888 uniqueCallSites.push_back(CI);
889 }
890
891 if (uniqueCallSites.size() > 1) {
892 ss << " (multiple call sites)\n";
893 break;
894 } else if (uniqueCallSites.size() == 1) {
895 inst = uniqueCallSites[0];
896 continue;
897 }
898 break;
899 }
900}
901
902void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal,
903 llvm::Value *shadow, const char *Message,
904 llvm::DebugLoc &&loc, llvm::Instruction *orig) {
905 Module &M = *B.GetInsertBlock()->getParent()->getParent();
906 std::string name = "__enzyme_runtimeinactiveerr";
908 static int count = 0;
909 name += std::to_string(count);
910 count++;
911 }
912 FunctionType *FT = FunctionType::get(Type::getVoidTy(M.getContext()),
913 {getInt8PtrTy(M.getContext()),
914 getInt8PtrTy(M.getContext()),
915 getInt8PtrTy(M.getContext())},
916 false);
917
918 Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
919
920 if (F->empty()) {
921 F->setLinkage(Function::LinkageTypes::InternalLinkage);
922 F->addFnAttr(Attribute::AlwaysInline);
925
926 BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
927 BasicBlock *error = BasicBlock::Create(M.getContext(), "error", F);
928 BasicBlock *end = BasicBlock::Create(M.getContext(), "end", F);
929
930 auto prim = F->arg_begin();
931 prim->setName("primal");
932 auto shadow = prim + 1;
933 shadow->setName("shadow");
934 auto msg = prim + 2;
935 msg->setName("msg");
936
937 IRBuilder<> EB(entry);
938 EB.CreateCondBr(EB.CreateICmpEQ(prim, shadow), error, end);
939
940 EB.SetInsertPoint(error);
941
943 CustomRuntimeInactiveError(wrap(&EB), wrap(msg), wrap(orig));
944 } else {
945 FunctionType *FT =
946 FunctionType::get(Type::getInt32Ty(M.getContext()),
947 {getInt8PtrTy(M.getContext())}, false);
948
949 auto PutsF = M.getOrInsertFunction("puts", FT);
950 EB.CreateCall(PutsF, msg);
951
952 FunctionType *FT2 =
953 FunctionType::get(Type::getVoidTy(M.getContext()),
954 {Type::getInt32Ty(M.getContext())}, false);
955
956 auto ExitF = M.getOrInsertFunction("exit", FT2);
957 EB.CreateCall(ExitF,
958 ConstantInt::get(Type::getInt32Ty(M.getContext()), 1));
959 }
960 EB.CreateUnreachable();
961
962 EB.SetInsertPoint(end);
963 EB.CreateRetVoid();
964 }
965
966 std::string Message2 = Message;
968 std::string str;
969 raw_string_ostream ss(str);
970 ss << Message << "\n";
971 emit_backtrace(orig, ss);
972 Message2 = ss.str();
973 }
974 Value *args[] = {B.CreatePointerCast(primal, getInt8PtrTy(M.getContext())),
975 B.CreatePointerCast(shadow, getInt8PtrTy(M.getContext())),
976 getString(M, Message2)};
977 auto call = B.CreateCall(F, args);
978 call->setDebugLoc(loc);
979}
980
981Type *BlasInfo::fpType(LLVMContext &ctx, bool to_scalar) const {
982 if (floatType == "d" || floatType == "D") {
983 return Type::getDoubleTy(ctx);
984 } else if (floatType == "s" || floatType == "S") {
985 return Type::getFloatTy(ctx);
986 } else if (floatType == "c" || floatType == "C") {
987 if (to_scalar)
988 return Type::getFloatTy(ctx);
989 return VectorType::get(Type::getFloatTy(ctx), 2, false);
990 } else if (floatType == "z" || floatType == "Z") {
991 if (to_scalar)
992 return Type::getDoubleTy(ctx);
993 return VectorType::get(Type::getDoubleTy(ctx), 2, false);
994 } else {
995 assert(false && "Unreachable");
996 return nullptr;
997 }
998}
999
1000IntegerType *BlasInfo::intType(LLVMContext &ctx) const {
1001 if (is64)
1002 return IntegerType::get(ctx, 64);
1003 else
1004 return IntegerType::get(ctx, 32);
1005}
1006
1007/// Create function for type that is equivalent to memcpy but adds to
1008/// destination rather than a direct copy; dst, src, numelems
1009Function *getOrInsertDifferentialFloatMemcpy(Module &M, Type *elementType,
1010 unsigned dstalign,
1011 unsigned srcalign,
1012 unsigned dstaddr, unsigned srcaddr,
1013 unsigned bitwidth) {
1014 assert(elementType->isFloatingPointTy());
1015 std::string name = "__enzyme_memcpy";
1016 if (bitwidth != 64)
1017 name += std::to_string(bitwidth);
1018 name += "add_" + tofltstr(elementType) + "da" + std::to_string(dstalign) +
1019 "sa" + std::to_string(srcalign);
1020 if (dstaddr)
1021 name += "dadd" + std::to_string(dstaddr);
1022 if (srcaddr)
1023 name += "sadd" + std::to_string(srcaddr);
1024 FunctionType *FT =
1025 FunctionType::get(Type::getVoidTy(M.getContext()),
1026 {PointerType::get(elementType, dstaddr),
1027 PointerType::get(elementType, srcaddr),
1028 IntegerType::get(M.getContext(), bitwidth)},
1029 false);
1030
1031 Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
1032
1033 if (!F->empty())
1034 return F;
1035
1036 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1037#if LLVM_VERSION_MAJOR >= 16
1038 F->setOnlyAccessesArgMemory();
1039#else
1040 F->addFnAttr(Attribute::ArgMemOnly);
1041#endif
1042 F->addFnAttr(Attribute::NoUnwind);
1043 F->addFnAttr(Attribute::AlwaysInline);
1046
1047 BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
1048 BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F);
1049 BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F);
1050
1051 auto dst = F->arg_begin();
1052 dst->setName("dst");
1053 auto src = dst + 1;
1054 src->setName("src");
1055 auto num = src + 1;
1056 num->setName("num");
1057
1058 {
1059 IRBuilder<> B(entry);
1060 B.CreateCondBr(B.CreateICmpEQ(num, ConstantInt::get(num->getType(), 0)),
1061 end, body);
1062 }
1063
1064 auto elSize = (M.getDataLayout().getTypeSizeInBits(elementType) + 7) / 8;
1065 {
1066 IRBuilder<> B(body);
1067 B.setFastMathFlags(getFast());
1068 PHINode *idx = B.CreatePHI(num->getType(), 2, "idx");
1069 idx->addIncoming(ConstantInt::get(num->getType(), 0), entry);
1070
1071 Value *dsti = B.CreateInBoundsGEP(elementType, dst, idx, "dst.i");
1072 LoadInst *dstl = B.CreateLoad(elementType, dsti, "dst.i.l");
1073 StoreInst *dsts = B.CreateStore(Constant::getNullValue(elementType), dsti);
1074
1075 if (dstalign) {
1076 // If the element size is already aligned to current alignment, do nothing
1077 // e.g. elsize = double = 8, dstalign = 2
1078 if (elSize % dstalign == 0) {
1079
1080 } else if (dstalign % elSize == 0) {
1081 // Otherwise if the dst alignment is a multiple of the element size,
1082 // use the element size as the new alignment. e.g. elsize = double = 8
1083 // and alignment = 16
1084 dstalign = elSize;
1085 } else {
1086 // else alignment only applies for first element, and we lose after all
1087 // other iterattions, assume nothing
1088 dstalign = 1;
1089 }
1090 }
1091
1092 if (srcalign) {
1093 // If the element size is already aligned to current alignment, do nothing
1094 // e.g. elsize = double = 8, dstalign = 2
1095 if (elSize % srcalign == 0) {
1096
1097 } else if (srcalign % elSize == 0) {
1098 // Otherwise if the dst alignment is a multiple of the element size,
1099 // use the element size as the new alignment. e.g. elsize = double = 8
1100 // and alignment = 16
1101 srcalign = elSize;
1102 } else {
1103 // else alignment only applies for first element, and we lose after all
1104 // other iterattions, assume nothing
1105 srcalign = 1;
1106 }
1107 }
1108
1109 if (dstalign) {
1110 dstl->setAlignment(Align(dstalign));
1111 dsts->setAlignment(Align(dstalign));
1112 }
1113
1114 Value *srci = B.CreateInBoundsGEP(elementType, src, idx, "src.i");
1115 LoadInst *srcl = B.CreateLoad(elementType, srci, "src.i.l");
1116 StoreInst *srcs = B.CreateStore(B.CreateFAdd(srcl, dstl), srci);
1117 if (srcalign) {
1118 srcl->setAlignment(Align(srcalign));
1119 srcs->setAlignment(Align(srcalign));
1120 }
1121
1122 Value *next =
1123 B.CreateNUWAdd(idx, ConstantInt::get(num->getType(), 1), "idx.next");
1124 idx->addIncoming(next, body);
1125 B.CreateCondBr(B.CreateICmpEQ(num, next), end, body);
1126 }
1127
1128 {
1129 IRBuilder<> B(end);
1130 B.CreateRetVoid();
1131 }
1132 return F;
1133}
1134
1135Value *lookup_with_layout(IRBuilder<> &B, Type *fpType, Value *layout,
1136 Value *const base, Value *lda, Value *row,
1137 Value *col) {
1138 Type *intType = row->getType();
1139 Value *is_row_maj =
1140 layout ? B.CreateICmpEQ(layout, ConstantInt::get(layout->getType(), 101))
1141 : B.getFalse();
1142 Value *offset = nullptr;
1143 if (col) {
1144 offset = B.CreateMul(
1145 row, CreateSelect(B, is_row_maj, lda, ConstantInt::get(intType, 1)));
1146 offset = B.CreateAdd(
1147 offset,
1148 B.CreateMul(col, CreateSelect(B, is_row_maj,
1149 ConstantInt::get(intType, 1), lda)));
1150 } else {
1151 offset = B.CreateMul(row, lda);
1152 }
1153 if (!base)
1154 return offset;
1155
1156 Value *ptr = base;
1157 if (base->getType()->isIntegerTy())
1158 ptr = B.CreateIntToPtr(ptr, getUnqual(fpType));
1159
1160#if LLVM_VERSION_MAJOR < 17
1161#if LLVM_VERSION_MAJOR >= 15
1162 if (ptr->getContext().supportsTypedPointers()) {
1163#endif
1164 if (fpType != ptr->getType()->getPointerElementType()) {
1165 ptr = B.CreatePointerCast(
1166 ptr,
1167 PointerType::get(
1168 fpType, cast<PointerType>(ptr->getType())->getAddressSpace()));
1169 }
1170#if LLVM_VERSION_MAJOR >= 15
1171 }
1172#endif
1173#endif
1174 ptr = B.CreateGEP(fpType, ptr, offset);
1175
1176 if (base->getType()->isIntegerTy()) {
1177 ptr = B.CreatePtrToInt(ptr, base->getType());
1178 } else if (ptr->getType() != base->getType()) {
1179 ptr = B.CreatePointerCast(ptr, base->getType());
1180 }
1181 return ptr;
1182}
1183
1184void copy_lower_to_upper(llvm::IRBuilder<> &B, llvm::Type *fpType,
1185 BlasInfo blas, bool byRef, llvm::Value *layout,
1186 llvm::Value *islower, llvm::Value *A, llvm::Value *lda,
1187 llvm::Value *N) {
1188
1189 const bool cublasv2 =
1190 blas.prefix == "cublas" && StringRef(blas.suffix).contains("v2");
1191
1192 const bool cublas = blas.prefix == "cublas";
1193 auto &M = *B.GetInsertBlock()->getParent()->getParent();
1194
1195 llvm::Type *intType = N->getType();
1196 // add spmv diag update call if not already present
1197 auto fnc_name = "__enzyme_copy_lower_to_upper" + blas.floatType +
1198 blas.prefix + blas.suffix;
1199
1200 SmallVector<Type *, 1> tys = {islower->getType(), A->getType(),
1201 lda->getType(), N->getType()};
1202 if (layout)
1203 tys.insert(tys.begin(), layout->getType());
1204 auto ltuFT = FunctionType::get(B.getVoidTy(), tys, false);
1205
1206 auto F0 = M.getOrInsertFunction(fnc_name, ltuFT);
1207
1208 SmallVector<Value *, 1> args = {islower, A, lda, N};
1209 if (layout)
1210 args.insert(args.begin(), layout);
1211 auto C = B.CreateCall(F0, args);
1212 auto F = getFunctionFromCall(C);
1213 assert(F);
1214 if (!F->empty()) {
1215 return;
1216 }
1217
1218 // now add the implementation for the call
1219 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1220#if LLVM_VERSION_MAJOR >= 16
1221 F->setOnlyAccessesArgMemory();
1222#else
1223 F->addFnAttr(Attribute::ArgMemOnly);
1224#endif
1225 F->addFnAttr(Attribute::NoUnwind);
1226 F->addFnAttr(Attribute::AlwaysInline);
1227 if (A->getType()->isPointerTy())
1228 addFunctionNoCapture(F, 1 + ((bool)layout));
1229
1230 BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
1231 BasicBlock *loop = BasicBlock::Create(M.getContext(), "loop", F);
1232 BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F);
1233
1234 auto arg = F->arg_begin();
1235 Argument *layoutarg = nullptr;
1236 if (layout) {
1237 layoutarg = arg;
1238 layoutarg->setName("layout");
1239 arg++;
1240 }
1241 auto islowerarg = arg;
1242 islowerarg->setName("islower");
1243 arg++;
1244 auto Aarg = arg;
1245 Aarg->setName("A");
1246 arg++;
1247 auto ldaarg = arg;
1248 ldaarg->setName("lda");
1249 arg++;
1250 auto Narg = arg;
1251 Narg->setName("N");
1252
1253 IRBuilder<> EB(entry);
1254
1255 auto one = ConstantInt::get(intType, 1);
1256 auto zero = ConstantInt::get(intType, 0);
1257
1258 Value *N_minus_1 = EB.CreateSub(Narg, one);
1259
1260 IRBuilder<> LB(loop);
1261
1262 auto i = LB.CreatePHI(intType, 2);
1263 i->addIncoming(zero, entry);
1264 auto i_plus_one = LB.CreateAdd(i, one, "", true, true);
1265 i->addIncoming(i_plus_one, loop);
1266
1267 Value *copyArgs[] = {
1268 to_blas_callconv(LB, LB.CreateSub(N_minus_1, i), byRef, cublas, nullptr,
1269 EB),
1270 lookup_with_layout(LB, fpType, layoutarg, Aarg, ldaarg,
1271 CreateSelect(LB, islowerarg, i_plus_one, i),
1272 CreateSelect(LB, islowerarg, i, i_plus_one)),
1274 LB,
1275 lookup_with_layout(LB, fpType, layoutarg, nullptr, ldaarg,
1276 CreateSelect(LB, islowerarg, one, zero),
1277 CreateSelect(LB, islowerarg, zero, one)),
1278 byRef, cublas, nullptr, EB),
1279 lookup_with_layout(LB, fpType, layoutarg, Aarg, ldaarg,
1280 CreateSelect(LB, islowerarg, i, i_plus_one),
1281 CreateSelect(LB, islowerarg, i_plus_one, i)),
1283 LB,
1284 lookup_with_layout(LB, fpType, layoutarg, nullptr, ldaarg,
1285 CreateSelect(LB, islowerarg, zero, one),
1286 CreateSelect(LB, islowerarg, one, zero)),
1287 byRef, cublas, nullptr, EB)};
1288
1289 Type *copyTys[] = {copyArgs[0]->getType(), copyArgs[1]->getType(),
1290 copyArgs[2]->getType(), copyArgs[3]->getType(),
1291 copyArgs[4]->getType()};
1292
1293 FunctionType *FT = FunctionType::get(B.getVoidTy(), copyTys, false);
1294
1295 auto copy_name = std::string(blas.prefix) + blas.floatType + "copy" +
1296 (cublasv2 ? "" : blas.suffix);
1297
1298 auto copyfn = M.getOrInsertFunction(copy_name, FT);
1299 LB.CreateCall(copyfn, copyArgs);
1300 if (auto F = GetFunctionFromValue(copyfn.getCallee()))
1302 LB.CreateCondBr(LB.CreateICmpEQ(i_plus_one, N_minus_1), end, loop);
1303
1304 EB.CreateCondBr(EB.CreateICmpSLE(N_minus_1, zero), end, loop);
1305 {
1306 IRBuilder<> B(end);
1307 B.CreateRetVoid();
1308 }
1309
1310 if (llvm::verifyFunction(*F, &llvm::errs())) {
1311 llvm::errs() << *F << "\n";
1312 report_fatal_error("helper function failed verification");
1313 }
1314}
1315
1316void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
1317 llvm::ArrayRef<llvm::Value *> args,
1318 llvm::Type *copy_retty,
1319 llvm::ArrayRef<llvm::OperandBundleDef> bundles) {
1320 const bool cublasv2 =
1321 blas.prefix == "cublas" && StringRef(blas.suffix).contains("v2");
1322 auto copy_name = std::string(blas.prefix) + blas.floatType + "copy" +
1323 (cublasv2 ? "" : blas.suffix);
1324
1325 SmallVector<Type *, 1> tys;
1326 for (auto arg : args)
1327 tys.push_back(arg->getType());
1328
1329 FunctionType *FT = FunctionType::get(copy_retty, tys, false);
1330 auto fn = M.getOrInsertFunction(copy_name, FT);
1331 B.CreateCall(fn, args, bundles);
1332 if (auto F = GetFunctionFromValue(fn.getCallee()))
1334}
1335
1336void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M,
1337 BlasInfo blas, llvm::ArrayRef<llvm::Value *> args,
1338 llvm::ArrayRef<llvm::OperandBundleDef> bundles) {
1339 auto copy_name =
1340 std::string(blas.prefix) + blas.floatType + "lacpy" + blas.suffix;
1341
1342 SmallVector<Type *, 1> tys;
1343 for (auto arg : args)
1344 tys.push_back(arg->getType());
1345
1346 auto FT = FunctionType::get(Type::getVoidTy(M.getContext()), tys, false);
1347 auto fn = M.getOrInsertFunction(copy_name, FT);
1348 B.CreateCall(fn, args, bundles);
1349
1350 if (auto F = GetFunctionFromValue(fn.getCallee()))
1352}
1353
1354void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas,
1355 IntegerType *IT, Type *BlasCT, Type *BlasFPT,
1356 Type *BlasPT, Type *BlasIT, Type *fpTy,
1357 ArrayRef<Value *> args,
1358 ArrayRef<OperandBundleDef> bundles, bool byRef,
1359 bool julia_decl) {
1360 // add spmv diag update call if not already present
1361 auto fnc_name = "__enzyme_spmv_diag" + blas.floatType + blas.suffix;
1362
1363 // spmvDiagHelper(uplo, n, alpha, x, incx, ya, incy, APa)
1364 auto FDiagUpdateT = FunctionType::get(
1365 B.getVoidTy(),
1366 {BlasCT, BlasIT, BlasFPT, BlasPT, BlasIT, BlasPT, BlasIT, BlasPT}, false);
1367 Function *F =
1368 cast<Function>(M.getOrInsertFunction(fnc_name, FDiagUpdateT).getCallee());
1369
1370 if (!F->empty()) {
1371 B.CreateCall(F, args, bundles);
1372 return;
1373 }
1374
1375 // now add the implementation for the call
1376 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1377#if LLVM_VERSION_MAJOR >= 16
1378 F->setOnlyAccessesArgMemory();
1379#else
1380 F->addFnAttr(Attribute::ArgMemOnly);
1381#endif
1382 F->addFnAttr(Attribute::NoUnwind);
1383 F->addFnAttr(Attribute::AlwaysInline);
1384 if (!julia_decl) {
1388 F->addParamAttr(3, Attribute::NoAlias);
1389 F->addParamAttr(5, Attribute::NoAlias);
1390 F->addParamAttr(7, Attribute::NoAlias);
1391 F->addParamAttr(3, Attribute::ReadOnly);
1392 F->addParamAttr(5, Attribute::ReadOnly);
1393 if (byRef) {
1395 F->addParamAttr(2, Attribute::NoAlias);
1396 F->addParamAttr(2, Attribute::ReadOnly);
1397 }
1398 }
1399
1400 BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
1401 BasicBlock *init = BasicBlock::Create(M.getContext(), "init", F);
1402 BasicBlock *uper_code = BasicBlock::Create(M.getContext(), "uper", F);
1403 BasicBlock *lower_code = BasicBlock::Create(M.getContext(), "lower", F);
1404 BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F);
1405
1406 // spmvDiagHelper(uplo, n, alpha, x, incx, ya, incy, APa)
1407 auto blasuplo = F->arg_begin();
1408 blasuplo->setName("blasuplo");
1409 auto blasn = blasuplo + 1;
1410 blasn->setName("blasn");
1411 auto blasalpha = blasn + 1;
1412 blasalpha->setName("blasalpha");
1413 auto blasx = blasalpha + 1;
1414 blasx->setName("blasx");
1415 auto blasincx = blasx + 1;
1416 blasincx->setName("blasincx");
1417 auto blasdy = blasx + 1;
1418 blasdy->setName("blasdy");
1419 auto blasincy = blasdy + 1;
1420 blasincy->setName("blasincy");
1421 auto blasdAP = blasincy + 1;
1422 blasdAP->setName("blasdAP");
1423
1424 // TODO: consider cblas_layout
1425
1426 // https://dl.acm.org/doi/pdf/10.1145/3382191
1427 // Following example is Fortran based, thus 1 indexed
1428 // if(uplo == 'u' .or. uplo == 'U') then
1429 // k = 0
1430 // do i = 1,n
1431 // k = k+i
1432 // APa(k) = APa(k) - alpha*x(1 + (i-1)*incx)*ya(1 + (i-1)*incy)
1433 // end do
1434 // else
1435 // k = 1
1436 // do i = 1,n
1437 // APa(k) = APa(k) - alpha*x(1 + (i-1)*incx)*ya(1 + (i-1)*incy)
1438 // k = k+n-i+1
1439 // end do
1440 // end if
1441 {
1442 IRBuilder<> B1(entry);
1443 Value *n = load_if_ref(B1, IT, blasn, byRef);
1444 Value *incx = load_if_ref(B1, IT, blasincx, byRef);
1445 Value *incy = load_if_ref(B1, IT, blasincy, byRef);
1446 Value *alpha = blasalpha;
1447 if (byRef) {
1448 auto VP = B1.CreatePointerCast(
1449 blasalpha,
1450 PointerType::get(
1451 fpTy,
1452 cast<PointerType>(blasalpha->getType())->getAddressSpace()));
1453 alpha = B1.CreateLoad(fpTy, VP);
1454 }
1455 Value *is_l = is_lower(B1, blasuplo, byRef, /*cublas*/ false);
1456 B1.CreateCondBr(B1.CreateICmpEQ(n, ConstantInt::get(IT, 0)), end, init);
1457
1458 IRBuilder<> B2(init);
1459 Value *xfloat = B2.CreatePointerCast(
1460 blasx,
1461 PointerType::get(
1462 fpTy, cast<PointerType>(blasx->getType())->getAddressSpace()));
1463 Value *dyfloat = B2.CreatePointerCast(
1464 blasdy,
1465 PointerType::get(
1466 fpTy, cast<PointerType>(blasdy->getType())->getAddressSpace()));
1467 Value *dAPfloat = B2.CreatePointerCast(
1468 blasdAP,
1469 PointerType::get(
1470 fpTy, cast<PointerType>(blasdAP->getType())->getAddressSpace()));
1471 B2.CreateCondBr(is_l, lower_code, uper_code);
1472
1473 IRBuilder<> B3(uper_code);
1474 B3.setFastMathFlags(getFast());
1475 {
1476 PHINode *iter = B3.CreatePHI(IT, 2, "iteration");
1477 PHINode *kval = B3.CreatePHI(IT, 2, "k");
1478 iter->addIncoming(ConstantInt::get(IT, 0), init);
1479 kval->addIncoming(ConstantInt::get(IT, 0), init);
1480 Value *iternext =
1481 B3.CreateAdd(iter, ConstantInt::get(IT, 1), "iter.next");
1482 // 0, 2, 5, 9, 14, 20, 27, 35, 44, 54, ... are diag elements
1483 Value *kvalnext = B3.CreateAdd(kval, iternext, "k.next");
1484 iter->addIncoming(iternext, uper_code);
1485 kval->addIncoming(kvalnext, uper_code);
1486
1487 Value *xidx = B3.CreateNUWMul(iter, incx, "x.idx");
1488 Value *yidx = B3.CreateNUWMul(iter, incy, "y.idx");
1489 Value *x = B3.CreateInBoundsGEP(fpTy, xfloat, xidx, "x.ptr");
1490 Value *y = B3.CreateInBoundsGEP(fpTy, dyfloat, yidx, "y.ptr");
1491 Value *xval = B3.CreateLoad(fpTy, x, "x.val");
1492 Value *yval = B3.CreateLoad(fpTy, y, "y.val");
1493 Value *xy = B3.CreateFMul(xval, yval, "xy");
1494 Value *xyalpha = B3.CreateFMul(xy, alpha, "xy.alpha");
1495 Value *kptr = B3.CreateInBoundsGEP(fpTy, dAPfloat, kval, "k.ptr");
1496 Value *kvalloaded = B3.CreateLoad(fpTy, kptr, "k.val");
1497 Value *kvalnew = B3.CreateFSub(kvalloaded, xyalpha, "k.val.new");
1498 B3.CreateStore(kvalnew, kptr);
1499
1500 B3.CreateCondBr(B3.CreateICmpEQ(iternext, n), end, uper_code);
1501 }
1502
1503 IRBuilder<> B4(lower_code);
1504 B4.setFastMathFlags(getFast());
1505 {
1506 PHINode *iter = B4.CreatePHI(IT, 2, "iteration");
1507 PHINode *kval = B4.CreatePHI(IT, 2, "k");
1508 iter->addIncoming(ConstantInt::get(IT, 0), init);
1509 kval->addIncoming(ConstantInt::get(IT, 0), init);
1510 Value *iternext =
1511 B4.CreateAdd(iter, ConstantInt::get(IT, 1), "iter.next");
1512 Value *ktmp = B4.CreateAdd(n, ConstantInt::get(IT, 1), "tmp.val");
1513 Value *ktmp2 = B4.CreateSub(ktmp, iternext, "tmp.val.other");
1514 Value *kvalnext = B4.CreateAdd(kval, ktmp2, "k.next");
1515 iter->addIncoming(iternext, lower_code);
1516 kval->addIncoming(kvalnext, lower_code);
1517
1518 Value *xidx = B4.CreateNUWMul(iter, incx, "x.idx");
1519 Value *yidx = B4.CreateNUWMul(iter, incy, "y.idx");
1520 Value *x = B4.CreateInBoundsGEP(fpTy, xfloat, xidx, "x.ptr");
1521 Value *y = B4.CreateInBoundsGEP(fpTy, dyfloat, yidx, "y.ptr");
1522 Value *xval = B4.CreateLoad(fpTy, x, "x.val");
1523 Value *yval = B4.CreateLoad(fpTy, y, "y.val");
1524 Value *xy = B4.CreateFMul(xval, yval, "xy");
1525 Value *xyalpha = B4.CreateFMul(xy, alpha, "xy.alpha");
1526 Value *kptr = B4.CreateInBoundsGEP(fpTy, dAPfloat, kval, "k.ptr");
1527 Value *kvalloaded = B4.CreateLoad(fpTy, kptr, "k.val");
1528 Value *kvalnew = B4.CreateFSub(kvalloaded, xyalpha, "k.val.new");
1529 B4.CreateStore(kvalnew, kptr);
1530
1531 B4.CreateCondBr(B4.CreateICmpEQ(iternext, n), end, lower_code);
1532 }
1533
1534 IRBuilder<> B5(end);
1535 B5.CreateRetVoid();
1536 }
1537 B.CreateCall(F, args, bundles);
1538 return;
1539}
1540
1541llvm::CallInst *
1542getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
1543 IntegerType *IT, Type *BlasPT, Type *BlasIT, Type *fpTy,
1544 llvm::ArrayRef<llvm::Value *> args,
1545 const llvm::ArrayRef<llvm::OperandBundleDef> bundles,
1546 bool byRef, bool cublas, bool julia_decl) {
1547 assert(fpTy->isFloatingPointTy());
1548
1549 // add inner_prod call if not already present
1550 std::string prod_name = "__enzyme_inner_prod" + blas.floatType + blas.suffix;
1551 auto FInnerProdT =
1552 FunctionType::get(fpTy, {BlasIT, BlasIT, BlasPT, BlasIT, BlasPT}, false);
1553 Function *F =
1554 cast<Function>(M.getOrInsertFunction(prod_name, FInnerProdT).getCallee());
1555
1556 if (!F->empty())
1557 return B.CreateCall(F, args, bundles);
1558
1559 // add dot call if not already present
1560 std::string dot_name = blas.prefix + blas.floatType + "dot" + blas.suffix;
1561 auto FDotT =
1562 FunctionType::get(fpTy, {BlasIT, BlasPT, BlasIT, BlasPT, BlasIT}, false);
1563 auto FDot = M.getOrInsertFunction(dot_name, FDotT);
1564
1565 // now add the implementation for the inner_prod call
1566 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1567#if LLVM_VERSION_MAJOR >= 16
1568 F->setOnlyAccessesArgMemory();
1569 F->setOnlyReadsMemory();
1570#else
1571 F->addFnAttr(Attribute::ArgMemOnly);
1572 F->addFnAttr(Attribute::ReadOnly);
1573#endif
1574 F->addFnAttr(Attribute::NoUnwind);
1575 F->addFnAttr(Attribute::AlwaysInline);
1576 if (!julia_decl) {
1579 F->addParamAttr(2, Attribute::NoAlias);
1580 F->addParamAttr(4, Attribute::NoAlias);
1581 F->addParamAttr(2, Attribute::ReadOnly);
1582 F->addParamAttr(4, Attribute::ReadOnly);
1583 }
1584
1585 BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
1586 BasicBlock *init = BasicBlock::Create(M.getContext(), "init.idx", F);
1587 BasicBlock *fastPath = BasicBlock::Create(M.getContext(), "fast.path", F);
1588 BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F);
1589 BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F);
1590
1591 // This is the .td declaration which we need to match
1592 // No need to support ld for the second matrix, as it will
1593 // always be based on a matrix which we allocated (contiguous)
1594 //(FrobInnerProd<> $m, $n, adj<"C">, $ldc, use<"AB">)
1595
1596 auto blasm = F->arg_begin();
1597 blasm->setName("blasm");
1598 auto blasn = blasm + 1;
1599 blasn->setName("blasn");
1600 auto matA = blasn + 1;
1601 matA->setName("A");
1602 auto blaslda = matA + 1;
1603 blaslda->setName("lda");
1604 auto matB = blaslda + 1;
1605 matB->setName("B");
1606
1607 {
1608 IRBuilder<> B1(entry);
1609 Value *blasOne = to_blas_callconv(B1, ConstantInt::get(IT, 1), byRef,
1610 cublas, nullptr, B1, "constant.one");
1611
1612 if (blasOne->getType() != BlasIT)
1613 blasOne = B1.CreatePointerCast(blasOne, BlasIT, "intcast.constant.one");
1614
1615 Value *m = load_if_ref(B1, IT, blasm, byRef);
1616 Value *n = load_if_ref(B1, IT, blasn, byRef);
1617 Value *size = B1.CreateNUWMul(m, n, "mat.size");
1618 Value *blasSize = to_blas_callconv(
1619 B1, size, byRef, cublas, julia_decl ? IT : nullptr, B1, "mat.size");
1620
1621 if (blasSize->getType() != BlasIT)
1622 blasSize = B1.CreatePointerCast(blasSize, BlasIT, "intcast.mat.size");
1623 B1.CreateCondBr(B1.CreateICmpEQ(size, ConstantInt::get(IT, 0)), end, init);
1624
1625 IRBuilder<> B2(init);
1626 B2.setFastMathFlags(getFast());
1627 Value *lda = load_if_ref(B2, IT, blaslda, byRef);
1628 Value *Afloat = B2.CreatePointerCast(
1629 matA, PointerType::get(
1630 fpTy, cast<PointerType>(matA->getType())->getAddressSpace()));
1631 Value *Bfloat = B2.CreatePointerCast(
1632 matB, PointerType::get(
1633 fpTy, cast<PointerType>(matB->getType())->getAddressSpace()));
1634 B2.CreateCondBr(B2.CreateICmpEQ(m, lda), fastPath, body);
1635
1636 // our second matrix is always continuos, by construction.
1637 // If our first matrix is continuous too (lda == m), then we can
1638 // use a single dot call.
1639 IRBuilder<> B3(fastPath);
1640 B3.setFastMathFlags(getFast());
1641 Value *blasA = B3.CreatePointerCast(matA, BlasPT);
1642 Value *blasB = B3.CreatePointerCast(matB, BlasPT);
1643 Value *fastSum =
1644 B3.CreateCall(FDot, {blasSize, blasA, blasOne, blasB, blasOne});
1645 B3.CreateBr(end);
1646
1647 IRBuilder<> B4(body);
1648 B4.setFastMathFlags(getFast());
1649 PHINode *Aidx = B4.CreatePHI(IT, 2, "Aidx");
1650 PHINode *Bidx = B4.CreatePHI(IT, 2, "Bidx");
1651 PHINode *iter = B4.CreatePHI(IT, 2, "iteration");
1652 PHINode *sum = B4.CreatePHI(fpTy, 2, "sum");
1653 Aidx->addIncoming(ConstantInt::get(IT, 0), init);
1654 Bidx->addIncoming(ConstantInt::get(IT, 0), init);
1655 iter->addIncoming(ConstantInt::get(IT, 0), init);
1656 sum->addIncoming(ConstantFP::get(fpTy, 0.0), init);
1657
1658 Value *Ai = B4.CreateInBoundsGEP(fpTy, Afloat, Aidx, "A.i");
1659 Value *Bi = B4.CreateInBoundsGEP(fpTy, Bfloat, Bidx, "B.i");
1660 Value *AiDot = B4.CreatePointerCast(Ai, BlasPT);
1661 Value *BiDot = B4.CreatePointerCast(Bi, BlasPT);
1662 Value *newDot =
1663 B4.CreateCall(FDot, {blasm, AiDot, blasOne, BiDot, blasOne});
1664
1665 Value *Anext = B4.CreateNUWAdd(Aidx, lda, "Aidx.next");
1666 Value *Bnext = B4.CreateNUWAdd(Aidx, m, "Bidx.next");
1667 Value *iternext = B4.CreateAdd(iter, ConstantInt::get(IT, 1), "iter.next");
1668 Value *sumnext = B4.CreateFAdd(sum, newDot);
1669
1670 iter->addIncoming(iternext, body);
1671 Aidx->addIncoming(Anext, body);
1672 Bidx->addIncoming(Bnext, body);
1673 sum->addIncoming(sumnext, body);
1674
1675 B4.CreateCondBr(B4.CreateICmpEQ(iter, n), end, body);
1676
1677 IRBuilder<> B5(end);
1678 PHINode *res = B5.CreatePHI(fpTy, 3, "res");
1679 res->addIncoming(ConstantFP::get(fpTy, 0.0), entry);
1680 res->addIncoming(sum, body);
1681 res->addIncoming(fastSum, fastPath);
1682 B5.CreateRet(res);
1683 }
1684
1685 auto res = B.CreateCall(F, args, bundles);
1686 if (auto F = GetFunctionFromValue(FDot.getCallee()))
1688 return res;
1689}
1690
1691Function *getOrInsertMemcpyStrided(Module &M, Type *elementType, PointerType *T,
1692 Type *IT, unsigned dstalign,
1693 unsigned srcalign) {
1694 assert(elementType->isFloatingPointTy());
1695 std::string name = "__enzyme_memcpy_" + tofltstr(elementType) + "_" +
1696 std::to_string(cast<IntegerType>(IT)->getBitWidth()) +
1697 "_da" + std::to_string(dstalign) + "sa" +
1698 std::to_string(srcalign) + "stride";
1699 FunctionType *FT =
1700 FunctionType::get(Type::getVoidTy(M.getContext()), {T, T, IT, IT}, false);
1701
1702 Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
1703
1704 if (!F->empty())
1705 return F;
1706
1707 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1708#if LLVM_VERSION_MAJOR >= 16
1709 F->setOnlyAccessesArgMemory();
1710#else
1711 F->addFnAttr(Attribute::ArgMemOnly);
1712#endif
1713 F->addFnAttr(Attribute::NoUnwind);
1714 F->addFnAttr(Attribute::AlwaysInline);
1716 F->addParamAttr(0, Attribute::NoAlias);
1718 F->addParamAttr(1, Attribute::NoAlias);
1719 F->addParamAttr(0, Attribute::WriteOnly);
1720 F->addParamAttr(1, Attribute::ReadOnly);
1721
1722 BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
1723 BasicBlock *init = BasicBlock::Create(M.getContext(), "init.idx", F);
1724 BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F);
1725 BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F);
1726
1727 auto dst = F->arg_begin();
1728 dst->setName("dst");
1729 auto src = dst + 1;
1730 src->setName("src");
1731 auto num = src + 1;
1732 num->setName("num");
1733 auto stride = num + 1;
1734 stride->setName("stride");
1735
1736 {
1737 IRBuilder<> B(entry);
1738 B.CreateCondBr(B.CreateICmpEQ(num, ConstantInt::get(num->getType(), 0)),
1739 end, init);
1740 }
1741
1742 {
1743 IRBuilder<> B2(init);
1744 B2.setFastMathFlags(getFast());
1745 Value *a = B2.CreateNSWSub(ConstantInt::get(num->getType(), 1), num, "a");
1746 Value *negidx = B2.CreateNSWMul(a, stride, "negidx");
1747 // Value *negidx =
1748 // B2.CreateNSWAdd(b, ConstantInt::get(num->getType(), 1),
1749 // "negidx");
1750 Value *isneg =
1751 B2.CreateICmpSLT(stride, ConstantInt::get(num->getType(), 0), "is.neg");
1752 Value *startidx = B2.CreateSelect(
1753 isneg, negidx, ConstantInt::get(num->getType(), 0), "startidx");
1754 B2.CreateBr(body);
1755 //}
1756
1757 //{
1758 IRBuilder<> B(body);
1759 B.setFastMathFlags(getFast());
1760 PHINode *idx = B.CreatePHI(num->getType(), 2, "idx");
1761 PHINode *sidx = B.CreatePHI(num->getType(), 2, "sidx");
1762 idx->addIncoming(ConstantInt::get(num->getType(), 0), init);
1763 sidx->addIncoming(startidx, init);
1764
1765 Value *dsti = B.CreateInBoundsGEP(elementType, dst, idx, "dst.i");
1766 Value *srci = B.CreateInBoundsGEP(elementType, src, sidx, "src.i");
1767 LoadInst *srcl = B.CreateLoad(elementType, srci, "src.i.l");
1768 StoreInst *dsts = B.CreateStore(srcl, dsti);
1769
1770 if (dstalign) {
1771 dsts->setAlignment(Align(dstalign));
1772 }
1773 if (srcalign) {
1774 srcl->setAlignment(Align(srcalign));
1775 }
1776
1777 Value *next =
1778 B.CreateNSWAdd(idx, ConstantInt::get(num->getType(), 1), "idx.next");
1779 Value *snext = B.CreateNSWAdd(sidx, stride, "sidx.next");
1780 idx->addIncoming(next, body);
1781 sidx->addIncoming(snext, body);
1782 B.CreateCondBr(B.CreateICmpEQ(num, next), end, body);
1783 }
1784
1785 {
1786 IRBuilder<> B(end);
1787 B.CreateRetVoid();
1788 }
1789
1790 return F;
1791}
1792
1793Function *getOrInsertMemcpyMat(Module &Mod, Type *elementType, PointerType *PT,
1794 IntegerType *IT, unsigned dstalign,
1795 unsigned srcalign) {
1796 assert(elementType->isFPOrFPVectorTy());
1797#if LLVM_VERSION_MAJOR < 17
1798#if LLVM_VERSION_MAJOR >= 15
1799 if (Mod.getContext().supportsTypedPointers()) {
1800#endif
1801#if LLVM_VERSION_MAJOR >= 13
1802 if (!PT->isOpaquePointerTy())
1803#endif
1804 assert(PT->getPointerElementType() == elementType);
1805#if LLVM_VERSION_MAJOR >= 15
1806 }
1807#endif
1808#endif
1809 std::string name = "__enzyme_memcpy_" + tofltstr(elementType) + "_mat_" +
1810 std::to_string(cast<IntegerType>(IT)->getBitWidth());
1811 FunctionType *FT = FunctionType::get(Type::getVoidTy(Mod.getContext()),
1812 {PT, PT, IT, IT, IT}, false);
1813
1814 Function *F = cast<Function>(Mod.getOrInsertFunction(name, FT).getCallee());
1815
1816 if (!F->empty())
1817 return F;
1818
1819 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1820#if LLVM_VERSION_MAJOR >= 16
1821 F->setOnlyAccessesArgMemory();
1822#else
1823 F->addFnAttr(Attribute::ArgMemOnly);
1824#endif
1825 F->addFnAttr(Attribute::NoUnwind);
1826 F->addFnAttr(Attribute::AlwaysInline);
1828 F->addParamAttr(0, Attribute::NoAlias);
1830 F->addParamAttr(1, Attribute::NoAlias);
1831 F->addParamAttr(0, Attribute::WriteOnly);
1832 F->addParamAttr(1, Attribute::ReadOnly);
1833
1834 BasicBlock *entry = BasicBlock::Create(F->getContext(), "entry", F);
1835 BasicBlock *init = BasicBlock::Create(F->getContext(), "init.idx", F);
1836 BasicBlock *body = BasicBlock::Create(F->getContext(), "for.body", F);
1837 BasicBlock *initend = BasicBlock::Create(F->getContext(), "init.end", F);
1838 BasicBlock *end = BasicBlock::Create(F->getContext(), "for.end", F);
1839
1840 auto dst = F->arg_begin();
1841 dst->setName("dst");
1842 auto src = dst + 1;
1843 src->setName("src");
1844 auto M = src + 1;
1845 M->setName("M");
1846 auto N = M + 1;
1847 N->setName("N");
1848 auto LDA = N + 1;
1849 LDA->setName("LDA");
1850
1851 {
1852 IRBuilder<> B(entry);
1853 Value *l0 = B.CreateICmpEQ(M, ConstantInt::get(IT, 0));
1854 Value *l1 = B.CreateICmpEQ(N, ConstantInt::get(IT, 0));
1855 // Don't copy a 0*0 matrix
1856 B.CreateCondBr(B.CreateOr(l0, l1), end, init);
1857 }
1858
1859 PHINode *j;
1860 {
1861 IRBuilder<> B(init);
1862 j = B.CreatePHI(IT, 2, "j");
1863 j->addIncoming(ConstantInt::get(IT, 0), entry);
1864 B.CreateBr(body);
1865 }
1866
1867 {
1868 IRBuilder<> B(body);
1869 PHINode *i = B.CreatePHI(IT, 2, "i");
1870 i->addIncoming(ConstantInt::get(IT, 0), init);
1871
1872 Value *dsti = B.CreateInBoundsGEP(
1873 elementType, dst,
1874 B.CreateAdd(i, B.CreateMul(j, M, "", true, true), "", true, true),
1875 "dst.i");
1876 Value *srci = B.CreateInBoundsGEP(
1877 elementType, src,
1878 B.CreateAdd(i, B.CreateMul(j, LDA, "", true, true), "", true, true),
1879 "dst.i");
1880 LoadInst *srcl = B.CreateLoad(elementType, srci, "src.i.l");
1881
1882 StoreInst *dsts = B.CreateStore(srcl, dsti);
1883
1884 if (dstalign) {
1885 dsts->setAlignment(Align(dstalign));
1886 }
1887 if (srcalign) {
1888 srcl->setAlignment(Align(srcalign));
1889 }
1890
1891 Value *nexti =
1892 B.CreateAdd(i, ConstantInt::get(IT, 1), "i.next", true, true);
1893 i->addIncoming(nexti, body);
1894 B.CreateCondBr(B.CreateICmpEQ(nexti, M), initend, body);
1895 }
1896
1897 {
1898 IRBuilder<> B(initend);
1899 Value *nextj =
1900 B.CreateAdd(j, ConstantInt::get(IT, 1), "j.next", true, true);
1901 j->addIncoming(nextj, initend);
1902 B.CreateCondBr(B.CreateICmpEQ(nextj, N), end, init);
1903 }
1904
1905 {
1906 IRBuilder<> B(end);
1907 B.CreateRetVoid();
1908 }
1909
1910 return F;
1911}
1912
1914 Module &Mod, Type *elementType, PointerType *PT, IntegerType *IT,
1915 IntegerType *CT, unsigned dstalign, unsigned srcalign, bool zeroSrc) {
1916 assert(elementType->isFPOrFPVectorTy());
1917#if LLVM_VERSION_MAJOR < 17
1918#if LLVM_VERSION_MAJOR >= 15
1919 if (Mod.getContext().supportsTypedPointers()) {
1920#endif
1921#if LLVM_VERSION_MAJOR >= 13
1922 if (!PT->isOpaquePointerTy())
1923#endif
1924 assert(PT->getPointerElementType() == elementType);
1925#if LLVM_VERSION_MAJOR >= 15
1926 }
1927#endif
1928#endif
1929 std::string name = "__enzyme_dmemcpy_" + tofltstr(elementType) + "_mat_" +
1930 std::to_string(cast<IntegerType>(IT)->getBitWidth()) +
1931 (zeroSrc ? "_zero" : "");
1932 FunctionType *FT = FunctionType::get(Type::getVoidTy(Mod.getContext()),
1933 {CT, IT, IT, PT, IT, PT, IT}, false);
1934
1935 Function *F = cast<Function>(Mod.getOrInsertFunction(name, FT).getCallee());
1936
1937 if (!F->empty())
1938 return F;
1939
1940 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1941#if LLVM_VERSION_MAJOR >= 16
1942 F->setOnlyAccessesArgMemory();
1943#else
1944 F->addFnAttr(Attribute::ArgMemOnly);
1945#endif
1946 F->addFnAttr(Attribute::NoUnwind);
1947 F->addFnAttr(Attribute::AlwaysInline);
1948 F->addParamAttr(3, Attribute::NoAlias);
1949 F->addParamAttr(5, Attribute::NoAlias);
1950
1951 BasicBlock *entry = BasicBlock::Create(F->getContext(), "entry", F);
1952 BasicBlock *swtch = BasicBlock::Create(F->getContext(), "swtch", F);
1953 BasicBlock *Ginit = BasicBlock::Create(F->getContext(), "Ginit.idx", F);
1954 BasicBlock *Uinit = BasicBlock::Create(F->getContext(), "Uinit.idx", F);
1955 BasicBlock *Linit = BasicBlock::Create(F->getContext(), "Linit.idx", F);
1956 BasicBlock *end = BasicBlock::Create(F->getContext(), "for.end", F);
1957
1958 auto uplo = F->arg_begin();
1959 uplo->setName("uplo");
1960 auto M = uplo + 1;
1961 M->setName("M");
1962 auto N = M + 1;
1963 N->setName("N");
1964
1965 auto dst = N + 1;
1966 dst->setName("dst");
1967 auto ldst = dst + 1;
1968 ldst->setName("ldst");
1969 auto src = ldst + 1;
1970 src->setName("src");
1971 auto lsrc = src + 1;
1972 lsrc->setName("lsrc");
1973
1974 {
1975 IRBuilder<> B(entry);
1976 Value *l0 = B.CreateICmpEQ(M, ConstantInt::get(IT, 0));
1977 Value *l1 = B.CreateICmpEQ(N, ConstantInt::get(IT, 0));
1978 // Don't copy a 0*0 matrix
1979 B.CreateCondBr(B.CreateOr(l0, l1), end, swtch);
1980 }
1981
1982 {
1983 IRBuilder<> B(swtch);
1984 auto swtchT = B.CreateSwitch(uplo, Ginit);
1985 swtchT->addCase(ConstantInt::get(CT, 'U'), Uinit);
1986 swtchT->addCase(ConstantInt::get(CT, 'L'), Linit);
1987 }
1988
1989 std::pair<char, BasicBlock *> todo[] = {
1990 {'G', Ginit}, {'U', Uinit}, {'L', Linit}};
1991 for (auto &&[direction, init] : todo) {
1992
1993 std::string dir(1, direction);
1994 BasicBlock *body = BasicBlock::Create(F->getContext(), dir + "for.body", F);
1995 BasicBlock *initend =
1996 BasicBlock::Create(F->getContext(), dir + "init.end", F);
1997
1998 Value *istart = ConstantInt::get(IT, 0);
1999 Value *iend = M;
2000
2001 PHINode *j;
2002 {
2003 IRBuilder<> B(init);
2004 j = B.CreatePHI(IT, 2, dir + "j");
2005 j->addIncoming(ConstantInt::get(IT, 0), swtch);
2006
2007 if (direction == 'L') {
2008 istart = j;
2009 } else if (direction == 'U') {
2010 auto jp1 = B.CreateAdd(j, ConstantInt::get(IT, 1), "", true, true);
2011 iend = B.CreateSelect(B.CreateICmpULT(jp1, M), jp1, M);
2012 }
2013
2014 B.CreateBr(body);
2015 }
2016
2017 {
2018 IRBuilder<> B(body);
2019 PHINode *i = B.CreatePHI(IT, 2, dir + "i");
2020 i->addIncoming(istart, init);
2021
2022 Value *srci = B.CreateInBoundsGEP(
2023 elementType, src,
2024 B.CreateAdd(i, B.CreateMul(j, lsrc, "", true, true), "", true, true),
2025 dir + "src.i");
2026
2027 Value *dsti = B.CreateInBoundsGEP(
2028 elementType, dst,
2029 B.CreateAdd(i, B.CreateMul(j, ldst, "", true, true), "", true, true),
2030 dir + "dst.i");
2031 LoadInst *srcl = B.CreateLoad(elementType, srci, dir + "src.i.l");
2032 LoadInst *dstl = B.CreateLoad(elementType, dsti, dir + "dst.i.l");
2033 auto res = B.CreateFAdd(srcl, dstl);
2034 StoreInst *dsts = B.CreateStore(res, dsti);
2035 StoreInst *srcs = nullptr;
2036 if (zeroSrc)
2037 srcs = B.CreateStore(Constant::getNullValue(res->getType()), srci);
2038 if (dstalign) {
2039 dsts->setAlignment(Align(dstalign));
2040 dstl->setAlignment(Align(dstalign));
2041 }
2042 if (srcalign) {
2043 if (zeroSrc)
2044 srcs->setAlignment(Align(srcalign));
2045 srcl->setAlignment(Align(srcalign));
2046 }
2047
2048 Value *nexti =
2049 B.CreateAdd(i, ConstantInt::get(IT, 1), dir + "i.next", true, true);
2050 i->addIncoming(nexti, body);
2051 B.CreateCondBr(B.CreateICmpEQ(nexti, iend), initend, body);
2052 }
2053
2054 {
2055 IRBuilder<> B(initend);
2056 Value *nextj =
2057 B.CreateAdd(j, ConstantInt::get(IT, 1), dir + "j.next", true, true);
2058 j->addIncoming(nextj, initend);
2059 B.CreateCondBr(B.CreateICmpEQ(nextj, N), end, init);
2060 }
2061 }
2062
2063 {
2064 IRBuilder<> B(end);
2065 B.CreateRetVoid();
2066 }
2067
2068 return F;
2069}
2070
2071// TODO implement differential memmove
2072Function *
2073getOrInsertDifferentialFloatMemmove(Module &M, Type *T, unsigned dstalign,
2074 unsigned srcalign, unsigned dstaddr,
2075 unsigned srcaddr, unsigned bitwidth) {
2077 llvm::errs()
2078 << "warning: didn't implement memmove, using memcpy as fallback "
2079 "which can result in errors\n";
2080 return getOrInsertDifferentialFloatMemcpy(M, T, dstalign, srcalign, dstaddr,
2081 srcaddr, bitwidth);
2082}
2083
2084Function *getOrInsertCheckedFree(Module &M, CallInst *call, Type *Ty,
2085 unsigned width) {
2086 FunctionType *FreeTy = call->getFunctionType();
2087 Value *Free = call->getCalledOperand();
2088 AttributeList FreeAttributes = call->getAttributes();
2089 CallingConv::ID CallingConvention = call->getCallingConv();
2090
2091 std::string name = "__enzyme_checked_free_" + std::to_string(width);
2092
2093 auto callname = getFuncNameFromCall(call);
2094 if (callname != "free")
2095 name += "_" + callname.str();
2096
2097 SmallVector<Type *, 3> types;
2098 types.push_back(Ty);
2099 for (unsigned i = 0; i < width; i++) {
2100 types.push_back(Ty);
2101 }
2102#if LLVM_VERSION_MAJOR >= 14
2103 for (size_t i = 1; i < call->arg_size(); i++)
2104#else
2105 for (size_t i = 1; i < call->getNumArgOperands(); i++)
2106#endif
2107 {
2108 types.push_back(call->getArgOperand(i)->getType());
2109 }
2110
2111 FunctionType *FT =
2112 FunctionType::get(Type::getVoidTy(M.getContext()), types, false);
2113
2114 Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
2115
2116 if (!F->empty())
2117 return F;
2118
2119 F->setLinkage(Function::LinkageTypes::InternalLinkage);
2120#if LLVM_VERSION_MAJOR >= 16
2121 F->setOnlyAccessesArgMemory();
2122#else
2123 F->addFnAttr(Attribute::ArgMemOnly);
2124#endif
2125 F->addFnAttr(Attribute::NoUnwind);
2126 F->addFnAttr(Attribute::AlwaysInline);
2127
2128 BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
2129 BasicBlock *free0 = BasicBlock::Create(M.getContext(), "free0", F);
2130 BasicBlock *end = BasicBlock::Create(M.getContext(), "end", F);
2131
2132 IRBuilder<> EntryBuilder(entry);
2133 IRBuilder<> Free0Builder(free0);
2134 IRBuilder<> EndBuilder(end);
2135
2136 auto primal = F->arg_begin();
2137 Argument *first_shadow = F->arg_begin() + 1;
2140
2141 Value *isNotEqual = EntryBuilder.CreateICmpNE(primal, first_shadow);
2142 EntryBuilder.CreateCondBr(isNotEqual, free0, end);
2143
2144 SmallVector<Value *, 1> args = {first_shadow};
2145#if LLVM_VERSION_MAJOR >= 14
2146 for (size_t i = 1; i < call->arg_size(); i++)
2147#else
2148 for (size_t i = 1; i < call->getNumArgOperands(); i++)
2149#endif
2150 {
2151 args.push_back(F->arg_begin() + width + i);
2152 }
2153
2154 CallInst *CI = Free0Builder.CreateCall(FreeTy, Free, args);
2155 CI->setAttributes(FreeAttributes);
2156 CI->setCallingConv(CallingConvention);
2157
2158 if (width > 1) {
2159 Value *checkResult = nullptr;
2160 BasicBlock *free1 = BasicBlock::Create(M.getContext(), "free1", F);
2161 IRBuilder<> Free1Builder(free1);
2162
2163 for (unsigned i = 0; i < width; i++) {
2164 addFunctionNoCapture(F, i + 1);
2165 Argument *shadow = F->arg_begin() + i + 1;
2166
2167 if (i < width - 1) {
2168 Argument *nextShadow = F->arg_begin() + i + 2;
2169 Value *isNotEqual = Free0Builder.CreateICmpNE(shadow, nextShadow);
2170 checkResult = checkResult
2171 ? Free0Builder.CreateAnd(isNotEqual, checkResult)
2172 : isNotEqual;
2173
2174 args[0] = nextShadow;
2175 CallInst *CI = Free1Builder.CreateCall(FreeTy, Free, args);
2176 CI->setAttributes(FreeAttributes);
2177 CI->setCallingConv(CallingConvention);
2178 }
2179 }
2180 Free0Builder.CreateCondBr(checkResult, free1, end);
2181 Free1Builder.CreateBr(end);
2182 } else {
2183 Free0Builder.CreateBr(end);
2184 }
2185
2186 EndBuilder.CreateRetVoid();
2187
2188 return F;
2189}
2190
2191/// Create function to computer nearest power of two
2192llvm::Value *nextPowerOfTwo(llvm::IRBuilder<> &B, llvm::Value *V) {
2193 assert(V->getType()->isIntegerTy());
2194 IntegerType *T = cast<IntegerType>(V->getType());
2195 V = B.CreateAdd(V, ConstantInt::get(T, -1));
2196 for (size_t i = 1; i < T->getBitWidth(); i *= 2) {
2197 V = B.CreateOr(V, B.CreateLShr(V, ConstantInt::get(T, i)));
2198 }
2199 V = B.CreateAdd(V, ConstantInt::get(T, 1));
2200 return V;
2201}
2202
2203llvm::Function *getOrInsertDifferentialWaitallSave(llvm::Module &M,
2204 ArrayRef<llvm::Type *> T,
2205 PointerType *reqType) {
2206 std::string name = "__enzyme_differential_waitall_save";
2207 FunctionType *FT = FunctionType::get(getUnqual(reqType), T, false);
2208 Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
2209
2210 if (!F->empty())
2211 return F;
2212
2213 F->setLinkage(Function::LinkageTypes::InternalLinkage);
2214 F->addFnAttr(Attribute::NoUnwind);
2215 F->addFnAttr(Attribute::AlwaysInline);
2216
2217 BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
2218
2219 auto buff = F->arg_begin();
2220 buff->setName("count");
2221 Value *count = buff;
2222 Value *req = buff + 1;
2223 req->setName("req");
2224 Value *dreq = buff + 2;
2225 dreq->setName("dreq");
2226
2227 IRBuilder<> B(entry);
2228 count = B.CreateZExtOrTrunc(count, Type::getInt64Ty(entry->getContext()));
2229
2230 auto ret = CreateAllocation(B, reqType, count);
2231
2232 BasicBlock *loopBlock = BasicBlock::Create(M.getContext(), "loop", F);
2233 BasicBlock *endBlock = BasicBlock::Create(M.getContext(), "end", F);
2234
2235 B.CreateCondBr(B.CreateICmpEQ(count, ConstantInt::get(count->getType(), 0)),
2236 endBlock, loopBlock);
2237
2238 B.SetInsertPoint(loopBlock);
2239 auto idx = B.CreatePHI(count->getType(), 2);
2240 idx->addIncoming(ConstantInt::get(count->getType(), 0), entry);
2241 auto inc = B.CreateAdd(idx, ConstantInt::get(count->getType(), 1));
2242 idx->addIncoming(inc, loopBlock);
2243
2244 Type *reqT = reqType; // req->getType()->getPointerElementType();
2245 Value *idxs[] = {idx};
2246 Value *ireq = B.CreateInBoundsGEP(reqT, req, idxs);
2247 Value *idreq = B.CreateInBoundsGEP(reqT, dreq, idxs);
2248 Value *iout = B.CreateInBoundsGEP(reqType, ret, idxs);
2249 Value *isNull = nullptr;
2250 if (auto GV = M.getNamedValue("ompi_request_null")) {
2251 Value *reql = B.CreatePointerCast(ireq, getUnqual(GV->getType()));
2252 reql = B.CreateLoad(GV->getType(), reql);
2253 isNull = B.CreateICmpEQ(reql, GV);
2254 }
2255
2256 idreq = B.CreatePointerCast(idreq, getUnqual(reqType));
2257 Value *d_reqp = B.CreateLoad(reqType, idreq);
2258 if (isNull)
2259 d_reqp = B.CreateSelect(isNull, Constant::getNullValue(d_reqp->getType()),
2260 d_reqp);
2261
2262 B.CreateStore(d_reqp, iout);
2263
2264 B.CreateCondBr(B.CreateICmpEQ(inc, count), endBlock, loopBlock);
2265
2266 B.SetInsertPoint(endBlock);
2267 B.CreateRet(ret);
2268 return F;
2269}
2270
2271llvm::Function *getOrInsertDifferentialMPI_Wait(llvm::Module &M,
2272 ArrayRef<llvm::Type *> T,
2273 Type *reqType,
2274 StringRef caller) {
2275 llvm::SmallVector<llvm::Type *, 4> types(T.begin(), T.end());
2276 types.push_back(reqType);
2277
2278 auto &&[prefix, _, postfix] = tripleSplitDollar(caller);
2279
2280 std::string name = "__enzyme_differential_mpi_wait";
2281 if (prefix.size() != 0 || postfix.size() != 0) {
2282 name = (Twine(name) + "$" + prefix + "$" + postfix).str();
2283 }
2284 FunctionType *FT =
2285 FunctionType::get(Type::getVoidTy(M.getContext()), types, false);
2286 Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
2287
2288 if (!F->empty())
2289 return F;
2290
2291 F->setLinkage(Function::LinkageTypes::InternalLinkage);
2292 F->addFnAttr(Attribute::NoUnwind);
2293 F->addFnAttr(Attribute::AlwaysInline);
2294
2295 BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
2296 BasicBlock *isend = BasicBlock::Create(M.getContext(), "invertISend", F);
2297 BasicBlock *irecv = BasicBlock::Create(M.getContext(), "invertIRecv", F);
2298
2299#if 0
2300 /*0 */getInt8PtrTy(call.getContext())
2301 /*1 */i64
2302 /*2 */getInt8PtrTy(call.getContext())
2303 /*3 */i64
2304 /*4 */i64
2305 /*5 */getInt8PtrTy(call.getContext())
2306 /*6 */Type::getInt8Ty(call.getContext())
2307#endif
2308
2309 auto buff = F->arg_begin();
2310 buff->setName("buf");
2311 Value *buf = buff;
2312 Value *count = buff + 1;
2313 count->setName("count");
2314 Value *datatype = buff + 2;
2315 datatype->setName("datatype");
2316 Value *source = buff + 3;
2317 source->setName("source");
2318 Value *tag = buff + 4;
2319 tag->setName("tag");
2320 Value *comm = buff + 5;
2321 comm->setName("comm");
2322 Value *fn = buff + 6;
2323 fn->setName("fn");
2324 Value *d_req = buff + 7;
2325 d_req->setName("d_req");
2326
2327 auto isendfn = M.getFunction(getRenamedPerCallingConv(caller, "MPI_Isend"));
2328 assert(isendfn);
2329 // TODO: what if Isend not defined, but Irecv is?
2330 FunctionType *FuT = isendfn->getFunctionType();
2331
2332 auto irecvfn = cast<Function>(
2333 M.getOrInsertFunction(getRenamedPerCallingConv(caller, "MPI_Irecv"), FuT)
2334 .getCallee());
2335 assert(irecvfn);
2336
2337 IRBuilder<> B(entry);
2338 auto arg = isendfn->arg_begin();
2339 if (arg->getType()->isIntegerTy())
2340 buf = B.CreatePtrToInt(buf, arg->getType());
2341 arg++;
2342 count = B.CreateZExtOrTrunc(count, arg->getType());
2343 arg++;
2344 datatype = B.CreatePointerCast(datatype, arg->getType());
2345 arg++;
2346 source = B.CreateZExtOrTrunc(source, arg->getType());
2347 arg++;
2348 tag = B.CreateZExtOrTrunc(tag, arg->getType());
2349 arg++;
2350 comm = B.CreatePointerCast(comm, arg->getType());
2351 arg++;
2352 if (arg->getType()->isIntegerTy())
2353 d_req = B.CreatePtrToInt(d_req, arg->getType());
2354 Value *args[] = {
2355 buf, count, datatype, source, tag, comm, d_req,
2356 };
2357
2358 B.CreateCondBr(B.CreateICmpEQ(fn, ConstantInt::get(fn->getType(),
2359 (int)MPI_CallType::ISEND)),
2360 isend, irecv);
2361
2362 {
2363 B.SetInsertPoint(isend);
2364 auto fcall = B.CreateCall(irecvfn, args);
2365 fcall->setCallingConv(isendfn->getCallingConv());
2366 B.CreateRetVoid();
2367 }
2368
2369 {
2370 B.SetInsertPoint(irecv);
2371 auto fcall = B.CreateCall(isendfn, args);
2372 fcall->setCallingConv(isendfn->getCallingConv());
2373 B.CreateRetVoid();
2374 }
2375 return F;
2376}
2377
2378llvm::Value *getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr,
2379 llvm::Type *OpType, ConcreteType CT,
2380 llvm::Type *intType, IRBuilder<> &B2) {
2381 std::string name = "__enzyme_mpi_sum" + CT.str();
2382 assert(CT.isFloat());
2383 auto FlT = CT.isFloat();
2384
2385 if (auto Glob = M.getGlobalVariable(name)) {
2386 return B2.CreateLoad(Glob->getValueType(), Glob);
2387 }
2388
2389 llvm::Type *types[] = {getUnqual(FlT), getUnqual(FlT), getUnqual(intType),
2390 OpPtr};
2391 FunctionType *FuT =
2392 FunctionType::get(Type::getVoidTy(M.getContext()), types, false);
2393 Function *F =
2394 cast<Function>(M.getOrInsertFunction(name + "_run", FuT).getCallee());
2395
2396 F->setLinkage(Function::LinkageTypes::InternalLinkage);
2397#if LLVM_VERSION_MAJOR >= 16
2398 F->setOnlyAccessesArgMemory();
2399#else
2400 F->addFnAttr(Attribute::ArgMemOnly);
2401#endif
2402 F->addFnAttr(Attribute::NoUnwind);
2403 F->addFnAttr(Attribute::AlwaysInline);
2405 F->addParamAttr(0, Attribute::ReadOnly);
2408 F->addParamAttr(2, Attribute::ReadOnly);
2410 F->addParamAttr(3, Attribute::ReadNone);
2411
2412 BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
2413 BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F);
2414 BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F);
2415
2416 auto src = F->arg_begin();
2417 src->setName("src");
2418 auto dst = src + 1;
2419 dst->setName("dst");
2420 auto lenp = dst + 1;
2421 lenp->setName("lenp");
2422 Value *len;
2423 // TODO consider using datatype arg and asserting same size as assumed
2424 // by type analysis
2425
2426 {
2427 IRBuilder<> B(entry);
2428 len = B.CreateLoad(intType, lenp);
2429 B.CreateCondBr(B.CreateICmpEQ(len, ConstantInt::get(len->getType(), 0)),
2430 end, body);
2431 }
2432
2433 {
2434 IRBuilder<> B(body);
2435 B.setFastMathFlags(getFast());
2436 PHINode *idx = B.CreatePHI(len->getType(), 2, "idx");
2437 idx->addIncoming(ConstantInt::get(len->getType(), 0), entry);
2438
2439 Value *dsti = B.CreateInBoundsGEP(FlT, dst, idx, "dst.i");
2440 LoadInst *dstl = B.CreateLoad(FlT, dsti, "dst.i.l");
2441
2442 Value *srci = B.CreateInBoundsGEP(FlT, src, idx, "src.i");
2443 LoadInst *srcl = B.CreateLoad(FlT, srci, "src.i.l");
2444 B.CreateStore(B.CreateFAdd(srcl, dstl), dsti);
2445
2446 Value *next =
2447 B.CreateNUWAdd(idx, ConstantInt::get(len->getType(), 1), "idx.next");
2448 idx->addIncoming(next, body);
2449 B.CreateCondBr(B.CreateICmpEQ(len, next), end, body);
2450 }
2451
2452 {
2453 IRBuilder<> B(end);
2454 B.CreateRetVoid();
2455 }
2456
2457 llvm::Type *rtypes[] = {getInt8PtrTy(M.getContext()), intType, OpPtr};
2458 FunctionType *RFT = FunctionType::get(intType, rtypes, false);
2459
2460 Constant *RF = M.getNamedValue("MPI_Op_create");
2461 if (!RF) {
2462 RF =
2463 cast<Function>(M.getOrInsertFunction("MPI_Op_create", RFT).getCallee());
2464 } else {
2465 RF = ConstantExpr::getBitCast(RF, getUnqual(RFT));
2466 }
2467
2468 GlobalVariable *GV =
2469 new GlobalVariable(M, OpType, false, GlobalVariable::InternalLinkage,
2470 UndefValue::get(OpType), name);
2471
2472 Type *i1Ty = Type::getInt1Ty(M.getContext());
2473 GlobalVariable *initD = new GlobalVariable(
2474 M, i1Ty, false, GlobalVariable::InternalLinkage,
2475 ConstantInt::getFalse(M.getContext()), name + "_initd");
2476
2477 // Finish initializing mpi sum
2478 // https://www.mpich.org/static/docs/v3.2/www3/MPI_Op_create.html
2479 FunctionType *IFT = FunctionType::get(Type::getVoidTy(M.getContext()),
2480 ArrayRef<Type *>(), false);
2481 Function *initializerFunction = cast<Function>(
2482 M.getOrInsertFunction(name + "initializer", IFT).getCallee());
2483
2484 initializerFunction->setLinkage(Function::LinkageTypes::InternalLinkage);
2485 initializerFunction->addFnAttr(Attribute::NoUnwind);
2486
2487 {
2488 BasicBlock *entry =
2489 BasicBlock::Create(M.getContext(), "entry", initializerFunction);
2490 BasicBlock *run =
2491 BasicBlock::Create(M.getContext(), "run", initializerFunction);
2492 BasicBlock *end =
2493 BasicBlock::Create(M.getContext(), "end", initializerFunction);
2494 IRBuilder<> B(entry);
2495
2496 B.CreateCondBr(B.CreateLoad(initD->getValueType(), initD), end, run);
2497
2498 B.SetInsertPoint(run);
2499 Value *args[] = {ConstantExpr::getPointerCast(F, rtypes[0]),
2500 ConstantInt::get(rtypes[1], 1, false),
2501 ConstantExpr::getPointerCast(GV, rtypes[2])};
2502 B.CreateCall(RFT, RF, args);
2503 B.CreateStore(ConstantInt::getTrue(M.getContext()), initD);
2504 B.CreateBr(end);
2505 B.SetInsertPoint(end);
2506 B.CreateRetVoid();
2507 }
2508
2509 B2.CreateCall(M.getFunction(name + "initializer"));
2510 return B2.CreateLoad(GV->getValueType(), GV);
2511}
2512
2513void mayExecuteAfter(llvm::SmallVectorImpl<llvm::Instruction *> &results,
2514 llvm::Instruction *inst,
2515 const llvm::SmallPtrSetImpl<Instruction *> &stores,
2516 const llvm::Loop *region) {
2517 using namespace llvm;
2518 std::map<BasicBlock *, SmallVector<Instruction *, 1>> maybeBlocks;
2519 BasicBlock *instBlk = inst->getParent();
2520 for (auto store : stores) {
2521 BasicBlock *storeBlk = store->getParent();
2522 if (instBlk == storeBlk) {
2523 // if store doesn't come before, exit.
2524
2525 if (store != inst) {
2526 BasicBlock::const_iterator It = storeBlk->begin();
2527 for (; &*It != store && &*It != inst; ++It)
2528 /*empty*/;
2529 // if inst comes first (e.g. before store) in the
2530 // block, return true
2531 if (&*It == inst) {
2532 results.push_back(store);
2533 }
2534 }
2535 maybeBlocks[storeBlk].push_back(store);
2536 } else {
2537 maybeBlocks[storeBlk].push_back(store);
2538 }
2539 }
2540
2541 if (maybeBlocks.size() == 0)
2542 return;
2543
2544 llvm::SmallVector<BasicBlock *, 2> todo;
2545 for (auto B : successors(instBlk)) {
2546 if (region && region->getHeader() == B) {
2547 continue;
2548 }
2549 todo.push_back(B);
2550 }
2551
2552 SmallPtrSet<BasicBlock *, 2> seen;
2553 while (todo.size()) {
2554 auto cur = todo.back();
2555 todo.pop_back();
2556 if (seen.count(cur))
2557 continue;
2558 seen.insert(cur);
2559 auto found = maybeBlocks.find(cur);
2560 if (found != maybeBlocks.end()) {
2561 for (auto store : found->second)
2562 results.push_back(store);
2563 maybeBlocks.erase(found);
2564 }
2565 for (auto B : successors(cur)) {
2566 if (region && region->getHeader() == B) {
2567 continue;
2568 }
2569 todo.push_back(B);
2570 }
2571 }
2572}
2573
2575 llvm::ScalarEvolution &SE, llvm::LoopInfo &LI, llvm::DominatorTree &DT,
2576 llvm::Instruction *maybeReader, const llvm::SCEV *LoadStart,
2577 const llvm::SCEV *LoadEnd, llvm::Instruction *maybeWriter,
2578 const llvm::SCEV *StoreStart, const llvm::SCEV *StoreEnd,
2579 llvm::Loop *scope) {
2580 // The store may either occur directly after the load in the current loop
2581 // nest, or prior to the load in a subsequent iteration of the loop nest
2582 // Generally:
2583 // L0 -> scope -> L1 -> L2 -> L3 -> load_L4 -> load_L5 ... Load
2584 // \-> store_L4 -> store_L5 ... Store
2585 // We begin by finding the common ancestor of the two loops, which may
2586 // be none.
2587 Loop *anc = getAncestor(LI.getLoopFor(maybeReader->getParent()),
2588 LI.getLoopFor(maybeWriter->getParent()));
2589
2590 // The surrounding scope must contain the ancestor
2591 if (scope) {
2592 assert(anc);
2593 assert(scope == anc || scope->contains(anc));
2594 }
2595
2596 // Consider the case where the load and store don't share any common loops.
2597 // That is to say, there's no loops in [scope, ancestor) we need to consider
2598 // having a store in a later iteration overwrite the load of a previous
2599 // iteration.
2600 //
2601 // An example of this overwriting would be a "left shift"
2602 // for (int j = 1; j<N; j++) {
2603 // load A[j]
2604 // store A[j-1]
2605 // }
2606 //
2607 // Ignoring such ancestors, if we compare the two regions to have no direct
2608 // overlap we can return that it doesn't overwrite memory if the regions
2609 // don't overlap at any level of region expansion. That is to say, we can
2610 // expand the start or end, for any loop to be the worst case scenario
2611 // given the loop bounds.
2612 //
2613 // However, now let us consider the case where there are surrounding loops.
2614 // If the storing boundary is represented by an induction variable of one
2615 // of these common loops, we must conseratively expand it all the way to the
2616 // end. We will also mark the loops we may expand. If we encounter all
2617 // intervening loops in this fashion, and it is proven safe in these cases,
2618 // the region does not overlap. However, if we don't encounter all surrounding
2619 // loops in our induction expansion, we may simply be repeating the write
2620 // which we should also ensure we say the region may overlap (due to the
2621 // repetition).
2622 //
2623 // Since we also have a Loop scope, we can ignore any common loops at the
2624 // scope level or above
2625
2626 /// We force all ranges for all loops in range ... [scope, anc], .... cur
2627 /// to expand the number of iterations
2628
2629 SmallPtrSet<const Loop *, 1> visitedAncestors;
2630 auto skipLoop = [&](const Loop *L) {
2631 assert(L);
2632 if (scope && L->contains(scope))
2633 return false;
2634
2635 if (anc && (anc == L || anc->contains(L))) {
2636 visitedAncestors.insert(L);
2637 return true;
2638 }
2639 return false;
2640 };
2641
2642 // Check the boounds of an [... endprev][startnext ...] for potential
2643 // overlaps. The boolean EndIsStore is true of the EndPev represents
2644 // the store and should have its loops expanded, or if that should
2645 // apply to StartNed.
2646 auto hasOverlap = [&](const SCEV *EndPrev, const SCEV *StartNext,
2647 bool EndIsStore) {
2648 for (auto slim = StartNext; slim != SE.getCouldNotCompute();) {
2649 bool sskip = false;
2650 if (!EndIsStore)
2651 if (auto startL = dyn_cast<SCEVAddRecExpr>(slim))
2652 if (skipLoop(startL->getLoop()) &&
2653 SE.isKnownNonPositive(startL->getStepRecurrence(SE))) {
2654 sskip = true;
2655 }
2656
2657 if (!sskip)
2658 for (auto elim = EndPrev; elim != SE.getCouldNotCompute();) {
2659 {
2660
2661 bool eskip = false;
2662 if (EndIsStore)
2663 if (auto endL = dyn_cast<SCEVAddRecExpr>(elim)) {
2664 if (skipLoop(endL->getLoop()) &&
2665 SE.isKnownNonNegative(endL->getStepRecurrence(SE))) {
2666 eskip = true;
2667 }
2668 }
2669
2670 // Moreover because otherwise SE cannot "groupScevByComplexity"
2671 // we need to ensure that if both slim/elim are AddRecv
2672 // they must be in the same loop, or one loop must dominate
2673 // the other.
2674 if (!eskip) {
2675
2676 if (auto endL = dyn_cast<SCEVAddRecExpr>(elim)) {
2677 auto EH = endL->getLoop()->getHeader();
2678 if (auto startL = dyn_cast<SCEVAddRecExpr>(slim)) {
2679 auto SH = startL->getLoop()->getHeader();
2680 if (EH != SH && !DT.dominates(EH, SH) &&
2681 !DT.dominates(SH, EH))
2682 eskip = true;
2683 }
2684 }
2685 }
2686 if (!eskip) {
2687 auto sub = SE.getMinusSCEV(slim, elim);
2688 if (sub != SE.getCouldNotCompute() && SE.isKnownNonNegative(sub))
2689 return false;
2690 }
2691 }
2692
2693 if (auto endL = dyn_cast<SCEVAddRecExpr>(elim)) {
2694 if (SE.isKnownNonPositive(endL->getStepRecurrence(SE))) {
2695 elim = endL->getStart();
2696 continue;
2697 } else if (SE.isKnownNonNegative(endL->getStepRecurrence(SE))) {
2698#if LLVM_VERSION_MAJOR >= 12
2699 auto ebd = SE.getSymbolicMaxBackedgeTakenCount(endL->getLoop());
2700#else
2701 auto ebd = SE.getBackedgeTakenCount(endL->getLoop());
2702#endif
2703 if (ebd == SE.getCouldNotCompute())
2704 break;
2705 elim = endL->evaluateAtIteration(ebd, SE);
2706 continue;
2707 }
2708 }
2709 break;
2710 }
2711
2712 if (auto startL = dyn_cast<SCEVAddRecExpr>(slim)) {
2713 if (SE.isKnownNonNegative(startL->getStepRecurrence(SE))) {
2714 slim = startL->getStart();
2715 continue;
2716 } else if (SE.isKnownNonPositive(startL->getStepRecurrence(SE))) {
2717#if LLVM_VERSION_MAJOR >= 12
2718 auto sbd = SE.getSymbolicMaxBackedgeTakenCount(startL->getLoop());
2719#else
2720 auto sbd = SE.getBackedgeTakenCount(startL->getLoop());
2721#endif
2722 if (sbd == SE.getCouldNotCompute())
2723 break;
2724 slim = startL->evaluateAtIteration(sbd, SE);
2725 continue;
2726 }
2727 }
2728 break;
2729 }
2730 return true;
2731 };
2732
2733 // There is no overwrite if either the stores all occur before the loads
2734 // [S, S+Size][start load, L+Size]
2735 visitedAncestors.clear();
2736 if (!hasOverlap(StoreEnd, LoadStart, /*EndIsStore*/ true)) {
2737 // We must have seen all common loops as induction variables
2738 // to be legal, lest we have a repetition of the store.
2739 bool legal = true;
2740 for (const Loop *L = anc; anc != scope; anc = anc->getParentLoop()) {
2741 if (!visitedAncestors.count(L))
2742 legal = false;
2743 }
2744 if (legal)
2745 return false;
2746 }
2747
2748 // There is no overwrite if either the loads all occur before the stores
2749 // [start load, L+Size] [S, S+Size]
2750 visitedAncestors.clear();
2751 if (!hasOverlap(LoadEnd, StoreStart, /*EndIsStore*/ false)) {
2752 // We must have seen all common loops as induction variables
2753 // to be legal, lest we have a repetition of the store.
2754 bool legal = true;
2755 for (const Loop *L = anc; anc != scope; anc = anc->getParentLoop()) {
2756 if (!visitedAncestors.count(L))
2757 legal = false;
2758 }
2759 if (legal)
2760 return false;
2761 }
2762 return true;
2763}
2764
2765bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
2766 llvm::TargetLibraryInfo &TLI, ScalarEvolution &SE,
2767 llvm::LoopInfo &LI, llvm::DominatorTree &DT,
2768 llvm::Instruction *maybeReader,
2769 llvm::Instruction *maybeWriter,
2770 llvm::Loop *scope) {
2771 using namespace llvm;
2772 if (!writesToMemoryReadBy(TR, AA, TLI, maybeReader, maybeWriter))
2773 return false;
2774 const SCEV *LoadBegin = SE.getCouldNotCompute();
2775 const SCEV *LoadEnd = SE.getCouldNotCompute();
2776
2777 const SCEV *StoreBegin = SE.getCouldNotCompute();
2778 const SCEV *StoreEnd = SE.getCouldNotCompute();
2779
2780 Value *loadPtr = nullptr;
2781 Value *storePtr = nullptr;
2782 if (auto LI = dyn_cast<LoadInst>(maybeReader)) {
2783 loadPtr = LI->getPointerOperand();
2784 LoadBegin = SE.getSCEV(LI->getPointerOperand());
2785 if (LoadBegin != SE.getCouldNotCompute() &&
2786 !LoadBegin->getType()->isIntegerTy()) {
2787 auto &DL = maybeWriter->getModule()->getDataLayout();
2788 auto width = cast<IntegerType>(DL.getIndexType(LoadBegin->getType()))
2789 ->getBitWidth();
2790#if LLVM_VERSION_MAJOR >= 18
2791 auto TS = SE.getConstant(
2792 APInt(width, (int64_t)DL.getTypeStoreSize(LI->getType())));
2793#else
2794 auto TS = SE.getConstant(
2795 APInt(width, DL.getTypeStoreSize(LI->getType()).getFixedSize()));
2796#endif
2797 LoadEnd = SE.getAddExpr(LoadBegin, TS);
2798 }
2799 }
2800 if (auto SI = dyn_cast<StoreInst>(maybeWriter)) {
2801 storePtr = SI->getPointerOperand();
2802 StoreBegin = SE.getSCEV(SI->getPointerOperand());
2803 if (StoreBegin != SE.getCouldNotCompute() &&
2804 !StoreBegin->getType()->isIntegerTy()) {
2805 auto &DL = maybeWriter->getModule()->getDataLayout();
2806 auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType()))
2807 ->getBitWidth();
2808#if LLVM_VERSION_MAJOR >= 18
2809 auto TS =
2810 SE.getConstant(APInt(width, (int64_t)DL.getTypeStoreSize(
2811 SI->getValueOperand()->getType())));
2812#else
2813 auto TS = SE.getConstant(
2814 APInt(width, DL.getTypeStoreSize(SI->getValueOperand()->getType())
2815 .getFixedSize()));
2816#endif
2817 StoreEnd = SE.getAddExpr(StoreBegin, TS);
2818 }
2819 }
2820 if (auto MS = dyn_cast<MemSetInst>(maybeWriter)) {
2821 storePtr = MS->getArgOperand(0);
2822 StoreBegin = SE.getSCEV(MS->getArgOperand(0));
2823 if (StoreBegin != SE.getCouldNotCompute() &&
2824 !StoreBegin->getType()->isIntegerTy()) {
2825 if (auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) {
2826 auto &DL = MS->getModule()->getDataLayout();
2827 auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType()))
2828 ->getBitWidth();
2829 auto TS =
2830 SE.getConstant(APInt(width, Len->getValue().getLimitedValue()));
2831 StoreEnd = SE.getAddExpr(StoreBegin, TS);
2832 }
2833 }
2834 }
2835 if (auto MS = dyn_cast<MemTransferInst>(maybeWriter)) {
2836 storePtr = MS->getArgOperand(0);
2837 StoreBegin = SE.getSCEV(MS->getArgOperand(0));
2838 if (StoreBegin != SE.getCouldNotCompute() &&
2839 !StoreBegin->getType()->isIntegerTy()) {
2840 if (auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) {
2841 auto &DL = MS->getModule()->getDataLayout();
2842 auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType()))
2843 ->getBitWidth();
2844 auto TS =
2845 SE.getConstant(APInt(width, Len->getValue().getLimitedValue()));
2846 StoreEnd = SE.getAddExpr(StoreBegin, TS);
2847 }
2848 }
2849 }
2850 if (auto MS = dyn_cast<MemTransferInst>(maybeReader)) {
2851 loadPtr = MS->getArgOperand(1);
2852 LoadBegin = SE.getSCEV(MS->getArgOperand(1));
2853 if (LoadBegin != SE.getCouldNotCompute() &&
2854 !LoadBegin->getType()->isIntegerTy()) {
2855 if (auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) {
2856 auto &DL = MS->getModule()->getDataLayout();
2857 auto width = cast<IntegerType>(DL.getIndexType(LoadBegin->getType()))
2858 ->getBitWidth();
2859 auto TS =
2860 SE.getConstant(APInt(width, Len->getValue().getLimitedValue()));
2861 LoadEnd = SE.getAddExpr(LoadBegin, TS);
2862 }
2863 }
2864 }
2865
2866 if (loadPtr && storePtr)
2867 if (auto alias =
2868 arePointersGuaranteedNoAlias(TLI, AA, LI, loadPtr, storePtr, true))
2869 if (*alias)
2870 return false;
2871
2872 if (!overwritesToMemoryReadByLoop(SE, LI, DT, maybeReader, LoadBegin, LoadEnd,
2873 maybeWriter, StoreBegin, StoreEnd, scope))
2874 return false;
2875
2876 return true;
2877}
2878
2879/// Return whether maybeReader can read from memory written to by maybeWriter
2880bool writesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
2881 llvm::TargetLibraryInfo &TLI,
2882 llvm::Instruction *maybeReader,
2883 llvm::Instruction *maybeWriter) {
2884 assert(maybeReader->getParent()->getParent() ==
2885 maybeWriter->getParent()->getParent());
2886 using namespace llvm;
2887 if (isa<StoreInst>(maybeReader))
2888 return false;
2889 if (isa<FenceInst>(maybeReader)) {
2890 return false;
2891 }
2892 if (auto call = dyn_cast<CallInst>(maybeWriter)) {
2893 StringRef funcName = getFuncNameFromCall(call);
2894
2895 if (isDebugFunction(call->getCalledFunction()))
2896 return false;
2897
2898 if (isCertainPrint(funcName) || isAllocationFunction(funcName, TLI) ||
2899 isDeallocationFunction(funcName, TLI)) {
2900 return false;
2901 }
2902
2903 if (isMemFreeLibMFunction(funcName)) {
2904 return false;
2905 }
2906 if (funcName == "jl_array_copy" || funcName == "ijl_array_copy")
2907 return false;
2908
2909 if (funcName == "jl_genericmemory_copy_slice" ||
2910 funcName == "ijl_genericmemory_copy_slice")
2911 return false;
2912
2913 if (funcName == "jl_new_array" || funcName == "ijl_new_array")
2914 return false;
2915
2916 if (funcName == "julia.safepoint")
2917 return false;
2918
2919 if (funcName == "jl_idtable_rehash" || funcName == "ijl_idtable_rehash")
2920 return false;
2921
2922 // Isend only writes to inaccessible mem only
2923 if (funcName == "MPI_Send" || funcName == "PMPI_Send") {
2924 return false;
2925 }
2926 // Wait only overwrites memory in the status and request.
2927 if (funcName == "MPI_Wait" || funcName == "PMPI_Wait" ||
2928 funcName == "MPI_Waitall" || funcName == "PMPI_Waitall") {
2929#if LLVM_VERSION_MAJOR > 11
2930 auto loc = LocationSize::afterPointer();
2931#else
2932 auto loc = MemoryLocation::UnknownSize;
2933#endif
2934 size_t off = (funcName == "MPI_Wait" || funcName == "PMPI_Wait") ? 0 : 1;
2935 // No alias with status
2936 if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(off + 1),
2937 loc))) {
2938 // No alias with request
2939 if (!isRefSet(AA.getModRefInfo(maybeReader,
2940 call->getArgOperand(off + 0), loc)))
2941 return false;
2942 auto R = parseTBAA(
2943 *maybeReader,
2944 maybeReader->getParent()->getParent()->getParent()->getDataLayout(),
2945 nullptr)[{-1}];
2946 // Could still conflict with the mpi_request unless a non pointer
2947 // type.
2948 if (R != BaseType::Unknown && R != BaseType::Anything &&
2949 R != BaseType::Pointer)
2950 return false;
2951 }
2952 }
2953 // Isend only writes to inaccessible mem and request.
2954 if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") {
2955 auto R = parseTBAA(
2956 *maybeReader,
2957 maybeReader->getParent()->getParent()->getParent()->getDataLayout(),
2958 nullptr)[{-1}];
2959 // Could still conflict with the mpi_request, unless either
2960 // synchronous, or a non pointer type.
2961 if (R != BaseType::Unknown && R != BaseType::Anything &&
2962 R != BaseType::Pointer)
2963 return false;
2964#if LLVM_VERSION_MAJOR > 11
2965 if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6),
2966 LocationSize::afterPointer())))
2967 return false;
2968#else
2969 if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6),
2970 MemoryLocation::UnknownSize)))
2971 return false;
2972#endif
2973 return false;
2974 }
2975 if (funcName == "MPI_Irecv" || funcName == "PMPI_Irecv" ||
2976 funcName == "MPI_Recv" || funcName == "PMPI_Recv") {
2978 if (Constant *C = dyn_cast<Constant>(call->getArgOperand(2))) {
2979 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
2980 C = CE->getOperand(0);
2981 }
2982 if (auto GV = dyn_cast<GlobalVariable>(C)) {
2983 if (GV->getName() == "ompi_mpi_double") {
2984 type = ConcreteType(Type::getDoubleTy(C->getContext()));
2985 } else if (GV->getName() == "ompi_mpi_float") {
2986 type = ConcreteType(Type::getFloatTy(C->getContext()));
2987 }
2988 }
2989 }
2990 if (type.isKnown()) {
2991 auto R = parseTBAA(
2992 *maybeReader,
2993 maybeReader->getParent()->getParent()->getParent()->getDataLayout(),
2994 nullptr)[{-1}];
2995 if (R.isKnown() && type != R) {
2996 // Could still conflict with the mpi_request, unless either
2997 // synchronous, or a non pointer type.
2998 if (funcName == "MPI_Recv" || funcName == "PMPI_Recv" ||
3000 return false;
3001#if LLVM_VERSION_MAJOR > 11
3002 if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6),
3003 LocationSize::afterPointer())))
3004 return false;
3005#else
3006 if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6),
3007 MemoryLocation::UnknownSize)))
3008 return false;
3009#endif
3010 }
3011 }
3012 }
3013 if (auto II = dyn_cast<IntrinsicInst>(call)) {
3014 if (II->getIntrinsicID() == Intrinsic::stacksave)
3015 return false;
3016 if (II->getIntrinsicID() == Intrinsic::stackrestore)
3017 return false;
3018 if (II->getIntrinsicID() == Intrinsic::trap)
3019 return false;
3020#if LLVM_VERSION_MAJOR >= 13
3021 if (II->getIntrinsicID() == Intrinsic::experimental_noalias_scope_decl)
3022 return false;
3023#endif
3024 }
3025
3026 if (auto iasm = dyn_cast<InlineAsm>(call->getCalledOperand())) {
3027 if (StringRef(iasm->getAsmString()).contains("exit"))
3028 return false;
3029 }
3030 }
3031 if (auto call = dyn_cast<CallInst>(maybeReader)) {
3032 StringRef funcName = getFuncNameFromCall(call);
3033
3034 if (isDebugFunction(call->getCalledFunction()))
3035 return false;
3036
3037 if (isAllocationFunction(funcName, TLI) ||
3038 isDeallocationFunction(funcName, TLI)) {
3039 return false;
3040 }
3041
3042 if (isMemFreeLibMFunction(funcName)) {
3043 return false;
3044 }
3045
3046 if (auto II = dyn_cast<IntrinsicInst>(call)) {
3047 if (II->getIntrinsicID() == Intrinsic::stacksave)
3048 return false;
3049 if (II->getIntrinsicID() == Intrinsic::stackrestore)
3050 return false;
3051 if (II->getIntrinsicID() == Intrinsic::trap)
3052 return false;
3053#if LLVM_VERSION_MAJOR >= 13
3054 if (II->getIntrinsicID() == Intrinsic::experimental_noalias_scope_decl)
3055 return false;
3056#endif
3057 }
3058 }
3059 if (auto call = dyn_cast<InvokeInst>(maybeWriter)) {
3060 StringRef funcName = getFuncNameFromCall(call);
3061
3062 if (isDebugFunction(call->getCalledFunction()))
3063 return false;
3064
3065 if (isAllocationFunction(funcName, TLI) ||
3066 isDeallocationFunction(funcName, TLI)) {
3067 return false;
3068 }
3069
3070 if (isMemFreeLibMFunction(funcName)) {
3071 return false;
3072 }
3073 if (funcName == "jl_array_copy" || funcName == "ijl_array_copy")
3074 return false;
3075
3076 if (funcName == "jl_genericmemory_copy_slice" ||
3077 funcName == "ijl_genericmemory_copy_slice")
3078 return false;
3079
3080 if (funcName == "jl_idtable_rehash" || funcName == "ijl_idtable_rehash")
3081 return false;
3082
3083 if (auto iasm = dyn_cast<InlineAsm>(call->getCalledOperand())) {
3084 if (StringRef(iasm->getAsmString()).contains("exit"))
3085 return false;
3086 }
3087 }
3088 if (auto call = dyn_cast<InvokeInst>(maybeReader)) {
3089 StringRef funcName = getFuncNameFromCall(call);
3090
3091 if (isDebugFunction(call->getCalledFunction()))
3092 return false;
3093
3094 if (isAllocationFunction(funcName, TLI) ||
3095 isDeallocationFunction(funcName, TLI)) {
3096 return false;
3097 }
3098
3099 if (isMemFreeLibMFunction(funcName)) {
3100 return false;
3101 }
3102 }
3103 assert(maybeWriter->mayWriteToMemory());
3104 assert(maybeReader->mayReadFromMemory());
3105
3106 if (auto li = dyn_cast<LoadInst>(maybeReader)) {
3107 if (TR) {
3108 auto TT = TR->query(li)[{-1}];
3109 if (TT != BaseType::Unknown && TT != BaseType::Anything) {
3110 if (auto si = dyn_cast<StoreInst>(maybeWriter)) {
3111 auto TT2 = TR->query(si->getValueOperand())[{-1}];
3112 if (TT2 != BaseType::Unknown && TT2 != BaseType::Anything) {
3113 if (TT != TT2)
3114 return false;
3115 }
3116 auto &dl = li->getParent()->getParent()->getParent()->getDataLayout();
3117 auto len =
3118 (dl.getTypeSizeInBits(si->getValueOperand()->getType()) + 7) / 8;
3119 TT2 = TR->query(si->getPointerOperand()).Lookup(len, dl)[{-1}];
3120 if (TT2 != BaseType::Unknown && TT2 != BaseType::Anything) {
3121 if (TT != TT2)
3122 return false;
3123 }
3124 }
3125 }
3126 }
3127 return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(li)));
3128 }
3129 if (auto rmw = dyn_cast<AtomicRMWInst>(maybeReader)) {
3130 return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(rmw)));
3131 }
3132 if (auto xch = dyn_cast<AtomicCmpXchgInst>(maybeReader)) {
3133 return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(xch)));
3134 }
3135 if (auto mti = dyn_cast<MemTransferInst>(maybeReader)) {
3136 return isModSet(
3137 AA.getModRefInfo(maybeWriter, MemoryLocation::getForSource(mti)));
3138 }
3139
3140 if (auto si = dyn_cast<StoreInst>(maybeWriter)) {
3141 return isRefSet(AA.getModRefInfo(maybeReader, MemoryLocation::get(si)));
3142 }
3143 if (auto rmw = dyn_cast<AtomicRMWInst>(maybeWriter)) {
3144 return isRefSet(AA.getModRefInfo(maybeReader, MemoryLocation::get(rmw)));
3145 }
3146 if (auto xch = dyn_cast<AtomicCmpXchgInst>(maybeWriter)) {
3147 return isRefSet(AA.getModRefInfo(maybeReader, MemoryLocation::get(xch)));
3148 }
3149 if (auto mti = dyn_cast<MemIntrinsic>(maybeWriter)) {
3150 return isRefSet(
3151 AA.getModRefInfo(maybeReader, MemoryLocation::getForDest(mti)));
3152 }
3153
3154 if (auto cb = dyn_cast<CallInst>(maybeReader)) {
3155 return isModOrRefSet(AA.getModRefInfo(maybeWriter, cb));
3156 }
3157 if (auto cb = dyn_cast<InvokeInst>(maybeReader)) {
3158 return isModOrRefSet(AA.getModRefInfo(maybeWriter, cb));
3159 }
3160 llvm::errs() << " maybeReader: " << *maybeReader
3161 << " maybeWriter: " << *maybeWriter << "\n";
3162 llvm_unreachable("unknown inst2");
3163}
3164
3165// Find the base pointer of ptr and the offset in bytes from the start of
3166// the returned base pointer to this value.
3167AllocaInst *getBaseAndOffset(Value *ptr, size_t &offset) {
3168 offset = 0;
3169 while (true) {
3170 if (auto CI = dyn_cast<CastInst>(ptr)) {
3171 ptr = CI->getOperand(0);
3172 continue;
3173 }
3174 if (auto CI = dyn_cast<GetElementPtrInst>(ptr)) {
3175 auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout();
3176#if LLVM_VERSION_MAJOR >= 20
3177 SmallMapVector<Value *, APInt, 4> VariableOffsets;
3178#else
3179 MapVector<Value *, APInt> VariableOffsets;
3180#endif
3181 auto width = sizeof(size_t) * 8;
3182 APInt Offset(width, 0);
3183 bool success = collectOffset(cast<GEPOperator>(CI), DL, width,
3184 VariableOffsets, Offset);
3185 if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) {
3186 return nullptr;
3187 }
3188 offset += Offset.getZExtValue();
3189 ptr = CI->getOperand(0);
3190 continue;
3191 }
3192 if (isa<AllocaInst>(ptr)) {
3193 break;
3194 }
3195 if (auto LI = dyn_cast<LoadInst>(ptr)) {
3196 if (auto S = simplifyLoad(LI)) {
3197 ptr = S;
3198 continue;
3199 }
3200 }
3201 return nullptr;
3202 }
3203 return cast<AllocaInst>(ptr);
3204}
3205
3206// Find all user instructions of AI, returning tuples of <instruction, value,
3207// byte offet from AI> Unlike a simple get users, this will recurse through any
3208// constant gep offsets and casts
3209SmallVector<std::tuple<Instruction *, Value *, size_t>, 1>
3210findAllUsersOf(Value *AI) {
3211 SmallVector<std::pair<Value *, size_t>, 1> todo;
3212 todo.emplace_back(AI, 0);
3213
3214 SmallVector<std::tuple<Instruction *, Value *, size_t>, 1> users;
3215 while (todo.size()) {
3216 auto pair = todo.pop_back_val();
3217 Value *ptr = pair.first;
3218 size_t suboff = pair.second;
3219
3220 for (auto U : ptr->users()) {
3221 if (auto CI = dyn_cast<CastInst>(U)) {
3222 todo.emplace_back(CI, suboff);
3223 continue;
3224 }
3225 if (auto CI = dyn_cast<GetElementPtrInst>(U)) {
3226 auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout();
3227#if LLVM_VERSION_MAJOR >= 20
3228 SmallMapVector<Value *, APInt, 4> VariableOffsets;
3229#else
3230 MapVector<Value *, APInt> VariableOffsets;
3231#endif
3232 auto width = sizeof(size_t) * 8;
3233 APInt Offset(width, 0);
3234 bool success = collectOffset(cast<GEPOperator>(CI), DL, width,
3235 VariableOffsets, Offset);
3236
3237 if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) {
3238 users.emplace_back(cast<Instruction>(U), ptr, suboff);
3239 continue;
3240 }
3241 todo.emplace_back(CI, suboff + Offset.getZExtValue());
3242 continue;
3243 }
3244 users.emplace_back(cast<Instruction>(U), ptr, suboff);
3245 continue;
3246 }
3247 }
3248 return users;
3249}
3250
3251// Given a pointer, find all values of size `valSz` which could be loaded from
3252// that pointer when indexed at offset. If it is impossible to guarantee that
3253// the set contains all such values, set legal to false
3254SmallVector<std::pair<Value *, size_t>, 1>
3255getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, size_t valSz,
3256 bool &legal) {
3257 SmallVector<std::pair<Value *, size_t>, 1> options;
3258
3259 auto todo = findAllUsersOf(ptr0);
3260 std::set<std::tuple<Instruction *, Value *, size_t>> seen;
3261
3262 while (todo.size()) {
3263 auto pair = todo.pop_back_val();
3264 if (seen.count(pair))
3265 continue;
3266 seen.insert(pair);
3267 Instruction *U = std::get<0>(pair);
3268 Value *ptr = std::get<1>(pair);
3269 size_t suboff = std::get<2>(pair);
3270
3271 // Read only users do not set the memory inside of ptr
3272 if (isa<LoadInst>(U)) {
3273 continue;
3274 }
3275 if (auto MTI = dyn_cast<MemTransferInst>(U))
3276 if (MTI->getOperand(0) != ptr) {
3277 continue;
3278 }
3279 if (auto I = dyn_cast<Instruction>(U)) {
3280 if (!I->mayWriteToMemory() && I->getType()->isVoidTy())
3281 continue;
3282 }
3283
3284 if (auto SI = dyn_cast<StoreInst>(U)) {
3285 auto &DL = SI->getParent()->getParent()->getParent()->getDataLayout();
3286
3287 // We are storing into the ptr
3288 if (SI->getPointerOperand() == ptr) {
3289 auto storeSz =
3290 (DL.getTypeStoreSizeInBits(SI->getValueOperand()->getType()) + 7) /
3291 8;
3292 // If store is before the load would start
3293 if (storeSz + suboff <= offset)
3294 continue;
3295 // if store starts after load would start
3296 if (offset + valSz <= suboff)
3297 continue;
3298
3299 if (valSz <= storeSz) {
3300 assert(offset >= suboff);
3301 options.emplace_back(SI->getValueOperand(), offset - suboff);
3302 continue;
3303 }
3304 }
3305
3306 // We capture our pointer of interest, if it is stored into an alloca,
3307 // all loads of said alloca would potentially store into.
3308 if (SI->getValueOperand() == ptr) {
3309 if (suboff == 0) {
3310 size_t mid_offset = 0;
3311 if (auto AI2 =
3312 getBaseAndOffset(SI->getPointerOperand(), mid_offset)) {
3313 bool sublegal = true;
3314 auto ptrSz = (DL.getTypeStoreSizeInBits(ptr->getType()) + 7) / 8;
3315 auto subPtrs =
3316 getAllLoadedValuesFrom(AI2, mid_offset, ptrSz, sublegal);
3317 if (!sublegal) {
3318 legal = false;
3319 return options;
3320 }
3321 for (auto &&[subPtr, subOff] : subPtrs) {
3322 if (subOff != 0)
3323 return options;
3324 for (const auto &pair3 : findAllUsersOf(subPtr)) {
3325 todo.emplace_back(std::move(pair3));
3326 }
3327 }
3328 continue;
3329 }
3330 }
3331 }
3332 }
3333
3334 if (auto II = dyn_cast<IntrinsicInst>(U)) {
3335 if (II->getCalledFunction()->getName() == "llvm.enzyme.lifetime_start" ||
3336 II->getCalledFunction()->getName() == "llvm.enzyme.lifetime_end")
3337 continue;
3338 if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
3339 II->getIntrinsicID() == Intrinsic::lifetime_end)
3340 continue;
3341 }
3342
3343 // If we copy into the ptr at a location that includes the offset, consider
3344 // all sub uses
3345 if (auto MTI = dyn_cast<MemTransferInst>(U)) {
3346 if (auto CI = dyn_cast<ConstantInt>(MTI->getLength())) {
3347 if (MTI->getOperand(0) == ptr) {
3348 auto storeSz = CI->getValue();
3349
3350 // If store is before the load would start
3351 if ((storeSz + suboff).ule(offset))
3352 continue;
3353
3354 // if store starts after load would start
3355 if (offset + valSz <= suboff)
3356 continue;
3357
3358 if (suboff == 0 && CI->getValue().uge(offset + valSz)) {
3359 size_t midoffset = 0;
3360 auto AI2 = getBaseAndOffset(MTI->getOperand(1), midoffset);
3361 if (!AI2) {
3362 legal = false;
3363 return options;
3364 }
3365 if (midoffset != 0) {
3366 legal = false;
3367 return options;
3368 }
3369 for (const auto &pair3 : findAllUsersOf(AI2)) {
3370 todo.emplace_back(std::move(pair3));
3371 }
3372 continue;
3373 }
3374 }
3375 }
3376 }
3377
3378 legal = false;
3379 return options;
3380 }
3381
3382 return options;
3383}
3384
3385// Perform mem2reg/sroa to identify the innermost value being represented.
3386Value *simplifyLoad(Value *V, size_t valSz, size_t preOffset) {
3387 if (auto LI = dyn_cast<LoadInst>(V)) {
3388 if (valSz == 0) {
3389 auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout();
3390 valSz = (DL.getTypeSizeInBits(LI->getType()) + 7) / 8;
3391 }
3392
3393 Value *ptr = LI->getPointerOperand();
3394 size_t offset = 0;
3395
3396 if (auto ptr2 = simplifyLoad(ptr)) {
3397 ptr = ptr2;
3398 }
3399 auto AI = getBaseAndOffset(ptr, offset);
3400 if (!AI) {
3401 return nullptr;
3402 }
3403 offset += preOffset;
3404
3405 bool legal = true;
3406 auto opts = getAllLoadedValuesFrom(AI, offset, valSz, legal);
3407
3408 if (!legal) {
3409 return nullptr;
3410 }
3411 std::set<Value *> res;
3412 for (auto &&[opt, startOff] : opts) {
3413 Value *v2 = simplifyLoad(opt, valSz, startOff);
3414 if (v2)
3415 res.insert(v2);
3416 else
3417 res.insert(opt);
3418 }
3419 if (res.size() != 1) {
3420 return nullptr;
3421 }
3422 Value *retval = *res.begin();
3423 return retval;
3424 }
3425 if (auto EVI = dyn_cast<ExtractValueInst>(V)) {
3426 IRBuilder<> B(EVI);
3427 auto em =
3428 GradientUtils::extractMeta(B, EVI->getAggregateOperand(),
3429 EVI->getIndices(), "", /*fallback*/ false);
3430 if (em != nullptr) {
3431 if (auto SL2 = simplifyLoad(em, valSz))
3432 em = SL2;
3433 return em;
3434 }
3435 if (auto LI = dyn_cast<LoadInst>(EVI->getAggregateOperand())) {
3436 auto offset = preOffset;
3437
3438 auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout();
3439 SmallVector<Value *, 4> vec;
3440 vec.push_back(ConstantInt::get(Type::getInt64Ty(EVI->getContext()), 0));
3441 for (auto ind : EVI->getIndices()) {
3442 vec.push_back(
3443 ConstantInt::get(Type::getInt32Ty(EVI->getContext()), ind));
3444 }
3445 auto ud = UndefValue::get(getUnqual(EVI->getOperand(0)->getType()));
3446 auto g2 =
3447 GetElementPtrInst::Create(EVI->getOperand(0)->getType(), ud, vec);
3448 APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
3449 g2->accumulateConstantOffset(DL, ai);
3450 // Using destructor rather than eraseFromParent
3451 // as g2 has no parent
3452 delete g2;
3453
3454 offset += (size_t)ai.getLimitedValue();
3455
3456 if (valSz == 0) {
3457 auto &DL = EVI->getParent()->getParent()->getParent()->getDataLayout();
3458 valSz = (DL.getTypeSizeInBits(EVI->getType()) + 7) / 8;
3459 }
3460 return simplifyLoad(LI, valSz, offset);
3461 }
3462 }
3463 return nullptr;
3464}
3465
3466Value *GetFunctionValFromValue(Value *fn) {
3467 while (!isa<Function>(fn)) {
3468 if (auto ci = dyn_cast<CastInst>(fn)) {
3469 fn = ci->getOperand(0);
3470 continue;
3471 }
3472 if (auto ci = dyn_cast<ConstantExpr>(fn)) {
3473 if (ci->isCast()) {
3474 fn = ci->getOperand(0);
3475 continue;
3476 }
3477 }
3478 if (auto ci = dyn_cast<BlockAddress>(fn)) {
3479 fn = ci->getFunction();
3480 continue;
3481 }
3482 if (auto *GA = dyn_cast<GlobalAlias>(fn)) {
3483 fn = GA->getAliasee();
3484 continue;
3485 }
3486 if (auto *Call = dyn_cast<CallInst>(fn)) {
3487 if (auto F = Call->getCalledFunction()) {
3488 SmallPtrSet<Value *, 1> ret;
3489 for (auto &BB : *F) {
3490 if (auto RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
3491 ret.insert(RI->getReturnValue());
3492 }
3493 }
3494 if (ret.size() == 1) {
3495 auto val = *ret.begin();
3496 val = GetFunctionValFromValue(val);
3497 if (isa<Constant>(val)) {
3498 fn = val;
3499 continue;
3500 }
3501 if (auto arg = dyn_cast<Argument>(val)) {
3502 fn = Call->getArgOperand(arg->getArgNo());
3503 continue;
3504 }
3505 }
3506 }
3507 }
3508 if (auto *Call = dyn_cast<InvokeInst>(fn)) {
3509 if (auto F = Call->getCalledFunction()) {
3510 SmallPtrSet<Value *, 1> ret;
3511 for (auto &BB : *F) {
3512 if (auto RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
3513 ret.insert(RI->getReturnValue());
3514 }
3515 }
3516 if (ret.size() == 1) {
3517 auto val = *ret.begin();
3518 while (isa<LoadInst>(val)) {
3519 auto v2 = simplifyLoad(val);
3520 if (v2) {
3521 val = v2;
3522 continue;
3523 }
3524 break;
3525 }
3526 if (isa<Constant>(val)) {
3527 fn = val;
3528 continue;
3529 }
3530 if (auto arg = dyn_cast<Argument>(val)) {
3531 fn = Call->getArgOperand(arg->getArgNo());
3532 continue;
3533 }
3534 }
3535 }
3536 }
3537 if (auto S = simplifyLoad(fn)) {
3538 fn = S;
3539 continue;
3540 }
3541 break;
3542 }
3543
3544 return fn;
3545}
3546
3547Function *GetFunctionFromValue(Value *fn) {
3548 return dyn_cast<Function>(GetFunctionValFromValue(fn));
3549}
3550
3551Function *getFirstFunctionDefinition(Module &M) {
3552 for (auto &F : M) {
3553 if (!F.isDeclaration()) {
3554 return &F;
3555 }
3556 }
3557 return nullptr;
3558}
3559
3560#if LLVM_VERSION_MAJOR >= 16
3561std::optional<BlasInfo> extractBLAS(llvm::StringRef in)
3562#else
3563llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in)
3564#endif
3565{
3566 const char *extractable[] = {
3567 "dot", "scal", "axpy", "gemv", "gemm", "spmv", "syrk", "nrm2",
3568 "trmm", "trmv", "symm", "potrf", "potrs", "copy", "spmv", "syr2k",
3569 "potrs", "getrf", "getrs", "trtrs", "getri", "symv", "lacpy", "trsv",
3570 };
3571 const char *floatType[] = {"s", "d", "c", "z"};
3572 const char *prefixes[] = {"" /*Fortran*/, "cblas_"};
3573 const char *suffixes[] = {"", "_", "64_", "_64_"};
3574 for (auto t : floatType) {
3575 for (auto f : extractable) {
3576 for (auto p : prefixes) {
3577 for (auto s : suffixes) {
3578 if (in == (Twine(p) + t + f + s).str()) {
3579 bool is64 = llvm::StringRef(s).contains("64");
3580 return BlasInfo{
3581 t, p, s, f, is64,
3582 };
3583 }
3584 }
3585 }
3586 }
3587 }
3588 // c interface to cublas
3589 const char *cuCFloatType[] = {"S", "D", "C", "Z"};
3590 const char *cuFFloatType[] = {"s", "d", "c", "z"};
3591 const char *cuCPrefixes[] = {"cublas"};
3592 const char *cuSuffixes[] = {"", "_v2", "_64", "_v2_64"};
3593 for (auto t : llvm::enumerate(cuCFloatType)) {
3594 for (auto f : extractable) {
3595 for (auto p : cuCPrefixes) {
3596 for (auto s : cuSuffixes) {
3597 if (in == (Twine(p) + t.value() + f + s).str()) {
3598 bool is64 = llvm::StringRef(s).contains("64");
3599 return BlasInfo{
3600 t.value(), p, s, f, is64,
3601 };
3602 }
3603 }
3604 }
3605 }
3606 }
3607 // Fortran interface to cublas
3608 const char *cuFPrefixes[] = {"cublas_"};
3609 for (auto t : cuFFloatType) {
3610 for (auto f : extractable) {
3611 for (auto p : cuFPrefixes) {
3612 if (in == (Twine(p) + t + f).str()) {
3613 return BlasInfo{
3614 t, p, "", f, false,
3615 };
3616 }
3617 }
3618 }
3619 }
3620 return {};
3621}
3622
3623llvm::Constant *getUndefinedValueForType(llvm::Module &M, llvm::Type *T,
3624 bool forceZero) {
3626 return cast<Constant>(
3627 unwrap(EnzymeUndefinedValueForType(wrap(&M), wrap(T), forceZero)));
3628 else if (EnzymeZeroCache || forceZero)
3629 return Constant::getNullValue(T);
3630 else
3631 return UndefValue::get(T);
3632}
3633
3634llvm::Value *SanitizeDerivatives(llvm::Value *val, llvm::Value *toset,
3635 llvm::IRBuilder<> &BuilderM,
3636 llvm::Value *mask) {
3637 if (EnzymeCheckDerivativeNaN && toset->getType()->isFPOrFPVectorTy()) {
3638 auto current_bb = BuilderM.GetInsertBlock();
3639 auto fn = current_bb->getParent();
3640 auto mod = fn->getParent();
3641 auto &Context = mod->getContext();
3642
3643 std::string type_str;
3644 llvm::raw_string_ostream type_ss(type_str);
3645 toset->getType()->print(type_ss);
3646 std::string fn_name = "__enzyme_sanitize_nan_" + type_str;
3647
3648 llvm::FunctionType *SanitizeFT = llvm::FunctionType::get(
3649 llvm::Type::getVoidTy(Context),
3650 {toset->getType(), getInt8PtrTy(Context)}, false);
3651
3652 auto SanitizeFCallee = mod->getOrInsertFunction(fn_name, SanitizeFT);
3653 llvm::Function *SanitizeF =
3654 llvm::cast<llvm::Function>(SanitizeFCallee.getCallee());
3655
3656 if (SanitizeF->empty()) {
3657 SanitizeF->setLinkage(Function::LinkageTypes::InternalLinkage);
3658 llvm::BasicBlock *entry =
3659 llvm::BasicBlock::Create(Context, "entry", SanitizeF);
3660 llvm::BasicBlock *good =
3661 llvm::BasicBlock::Create(Context, "good", SanitizeF);
3662 llvm::BasicBlock *bad =
3663 llvm::BasicBlock::Create(Context, "bad", SanitizeF);
3664
3665 llvm::IRBuilder<> B(entry);
3666 llvm::Value *inp = SanitizeF->getArg(0);
3667 llvm::Value *msg_ptr = SanitizeF->getArg(1);
3668
3669 llvm::Value *cmp = B.CreateFCmpUNO(inp, inp);
3670 if (auto VT = llvm::dyn_cast<llvm::VectorType>(inp->getType())) {
3671#if LLVM_VERSION_MAJOR >= 12
3672 unsigned len = VT->getElementCount().getKnownMinValue();
3673#else
3674 unsigned len = VT->getNumElements();
3675#endif
3676 llvm::Value *res = B.CreateExtractElement(cmp, (uint64_t)0);
3677 for (unsigned i = 1; i < len; ++i) {
3678 res = B.CreateOr(res, B.CreateExtractElement(cmp, (uint64_t)i));
3679 }
3680 cmp = res;
3681 }
3682 B.CreateCondBr(cmp, bad, good);
3683
3684 B.SetInsertPoint(good);
3685 B.CreateRetVoid();
3686
3687 B.SetInsertPoint(bad);
3688 if (CustomErrorHandler) {
3689 CustomErrorHandler("NaN Error", wrap(inp), ErrorType::NaNError, nullptr,
3690 wrap(msg_ptr), wrap(&B));
3691 } else {
3692 llvm::FunctionType *PutsFT = llvm::FunctionType::get(
3693 llvm::Type::getInt32Ty(Context), {getInt8PtrTy(Context)}, false);
3694 auto PutsF = mod->getOrInsertFunction("puts", PutsFT);
3695 B.CreateCall(PutsF, msg_ptr);
3696
3697 llvm::FunctionType *ExitFT =
3698 llvm::FunctionType::get(llvm::Type::getVoidTy(Context),
3699 {llvm::Type::getInt32Ty(Context)}, false);
3700 auto ExitF = mod->getOrInsertFunction("exit", ExitFT);
3701 B.CreateCall(
3702 ExitF, llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), 1));
3703 }
3704 B.CreateUnreachable();
3705 }
3706
3707 std::string stringv = "Enzyme: Found nan while computing derivative of ";
3708 if (val) {
3709 std::string str;
3710 llvm::raw_string_ostream ss(str);
3711 if (auto inst = llvm::dyn_cast<llvm::Instruction>(val)) {
3712 ss << *inst << "\n";
3713 emit_backtrace(inst, ss);
3714 } else {
3715 ss << *val << "\n";
3716 }
3717 stringv += ss.str();
3718 } else {
3719 stringv += "\n";
3720 }
3721
3722 BuilderM.CreateCall(SanitizeFCallee, {toset, getString(*mod, stringv)});
3723 }
3724
3726 return unwrap(EnzymeSanitizeDerivatives(wrap(val), wrap(toset),
3727 wrap(&BuilderM), wrap(mask)));
3728 return toset;
3729}
3730
3731llvm::FastMathFlags getFast() {
3732 llvm::FastMathFlags f;
3733 if (EnzymeFastMath)
3734 f.set();
3735 return f;
3736}
3737
3738void addValueToCache(llvm::Value *arg, bool cache_arg, llvm::Type *ty,
3739 llvm::SmallVectorImpl<llvm::Value *> &cacheValues,
3740 llvm::IRBuilder<> &BuilderZ, const Twine &name) {
3741 if (!cache_arg)
3742 return;
3743 if (!arg->getType()->isPointerTy()) {
3744 assert(arg->getType() == ty);
3745 cacheValues.push_back(arg);
3746 return;
3747 }
3748#if LLVM_VERSION_MAJOR < 17
3749 auto PT = cast<PointerType>(arg->getType());
3750#if LLVM_VERSION_MAJOR <= 14
3751 if (PT->getElementType() != ty)
3752 arg = BuilderZ.CreatePointerCast(
3753 arg, PointerType::get(ty, PT->getAddressSpace()), "pcld." + name);
3754#else
3755 auto PT2 = PointerType::get(ty, PT->getAddressSpace());
3756 if (!PT->isOpaqueOrPointeeTypeMatches(PT2))
3757 arg = BuilderZ.CreatePointerCast(
3758 arg, PointerType::get(ty, PT->getAddressSpace()), "pcld." + name);
3759#endif
3760#endif
3761 arg = BuilderZ.CreateLoad(ty, arg, "avld." + name);
3762 cacheValues.push_back(arg);
3763}
3764
3765// julia_decl null means not julia decl, otherwise it is the integer type needed
3766// to cast to
3767llvm::Value *to_blas_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef,
3768 bool cublas, IntegerType *julia_decl,
3769 IRBuilder<> &entryBuilder,
3770 llvm::Twine const &name) {
3771 if (!byRef)
3772 return V;
3773
3774 Value *allocV =
3775 entryBuilder.CreateAlloca(V->getType(), nullptr, "byref." + name);
3776 B.CreateStore(V, allocV);
3777
3778 if (julia_decl)
3779 allocV = B.CreatePointerCast(allocV, getInt8PtrTy(V->getContext()),
3780 "intcast." + name);
3781
3782 return allocV;
3783}
3784llvm::Value *to_blas_fp_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef,
3785 Type *fpTy, IRBuilder<> &entryBuilder,
3786 llvm::Twine const &name) {
3787 if (!byRef)
3788 return V;
3789
3790 Value *allocV =
3791 entryBuilder.CreateAlloca(V->getType(), nullptr, "byref." + name);
3792 B.CreateStore(V, allocV);
3793
3794 if (fpTy)
3795 allocV = B.CreatePointerCast(allocV, fpTy, "fpcast." + name);
3796
3797 return allocV;
3798}
3799
3800Value *is_lower(IRBuilder<> &B, Value *uplo, bool byRef, bool cublas) {
3801 if (cublas) {
3802 Value *isNormal = nullptr;
3803 isNormal = B.CreateICmpEQ(
3804 uplo, ConstantInt::get(uplo->getType(),
3805 /*cublasFillMode_t::CUBLAS_FILL_MODE_LOWER*/ 0));
3806 return isNormal;
3807 }
3808 if (auto CI = dyn_cast<ConstantInt>(uplo)) {
3809 if (CI->getValue() == 'L' || CI->getValue() == 'l')
3810 return ConstantInt::getTrue(B.getContext());
3811 if (CI->getValue() == 'U' || CI->getValue() == 'u')
3812 return ConstantInt::getFalse(B.getContext());
3813 }
3814 if (byRef) {
3815 // can't inspect opaque ptr, so assume 8 (Julia)
3816 IntegerType *charTy = IntegerType::get(uplo->getContext(), 8);
3817 uplo = B.CreateLoad(charTy, uplo, "loaded.trans");
3818
3819 auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'L'));
3820 auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'l'));
3821 // fortran blas
3822 return B.CreateOr(isl, isL);
3823 } else {
3824 // we can inspect scalars
3825 auto capi = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 122));
3826 // TODO we really should just return capi, but for sake of consistency,
3827 // we will accept either here.
3828 auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'L'));
3829 auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'l'));
3830 return B.CreateOr(capi, B.CreateOr(isl, isL));
3831 }
3832}
3833
3834Value *is_nonunit(IRBuilder<> &B, Value *uplo, bool byRef, bool cublas) {
3835 if (cublas) {
3836 Value *isNormal = nullptr;
3837 isNormal =
3838 B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(),
3839 /*CUBLAS_DIAG_NON_UNIT*/ 0));
3840 return isNormal;
3841 }
3842 if (auto CI = dyn_cast<ConstantInt>(uplo)) {
3843 if (CI->getValue() == 'N' || CI->getValue() == 'n')
3844 return ConstantInt::getTrue(B.getContext());
3845 if (CI->getValue() == 'U' || CI->getValue() == 'u')
3846 return ConstantInt::getFalse(B.getContext());
3847 }
3848 if (byRef) {
3849 // can't inspect opaque ptr, so assume 8 (Julia)
3850 IntegerType *charTy = IntegerType::get(uplo->getContext(), 8);
3851 uplo = B.CreateLoad(charTy, uplo, "loaded.nonunit");
3852
3853 auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'N'));
3854 auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'n'));
3855 // fortran blas
3856 return B.CreateOr(isl, isL);
3857 } else {
3858 // we can inspect scalars
3859 auto capi = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 131));
3860 // TODO we really should just return capi, but for sake of consistency,
3861 // we will accept either here.
3862 auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'N'));
3863 auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'n'));
3864 return B.CreateOr(capi, B.CreateOr(isl, isL));
3865 }
3866}
3867
3868llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef,
3869 bool cublas) {
3870 if (cublas) {
3871 Value *isNormal = nullptr;
3872 isNormal = B.CreateICmpEQ(
3873 trans, ConstantInt::get(trans->getType(),
3874 /*cublasOperation_t::CUBLAS_OP_N*/ 0));
3875 return isNormal;
3876 }
3877 // Explicitly support 'N' always, since we use in the rule infra
3878 if (auto CI = dyn_cast<ConstantInt>(trans)) {
3879 if (CI->getValue() == 'N' || CI->getValue() == 'n')
3880 return ConstantInt::getTrue(
3881 B.getContext()); //(Type::getInt1Ty(B.getContext()), true);
3882 }
3883 if (byRef) {
3884 // can't inspect opaque ptr, so assume 8 (Julia)
3885 IntegerType *charTy = IntegerType::get(trans->getContext(), 8);
3886 trans = B.CreateLoad(charTy, trans, "loaded.trans");
3887
3888 auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N'));
3889 auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'));
3890 // fortran blas
3891 return B.CreateOr(isn, isN);
3892 } else {
3893 // TODO we really should just return capi, but for sake of consistency,
3894 // we will accept either here.
3895 // we can inspect scalars
3896 auto capi = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111));
3897 auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N'));
3898 auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'));
3899 // fortran blas
3900 return B.CreateOr(capi, B.CreateOr(isn, isN));
3901 }
3902}
3903
3904llvm::Value *is_left(IRBuilder<> &B, llvm::Value *side, bool byRef,
3905 bool cublas) {
3906 if (cublas) {
3907 Value *isNormal = nullptr;
3908 isNormal = B.CreateICmpEQ(
3909 side, ConstantInt::get(side->getType(),
3910 /*cublasSideMode_t::CUBLAS_SIDE_LEFT*/ 0));
3911 return isNormal;
3912 }
3913 // Explicitly support 'L'/'R' always, since we use in the rule infra
3914 if (auto CI = dyn_cast<ConstantInt>(side)) {
3915 if (CI->getValue() == 'L' || CI->getValue() == 'l')
3916 return ConstantInt::getTrue(B.getContext());
3917 if (CI->getValue() == 'R' || CI->getValue() == 'r')
3918 return ConstantInt::getFalse(B.getContext());
3919 }
3920 if (byRef) {
3921 // can't inspect opaque ptr, so assume 8 (Julia)
3922 IntegerType *charTy = IntegerType::get(side->getContext(), 8);
3923 side = B.CreateLoad(charTy, side, "loaded.side");
3924
3925 auto isL = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'L'));
3926 auto isl = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'l'));
3927 // fortran blas
3928 return B.CreateOr(isl, isL);
3929 } else {
3930 // TODO we really should just return capi, but for sake of consistency,
3931 // we will accept either here.
3932 // we can inspect scalars
3933 auto capi = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 141));
3934 auto isL = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'L'));
3935 auto isl = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'l'));
3936 // fortran blas
3937 return B.CreateOr(capi, B.CreateOr(isl, isL));
3938 }
3939}
3940
3941// Ok. Here we are.
3942// netlib declares trans args as something out of
3943// N,n,T,t,C,c, represented as 8 bit chars.
3944// However, if we ask openBlas c ABI,
3945// it is one of the following 32 bit integers values:
3946// enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
3947llvm::Value *transpose(std::string floatType, IRBuilder<> &B, llvm::Value *V,
3948 bool cublas) {
3949 llvm::Type *T = V->getType();
3950 if (cublas) {
3951 auto isT1 = B.CreateICmpEQ(V, ConstantInt::get(T, 1));
3952 auto isT0 = B.CreateICmpEQ(V, ConstantInt::get(T, 0));
3953 return B.CreateSelect(isT1, ConstantInt::get(V->getType(), 0),
3954 B.CreateSelect(isT0,
3955 ConstantInt::get(V->getType(), 1),
3956 ConstantInt::get(V->getType(), 42)));
3957 } else if (T->isIntegerTy(8)) {
3958 if (floatType == "z" || floatType == "c") {
3959 auto isn = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'n'));
3960 auto sel1 = B.CreateSelect(isn, ConstantInt::get(V->getType(), 'c'),
3961 ConstantInt::get(V->getType(), 0));
3962
3963 auto isN = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'N'));
3964 auto sel2 =
3965 B.CreateSelect(isN, ConstantInt::get(V->getType(), 'C'), sel1);
3966
3967 auto ist = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'c'));
3968 auto sel3 =
3969 B.CreateSelect(ist, ConstantInt::get(V->getType(), 'n'), sel2);
3970
3971 auto isT = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'C'));
3972 return B.CreateSelect(isT, ConstantInt::get(V->getType(), 'N'), sel3);
3973 } else {
3974 // the base case here of 'C' or 'c' becomes simply 'N'
3975 auto isn = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'n'));
3976 auto sel1 = B.CreateSelect(isn, ConstantInt::get(V->getType(), 't'),
3977 ConstantInt::get(V->getType(), 'N'));
3978
3979 auto isN = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'N'));
3980 auto sel2 =
3981 B.CreateSelect(isN, ConstantInt::get(V->getType(), 'T'), sel1);
3982
3983 auto ist = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 't'));
3984 auto sel3 =
3985 B.CreateSelect(ist, ConstantInt::get(V->getType(), 'n'), sel2);
3986
3987 auto isT = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'T'));
3988 return B.CreateSelect(isT, ConstantInt::get(V->getType(), 'N'), sel3);
3989 }
3990
3991 } else if (T->isIntegerTy(32)) {
3992 auto is111 = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 111));
3993 auto sel1 = B.CreateSelect(
3994 B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 112)),
3995 ConstantInt::get(V->getType(), 111), ConstantInt::get(V->getType(), 0));
3996 return B.CreateSelect(is111, ConstantInt::get(V->getType(), 112), sel1);
3997 } else {
3998 std::string s;
3999 llvm::raw_string_ostream ss(s);
4000 ss << "cannot handle unknown trans blas value\n" << V;
4001 if (CustomErrorHandler) {
4002 CustomErrorHandler(ss.str().c_str(), nullptr, ErrorType::NoDerivative,
4003 nullptr, nullptr, nullptr);
4004 } else {
4005 EmitFailure("unknown trans blas value", B.getCurrentDebugLocation(),
4006 B.GetInsertBlock()->getParent(), ss.str());
4007 }
4008 return V;
4009 }
4010}
4011
4012// Implement the following logic to get the width of a matrix
4013// if (cache_A) {
4014// ld_A = (arg_transa == 'N') ? arg_k : arg_m;
4015// } else {
4016// ld_A = arg_lda;
4017// }
4018llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B,
4019 llvm::ArrayRef<llvm::Value *> trans,
4020 llvm::Value *arg_ld, llvm::Value *dim1,
4021 llvm::Value *dim2, bool cacheMat, bool byRef,
4022 bool cublas) {
4023 if (!cacheMat)
4024 return arg_ld;
4025
4026 assert(trans.size() == 1);
4027
4028 llvm::Value *width =
4029 CreateSelect(B, is_normal(B, trans[0], byRef, cublas), dim2, dim1);
4030
4031 return width;
4032}
4033
4034llvm::Value *transpose(std::string floatType, llvm::IRBuilder<> &B,
4035 llvm::Value *V, bool byRef, bool cublas,
4036 llvm::IntegerType *julia_decl,
4037 llvm::IRBuilder<> &entryBuilder,
4038 const llvm::Twine &name) {
4039
4040 if (!byRef) {
4041 // Explicitly support 'N' always, since we use in the rule infra
4042 if (auto CI = dyn_cast<ConstantInt>(V)) {
4043 if (floatType == "c" || floatType == "z") {
4044 if (CI->getValue() == 'N')
4045 return ConstantInt::get(CI->getType(), 'C');
4046 if (CI->getValue() == 'c')
4047 return ConstantInt::get(CI->getType(), 'c');
4048 } else {
4049 if (CI->getValue() == 'N')
4050 return ConstantInt::get(CI->getType(), 'T');
4051 if (CI->getValue() == 'n')
4052 return ConstantInt::get(CI->getType(), 't');
4053 }
4054 }
4055
4056 // cblas
4057 if (!cublas)
4058 return B.CreateSelect(
4059 B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 111)),
4060 ConstantInt::get(V->getType(), 112),
4061 ConstantInt::get(V->getType(), 111));
4062 }
4063
4064 if (byRef) {
4065 auto charType = IntegerType::get(V->getContext(), 8);
4066 V = B.CreateLoad(charType, V, "ld." + name);
4067 }
4068
4069 V = transpose(floatType, B, V, cublas);
4070
4071 return to_blas_callconv(B, V, byRef, cublas, julia_decl, entryBuilder,
4072 "transpose." + name);
4073}
4074
4075llvm::Value *load_if_ref(llvm::IRBuilder<> &B, llvm::Type *intType,
4076 llvm::Value *V, bool byRef) {
4077 if (!byRef)
4078 return V;
4079
4080 if (V->getType()->isIntegerTy())
4081 V = B.CreateIntToPtr(V, getUnqual(intType));
4082 else
4083 V = B.CreatePointerCast(
4084 V, PointerType::get(
4085 intType, cast<PointerType>(V->getType())->getAddressSpace()));
4086 return B.CreateLoad(intType, V);
4087}
4088
4089SmallVector<llvm::Value *, 1> get_blas_row(llvm::IRBuilder<> &B,
4090 ArrayRef<llvm::Value *> transA,
4091 bool byRef, bool cublas) {
4092 assert(transA.size() == 1);
4093 auto trans = transA[0];
4094 if (byRef) {
4095 auto charType = IntegerType::get(trans->getContext(), 8);
4096 trans = B.CreateLoad(charType, trans, "ld.row.trans");
4097 }
4098
4099 Value *cond = nullptr;
4100 if (!cublas) {
4101
4102 if (!byRef) {
4103 cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111));
4104 } else {
4105 auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'));
4106 auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N'));
4107 cond = B.CreateOr(isN, isn);
4108 }
4109 } else {
4110 // CUBLAS_OP_N = 0, CUBLAS_OP_T = 1, CUBLAS_OP_C = 2
4111 // TODO: verify
4112 cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 0));
4113 }
4114 return {cond};
4115}
4116SmallVector<llvm::Value *, 1> get_blas_row(llvm::IRBuilder<> &B,
4117 ArrayRef<llvm::Value *> transA,
4118 ArrayRef<llvm::Value *> row,
4119 ArrayRef<llvm::Value *> col,
4120 bool byRef, bool cublas) {
4121 auto conds = get_blas_row(B, transA, byRef, cublas);
4122 assert(row.size() == col.size());
4123 SmallVector<Value *, 1> toreturn;
4124 for (size_t i = 0; i < row.size(); i++) {
4125 auto lhs = row[i];
4126 auto rhs = col[i];
4127 if (lhs->getType() != rhs->getType())
4128 rhs = B.CreatePointerCast(rhs, lhs->getType());
4129 toreturn.push_back(B.CreateSelect(conds[0], lhs, rhs));
4130 }
4131 return toreturn;
4132}
4133
4134// return how many Special pointers are in T (count > 0),
4135// and if there is anything else in T (all == false)
4137 if (isa<PointerType>(T)) {
4138 if (isSpecialPtr(T)) {
4139 count++;
4140 if (T->getPointerAddressSpace() != AddressSpace::Tracked)
4141 derived = true;
4142 }
4143 } else if (isa<StructType>(T) || isa<ArrayType>(T) || isa<VectorType>(T)) {
4144 for (Type *ElT : T->subtypes()) {
4145 auto sub = CountTrackedPointers(ElT);
4146 count += sub.count;
4147 all &= sub.all;
4148 derived |= sub.derived;
4149 }
4150 if (isa<ArrayType>(T))
4151 count *= cast<ArrayType>(T)->getNumElements();
4152 else if (isa<VectorType>(T)) {
4153#if LLVM_VERSION_MAJOR >= 12
4154 count *= cast<VectorType>(T)->getElementCount().getKnownMinValue();
4155#else
4156 count *= cast<VectorType>(T)->getNumElements();
4157#endif
4158 }
4159 }
4160 if (count == 0)
4161 all = false;
4162}
4163
4164#if LLVM_VERSION_MAJOR >= 20
4165bool collectOffset(GEPOperator *gep, const DataLayout &DL, unsigned BitWidth,
4166 SmallMapVector<Value *, APInt, 4> &VariableOffsets,
4167 APInt &ConstantOffset)
4168#else
4169bool collectOffset(GEPOperator *gep, const DataLayout &DL, unsigned BitWidth,
4170 MapVector<Value *, APInt> &VariableOffsets,
4171 APInt &ConstantOffset)
4172#endif
4173{
4174#if LLVM_VERSION_MAJOR >= 13
4175 return gep->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset);
4176#else
4177 assert(BitWidth == DL.getIndexSizeInBits(gep->getPointerAddressSpace()) &&
4178 "The offset bit width does not match DL specification.");
4179
4180 auto CollectConstantOffset = [&](APInt Index, uint64_t Size) {
4181 Index = Index.sextOrTrunc(BitWidth);
4182 APInt IndexedSize = APInt(BitWidth, Size);
4183 ConstantOffset += Index * IndexedSize;
4184 };
4185
4186 for (gep_type_iterator GTI = gep_type_begin(gep), GTE = gep_type_end(gep);
4187 GTI != GTE; ++GTI) {
4188 // Scalable vectors are multiplied by a runtime constant.
4189 bool ScalableType = isa<ScalableVectorType>(GTI.getIndexedType());
4190
4191 Value *V = GTI.getOperand();
4192 StructType *STy = GTI.getStructTypeOrNull();
4193 // Handle ConstantInt if possible.
4194 if (auto ConstOffset = dyn_cast<ConstantInt>(V)) {
4195 if (ConstOffset->isZero())
4196 continue;
4197 // If the type is scalable and the constant is not zero (vscale * n * 0 =
4198 // 0) bailout.
4199 // TODO: If the runtime value is accessible at any point before DWARF
4200 // emission, then we could potentially keep a forward reference to it
4201 // in the debug value to be filled in later.
4202 if (ScalableType)
4203 return false;
4204 // Handle a struct index, which adds its field offset to the pointer.
4205 if (STy) {
4206 unsigned ElementIdx = ConstOffset->getZExtValue();
4207 const StructLayout *SL = DL.getStructLayout(STy);
4208 // Element offset is in bytes.
4209 CollectConstantOffset(APInt(BitWidth, SL->getElementOffset(ElementIdx)),
4210 1);
4211 continue;
4212 }
4213 CollectConstantOffset(ConstOffset->getValue(),
4214 DL.getTypeAllocSize(GTI.getIndexedType()));
4215 continue;
4216 }
4217
4218 if (STy || ScalableType)
4219 return false;
4220 APInt IndexedSize =
4221 APInt(BitWidth, DL.getTypeAllocSize(GTI.getIndexedType()));
4222 // Insert an initial offset of 0 for V iff none exists already, then
4223 // increment the offset by IndexedSize.
4224 if (IndexedSize != 0) {
4225 VariableOffsets.insert({V, APInt(BitWidth, 0)});
4226 VariableOffsets[V] += IndexedSize;
4227 }
4228 }
4229 return true;
4230#endif
4231}
4232
4233llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &B,
4234 llvm::Intrinsic::ID ID, llvm::Type *RetTy,
4235 llvm::ArrayRef<llvm::Value *> Args,
4236 llvm::Instruction *FMFSource,
4237 const llvm::Twine &Name) {
4238#if LLVM_VERSION_MAJOR >= 16
4239 llvm::CallInst *nres = B.CreateIntrinsic(RetTy, ID, Args, FMFSource, Name);
4240#else
4241 SmallVector<Intrinsic::IITDescriptor, 1> Table;
4242 Intrinsic::getIntrinsicInfoTableEntries(ID, Table);
4243 ArrayRef<Intrinsic::IITDescriptor> TableRef(Table);
4244
4245 SmallVector<Type *, 2> ArgTys;
4246 ArgTys.reserve(Args.size());
4247 for (auto &I : Args)
4248 ArgTys.push_back(I->getType());
4249 FunctionType *FTy = FunctionType::get(RetTy, ArgTys, false);
4250 SmallVector<Type *, 2> OverloadTys;
4251 Intrinsic::MatchIntrinsicTypesResult Res =
4252 matchIntrinsicSignature(FTy, TableRef, OverloadTys);
4253 (void)Res;
4254 assert(Res == Intrinsic::MatchIntrinsicTypes_Match && TableRef.empty() &&
4255 "Wrong types for intrinsic!");
4256 Function *Fn = Intrinsic::getDeclaration(B.GetInsertPoint()->getModule(), ID,
4257 OverloadTys);
4258 CallInst *nres = B.CreateCall(Fn, Args, {}, Name);
4259 if (FMFSource)
4260 nres->copyFastMathFlags(FMFSource);
4261#endif
4262 return nres;
4263}
4264
4265/* Bithack to compute 1 ulp as follows:
4266double ulp(double res) {
4267 double nres = res;
4268 (*(uint64_t*)&nres) = 0x1 ^ *(uint64_t*)&nres;
4269 return abs(nres - res);
4270}
4271*/
4272llvm::Value *get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res) {
4273 auto ty = res->getType();
4274 unsigned tsize = builder.GetInsertBlock()
4275 ->getParent()
4276 ->getParent()
4277 ->getDataLayout()
4278 .getTypeSizeInBits(ty);
4279
4280 auto ity = IntegerType::get(ty->getContext(), tsize);
4281
4282 auto as_int = builder.CreateBitCast(res, ity);
4283 auto masked = builder.CreateXor(as_int, ConstantInt::get(ity, 1));
4284 auto neighbor = builder.CreateBitCast(masked, ty);
4285
4286 auto diff = builder.CreateFSub(res, neighbor);
4287
4288 auto absres = builder.CreateIntrinsic(Intrinsic::fabs,
4289 ArrayRef<Type *>(diff->getType()),
4290 ArrayRef<Value *>(diff));
4291
4292 return absres;
4293}
4294
4295llvm::Value *EmitNoDerivativeError(const std::string &message,
4296 llvm::Instruction &inst,
4297 GradientUtils *gutils,
4298 llvm::IRBuilder<> &Builder2,
4299 llvm::Value *condition) {
4300 if (CustomErrorHandler) {
4301 return unwrap(CustomErrorHandler(message.c_str(), wrap(&inst),
4303 wrap(condition), wrap(&Builder2)));
4304 } else if (EnzymeRuntimeError) {
4305 auto &M = *inst.getParent()->getParent()->getParent();
4306 FunctionType *FT = FunctionType::get(Type::getInt32Ty(M.getContext()),
4307 {getInt8PtrTy(M.getContext())}, false);
4308 std::string str;
4309 raw_string_ostream ss(str);
4310 ss << message << "\n";
4311 emit_backtrace(&inst, ss);
4312 auto msg = getString(M, ss.str());
4313 ;
4314 auto PutsF = M.getOrInsertFunction("puts", FT);
4315 Builder2.CreateCall(PutsF, msg);
4316
4317 FunctionType *FT2 =
4318 FunctionType::get(Type::getVoidTy(M.getContext()),
4319 {Type::getInt32Ty(M.getContext())}, false);
4320
4321 auto ExitF = M.getOrInsertFunction("exit", FT2);
4322 Builder2.CreateCall(ExitF,
4323 ConstantInt::get(Type::getInt32Ty(M.getContext()), 1));
4324 return nullptr;
4325 } else {
4326 if (StringRef(message).contains("cannot handle above cast")) {
4327 gutils->TR.dump();
4328 }
4329 EmitFailure("NoDerivative", inst.getDebugLoc(), &inst, message);
4330 return nullptr;
4331 }
4332}
4333
4334bool EmitNoDerivativeError(const std::string &message, Value *todiff,
4335 RequestContext &context) {
4336 Value *toshow = todiff;
4337 if (context.req) {
4338 toshow = context.req;
4339 }
4340 if (CustomErrorHandler) {
4341 CustomErrorHandler(message.c_str(), wrap(toshow), ErrorType::NoDerivative,
4342 nullptr, wrap(todiff), wrap(context.ip));
4343 return true;
4344 } else if (context.ip && EnzymeRuntimeError) {
4345 auto &M = *context.ip->GetInsertBlock()->getParent()->getParent();
4346 FunctionType *FT = FunctionType::get(Type::getInt32Ty(M.getContext()),
4347 {getInt8PtrTy(M.getContext())}, false);
4348 std::string str;
4349 raw_string_ostream ss(str);
4350 ss << message << "\n";
4351 if (auto inst = dyn_cast<Instruction>(todiff))
4352 emit_backtrace(inst, ss);
4353 auto msg = getString(M, ss.str());
4354 auto PutsF = M.getOrInsertFunction("puts", FT);
4355 context.ip->CreateCall(PutsF, msg);
4356
4357 FunctionType *FT2 =
4358 FunctionType::get(Type::getVoidTy(M.getContext()),
4359 {Type::getInt32Ty(M.getContext())}, false);
4360
4361 auto ExitF = M.getOrInsertFunction("exit", FT2);
4362 context.ip->CreateCall(
4363 ExitF, ConstantInt::get(Type::getInt32Ty(M.getContext()), 1));
4364 return true;
4365 } else if (context.req) {
4366 EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
4367 message);
4368 return true;
4369 } else if (auto arg = dyn_cast<Instruction>(todiff)) {
4370 auto loc = arg->getDebugLoc();
4371 EmitFailure("NoDerivative", loc, arg, message);
4372 return true;
4373 }
4374 return false;
4375}
4376
4377void EmitNoTypeError(const std::string &message, llvm::Instruction &inst,
4378 GradientUtils *gutils, llvm::IRBuilder<> &Builder2) {
4379 if (CustomErrorHandler) {
4380 CustomErrorHandler(message.c_str(), wrap(&inst), ErrorType::NoType,
4381 gutils->TR.analyzer, nullptr, wrap(&Builder2));
4382 } else if (EnzymeRuntimeError) {
4383 auto &M = *inst.getParent()->getParent()->getParent();
4384 FunctionType *FT = FunctionType::get(Type::getInt32Ty(M.getContext()),
4385 {getInt8PtrTy(M.getContext())}, false);
4386 std::string str;
4387 raw_string_ostream ss(str);
4388 ss << message << "\n";
4389 emit_backtrace(&inst, ss);
4390 auto msg = getString(M, ss.str());
4391 auto PutsF = M.getOrInsertFunction("puts", FT);
4392 Builder2.CreateCall(PutsF, msg);
4393
4394 FunctionType *FT2 =
4395 FunctionType::get(Type::getVoidTy(M.getContext()),
4396 {Type::getInt32Ty(M.getContext())}, false);
4397
4398 auto ExitF = M.getOrInsertFunction("exit", FT2);
4399 Builder2.CreateCall(ExitF,
4400 ConstantInt::get(Type::getInt32Ty(M.getContext()), 1));
4401 } else {
4402 std::string str;
4403 raw_string_ostream ss(str);
4404 ss << message << "\n";
4405 gutils->TR.dump(ss);
4406 EmitFailure("CannotDeduceType", inst.getDebugLoc(), &inst, ss.str());
4407 }
4408}
4409
4410std::vector<std::tuple<llvm::Type *, size_t, size_t>>
4411parseTrueType(const llvm::MDNode *md, DerivativeMode Mode, bool const_src) {
4412 std::vector<std::pair<ConcreteType, size_t>> parsed;
4413 for (size_t i = 0; i < md->getNumOperands(); i += 2) {
4414 ConcreteType base(
4415 llvm::cast<llvm::MDString>(md->getOperand(i))->getString(),
4416 md->getContext());
4417 auto size = llvm::cast<llvm::ConstantInt>(
4418 llvm::cast<llvm::ConstantAsMetadata>(md->getOperand(i + 1))
4419 ->getValue())
4420 ->getSExtValue();
4421 parsed.emplace_back(base, size);
4422 }
4423
4424 std::vector<std::tuple<llvm::Type *, size_t, size_t>> toIterate;
4425 size_t idx = 0;
4426 while (idx < parsed.size()) {
4427
4428 auto dt = parsed[idx].first;
4429 size_t start = parsed[idx].second;
4430 size_t end = 0x0fffffff;
4431 for (idx = idx + 1; idx < parsed.size(); ++idx) {
4432 bool Legal = true;
4433 auto tmp = dt;
4434 auto next = parsed[idx].first;
4435 tmp.checkedOrIn(next, /*PointerIntSame*/ true, Legal);
4436 // Prevent fusion of {Anything, Float} since anything is an int rule
4437 // but float requires zeroing.
4438 if ((dt == BaseType::Anything &&
4439 (next != BaseType::Anything && next.isKnown())) ||
4440 (next == BaseType::Anything &&
4441 (dt != BaseType::Anything && dt.isKnown())))
4442 Legal = false;
4443 if (!Legal) {
4444 if (Mode == DerivativeMode::ForwardMode ||
4446 // if both are floats (of any type), forward mode is the same.
4447 // + [potentially zero if const, otherwise copy]
4448 // if both are int/pointer (of any type), also the same
4449 // + copy
4450 // if known non-constant, also the same
4451 // + copy
4452 if ((parsed[idx].first.isFloat() == nullptr) ==
4453 (parsed[idx - 1].first.isFloat() == nullptr)) {
4454 Legal = true;
4455 }
4456 if (const_src) {
4457 Legal = true;
4458 }
4459 }
4460 if (!Legal) {
4461 end = parsed[idx].second;
4462 break;
4463 }
4464 } else
4465 dt = tmp;
4466 }
4467 assert(dt.isKnown());
4468 toIterate.emplace_back(dt.isFloat(), start, end - start);
4469 }
4470 return toIterate;
4471}
4472
4473void dumpModule(llvm::Module *mod) { llvm::errs() << *mod << "\n"; }
4474
4475void dumpValue(llvm::Value *val) { llvm::errs() << *val << "\n"; }
4476
4477void dumpBlock(llvm::BasicBlock *blk) { llvm::errs() << *blk << "\n"; }
4478
4479void dumpType(llvm::Type *ty) { llvm::errs() << *ty << "\n"; }
4480
4482
4483bool isNVLoad(const llvm::Value *V) {
4484 auto II = dyn_cast<IntrinsicInst>(V);
4485 if (!II)
4486 return false;
4487 switch (II->getIntrinsicID()) {
4488 case Intrinsic::nvvm_ldu_global_i:
4489 case Intrinsic::nvvm_ldu_global_p:
4490 case Intrinsic::nvvm_ldu_global_f:
4491#if LLVM_VERSION_MAJOR < 20
4492 case Intrinsic::nvvm_ldg_global_i:
4493 case Intrinsic::nvvm_ldg_global_p:
4494 case Intrinsic::nvvm_ldg_global_f:
4495#endif
4496 return true;
4497 default:
4498 return false;
4499 }
4500 return false;
4501}
4502
4503bool notCapturedBefore(llvm::Value *V, Instruction *inst,
4504 size_t checkLoadCaptures) {
4505 Instruction *VI = dyn_cast<Instruction>(V);
4506 if (!VI)
4507 VI = &*inst->getParent()->getParent()->getEntryBlock().begin();
4508 else
4509 VI = VI->getNextNode();
4510 SmallPtrSet<BasicBlock *, 1> regionBetween;
4511 if (inst) {
4512 SmallVector<BasicBlock *, 1> todo;
4513 todo.push_back(VI->getParent());
4514 while (todo.size()) {
4515 auto cur = todo.pop_back_val();
4516 if (regionBetween.count(cur))
4517 continue;
4518 regionBetween.insert(cur);
4519 if (cur == inst->getParent())
4520 continue;
4521 for (auto BB : successors(cur))
4522 todo.push_back(BB);
4523 }
4524 }
4525 SmallVector<std::tuple<Instruction *, size_t, Value *>, 1> todo;
4526 for (auto U : V->users()) {
4527 todo.emplace_back(cast<Instruction>(U), checkLoadCaptures, V);
4528 }
4529 std::set<std::tuple<Value *, size_t, Value *>> seen;
4530 while (todo.size()) {
4531 auto pair = todo.pop_back_val();
4532 if (seen.count(pair))
4533 continue;
4534 seen.insert(pair);
4535 auto UI = std::get<0>(pair);
4536 auto level = std::get<1>(pair);
4537 auto prev = std::get<2>(pair);
4538 if (inst) {
4539 if (!regionBetween.count(UI->getParent()))
4540 continue;
4541 if (UI->getParent() == VI->getParent()) {
4542 if (UI->comesBefore(VI))
4543 continue;
4544 }
4545 if (UI->getParent() == inst->getParent())
4546 if (inst->comesBefore(UI))
4547 continue;
4548 }
4549
4550 if (isPointerArithmeticInst(UI, /*includephi*/ true,
4551 /*includebin*/ true)) {
4552 for (auto U2 : UI->users()) {
4553 auto UI2 = cast<Instruction>(U2);
4554 todo.emplace_back(UI2, level, UI);
4555 }
4556 continue;
4557 }
4558
4559 if (isa<MemSetInst>(UI))
4560 continue;
4561
4562 if (isa<MemTransferInst>(UI)) {
4563 if (level == 0)
4564 continue;
4565 if (UI->getOperand(1) != prev)
4566 continue;
4567 }
4568
4569 if (auto CI = dyn_cast<CallBase>(UI)) {
4570#if LLVM_VERSION_MAJOR >= 14
4571 for (size_t i = 0, size = CI->arg_size(); i < size; i++)
4572#else
4573 for (size_t i = 0, size = CI->getNumArgOperands(); i < size; i++)
4574#endif
4575 {
4576 if (prev == CI->getArgOperand(i)) {
4577 if (isNoCapture(CI, i) && level == 0)
4578 continue;
4579 return false;
4580 }
4581 }
4582 return true;
4583 }
4584
4585 if (isa<CmpInst>(UI)) {
4586 continue;
4587 }
4588 if (isa<LoadInst>(UI)) {
4589 if (level) {
4590 for (auto U2 : UI->users()) {
4591 auto UI2 = cast<Instruction>(U2);
4592 todo.emplace_back(UI2, level - 1, UI);
4593 }
4594 }
4595 continue;
4596 }
4597 // storing into it.
4598 if (auto SI = dyn_cast<StoreInst>(UI)) {
4599 if (SI->getValueOperand() != prev) {
4600 continue;
4601 }
4602 }
4603 return false;
4604 }
4605 return true;
4606}
4607
4608bool notCaptured(llvm::Value *V) { return notCapturedBefore(V, nullptr, 0); }
4609
4610// Return true if guaranteed not to alias
4611// Return false if guaranteed to alias [with possible offset depending on flag].
4612// Return {} if no information is given.
4613#if LLVM_VERSION_MAJOR >= 16
4614std::optional<bool>
4615#else
4616llvm::Optional<bool>
4617#endif
4618arePointersGuaranteedNoAlias(TargetLibraryInfo &TLI, llvm::AAResults &AA,
4619 llvm::LoopInfo &LI, llvm::Value *op0,
4620 llvm::Value *op1, bool offsetAllowed) {
4621 auto lhs = getBaseObject(op0, offsetAllowed);
4622 auto rhs = getBaseObject(op1, offsetAllowed);
4623
4624 if (lhs == rhs) {
4625 return false;
4626 }
4627 if (auto i1 = dyn_cast<Instruction>(op1))
4628 if (isa<ConstantPointerNull>(op0) &&
4629 hasMetadata(i1, LLVMContext::MD_nonnull)) {
4630 return true;
4631 }
4632 if (auto i0 = dyn_cast<Instruction>(op0))
4633 if (isa<ConstantPointerNull>(op1) &&
4634 hasMetadata(i0, LLVMContext::MD_nonnull)) {
4635 return true;
4636 }
4637
4638 if (!lhs->getType()->isPointerTy() && !rhs->getType()->isPointerTy())
4639 return {};
4640
4641 bool noalias_lhs = isNoAlias(lhs);
4642 bool noalias_rhs = isNoAlias(rhs);
4643
4644 bool noalias[2] = {noalias_lhs, noalias_rhs};
4645
4646 for (int i = 0; i < 2; i++) {
4647 Value *start = (i == 0) ? lhs : rhs;
4648 Value *end = (i == 0) ? rhs : lhs;
4649 if (noalias[i]) {
4650 if (noalias[1 - i]) {
4651 return true;
4652 }
4653 if (isa<Argument>(end)) {
4654 return true;
4655 }
4656 if (auto endi = dyn_cast<Instruction>(end)) {
4657 if (notCapturedBefore(start, endi, 0)) {
4658 return true;
4659 }
4660 }
4661 }
4662 if (auto ld = dyn_cast<LoadInst>(start)) {
4663 auto base = getBaseObject(ld->getOperand(0), /*offsetAllowed*/ false);
4664 if (isAllocationCall(base, TLI)) {
4665 if (isa<Argument>(end))
4666 return true;
4667 if (auto endi = dyn_cast<Instruction>(end))
4668 if (isNoAlias(end) || (notCapturedBefore(start, endi, 1))) {
4669 Instruction *starti = dyn_cast<Instruction>(start);
4670 if (!starti) {
4671 if (!isa<Argument>(start))
4672 continue;
4673 starti =
4674 &cast<Argument>(start)->getParent()->getEntryBlock().front();
4675 }
4676
4677 bool overwritten = false;
4679 LI, starti, endi, [&](Instruction *I) -> bool {
4680 if (!I->mayWriteToMemory())
4681 return /*earlyBreak*/ false;
4682
4683 if (writesToMemoryReadBy(nullptr, AA, TLI,
4684 /*maybeReader*/ ld,
4685 /*maybeWriter*/ I)) {
4686 overwritten = true;
4687 return /*earlyBreak*/ true;
4688 }
4689 return /*earlyBreak*/ false;
4690 });
4691
4692 if (!overwritten) {
4693 return true;
4694 }
4695 }
4696 }
4697 }
4698 }
4699
4700 return {};
4701}
4702
4703static Value *constantInBoundsGEPHelper(llvm::IRBuilder<> &B, llvm::Type *type,
4704 llvm::Value *value,
4705 ArrayRef<unsigned> path) {
4706 SmallVector<Value *, 2> vals;
4707 vals.push_back(ConstantInt::get(B.getInt64Ty(), 0));
4708 for (auto v : path) {
4709 vals.push_back(ConstantInt::get(B.getInt32Ty(), v));
4710 }
4711 return B.CreateInBoundsGEP(type, value, vals);
4712}
4713
4714llvm::Value *moveSRetToFromRoots(llvm::IRBuilder<> &B, llvm::Type *jltype,
4715 llvm::Value *sret, llvm::Type *root_ty,
4716 llvm::Value *rootRet, size_t rootOffset,
4717 SRetRootMovement direction) {
4718 std::deque<std::pair<llvm::Type *, std::vector<unsigned>>> todo = {
4719 {jltype, {}}};
4720 SmallVector<Value *> extracted;
4721 Value *val = sret;
4722 auto rootOffset0 = rootOffset;
4723 while (!todo.empty()) {
4724 auto cur = std::move(todo[0]);
4725 todo.pop_front();
4726 auto path = std::move(cur.second);
4727 auto ty = cur.first;
4728
4729 if (auto PT = dyn_cast<PointerType>(ty)) {
4730 if (!isSpecialPtr(PT))
4731 continue;
4732
4733 Value *loc = nullptr;
4734 switch (direction) {
4739 loc = constantInBoundsGEPHelper(B, root_ty, rootRet, rootOffset);
4740 break;
4741 default:
4742 llvm_unreachable("Unhandled");
4743 }
4744 switch (direction) {
4746 Value *outloc = constantInBoundsGEPHelper(B, jltype, sret, path);
4747 outloc = B.CreateLoad(ty, outloc);
4748 B.CreateStore(outloc, loc);
4749 break;
4750 }
4752 Value *outloc = GradientUtils::extractMeta(B, sret, path);
4753 outloc = B.CreatePointerCast(
4754 outloc, PointerType::get(StructType::get(outloc->getContext(), {}),
4755 Tracked));
4756 B.CreateStore(outloc, loc);
4757 break;
4758 }
4760 loc = B.CreateLoad(ty, loc);
4761 val = B.CreateInsertValue(val, loc, path);
4762 break;
4763 }
4766 *B.GetInsertBlock()->getParent()->getParent(), ty, false);
4767 val = B.CreateInsertValue(val, loc, path);
4768 break;
4769 }
4771 Value *outloc = constantInBoundsGEPHelper(B, jltype, sret, path);
4772 loc = B.CreateLoad(ty, loc);
4773 extracted.push_back(loc);
4774 B.CreateStore(loc, outloc);
4775 break;
4776 }
4777 default:
4778 llvm_unreachable("Unhandled");
4779 break;
4780 }
4781
4782 rootOffset += 1;
4783 continue;
4784 }
4785
4786 if (auto AT = dyn_cast<ArrayType>(ty)) {
4787 for (size_t i = 0, E = AT->getNumElements(); i < E; i++) {
4788 std::vector<unsigned> path2(path);
4789 path2.push_back(E - 1 - i);
4790 todo.emplace_front(AT->getElementType(), path2);
4791 }
4792 continue;
4793 }
4794
4795 if (auto VT = dyn_cast<VectorType>(ty)) {
4796 for (size_t i = 0, E = VT->getElementCount().getKnownMinValue(); i < E;
4797 i++) {
4798 std::vector<unsigned> path2(path);
4799 path2.push_back(E - 1 - i);
4800 todo.emplace_front(VT->getElementType(), path2);
4801 }
4802 continue;
4803 }
4804
4805 if (auto ST = dyn_cast<StructType>(ty)) {
4806 for (size_t i = 0, E = ST->getNumElements(); i < E; i++) {
4807 std::vector<unsigned> path2(path);
4808 path2.push_back(E - 1 - i);
4809 todo.emplace_front(ST->getTypeAtIndex(E - 1 - i), path2);
4810 }
4811 continue;
4812 }
4813 }
4814
4816 auto obj = getBaseObject(sret);
4817 auto PT = cast<PointerType>(obj->getType());
4818 assert(PT->getAddressSpace() == 0 || PT->getAddressSpace() == 10);
4819 if (PT->getAddressSpace() == 10 && extracted.size()) {
4820 extracted.insert(extracted.begin(), obj);
4821 auto JLT = PointerType::get(StructType::get(PT->getContext(), {}), 10);
4822 auto FT = FunctionType::get(JLT, {}, true);
4823 auto wb =
4824 B.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
4825 "julia.write_barrier", FT);
4826 assert(obj->getType() == JLT);
4827 B.CreateCall(wb, extracted);
4828 }
4829 }
4830
4831 CountTrackedPointers tracked(jltype);
4832 assert(rootOffset - rootOffset0 == tracked.count);
4833
4834 return val;
4835}
4836
4837void copyNonJLValueInto(llvm::IRBuilder<> &B, llvm::Type *curType,
4838 llvm::Type *dstType, llvm::Value *dst,
4839 llvm::ArrayRef<unsigned> dstPrefix0,
4840 llvm::Type *srcType, llvm::Value *src,
4841 llvm::ArrayRef<unsigned> srcPrefix0, bool shouldZero) {
4842 std::deque<
4843 std::tuple<llvm::Type *, std::vector<unsigned>, std::vector<unsigned>>>
4844 todo = {{curType,
4845 std::vector<unsigned>(dstPrefix0.begin(), dstPrefix0.end()),
4846 std::vector<unsigned>(srcPrefix0.begin(), srcPrefix0.end())}};
4847
4848 auto &M = *B.GetInsertBlock()->getParent()->getParent();
4849
4850 size_t numRootsSeen = 0;
4851
4852 while (!todo.empty()) {
4853 auto cur = std::move(todo[0]);
4854 auto &&[ty, dstPrefix, srcPrefix] = cur;
4855 todo.pop_front();
4856
4857 if (auto PT = dyn_cast<PointerType>(ty)) {
4858 if (PT->getAddressSpace() == 10) {
4859 numRootsSeen++;
4860 if (shouldZero) {
4861 Value *out = dst;
4862 if (dstPrefix.size() > 0)
4863 out = constantInBoundsGEPHelper(B, dstType, out, dstPrefix);
4864 B.CreateStore(getUndefinedValueForType(M, ty), out);
4865 }
4866 }
4867 // We don't actually need pointers either here
4868 continue;
4869 }
4870
4871 if (auto AT = dyn_cast<ArrayType>(ty)) {
4872 for (size_t i = 0, E = AT->getNumElements(); i < E; i++) {
4873 std::vector<unsigned> nextDst(dstPrefix);
4874 std::vector<unsigned> nextSrc(srcPrefix);
4875 nextDst.push_back(E - 1 - i);
4876 nextSrc.push_back(E - 1 - i);
4877 todo.emplace_front(AT->getElementType(), std::move(nextDst),
4878 std::move(nextSrc));
4879 }
4880 continue;
4881 }
4882
4883 if (auto ST = dyn_cast<StructType>(ty)) {
4884 for (size_t i = 0, E = ST->getNumElements(); i < E; i++) {
4885 std::vector<unsigned> nextDst(dstPrefix);
4886 std::vector<unsigned> nextSrc(srcPrefix);
4887 nextDst.push_back(E - 1 - i);
4888 nextSrc.push_back(E - 1 - i);
4889 todo.emplace_front(ST->getElementType(E - 1 - i), std::move(nextDst),
4890 std::move(nextSrc));
4891 }
4892 continue;
4893 }
4894
4895 Value *out = dst;
4896 if (dstPrefix.size() > 0)
4897 out = constantInBoundsGEPHelper(B, dstType, out, dstPrefix);
4898
4899 Value *in = src;
4900 if (srcPrefix.size() > 0)
4901 in = constantInBoundsGEPHelper(B, srcType, in, srcPrefix);
4902
4903 auto ld = B.CreateLoad(ty, in);
4904 B.CreateStore(ld, out);
4905 }
4906
4907 CountTrackedPointers tracked(curType);
4908 assert(numRootsSeen == tracked.count);
4909 (void)tracked;
4910 (void)numRootsSeen;
4911}
4912
4913llvm::SmallVector<llvm::Value *, 1> getJuliaObjects(llvm::Value *v,
4914 llvm::IRBuilder<> &B) {
4915 std::deque<Value *> todo = {v};
4916 SmallVector<Value *, 1> done;
4917 while (todo.size()) {
4918 auto cur = todo.front();
4919 todo.pop_front();
4920 auto T = cur->getType();
4921 if (!anyJuliaObjects(T)) {
4922 continue;
4923 }
4924 if (isSpecialPtr(T)) {
4925 done.push_back(cur);
4926 continue;
4927 }
4928 if (auto ST = dyn_cast<StructType>(T)) {
4929 for (size_t i = 0, E = ST->getNumElements(); i < E; i++) {
4930 auto T2 = ST->getElementType(E - 1 - i);
4931 if (anyJuliaObjects(T2)) {
4932 auto V2 = B.CreateExtractValue(cur, E - 1 - i);
4933 todo.push_front(V2);
4934 }
4935 }
4936 continue;
4937 }
4938 if (auto AT = dyn_cast<ArrayType>(T)) {
4939 for (size_t i = 0, E = AT->getNumElements(); i < E; i++) {
4940 todo.push_front(B.CreateExtractValue(cur, E - 1 - i));
4941 }
4942 continue;
4943 }
4944 if (auto VT = dyn_cast<VectorType>(T)) {
4945 assert(!VT->getElementCount().isScalable());
4946 size_t numElems = VT->getElementCount().getKnownMinValue();
4947 for (size_t i = 0; i < numElems; i++) {
4948 todo.push_front(B.CreateExtractElement(cur, numElems - 1 - i));
4949 }
4950 continue;
4951 }
4952 llvm_unreachable("unknown source of julia type");
4953 }
4954 return done;
4955}
static bool contains(ArrayRef< int > ar, int v)
static bool isDeallocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory deallocation function For updating below one ...
static bool isAllocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory allocation function For updating below one sh...
static bool isAllocationCall(const llvm::Value *TmpOrig, llvm::TargetLibraryInfo &TLI)
static Operation * getFunctionFromCall(CallOpInterface iface)
static std::string str(AugmentedStruct c)
Definition EnzymeLogic.h:62
static TypeTree parseTBAA(TBAAStructTypeNode AccessType, llvm::Instruction &I, const llvm::DataLayout &DL, std::shared_ptr< llvm::ModuleSlotTracker > MST)
Given a TBAA access node return the corresponding TypeTree This includes recursively parsing the acce...
Definition TBAA.h:439
static bool isMemFreeLibMFunction(llvm::StringRef str, llvm::Intrinsic::ID *ID=nullptr)
llvm::Value * get_cached_mat_width(llvm::IRBuilder<> &B, llvm::ArrayRef< llvm::Value * > trans, llvm::Value *arg_ld, llvm::Value *dim1, llvm::Value *dim2, bool cacheMat, bool byRef, bool cublas)
Definition Utils.cpp:4018
void addValueToCache(llvm::Value *arg, bool cache_arg, llvm::Type *ty, llvm::SmallVectorImpl< llvm::Value * > &cacheValues, llvm::IRBuilder<> &BuilderZ, const Twine &name)
Definition Utils.cpp:3738
llvm::Optional< bool > arePointersGuaranteedNoAlias(TargetLibraryInfo &TLI, llvm::AAResults &AA, llvm::LoopInfo &LI, llvm::Value *op0, llvm::Value *op1, bool offsetAllowed)
Definition Utils.cpp:4618
Function * getOrInsertDifferentialFloatMemmove(Module &M, Type *T, unsigned dstalign, unsigned srcalign, unsigned dstaddr, unsigned srcaddr, unsigned bitwidth)
Definition Utils.cpp:2073
llvm::Value * load_if_ref(llvm::IRBuilder<> &B, llvm::Type *intType, llvm::Value *V, bool byRef)
Definition Utils.cpp:4075
CallInst * CreateDealloc(llvm::IRBuilder<> &Builder, llvm::Value *ToFree)
Definition Utils.cpp:742
llvm::Value * CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev, llvm::Type *T, llvm::Value *OuterCount, llvm::Value *InnerCount, const llvm::Twine &Name, llvm::CallInst **caller, bool ZeroMem)
Definition Utils.cpp:590
AllocaInst * getBaseAndOffset(Value *ptr, size_t &offset)
Definition Utils.cpp:3167
Value * lookup_with_layout(IRBuilder<> &B, Type *fpType, Value *layout, Value *const base, Value *lda, Value *row, Value *col)
Definition Utils.cpp:1135
llvm::cl::opt< bool > EnzymeZeroCache
llvm::Value * to_blas_fp_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef, Type *fpTy, IRBuilder<> &entryBuilder, llvm::Twine const &name)
Definition Utils.cpp:3784
void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj, bool isTape)
Definition Utils.cpp:414
llvm::SmallVector< llvm::Instruction *, 2 > PostCacheStore(llvm::StoreInst *SI, llvm::IRBuilder<> &B)
Definition Utils.cpp:423
Function * getFirstFunctionDefinition(Module &M)
Definition Utils.cpp:3551
llvm::cl::opt< bool > EnzymeFastMath("enzyme-fast-math", cl::init(true), cl::Hidden, cl::desc("Use fast math on derivative compuation"))
llvm::Function * getOrInsertDifferentialMPI_Wait(llvm::Module &M, ArrayRef< llvm::Type * > T, Type *reqType, StringRef caller)
Definition Utils.cpp:2271
llvm::PointerType * getDefaultAnonymousTapeType(llvm::LLVMContext &C)
Definition Utils.cpp:437
llvm::Value * getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr, llvm::Type *OpType, ConcreteType CT, llvm::Type *intType, IRBuilder<> &B2)
Definition Utils.cpp:2378
llvm::Function * getOrInsertDifferentialWaitallSave(llvm::Module &M, ArrayRef< llvm::Type * > T, PointerType *reqType)
Definition Utils.cpp:2203
Value * simplifyLoad(Value *V, size_t valSz, size_t preOffset)
Definition Utils.cpp:3386
bool attributeKnownFunctions(llvm::Function &F)
Definition Utils.cpp:114
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &B, llvm::Intrinsic::ID ID, llvm::Type *RetTy, llvm::ArrayRef< llvm::Value * > Args, llvm::Instruction *FMFSource, const llvm::Twine &Name)
Definition Utils.cpp:4233
llvm::Value * get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res)
Definition Utils.cpp:4272
void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::ArrayRef< llvm::Value * > args, llvm::ArrayRef< llvm::OperandBundleDef > bundles)
Create function for type that performs memcpy using lapack copy.
Definition Utils.cpp:1336
LLVMValueRef(* EnzymeUndefinedValueForType)(LLVMModuleRef, LLVMTypeRef, uint8_t)
Definition Utils.cpp:77
Function * getOrInsertDifferentialFloatMemcpy(Module &M, Type *elementType, unsigned dstalign, unsigned srcalign, unsigned dstaddr, unsigned srcaddr, unsigned bitwidth)
Create function for type that is equivalent to memcpy but adds to destination rather than a direct co...
Definition Utils.cpp:1009
llvm::Value * nextPowerOfTwo(llvm::IRBuilder<> &B, llvm::Value *V)
Create function to computer nearest power of two.
Definition Utils.cpp:2192
Function * getOrInsertMemcpyMat(Module &Mod, Type *elementType, PointerType *PT, IntegerType *IT, unsigned dstalign, unsigned srcalign)
Definition Utils.cpp:1793
llvm::Value * transpose(std::string floatType, IRBuilder<> &B, llvm::Value *V, bool cublas)
Definition Utils.cpp:3947
void EmitNoTypeError(const std::string &message, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &Builder2)
Definition Utils.cpp:4377
Value * is_lower(IRBuilder<> &B, Value *uplo, bool byRef, bool cublas)
Definition Utils.cpp:3800
LLVMValueRef(* CustomDeallocator)(LLVMBuilderRef, LLVMValueRef)
Definition Utils.cpp:71
llvm::cl::opt< bool > EnzymeCheckDerivativeNaN("enzyme-check-nan", cl::init(false), cl::Hidden, cl::desc("Add NaN checks to all derivative intermediate values"))
llvm::Optional< BlasInfo > extractBLAS(llvm::StringRef in)
Definition Utils.cpp:3563
Value * CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count, const Twine &Name, CallInst **caller, Instruction **ZeroMem, bool isDefault)
Definition Utils.cpp:619
LLVMValueRef(* CustomAllocator)(LLVMBuilderRef, LLVMTypeRef, LLVMValueRef, LLVMValueRef, uint8_t, LLVMValueRef *)
Definition Utils.cpp:65
void dumpBlock(llvm::BasicBlock *blk)
Definition Utils.cpp:4477
LLVMValueRef(* CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, const void *, LLVMValueRef, LLVMBuilderRef)
Definition Utils.cpp:62
Function * getOrInsertExponentialAllocator(Module &M, Function *newFunc, bool ZeroInit, llvm::Type *RT)
Definition Utils.cpp:443
bool notCaptured(llvm::Value *V)
Check if value if b captured.
Definition Utils.cpp:4608
Value * GetFunctionValFromValue(Value *fn)
Definition Utils.cpp:3466
bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, ScalarEvolution &SE, llvm::LoopInfo &LI, llvm::DominatorTree &DT, llvm::Instruction *maybeReader, llvm::Instruction *maybeWriter, llvm::Loop *scope)
Definition Utils.cpp:2765
bool isNVLoad(const llvm::Value *V)
Definition Utils.cpp:4483
llvm::Value * is_left(IRBuilder<> &B, llvm::Value *side, bool byRef, bool cublas)
Definition Utils.cpp:3904
llvm::Value * EmitNoDerivativeError(const std::string &message, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &Builder2, llvm::Value *condition)
Definition Utils.cpp:4295
Constant * getString(Module &M, StringRef Str)
Definition Utils.cpp:825
llvm::cl::opt< bool > EnzymeMemmoveWarning("enzyme-memmove-warning", cl::init(true), cl::Hidden, cl::desc("Warn if using memmove implementation as a fallback for memmove"))
void dumpTypeResults(TypeResults &TR)
Definition Utils.cpp:4481
bool writesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm::Instruction *maybeReader, llvm::Instruction *maybeWriter)
Return whether maybeReader can read from memory written to by maybeWriter.
Definition Utils.cpp:2880
Function * getOrInsertDifferentialFloatMemcpyMat(Module &Mod, Type *elementType, PointerType *PT, IntegerType *IT, IntegerType *CT, unsigned dstalign, unsigned srcalign, bool zeroSrc)
Definition Utils.cpp:1913
SmallVector< std::pair< Value *, size_t >, 1 > getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, size_t valSz, bool &legal)
Definition Utils.cpp:3255
void dumpType(llvm::Type *ty)
Definition Utils.cpp:4479
void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::ArrayRef< llvm::Value * > args, llvm::Type *copy_retty, llvm::ArrayRef< llvm::OperandBundleDef > bundles)
Create function for type that performs memcpy with a stride using blas copy.
Definition Utils.cpp:1316
void(* CustomRuntimeInactiveError)(LLVMBuilderRef, LLVMValueRef, LLVMValueRef)
Definition Utils.cpp:72
LLVMValueRef *(* EnzymePostCacheStore)(LLVMValueRef, LLVMBuilderRef, uint64_t *size)
Definition Utils.cpp:74
llvm::Value * moveSRetToFromRoots(llvm::IRBuilder<> &B, llvm::Type *jltype, llvm::Value *sret, llvm::Type *root_ty, llvm::Value *rootRet, size_t rootOffset, SRetRootMovement direction)
Definition Utils.cpp:4714
void copy_lower_to_upper(llvm::IRBuilder<> &B, llvm::Type *fpType, BlasInfo blas, bool byRef, llvm::Value *layout, llvm::Value *islower, llvm::Value *A, llvm::Value *lda, llvm::Value *N)
Definition Utils.cpp:1184
void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, llvm::Value *shadow, const char *Message, llvm::DebugLoc &&loc, llvm::Instruction *orig)
Definition Utils.cpp:902
SmallVector< llvm::Value *, 1 > get_blas_row(llvm::IRBuilder<> &B, ArrayRef< llvm::Value * > transA, bool byRef, bool cublas)
Definition Utils.cpp:4089
Value * is_nonunit(IRBuilder<> &B, Value *uplo, bool byRef, bool cublas)
Definition Utils.cpp:3834
Function * GetFunctionFromValue(Value *fn)
Definition Utils.cpp:3547
llvm::SmallVector< llvm::Value *, 1 > getJuliaObjects(llvm::Value *v, llvm::IRBuilder<> &B)
Definition Utils.cpp:4913
void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas, IntegerType *IT, Type *BlasCT, Type *BlasFPT, Type *BlasPT, Type *BlasIT, Type *fpTy, ArrayRef< Value * > args, ArrayRef< OperandBundleDef > bundles, bool byRef, bool julia_decl)
Definition Utils.cpp:1354
bool overwritesToMemoryReadByLoop(llvm::ScalarEvolution &SE, llvm::LoopInfo &LI, llvm::DominatorTree &DT, llvm::Instruction *maybeReader, const llvm::SCEV *LoadStart, const llvm::SCEV *LoadEnd, llvm::Instruction *maybeWriter, const llvm::SCEV *StoreStart, const llvm::SCEV *StoreEnd, llvm::Loop *scope)
Definition Utils.cpp:2574
void dumpValue(llvm::Value *val)
Definition Utils.cpp:4475
LLVMTypeRef(* EnzymeDefaultTapeType)(LLVMContextRef)
Definition Utils.cpp:76
void emit_backtrace(llvm::Instruction *inst, llvm::raw_ostream &ss)
Definition Utils.cpp:835
void mayExecuteAfter(llvm::SmallVectorImpl< llvm::Instruction * > &results, llvm::Instruction *inst, const llvm::SmallPtrSetImpl< Instruction * > &stores, const llvm::Loop *region)
Definition Utils.cpp:2513
llvm::Value * to_blas_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef, bool cublas, IntegerType *julia_decl, IRBuilder<> &entryBuilder, llvm::Twine const &name)
Definition Utils.cpp:3767
LLVMValueRef(* EnzymeSanitizeDerivatives)(LLVMValueRef, LLVMValueRef toset, LLVMBuilderRef, LLVMValueRef)
Definition Utils.cpp:80
llvm::cl::opt< bool > EnzymeRuntimeError("enzyme-runtime-error", cl::init(false), cl::Hidden, cl::desc("Emit Runtime errors instead of compile time ones"))
llvm::CallInst * getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, IntegerType *IT, Type *BlasPT, Type *BlasIT, Type *fpTy, llvm::ArrayRef< llvm::Value * > args, const llvm::ArrayRef< llvm::OperandBundleDef > bundles, bool byRef, bool cublas, bool julia_decl)
Definition Utils.cpp:1542
SmallVector< std::tuple< Instruction *, Value *, size_t >, 1 > findAllUsersOf(Value *AI)
Definition Utils.cpp:3210
Function * getOrInsertMemcpyStrided(Module &M, Type *elementType, PointerType *T, Type *IT, unsigned dstalign, unsigned srcalign)
Definition Utils.cpp:1691
static Value * constantInBoundsGEPHelper(llvm::IRBuilder<> &B, llvm::Type *type, llvm::Value *value, ArrayRef< unsigned > path)
Definition Utils.cpp:4703
llvm::FastMathFlags getFast()
Get LLVM fast math flags.
Definition Utils.cpp:3731
void copyNonJLValueInto(llvm::IRBuilder<> &B, llvm::Type *curType, llvm::Type *dstType, llvm::Value *dst, llvm::ArrayRef< unsigned > dstPrefix0, llvm::Type *srcType, llvm::Value *src, llvm::ArrayRef< unsigned > srcPrefix0, bool shouldZero)
Definition Utils.cpp:4837
bool notCapturedBefore(llvm::Value *V, Instruction *inst, size_t checkLoadCaptures)
Definition Utils.cpp:4503
static std::string tofltstr(Type *T)
Convert a floating type to a string.
Definition Utils.cpp:796
void dumpModule(llvm::Module *mod)
Definition Utils.cpp:4473
void(* CustomZero)(LLVMBuilderRef, LLVMTypeRef, LLVMValueRef, uint8_t)
Definition Utils.cpp:69
llvm::Constant * getUndefinedValueForType(llvm::Module &M, llvm::Type *T, bool forceZero)
Definition Utils.cpp:3623
Function * getOrInsertCheckedFree(Module &M, CallInst *call, Type *Ty, unsigned width)
Definition Utils.cpp:2084
std::vector< std::tuple< llvm::Type *, size_t, size_t > > parseTrueType(const llvm::MDNode *md, DerivativeMode Mode, bool const_src)
Definition Utils.cpp:4411
llvm::Value * is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef, bool cublas)
Definition Utils.cpp:3868
llvm::Value * SanitizeDerivatives(llvm::Value *val, llvm::Value *toset, llvm::IRBuilder<> &BuilderM, llvm::Value *mask)
Definition Utils.cpp:3634
bool collectOffset(GEPOperator *gep, const DataLayout &DL, unsigned BitWidth, MapVector< Value *, APInt > &VariableOffsets, APInt &ConstantOffset)
Definition Utils.cpp:4169
static llvm::StringRef getFuncName(llvm::Function *called)
Definition Utils.h:1260
static std::string getRenamedPerCallingConv(llvm::StringRef caller, llvm::StringRef callee)
Definition Utils.h:2413
llvm::cl::opt< bool > EnzymeBlasCopy
static llvm::Loop * getAncestor(llvm::Loop *R1, llvm::Loop *R2)
Definition Utils.h:1053
static bool anyJuliaObjects(llvm::Type *T)
Definition Utils.h:2524
@ Args
Return is a struct of all args.
static bool isNoAlias(const llvm::CallBase *call)
Definition Utils.h:1858
static llvm::PointerType * getUnqual(llvm::Type *T)
Definition Utils.h:1179
static bool isCertainPrint(const llvm::StringRef name)
Definition Utils.h:729
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
Definition Utils.h:203
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
Definition Utils.h:1174
static llvm::Function * getIntrinsicDeclaration(llvm::Module *M, llvm::Intrinsic::ID id, llvm::ArrayRef< llvm::Type * > Tys={})
Definition Utils.h:2263
static bool isNoCapture(const llvm::CallBase *call, size_t idx)
Definition Utils.h:1840
static void addFunctionNoCapture(llvm::Function *call, size_t idx)
Definition Utils.h:2299
static bool isDebugFunction(llvm::Function *called)
Definition Utils.h:690
static bool isPointerArithmeticInst(const llvm::Value *V, bool includephi=true, bool includebin=true)
Definition Utils.h:1456
static llvm::Value * getBaseObject(llvm::Value *V, bool offsetAllowed=true)
Definition Utils.h:1507
static llvm::MDNode * hasMetadata(const llvm::GlobalObject *O, llvm::StringRef kind)
Check if a global has metadata.
Definition Utils.h:339
static bool isSpecialPtr(llvm::Type *Ty)
Definition Utils.h:2354
@ Tracked
Definition Utils.h:2341
static void allInstructionsBetween(llvm::LoopInfo &LI, llvm::Instruction *inst1, llvm::Instruction *inst2, llvm::function_ref< bool(llvm::Instruction *)> f)
Call the function f for all instructions that happen between inst1 and inst2 If the function returns ...
Definition Utils.h:1099
ErrorType
Definition Utils.h:77
static std::tuple< llvm::StringRef, llvm::StringRef, llvm::StringRef > tripleSplitDollar(llvm::StringRef caller)
Definition Utils.h:2404
SRetRootMovement
Definition Utils.h:2505
static llvm::Value * CreateSelect(llvm::IRBuilder<> &Builder2, llvm::Value *cmp, llvm::Value *tval, llvm::Value *fval, const llvm::Twine &Name="")
Definition Utils.h:2005
static llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op)
Definition Utils.h:1269
llvm::cl::opt< bool > EnzymeLapackCopy
llvm::cl::opt< bool > EnzymeNonPower2Cache
DerivativeMode
Definition Utils.h:390
Concrete SubType of a given value.
bool isKnown() const
Whether this ConcreteType has information (is not unknown)
std::string str() const
Convert the ConcreteType to a string.
llvm::Type * isFloat() const
Return the floating point type, if this is a float.
EnzymeFailure(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion)
Definition Utils.cpp:785
EnzymeWarning(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion)
Definition Utils.cpp:775
static llvm::Value * extractMeta(llvm::IRBuilder<> &Builder, llvm::Value *Agg, unsigned off, const llvm::Twine &name="")
Helper routine to extract a nested element from a struct/array. This is.
TypeResults TR
A holder class representing the results of running TypeAnalysis on a given function.
void dump(llvm::raw_ostream &ss=llvm::errs()) const
Prints all known information.
TypeAnalyzer * analyzer
TypeTree query(llvm::Value *val) const
The TypeTree of a particular Value.
TypeTree Lookup(size_t len, const llvm::DataLayout &dl) const
Select all submappings whose first index is in range [0, len) and remove the first index.
Definition TypeTree.h:593
llvm::Type * fpType(llvm::LLVMContext &ctx, bool to_scalar=false) const
Definition Utils.cpp:981
bool is64
Definition Utils.h:749
std::string suffix
Definition Utils.h:747
std::string prefix
Definition Utils.h:746
std::string floatType
Definition Utils.h:745
llvm::IntegerType * intType(llvm::LLVMContext &ctx) const
Definition Utils.cpp:1000
CountTrackedPointers(llvm::Type *T)
Definition Utils.cpp:4136
llvm::Instruction * req
llvm::IRBuilder * ip