Enzyme main
Loading...
Searching...
No Matches
PreserveNVVM.cpp
Go to the documentation of this file.
1//===- PreserveNVVM.cpp - Mark NVVM attributes for preservation. -------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains createPreserveNVVM, a transformation pass that marks
22// calls to __nv_* functions, marking them as noinline as implementing the llvm
23// intrinsic.
24//
25//===----------------------------------------------------------------------===//
26#include <llvm/Config/llvm-config.h>
27
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/SetVector.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/ADT/StringMap.h"
32
33#include "llvm/ADT/SmallSet.h"
34#include "llvm/IR/Constants.h"
35#include "llvm/IR/Function.h"
36#include "llvm/IR/GlobalVariable.h"
37#include "llvm/IR/Module.h"
38#include "llvm/Support/raw_ostream.h"
39
40#include "llvm/Support/TimeProfiler.h"
41
42#include "llvm/Pass.h"
43
44#include "llvm/Transforms/Utils.h"
45
46#include <map>
47
48#include "PreserveNVVM.h"
49#include "Utils.h"
50
51using namespace llvm;
52#ifdef DEBUG_TYPE
53#undef DEBUG_TYPE
54#endif
55#define DEBUG_TYPE "preserve-nvvm"
56
57#if LLVM_VERSION_MAJOR >= 14
58#define addAttribute addAttributeAtIndex
59#endif
60
61#ifndef ENZYME_ENABLE_NVVM_ATTRIBUTION
62#define ENZYME_ENABLE_NVVM_ATTRIBUTION 1
63#endif
64
65//! Returns whether changed.
66bool preserveLinkage(bool Begin, Function &F, bool Inlining = true) {
67 if (Begin && !F.hasFnAttribute("prev_fixup")) {
68 F.addFnAttr("prev_fixup");
69 if (F.hasFnAttribute(Attribute::AlwaysInline))
70 F.addFnAttr("prev_always_inline");
71 if (F.hasFnAttribute(Attribute::NoInline))
72 F.addFnAttr("prev_no_inline");
73 if (Inlining) {
74 F.removeFnAttr(Attribute::AlwaysInline);
75 F.addFnAttr(Attribute::NoInline);
76 }
77 F.addFnAttr("prev_linkage", std::to_string(F.getLinkage()));
78 F.setLinkage(Function::LinkageTypes::ExternalLinkage);
79 return true;
80 }
81 return false;
82}
83
84// Return true if the module has a triple indicating an nvptx target, false
85// otherwise.
86bool isTargetNVPTX(llvm::Module &M) {
87#if LLVM_VERSION_MAJOR > 20
88 return M.getTargetTriple().getArch() == Triple::ArchType::nvptx ||
89 M.getTargetTriple().getArch() == Triple::ArchType::nvptx64;
90#else
91 return M.getTargetTriple().find("nvptx") != std::string::npos;
92#endif
93}
94
95template <const char *handlername, DerivativeMode Mode, int numargs>
96static void
97handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g,
98 SmallVectorImpl<GlobalVariable *> &globalsToErase) {
99 if (g.hasInitializer()) {
100 if (auto CA = dyn_cast<ConstantAggregate>(g.getInitializer())) {
101 if (CA->getNumOperands() < numargs) {
102 llvm::errs() << M << "\n";
103 llvm::errs() << "Use of " << handlername
104 << " must be a "
105 "constant of size at least "
106 << numargs << " " << g << "\n";
107 llvm_unreachable(handlername);
108 } else {
109 Function *Fs[numargs];
110 for (size_t i = 0; i < numargs; i++) {
111 Value *V = CA->getOperand(i);
112 while (auto CE = dyn_cast<ConstantExpr>(V)) {
113 V = CE->getOperand(0);
114 }
115 if (auto CA = dyn_cast<ConstantAggregate>(V))
116 V = CA->getOperand(0);
117 while (auto CE = dyn_cast<ConstantExpr>(V)) {
118 V = CE->getOperand(0);
119 }
120 if (auto F = dyn_cast<Function>(V)) {
121 Fs[i] = F;
122 } else {
123 llvm::errs() << M << "\n";
124 llvm::errs() << "Param of " << handlername
125 << " must be a "
126 "function"
127 << g << "\n"
128 << *V << "\n";
129 llvm_unreachable(handlername);
130 }
131 }
132
133 SmallSet<size_t, 1> byref;
134
135 if constexpr (Mode == DerivativeMode::ReverseModeGradient) {
136 assert(numargs >= 3);
137 for (size_t i = numargs; i < CA->getNumOperands(); i++) {
138 Value *V = CA->getOperand(i);
139 while (auto CE = dyn_cast<ConstantExpr>(V)) {
140 V = CE->getOperand(0);
141 }
142 if (auto CA = dyn_cast<ConstantAggregate>(V))
143 V = CA->getOperand(0);
144 while (auto CE = dyn_cast<ConstantExpr>(V)) {
145 V = CE->getOperand(0);
146 }
147 if (auto GV = dyn_cast<GlobalVariable>(V)) {
148 if (GV->isConstant())
149 if (auto C = GV->getInitializer())
150 if (auto CA = dyn_cast<ConstantDataArray>(C))
151 if (CA->getType()->getElementType()->isIntegerTy(8) &&
152 CA->isCString()) {
153
154 auto str = CA->getAsCString();
155 bool legal = startsWith(str, "byref_");
156 size_t argnum = 0;
157 if (legal) {
158 for (size_t i = str.size() - 1, len = strlen("byref_");
159 i >= len; i--) {
160 char c = str[i];
161 if (c < '0' || c > '9') {
162 legal = false;
163 break;
164 }
165 argnum *= 10;
166 argnum += c - '0';
167 }
168 }
169 if (legal) {
170 byref.insert(argnum);
171 continue;
172 }
173 }
174 }
175 llvm::errs() << M << "\n";
176 llvm::errs() << "Use of " << handlername
177 << " possible post args include 'byref_ret'"
178 << "\n";
179 llvm_unreachable(handlername);
180 }
181
182 if (byref.size())
183 for (size_t fn = 1; fn <= 2; fn++) {
184 Function *F = Fs[fn];
185 bool need = false;
186 size_t nonSRetSize = 0;
187 for (size_t i = 0; i < F->arg_size(); i++)
188 if (!F->hasParamAttribute(i, Attribute::StructRet))
189 nonSRetSize++;
190 for (auto r : byref)
191 if (r < nonSRetSize)
192 need = true;
193 if (!need)
194 continue;
195
196 SmallVector<Type *, 3> args;
197 Type *sretTy = nullptr;
198 size_t realidx = 0;
199 size_t i = 0;
200 for (auto &arg : F->args()) {
201 if (!F->hasParamAttribute(i, Attribute::StructRet)) {
202 if (!byref.count(realidx))
203 args.push_back(arg.getType());
204 else {
205 // TODO in opaque pointers
206 Type *subTy = nullptr;
207#if LLVM_VERSION_MAJOR < 17
208 subTy = arg.getType()->getPointerElementType();
209#endif
210 assert(subTy);
211 args.push_back(subTy);
212 }
213 realidx++;
214 } else {
215 llvm::Type *T = nullptr;
216#if LLVM_VERSION_MAJOR > 12
217 T = F->getParamAttribute(i, Attribute::StructRet)
218 .getValueAsType();
219#else
220 T = arg.getType()->getPointerElementType();
221#endif
222 sretTy = T;
223 }
224 i++;
225 }
226 Type *RT = F->getReturnType();
227 if (sretTy) {
228 assert(RT->isVoidTy());
229 RT = sretTy;
230 }
231 FunctionType *FTy =
232 FunctionType::get(RT, args, F->getFunctionType()->isVarArg());
233 Function *NewF =
234 Function::Create(FTy, Function::LinkageTypes::InternalLinkage,
235 "fixbyval_" + F->getName(), F->getParent());
236
237 AllocaInst *AI = nullptr;
238 BasicBlock *BB =
239 BasicBlock::Create(NewF->getContext(), "entry", NewF);
240 IRBuilder<> bb(BB);
241 if (sretTy)
242 AI = bb.CreateAlloca(sretTy);
243 SmallVector<Value *, 3> argVs;
244 auto arg = NewF->arg_begin();
245 realidx = 0;
246 for (size_t i = 0; i < F->arg_size(); i++) {
247 if (!F->hasParamAttribute(i, Attribute::StructRet)) {
248 arg->setName("arg" + Twine(realidx));
249 if (!byref.count(realidx))
250 argVs.push_back(arg);
251 else {
252 auto A = bb.CreateAlloca(arg->getType());
253 bb.CreateStore(arg, A);
254 argVs.push_back(A);
255 }
256 realidx++;
257 ++arg;
258 } else {
259 argVs.push_back(AI);
260 }
261 }
262 auto cal = bb.CreateCall(F, argVs);
263 cal->setCallingConv(F->getCallingConv());
264
265 if (sretTy) {
266 Value *res = bb.CreateLoad(sretTy, AI);
267 bb.CreateRet(res);
268 } else if (!RT->isVoidTy()) {
269 bb.CreateRet(cal);
270 } else
271 bb.CreateRetVoid();
272
273 Fs[fn] = NewF;
274 }
275
276 preserveLinkage(true, *Fs[1], false);
277 Fs[0]->setMetadata(
278 "enzyme_augment",
279 llvm::MDTuple::get(Fs[0]->getContext(),
280 {llvm::ValueAsMetadata::get(Fs[1])}));
281 preserveLinkage(true, *Fs[2], false);
282 Fs[0]->setMetadata(
283 "enzyme_gradient",
284 llvm::MDTuple::get(Fs[0]->getContext(),
285 {llvm::ValueAsMetadata::get(Fs[2])}));
286 } else if (Mode == DerivativeMode::ForwardMode) {
287 assert(numargs == 2);
288 preserveLinkage(true, *Fs[1], false);
289 Fs[0]->setMetadata(
290 "enzyme_derivative",
291 llvm::MDTuple::get(Fs[0]->getContext(),
292 {llvm::ValueAsMetadata::get(Fs[1])}));
293 } else if (Mode == DerivativeMode::ForwardModeSplit) {
294 assert(numargs == 3);
295 preserveLinkage(true, *Fs[1], false);
296 Fs[0]->setMetadata(
297 "enzyme_augment",
298 llvm::MDTuple::get(Fs[0]->getContext(),
299 {llvm::ValueAsMetadata::get(Fs[1])}));
300 preserveLinkage(true, *Fs[2], false);
301 Fs[0]->setMetadata(
302 "enzyme_splitderivative",
303 llvm::MDTuple::get(Fs[0]->getContext(),
304 {llvm::ValueAsMetadata::get(Fs[2])}));
305 } else
306 assert("Unknown mode");
307 }
308 } else if (isTargetNVPTX(M)) {
309 llvm::errs() << M << "\n";
310 llvm::errs() << "Use of " << handlername
311 << " must be a "
312 "constant aggregate "
313 << g << "\n";
314 llvm_unreachable(handlername);
315 }
316 } else {
317 llvm::errs() << M << "\n";
318 llvm::errs() << "Use of " << handlername
319 << " must be a "
320 "constant array of size "
321 << numargs << " " << g << "\n";
322 llvm_unreachable(handlername);
323 }
324 globalsToErase.push_back(&g);
325}
326
327bool preserveNVVM(bool Begin, Module &M) {
328 bool changed = false;
329 constexpr static const char gradient_handler_name[] =
330 "__enzyme_register_gradient";
331 constexpr static const char derivative_handler_name[] =
332 "__enzyme_register_derivative";
333 constexpr static const char splitderivative_handler_name[] =
334 "__enzyme_register_splitderivative";
335
336 if (Begin)
337 if (GlobalVariable *GA = M.getGlobalVariable("llvm.global.annotations")) {
338 if (GA->hasInitializer()) {
339 auto AOp = GA->getInitializer();
340 // all metadata are stored in an array of struct of metadata
341 if (ConstantArray *CA = dyn_cast<ConstantArray>(AOp)) {
342 // so iterate over the operands
343 SmallVector<Constant *, 1> replacements;
344 for (Value *CAOp : CA->operands()) {
345 // get the struct, which holds a pointer to the annotated function
346 // as first field, and the annotation as second field
347 ConstantStruct *CS = dyn_cast<ConstantStruct>(CAOp);
348 if (!CS)
349 continue;
350
351 if (CS->getNumOperands() < 2)
352 continue;
353
354 // the second field is a pointer to a global constant Array that
355 // holds the string
356 GlobalVariable *GAnn =
357 dyn_cast<GlobalVariable>(CS->getOperand(1)->getOperand(0));
358
359 ConstantDataArray *A = nullptr;
360
361 if (GAnn)
362 A = dyn_cast<ConstantDataArray>(GAnn->getOperand(0));
363 else
364 A = dyn_cast<ConstantDataArray>(CS->getOperand(1)->getOperand(0));
365
366 if (!A)
367 continue;
368
369 // we have the annotation! Check it's an epona annotation
370 // and process
371 StringRef AS = A->getAsCString();
372
373 Constant *Val = cast<Constant>(CS->getOperand(0));
374 while (auto CE = dyn_cast<ConstantExpr>(Val))
375 Val = CE->getOperand(0);
376
377 Function *Func = dyn_cast<Function>(Val);
378 GlobalVariable *Glob = dyn_cast<GlobalVariable>(Val);
379
380 if (AS == "enzyme_inactive" && Func) {
381 Func->addAttribute(
382 AttributeList::FunctionIndex,
383 Attribute::get(Func->getContext(), "enzyme_inactive"));
384 changed = true;
385 preserveLinkage(Begin, *Func);
386 replacements.push_back(Constant::getNullValue(CAOp->getType()));
387 continue;
388 }
389
390 if (AS == "enzyme_elementwise_read" && Func) {
391 Func->addAttribute(AttributeList::FunctionIndex,
392 Attribute::get(Func->getContext(),
393 "enzyme_elementwise_read"));
394 changed = true;
395 replacements.push_back(Constant::getNullValue(CAOp->getType()));
396 continue;
397 }
398
399 if (AS == "enzyme_shouldrecompute" && Func) {
400 Func->addAttribute(
401 AttributeList::FunctionIndex,
402 Attribute::get(Func->getContext(), "enzyme_shouldrecompute"));
403 changed = true;
404 replacements.push_back(Constant::getNullValue(CAOp->getType()));
405 continue;
406 }
407
408 if (AS == "enzyme_inactive" && Glob) {
409 Glob->setMetadata("enzyme_inactive",
410 MDNode::get(Glob->getContext(), {}));
411 changed = true;
412 replacements.push_back(Constant::getNullValue(CAOp->getType()));
413 continue;
414 }
415
416 if (AS == "enzyme_nofree" && Func) {
417 Func->addAttribute(
418 AttributeList::FunctionIndex,
419 Attribute::get(Func->getContext(), Attribute::NoFree));
420 changed = true;
421 preserveLinkage(Begin, *Func);
422 replacements.push_back(Constant::getNullValue(CAOp->getType()));
423 continue;
424 }
425
426 if (startsWith(AS, "enzyme_function_like") && Func) {
427 auto val = AS.substr(1 + AS.find('='));
428 Func->addAttribute(
429 AttributeList::FunctionIndex,
430 Attribute::get(Func->getContext(), "enzyme_math", val));
431 changed = true;
432 preserveLinkage(Begin, *Func);
433 replacements.push_back(Constant::getNullValue(CAOp->getType()));
434 continue;
435 }
436
437 if (AS == "enzyme_sparse_accumulate" && Func) {
438 Func->addAttribute(AttributeList::FunctionIndex,
439 Attribute::get(Func->getContext(),
440 "enzyme_sparse_accumulate"));
441 changed = true;
442 preserveLinkage(Begin, *Func);
443 replacements.push_back(Constant::getNullValue(CAOp->getType()));
444 continue;
445 }
446 replacements.push_back(cast<Constant>(CAOp));
447 }
448 GA->setInitializer(ConstantArray::get(CA->getType(), replacements));
449 }
450 }
451 }
452
453 for (GlobalVariable &g : M.globals()) {
454 if (g.getName().contains(gradient_handler_name) ||
455 g.getName().contains(derivative_handler_name) ||
456 g.getName().contains(splitderivative_handler_name) ||
457 g.getName().contains("__enzyme_nofree") ||
458 g.getName().contains("__enzyme_inactivefn") ||
459 g.getName().contains("__enzyme_sparse_accumulate") ||
460 g.getName().contains("__enzyme_function_like") ||
461 g.getName().contains("__enzyme_allocation_like")) {
462 if (g.hasInitializer()) {
463 Value *V = g.getInitializer();
464 while (1) {
465 if (auto CE = dyn_cast<ConstantExpr>(V)) {
466 V = CE->getOperand(0);
467 continue;
468 }
469 if (auto CA = dyn_cast<ConstantAggregate>(V)) {
470 V = CA->getOperand(0);
471 continue;
472 }
473 break;
474 }
475 if (auto F = dyn_cast<Function>(V))
476 changed |= preserveLinkage(Begin, *F);
477 }
478 }
479 }
480 SmallVector<GlobalVariable *, 1> toErase;
481 for (GlobalVariable &g : M.globals()) {
482 if (g.getName().contains(gradient_handler_name)) {
483 handleCustomDerivative<gradient_handler_name,
485 toErase);
486 changed = true;
487 } else if (g.getName().contains(derivative_handler_name)) {
488 handleCustomDerivative<derivative_handler_name,
489 DerivativeMode::ForwardMode, 2>(M, g, toErase);
490 changed = true;
491 } else if (g.getName().contains(splitderivative_handler_name)) {
492 handleCustomDerivative<splitderivative_handler_name,
494 toErase);
495 changed = true;
496 }
497 if (g.getName().contains("__enzyme_inactive_global")) {
498 if (g.hasInitializer()) {
499 Value *V = g.getInitializer();
500 while (1) {
501 if (auto CE = dyn_cast<ConstantExpr>(V)) {
502 V = CE->getOperand(0);
503 continue;
504 }
505 if (auto CA = dyn_cast<ConstantAggregate>(V)) {
506 V = CA->getOperand(0);
507 continue;
508 }
509 break;
510 }
511 if (auto GV = cast<GlobalVariable>(V)) {
512 GV->setMetadata("enzyme_inactive", MDNode::get(g.getContext(), {}));
513 toErase.push_back(&g);
514 changed = true;
515 } else {
516 llvm::errs() << "Param of __enzyme_inactive_global must be a "
517 "global variable"
518 << g << "\n"
519 << *V << "\n";
520 llvm_unreachable("__enzyme_inactive_global");
521 }
522 }
523 }
524 if (g.getName().contains("__enzyme_inactivefn")) {
525 if (g.hasInitializer()) {
526 Value *V = g.getInitializer();
527 while (1) {
528 if (auto CE = dyn_cast<ConstantExpr>(V)) {
529 V = CE->getOperand(0);
530 continue;
531 }
532 if (auto CA = dyn_cast<ConstantAggregate>(V)) {
533 V = CA->getOperand(0);
534 continue;
535 }
536 break;
537 }
538 if (auto F = cast<Function>(V)) {
539 F->addAttribute(AttributeList::FunctionIndex,
540 Attribute::get(g.getContext(), "enzyme_inactive"));
541 toErase.push_back(&g);
542 changed = true;
543 } else {
544 llvm::errs() << "Param of __enzyme_inactivefn must be a "
545 "constant function"
546 << g << "\n"
547 << *V << "\n";
548 llvm_unreachable("__enzyme_inactivefn");
549 }
550 }
551 }
552 if (g.getName().contains("__enzyme_sparse_accumulate")) {
553 if (g.hasInitializer()) {
554 Value *V = g.getInitializer();
555 while (1) {
556 if (auto CE = dyn_cast<ConstantExpr>(V)) {
557 V = CE->getOperand(0);
558 continue;
559 }
560 if (auto CA = dyn_cast<ConstantAggregate>(V)) {
561 V = CA->getOperand(0);
562 continue;
563 }
564 break;
565 }
566 if (auto F = cast<Function>(V)) {
567 F->addAttribute(
568 AttributeList::FunctionIndex,
569 Attribute::get(g.getContext(), "enzyme_sparse_accumulate"));
570 toErase.push_back(&g);
571 changed = true;
572 } else {
573 llvm::errs() << "Param of __enzyme_sparse_accumulate must be a "
574 "constant function"
575 << g << "\n"
576 << *V << "\n";
577 llvm_unreachable("__enzyme_sparse_accumulate");
578 }
579 }
580 }
581 if (g.getName().contains("__enzyme_nofree")) {
582 if (g.hasInitializer()) {
583 Value *V = g.getInitializer();
584 while (1) {
585 if (auto CE = dyn_cast<ConstantExpr>(V)) {
586 V = CE->getOperand(0);
587 continue;
588 }
589 if (auto CA = dyn_cast<ConstantAggregate>(V)) {
590 V = CA->getOperand(0);
591 continue;
592 }
593 break;
594 }
595 if (auto F = cast<Function>(V)) {
596 F->addAttribute(AttributeList::FunctionIndex,
597 Attribute::get(g.getContext(), Attribute::NoFree));
598 toErase.push_back(&g);
599 changed = true;
600 } else {
601 llvm::errs() << "Param of __enzyme_nofree must be a "
602 "constant function"
603 << g << "\n"
604 << *V << "\n";
605 llvm_unreachable("__enzyme_nofree");
606 }
607 }
608 }
609 if (g.getName().contains("__enzyme_function_like")) {
610 if (g.hasInitializer()) {
611 auto CA = dyn_cast<ConstantAggregate>(g.getInitializer());
612 if (!CA || CA->getNumOperands() < 2) {
613 llvm::errs() << "Use of "
614 << "enzyme_function_like"
615 << " must be a "
616 "constant of size at least "
617 << 2 << " " << g << "\n";
618 llvm_unreachable("enzyme_function_like");
619 }
620 Value *V = CA->getOperand(0);
621 Value *name = CA->getOperand(1);
622 while (auto CE = dyn_cast<ConstantExpr>(V)) {
623 V = CE->getOperand(0);
624 }
625 while (auto CE = dyn_cast<ConstantExpr>(name)) {
626 name = CE->getOperand(0);
627 }
628 StringRef nameVal;
629 if (auto GV = dyn_cast<GlobalVariable>(name))
630 if (GV->isConstant())
631 if (auto C = GV->getInitializer())
632 if (auto CA = dyn_cast<ConstantDataArray>(C))
633 if (CA->getType()->getElementType()->isIntegerTy(8) &&
634 CA->isCString())
635 nameVal = CA->getAsCString();
636
637 if (nameVal == "") {
638 llvm::errs() << *name << "\n";
639 llvm::errs() << "Use of "
640 << "enzyme_function_like"
641 << "requires a non-empty function name"
642 << "\n";
643 llvm_unreachable("enzyme_function_like");
644 }
645 if (auto F = cast<Function>(V)) {
646 F->addAttribute(
647 AttributeList::FunctionIndex,
648 Attribute::get(g.getContext(), "enzyme_math", nameVal));
649 toErase.push_back(&g);
650 changed = true;
651 } else {
652 llvm::errs() << "Param of __enzyme_function_like must be a "
653 "constant function"
654 << g << "\n"
655 << *V << "\n";
656 llvm_unreachable("__enzyme_function_like");
657 }
658 }
659 }
660 if (g.getName().contains("__enzyme_allocation_like")) {
661 if (g.hasInitializer()) {
662 auto CA = dyn_cast<ConstantAggregate>(g.getInitializer());
663 if (!CA || CA->getNumOperands() != 4) {
664 llvm::errs() << "Use of "
665 << "enzyme_allocation_like"
666 << " must be a "
667 "constant of size at least "
668 << 4 << " " << g << "\n";
669 llvm_unreachable("enzyme_allocation_like");
670 }
671 Value *V = CA->getOperand(0);
672 Value *name = CA->getOperand(1);
673 while (auto CE = dyn_cast<ConstantExpr>(V)) {
674 V = CE->getOperand(0);
675 }
676 while (auto CE = dyn_cast<ConstantExpr>(name)) {
677 name = CE->getOperand(0);
678 }
679 Value *deallocind = CA->getOperand(2);
680 while (auto CE = dyn_cast<ConstantExpr>(deallocind)) {
681 deallocind = CE->getOperand(0);
682 }
683 Value *deallocfn = CA->getOperand(3);
684 while (auto CE = dyn_cast<ConstantExpr>(deallocfn)) {
685 deallocfn = CE->getOperand(0);
686 }
687 size_t index = 0;
688 if (isa<ConstantPointerNull>(name)) {
689 // An integer 0 may have been implicitly converted to a null pointer
690 index = 0;
691 } else if (auto CI = dyn_cast<ConstantInt>(name)) {
692 index = CI->getZExtValue();
693 } else {
694 llvm::errs() << *name << "\n";
695 llvm::errs() << "Use of "
696 << "enzyme_allocation_like"
697 << "requires an integer index"
698 << "\n";
699 llvm_unreachable("enzyme_allocation_like");
700 }
701
702 StringRef deallocIndStr;
703 bool foundInd = false;
704 if (auto GV = dyn_cast<GlobalVariable>(deallocind))
705 if (GV->isConstant())
706 if (auto C = GV->getInitializer())
707 if (auto CA = dyn_cast<ConstantDataArray>(C))
708 if (CA->getType()->getElementType()->isIntegerTy(8) &&
709 CA->isCString()) {
710 deallocIndStr = CA->getAsCString();
711 foundInd = true;
712 }
713
714 if (!foundInd) {
715 llvm::errs() << *deallocind << "\n";
716 llvm::errs() << "Use of "
717 << "enzyme_allocation_like"
718 << "requires a deallocation index string"
719 << "\n";
720 llvm_unreachable("enzyme_allocation_like");
721 }
722 if (auto F = dyn_cast<Function>(V)) {
723 F->addAttribute(AttributeList::FunctionIndex,
724 Attribute::get(g.getContext(), "enzyme_allocator",
725 std::to_string(index)));
726 } else {
727 llvm::errs() << "Param of __enzyme_allocation_like must be a "
728 "function"
729 << g << "\n"
730 << *V << "\n";
731 llvm_unreachable("__enzyme_allocation_like");
732 }
733 cast<Function>(V)->addAttribute(AttributeList::FunctionIndex,
734 Attribute::get(g.getContext(),
735 "enzyme_deallocator",
736 deallocIndStr));
737
738 if (auto F = dyn_cast<Function>(deallocfn)) {
739 cast<Function>(V)->setMetadata(
740 "enzyme_deallocator_fn",
741 llvm::MDTuple::get(F->getContext(),
742 {llvm::ValueAsMetadata::get(F)}));
743 changed |= preserveLinkage(Begin, *F);
744 } else {
745 llvm::errs() << "Free fn of __enzyme_allocation_like must be a "
746 "function"
747 << g << "\n"
748 << *deallocfn << "\n";
749 llvm_unreachable("__enzyme_allocation_like");
750 }
751 toErase.push_back(&g);
752 changed = true;
753 }
754 }
755 }
756
757 for (auto G : toErase) {
758 for (auto name : {"llvm.used", "llvm.compiler.used"}) {
759 if (auto V = M.getGlobalVariable(name)) {
760 auto C = cast<ConstantArray>(V->getInitializer());
761 SmallVector<Constant *, 1> toKeep;
762 bool found = false;
763 for (unsigned i = 0; i < C->getNumOperands(); i++) {
764 Value *Op = C->getOperand(i)->stripPointerCasts();
765 if (Op == G)
766 found = true;
767 else
768 toKeep.push_back(C->getOperand(i));
769 }
770 if (found) {
771 if (toKeep.size()) {
772 auto CA = ConstantArray::get(
773 ArrayType::get(C->getType()->getElementType(), toKeep.size()),
774 toKeep);
775 GlobalVariable *NGV = new GlobalVariable(
776 CA->getType(), V->isConstant(), V->getLinkage(), CA, "",
777 V->getThreadLocalMode());
778#if LLVM_VERSION_MAJOR > 16
779 V->getParent()->insertGlobalVariable(V->getIterator(), NGV);
780#else
781 V->getParent()->getGlobalList().insert(V->getIterator(), NGV);
782#endif
783 NGV->takeName(V);
784
785 // Nuke the old list, replacing any uses with the new one.
786 if (!V->use_empty()) {
787 Constant *VV = NGV;
788 if (VV->getType() != V->getType())
789 VV = ConstantExpr::getBitCast(VV, V->getType());
790 V->replaceAllUsesWith(VV);
791 }
792 }
793 V->eraseFromParent();
794 }
795 }
796 }
797 changed = true;
798 G->replaceAllUsesWith(ConstantPointerNull::get(G->getType()));
799 G->eraseFromParent();
800 }
801
802 StringMap<std::pair<std::string, std::string>> Implements;
803 for (std::string T : {"", "f"}) {
804 // CUDA
805 // sincos, sinpi, cospi, sincospi, cyl_bessel_i1
806 for (std::string name :
807 {"sin", "cos", "tan", "log2", "exp", "exp2",
808 "exp10", "cosh", "sinh", "tanh", "atan2", "atan",
809 "asin", "acos", "log", "log10", "log1p", "acosh",
810 "asinh", "atanh", "expm1", "hypot", "rhypot", "norm3d",
811 "rnorm3d", "norm4d", "rnorm4d", "norm", "rnorm", "cbrt",
812 "rcbrt", "j0", "j1", "y0", "y1", "yn",
813 "jn", "erf", "erfinv", "erfc", "erfcx", "erfcinv",
814 "normcdfinv", "normcdf", "lgamma", "ldexp", "scalbn", "frexp",
815 "modf", "fmod", "remainder", "remquo", "powi", "tgamma",
816 "round", "fdim", "ilogb", "logb", "isinf", "pow",
817 "sqrt", "finite", "fabs", "fmax"}) {
818 std::string nvname = "__nv_" + name;
819 std::string llname = "llvm." + name + ".";
820 std::string mathname = name;
821
822 if (T == "f") {
823 mathname += "f";
824 nvname += "f";
825 llname += "f32";
826 } else {
827 llname += "f64";
828 }
829
830 Implements[nvname] = std::make_pair(mathname, llname);
831 }
832 // ROCM
833 // sincos, sinpi, cospi, sincospi, cyl_bessel_i1
834 for (std::string name : {"acos", "acosh", "asin",
835 "asinh", "atan2", "atan",
836 "atanh", "cbrt", "ceil",
837 "copysign", "cos", "native_cos",
838 "cosh", "cospi", "i0",
839 "i1", "erfc", "erfcinv",
840 "erfcx", "erf", "erfinv",
841 "exp10", "native_exp10", "exp2",
842 "exp", "native_exp", "expm1",
843 "fabs", "fdim", "floor",
844 "fma", "fmax", "fmin",
845 "fmod", "frexp", "hypot",
846 "ilogb", "isfinite", "isinf",
847 "isnan", "j0", "j1",
848 "ldexp", "lgamma", "log10",
849 "native_log10", "log1p", "log2",
850 "log2", "logb", "log",
851 "native_log", "modf", "nearbyint",
852 "nextafter", "len3", "len4",
853 "ncdf", "ncdfinv", "pow",
854 "pown", "rcbrt", "remainder",
855 "remquo", "rhypot", "rint",
856 "rlen3", "rlen4", "round",
857 "rsqrt", "scalb", "scalbn",
858 "signbit", "sincos", "sincospi",
859 "sin", "native_sin", "sinh",
860 "sinpi", "sqrt", "native_sqrt",
861 "tan", "tanh", "tgamma",
862 "trunc", "y0", "y1"}) {
863 std::string nvname = "__ocml_" + name + "_";
864 std::string llname = "llvm." + name + ".";
865 std::string mathname = name;
866
867 if (T == "f") {
868 mathname += "f";
869 nvname += "f32";
870 llname += "f32";
871 } else {
872 nvname += "f64";
873 llname += "f64";
874 }
875
876 Implements[nvname] = std::make_pair(mathname, llname);
877 }
878 }
879#if ENZYME_ENABLE_NVVM_ATTRIBUTION
880 for (auto &F : llvm::make_early_inc_range(M)) {
881 if (Begin) {
882 changed |= attributeKnownFunctions(F);
883 }
884 }
885#endif
886 for (auto &F : M) {
887 auto found = Implements.find(F.getName());
888 if (found != Implements.end()) {
889 changed = true;
890 if (Begin) {
891 // As a side effect, enforces arguments
892 // cannot be erased.
893 F.addFnAttr("implements", found->second.second);
894 F.addFnAttr("implements2", found->second.first);
895 F.addFnAttr("enzyme_math", found->second.first);
896 changed |= preserveLinkage(Begin, F);
897 }
898 } else if (F.getName() == "_ZL21__internal_float2halffRjS_" ||
899 F.getName() == "_ZL4hlog6__half" ||
900 F.getName() == "_ZL6__hdiv6__halfS_" ||
901 F.getName() == "_ZL12__half2float6__half" ||
902 F.getName() == "_ZL6__habs6__half" ||
903 F.getName() == "_ZL5__hlt6__halfS_" ||
904 F.getName() == "_ZL6__hmul6__halfS_" ||
905 F.getName() == "_ZL6__hadd6__halfS_" ||
906 F.getName() == "_ZL5hsqrt6__half" ||
907 F.getName() == "_ZL6__hsub6__halfS_" ||
908 F.getName() == "_ZL4hexp6__half" ||
909 F.getName() == "_ZL6__hneg6__half" ||
910 F.getName() == "_ZL22__internal_device_hdiv13__nv_bfloat16S_" ||
911 F.getName() ==
912 "_ZL27__internal_sm80_device_hmul13__nv_bfloat16S_" ||
913 F.getName() == "_ZL22__internal_device_hadd13__nv_bfloat16S_" ||
914 F.getName() ==
915 "_ZL27__internal_sm80_device_hsub13__nv_bfloat16S_" ||
916 F.getName() == "_ZL22__internal_device_hneg13__nv_bfloat16" ||
917 F.getName() == "_ZL16__float2bfloat16f" ||
918 F.getName() == "_ZL25__internal_bfloat162floatt" ||
919 F.getName() == "_ZL32__internal_device_bfloat162floatt") {
920 changed = true;
921 if (Begin) {
922 changed |= preserveLinkage(Begin, F);
923 }
924 }
925 if (!Begin && F.hasFnAttribute("prev_fixup")) {
926 changed = true;
927 F.removeFnAttr("prev_fixup");
928 if (F.hasFnAttribute("prev_always_inline")) {
929 F.addFnAttr(Attribute::AlwaysInline);
930 F.removeFnAttr("prev_always_inline");
931 }
932 if (F.hasFnAttribute("prev_no_inline")) {
933 F.removeFnAttr("prev_no_inline");
934 } else {
935 F.removeFnAttr(Attribute::NoInline);
936 }
937 int64_t val;
938 F.getFnAttribute("prev_linkage").getValueAsString().getAsInteger(10, val);
939 F.setLinkage((Function::LinkageTypes)val);
940 }
941 }
942 return changed;
943}
944
945namespace {
946
947class PreserveNVVM final : public ModulePass {
948public:
949 static char ID;
950 bool Begin;
951 PreserveNVVM(bool Begin = true) : ModulePass(ID), Begin(Begin) {}
952
953 void getAnalysisUsage(AnalysisUsage &AU) const override {}
954 bool runOnModule(Module &M) override { return preserveNVVM(Begin, M); }
955};
956
957class PreserveNVVMFn final : public FunctionPass {
958public:
959 static char ID;
960 bool Begin;
961 PreserveNVVMFn(bool Begin = true) : FunctionPass(ID), Begin(Begin) {}
962
963 void getAnalysisUsage(AnalysisUsage &AU) const override {}
964 bool runOnFunction(Function &F) override {
965 return preserveNVVM(Begin, *F.getParent());
966 }
967};
968
969} // namespace
970
971char PreserveNVVM::ID = 0;
972
973char PreserveNVVMFn::ID = 0;
974
975static RegisterPass<PreserveNVVM> X("preserve-nvvm", "Preserve NVVM Pass");
976
977static RegisterPass<PreserveNVVMFn> XFn("preserve-nvvm-fn",
978 "Preserve NVVM Pass");
979
981 return new PreserveNVVM(Begin);
982}
983
984FunctionPass *createPreserveNVVMFnPass(bool Begin) {
985 return new PreserveNVVMFn(Begin);
986}
987
988#include <llvm-c/Core.h>
989#include <llvm-c/Types.h>
990
991#include "llvm/IR/LegacyPassManager.h"
992
993extern "C" void AddPreserveNVVMPass(LLVMPassManagerRef PM, uint8_t Begin) {
994 unwrap(PM)->add(createPreserveNVVMPass((bool)Begin));
995}
996
998PreserveNVVMNewPM::run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) {
999 bool changed = preserveNVVM(Begin, M);
1000 return changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
1001}
1002llvm::AnalysisKey PreserveNVVMNewPM::Key;
static std::string str(AugmentedStruct c)
Definition EnzymeLogic.h:62
bool preserveNVVM(bool Begin, Module &M)
bool isTargetNVPTX(llvm::Module &M)
static RegisterPass< PreserveNVVM > X("preserve-nvvm", "Preserve NVVM Pass")
bool preserveLinkage(bool Begin, Function &F, bool Inlining=true)
Returns whether changed.
static void handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, SmallVectorImpl< GlobalVariable * > &globalsToErase)
ModulePass * createPreserveNVVMPass(bool Begin)
static RegisterPass< PreserveNVVMFn > XFn("preserve-nvvm-fn", "Preserve NVVM Pass")
void AddPreserveNVVMPass(LLVMPassManagerRef PM, uint8_t Begin)
FunctionPass * createPreserveNVVMFnPass(bool Begin)
bool attributeKnownFunctions(llvm::Function &F)
Definition Utils.cpp:114
static bool startsWith(llvm::StringRef string, llvm::StringRef prefix)
Definition Utils.h:713
llvm::PreservedAnalyses Result
Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM)