Enzyme main
Loading...
Searching...
No Matches
EnzymeClang.cpp
Go to the documentation of this file.
1//===- EnzymeClang.cpp - Automatic Differentiation Transformation Pass ----===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains a clang plugin for Enzyme.
22//
23//===----------------------------------------------------------------------===//
24
25#include "clang/AST/Attr.h"
26#include "clang/AST/DeclGroup.h"
27#include "clang/AST/RecursiveASTVisitor.h"
28#include "clang/Basic/FileManager.h"
29#include "clang/Basic/MacroBuilder.h"
30#include "clang/Frontend/CompilerInstance.h"
31#include "clang/Frontend/FrontendAction.h"
32#include "clang/Frontend/FrontendPluginRegistry.h"
33#include "clang/Lex/HeaderSearch.h"
34#include "clang/Lex/PreprocessorOptions.h"
35#include "clang/Sema/Sema.h"
36#include "clang/Sema/SemaDiagnostic.h"
37
38#include "../Utils.h"
39
40#include "bundled_includes.h"
41
42using namespace clang;
43
44#if LLVM_VERSION_MAJOR >= 18
45constexpr auto StructKind = clang::TagTypeKind::Struct;
46#else
47constexpr auto StructKind = clang::TagTypeKind::TTK_Struct;
48#endif
49
50template <typename ConsumerType>
51class EnzymeAction final : public clang::PluginASTAction {
52protected:
53 std::unique_ptr<clang::ASTConsumer>
54 CreateASTConsumer(clang::CompilerInstance &CI,
55 llvm::StringRef InFile) override {
56 return std::unique_ptr<clang::ASTConsumer>(new ConsumerType(CI));
57 }
58
59 bool ParseArgs(const clang::CompilerInstance &CI,
60 const std::vector<std::string> &args) override {
61 return true;
62 }
63
64 PluginASTAction::ActionType getActionType() override {
65 return AddBeforeMainAction;
66 }
67};
68
69void MakeGlobalOfFn(FunctionDecl *FD, CompilerInstance &CI) {
70 // if (FD->isLateTemplateParsed()) return;
71 // TODO save any type info into string like attribute
72}
73
74struct Visitor : public RecursiveASTVisitor<Visitor> {
75 CompilerInstance &CI;
76 Visitor(CompilerInstance &CI) : CI(CI) {}
77 bool VisitFunctionDecl(FunctionDecl *FD) {
78 MakeGlobalOfFn(FD, CI);
79 return true;
80 }
81};
82
83#if LLVM_VERSION_MAJOR >= 18
84extern "C" void registerEnzyme(llvm::PassBuilder &PB);
85#endif
86
87class EnzymePlugin final : public clang::ASTConsumer {
88 clang::CompilerInstance &CI;
89
90public:
91 EnzymePlugin(clang::CompilerInstance &CI) : CI(CI) {
92
93 FrontendOptions &Opts = CI.getFrontendOpts();
94 CodeGenOptions &CGOpts = CI.getCodeGenOpts();
95 auto PluginName = "ClangEnzyme-" + std::to_string(LLVM_VERSION_MAJOR);
96 bool contains = false;
97#if LLVM_VERSION_MAJOR < 18
98 std::string pluginPath;
99#endif
100 for (auto P : Opts.Plugins)
101 if (endsWith(llvm::sys::path::stem(P), PluginName)) {
102#if LLVM_VERSION_MAJOR < 18
103 pluginPath = P;
104#endif
105 for (auto passPlugin : CGOpts.PassPlugins) {
106 if (endsWith(llvm::sys::path::stem(passPlugin), PluginName)) {
107 contains = true;
108 break;
109 }
110 }
111 }
112
113 if (!contains) {
114#if LLVM_VERSION_MAJOR >= 18
115 CGOpts.PassBuilderCallbacks.push_back(registerEnzyme);
116#else
117 CGOpts.PassPlugins.push_back(pluginPath);
118#endif
119 }
120 CI.getPreprocessorOpts().Includes.push_back("/enzyme/enzyme/version");
121
122 std::string PredefineBuffer;
123 PredefineBuffer.reserve(4080);
124 llvm::raw_string_ostream Predefines(PredefineBuffer);
125 Predefines << CI.getPreprocessor().getPredefines();
126 MacroBuilder Builder(Predefines);
127 Builder.defineMacro("ENZYME_VERSION_MAJOR",
128 std::to_string(ENZYME_VERSION_MAJOR));
129 Builder.defineMacro("ENZYME_VERSION_MINOR",
130 std::to_string(ENZYME_VERSION_MINOR));
131 Builder.defineMacro("ENZYME_VERSION_PATCH",
132 std::to_string(ENZYME_VERSION_PATCH));
133 CI.getPreprocessor().setPredefines(Predefines.str());
134
135 auto baseFS = &CI.getFileManager().getVirtualFileSystem();
136 llvm::vfs::OverlayFileSystem *fuseFS(
137 new llvm::vfs::OverlayFileSystem(baseFS));
138 IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> fs(
139 new llvm::vfs::InMemoryFileSystem());
140
141 struct tm y2k = {};
142
143 y2k.tm_hour = 0;
144 y2k.tm_min = 0;
145 y2k.tm_sec = 0;
146 y2k.tm_year = 100;
147 y2k.tm_mon = 0;
148 y2k.tm_mday = 1;
149 time_t timer = mktime(&y2k);
150 for (const auto &pair : include_headers) {
151 fs->addFile(StringRef(pair[0]), timer,
152 llvm::MemoryBuffer::getMemBuffer(
153 StringRef(pair[1]), StringRef(pair[0]),
154 /*RequiresNullTerminator*/ true));
155 }
156
157 fuseFS->pushOverlay(fs);
158 fuseFS->pushOverlay(baseFS);
159 CI.getFileManager().setVirtualFileSystem(fuseFS);
160
161 auto DE = CI.getFileManager().getDirectoryRef("/enzymeroot");
162 assert(DE);
163 auto DL = DirectoryLookup(*DE, SrcMgr::C_User,
164 /*isFramework=*/false);
165 CI.getPreprocessor().getHeaderSearchInfo().AddSearchPath(DL,
166 /*isAngled=*/true);
167 }
169 void HandleTranslationUnit(ASTContext &context) override {}
170 bool HandleTopLevelDecl(clang::DeclGroupRef dg) override {
171 using namespace clang;
172 DeclGroupRef::iterator it;
173
174 // Visitor v(CI);
175 // Forcibly require emission of all libdevice
176 for (it = dg.begin(); it != dg.end(); ++it) {
177 // v.TraverseDecl(*it);
178 if (auto FD = dyn_cast<FunctionDecl>(*it)) {
179 if (!FD->hasAttr<clang::CUDADeviceAttr>())
180 continue;
181
182 if (!FD->getIdentifier())
183 continue;
184 if (!StringRef(FD->getLocation().printToString(CI.getSourceManager()))
185 .contains("/__clang_cuda_math.h") &&
186 !StringRef(FD->getLocation().printToString(CI.getSourceManager()))
187 .contains("/__clang_hip_math.h"))
188 continue;
189
190 FD->addAttr(UsedAttr::CreateImplicit(CI.getASTContext()));
191 }
192 if (auto FD = dyn_cast<VarDecl>(*it)) {
194 }
195 }
196 return true;
197 }
198 void HandleCXXStaticMemberVarInstantiation(clang::VarDecl *V) override {
199 if (!V->getIdentifier())
200 return;
201 auto name = V->getName();
202 if (!(name.contains("__enzyme_inactive_global") ||
203 name.contains("__enzyme_inactivefn") ||
204 name.contains("__enzyme_shouldrecompute") ||
205 name.contains("__enzyme_function_like") ||
206 name.contains("__enzyme_allocation_like") ||
207 name.contains("__enzyme_register_gradient") ||
208 name.contains("__enzyme_register_derivative") ||
209 name.contains("__enzyme_register_splitderivative")))
210 return;
211
212 V->addAttr(clang::UsedAttr::CreateImplicit(CI.getASTContext()));
213 return;
214 }
215};
216
217// register the PluginASTAction in the registry.
218static clang::FrontendPluginRegistry::Add<EnzymeAction<EnzymePlugin>>
219 X("enzyme", "Enzyme Plugin");
220
221#if LLVM_VERSION_MAJOR > 10
222namespace {
223
224struct EnzymeFunctionLikeAttrInfo : public ParsedAttrInfo {
225 EnzymeFunctionLikeAttrInfo() {
226 OptArgs = 1;
227 // GNU-style __attribute__(("example")) and C++/C2x-style [[example]] and
228 // [[plugin::example]] supported.
229 static constexpr Spelling S[] = {
230 {ParsedAttr::AS_GNU, "enzyme_function_like"},
231#if LLVM_VERSION_MAJOR > 17
232 {ParsedAttr::AS_C23, "enzyme_function_like"},
233#else
234 {ParsedAttr::AS_C2x, "enzyme_function_like"},
235#endif
236 {ParsedAttr::AS_CXX11, "enzyme_function_like"},
237 {ParsedAttr::AS_CXX11, "enzyme::function_like"}
238 };
239 Spellings = S;
240 }
241
242 bool diagAppertainsToDecl(Sema &S, const ParsedAttr &Attr,
243 const Decl *D) const override {
244 // This attribute appertains to functions only.
245 if (!isa<FunctionDecl>(D)) {
246 S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
247 << Attr << "functions";
248 return false;
249 }
250 return true;
251 }
252
253 AttrHandling handleDeclAttribute(Sema &S, Decl *D,
254 const ParsedAttr &Attr) const override {
255 if (Attr.getNumArgs() != 1) {
256 unsigned ID = S.getDiagnostics().getCustomDiagID(
257 DiagnosticsEngine::Error,
258 "'enzyme_function' attribute requires a single string argument");
259 S.Diag(Attr.getLoc(), ID);
260 return AttributeNotApplied;
261 }
262 auto *Arg0 = Attr.getArgAsExpr(0);
263 StringLiteral *Literal = dyn_cast<StringLiteral>(Arg0->IgnoreParenCasts());
264 if (!Literal) {
265 unsigned ID = S.getDiagnostics().getCustomDiagID(
266 DiagnosticsEngine::Error, "first argument to 'enzyme_function_like' "
267 "attribute must be a string literal");
268 S.Diag(Attr.getLoc(), ID);
269 return AttributeNotApplied;
270 }
271#if LLVM_VERSION_MAJOR >= 12
272 D->addAttr(AnnotateAttr::Create(
273 S.Context, ("enzyme_function_like=" + Literal->getString()).str(),
274 nullptr, 0, Attr.getRange()));
275 return AttributeApplied;
276#else
277 auto FD = cast<FunctionDecl>(D);
278 // if (FD->isLateTemplateParsed()) return;
279 auto &AST = S.getASTContext();
280 DeclContext *declCtx = FD->getDeclContext();
281 for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) {
282 if (tmpCtx->isRecord()) {
283 declCtx = tmpCtx->getParent();
284 }
285 }
286 auto loc = FD->getLocation();
287 RecordDecl *RD;
288 if (S.getLangOpts().CPlusPlus)
289 RD = CXXRecordDecl::Create(AST, StructKind, declCtx, loc, loc,
290 nullptr); // rId);
291 else
292 RD = RecordDecl::Create(AST, StructKind, declCtx, loc, loc,
293 nullptr); // rId);
294 RD->setAnonymousStructOrUnion(true);
295 RD->setImplicit();
296 RD->startDefinition();
297 auto Tinfo = nullptr;
298 auto Tinfo0 = nullptr;
299 auto FT = AST.getPointerType(FD->getType());
300 auto CharTy = AST.getIntTypeForBitwidth(8, false);
301 auto FD0 = FieldDecl::Create(AST, RD, loc, loc, /*Ud*/ nullptr, FT, Tinfo0,
302 /*expr*/ nullptr, /*mutable*/ true,
303 /*inclassinit*/ ICIS_NoInit);
304 FD0->setAccess(AS_public);
305 RD->addDecl(FD0);
306 auto FD1 = FieldDecl::Create(
307 AST, RD, loc, loc, /*Ud*/ nullptr, AST.getPointerType(CharTy), Tinfo0,
308 /*expr*/ nullptr, /*mutable*/ true, /*inclassinit*/ ICIS_NoInit);
309 FD1->setAccess(AS_public);
310 RD->addDecl(FD1);
311 RD->completeDefinition();
312 assert(RD->getDefinition());
313 auto &Id = AST.Idents.get("__enzyme_function_like_autoreg_" +
314 FD->getNameAsString());
315 auto T = AST.getRecordType(RD);
316 auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, T, Tinfo, SC_None);
317 V->setStorageClass(SC_PrivateExtern);
318 V->addAttr(clang::UsedAttr::CreateImplicit(AST));
319 TemplateArgumentListInfo *TemplateArgs = nullptr;
320 auto DR = DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), loc, FD, false,
321 loc, FD->getType(), ExprValueKind::VK_LValue,
322 FD, TemplateArgs);
323 auto rval = ExprValueKind::VK_PRValue;
324 StringRef cstr = Literal->getString();
325 Expr *exprs[2] = {
326 ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay,
327 DR, nullptr, rval, FPOptionsOverride()),
328 ImplicitCastExpr::Create(
329 AST, AST.getPointerType(CharTy), CastKind::CK_ArrayToPointerDecay,
330 StringLiteral::Create(
331 AST, cstr, stringkind,
332 /*Pascal*/ false,
333 AST.getStringLiteralArrayType(CharTy, cstr.size()), loc),
334 nullptr, rval, FPOptionsOverride())};
335 auto IL = new (AST) InitListExpr(AST, loc, exprs, loc);
336 V->setInit(IL);
337 IL->setType(T);
338 if (IL->isValueDependent()) {
339 unsigned ID = S.getDiagnostics().getCustomDiagID(
340 DiagnosticsEngine::Error, "use of attribute 'enzyme_function_like' "
341 "in a templated context not yet supported");
342 S.Diag(Attr.getLoc(), ID);
343 return AttributeNotApplied;
344 }
345 S.MarkVariableReferenced(loc, V);
346 S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V));
347 return AttributeApplied;
348#endif
349 }
350};
351
352static ParsedAttrInfoRegistry::Add<EnzymeFunctionLikeAttrInfo>
353 X3("enzyme_function_like", "");
354
355struct EnzymeShouldRecomputeAttrInfo : public ParsedAttrInfo {
356 EnzymeShouldRecomputeAttrInfo() {
357 OptArgs = 1;
358 static constexpr Spelling S[] = {
359 {ParsedAttr::AS_GNU, "enzyme_shouldrecompute"},
360#if LLVM_VERSION_MAJOR > 17
361 {ParsedAttr::AS_C23, "enzyme_shouldrecompute"},
362#else
363 {ParsedAttr::AS_C2x, "enzyme_shouldrecompute"},
364#endif
365 {ParsedAttr::AS_CXX11, "enzyme_shouldrecompute"},
366 {ParsedAttr::AS_CXX11, "enzyme::shouldrecompute"}
367 };
368 Spellings = S;
369 }
370
371 bool diagAppertainsToDecl(Sema &S, const ParsedAttr &Attr,
372 const Decl *D) const override {
373 // This attribute appertains to functions only.
374 if (isa<FunctionDecl>(D))
375 return true;
376 if (auto VD = dyn_cast<VarDecl>(D)) {
377 if (VD->hasGlobalStorage())
378 return true;
379 }
380 S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
381 << Attr << "functions and globals";
382 return false;
383 }
384
385 AttrHandling handleDeclAttribute(Sema &S, Decl *D,
386 const ParsedAttr &Attr) const override {
387 if (Attr.getNumArgs() != 0) {
388 unsigned ID = S.getDiagnostics().getCustomDiagID(
389 DiagnosticsEngine::Error,
390 "'enzyme_inactive' attribute requires zero arguments");
391 S.Diag(Attr.getLoc(), ID);
392 return AttributeNotApplied;
393 }
394 D->addAttr(AnnotateAttr::Create(S.Context, "enzyme_shouldrecompute",
395 nullptr, 0, Attr.getRange()));
396 return AttributeApplied;
397 }
398};
399
400static ParsedAttrInfoRegistry::Add<EnzymeShouldRecomputeAttrInfo>
401 ESR("enzyme_shouldrecompute", "");
402
403struct EnzymeInactiveAttrInfo : public ParsedAttrInfo {
404 EnzymeInactiveAttrInfo() {
405 OptArgs = 1;
406 // GNU-style __attribute__(("example")) and C++/C2x-style [[example]] and
407 // [[plugin::example]] supported.
408 static constexpr Spelling S[] = {
409 {ParsedAttr::AS_GNU, "enzyme_inactive"},
410#if LLVM_VERSION_MAJOR > 17
411 {ParsedAttr::AS_C23, "enzyme_inactive"},
412#else
413 {ParsedAttr::AS_C2x, "enzyme_inactive"},
414#endif
415 {ParsedAttr::AS_CXX11, "enzyme_inactive"},
416 {ParsedAttr::AS_CXX11, "enzyme::inactive"}
417 };
418 Spellings = S;
419 }
420
421 bool diagAppertainsToDecl(Sema &S, const ParsedAttr &Attr,
422 const Decl *D) const override {
423 // This attribute appertains to functions only.
424 if (isa<FunctionDecl>(D))
425 return true;
426 if (auto VD = dyn_cast<VarDecl>(D)) {
427 if (VD->hasGlobalStorage())
428 return true;
429 }
430 S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
431 << Attr << "functions and globals";
432 return false;
433 }
434
435 AttrHandling handleDeclAttribute(Sema &S, Decl *D,
436 const ParsedAttr &Attr) const override {
437 if (Attr.getNumArgs() != 0) {
438 unsigned ID = S.getDiagnostics().getCustomDiagID(
439 DiagnosticsEngine::Error,
440 "'enzyme_inactive' attribute requires zero arguments");
441 S.Diag(Attr.getLoc(), ID);
442 return AttributeNotApplied;
443 }
444
445 auto &AST = S.getASTContext();
446 DeclContext *declCtx = D->getDeclContext();
447 for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) {
448 if (tmpCtx->isRecord()) {
449 declCtx = tmpCtx->getParent();
450 }
451 }
452 auto loc = D->getLocation();
453 RecordDecl *RD;
454 if (S.getLangOpts().CPlusPlus)
455 RD = CXXRecordDecl::Create(AST, StructKind, declCtx, loc, loc,
456 nullptr); // rId);
457 else
458 RD = RecordDecl::Create(AST, StructKind, declCtx, loc, loc,
459 nullptr); // rId);
460 RD->setAnonymousStructOrUnion(true);
461 RD->setImplicit();
462 RD->startDefinition();
463 auto T = isa<FunctionDecl>(D) ? cast<FunctionDecl>(D)->getType()
464 : cast<VarDecl>(D)->getType();
465 auto Name = isa<FunctionDecl>(D) ? cast<FunctionDecl>(D)->getNameAsString()
466 : cast<VarDecl>(D)->getNameAsString();
467 auto FT = AST.getPointerType(T);
468 auto subname = isa<FunctionDecl>(D) ? "inactivefn" : "inactive_global";
469 auto &Id = AST.Idents.get(
470 (StringRef("__enzyme_") + subname + "_autoreg_" + Name).str());
471 auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, FT, nullptr, SC_None);
472 V->setStorageClass(SC_PrivateExtern);
473 V->addAttr(clang::UsedAttr::CreateImplicit(AST));
474 TemplateArgumentListInfo *TemplateArgs = nullptr;
475 auto DR = DeclRefExpr::Create(
476 AST, NestedNameSpecifierLoc(), loc, cast<ValueDecl>(D), false, loc, T,
477 ExprValueKind::VK_LValue, cast<NamedDecl>(D), TemplateArgs);
478 auto rval = ExprValueKind::VK_PRValue;
479 Expr *expr = nullptr;
480 if (isa<FunctionDecl>(D)) {
481 expr =
482 ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay,
483 DR, nullptr, rval, FPOptionsOverride());
484 } else {
485 expr =
486 UnaryOperator::Create(AST, DR, UnaryOperatorKind::UO_AddrOf, FT, rval,
487 clang::ExprObjectKind ::OK_Ordinary, loc,
488 /*canoverflow*/ false, FPOptionsOverride());
489 }
490
491 if (expr->isValueDependent()) {
492 unsigned ID = S.getDiagnostics().getCustomDiagID(
493 DiagnosticsEngine::Error, "use of attribute 'enzyme_inactive' "
494 "in a templated context not yet supported");
495 S.Diag(Attr.getLoc(), ID);
496 return AttributeNotApplied;
497 }
498 V->setInit(expr);
499 S.MarkVariableReferenced(loc, V);
500 S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V));
501 return AttributeApplied;
502 }
503};
504
505static ParsedAttrInfoRegistry::Add<EnzymeInactiveAttrInfo> X4("enzyme_inactive",
506 "");
507
508struct EnzymeElementwiseReadAttrInfo : public ParsedAttrInfo {
509 EnzymeElementwiseReadAttrInfo() {
510 OptArgs = 1;
511 static constexpr Spelling S[] = {
512 {ParsedAttr::AS_GNU, "enzyme_elementwise_read"},
513#if LLVM_VERSION_MAJOR > 17
514 {ParsedAttr::AS_C23, "enzyme_elementwise_read"},
515#else
516 {ParsedAttr::AS_C2x, "enzyme_elementwise_read"},
517#endif
518 {ParsedAttr::AS_CXX11, "enzyme_elementwise_read"},
519 {ParsedAttr::AS_CXX11, "enzyme::elementwise_read"}
520 };
521 Spellings = S;
522 }
523
524 bool diagAppertainsToDecl(Sema &S, const ParsedAttr &Attr,
525 const Decl *D) const override {
526 if (isa<FunctionDecl>(D))
527 return true;
528 S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
529 << Attr << "functions";
530 return false;
531 }
532
533 AttrHandling handleDeclAttribute(Sema &S, Decl *D,
534 const ParsedAttr &Attr) const override {
535 if (Attr.getNumArgs() != 0) {
536 unsigned ID = S.getDiagnostics().getCustomDiagID(
537 DiagnosticsEngine::Error,
538 "'enzyme_elementwise_read' attribute requires zero arguments");
539 S.Diag(Attr.getLoc(), ID);
540 return AttributeNotApplied;
541 }
542 D->addAttr(AnnotateAttr::Create(S.Context, "enzyme_elementwise_read",
543 nullptr, 0, Attr.getRange()));
544 return AttributeApplied;
545 }
546};
547
548static ParsedAttrInfoRegistry::Add<EnzymeElementwiseReadAttrInfo>
549 XElemRead("enzyme_elementwise_read", "");
550
551struct EnzymeNoFreeAttrInfo : public ParsedAttrInfo {
552 EnzymeNoFreeAttrInfo() {
553 OptArgs = 1;
554 // GNU-style __attribute__(("example")) and C++/C2x-style [[example]] and
555 // [[plugin::example]] supported.
556 static constexpr Spelling S[] = {
557 {ParsedAttr::AS_GNU, "enzyme_nofree"},
558#if LLVM_VERSION_MAJOR > 17
559 {ParsedAttr::AS_C23, "enzyme_nofree"},
560#else
561 {ParsedAttr::AS_C2x, "enzyme_nofree"},
562#endif
563 {ParsedAttr::AS_CXX11, "enzyme_nofree"},
564 {ParsedAttr::AS_CXX11, "enzyme::nofree"}
565 };
566 Spellings = S;
567 }
568
569 bool diagAppertainsToDecl(Sema &S, const ParsedAttr &Attr,
570 const Decl *D) const override {
571 // This attribute appertains to functions only.
572 if (isa<FunctionDecl>(D))
573 return true;
574 if (auto VD = dyn_cast<VarDecl>(D)) {
575 if (VD->hasGlobalStorage())
576 return true;
577 }
578 S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
579 << Attr << "functions and globals";
580 return false;
581 }
582
583 AttrHandling handleDeclAttribute(Sema &S, Decl *D,
584 const ParsedAttr &Attr) const override {
585 if (Attr.getNumArgs() != 0) {
586 unsigned ID = S.getDiagnostics().getCustomDiagID(
587 DiagnosticsEngine::Error,
588 "'enzyme_nofree' attribute requires zero arguments");
589 S.Diag(Attr.getLoc(), ID);
590 return AttributeNotApplied;
591 }
592
593 auto &AST = S.getASTContext();
594 DeclContext *declCtx = D->getDeclContext();
595 for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) {
596 if (tmpCtx->isRecord()) {
597 declCtx = tmpCtx->getParent();
598 }
599 }
600 auto loc = D->getLocation();
601 RecordDecl *RD;
602 if (S.getLangOpts().CPlusPlus)
603 RD = CXXRecordDecl::Create(AST, StructKind, declCtx, loc, loc,
604 nullptr); // rId);
605 else
606 RD = RecordDecl::Create(AST, StructKind, declCtx, loc, loc,
607 nullptr); // rId);
608 RD->setAnonymousStructOrUnion(true);
609 RD->setImplicit();
610 RD->startDefinition();
611 auto T = isa<FunctionDecl>(D) ? cast<FunctionDecl>(D)->getType()
612 : cast<VarDecl>(D)->getType();
613 auto Name = isa<FunctionDecl>(D) ? cast<FunctionDecl>(D)->getNameAsString()
614 : cast<VarDecl>(D)->getNameAsString();
615 auto FT = AST.getPointerType(T);
616 auto &Id = AST.Idents.get(
617 (StringRef("__enzyme_nofree") + "_autoreg_" + Name).str());
618 auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, FT, nullptr, SC_None);
619 V->setStorageClass(SC_PrivateExtern);
620 V->addAttr(clang::UsedAttr::CreateImplicit(AST));
621 TemplateArgumentListInfo *TemplateArgs = nullptr;
622 auto DR = DeclRefExpr::Create(
623 AST, NestedNameSpecifierLoc(), loc, cast<ValueDecl>(D), false, loc, T,
624 ExprValueKind::VK_LValue, cast<NamedDecl>(D), TemplateArgs);
625 auto rval = ExprValueKind::VK_PRValue;
626 Expr *expr = nullptr;
627 if (isa<FunctionDecl>(D)) {
628 expr =
629 ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay,
630 DR, nullptr, rval, FPOptionsOverride());
631 } else {
632 expr =
633 UnaryOperator::Create(AST, DR, UnaryOperatorKind::UO_AddrOf, FT, rval,
634 clang::ExprObjectKind ::OK_Ordinary, loc,
635 /*canoverflow*/ false, FPOptionsOverride());
636 }
637
638 if (expr->isValueDependent()) {
639 unsigned ID = S.getDiagnostics().getCustomDiagID(
640 DiagnosticsEngine::Error, "use of attribute 'enzyme_nofree' "
641 "in a templated context not yet supported");
642 S.Diag(Attr.getLoc(), ID);
643 return AttributeNotApplied;
644 }
645 V->setInit(expr);
646 S.MarkVariableReferenced(loc, V);
647 S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V));
648 return AttributeApplied;
649 }
650};
651
652static ParsedAttrInfoRegistry::Add<EnzymeNoFreeAttrInfo> X5("enzyme_nofree",
653 "");
654
655struct EnzymeSparseAccumulateAttrInfo : public ParsedAttrInfo {
656 EnzymeSparseAccumulateAttrInfo() {
657 OptArgs = 1;
658 // GNU-style __attribute__(("example")) and C++/C2x-style [[example]] and
659 // [[plugin::example]] supported.
660 static constexpr Spelling S[] = {
661 {ParsedAttr::AS_GNU, "enzyme_sparse_accumulate"},
662#if LLVM_VERSION_MAJOR > 17
663 {ParsedAttr::AS_C23, "enzyme_sparse_accumulate"},
664#else
665 {ParsedAttr::AS_C2x, "enzyme_sparse_accumulate"},
666#endif
667 {ParsedAttr::AS_CXX11, "enzyme_sparse_accumulate"},
668 {ParsedAttr::AS_CXX11, "enzyme::sparse_accumulate"}
669 };
670 Spellings = S;
671 }
672
673 bool diagAppertainsToDecl(Sema &S, const ParsedAttr &Attr,
674 const Decl *D) const override {
675 // This attribute appertains to functions only.
676 if (isa<FunctionDecl>(D))
677 return true;
678 S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
679 << Attr << "functions";
680 return false;
681 }
682
683 AttrHandling handleDeclAttribute(Sema &S, Decl *D,
684 const ParsedAttr &Attr) const override {
685 if (Attr.getNumArgs() != 0) {
686 unsigned ID = S.getDiagnostics().getCustomDiagID(
687 DiagnosticsEngine::Error,
688 "'enzyme_sparse_accumulate' attribute requires zero arguments");
689 S.Diag(Attr.getLoc(), ID);
690 return AttributeNotApplied;
691 }
692
693 auto &AST = S.getASTContext();
694 DeclContext *declCtx = D->getDeclContext();
695 for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) {
696 if (tmpCtx->isRecord()) {
697 declCtx = tmpCtx->getParent();
698 }
699 }
700 auto loc = D->getLocation();
701 RecordDecl *RD;
702 if (S.getLangOpts().CPlusPlus)
703 RD = CXXRecordDecl::Create(AST, StructKind, declCtx, loc, loc,
704 nullptr); // rId);
705 else
706 RD = RecordDecl::Create(AST, StructKind, declCtx, loc, loc,
707 nullptr); // rId);
708 RD->setAnonymousStructOrUnion(true);
709 RD->setImplicit();
710 RD->startDefinition();
711 auto T = cast<FunctionDecl>(D)->getType();
712 auto Name = cast<FunctionDecl>(D)->getNameAsString();
713 auto FT = AST.getPointerType(T);
714 auto &Id = AST.Idents.get(
715 (StringRef("__enzyme_sparse_accumulate") + "_autoreg_" + Name).str());
716 auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, FT, nullptr, SC_None);
717 V->setStorageClass(SC_PrivateExtern);
718 V->addAttr(clang::UsedAttr::CreateImplicit(AST));
719 TemplateArgumentListInfo *TemplateArgs = nullptr;
720 auto DR = DeclRefExpr::Create(
721 AST, NestedNameSpecifierLoc(), loc, cast<ValueDecl>(D), false, loc, T,
722 ExprValueKind::VK_LValue, cast<NamedDecl>(D), TemplateArgs);
723 auto rval = ExprValueKind::VK_PRValue;
724 Expr *expr = nullptr;
725 expr =
726 ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay,
727 DR, nullptr, rval, FPOptionsOverride());
728
729 if (expr->isValueDependent()) {
730 unsigned ID = S.getDiagnostics().getCustomDiagID(
731 DiagnosticsEngine::Error,
732 "use of attribute 'enzyme_sparse_accumulate' "
733 "in a templated context not yet supported");
734 S.Diag(Attr.getLoc(), ID);
735 return AttributeNotApplied;
736 }
737 V->setInit(expr);
738 S.MarkVariableReferenced(loc, V);
739 S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V));
740 return AttributeApplied;
741 }
742};
743
744static ParsedAttrInfoRegistry::Add<EnzymeSparseAccumulateAttrInfo>
745 SparseX("enzyme_sparse_accumulate", "");
746} // namespace
747
748#endif
static bool contains(ArrayRef< int > ar, int v)
static clang::FrontendPluginRegistry::Add< EnzymeAction< EnzymePlugin > > X("enzyme", "Enzyme Plugin")
constexpr auto StructKind
void MakeGlobalOfFn(FunctionDecl *FD, CompilerInstance &CI)
void registerEnzyme(llvm::PassBuilder &PB)
Definition Enzyme.cpp:3539
static std::string str(AugmentedStruct c)
Definition EnzymeLogic.h:62
static bool endsWith(llvm::StringRef string, llvm::StringRef suffix)
Definition Utils.h:721
bool ParseArgs(const clang::CompilerInstance &CI, const std::vector< std::string > &args) override
std::unique_ptr< clang::ASTConsumer > CreateASTConsumer(clang::CompilerInstance &CI, llvm::StringRef InFile) override
PluginASTAction::ActionType getActionType() override
void HandleTranslationUnit(ASTContext &context) override
EnzymePlugin(clang::CompilerInstance &CI)
bool HandleTopLevelDecl(clang::DeclGroupRef dg) override
void HandleCXXStaticMemberVarInstantiation(clang::VarDecl *V) override
bool VisitFunctionDecl(FunctionDecl *FD)
Visitor(CompilerInstance &CI)
CompilerInstance & CI