25#include "clang/AST/Attr.h"
26#include "clang/AST/DeclGroup.h"
27#include "clang/AST/RecursiveASTVisitor.h"
28#include "clang/Basic/FileManager.h"
29#include "clang/Basic/MacroBuilder.h"
30#include "clang/Frontend/CompilerInstance.h"
31#include "clang/Frontend/FrontendAction.h"
32#include "clang/Frontend/FrontendPluginRegistry.h"
33#include "clang/Lex/HeaderSearch.h"
34#include "clang/Lex/PreprocessorOptions.h"
35#include "clang/Sema/Sema.h"
36#include "clang/Sema/SemaDiagnostic.h"
40#include "bundled_includes.h"
44#if LLVM_VERSION_MAJOR >= 18
45constexpr auto StructKind = clang::TagTypeKind::Struct;
47constexpr auto StructKind = clang::TagTypeKind::TTK_Struct;
50template <
typename ConsumerType>
53 std::unique_ptr<clang::ASTConsumer>
55 llvm::StringRef InFile)
override {
56 return std::unique_ptr<clang::ASTConsumer>(
new ConsumerType(CI));
60 const std::vector<std::string> &args)
override {
65 return AddBeforeMainAction;
83#if LLVM_VERSION_MAJOR >= 18
88 clang::CompilerInstance &CI;
93 FrontendOptions &Opts = CI.getFrontendOpts();
94 CodeGenOptions &CGOpts = CI.getCodeGenOpts();
95 auto PluginName =
"ClangEnzyme-" + std::to_string(LLVM_VERSION_MAJOR);
97#if LLVM_VERSION_MAJOR < 18
98 std::string pluginPath;
100 for (
auto P : Opts.Plugins)
101 if (
endsWith(llvm::sys::path::stem(P), PluginName)) {
102#if LLVM_VERSION_MAJOR < 18
105 for (
auto passPlugin : CGOpts.PassPlugins) {
106 if (
endsWith(llvm::sys::path::stem(passPlugin), PluginName)) {
114#if LLVM_VERSION_MAJOR >= 18
117 CGOpts.PassPlugins.push_back(pluginPath);
120 CI.getPreprocessorOpts().Includes.push_back(
"/enzyme/enzyme/version");
122 std::string PredefineBuffer;
123 PredefineBuffer.reserve(4080);
124 llvm::raw_string_ostream Predefines(PredefineBuffer);
125 Predefines << CI.getPreprocessor().getPredefines();
126 MacroBuilder Builder(Predefines);
127 Builder.defineMacro(
"ENZYME_VERSION_MAJOR",
128 std::to_string(ENZYME_VERSION_MAJOR));
129 Builder.defineMacro(
"ENZYME_VERSION_MINOR",
130 std::to_string(ENZYME_VERSION_MINOR));
131 Builder.defineMacro(
"ENZYME_VERSION_PATCH",
132 std::to_string(ENZYME_VERSION_PATCH));
133 CI.getPreprocessor().setPredefines(Predefines.str());
135 auto baseFS = &CI.getFileManager().getVirtualFileSystem();
136 llvm::vfs::OverlayFileSystem *fuseFS(
137 new llvm::vfs::OverlayFileSystem(baseFS));
138 IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> fs(
139 new llvm::vfs::InMemoryFileSystem());
149 time_t timer = mktime(&y2k);
150 for (
const auto &pair : include_headers) {
151 fs->addFile(StringRef(pair[0]), timer,
152 llvm::MemoryBuffer::getMemBuffer(
153 StringRef(pair[1]), StringRef(pair[0]),
157 fuseFS->pushOverlay(fs);
158 fuseFS->pushOverlay(baseFS);
159 CI.getFileManager().setVirtualFileSystem(fuseFS);
161 auto DE = CI.getFileManager().getDirectoryRef(
"/enzymeroot");
163 auto DL = DirectoryLookup(*DE, SrcMgr::C_User,
165 CI.getPreprocessor().getHeaderSearchInfo().AddSearchPath(DL,
171 using namespace clang;
172 DeclGroupRef::iterator it;
176 for (it = dg.begin(); it != dg.end(); ++it) {
178 if (
auto FD = dyn_cast<FunctionDecl>(*it)) {
179 if (!FD->hasAttr<clang::CUDADeviceAttr>())
182 if (!FD->getIdentifier())
184 if (!StringRef(FD->getLocation().printToString(CI.getSourceManager()))
185 .contains(
"/__clang_cuda_math.h") &&
186 !StringRef(FD->getLocation().printToString(CI.getSourceManager()))
187 .contains(
"/__clang_hip_math.h"))
190 FD->addAttr(UsedAttr::CreateImplicit(CI.getASTContext()));
192 if (
auto FD = dyn_cast<VarDecl>(*it)) {
199 if (!V->getIdentifier())
201 auto name = V->getName();
202 if (!(name.contains(
"__enzyme_inactive_global") ||
203 name.contains(
"__enzyme_inactivefn") ||
204 name.contains(
"__enzyme_shouldrecompute") ||
205 name.contains(
"__enzyme_function_like") ||
206 name.contains(
"__enzyme_allocation_like") ||
207 name.contains(
"__enzyme_register_gradient") ||
208 name.contains(
"__enzyme_register_derivative") ||
209 name.contains(
"__enzyme_register_splitderivative")))
212 V->addAttr(clang::UsedAttr::CreateImplicit(CI.getASTContext()));
218static clang::FrontendPluginRegistry::Add<EnzymeAction<EnzymePlugin>>
219 X(
"enzyme",
"Enzyme Plugin");
221#if LLVM_VERSION_MAJOR > 10
224struct EnzymeFunctionLikeAttrInfo :
public ParsedAttrInfo {
225 EnzymeFunctionLikeAttrInfo() {
229 static constexpr Spelling S[] = {
230 {ParsedAttr::AS_GNU,
"enzyme_function_like"},
231#if LLVM_VERSION_MAJOR > 17
232 {ParsedAttr::AS_C23,
"enzyme_function_like"},
234 {ParsedAttr::AS_C2x,
"enzyme_function_like"},
236 {ParsedAttr::AS_CXX11,
"enzyme_function_like"},
237 {ParsedAttr::AS_CXX11,
"enzyme::function_like"}
242 bool diagAppertainsToDecl(Sema &S,
const ParsedAttr &Attr,
243 const Decl *D)
const override {
245 if (!isa<FunctionDecl>(D)) {
246 S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
247 << Attr <<
"functions";
253 AttrHandling handleDeclAttribute(Sema &S, Decl *D,
254 const ParsedAttr &Attr)
const override {
255 if (Attr.getNumArgs() != 1) {
256 unsigned ID = S.getDiagnostics().getCustomDiagID(
257 DiagnosticsEngine::Error,
258 "'enzyme_function' attribute requires a single string argument");
259 S.Diag(Attr.getLoc(), ID);
260 return AttributeNotApplied;
262 auto *Arg0 = Attr.getArgAsExpr(0);
263 StringLiteral *Literal = dyn_cast<StringLiteral>(Arg0->IgnoreParenCasts());
265 unsigned ID = S.getDiagnostics().getCustomDiagID(
266 DiagnosticsEngine::Error,
"first argument to 'enzyme_function_like' "
267 "attribute must be a string literal");
268 S.Diag(Attr.getLoc(), ID);
269 return AttributeNotApplied;
271#if LLVM_VERSION_MAJOR >= 12
272 D->addAttr(AnnotateAttr::Create(
273 S.Context, (
"enzyme_function_like=" + Literal->getString()).str(),
274 nullptr, 0, Attr.getRange()));
275 return AttributeApplied;
277 auto FD = cast<FunctionDecl>(D);
279 auto &AST = S.getASTContext();
280 DeclContext *declCtx = FD->getDeclContext();
281 for (
auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) {
282 if (tmpCtx->isRecord()) {
283 declCtx = tmpCtx->getParent();
286 auto loc = FD->getLocation();
288 if (S.getLangOpts().CPlusPlus)
289 RD = CXXRecordDecl::Create(AST,
StructKind, declCtx, loc, loc,
292 RD = RecordDecl::Create(AST,
StructKind, declCtx, loc, loc,
294 RD->setAnonymousStructOrUnion(
true);
296 RD->startDefinition();
297 auto Tinfo =
nullptr;
298 auto Tinfo0 =
nullptr;
299 auto FT = AST.getPointerType(FD->getType());
300 auto CharTy = AST.getIntTypeForBitwidth(8,
false);
301 auto FD0 = FieldDecl::Create(AST, RD, loc, loc,
nullptr, FT, Tinfo0,
304 FD0->setAccess(AS_public);
306 auto FD1 = FieldDecl::Create(
307 AST, RD, loc, loc,
nullptr, AST.getPointerType(CharTy), Tinfo0,
308 nullptr,
true, ICIS_NoInit);
309 FD1->setAccess(AS_public);
311 RD->completeDefinition();
312 assert(RD->getDefinition());
313 auto &Id = AST.Idents.get(
"__enzyme_function_like_autoreg_" +
314 FD->getNameAsString());
315 auto T = AST.getRecordType(RD);
316 auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, T, Tinfo, SC_None);
317 V->setStorageClass(SC_PrivateExtern);
318 V->addAttr(clang::UsedAttr::CreateImplicit(AST));
319 TemplateArgumentListInfo *TemplateArgs =
nullptr;
320 auto DR = DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), loc, FD,
false,
321 loc, FD->getType(), ExprValueKind::VK_LValue,
323 auto rval = ExprValueKind::VK_PRValue;
324 StringRef cstr = Literal->getString();
326 ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay,
327 DR,
nullptr, rval, FPOptionsOverride()),
328 ImplicitCastExpr::Create(
329 AST, AST.getPointerType(CharTy), CastKind::CK_ArrayToPointerDecay,
330 StringLiteral::Create(
331 AST, cstr, stringkind,
333 AST.getStringLiteralArrayType(CharTy, cstr.size()), loc),
334 nullptr, rval, FPOptionsOverride())};
335 auto IL =
new (AST) InitListExpr(AST, loc, exprs, loc);
338 if (IL->isValueDependent()) {
339 unsigned ID = S.getDiagnostics().getCustomDiagID(
340 DiagnosticsEngine::Error,
"use of attribute 'enzyme_function_like' "
341 "in a templated context not yet supported");
342 S.Diag(Attr.getLoc(), ID);
343 return AttributeNotApplied;
345 S.MarkVariableReferenced(loc, V);
346 S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V));
347 return AttributeApplied;
352static ParsedAttrInfoRegistry::Add<EnzymeFunctionLikeAttrInfo>
353 X3(
"enzyme_function_like",
"");
355struct EnzymeShouldRecomputeAttrInfo :
public ParsedAttrInfo {
356 EnzymeShouldRecomputeAttrInfo() {
358 static constexpr Spelling S[] = {
359 {ParsedAttr::AS_GNU,
"enzyme_shouldrecompute"},
360#if LLVM_VERSION_MAJOR > 17
361 {ParsedAttr::AS_C23,
"enzyme_shouldrecompute"},
363 {ParsedAttr::AS_C2x,
"enzyme_shouldrecompute"},
365 {ParsedAttr::AS_CXX11,
"enzyme_shouldrecompute"},
366 {ParsedAttr::AS_CXX11,
"enzyme::shouldrecompute"}
371 bool diagAppertainsToDecl(Sema &S,
const ParsedAttr &Attr,
372 const Decl *D)
const override {
374 if (isa<FunctionDecl>(D))
376 if (
auto VD = dyn_cast<VarDecl>(D)) {
377 if (VD->hasGlobalStorage())
380 S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
381 << Attr <<
"functions and globals";
385 AttrHandling handleDeclAttribute(Sema &S, Decl *D,
386 const ParsedAttr &Attr)
const override {
387 if (Attr.getNumArgs() != 0) {
388 unsigned ID = S.getDiagnostics().getCustomDiagID(
389 DiagnosticsEngine::Error,
390 "'enzyme_inactive' attribute requires zero arguments");
391 S.Diag(Attr.getLoc(), ID);
392 return AttributeNotApplied;
394 D->addAttr(AnnotateAttr::Create(S.Context,
"enzyme_shouldrecompute",
395 nullptr, 0, Attr.getRange()));
396 return AttributeApplied;
400static ParsedAttrInfoRegistry::Add<EnzymeShouldRecomputeAttrInfo>
401 ESR(
"enzyme_shouldrecompute",
"");
403struct EnzymeInactiveAttrInfo :
public ParsedAttrInfo {
404 EnzymeInactiveAttrInfo() {
408 static constexpr Spelling S[] = {
409 {ParsedAttr::AS_GNU,
"enzyme_inactive"},
410#if LLVM_VERSION_MAJOR > 17
411 {ParsedAttr::AS_C23,
"enzyme_inactive"},
413 {ParsedAttr::AS_C2x,
"enzyme_inactive"},
415 {ParsedAttr::AS_CXX11,
"enzyme_inactive"},
416 {ParsedAttr::AS_CXX11,
"enzyme::inactive"}
421 bool diagAppertainsToDecl(Sema &S,
const ParsedAttr &Attr,
422 const Decl *D)
const override {
424 if (isa<FunctionDecl>(D))
426 if (
auto VD = dyn_cast<VarDecl>(D)) {
427 if (VD->hasGlobalStorage())
430 S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
431 << Attr <<
"functions and globals";
435 AttrHandling handleDeclAttribute(Sema &S, Decl *D,
436 const ParsedAttr &Attr)
const override {
437 if (Attr.getNumArgs() != 0) {
438 unsigned ID = S.getDiagnostics().getCustomDiagID(
439 DiagnosticsEngine::Error,
440 "'enzyme_inactive' attribute requires zero arguments");
441 S.Diag(Attr.getLoc(), ID);
442 return AttributeNotApplied;
445 auto &AST = S.getASTContext();
446 DeclContext *declCtx = D->getDeclContext();
447 for (
auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) {
448 if (tmpCtx->isRecord()) {
449 declCtx = tmpCtx->getParent();
452 auto loc = D->getLocation();
454 if (S.getLangOpts().CPlusPlus)
455 RD = CXXRecordDecl::Create(AST,
StructKind, declCtx, loc, loc,
458 RD = RecordDecl::Create(AST,
StructKind, declCtx, loc, loc,
460 RD->setAnonymousStructOrUnion(
true);
462 RD->startDefinition();
463 auto T = isa<FunctionDecl>(D) ? cast<FunctionDecl>(D)->getType()
464 : cast<VarDecl>(D)->getType();
465 auto Name = isa<FunctionDecl>(D) ? cast<FunctionDecl>(D)->getNameAsString()
466 : cast<VarDecl>(D)->getNameAsString();
467 auto FT = AST.getPointerType(T);
468 auto subname = isa<FunctionDecl>(D) ?
"inactivefn" :
"inactive_global";
469 auto &Id = AST.Idents.get(
470 (StringRef(
"__enzyme_") + subname +
"_autoreg_" + Name).
str());
471 auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, FT,
nullptr, SC_None);
472 V->setStorageClass(SC_PrivateExtern);
473 V->addAttr(clang::UsedAttr::CreateImplicit(AST));
474 TemplateArgumentListInfo *TemplateArgs =
nullptr;
475 auto DR = DeclRefExpr::Create(
476 AST, NestedNameSpecifierLoc(), loc, cast<ValueDecl>(D),
false, loc, T,
477 ExprValueKind::VK_LValue, cast<NamedDecl>(D), TemplateArgs);
478 auto rval = ExprValueKind::VK_PRValue;
479 Expr *expr =
nullptr;
480 if (isa<FunctionDecl>(D)) {
482 ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay,
483 DR,
nullptr, rval, FPOptionsOverride());
486 UnaryOperator::Create(AST, DR, UnaryOperatorKind::UO_AddrOf, FT, rval,
487 clang::ExprObjectKind ::OK_Ordinary, loc,
488 false, FPOptionsOverride());
491 if (expr->isValueDependent()) {
492 unsigned ID = S.getDiagnostics().getCustomDiagID(
493 DiagnosticsEngine::Error,
"use of attribute 'enzyme_inactive' "
494 "in a templated context not yet supported");
495 S.Diag(Attr.getLoc(), ID);
496 return AttributeNotApplied;
499 S.MarkVariableReferenced(loc, V);
500 S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V));
501 return AttributeApplied;
505static ParsedAttrInfoRegistry::Add<EnzymeInactiveAttrInfo> X4(
"enzyme_inactive",
508struct EnzymeElementwiseReadAttrInfo :
public ParsedAttrInfo {
509 EnzymeElementwiseReadAttrInfo() {
511 static constexpr Spelling S[] = {
512 {ParsedAttr::AS_GNU,
"enzyme_elementwise_read"},
513#if LLVM_VERSION_MAJOR > 17
514 {ParsedAttr::AS_C23,
"enzyme_elementwise_read"},
516 {ParsedAttr::AS_C2x,
"enzyme_elementwise_read"},
518 {ParsedAttr::AS_CXX11,
"enzyme_elementwise_read"},
519 {ParsedAttr::AS_CXX11,
"enzyme::elementwise_read"}
524 bool diagAppertainsToDecl(Sema &S,
const ParsedAttr &Attr,
525 const Decl *D)
const override {
526 if (isa<FunctionDecl>(D))
528 S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
529 << Attr <<
"functions";
533 AttrHandling handleDeclAttribute(Sema &S, Decl *D,
534 const ParsedAttr &Attr)
const override {
535 if (Attr.getNumArgs() != 0) {
536 unsigned ID = S.getDiagnostics().getCustomDiagID(
537 DiagnosticsEngine::Error,
538 "'enzyme_elementwise_read' attribute requires zero arguments");
539 S.Diag(Attr.getLoc(), ID);
540 return AttributeNotApplied;
542 D->addAttr(AnnotateAttr::Create(S.Context,
"enzyme_elementwise_read",
543 nullptr, 0, Attr.getRange()));
544 return AttributeApplied;
548static ParsedAttrInfoRegistry::Add<EnzymeElementwiseReadAttrInfo>
549 XElemRead(
"enzyme_elementwise_read",
"");
551struct EnzymeNoFreeAttrInfo :
public ParsedAttrInfo {
552 EnzymeNoFreeAttrInfo() {
556 static constexpr Spelling S[] = {
557 {ParsedAttr::AS_GNU,
"enzyme_nofree"},
558#if LLVM_VERSION_MAJOR > 17
559 {ParsedAttr::AS_C23,
"enzyme_nofree"},
561 {ParsedAttr::AS_C2x,
"enzyme_nofree"},
563 {ParsedAttr::AS_CXX11,
"enzyme_nofree"},
564 {ParsedAttr::AS_CXX11,
"enzyme::nofree"}
569 bool diagAppertainsToDecl(Sema &S,
const ParsedAttr &Attr,
570 const Decl *D)
const override {
572 if (isa<FunctionDecl>(D))
574 if (
auto VD = dyn_cast<VarDecl>(D)) {
575 if (VD->hasGlobalStorage())
578 S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
579 << Attr <<
"functions and globals";
583 AttrHandling handleDeclAttribute(Sema &S, Decl *D,
584 const ParsedAttr &Attr)
const override {
585 if (Attr.getNumArgs() != 0) {
586 unsigned ID = S.getDiagnostics().getCustomDiagID(
587 DiagnosticsEngine::Error,
588 "'enzyme_nofree' attribute requires zero arguments");
589 S.Diag(Attr.getLoc(), ID);
590 return AttributeNotApplied;
593 auto &AST = S.getASTContext();
594 DeclContext *declCtx = D->getDeclContext();
595 for (
auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) {
596 if (tmpCtx->isRecord()) {
597 declCtx = tmpCtx->getParent();
600 auto loc = D->getLocation();
602 if (S.getLangOpts().CPlusPlus)
603 RD = CXXRecordDecl::Create(AST,
StructKind, declCtx, loc, loc,
606 RD = RecordDecl::Create(AST,
StructKind, declCtx, loc, loc,
608 RD->setAnonymousStructOrUnion(
true);
610 RD->startDefinition();
611 auto T = isa<FunctionDecl>(D) ? cast<FunctionDecl>(D)->getType()
612 : cast<VarDecl>(D)->getType();
613 auto Name = isa<FunctionDecl>(D) ? cast<FunctionDecl>(D)->getNameAsString()
614 : cast<VarDecl>(D)->getNameAsString();
615 auto FT = AST.getPointerType(T);
616 auto &Id = AST.Idents.get(
617 (StringRef(
"__enzyme_nofree") +
"_autoreg_" + Name).
str());
618 auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, FT,
nullptr, SC_None);
619 V->setStorageClass(SC_PrivateExtern);
620 V->addAttr(clang::UsedAttr::CreateImplicit(AST));
621 TemplateArgumentListInfo *TemplateArgs =
nullptr;
622 auto DR = DeclRefExpr::Create(
623 AST, NestedNameSpecifierLoc(), loc, cast<ValueDecl>(D),
false, loc, T,
624 ExprValueKind::VK_LValue, cast<NamedDecl>(D), TemplateArgs);
625 auto rval = ExprValueKind::VK_PRValue;
626 Expr *expr =
nullptr;
627 if (isa<FunctionDecl>(D)) {
629 ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay,
630 DR,
nullptr, rval, FPOptionsOverride());
633 UnaryOperator::Create(AST, DR, UnaryOperatorKind::UO_AddrOf, FT, rval,
634 clang::ExprObjectKind ::OK_Ordinary, loc,
635 false, FPOptionsOverride());
638 if (expr->isValueDependent()) {
639 unsigned ID = S.getDiagnostics().getCustomDiagID(
640 DiagnosticsEngine::Error,
"use of attribute 'enzyme_nofree' "
641 "in a templated context not yet supported");
642 S.Diag(Attr.getLoc(), ID);
643 return AttributeNotApplied;
646 S.MarkVariableReferenced(loc, V);
647 S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V));
648 return AttributeApplied;
652static ParsedAttrInfoRegistry::Add<EnzymeNoFreeAttrInfo> X5(
"enzyme_nofree",
655struct EnzymeSparseAccumulateAttrInfo :
public ParsedAttrInfo {
656 EnzymeSparseAccumulateAttrInfo() {
660 static constexpr Spelling S[] = {
661 {ParsedAttr::AS_GNU,
"enzyme_sparse_accumulate"},
662#if LLVM_VERSION_MAJOR > 17
663 {ParsedAttr::AS_C23,
"enzyme_sparse_accumulate"},
665 {ParsedAttr::AS_C2x,
"enzyme_sparse_accumulate"},
667 {ParsedAttr::AS_CXX11,
"enzyme_sparse_accumulate"},
668 {ParsedAttr::AS_CXX11,
"enzyme::sparse_accumulate"}
673 bool diagAppertainsToDecl(Sema &S,
const ParsedAttr &Attr,
674 const Decl *D)
const override {
676 if (isa<FunctionDecl>(D))
678 S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
679 << Attr <<
"functions";
683 AttrHandling handleDeclAttribute(Sema &S, Decl *D,
684 const ParsedAttr &Attr)
const override {
685 if (Attr.getNumArgs() != 0) {
686 unsigned ID = S.getDiagnostics().getCustomDiagID(
687 DiagnosticsEngine::Error,
688 "'enzyme_sparse_accumulate' attribute requires zero arguments");
689 S.Diag(Attr.getLoc(), ID);
690 return AttributeNotApplied;
693 auto &AST = S.getASTContext();
694 DeclContext *declCtx = D->getDeclContext();
695 for (
auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) {
696 if (tmpCtx->isRecord()) {
697 declCtx = tmpCtx->getParent();
700 auto loc = D->getLocation();
702 if (S.getLangOpts().CPlusPlus)
703 RD = CXXRecordDecl::Create(AST,
StructKind, declCtx, loc, loc,
706 RD = RecordDecl::Create(AST,
StructKind, declCtx, loc, loc,
708 RD->setAnonymousStructOrUnion(
true);
710 RD->startDefinition();
711 auto T = cast<FunctionDecl>(D)->getType();
712 auto Name = cast<FunctionDecl>(D)->getNameAsString();
713 auto FT = AST.getPointerType(T);
714 auto &Id = AST.Idents.get(
715 (StringRef(
"__enzyme_sparse_accumulate") +
"_autoreg_" + Name).
str());
716 auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, FT,
nullptr, SC_None);
717 V->setStorageClass(SC_PrivateExtern);
718 V->addAttr(clang::UsedAttr::CreateImplicit(AST));
719 TemplateArgumentListInfo *TemplateArgs =
nullptr;
720 auto DR = DeclRefExpr::Create(
721 AST, NestedNameSpecifierLoc(), loc, cast<ValueDecl>(D),
false, loc, T,
722 ExprValueKind::VK_LValue, cast<NamedDecl>(D), TemplateArgs);
723 auto rval = ExprValueKind::VK_PRValue;
724 Expr *expr =
nullptr;
726 ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay,
727 DR,
nullptr, rval, FPOptionsOverride());
729 if (expr->isValueDependent()) {
730 unsigned ID = S.getDiagnostics().getCustomDiagID(
731 DiagnosticsEngine::Error,
732 "use of attribute 'enzyme_sparse_accumulate' "
733 "in a templated context not yet supported");
734 S.Diag(Attr.getLoc(), ID);
735 return AttributeNotApplied;
738 S.MarkVariableReferenced(loc, V);
739 S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V));
740 return AttributeApplied;
744static ParsedAttrInfoRegistry::Add<EnzymeSparseAccumulateAttrInfo>
745 SparseX(
"enzyme_sparse_accumulate",
"");
static bool contains(ArrayRef< int > ar, int v)
static clang::FrontendPluginRegistry::Add< EnzymeAction< EnzymePlugin > > X("enzyme", "Enzyme Plugin")
constexpr auto StructKind
void MakeGlobalOfFn(FunctionDecl *FD, CompilerInstance &CI)
void registerEnzyme(llvm::PassBuilder &PB)
static std::string str(AugmentedStruct c)
static bool endsWith(llvm::StringRef string, llvm::StringRef suffix)
bool ParseArgs(const clang::CompilerInstance &CI, const std::vector< std::string > &args) override
std::unique_ptr< clang::ASTConsumer > CreateASTConsumer(clang::CompilerInstance &CI, llvm::StringRef InFile) override
PluginASTAction::ActionType getActionType() override
void HandleTranslationUnit(ASTContext &context) override
EnzymePlugin(clang::CompilerInstance &CI)
bool HandleTopLevelDecl(clang::DeclGroupRef dg) override
void HandleCXXStaticMemberVarInstantiation(clang::VarDecl *V) override
bool VisitFunctionDecl(FunctionDecl *FD)
Visitor(CompilerInstance &CI)