Enzyme main
Loading...
Searching...
No Matches
Enzyme.cpp
Go to the documentation of this file.
1//===- Enzyme.cpp - Automatic Differentiation Transformation Pass -------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains Enzyme, a transformation pass that takes replaces calls
22// to function calls to *__enzyme_autodiff* with a call to the derivative of
23// the function passed as the first argument.
24//
25//===----------------------------------------------------------------------===//
26#include <llvm/Config/llvm-config.h>
27#include <memory>
28
29#if LLVM_VERSION_MAJOR >= 16
30#define private public
31#include "llvm/Analysis/ScalarEvolution.h"
32#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
33#undef private
34#else
35#include "SCEV/ScalarEvolution.h"
36#include "SCEV/ScalarEvolutionExpander.h"
37#endif
38
39#include "llvm/ADT/ArrayRef.h"
40#include "llvm/ADT/MapVector.h"
41#include <optional>
42#if LLVM_VERSION_MAJOR <= 16
43#include "llvm/ADT/Optional.h"
44#endif
45#include "llvm/ADT/SetVector.h"
46#include "llvm/ADT/SmallSet.h"
47#include "llvm/ADT/SmallVector.h"
48
49#include "llvm/Passes/PassBuilder.h"
50
51#include "llvm/IR/BasicBlock.h"
52#include "llvm/IR/Constants.h"
53#include "llvm/IR/Function.h"
54#include "llvm/IR/IRBuilder.h"
55#include "llvm/IR/InstrTypes.h"
56#include "llvm/IR/Instructions.h"
57#include "llvm/IR/MDBuilder.h"
58#include "llvm/IR/Metadata.h"
59
60#include "llvm/Analysis/ScalarEvolution.h"
61#include "llvm/Support/Debug.h"
62#include "llvm/Support/ErrorHandling.h"
63#include "llvm/Transforms/Scalar.h"
64
65#include "llvm/Analysis/BasicAliasAnalysis.h"
66#include "llvm/Analysis/GlobalsModRef.h"
67#include "llvm/Analysis/InlineAdvisor.h"
68#include "llvm/Analysis/InlineCost.h"
69#include "llvm/Analysis/ScalarEvolution.h"
70#include "llvm/Analysis/TargetLibraryInfo.h"
71#include "llvm/IR/AbstractCallSite.h"
72#include "llvm/Support/CommandLine.h"
73#include "llvm/Transforms/Utils/BasicBlockUtils.h"
74#include "llvm/Transforms/Utils/Cloning.h"
75
76#include "ActivityAnalysis.h"
77#include "DiffeGradientUtils.h"
78#include "EnzymeLogic.h"
79#include "GradientUtils.h"
80#include "PassUtils.h"
81#include "TraceInterface.h"
82#include "TraceUtils.h"
83#include "Utils.h"
84
85#include "InstructionBatcher.h"
86
87#include "llvm/Transforms/Utils.h"
88
89#include "llvm/Transforms/IPO/Attributor.h"
90#include "llvm/Transforms/IPO/OpenMPOpt.h"
91#include "llvm/Transforms/Utils/Mem2Reg.h"
92
93#include "CApi.h"
94using namespace llvm;
95#ifdef DEBUG_TYPE
96#undef DEBUG_TYPE
97#endif
98#define DEBUG_TYPE "lower-enzyme-intrinsic"
99
100llvm::cl::opt<bool> EnzymeEnable("enzyme-enable", cl::init(true), cl::Hidden,
101 cl::desc("Run the Enzyme pass"));
102
103llvm::cl::opt<bool>
104 EnzymePostOpt("enzyme-postopt", cl::init(false), cl::Hidden,
105 cl::desc("Run enzymepostprocessing optimizations"));
106
107llvm::cl::opt<bool> EnzymeAttributor("enzyme-attributor", cl::init(false),
108 cl::Hidden,
109 cl::desc("Run attributor post Enzyme"));
110
111llvm::cl::opt<bool> EnzymeOMPOpt("enzyme-omp-opt", cl::init(false), cl::Hidden,
112 cl::desc("Whether to enable openmp opt"));
113
114llvm::cl::opt<bool> EnzymeDetectReadThrow(
115 "enzyme-detect-readthrow", cl::init(true), cl::Hidden,
116 cl::desc("Run preprocessing detect readonly or throw optimization"));
117
118llvm::cl::opt<std::string> EnzymeTruncateAll(
119 "enzyme-truncate-all", cl::init(""), cl::Hidden,
120 cl::desc(
121 "Truncate all floating point operations. "
122 "E.g. \"64to32\" or \"64to<exponent_width>-<significand_width>\"."));
123
124#define addAttribute addAttributeAtIndex
125#define getAttribute getAttributeAtIndex
126
127namespace {
128static Value *
129castToDiffeFunctionArgType(IRBuilder<> &Builder, llvm::CallInst *CI,
130 llvm::FunctionType *FT, llvm::Type *destType,
131 unsigned int i, DerivativeMode mode,
132 llvm::Value *value, unsigned int truei) {
133 auto res = value;
134 if (auto ptr = dyn_cast<PointerType>(res->getType())) {
135 if (auto PT = dyn_cast<PointerType>(destType)) {
136 if (ptr->getAddressSpace() != PT->getAddressSpace()) {
137#if LLVM_VERSION_MAJOR < 17
138 if (CI->getContext().supportsTypedPointers()) {
139 res = Builder.CreateAddrSpaceCast(
140 res, PointerType::get(ptr->getPointerElementType(),
141 PT->getAddressSpace()));
142 } else {
143 res = Builder.CreateAddrSpaceCast(res, PT);
144 }
145#else
146 res = Builder.CreateAddrSpaceCast(res, PT);
147#endif
148 assert(value);
149 assert(destType);
150 assert(FT);
151 llvm::errs() << "Warning cast(2) __enzyme_autodiff argument " << i
152 << " " << *res << "|" << *res->getType() << " to argument "
153 << truei << " " << *destType << "\n"
154 << "orig: " << *FT << "\n";
155 return res;
156 }
157 }
158 }
159
160 if (!res->getType()->canLosslesslyBitCastTo(destType)) {
161 assert(value);
162 assert(value->getType());
163 assert(destType);
164 assert(FT);
165 auto loc = CI->getDebugLoc();
166 if (auto arg = dyn_cast<Instruction>(res)) {
167 loc = arg->getDebugLoc();
168 }
169 EmitFailure("IllegalArgCast", loc, CI,
170 "Cannot cast __enzyme_autodiff shadow argument ", i, ", found ",
171 *res, ", type ", *res->getType(), " - to arg ", truei, " ",
172 *destType);
173 return nullptr;
174 }
175 return Builder.CreateBitCast(value, destType);
176}
177
178#if LLVM_VERSION_MAJOR > 16
179static std::optional<StringRef> getMetadataName(llvm::Value *res);
180#else
181static Optional<StringRef> getMetadataName(llvm::Value *res);
182#endif
183
184// if all phi arms are (recursively) based on the same metaString, use that
185#if LLVM_VERSION_MAJOR > 16
186static std::optional<StringRef> recursePhiReads(PHINode *val)
187#else
188static Optional<StringRef> recursePhiReads(PHINode *val)
189#endif
190{
191#if LLVM_VERSION_MAJOR > 16
192 std::optional<StringRef> finalMetadata;
193#else
194 Optional<StringRef> finalMetadata;
195#endif
196 SmallVector<PHINode *, 1> todo = {val};
197 SmallSet<PHINode *, 1> done;
198 while (todo.size()) {
199 auto phiInst = todo.back();
200 todo.pop_back();
201 if (done.count(phiInst))
202 continue;
203 done.insert(phiInst);
204 for (unsigned j = 0; j < phiInst->getNumIncomingValues(); ++j) {
205 auto newVal = phiInst->getIncomingValue(j);
206 if (auto phi = dyn_cast<PHINode>(newVal)) {
207 todo.push_back(phi);
208 } else {
209 auto metaString = getMetadataName(newVal);
210 if (metaString) {
211 if (!finalMetadata) {
212 finalMetadata = metaString;
213 } else if (finalMetadata != metaString) {
214 return {};
215 }
216 }
217 }
218 }
219 }
220 return finalMetadata;
221}
222
223#if LLVM_VERSION_MAJOR > 16
224std::optional<StringRef> getMetadataName(llvm::Value *res)
225#else
226Optional<StringRef> getMetadataName(llvm::Value *res)
227#endif
228{
229 if (auto S = simplifyLoad(res))
230 return getMetadataName(S);
231
232 if (auto av = dyn_cast<MetadataAsValue>(res)) {
233 return cast<MDString>(av->getMetadata())->getString();
234 } else if ((isa<LoadInst>(res) || isa<CastInst>(res)) &&
235 isa<GlobalVariable>(cast<Instruction>(res)->getOperand(0))) {
236 GlobalVariable *gv =
237 cast<GlobalVariable>(cast<Instruction>(res)->getOperand(0));
238 return gv->getName();
239 } else if (isa<LoadInst>(res) &&
240 isa<ConstantExpr>(cast<LoadInst>(res)->getOperand(0)) &&
241 cast<ConstantExpr>(cast<LoadInst>(res)->getOperand(0))->isCast() &&
242 isa<GlobalVariable>(
243 cast<ConstantExpr>(cast<LoadInst>(res)->getOperand(0))
244 ->getOperand(0))) {
245 auto gv = cast<GlobalVariable>(
246 cast<ConstantExpr>(cast<LoadInst>(res)->getOperand(0))->getOperand(0));
247 return gv->getName();
248 } else if (auto gv = dyn_cast<GlobalVariable>(res)) {
249 return gv->getName();
250 } else if (isa<ConstantExpr>(res) && cast<ConstantExpr>(res)->isCast() &&
251 isa<GlobalVariable>(cast<ConstantExpr>(res)->getOperand(0))) {
252 auto gv = cast<GlobalVariable>(cast<ConstantExpr>(res)->getOperand(0));
253 return gv->getName();
254 } else if (isa<CastInst>(res) && cast<CastInst>(res) &&
255 isa<AllocaInst>(cast<CastInst>(res)->getOperand(0))) {
256 auto gv = cast<AllocaInst>(cast<CastInst>(res)->getOperand(0));
257 return gv->getName();
258 } else if (auto gv = dyn_cast<AllocaInst>(res)) {
259 return gv->getName();
260 } else if (isa<PHINode>(res)) {
261 return recursePhiReads(cast<PHINode>(res));
262 }
263
264 return {};
265}
266
267static Value *adaptReturnedVector(Value *ret, Value *diffret,
268 IRBuilder<> &Builder, unsigned width) {
269 Type *returnType = ret->getType();
270
271 if (StructType *sty = dyn_cast<StructType>(returnType)) {
272 Value *agg = ConstantAggregateZero::get(sty);
273
274 for (unsigned int i = 0; i < width; i++) {
275 Value *elem = Builder.CreateExtractValue(diffret, {i});
276 if (auto vty = dyn_cast<FixedVectorType>(elem->getType())) {
277 for (unsigned j = 0; j < vty->getNumElements(); ++j) {
278 Value *vecelem = Builder.CreateExtractElement(elem, j);
279 agg = Builder.CreateInsertValue(agg, vecelem, {i * j});
280 }
281 } else {
282 agg = Builder.CreateInsertValue(agg, elem, {i});
283 }
284 }
285 diffret = agg;
286 }
287 return diffret;
288}
289
290static bool ReplaceOriginalCall(IRBuilder<> &Builder, Value *ret,
291 Type *retElemType, Value *diffret,
292 Instruction *CI, DerivativeMode mode) {
293 Type *retType = ret->getType();
294 Type *diffretType = diffret->getType();
295 auto &DL = CI->getModule()->getDataLayout();
296
297 if (diffretType->isEmptyTy() || diffretType->isVoidTy() ||
298 retType->isEmptyTy() || retType->isVoidTy()) {
299 CI->replaceAllUsesWith(UndefValue::get(CI->getType()));
300 CI->eraseFromParent();
301 return true;
302 }
303
304 if (retType == diffretType) {
305 CI->replaceAllUsesWith(diffret);
306 CI->eraseFromParent();
307 return true;
308 }
309
310 if (auto sretType = dyn_cast<StructType>(retType),
311 diffsretType = dyn_cast<StructType>(diffretType);
312 sretType && diffsretType && sretType->isLayoutIdentical(diffsretType)) {
313 Value *newStruct = UndefValue::get(sretType);
314 for (unsigned int i = 0; i < sretType->getStructNumElements(); i++) {
315 Value *elem = Builder.CreateExtractValue(diffret, {i});
316 newStruct = Builder.CreateInsertValue(newStruct, elem, {i});
317 }
318 CI->replaceAllUsesWith(newStruct);
319 CI->eraseFromParent();
320 return true;
321 }
322
323 if (isa<PointerType>(retType)) {
324 retType = retElemType;
325
326 if (auto sretType = dyn_cast<StructType>(retType),
327 diffsretType = dyn_cast<StructType>(diffretType);
328 sretType && diffsretType && sretType->isLayoutIdentical(diffsretType)) {
329 for (unsigned int i = 0; i < sretType->getStructNumElements(); i++) {
330 Value *sgep = Builder.CreateStructGEP(retType, ret, i);
331 Builder.CreateStore(Builder.CreateExtractValue(diffret, {i}), sgep);
332 }
333 CI->eraseFromParent();
334 return true;
335 }
336
337 if (DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) {
338 Builder.CreateStore(
339 diffret, Builder.CreatePointerCast(ret, getUnqual(diffretType)));
340 CI->eraseFromParent();
341 return true;
342 }
343 }
344
346 DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) ||
347 ((mode == DerivativeMode::ForwardMode ||
349 DL.getTypeSizeInBits(retType) == DL.getTypeSizeInBits(diffretType))) {
350 IRBuilder<> EB(CI->getFunction()->getEntryBlock().getFirstNonPHI());
351 auto AL = EB.CreateAlloca(retType);
352 Builder.CreateStore(diffret,
353 Builder.CreatePointerCast(AL, getUnqual(diffretType)));
354 Value *cload = Builder.CreateLoad(retType, AL);
355 CI->replaceAllUsesWith(cload);
356 CI->eraseFromParent();
357 return true;
358 }
359
361 diffret->getType()->isAggregateType()) {
362 auto diffreti = Builder.CreateExtractValue(diffret, {0});
363 if (diffreti->getType() == retType) {
364 CI->replaceAllUsesWith(diffreti);
365 CI->eraseFromParent();
366 return true;
367 } else if (diffretType == retType) {
368 CI->replaceAllUsesWith(diffret);
369 CI->eraseFromParent();
370 return true;
371 }
372 }
373
374 auto diffretsize = DL.getTypeSizeInBits(diffretType);
375 auto retsize = DL.getTypeSizeInBits(retType);
376 EmitFailure("IllegalReturnCast", CI->getDebugLoc(), CI,
377 "Cannot cast return type of gradient ", *diffretType, *diffret,
378 " of size ", diffretsize, " bits ", ", to desired type ",
379 *retType, " of size ", retsize, " bits");
380 return false;
381}
382
383class EnzymeBase {
384public:
385 EnzymeLogic Logic;
386 EnzymeBase(bool PostOpt)
387 : Logic(EnzymePostOpt.getNumOccurrences() ? EnzymePostOpt : PostOpt) {
388 // initializeLowerAutodiffIntrinsicPass(*PassRegistry::getPassRegistry());
389 }
390
391 Function *parseFunctionParameter(CallInst *CI) {
392 Value *fn = CI->getArgOperand(0);
393
394 // determine function to differentiate
395 if (CI->hasStructRetAttr()) {
396 fn = CI->getArgOperand(1);
397 }
398
399 Value *ofn = fn;
400 fn = GetFunctionFromValue(fn);
401
402 if (!fn || !isa<Function>(fn)) {
403 assert(ofn);
404 EmitFailure("NoFunctionToDifferentiate", CI->getDebugLoc(), CI,
405 "failed to find fn to differentiate", *CI, " - found - ",
406 *ofn);
407 return nullptr;
408 }
409 if (cast<Function>(fn)->empty()) {
410 EmitFailure("EmptyFunctionToDifferentiate", CI->getDebugLoc(), CI,
411 "failed to find fn to differentiate", *CI, " - found - ",
412 *fn);
413 return nullptr;
414 }
415
416 return cast<Function>(fn);
417 }
418
419#if LLVM_VERSION_MAJOR > 16
420 static std::optional<unsigned> parseWidthParameter(CallInst *CI)
421#else
422 static Optional<unsigned> parseWidthParameter(CallInst *CI)
423#endif
424 {
425 unsigned width = 1;
426
427 for (auto [i, found] = std::tuple{0u, false}; i < CI->arg_size(); ++i) {
428 Value *arg = CI->getArgOperand(i);
429
430 if (auto MDName = getMetadataName(arg)) {
431 if (*MDName == "enzyme_width") {
432 if (found) {
433 EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
434 "vector width declared more than once",
435 *CI->getArgOperand(i), " in", *CI);
436 return {};
437 }
438
439 if (i + 1 >= CI->arg_size()) {
440 EmitFailure("MissingVectorWidth", CI->getDebugLoc(), CI,
441 "constant integer followong enzyme_width is missing",
442 *CI->getArgOperand(i), " in", *CI);
443 return {};
444 }
445
446 Value *width_arg = CI->getArgOperand(i + 1);
447 if (auto cint = dyn_cast<ConstantInt>(width_arg)) {
448 width = cint->getZExtValue();
449 found = true;
450 } else {
451 EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
452 "enzyme_width must be a constant integer",
453 *CI->getArgOperand(i), " in", *CI);
454 return {};
455 }
456
457 if (!found) {
458 EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
459 "illegal enzyme vector argument width ",
460 *CI->getArgOperand(i), " in", *CI);
461 return {};
462 }
463 }
464 }
465 }
466 return width;
467 }
468
469 struct Options {
470 Value *differet;
471 Value *tape;
472 Value *dynamic_interface;
473 Value *trace;
474 Value *observations;
475 Value *likelihood;
476 Value *diffeLikelihood;
477 unsigned width;
478 int allocatedTapeSize;
479 bool freeMemory;
480 bool returnUsed;
481 bool tapeIsPointer;
482 bool differentialReturn;
483 bool diffeTrace;
484 DIFFE_TYPE retType;
485 bool primalReturn;
486 StringSet<> ActiveRandomVariables;
487 std::vector<bool> overwritten_args;
488 bool runtimeActivity;
489 bool strongZero;
490 bool subsequent_calls_may_write;
491 };
492
493#if LLVM_VERSION_MAJOR > 16
494 static std::optional<Options>
495 handleArguments(IRBuilder<> &Builder, CallInst *CI, Function *fn,
496 DerivativeMode mode, bool sizeOnly,
497 std::vector<DIFFE_TYPE> &constants,
498 SmallVectorImpl<Value *> &args, std::map<int, Type *> &byVal)
499#else
500 static Optional<Options>
501 handleArguments(IRBuilder<> &Builder, CallInst *CI, Function *fn,
502 DerivativeMode mode, bool sizeOnly,
503 std::vector<DIFFE_TYPE> &constants,
504 SmallVectorImpl<Value *> &args, std::map<int, Type *> &byVal)
505#endif
506 {
507 FunctionType *FT = fn->getFunctionType();
508
509 Value *differet = nullptr;
510 Value *tape = nullptr;
511 Value *dynamic_interface = nullptr;
512 Value *trace = nullptr;
513 Value *observations = nullptr;
514 Value *likelihood = nullptr;
515 Value *diffeLikelihood = nullptr;
516 unsigned width = 1;
517 int allocatedTapeSize = -1;
518 bool freeMemory = true;
519 bool tapeIsPointer = false;
520 bool diffeTrace = false;
521 unsigned truei = 0;
522 unsigned byRefSize = 0;
523 bool primalReturn = false;
524 bool runtimeActivity = false;
525 bool strongZero = false;
526 bool subsequent_calls_may_write =
530 StringSet<> ActiveRandomVariables;
531
532 DIFFE_TYPE retType = whatType(fn->getReturnType(), mode);
533
534 if (fn->hasParamAttribute(0, Attribute::StructRet)) {
535 Type *Ty = nullptr;
536 Ty = fn->getParamAttribute(0, Attribute::StructRet).getValueAsType();
537 if (whatType(Ty, mode) != DIFFE_TYPE::CONSTANT) {
538 retType = DIFFE_TYPE::DUP_ARG;
539 }
540 }
541
542 bool returnUsed =
543 !fn->getReturnType()->isVoidTy() && !fn->getReturnType()->isEmptyTy();
544
545 bool sret = CI->hasStructRetAttr() ||
546 fn->hasParamAttribute(0, Attribute::StructRet);
547
548 std::vector<bool> overwritten_args(
549 fn->getFunctionType()->getNumParams(),
551
552 for (unsigned i = 1 + sret; i < CI->arg_size(); ++i) {
553 Value *res = CI->getArgOperand(i);
554 auto metaString = getMetadataName(res);
555 // handle metadata
556 if (metaString && startsWith(*metaString, "enzyme_")) {
557 if (*metaString == "enzyme_const_return") {
558 retType = DIFFE_TYPE::CONSTANT;
559 continue;
560 } else if (*metaString == "enzyme_active_return") {
561 retType = DIFFE_TYPE::OUT_DIFF;
562 continue;
563 } else if (*metaString == "enzyme_dup_return") {
564 retType = DIFFE_TYPE::DUP_ARG;
565 continue;
566 } else if (*metaString == "enzyme_noret") {
567 returnUsed = false;
568 continue;
569 } else if (*metaString == "enzyme_primal_return") {
570 primalReturn = true;
571 continue;
572 }
573 }
574 }
575 bool differentialReturn = (mode == DerivativeMode::ReverseModeCombined ||
577 (retType == DIFFE_TYPE::OUT_DIFF);
578
579 // find and handle enzyme_width
580 if (auto parsedWidth = parseWidthParameter(CI)) {
581 width = *parsedWidth;
582 } else {
583 return {};
584 }
585
586 // handle different argument order for struct return.
587 if (fn->hasParamAttribute(0, Attribute::StructRet)) {
588 truei = 1;
589
590 const DataLayout &DL = CI->getParent()->getModule()->getDataLayout();
591 Type *Ty = nullptr;
592 Ty = fn->getParamAttribute(0, Attribute::StructRet).getValueAsType();
593 Type *CTy = nullptr;
594 CTy = CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet)
595 .getValueAsType();
596 auto FnSize = (DL.getTypeSizeInBits(Ty) / 8);
597 auto CSize = CTy ? (DL.getTypeSizeInBits(CTy) / 8) : 0;
598 auto count = ((mode == DerivativeMode::ForwardMode ||
601 (retType == DIFFE_TYPE::DUP_ARG ||
602 retType == DIFFE_TYPE::DUP_NONEED)) *
603 width +
604 primalReturn;
605 if (CSize < count * FnSize) {
607 "IllegalByRefSize", CI->getDebugLoc(), CI, "Struct return type ",
608 *CTy, " (", CSize, " bytes), not large enough to store ", count,
609 " returns of type ", *Ty, " (", FnSize, " bytes), width=", width,
610 " primal requested=", primalReturn);
611 }
612 Value *primal = nullptr;
613 if (primalReturn) {
614 Value *sretPt = CI->getArgOperand(0);
615 PointerType *pty = cast<PointerType>(sretPt->getType());
616 primal = Builder.CreatePointerCast(
617 sretPt, PointerType::get(Ty, pty->getAddressSpace()));
618 } else {
619 AllocaInst *primalA = new AllocaInst(Ty, DL.getAllocaAddrSpace(),
620 nullptr, DL.getPrefTypeAlign(Ty));
621 primalA->insertBefore(CI);
622 primal = primalA;
623 }
624
625 Value *shadow = nullptr;
626 switch (mode) {
630 if (retType != DIFFE_TYPE::CONSTANT) {
631 Value *sretPt = CI->getArgOperand(0);
632 PointerType *pty = cast<PointerType>(sretPt->getType());
633 auto shadowPtr = Builder.CreatePointerCast(
634 sretPt, PointerType::get(Ty, pty->getAddressSpace()));
635 if (width == 1) {
636 if (primalReturn)
637 shadowPtr = Builder.CreateConstGEP1_64(Ty, shadowPtr, 1);
638 shadow = shadowPtr;
639 } else {
640 Value *acc = UndefValue::get(ArrayType::get(
641 PointerType::get(Ty, pty->getAddressSpace()), width));
642 for (size_t i = 0; i < width; ++i) {
643 Value *elem =
644 Builder.CreateConstGEP1_64(Ty, shadowPtr, i + primalReturn);
645 acc = Builder.CreateInsertValue(acc, elem, i);
646 }
647 shadow = acc;
648 }
649 }
650 break;
651 }
655 if (retType != DIFFE_TYPE::CONSTANT)
656 shadow = CI->getArgOperand(1);
657 sret = true;
658 break;
659 }
660 }
661
662 args.push_back(primal);
663 if (retType != DIFFE_TYPE::CONSTANT)
664 args.push_back(shadow);
665 if (retType == DIFFE_TYPE::DUP_ARG && !primalReturn && isWriteOnly(fn, 0))
666 retType = DIFFE_TYPE::DUP_NONEED;
667 constants.push_back(retType);
668 retType = DIFFE_TYPE::CONSTANT;
669 primalReturn = false;
670 }
671
672 ssize_t interleaved = -1;
673
674 size_t maxsize;
675 maxsize = CI->arg_size();
676 size_t num_args = maxsize;
677 for (unsigned i = 1 + sret; i < maxsize; ++i) {
678 Value *res = CI->getArgOperand(i);
679 auto metaString = getMetadataName(res);
680 if (metaString && startsWith(*metaString, "enzyme_")) {
681 if (*metaString == "enzyme_interleave") {
682 maxsize = i;
683 interleaved = i + 1;
684 break;
685 }
686 }
687 }
688
690
691 for (ssize_t i = 1 + sret; (size_t)i < maxsize; ++i) {
692 Value *res = CI->getArgOperand(i);
693 auto metaString = getMetadataName(res);
694#if LLVM_VERSION_MAJOR > 16
695 std::optional<Value *> batchOffset;
696 std::optional<DIFFE_TYPE> opt_ty;
697#else
698 Optional<Value *> batchOffset;
699 Optional<DIFFE_TYPE> opt_ty;
700#endif
701
702 bool overwritten = !(mode == DerivativeMode::ReverseModeCombined);
703
704 bool skipArg = false;
705
706 // handle metadata
707 while (metaString && startsWith(*metaString, "enzyme_")) {
708 if (*metaString == "enzyme_not_overwritten") {
709 overwritten = false;
710 } else if (*metaString == "enzyme_byref") {
711 ++i;
712 if (!isa<ConstantInt>(CI->getArgOperand(i))) {
713 EmitFailure("IllegalAllocatedSize", CI->getDebugLoc(), CI,
714 "illegal enzyme byref size ", *CI->getArgOperand(i),
715 "in", *CI);
716 return {};
717 }
718 byRefSize = cast<ConstantInt>(CI->getArgOperand(i))->getZExtValue();
719 assert(byRefSize > 0);
720 skipArg = true;
721 break;
722 } else if (*metaString == "enzyme_dup") {
723 opt_ty = DIFFE_TYPE::DUP_ARG;
724 } else if (*metaString == "enzyme_dupv") {
725 opt_ty = DIFFE_TYPE::DUP_ARG;
726 ++i;
727 Value *offset_arg = CI->getArgOperand(i);
728 if (offset_arg->getType()->isIntegerTy()) {
729 batchOffset = offset_arg;
730 } else {
731 EmitFailure("IllegalVectorOffset", CI->getDebugLoc(), CI,
732 "enzyme_batch must be followd by an integer "
733 "offset.",
734 *CI->getArgOperand(i), " in", *CI);
735 return {};
736 }
737 } else if (*metaString == "enzyme_dupnoneed") {
738 opt_ty = DIFFE_TYPE::DUP_NONEED;
739 } else if (*metaString == "enzyme_dupnoneedv") {
740 opt_ty = DIFFE_TYPE::DUP_NONEED;
741 ++i;
742 Value *offset_arg = CI->getArgOperand(i);
743 if (offset_arg->getType()->isIntegerTy()) {
744 batchOffset = offset_arg;
745 } else {
746 EmitFailure("IllegalVectorOffset", CI->getDebugLoc(), CI,
747 "enzyme_batch must be followd by an integer "
748 "offset.",
749 *CI->getArgOperand(i), " in", *CI);
750 return {};
751 }
752 } else if (*metaString == "enzyme_out") {
753 opt_ty = DIFFE_TYPE::OUT_DIFF;
754 } else if (*metaString == "enzyme_const") {
755 opt_ty = DIFFE_TYPE::CONSTANT;
756 } else if (*metaString == "enzyme_noret") {
757 skipArg = true;
758 break;
759 } else if (*metaString == "enzyme_allocated") {
760 assert(!sizeOnly);
761 ++i;
762 if (!isa<ConstantInt>(CI->getArgOperand(i))) {
763 EmitFailure("IllegalAllocatedSize", CI->getDebugLoc(), CI,
764 "illegal enzyme allocated size ", *CI->getArgOperand(i),
765 "in", *CI);
766 return {};
767 }
768 allocatedTapeSize =
769 cast<ConstantInt>(CI->getArgOperand(i))->getZExtValue();
770 skipArg = true;
771 break;
772 } else if (*metaString == "enzyme_tape") {
773 assert(!sizeOnly);
774 ++i;
775 tape = CI->getArgOperand(i);
776 tapeIsPointer = true;
777 skipArg = true;
778 break;
779 } else if (*metaString == "enzyme_nofree") {
780 assert(!sizeOnly);
781 freeMemory = false;
782 skipArg = true;
783 break;
784 } else if (*metaString == "enzyme_runtime_activity") {
785 runtimeActivity = true;
786 skipArg = true;
787 break;
788 } else if (*metaString == "enzyme_strong_zero") {
789 strongZero = true;
790 skipArg = true;
791 break;
792 } else if (*metaString == "enzyme_primal_return") {
793 skipArg = true;
794 break;
795 } else if (*metaString == "enzyme_const_return") {
796 skipArg = true;
797 break;
798 } else if (*metaString == "enzyme_active_return") {
799 skipArg = true;
800 break;
801 } else if (*metaString == "enzyme_dup_return") {
802 skipArg = true;
803 break;
804 } else if (*metaString == "enzyme_width") {
805 ++i;
806 skipArg = true;
807 break;
808 } else if (*metaString == "enzyme_interface") {
809 ++i;
810 dynamic_interface = CI->getArgOperand(i);
811 skipArg = true;
812 break;
813 } else if (*metaString == "enzyme_trace") {
814 trace = CI->getArgOperand(++i);
815 opt_ty = DIFFE_TYPE::CONSTANT;
816 skipArg = true;
817 break;
818 } else if (*metaString == "enzyme_duptrace") {
819 trace = CI->getArgOperand(++i);
820 diffeTrace = true;
821 opt_ty = DIFFE_TYPE::CONSTANT;
822 skipArg = true;
823 break;
824 } else if (*metaString == "enzyme_likelihood") {
825 likelihood = CI->getArgOperand(++i);
826 opt_ty = DIFFE_TYPE::CONSTANT;
827 skipArg = true;
828 break;
829 } else if (*metaString == "enzyme_duplikelihood") {
830 likelihood = CI->getArgOperand(++i);
831 diffeLikelihood = CI->getArgOperand(++i);
832 opt_ty = DIFFE_TYPE::DUP_ARG;
833 skipArg = true;
834 break;
835 } else if (*metaString == "enzyme_observations") {
836 observations = CI->getArgOperand(++i);
837 opt_ty = DIFFE_TYPE::CONSTANT;
838 skipArg = true;
839 break;
840 } else if (*metaString == "enzyme_active_rand_var") {
841 Value *string = CI->getArgOperand(++i);
842 StringRef const_string;
843 if (getConstantStringInfo(string, const_string)) {
844 ActiveRandomVariables.insert(const_string);
845 } else {
847 "IllegalStringType", CI->getDebugLoc(), CI,
848 "active variable address must be a compile-time constant", *CI,
849 *metaString);
850 }
851 skipArg = true;
852 break;
853 } else {
854 EmitFailure("IllegalDiffeType", CI->getDebugLoc(), CI,
855 "illegal enzyme metadata classification ", *CI,
856 *metaString);
857 return {};
858 }
859 if (sizeOnly) {
860 assert(opt_ty);
861 constants.push_back(*opt_ty);
862 truei++;
863 skipArg = true;
864 break;
865 }
866 ++i;
867 if (i == CI->arg_size()) {
868 EmitFailure("EnzymeCallingError", CI->getDebugLoc(), CI,
869 "Too few arguments to Enzyme call ", *CI);
870 return {};
871 }
872 res = CI->getArgOperand(i);
873 metaString = getMetadataName(res);
874 }
875
876 if (skipArg)
877 continue;
878
879 if (byRefSize) {
880 Type *subTy = nullptr;
881 if (truei < FT->getNumParams()) {
882 subTy = FT->getParamType(i);
883 } else if ((mode == DerivativeMode::ReverseModeGradient ||
885 if (differentialReturn && differet == nullptr) {
886 subTy = FT->getReturnType();
887 }
888 }
889
890 if (!subTy) {
891 EmitFailure("IllegalByVal", CI->getDebugLoc(), CI,
892 "illegal enzyme byval arg", truei, " ", *res);
893 return {};
894 }
895
896 auto &DL = fn->getParent()->getDataLayout();
897 auto BitSize = DL.getTypeSizeInBits(subTy);
898 if (BitSize / 8 != byRefSize) {
899 EmitFailure("IllegalByRefSize", CI->getDebugLoc(), CI,
900 "illegal enzyme pointer type size ", *res, " expected ",
901 byRefSize, " (bytes) actual size ", BitSize,
902 " (bits) in ", *CI);
903 }
904 res = Builder.CreateBitCast(
905 res,
906 PointerType::get(
907 subTy, cast<PointerType>(res->getType())->getAddressSpace()));
908 res = Builder.CreateLoad(subTy, res);
909 byRefSize = 0;
910 }
911
912 if (truei >= FT->getNumParams()) {
913 if (!isa<MetadataAsValue>(res) &&
916 if (differentialReturn && differet == nullptr) {
917 differet = res;
918 if (CI->paramHasAttr(i, Attribute::ByVal)) {
919 Type *T = nullptr;
920 T = CI->getParamAttr(i, Attribute::ByVal).getValueAsType();
921 differet = Builder.CreateLoad(T, differet);
922 }
923 if (differet->getType() != fn->getReturnType())
924 if (auto ST0 = dyn_cast<StructType>(differet->getType()))
925 if (auto ST1 = dyn_cast<StructType>(fn->getReturnType()))
926 if (ST0->isLayoutIdentical(ST1)) {
927 IRBuilder<> B(&Builder.GetInsertBlock()
928 ->getParent()
929 ->getEntryBlock()
930 .front());
931 auto AI = B.CreateAlloca(ST1);
932 Builder.CreateStore(differet, Builder.CreatePointerCast(
933 AI, getUnqual(ST0)));
934 differet = Builder.CreateLoad(ST1, AI);
935 }
936
937 if (differet->getType() !=
938 GradientUtils::getShadowType(fn->getReturnType(), width)) {
939 EmitFailure("BadDiffRet", CI->getDebugLoc(), CI,
940 "Bad DiffRet type ", *differet, " expected ",
941 *fn->getReturnType());
942 return {};
943 }
944 continue;
945 } else if (tape == nullptr) {
946 tape = res;
947 if (CI->paramHasAttr(i, Attribute::ByVal)) {
948 Type *T = nullptr;
949 T = CI->getParamAttr(i, Attribute::ByVal).getValueAsType();
950 tape = Builder.CreateLoad(T, tape);
951 }
952 continue;
953 }
954 }
955 EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
956 "Had too many arguments to __enzyme_autodiff", *CI,
957 " - extra arg - ", *res);
958 return {};
959 }
960 assert(truei < FT->getNumParams());
961 overwritten_args[truei] = overwritten;
962
963 auto PTy = FT->getParamType(truei);
964 DIFFE_TYPE ty =
965 opt_ty ? *opt_ty
966 : ((interleaved == -1) ? whatType(PTy, mode) : last_ty);
967 last_ty = ty;
968
969 constants.push_back(ty);
970
971 assert(truei < FT->getNumParams());
972 // cast primal
973 if (PTy != res->getType()) {
974 if (auto ptr = dyn_cast<PointerType>(res->getType())) {
975 if (auto PT = dyn_cast<PointerType>(PTy)) {
976 if (ptr->getAddressSpace() != PT->getAddressSpace()) {
977#if LLVM_VERSION_MAJOR < 17
978 if (CI->getContext().supportsTypedPointers()) {
979 res = Builder.CreateAddrSpaceCast(
980 res, PointerType::get(ptr->getPointerElementType(),
981 PT->getAddressSpace()));
982 } else {
983 res = Builder.CreateAddrSpaceCast(res, PT);
984 }
985#else
986 res = Builder.CreateAddrSpaceCast(res, PT);
987#endif
988 assert(res);
989 assert(PTy);
990 assert(FT);
991 llvm::errs() << "Warning cast(1) __enzyme_autodiff argument " << i
992 << " " << *res << "|" << *res->getType()
993 << " to argument " << truei << " " << *PTy << "\n"
994 << "orig: " << *FT << "\n";
995 }
996 }
997 }
998 if (res->getType()->canLosslesslyBitCastTo(PTy)) {
999 res = Builder.CreateBitCast(res, PTy);
1000 }
1001 if (res->getType() != PTy && res->getType()->isIntegerTy() &&
1002 PTy->isIntegerTy(1)) {
1003 res = Builder.CreateTrunc(res, PTy);
1004 }
1005 if (res->getType() != PTy) {
1006 auto loc = CI->getDebugLoc();
1007 if (auto arg = dyn_cast<Instruction>(res)) {
1008 loc = arg->getDebugLoc();
1009 }
1010 auto S = simplifyLoad(res);
1011 if (!S)
1012 S = res;
1013 EmitFailure("IllegalArgCast", loc, CI,
1014 "Cannot cast __enzyme_autodiff primal argument ", i,
1015 ", found ", *res, ", type ", *res->getType(),
1016 " (simplified to ", *S, " ) ", " - to arg ", truei, ", ",
1017 *PTy);
1018 return {};
1019 }
1020 }
1021 if (CI->isByValArgument(i)) {
1022 byVal[args.size()] = CI->getParamByValType(i);
1023 }
1024
1025 args.push_back(res);
1026 if (ty == DIFFE_TYPE::DUP_ARG || ty == DIFFE_TYPE::DUP_NONEED) {
1027 if (interleaved == -1)
1028 ++i;
1029
1030 Value *res = nullptr;
1031#if LLVM_VERSION_MAJOR >= 16
1032 bool batch = batchOffset.has_value();
1033#else
1034 bool batch = batchOffset.hasValue();
1035#endif
1036
1037 for (unsigned v = 0; v < width; ++v) {
1038 if ((size_t)((interleaved == -1) ? i : interleaved) >= num_args) {
1039 EmitFailure("MissingArgShadow", CI->getDebugLoc(), CI,
1040 "__enzyme_autodiff missing argument shadow at index ",
1041 *((interleaved == -1) ? &i : &interleaved),
1042 ", need shadow of type ", *PTy,
1043 " to shadow primal argument ", *args.back(),
1044 " at call ", *CI);
1045 return {};
1046 }
1047
1048 // cast diffe
1049 Value *element =
1050 CI->getArgOperand((interleaved == -1) ? i : interleaved);
1051 if (batch) {
1052 if (auto elementPtrTy = dyn_cast<PointerType>(element->getType())) {
1053 element = Builder.CreateBitCast(
1054 element, PointerType::get(Type::getInt8Ty(CI->getContext()),
1055 elementPtrTy->getAddressSpace()));
1056 element = Builder.CreateGEP(
1057 Type::getInt8Ty(CI->getContext()), element,
1058 Builder.CreateMul(
1059 *batchOffset,
1060 ConstantInt::get((*batchOffset)->getType(), v)));
1061 element = Builder.CreateBitCast(element, elementPtrTy);
1062 } else {
1064 "NonPointerBatch", CI->getDebugLoc(), CI,
1065 "Batched argument at index ",
1066 *((interleaved == -1) ? &i : &interleaved),
1067 " must be of pointer type, found: ", *element->getType());
1068 return {};
1069 }
1070 }
1071 if (PTy != element->getType()) {
1072 element = castToDiffeFunctionArgType(
1073 Builder, CI, FT, PTy, (interleaved == -1) ? i : interleaved,
1074 mode, element, truei);
1075 if (!element) {
1076 return {};
1077 }
1078 }
1079
1080 if (width > 1) {
1081 res =
1082 res ? Builder.CreateInsertValue(res, element, {v})
1083 : Builder.CreateInsertValue(UndefValue::get(ArrayType::get(
1084 element->getType(), width)),
1085 element, {v});
1086
1087 if (v < width - 1 && !batch && (interleaved == -1)) {
1088 ++i;
1089 }
1090
1091 } else {
1092 res = element;
1093 }
1094
1095 if (interleaved != -1)
1096 interleaved++;
1097 }
1098
1099 args.push_back(res);
1100 }
1101
1102 ++truei;
1103 }
1104 if (truei < FT->getNumParams()) {
1105 auto numParams = FT->getNumParams();
1107 "EnzymeInsufficientArgs", CI->getDebugLoc(), CI,
1108 "Insufficient number of args passed to derivative call required ",
1109 numParams, " primal args, found ", truei);
1110 return {};
1111 }
1112
1113 return Options({differet,
1114 tape,
1115 dynamic_interface,
1116 trace,
1117 observations,
1118 likelihood,
1119 diffeLikelihood,
1120 width,
1121 allocatedTapeSize,
1122 freeMemory,
1123 returnUsed,
1124 tapeIsPointer,
1125 differentialReturn,
1126 diffeTrace,
1127 retType,
1128 primalReturn,
1129 ActiveRandomVariables,
1130 overwritten_args,
1131 runtimeActivity,
1132 strongZero,
1133 subsequent_calls_may_write});
1134 }
1135
1136 static FnTypeInfo populate_type_args(TypeAnalysis &TA, llvm::Function *fn,
1137 DerivativeMode mode) {
1138 FnTypeInfo type_args(fn);
1139 for (auto &a : type_args.Function->args()) {
1140 TypeTree dt;
1141 if (a.getType()->isFPOrFPVectorTy()) {
1142 dt = ConcreteType(a.getType()->getScalarType());
1143 } else if (a.getType()->isPointerTy()) {
1144#if LLVM_VERSION_MAJOR < 17
1145 if (a.getContext().supportsTypedPointers()) {
1146 auto et = a.getType()->getPointerElementType();
1147 if (et->isFPOrFPVectorTy()) {
1148 dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr);
1149 } else if (et->isPointerTy()) {
1150 dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr);
1151 }
1152 }
1153#endif
1154 dt.insert({}, BaseType::Pointer);
1155 } else if (a.getType()->isIntOrIntVectorTy()) {
1157 }
1158 type_args.Arguments.insert(
1159 std::pair<Argument *, TypeTree>(&a, dt.Only(-1, nullptr)));
1160 // TODO note that here we do NOT propagate constants in type info (and
1161 // should consider whether we should)
1162 type_args.KnownValues.insert(
1163 std::pair<Argument *, std::set<int64_t>>(&a, {}));
1164 }
1165 TypeTree dt;
1166 if (fn->getReturnType()->isFPOrFPVectorTy()) {
1167 dt = ConcreteType(fn->getReturnType()->getScalarType());
1168 }
1169 type_args.Return = dt.Only(-1, nullptr);
1170
1171 type_args = TA.analyzeFunction(type_args).getAnalyzedTypeInfo();
1172 return type_args;
1173 }
1174
1175 static FloatRepresentation getDefaultFloatRepr(unsigned width) {
1176 switch (width) {
1177 case 16:
1178 return FloatRepresentation(5, 10);
1179 case 32:
1180 return FloatRepresentation(8, 23);
1181 case 64:
1182 return FloatRepresentation(11, 52);
1183 default:
1184 llvm_unreachable("Invalid float width");
1185 }
1186 };
1187
1188 bool HandleTruncateFunc(CallInst *CI, TruncateMode mode) {
1189 IRBuilder<> Builder(CI);
1190 Function *F = parseFunctionParameter(CI);
1191 if (!F)
1192 return false;
1193 unsigned ArgSize = CI->arg_size();
1194 if (ArgSize != 4 && ArgSize != 3) {
1195 EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
1196 "Had incorrect number of args to __enzyme_truncate_func", *CI,
1197 " - expected 3 or 4");
1198 return false;
1199 }
1200 FloatTruncation truncation = [&]() -> FloatTruncation {
1201 if (ArgSize == 3) {
1202 auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
1203 assert(Cfrom);
1204 auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
1205 assert(Cto);
1206 return FloatTruncation(
1207 getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()),
1208 getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()),
1209 mode);
1210 } else if (ArgSize == 4) {
1211 auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
1212 assert(Cfrom);
1213 auto Cto_exponent = cast<ConstantInt>(CI->getArgOperand(2));
1214 assert(Cto_exponent);
1215 auto Cto_significand = cast<ConstantInt>(CI->getArgOperand(3));
1216 assert(Cto_significand);
1217 return FloatTruncation(
1218 getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()),
1220 (unsigned)Cto_exponent->getValue().getZExtValue(),
1221 (unsigned)Cto_significand->getValue().getZExtValue()),
1222 mode);
1223 }
1224 llvm_unreachable("??");
1225 }();
1226
1227 RequestContext context(CI, &Builder);
1228 llvm::Value *res = Logic.CreateTruncateFunc(context, F, truncation, mode);
1229 if (!res)
1230 return false;
1231 res = Builder.CreatePointerCast(res, CI->getType());
1232 CI->replaceAllUsesWith(res);
1233 CI->eraseFromParent();
1234 return true;
1235 }
1236
1237 bool HandleTruncateValue(CallInst *CI, bool isTruncate) {
1238 IRBuilder<> Builder(CI);
1239 if (CI->arg_size() != 3) {
1240 EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
1241 "Had incorrect number of args to __enzyme_truncate_value",
1242 *CI, " - expected 3");
1243 return false;
1244 }
1245 auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
1246 assert(Cfrom);
1247 auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
1248 assert(Cto);
1249 auto Addr = CI->getArgOperand(0);
1250 RequestContext context(CI, &Builder);
1251 bool res = Logic.CreateTruncateValue(
1252 context, Addr,
1253 getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()),
1254 getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()),
1255 isTruncate);
1256 if (!res)
1257 return false;
1258 return true;
1259 }
1260
1261 bool HandleBatch(CallInst *CI) {
1262 unsigned width = 1;
1263 unsigned truei = 0;
1264 std::map<unsigned, Value *> batchOffset;
1265 SmallVector<Value *, 4> args;
1266 SmallVector<BATCH_TYPE, 4> arg_types;
1267 IRBuilder<> Builder(CI);
1268 Function *F = parseFunctionParameter(CI);
1269 if (!F)
1270 return false;
1271
1272 assert(F);
1273 FunctionType *FT = F->getFunctionType();
1274
1275 // find and handle enzyme_width
1276 if (auto parsedWidth = parseWidthParameter(CI)) {
1277 width = *parsedWidth;
1278 } else {
1279 return false;
1280 }
1281
1282 // handle different argument order for struct return.
1283 bool sret =
1284 CI->hasStructRetAttr() || F->hasParamAttribute(0, Attribute::StructRet);
1285
1286 if (F->hasParamAttribute(0, Attribute::StructRet)) {
1287 truei = 1;
1288 Value *sretPt = CI->getArgOperand(0);
1289
1290 args.push_back(sretPt);
1291 arg_types.push_back(BATCH_TYPE::VECTOR);
1292 }
1293
1294 for (unsigned i = 1 + sret; i < CI->arg_size(); ++i) {
1295 Value *res = CI->getArgOperand(i);
1296
1297 if (truei >= FT->getNumParams()) {
1298 EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
1299 "Had too many arguments to __enzyme_batch", *CI,
1300 " - extra arg - ", *res);
1301 return false;
1302 }
1303 assert(truei < FT->getNumParams());
1304 auto PTy = FT->getParamType(truei);
1305
1307 auto metaString = getMetadataName(res);
1308
1309 // handle metadata
1310 if (metaString && startsWith(*metaString, "enzyme_")) {
1311 if (*metaString == "enzyme_scalar") {
1312 ty = BATCH_TYPE::SCALAR;
1313 } else if (*metaString == "enzyme_vector") {
1314 ty = BATCH_TYPE::VECTOR;
1315 } else if (*metaString == "enzyme_buffer") {
1316 ty = BATCH_TYPE::VECTOR;
1317 ++i;
1318 Value *offset_arg = CI->getArgOperand(i);
1319 if (offset_arg->getType()->isIntegerTy()) {
1320 batchOffset[i + 1] = offset_arg;
1321 } else {
1322 EmitFailure("IllegalVectorOffset", CI->getDebugLoc(), CI,
1323 "enzyme_batch must be followd by an integer "
1324 "offset.",
1325 *CI->getArgOperand(i), " in", *CI);
1326 return false;
1327 }
1328 continue;
1329 } else if (*metaString == "enzyme_width") {
1330 ++i;
1331 continue;
1332 } else {
1333 EmitFailure("IllegalDiffeType", CI->getDebugLoc(), CI,
1334 "illegal enzyme metadata classification ", *CI,
1335 *metaString);
1336 return false;
1337 }
1338 ++i;
1339 res = CI->getArgOperand(i);
1340 }
1341
1342 arg_types.push_back(ty);
1343
1344 // wrap vector
1345 if (ty == BATCH_TYPE::VECTOR) {
1346 Value *res = nullptr;
1347 bool batch = batchOffset.count(i - 1) != 0;
1348
1349 for (unsigned v = 0; v < width; ++v) {
1350 if (i >= CI->arg_size()) {
1351 EmitFailure("MissingVectorArg", CI->getDebugLoc(), CI,
1352 "__enzyme_batch missing vector argument at index ", i,
1353 ", need argument of type ", *PTy, " at call ", *CI);
1354 return false;
1355 }
1356
1357 // vectorize pointer
1358 Value *element = CI->getArgOperand(i);
1359 if (batch) {
1360 if (auto elementPtrTy = dyn_cast<PointerType>(element->getType())) {
1361 element = Builder.CreateBitCast(
1362 element, PointerType::get(Type::getInt8Ty(CI->getContext()),
1363 elementPtrTy->getAddressSpace()));
1364 element = Builder.CreateGEP(
1365 Type::getInt8Ty(CI->getContext()), element,
1366 Builder.CreateMul(
1367 batchOffset[i - 1],
1368 ConstantInt::get(batchOffset[i - 1]->getType(), v)));
1369 element = Builder.CreateBitCast(element, elementPtrTy);
1370 } else {
1371 return false;
1372 }
1373 }
1374
1375 if (width > 1) {
1376 res =
1377 res ? Builder.CreateInsertValue(res, element, {v})
1378 : Builder.CreateInsertValue(UndefValue::get(ArrayType::get(
1379 element->getType(), width)),
1380 element, {v});
1381
1382 if (v < width - 1 && !batch) {
1383 ++i;
1384 }
1385
1386 } else {
1387 res = element;
1388 }
1389 }
1390
1391 args.push_back(res);
1392
1393 } else if (ty == BATCH_TYPE::SCALAR) {
1394 args.push_back(res);
1395 }
1396
1397 truei++;
1398 }
1399
1400 BATCH_TYPE ret_type = (F->getReturnType()->isVoidTy() || width == 1)
1403
1404 auto newFunc = Logic.CreateBatch(RequestContext(CI, &Builder), F, width,
1405 arg_types, ret_type);
1406
1407 if (!newFunc)
1408 return false;
1409
1410 Value *batch =
1411 Builder.CreateCall(newFunc->getFunctionType(), newFunc, args);
1412
1413 batch = adaptReturnedVector(CI, batch, Builder, width);
1414
1415 Value *ret = CI;
1416 Type *retElemType = nullptr;
1417 if (CI->hasStructRetAttr()) {
1418 ret = CI->getArgOperand(0);
1419 retElemType =
1420 CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet)
1421 .getValueAsType();
1422 }
1423 ReplaceOriginalCall(Builder, ret, retElemType, batch, CI,
1425
1426 return true;
1427 }
1428
1429 bool HandleAutoDiff(Instruction *CI, CallingConv::ID CallingConv, Value *ret,
1430 Type *retElemType, SmallVectorImpl<Value *> &args,
1431 const std::map<int, Type *> &byVal,
1432 const std::vector<DIFFE_TYPE> &constants, Function *fn,
1433 DerivativeMode mode, Options &options, bool sizeOnly,
1434 SmallVectorImpl<CallInst *> &calls) {
1435 auto &differet = options.differet;
1436 auto &tape = options.tape;
1437 auto &width = options.width;
1438 auto &allocatedTapeSize = options.allocatedTapeSize;
1439 auto &freeMemory = options.freeMemory;
1440 auto &returnUsed = options.returnUsed;
1441 auto &tapeIsPointer = options.tapeIsPointer;
1442 auto &differentialReturn = options.differentialReturn;
1443 auto &retType = options.retType;
1444 auto &overwritten_args = options.overwritten_args;
1445 auto primalReturn = options.primalReturn;
1446 auto subsequent_calls_may_write = options.subsequent_calls_may_write;
1447
1448 auto Arch = Triple(CI->getModule()->getTargetTriple()).getArch();
1449 bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
1450 Arch == Triple::amdgcn;
1451
1452 TypeAnalysis TA(Logic);
1453 FnTypeInfo type_args = populate_type_args(TA, fn, mode);
1454
1455 IRBuilder Builder(CI);
1456 RequestContext context(CI, &Builder);
1457
1458 // differentiate fn
1459 Function *newFunc = nullptr;
1460 Type *tapeType = nullptr;
1461 const AugmentedReturn *aug;
1462 switch (mode) {
1465 if (primalReturn && fn->getReturnType()->isVoidTy()) {
1466 auto fnname = fn->getName();
1467 EmitFailure("PrimalRetOfVoid", CI->getDebugLoc(), CI,
1468 "Requested primal result of void-returning function type ",
1469 *fn->getFunctionType(), " ", fnname, " ", *CI);
1470 } else
1471 newFunc = Logic.CreateForwardDiff(
1472 context, fn, retType, constants, TA,
1473 /*should return*/ primalReturn, mode, freeMemory,
1474 options.runtimeActivity, options.strongZero, width,
1475 /*addedType*/ nullptr, type_args, subsequent_calls_may_write,
1476 overwritten_args,
1477 /*augmented*/ nullptr);
1478 break;
1480 bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1;
1481 aug = &Logic.CreateAugmentedPrimal(
1482 context, fn, retType, constants, TA,
1483 /*returnUsed*/ false, /*shadowReturnUsed*/ false, type_args,
1484 subsequent_calls_may_write, overwritten_args, forceAnonymousTape,
1485 options.runtimeActivity, options.strongZero, width,
1486 /*atomicAdd*/ AtomicAdd);
1487 auto &DL = fn->getParent()->getDataLayout();
1488 if (!forceAnonymousTape) {
1489 assert(!aug->tapeType);
1490 if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) {
1491 auto tapeIdx = aug->returns.find(AugmentedStruct::Tape)->second;
1492 tapeType = (tapeIdx == -1)
1493 ? aug->fn->getReturnType()
1494 : cast<StructType>(aug->fn->getReturnType())
1495 ->getElementType(tapeIdx);
1496 } else {
1497 if (sizeOnly) {
1498 CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0, false));
1499 CI->eraseFromParent();
1500 return true;
1501 }
1502 }
1503 if (sizeOnly) {
1504 auto size = DL.getTypeSizeInBits(tapeType) / 8;
1505 CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), size, false));
1506 CI->eraseFromParent();
1507 return true;
1508 }
1509 if (tapeType &&
1510 DL.getTypeSizeInBits(tapeType) > 8 * (size_t)allocatedTapeSize) {
1511 auto bytes = DL.getTypeSizeInBits(tapeType) / 8;
1512 EmitFailure("Insufficient tape allocation size", CI->getDebugLoc(),
1513 CI, "need ", bytes, " bytes have ", allocatedTapeSize,
1514 " bytes");
1515 }
1516 } else {
1517 tapeType = getInt8PtrTy(fn->getContext());
1518 }
1519 newFunc = Logic.CreateForwardDiff(
1520 context, fn, retType, constants, TA,
1521 /*should return*/ primalReturn, mode, freeMemory,
1522 options.runtimeActivity, options.strongZero, width,
1523 /*addedType*/ tapeType, type_args, subsequent_calls_may_write,
1524 overwritten_args, aug);
1525 break;
1526 }
1528 assert(freeMemory);
1529 newFunc = Logic.CreatePrimalAndGradient(
1530 context,
1531 (ReverseCacheKey){.todiff = fn,
1532 .retType = retType,
1533 .constant_args = constants,
1534 .subsequent_calls_may_write =
1535 subsequent_calls_may_write,
1536 .overwritten_args = overwritten_args,
1537 .returnUsed = primalReturn,
1538 .shadowReturnUsed = false,
1539 .mode = mode,
1540 .width = width,
1541 .freeMemory = freeMemory,
1542 .AtomicAdd = AtomicAdd,
1543 .additionalType = nullptr,
1544 .forceAnonymousTape = false,
1545 .typeInfo = type_args,
1546 .runtimeActivity = options.runtimeActivity,
1547 .strongZero = options.strongZero},
1548 TA, /*augmented*/ nullptr);
1549 break;
1552 if (primalReturn) {
1554 "SplitPrimalRet", CI->getDebugLoc(), CI,
1555 "Option enzyme_primal_return not available in reverse split mode");
1556 }
1557 bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1;
1558 bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG ||
1559 retType == DIFFE_TYPE::DUP_NONEED);
1560 aug = &Logic.CreateAugmentedPrimal(
1561 context, fn, retType, constants, TA, returnUsed, shadowReturnUsed,
1562 type_args, subsequent_calls_may_write, overwritten_args,
1563 forceAnonymousTape, options.runtimeActivity, options.strongZero,
1564 width,
1565 /*atomicAdd*/ AtomicAdd);
1566 auto &DL = fn->getParent()->getDataLayout();
1567 if (!forceAnonymousTape) {
1568 assert(!aug->tapeType);
1569 if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) {
1570 auto tapeIdx = aug->returns.find(AugmentedStruct::Tape)->second;
1571 tapeType = (tapeIdx == -1)
1572 ? aug->fn->getReturnType()
1573 : cast<StructType>(aug->fn->getReturnType())
1574 ->getElementType(tapeIdx);
1575 } else {
1576 if (sizeOnly) {
1577 CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0, false));
1578 CI->eraseFromParent();
1579 return true;
1580 }
1581 }
1582 if (sizeOnly) {
1583 auto size = DL.getTypeSizeInBits(tapeType) / 8;
1584 CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), size, false));
1585 CI->eraseFromParent();
1586 return true;
1587 }
1588 if (tapeType &&
1589 DL.getTypeSizeInBits(tapeType) > 8 * (size_t)allocatedTapeSize) {
1590 auto bytes = DL.getTypeSizeInBits(tapeType) / 8;
1591 EmitFailure("Insufficient tape allocation size", CI->getDebugLoc(),
1592 CI, "need ", bytes, " bytes have ", allocatedTapeSize,
1593 " bytes");
1594 }
1595 } else {
1596 tapeType = getInt8PtrTy(fn->getContext());
1597 }
1599 newFunc = aug->fn;
1600 else
1601 newFunc = Logic.CreatePrimalAndGradient(
1602 context,
1603 (ReverseCacheKey){.todiff = fn,
1604 .retType = retType,
1605 .constant_args = constants,
1606 .subsequent_calls_may_write =
1607 subsequent_calls_may_write,
1608 .overwritten_args = overwritten_args,
1609 .returnUsed = false,
1610 .shadowReturnUsed = false,
1611 .mode = mode,
1612 .width = width,
1613 .freeMemory = freeMemory,
1614 .AtomicAdd = AtomicAdd,
1615 .additionalType = tapeType,
1616 .forceAnonymousTape = forceAnonymousTape,
1617 .typeInfo = type_args,
1618 .runtimeActivity = options.runtimeActivity,
1619 .strongZero = options.strongZero},
1620 TA, aug);
1621 }
1622 }
1623
1624 if (!newFunc) {
1625 StringRef n = fn->getName();
1626 EmitFailure("FailedToDifferentiate", fn->getSubprogram(),
1627 &*fn->getEntryBlock().begin(),
1628 "Could not generate derivative function of ", n);
1629 return false;
1630 }
1631
1632 if (differentialReturn) {
1633 if (differet)
1634 args.push_back(differet);
1635 else if (fn->getReturnType()->isFPOrFPVectorTy()) {
1636 Constant *seed = ConstantFP::get(fn->getReturnType(), 1.0);
1637 if (width == 1) {
1638 args.push_back(seed);
1639 } else {
1640 ArrayType *arrayType = ArrayType::get(fn->getReturnType(), width);
1641 args.push_back(ConstantArray::get(
1642 arrayType, SmallVector<Constant *, 3>(width, seed)));
1643 }
1644 } else if (auto ST = dyn_cast<StructType>(fn->getReturnType())) {
1645 SmallVector<Constant *, 2> csts;
1646 for (auto e : ST->elements()) {
1647 csts.push_back(ConstantFP::get(e, 1.0));
1648 }
1649 args.push_back(ConstantStruct::get(ST, csts));
1650 } else if (auto AT = dyn_cast<ArrayType>(fn->getReturnType())) {
1651 SmallVector<Constant *, 2> csts(
1652 AT->getNumElements(), ConstantFP::get(AT->getElementType(), 1.0));
1653 args.push_back(ConstantArray::get(AT, csts));
1654 } else {
1655 auto RT = fn->getReturnType();
1656 EmitFailure("EnzymeCallingError", CI->getDebugLoc(), CI,
1657 "Differential return required for call ", *CI,
1658 " but one of type ", *RT, " could not be auto deduced");
1659 return false;
1660 }
1661 }
1662
1665 tape && tapeType) {
1666 auto &DL = fn->getParent()->getDataLayout();
1667 if (tapeIsPointer) {
1668 tape = Builder.CreateBitCast(
1669 tape, PointerType::get(
1670 tapeType,
1671 cast<PointerType>(tape->getType())->getAddressSpace()));
1672 tape = Builder.CreateLoad(tapeType, tape);
1673 } else if (tapeType != tape->getType() &&
1674 DL.getTypeSizeInBits(tapeType) <=
1675 DL.getTypeSizeInBits(tape->getType())) {
1676 IRBuilder<> EB(&CI->getParent()->getParent()->getEntryBlock().front());
1677 auto AL = EB.CreateAlloca(tape->getType());
1678 Builder.CreateStore(tape, AL);
1679 tape = Builder.CreateLoad(
1680 tapeType, Builder.CreatePointerCast(AL, getUnqual(tapeType)));
1681 }
1682 assert(tape->getType() == tapeType);
1683 args.push_back(tape);
1684 }
1685
1686 if (EnzymePrint) {
1687 llvm::errs() << "postfn:\n" << *newFunc << "\n";
1688 }
1689 Builder.setFastMathFlags(getFast());
1690
1691 // call newFunc with the provided arguments.
1692 if (args.size() != newFunc->getFunctionType()->getNumParams()) {
1693 llvm::errs() << *CI << "\n";
1694 llvm::errs() << *newFunc << "\n";
1695 for (auto arg : args) {
1696 llvm::errs() << " + " << *arg << "\n";
1697 }
1698 auto modestr = to_string(mode);
1700 "TooFewArguments", CI->getDebugLoc(), CI,
1701 "Too few arguments passed to __enzyme_autodiff mode=", modestr);
1702 return false;
1703 }
1704 assert(args.size() == newFunc->getFunctionType()->getNumParams());
1705 for (size_t i = 0; i < args.size(); i++) {
1706 if (args[i]->getType() != newFunc->getFunctionType()->getParamType(i)) {
1707 llvm::errs() << *CI << "\n";
1708 llvm::errs() << *newFunc << "\n";
1709 for (auto arg : args) {
1710 llvm::errs() << " + " << *arg << "\n";
1711 }
1712 auto modestr = to_string(mode);
1713 EmitFailure("BadArgumentType", CI->getDebugLoc(), CI,
1714 "Incorrect argument type passed to __enzyme_autodiff mode=",
1715 modestr, " at index ", i, " expected ",
1716 *newFunc->getFunctionType()->getParamType(i), " found ",
1717 *args[i]->getType());
1718 return false;
1719 }
1720 }
1721 CallInst *diffretc = cast<CallInst>(Builder.CreateCall(newFunc, args));
1722 diffretc->setCallingConv(CallingConv);
1723 diffretc->setDebugLoc(CI->getDebugLoc());
1724
1725 for (auto &&[attr, ty] : byVal) {
1726 diffretc->addParamAttr(
1727 attr, Attribute::getWithByValType(diffretc->getContext(), ty));
1728 }
1729
1730 Value *diffret = diffretc;
1731 if (mode == DerivativeMode::ReverseModePrimal && tape) {
1732 if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) {
1733 auto tapeIdx = aug->returns.find(AugmentedStruct::Tape)->second;
1734 tapeType = (tapeIdx == -1) ? aug->fn->getReturnType()
1735 : cast<StructType>(aug->fn->getReturnType())
1736 ->getElementType(tapeIdx);
1737 unsigned idxs[] = {(unsigned)tapeIdx};
1738 Value *tapeRes = (tapeIdx == -1)
1739 ? diffret
1740 : Builder.CreateExtractValue(diffret, idxs);
1741 Builder.CreateStore(
1742 tapeRes,
1743 Builder.CreateBitCast(
1744 tape,
1745 PointerType::get(
1746 tapeRes->getType(),
1747 cast<PointerType>(tape->getType())->getAddressSpace())));
1748 if (tapeIdx != -1) {
1749 auto ST = cast<StructType>(diffret->getType());
1750 SmallVector<Type *, 2> tys(ST->elements().begin(),
1751 ST->elements().end());
1752 tys.erase(tys.begin());
1753 auto ST0 = StructType::get(ST->getContext(), tys);
1754 Value *out = UndefValue::get(ST0);
1755 for (unsigned i = 0; i < tys.size(); i++) {
1756 out = Builder.CreateInsertValue(
1757 out, Builder.CreateExtractValue(diffret, {i + 1}), {i});
1758 }
1759 diffret = out;
1760 } else {
1761 auto ST0 = StructType::get(tape->getContext(), {});
1762 diffret = UndefValue::get(ST0);
1763 }
1764 }
1765 }
1766
1767 // Adapt the returned vector type to the struct type expected by our calling
1768 // convention.
1769 if (width > 1 && !diffret->getType()->isEmptyTy() &&
1770 !diffret->getType()->isVoidTy() &&
1771 (mode == DerivativeMode::ForwardMode ||
1773
1774 diffret = adaptReturnedVector(ret, diffret, Builder, width);
1775 }
1776
1777 ReplaceOriginalCall(Builder, ret, retElemType, diffret, CI, mode);
1778 calls.push_back(diffretc);
1779 return diffret;
1780 }
1781
1782 /// Return whether successful
1783 bool HandleAutoDiffArguments(CallInst *CI, DerivativeMode mode, bool sizeOnly,
1784 SmallVectorImpl<CallInst *> &calls) {
1785
1786 // determine function to differentiate
1787 Function *fn = parseFunctionParameter(CI);
1788 if (!fn)
1789 return false;
1790
1791 IRBuilder<> Builder(CI);
1792
1793 if (EnzymePrint)
1794 llvm::errs() << "prefn:\n" << *fn << "\n";
1795
1796 std::map<int, Type *> byVal;
1797 std::vector<DIFFE_TYPE> constants;
1798 SmallVector<Value *, 2> args;
1799
1800 auto options = handleArguments(Builder, CI, fn, mode, sizeOnly, constants,
1801 args, byVal);
1802
1803 if (!options) {
1804 return false;
1805 }
1806
1807 Value *ret = CI;
1808 Type *retElemType = nullptr;
1809 if (CI->hasStructRetAttr()) {
1810 ret = CI->getArgOperand(0);
1811 retElemType =
1812 CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet)
1813 .getValueAsType();
1814 }
1815
1816 return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args,
1817 byVal, constants, fn, mode, *options, sizeOnly,
1818 calls);
1819 }
1820
1821 bool HandleProbProg(CallInst *CI, ProbProgMode mode,
1822 SmallVectorImpl<CallInst *> &calls) {
1823 IRBuilder<> Builder(CI);
1824 Function *F = parseFunctionParameter(CI);
1825 if (!F)
1826 return false;
1827
1828 assert(F);
1829
1830 std::vector<DIFFE_TYPE> constants;
1831 std::map<int, Type *> byVal;
1832 SmallVector<Value *, 4> args;
1833
1834 auto diffeMode = DerivativeMode::ReverseModeCombined;
1835
1836 auto opt = handleArguments(Builder, CI, F, diffeMode, false, constants,
1837 args, byVal);
1838
1839 SmallVector<Value *, 6> dargs(args.begin(), args.end());
1840
1841#if LLVM_VERSION_MAJOR >= 16
1842 if (!opt.has_value())
1843 return false;
1844#else
1845 if (!opt.hasValue())
1846 return false;
1847#endif
1848
1849 auto dynamic_interface = opt->dynamic_interface;
1850 auto trace = opt->trace;
1851 auto dtrace = opt->diffeTrace;
1852 auto observations = opt->observations;
1853 auto likelihood = opt->likelihood;
1854 auto dlikelihood = opt->diffeLikelihood;
1855
1856 // Interface
1857 bool has_dynamic_interface = dynamic_interface != nullptr;
1858 bool needs_interface =
1860 std::unique_ptr<TraceInterface> interface;
1861 if (has_dynamic_interface) {
1862 interface = std::make_unique<DynamicTraceInterface>(dynamic_interface,
1863 CI->getFunction());
1864 } else if (needs_interface) {
1865 interface = std::make_unique<StaticTraceInterface>(F->getParent());
1866 }
1867
1868 // Find sample function
1869 SmallPtrSet<Function *, 4> sampleFunctions;
1870 SmallPtrSet<Function *, 4> observeFunctions;
1871 for (auto &func : F->getParent()->functions()) {
1872 if (func.getName().contains("__enzyme_sample")) {
1873 assert(func.getFunctionType()->getNumParams() >= 3);
1874 sampleFunctions.insert(&func);
1875 } else if (func.getName().contains("__enzyme_observe")) {
1876 assert(func.getFunctionType()->getNumParams() >= 3);
1877 observeFunctions.insert(&func);
1878 }
1879 }
1880
1881 assert(!sampleFunctions.empty() || !observeFunctions.empty());
1882
1883 bool autodiff = dtrace || dlikelihood;
1884 IRBuilder<> AllocaBuilder(CI->getParent()->getFirstNonPHI());
1885
1886 if (!likelihood) {
1887 likelihood = AllocaBuilder.CreateAlloca(AllocaBuilder.getDoubleTy(),
1888 nullptr, "likelihood");
1889 Builder.CreateStore(ConstantFP::getNullValue(Builder.getDoubleTy()),
1890 likelihood);
1891 }
1892 args.push_back(likelihood);
1893
1894 if (autodiff && !dlikelihood) {
1895 dlikelihood = AllocaBuilder.CreateAlloca(AllocaBuilder.getDoubleTy(),
1896 nullptr, "dlikelihood");
1897 Builder.CreateStore(ConstantFP::get(Builder.getDoubleTy(), 1.0),
1898 dlikelihood);
1899 }
1900
1901 if (autodiff) {
1902 dargs.push_back(likelihood);
1903 dargs.push_back(dlikelihood);
1904 constants.push_back(DIFFE_TYPE::DUP_ARG);
1905 opt->overwritten_args.push_back(false);
1906 } else {
1907 constants.push_back(DIFFE_TYPE::CONSTANT);
1908 opt->overwritten_args.push_back(false);
1909 }
1910
1911 if (mode == ProbProgMode::Condition) {
1912 opt->overwritten_args.push_back(false);
1913 args.push_back(observations);
1914 dargs.push_back(observations);
1915 constants.push_back(DIFFE_TYPE::CONSTANT);
1916 }
1917
1918 if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) {
1919 opt->overwritten_args.push_back(false);
1920 args.push_back(trace);
1921 dargs.push_back(trace);
1922 constants.push_back(DIFFE_TYPE::CONSTANT);
1923 }
1924
1925 auto newFunc = Logic.CreateTrace(
1926 RequestContext(CI, &Builder), F, sampleFunctions, observeFunctions,
1927 opt->ActiveRandomVariables, mode, autodiff, interface.get());
1928
1929 if (!autodiff) {
1930 auto call = CallInst::Create(newFunc->getFunctionType(), newFunc, args);
1931 ReplaceInstWithInst(CI, call);
1932 return true;
1933 }
1934
1935 Value *ret = CI;
1936 Type *retElemType = nullptr;
1937 if (CI->hasStructRetAttr()) {
1938 ret = CI->getArgOperand(0);
1939 retElemType =
1940 CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet)
1941 .getValueAsType();
1942 }
1943
1944 bool status = HandleAutoDiff(
1945 CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants,
1946 newFunc, DerivativeMode::ReverseModeCombined, *opt, false, calls);
1947
1948 return status;
1949 }
1950
1951 bool handleFullModuleTrunc(Function &F) {
1952 if (startsWith(F.getName(), EnzymeFPRTPrefix))
1953 return false;
1954 typedef std::vector<FloatTruncation> TruncationsTy;
1955 static TruncationsTy FullModuleTruncs = []() -> TruncationsTy {
1956 StringRef ConfigStr(EnzymeTruncateAll);
1957 auto Invalid = [=]() {
1958 // TODO emit better diagnostic
1959 llvm::report_fatal_error("error: invalid format for truncation config");
1960 };
1961
1962 // "64" or "11-52"
1963 auto parseFloatRepr = [&]() -> std::optional<FloatRepresentation> {
1964 unsigned Tmp = 0;
1965 if (ConfigStr.consumeInteger(10, Tmp))
1966 return {};
1967 if (ConfigStr.consume_front("-")) {
1968 unsigned Tmp2 = 0;
1969 if (ConfigStr.consumeInteger(10, Tmp2))
1970 Invalid();
1971 return FloatRepresentation(Tmp, Tmp2);
1972 }
1973 return getDefaultFloatRepr(Tmp);
1974 };
1975
1976 // Parse "64to32;32to16;5-10to4-9"
1977 TruncationsTy Tmp;
1978 while (true) {
1979 auto From = parseFloatRepr();
1980 if (!From && !ConfigStr.empty())
1981 Invalid();
1982 if (!From)
1983 break;
1984 if (!ConfigStr.consume_front("to"))
1985 Invalid();
1986 auto To = parseFloatRepr();
1987 if (!To)
1988 Invalid();
1989 Tmp.push_back({*From, *To, TruncOpFullModuleMode});
1990 ConfigStr.consume_front(";");
1991 }
1992 return Tmp;
1993 }();
1994
1995 if (FullModuleTruncs.empty())
1996 return false;
1997
1998 // TODO sort truncations (64to32, then 32to16 will make everything 16)
1999 for (auto Truncation : FullModuleTruncs) {
2000 IRBuilder<> Builder(F.getContext());
2001 RequestContext context(&*F.getEntryBlock().begin(), &Builder);
2002 Function *TruncatedFunc = Logic.CreateTruncateFunc(
2003 context, &F, Truncation, TruncOpFullModuleMode);
2004
2005 ValueToValueMapTy Mapping;
2006 for (auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args()))
2007 Mapping[&TArg] = &Arg;
2008
2009 // Move the truncated body into the original function
2010 F.deleteBody();
2011#if LLVM_VERSION_MAJOR >= 16
2012 F.splice(F.begin(), TruncatedFunc);
2013#else
2014 F.getBasicBlockList().splice(F.begin(),
2015 TruncatedFunc->getBasicBlockList());
2016#endif
2017 RemapFunction(F, Mapping,
2018 RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
2019 TruncatedFunc->deleteBody();
2020 }
2021 return true;
2022 }
2023
2024 bool lowerEnzymeCalls(Function &F, std::set<Function *> &done) {
2025 if (done.count(&F))
2026 return false;
2027 done.insert(&F);
2028
2029 if (F.empty())
2030 return false;
2031
2032 if (handleFullModuleTrunc(F))
2033 return true;
2034
2035 bool Changed = false;
2036
2037 for (BasicBlock &BB : F)
2038 if (InvokeInst *II = dyn_cast<InvokeInst>(BB.getTerminator())) {
2039
2040 Function *Fn = II->getCalledFunction();
2041
2042 if (auto castinst = dyn_cast<ConstantExpr>(II->getCalledOperand())) {
2043 if (castinst->isCast())
2044 if (auto fn = dyn_cast<Function>(castinst->getOperand(0)))
2045 Fn = fn;
2046 }
2047 if (!Fn)
2048 continue;
2049
2050 if (!(Fn->getName().contains("__enzyme_float") ||
2051 Fn->getName().contains("__enzyme_double") ||
2052 Fn->getName().contains("__enzyme_integer") ||
2053 Fn->getName().contains("__enzyme_pointer") ||
2054 Fn->getName().contains("__enzyme_virtualreverse") ||
2055 Fn->getName().contains("__enzyme_call_inactive") ||
2056 Fn->getName().contains("__enzyme_autodiff") ||
2057 Fn->getName().contains("__enzyme_fwddiff") ||
2058 Fn->getName().contains("__enzyme_fwdsplit") ||
2059 Fn->getName().contains("__enzyme_augmentfwd") ||
2060 Fn->getName().contains("__enzyme_augmentsize") ||
2061 Fn->getName().contains("__enzyme_reverse") ||
2062 Fn->getName().contains("__enzyme_truncate") ||
2063 Fn->getName().contains("__enzyme_batch") ||
2064 Fn->getName().contains("__enzyme_error_estimate") ||
2065 Fn->getName().contains("__enzyme_trace") ||
2066 Fn->getName().contains("__enzyme_condition")))
2067 continue;
2068
2069 SmallVector<Value *, 16> CallArgs(II->arg_begin(), II->arg_end());
2070 SmallVector<OperandBundleDef, 1> OpBundles;
2071 II->getOperandBundlesAsDefs(OpBundles);
2072 // Insert a normal call instruction...
2073 CallInst *NewCall =
2074 CallInst::Create(II->getFunctionType(), II->getCalledOperand(),
2075 CallArgs, OpBundles, "", II);
2076 NewCall->takeName(II);
2077 NewCall->setCallingConv(II->getCallingConv());
2078 NewCall->setAttributes(II->getAttributes());
2079 NewCall->setDebugLoc(II->getDebugLoc());
2080 II->replaceAllUsesWith(NewCall);
2081
2082 // Insert an unconditional branch to the normal destination.
2083 BranchInst::Create(II->getNormalDest(), II);
2084
2085 // Remove any PHI node entries from the exception destination.
2086 II->getUnwindDest()->removePredecessor(&BB);
2087
2088 II->eraseFromParent();
2089 Changed = true;
2090 }
2091
2092 MapVector<CallInst *, DerivativeMode> toLower;
2093 MapVector<CallInst *, DerivativeMode> toVirtual;
2094 MapVector<CallInst *, DerivativeMode> toSize;
2095 SmallVector<CallInst *, 4> toBatch;
2096 SmallVector<CallInst *, 4> toTruncateFuncMem;
2097 SmallVector<CallInst *, 4> toTruncateFuncOp;
2098 SmallVector<CallInst *, 4> toTruncateValue;
2099 SmallVector<CallInst *, 4> toExpandValue;
2100 MapVector<CallInst *, ProbProgMode> toProbProg;
2101 SetVector<CallInst *> InactiveCalls;
2102 SetVector<CallInst *> IterCalls;
2103 retry:;
2104 for (BasicBlock &BB : F) {
2105 for (Instruction &I : BB) {
2106 CallInst *CI = dyn_cast<CallInst>(&I);
2107
2108 if (!CI)
2109 continue;
2110
2111 Function *Fn = nullptr;
2112
2113 Value *FnOp = CI->getCalledOperand();
2114 while (true) {
2115 if ((Fn = dyn_cast<Function>(FnOp)))
2116 break;
2117 if (auto castinst = dyn_cast<ConstantExpr>(FnOp)) {
2118 if (castinst->isCast()) {
2119 FnOp = castinst->getOperand(0);
2120 continue;
2121 }
2122 }
2123 break;
2124 }
2125
2126 if (!Fn)
2127 continue;
2128
2129 size_t num_args = CI->arg_size();
2130
2131 if (Fn->getName().contains("__enzyme_todense") ||
2132 Fn->getName().contains("__enzyme_ignore_derivatives")) {
2133#if LLVM_VERSION_MAJOR >= 16
2134 CI->setOnlyReadsMemory();
2135 CI->setOnlyWritesMemory();
2136#else
2137 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2138#endif
2139 }
2140 if (Fn->getName().contains("__enzyme_float")) {
2141#if LLVM_VERSION_MAJOR >= 16
2142 CI->setOnlyReadsMemory();
2143 CI->setOnlyWritesMemory();
2144#else
2145 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2146#endif
2147 for (size_t i = 0; i < num_args; ++i) {
2148 if (CI->getArgOperand(i)->getType()->isPointerTy()) {
2149 CI->addParamAttr(i, Attribute::ReadNone);
2150 addCallSiteNoCapture(CI, i);
2151 }
2152 }
2153 }
2154 if (Fn->getName().contains("__enzyme_integer")) {
2155#if LLVM_VERSION_MAJOR >= 16
2156 CI->setOnlyReadsMemory();
2157 CI->setOnlyWritesMemory();
2158#else
2159 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2160#endif
2161 for (size_t i = 0; i < num_args; ++i) {
2162 if (CI->getArgOperand(i)->getType()->isPointerTy()) {
2163 CI->addParamAttr(i, Attribute::ReadNone);
2164 addCallSiteNoCapture(CI, i);
2165 }
2166 }
2167 }
2168 if (Fn->getName().contains("__enzyme_double")) {
2169#if LLVM_VERSION_MAJOR >= 16
2170 CI->setOnlyReadsMemory();
2171 CI->setOnlyWritesMemory();
2172#else
2173 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2174#endif
2175 for (size_t i = 0; i < num_args; ++i) {
2176 if (CI->getArgOperand(i)->getType()->isPointerTy()) {
2177 CI->addParamAttr(i, Attribute::ReadNone);
2178 addCallSiteNoCapture(CI, i);
2179 }
2180 }
2181 }
2182 if (Fn->getName().contains("__enzyme_pointer")) {
2183#if LLVM_VERSION_MAJOR >= 16
2184 CI->setOnlyReadsMemory();
2185 CI->setOnlyWritesMemory();
2186#else
2187 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2188#endif
2189 for (size_t i = 0; i < num_args; ++i) {
2190 if (CI->getArgOperand(i)->getType()->isPointerTy()) {
2191 CI->addParamAttr(i, Attribute::ReadNone);
2192 addCallSiteNoCapture(CI, i);
2193 }
2194 }
2195 }
2196 if (Fn->getName().contains("__enzyme_virtualreverse")) {
2197#if LLVM_VERSION_MAJOR >= 16
2198 CI->setOnlyReadsMemory();
2199 CI->setOnlyWritesMemory();
2200#else
2201 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2202#endif
2203 }
2204 if (Fn->getName().contains("__enzyme_iter")) {
2205#if LLVM_VERSION_MAJOR >= 16
2206 CI->setOnlyReadsMemory();
2207 CI->setOnlyWritesMemory();
2208#else
2209 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2210#endif
2211 }
2212 if (Fn->getName().contains("__enzyme_call_inactive")) {
2213 InactiveCalls.insert(CI);
2214 }
2215 if (Fn->getName() == "omp_get_max_threads" ||
2216 Fn->getName() == "omp_get_thread_num") {
2217#if LLVM_VERSION_MAJOR >= 16
2218 Fn->setOnlyAccessesInaccessibleMemory();
2219 CI->setOnlyAccessesInaccessibleMemory();
2220 Fn->setOnlyReadsMemory();
2221 CI->setOnlyReadsMemory();
2222#else
2223 Fn->addFnAttr(Attribute::InaccessibleMemOnly);
2224 CI->addAttribute(AttributeList::FunctionIndex,
2225 Attribute::InaccessibleMemOnly);
2226 Fn->addFnAttr(Attribute::ReadOnly);
2227 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
2228#endif
2229 }
2230 if ((Fn->getName() == "cblas_ddot" || Fn->getName() == "cblas_sdot") &&
2231 Fn->isDeclaration()) {
2232#if LLVM_VERSION_MAJOR >= 16
2233 Fn->setOnlyAccessesArgMemory();
2234 Fn->setOnlyReadsMemory();
2235 CI->setOnlyReadsMemory();
2236#else
2237 Fn->addFnAttr(Attribute::ArgMemOnly);
2238 Fn->addFnAttr(Attribute::ReadOnly);
2239 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
2240#endif
2241 CI->addParamAttr(1, Attribute::ReadOnly);
2242 addCallSiteNoCapture(CI, 1);
2243 CI->addParamAttr(3, Attribute::ReadOnly);
2244 addCallSiteNoCapture(CI, 3);
2245 }
2246 if (Fn->getName() == "frexp" || Fn->getName() == "frexpf" ||
2247 Fn->getName() == "frexpl") {
2248#if LLVM_VERSION_MAJOR >= 16
2249 CI->setOnlyAccessesArgMemory();
2250#else
2251 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ArgMemOnly);
2252#endif
2253 CI->addParamAttr(1, Attribute::WriteOnly);
2254 }
2255 if (Fn->getName() == "__fd_sincos_1" || Fn->getName() == "__fd_cos_1" ||
2256 Fn->getName() == "__mth_i_ipowi") {
2257#if LLVM_VERSION_MAJOR >= 16
2258 CI->setOnlyReadsMemory();
2259 CI->setOnlyWritesMemory();
2260#else
2261 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2262#endif
2263 }
2264 if (getFuncName(Fn) == "strcmp") {
2265 Fn->addParamAttr(0, Attribute::ReadOnly);
2266 Fn->addParamAttr(1, Attribute::ReadOnly);
2267#if LLVM_VERSION_MAJOR >= 16
2268 Fn->setOnlyReadsMemory();
2269 CI->setOnlyReadsMemory();
2270#else
2271 Fn->addFnAttr(Attribute::ReadOnly);
2272 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
2273#endif
2274 }
2275 if (Fn->getName() == "f90io_fmtw_end" ||
2276 Fn->getName() == "f90io_unf_end") {
2277#if LLVM_VERSION_MAJOR >= 16
2278 Fn->setOnlyAccessesInaccessibleMemory();
2279 CI->setOnlyAccessesInaccessibleMemory();
2280#else
2281 Fn->addFnAttr(Attribute::InaccessibleMemOnly);
2282 CI->addAttribute(AttributeList::FunctionIndex,
2283 Attribute::InaccessibleMemOnly);
2284#endif
2285 }
2286 if (Fn->getName() == "f90io_open2003a") {
2287#if LLVM_VERSION_MAJOR >= 16
2288 Fn->setOnlyAccessesInaccessibleMemOrArgMem();
2289 CI->setOnlyAccessesInaccessibleMemOrArgMem();
2290#else
2291 Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
2292 CI->addAttribute(AttributeList::FunctionIndex,
2293 Attribute::InaccessibleMemOrArgMemOnly);
2294#endif
2295 for (size_t i : {0, 1, 2, 3, 4, 5, 6, 7, /*8, */ 9, 10, 11, 12, 13}) {
2296 if (i < num_args &&
2297 CI->getArgOperand(i)->getType()->isPointerTy()) {
2298 CI->addParamAttr(i, Attribute::ReadOnly);
2299 }
2300 }
2301 // todo more
2302 for (size_t i : {0, 1}) {
2303 if (i < num_args &&
2304 CI->getArgOperand(i)->getType()->isPointerTy()) {
2305 addCallSiteNoCapture(CI, i);
2306 }
2307 }
2308 }
2309 if (Fn->getName() == "f90io_fmtw_inita") {
2310#if LLVM_VERSION_MAJOR >= 16
2311 Fn->setOnlyAccessesInaccessibleMemOrArgMem();
2312 CI->setOnlyAccessesInaccessibleMemOrArgMem();
2313#else
2314 Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
2315 CI->addAttribute(AttributeList::FunctionIndex,
2316 Attribute::InaccessibleMemOrArgMemOnly);
2317#endif
2318 // todo more
2319 for (size_t i : {0, 2}) {
2320 if (i < num_args &&
2321 CI->getArgOperand(i)->getType()->isPointerTy()) {
2322 CI->addParamAttr(i, Attribute::ReadOnly);
2323 }
2324 }
2325
2326 // todo more
2327 for (size_t i : {0, 2}) {
2328 if (i < num_args &&
2329 CI->getArgOperand(i)->getType()->isPointerTy()) {
2330 addCallSiteNoCapture(CI, i);
2331 }
2332 }
2333 }
2334
2335 if (Fn->getName() == "f90io_unf_init") {
2336#if LLVM_VERSION_MAJOR >= 16
2337 Fn->setOnlyAccessesInaccessibleMemOrArgMem();
2338 CI->setOnlyAccessesInaccessibleMemOrArgMem();
2339#else
2340 Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
2341 CI->addAttribute(AttributeList::FunctionIndex,
2342 Attribute::InaccessibleMemOrArgMemOnly);
2343#endif
2344 // todo more
2345 for (size_t i : {0, 1, 2, 3}) {
2346 if (i < num_args &&
2347 CI->getArgOperand(i)->getType()->isPointerTy()) {
2348 CI->addParamAttr(i, Attribute::ReadOnly);
2349 }
2350 }
2351
2352 // todo more
2353 for (size_t i : {0, 1, 2, 3}) {
2354 if (i < num_args &&
2355 CI->getArgOperand(i)->getType()->isPointerTy()) {
2356 addCallSiteNoCapture(CI, i);
2357 }
2358 }
2359 }
2360
2361 if (Fn->getName() == "f90io_src_info03a") {
2362#if LLVM_VERSION_MAJOR >= 16
2363 Fn->setOnlyAccessesInaccessibleMemOrArgMem();
2364 CI->setOnlyAccessesInaccessibleMemOrArgMem();
2365#else
2366 Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
2367 CI->addAttribute(AttributeList::FunctionIndex,
2368 Attribute::InaccessibleMemOrArgMemOnly);
2369#endif
2370 // todo more
2371 for (size_t i : {0, 1}) {
2372 if (i < num_args &&
2373 CI->getArgOperand(i)->getType()->isPointerTy()) {
2374 CI->addParamAttr(i, Attribute::ReadOnly);
2375 }
2376 }
2377
2378 // todo more
2379 for (size_t i : {0}) {
2380 if (i < num_args &&
2381 CI->getArgOperand(i)->getType()->isPointerTy()) {
2382 addCallSiteNoCapture(CI, i);
2383 }
2384 }
2385 }
2386 if (Fn->getName() == "f90io_sc_d_fmt_write" ||
2387 Fn->getName() == "f90io_sc_i_fmt_write" ||
2388 Fn->getName() == "ftnio_fmt_write64" ||
2389 Fn->getName() == "f90io_fmt_write64_aa" ||
2390 Fn->getName() == "f90io_fmt_writea" ||
2391 Fn->getName() == "f90io_unf_writea" ||
2392 Fn->getName() == "f90_pausea") {
2393#if LLVM_VERSION_MAJOR >= 16
2394 Fn->setOnlyAccessesInaccessibleMemOrArgMem();
2395 CI->setOnlyAccessesInaccessibleMemOrArgMem();
2396#else
2397 Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
2398 CI->addAttribute(AttributeList::FunctionIndex,
2399 Attribute::InaccessibleMemOrArgMemOnly);
2400#endif
2401 for (size_t i = 0; i < num_args; ++i) {
2402 if (CI->getArgOperand(i)->getType()->isPointerTy()) {
2403 CI->addParamAttr(i, Attribute::ReadOnly);
2404 addCallSiteNoCapture(CI, i);
2405 }
2406 }
2407 }
2408
2409 bool enableEnzyme = false;
2410 bool virtualCall = false;
2411 bool sizeOnly = false;
2412 bool batch = false;
2413 bool truncateFuncOp = false;
2414 bool truncateFuncMem = false;
2415 bool truncateValue = false;
2416 bool expandValue = false;
2417 bool probProg = false;
2418 DerivativeMode derivativeMode;
2419 ProbProgMode probProgMode;
2420 if (Fn->getName().contains("__enzyme_autodiff")) {
2421 enableEnzyme = true;
2422 derivativeMode = DerivativeMode::ReverseModeCombined;
2423 } else if (Fn->getName().contains("__enzyme_fwddiff")) {
2424 enableEnzyme = true;
2425 derivativeMode = DerivativeMode::ForwardMode;
2426 } else if (Fn->getName().contains("__enzyme_error_estimate")) {
2427 enableEnzyme = true;
2428 derivativeMode = DerivativeMode::ForwardModeError;
2429 } else if (Fn->getName().contains("__enzyme_fwdsplit")) {
2430 enableEnzyme = true;
2431 derivativeMode = DerivativeMode::ForwardModeSplit;
2432 } else if (Fn->getName().contains("__enzyme_augmentfwd")) {
2433 enableEnzyme = true;
2434 derivativeMode = DerivativeMode::ReverseModePrimal;
2435 } else if (Fn->getName().contains("__enzyme_augmentsize")) {
2436 enableEnzyme = true;
2437 sizeOnly = true;
2438 derivativeMode = DerivativeMode::ReverseModePrimal;
2439 } else if (Fn->getName().contains("__enzyme_reverse")) {
2440 enableEnzyme = true;
2441 derivativeMode = DerivativeMode::ReverseModeGradient;
2442 } else if (Fn->getName().contains("__enzyme_virtualreverse")) {
2443 enableEnzyme = true;
2444 virtualCall = true;
2445 derivativeMode = DerivativeMode::ReverseModeCombined;
2446 } else if (Fn->getName().contains("__enzyme_batch")) {
2447 enableEnzyme = true;
2448 batch = true;
2449 } else if (Fn->getName().contains("__enzyme_truncate_mem_func")) {
2450 enableEnzyme = true;
2451 truncateFuncMem = true;
2452 } else if (Fn->getName().contains("__enzyme_truncate_op_func")) {
2453 enableEnzyme = true;
2454 truncateFuncOp = true;
2455 } else if (Fn->getName().contains("__enzyme_truncate_mem_value")) {
2456 enableEnzyme = true;
2457 truncateValue = true;
2458 } else if (Fn->getName().contains("__enzyme_expand_mem_value")) {
2459 enableEnzyme = true;
2460 expandValue = true;
2461 } else if (Fn->getName().contains("__enzyme_likelihood")) {
2462 enableEnzyme = true;
2463 probProgMode = ProbProgMode::Likelihood;
2464 probProg = true;
2465 } else if (Fn->getName().contains("__enzyme_trace")) {
2466 enableEnzyme = true;
2467 probProgMode = ProbProgMode::Trace;
2468 probProg = true;
2469 } else if (Fn->getName().contains("__enzyme_condition")) {
2470 enableEnzyme = true;
2471 probProgMode = ProbProgMode::Condition;
2472 probProg = true;
2473 }
2474
2475 if (enableEnzyme) {
2476
2477 Value *fn = CI->getArgOperand(0);
2478 while (auto ci = dyn_cast<CastInst>(fn)) {
2479 fn = ci->getOperand(0);
2480 }
2481 while (auto ci = dyn_cast<BlockAddress>(fn)) {
2482 fn = ci->getFunction();
2483 }
2484 while (auto ci = dyn_cast<ConstantExpr>(fn)) {
2485 fn = ci->getOperand(0);
2486 }
2487 if (auto si = dyn_cast<SelectInst>(fn)) {
2488 BasicBlock *post = BB.splitBasicBlock(CI);
2489 BasicBlock *sel1 = BasicBlock::Create(BB.getContext(), "sel1", &F);
2490 BasicBlock *sel2 = BasicBlock::Create(BB.getContext(), "sel2", &F);
2491 BB.getTerminator()->eraseFromParent();
2492 IRBuilder<> PB(&BB);
2493 PB.CreateCondBr(si->getCondition(), sel1, sel2);
2494 IRBuilder<> S1(sel1);
2495 auto B1 = S1.CreateBr(post);
2496 CallInst *cloned = cast<CallInst>(CI->clone());
2497 cloned->insertBefore(B1);
2498 cloned->setOperand(0, si->getTrueValue());
2499 IRBuilder<> S2(sel2);
2500 auto B2 = S2.CreateBr(post);
2501 CI->moveBefore(B2);
2502 CI->setOperand(0, si->getFalseValue());
2503 if (CI->getNumUses() != 0) {
2504 IRBuilder<> P(post->getFirstNonPHI());
2505 auto merge = P.CreatePHI(CI->getType(), 2);
2506 merge->addIncoming(cloned, sel1);
2507 merge->addIncoming(CI, sel2);
2508 CI->replaceAllUsesWith(merge);
2509 }
2510 goto retry;
2511 }
2512 if (virtualCall)
2513 toVirtual[CI] = derivativeMode;
2514 else if (sizeOnly)
2515 toSize[CI] = derivativeMode;
2516 else if (batch)
2517 toBatch.push_back(CI);
2518 else if (truncateFuncOp)
2519 toTruncateFuncOp.push_back(CI);
2520 else if (truncateFuncMem)
2521 toTruncateFuncMem.push_back(CI);
2522 else if (truncateValue)
2523 toTruncateValue.push_back(CI);
2524 else if (expandValue)
2525 toExpandValue.push_back(CI);
2526 else if (probProg) {
2527 toProbProg[CI] = probProgMode;
2528 } else
2529 toLower[CI] = derivativeMode;
2530
2531 if (auto dc = dyn_cast<Function>(fn)) {
2532 // Force postopt on any inner functions in the nested
2533 // AD case.
2534 bool tmp = Logic.PostOpt;
2535 Logic.PostOpt = true;
2536 Changed |= lowerEnzymeCalls(*dc, done);
2537 Logic.PostOpt = tmp;
2538 }
2539 }
2540 }
2541 }
2542
2543 for (auto CI : InactiveCalls) {
2544 IRBuilder<> B(CI);
2545 Value *fn = CI->getArgOperand(0);
2546 SmallVector<Value *, 4> Args;
2547 SmallVector<Type *, 4> ArgTypes;
2548 for (size_t i = 1; i < CI->arg_size(); ++i) {
2549 Args.push_back(CI->getArgOperand(i));
2550 ArgTypes.push_back(CI->getArgOperand(i)->getType());
2551 }
2552 auto FT = FunctionType::get(CI->getType(), ArgTypes, /*varargs*/ false);
2553 if (fn->getType() != FT) {
2554 fn = B.CreatePointerCast(fn, getUnqual(FT));
2555 }
2556 auto Rep = B.CreateCall(FT, fn, Args);
2557 Rep->addAttribute(AttributeList::FunctionIndex,
2558 Attribute::get(Rep->getContext(), "enzyme_inactive"));
2559 CI->replaceAllUsesWith(Rep);
2560 CI->eraseFromParent();
2561 Changed = true;
2562 }
2563
2564 SmallVector<CallInst *, 1> calls;
2565
2566 // Perform all the size replacements first to create constants
2567 for (auto pair : toSize) {
2568 bool successful = HandleAutoDiffArguments(pair.first, pair.second,
2569 /*sizeOnly*/ true, calls);
2570 Changed = true;
2571 if (!successful)
2572 break;
2573 }
2574 for (auto pair : toLower) {
2575 bool successful = HandleAutoDiffArguments(pair.first, pair.second,
2576 /*sizeOnly*/ false, calls);
2577 Changed = true;
2578 if (!successful)
2579 break;
2580 }
2581
2582 for (auto pair : toVirtual) {
2583 auto CI = pair.first;
2584 Constant *fn = dyn_cast<Constant>(CI->getArgOperand(0));
2585 if (!fn) {
2586 EmitFailure("IllegalVirtual", CI->getDebugLoc(), CI,
2587 "Cannot create virtual version of non-constant value ", *CI,
2588 *CI->getArgOperand(0));
2589 return false;
2590 }
2591 TypeAnalysis TA(Logic);
2592
2593 auto Arch =
2594 llvm::Triple(
2595 CI->getParent()->getParent()->getParent()->getTargetTriple())
2596 .getArch();
2597
2598 bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
2599 Arch == Triple::amdgcn;
2600
2601 IRBuilder<> Builder(CI);
2603 RequestContext(CI, &Builder), Logic,
2604 Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F), TA, fn,
2605 pair.second, /*runtimeActivity*/ false, /*strongZero*/ false,
2606 /*width*/ 1, AtomicAdd);
2607 CI->replaceAllUsesWith(ConstantExpr::getPointerCast(val, CI->getType()));
2608 CI->eraseFromParent();
2609 Changed = true;
2610 }
2611
2612 for (auto call : toBatch) {
2613 HandleBatch(call);
2614 }
2615 for (auto call : toTruncateFuncMem) {
2616 HandleTruncateFunc(call, TruncMemMode);
2617 }
2618 for (auto call : toTruncateFuncOp) {
2619 HandleTruncateFunc(call, TruncOpMode);
2620 }
2621 for (auto call : toTruncateValue) {
2622 HandleTruncateValue(call, true);
2623 }
2624 for (auto call : toExpandValue) {
2625 HandleTruncateValue(call, false);
2626 }
2627
2628 for (auto &&[call, mode] : toProbProg) {
2629 HandleProbProg(call, mode, calls);
2630 }
2631
2632 if (Logic.PostOpt) {
2633 auto Params = llvm::getInlineParams();
2634
2635 llvm::SetVector<CallInst *> Q;
2636 for (auto call : calls)
2637 Q.insert(call);
2638 while (Q.size()) {
2639 auto cur = *Q.begin();
2640 Function *outerFunc = cur->getParent()->getParent();
2641 llvm::OptimizationRemarkEmitter ORE(outerFunc);
2642 Q.erase(Q.begin());
2643 if (auto F = cur->getCalledFunction()) {
2644 if (!F->empty()) {
2645 // Garbage collect AC's created
2646 SmallVector<std::unique_ptr<AssumptionCache>, 2> ACAlloc;
2647 auto getAC = [&](Function &F) -> llvm::AssumptionCache & {
2648 auto AC = std::make_unique<AssumptionCache>(F);
2649 ACAlloc.push_back(std::move(AC));
2650 return *ACAlloc.back();
2651 };
2652 auto GetTLI =
2653 [&](llvm::Function &F) -> const llvm::TargetLibraryInfo & {
2654 return Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F);
2655 };
2656
2657 TargetTransformInfo TTI(F->getParent()->getDataLayout());
2658 auto GetInlineCost = [&](CallBase &CB) {
2659 auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI);
2660 return cst;
2661 };
2662#if LLVM_VERSION_MAJOR >= 20
2663 if (llvm::shouldInline(*cur, TTI, GetInlineCost, ORE))
2664#else
2665 if (llvm::shouldInline(*cur, GetInlineCost, ORE))
2666#endif
2667 {
2668 InlineFunctionInfo IFI;
2669 InlineResult IR = InlineFunction(*cur, IFI);
2670 if (IR.isSuccess()) {
2671 LowerSparsification(outerFunc, /*replaceAll*/ false);
2672 for (auto U : outerFunc->users()) {
2673 if (auto CI = dyn_cast<CallInst>(U)) {
2674 if (CI->getCalledFunction() == outerFunc) {
2675 Q.insert(CI);
2676 }
2677 }
2678 }
2679 }
2680 }
2681 }
2682 }
2683 }
2684 }
2685
2686 if (Changed && EnzymeAttributor) {
2687 // TODO consider enabling when attributor does not delete
2688 // dead internal functions, which invalidates Enzyme's cache
2689 // code left here to re-enable upon Attributor patch
2690
2691#if !defined(FLANG) && !defined(ROCM)
2692
2693 AnalysisGetter AG(Logic.PPC.FAM);
2694 SetVector<Function *> Functions;
2695 for (Function &F2 : *F.getParent()) {
2696 Functions.insert(&F2);
2697 }
2698
2699 CallGraphUpdater CGUpdater;
2700 BumpPtrAllocator Allocator;
2701 InformationCache InfoCache(*F.getParent(), AG, Allocator,
2702 /* CGSCC */ nullptr);
2703
2704 DenseSet<const char *> Allowed = {
2705 &AAHeapToStack::ID,
2706 &AANoCapture::ID,
2707
2708 &AAMemoryBehavior::ID,
2709 &AAMemoryLocation::ID,
2710 &AANoUnwind::ID,
2711 &AANoSync::ID,
2712 &AANoRecurse::ID,
2713 &AAWillReturn::ID,
2714 &AANoReturn::ID,
2715 &AANonNull::ID,
2716 &AANoAlias::ID,
2717 &AADereferenceable::ID,
2718 &AAAlign::ID,
2719#if LLVM_VERSION_MAJOR < 17
2720 &AAReturnedValues::ID,
2721#endif
2722 &AANoFree::ID,
2723 &AANoUndef::ID,
2724
2725 //&AAValueSimplify::ID,
2726 //&AAReachability::ID,
2727 //&AAValueConstantRange::ID,
2728 //&AAUndefinedBehavior::ID,
2729 //&AAPotentialValues::ID,
2730 };
2731
2732 AttributorConfig aconfig(CGUpdater);
2733 aconfig.Allowed = &Allowed;
2734 aconfig.DeleteFns = false;
2735 Attributor A(Functions, InfoCache, aconfig);
2736 for (Function *F : Functions) {
2737 // Populate the Attributor with abstract attribute opportunities in
2738 // the function and the information cache with IR information.
2739 A.identifyDefaultAbstractAttributes(*F);
2740 }
2741 A.run();
2742#endif
2743 }
2744
2745 return Changed;
2746 }
2747
2748 bool run(Module &M) {
2749 Logic.clear();
2750
2751 for (Function &F : make_early_inc_range(M)) {
2753 }
2754
2755 bool changed = false;
2756 for (Function &F : M) {
2757 if (F.empty())
2758 continue;
2759 for (BasicBlock &BB : F) {
2760 for (Instruction &I : make_early_inc_range(BB)) {
2761 if (auto CI = dyn_cast<CallInst>(&I)) {
2762 Function *F = CI->getCalledFunction();
2763 if (auto castinst =
2764 dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
2765 if (castinst->isCast())
2766 if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
2767 F = fn;
2768 }
2769 }
2770 if (F && F->getName() == "f90_mzero8") {
2771 IRBuilder<> B(CI);
2772
2773 Value *args[3];
2774 args[0] = CI->getArgOperand(0);
2775 args[1] = ConstantInt::get(Type::getInt8Ty(M.getContext()), 0);
2776 args[2] = B.CreateMul(
2777 CI->getArgOperand(1),
2778 ConstantInt::get(CI->getArgOperand(1)->getType(), 8));
2779 B.CreateMemSet(args[0], args[1], args[2], MaybeAlign());
2780
2781 CI->eraseFromParent();
2782 }
2783 }
2784 }
2785 }
2786 }
2787
2788 if (Logic.PostOpt && EnzymeOMPOpt) {
2789 OpenMPOptPass().run(M, Logic.PPC.MAM);
2790 /// Attributor is run second time for promoted args to get attributes.
2791 AttributorPass().run(M, Logic.PPC.MAM);
2792 for (auto &F : M)
2793 if (!F.empty())
2794 PromotePass().run(F, Logic.PPC.FAM);
2795 changed = true;
2796 }
2797
2799 changed = true;
2800 }
2801
2802 std::set<Function *> done;
2803 for (Function &F : M) {
2804 if (F.empty())
2805 continue;
2806
2807 changed |= lowerEnzymeCalls(F, done);
2808 }
2809
2810 for (Function &F : M) {
2811 if (F.empty())
2812 continue;
2813
2814 for (BasicBlock &BB : F) {
2815 for (Instruction &I : make_early_inc_range(BB)) {
2816 if (auto CI = dyn_cast<CallInst>(&I)) {
2817 Function *F = CI->getCalledFunction();
2818 if (auto castinst =
2819 dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
2820 if (castinst->isCast())
2821 if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
2822 F = fn;
2823 }
2824 }
2825 if (F) {
2826 if (F->getName().contains("__enzyme_float") ||
2827 F->getName().contains("__enzyme_double") ||
2828 F->getName().contains("__enzyme_integer") ||
2829 F->getName().contains("__enzyme_pointer")) {
2830 CI->eraseFromParent();
2831 changed = true;
2832 }
2833 if (F->getName().contains("__enzyme_iter") ||
2834 F->getName().contains("__enzyme_ignore_derivatives")) {
2835 CI->replaceAllUsesWith(CI->getArgOperand(0));
2836 CI->eraseFromParent();
2837 changed = true;
2838 }
2839 }
2840 }
2841 }
2842 }
2843 }
2844
2845 SmallPtrSet<CallInst *, 16> sample_calls;
2846 SmallPtrSet<CallInst *, 16> observe_calls;
2847 for (auto &&func : M) {
2848 for (auto &&BB : func) {
2849 for (auto &&Inst : BB) {
2850 if (auto CI = dyn_cast<CallInst>(&Inst)) {
2851 Function *fun = CI->getCalledFunction();
2852 if (!fun)
2853 continue;
2854
2855 if (fun->getName().contains("__enzyme_sample")) {
2856 if (CI->getNumOperands() < 3) {
2858 "IllegalNumberOfArguments", CI->getDebugLoc(), CI,
2859 "Not enough arguments passed to call to __enzyme_sample");
2860 }
2861 Function *samplefn = GetFunctionFromValue(CI->getOperand(0));
2862 unsigned expected =
2863 samplefn->getFunctionType()->getNumParams() + 3;
2864 unsigned actual = CI->arg_size();
2865 if (actual - 3 != samplefn->getFunctionType()->getNumParams()) {
2866 EmitFailure("IllegalNumberOfArguments", CI->getDebugLoc(), CI,
2867 "Illegal number of arguments passed to call to "
2868 "__enzyme_sample.",
2869 " Expected: ", expected, " got: ", actual);
2870 }
2871 Function *pdf = GetFunctionFromValue(CI->getArgOperand(1));
2872
2873 for (unsigned i = 0;
2874 i < samplefn->getFunctionType()->getNumParams(); ++i) {
2875 Value *ci_arg = CI->getArgOperand(i + 3);
2876 Value *sample_arg = samplefn->arg_begin() + i;
2877 Value *pdf_arg = pdf->arg_begin() + i;
2878
2879 if (ci_arg->getType() != sample_arg->getType()) {
2881 "IllegalSampleType", CI->getDebugLoc(), CI,
2882 "Type of: ", *ci_arg, " (", *ci_arg->getType(), ")",
2883 " does not match the argument type of the sample "
2884 "function: ",
2885 *samplefn, " at: ", i, " (", *sample_arg->getType(), ")");
2886 }
2887 if (ci_arg->getType() != pdf_arg->getType()) {
2888 EmitFailure("IllegalSampleType", CI->getDebugLoc(), CI,
2889 "Type of: ", *ci_arg, " (", *ci_arg->getType(),
2890 ")",
2891 " does not match the argument type of the "
2892 "density function: ",
2893 *pdf, " at: ", i, " (", *pdf_arg->getType(), ")");
2894 }
2895 }
2896
2897 if ((pdf->arg_end() - 1)->getType() !=
2898 samplefn->getReturnType()) {
2900 "IllegalSampleType", CI->getDebugLoc(), CI,
2901 "Return type of ", *samplefn, " (",
2902 *samplefn->getReturnType(), ")",
2903 " does not match the last argument type of the density "
2904 "function: ",
2905 *pdf, " (", *(pdf->arg_end() - 1)->getType(), ")");
2906 }
2907 sample_calls.insert(CI);
2908
2909 } else if (fun->getName().contains("__enzyme_observe")) {
2910 if (CI->getNumOperands() < 3) {
2912 "IllegalNumberOfArguments", CI->getDebugLoc(), CI,
2913 "Not enough arguments passed to call to __enzyme_sample");
2914 }
2915 Value *observed = CI->getOperand(0);
2916 Function *pdf = GetFunctionFromValue(CI->getArgOperand(1));
2917 unsigned expected = pdf->getFunctionType()->getNumParams() - 1;
2918
2919 unsigned actual = CI->arg_size();
2920 if (actual - 3 != expected) {
2921 EmitFailure("IllegalNumberOfArguments", CI->getDebugLoc(), CI,
2922 "Illegal number of arguments passed to call to "
2923 "__enzyme_observe.",
2924 " Expected: ", expected, " got: ", actual);
2925 }
2926
2927 for (unsigned i = 0;
2928 i < pdf->getFunctionType()->getNumParams() - 1; ++i) {
2929 Value *ci_arg = CI->getArgOperand(i + 3);
2930 Value *pdf_arg = pdf->arg_begin() + i;
2931
2932 if (ci_arg->getType() != pdf_arg->getType()) {
2933 EmitFailure("IllegalSampleType", CI->getDebugLoc(), CI,
2934 "Type of: ", *ci_arg, " (", *ci_arg->getType(),
2935 ")",
2936 " does not match the argument type of the "
2937 "density function: ",
2938 *pdf, " at: ", i, " (", *pdf_arg->getType(), ")");
2939 }
2940 }
2941
2942 if ((pdf->arg_end() - 1)->getType() != observed->getType()) {
2944 "IllegalSampleType", CI->getDebugLoc(), CI,
2945 "Return type of ", *observed, " (", *observed->getType(),
2946 ")",
2947 " does not match the last argument type of the density "
2948 "function: ",
2949 *pdf, " (", *(pdf->arg_end() - 1)->getType(), ")");
2950 }
2951 observe_calls.insert(CI);
2952 }
2953 }
2954 }
2955 }
2956 }
2957
2958 // Replace calls to __enzyme_sample with the actual sample calls after
2959 // running prob prog
2960 for (auto call : sample_calls) {
2961 Function *samplefn = GetFunctionFromValue(call->getArgOperand(0));
2962
2963 SmallVector<Value *, 2> args;
2964 for (auto it = call->arg_begin() + 3; it != call->arg_end(); it++) {
2965 args.push_back(*it);
2966 }
2967 CallInst *choice =
2968 CallInst::Create(samplefn->getFunctionType(), samplefn, args);
2969
2970 ReplaceInstWithInst(call, choice);
2971 }
2972
2973 for (auto call : observe_calls) {
2974 Value *observed = call->getArgOperand(0);
2975
2976 if (!call->getType()->isVoidTy())
2977 call->replaceAllUsesWith(observed);
2978 call->eraseFromParent();
2979 }
2980
2981 for (const auto &pair : Logic.PPC.cache)
2982 pair.second->eraseFromParent();
2983 Logic.clear();
2984
2985 if (changed && Logic.PostOpt) {
2986 TimeTraceScope timeScope("Enzyme PostOpt", M.getName());
2987
2988 PassBuilder PB;
2989 LoopAnalysisManager LAM;
2990 FunctionAnalysisManager FAM;
2991 CGSCCAnalysisManager CGAM;
2992 ModuleAnalysisManager MAM;
2993 PB.registerModuleAnalyses(MAM);
2994 PB.registerFunctionAnalyses(FAM);
2995 PB.registerLoopAnalyses(LAM);
2996 PB.registerCGSCCAnalyses(CGAM);
2997 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
2998 auto PM = PB.buildModuleSimplificationPipeline(OptimizationLevel::O2,
2999 ThinOrFullLTOPhase::None);
3000 PM.run(M, MAM);
3001 if (EnzymeOMPOpt) {
3002 OpenMPOptPass().run(M, MAM);
3003 /// Attributor is run second time for promoted args to get attributes.
3004 AttributorPass().run(M, MAM);
3005 for (auto &F : M)
3006 if (!F.empty())
3007 PromotePass().run(F, FAM);
3008 }
3009 }
3010
3011 for (auto &F : M) {
3012 if (!F.empty())
3013 changed |= LowerSparsification(&F);
3014 }
3015 return changed;
3016 }
3017};
3018
3019class EnzymeOldPM : public EnzymeBase, public ModulePass {
3020public:
3021 static char ID;
3022 EnzymeOldPM(bool PostOpt = false) : EnzymeBase(PostOpt), ModulePass(ID) {}
3023
3024 void getAnalysisUsage(AnalysisUsage &AU) const override {
3025 AU.addRequired<TargetLibraryInfoWrapperPass>();
3026
3027 // AU.addRequiredID(LCSSAID);
3028
3029 // LoopInfo is required to ensure that all loops have preheaders
3030 // AU.addRequired<LoopInfoWrapperPass>();
3031
3032 // AU.addRequiredID(llvm::LoopSimplifyID);//<LoopSimplifyWrapperPass>();
3033 }
3034 bool runOnModule(Module &M) override { return run(M); }
3035};
3036
3037} // namespace
3038
3039char EnzymeOldPM::ID = 0;
3040
3041static RegisterPass<EnzymeOldPM> X("enzyme", "Enzyme Pass");
3042
3043ModulePass *createEnzymePass(bool PostOpt) { return new EnzymeOldPM(PostOpt); }
3044
3045#include <llvm-c/Core.h>
3046#include <llvm-c/Types.h>
3047
3048#include "llvm/IR/LegacyPassManager.h"
3049
3050extern "C" void AddEnzymePass(LLVMPassManagerRef PM) {
3051 unwrap(PM)->add(createEnzymePass(/*PostOpt*/ false));
3052}
3053
3054#if LLVM_VERSION_MAJOR >= 22
3055#include "llvm/Plugins/PassPlugin.h"
3056#else
3057#include "llvm/Passes/PassPlugin.h"
3058#endif
3059
3060class EnzymeNewPM final : public EnzymeBase, public PassParent<EnzymeNewPM> {
3062
3063private:
3064 static llvm::AnalysisKey Key;
3065
3066public:
3067 using Result = llvm::PreservedAnalyses;
3068 EnzymeNewPM(bool PostOpt = false) : EnzymeBase(PostOpt) {}
3069
3070 Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) {
3071 return EnzymeBase::run(M) ? PreservedAnalyses::none()
3072 : PreservedAnalyses::all();
3073 }
3074
3075 static bool isRequired() { return true; }
3076};
3077
3078#undef DEBUG_TYPE
3079AnalysisKey EnzymeNewPM::Key;
3080
3082#include "JLInstSimplify.h"
3083#include "PreserveNVVM.h"
3084#include "SimpleGVN.h"
3086#include "llvm/Passes/PassBuilder.h"
3087#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
3088#include "llvm/Transforms/IPO/AlwaysInliner.h"
3089#include "llvm/Transforms/IPO/CalledValuePropagation.h"
3090#include "llvm/Transforms/IPO/ConstantMerge.h"
3091#include "llvm/Transforms/IPO/CrossDSOCFI.h"
3092#include "llvm/Transforms/IPO/DeadArgumentElimination.h"
3093#include "llvm/Transforms/IPO/FunctionAttrs.h"
3094#include "llvm/Transforms/IPO/GlobalDCE.h"
3095#include "llvm/Transforms/IPO/GlobalOpt.h"
3096#include "llvm/Transforms/IPO/GlobalSplit.h"
3097#include "llvm/Transforms/IPO/InferFunctionAttrs.h"
3098#include "llvm/Transforms/IPO/SCCP.h"
3099#include "llvm/Transforms/InstCombine/InstCombine.h"
3100#include "llvm/Transforms/Scalar/CallSiteSplitting.h"
3101#include "llvm/Transforms/Scalar/EarlyCSE.h"
3102#include "llvm/Transforms/Scalar/Float2Int.h"
3103#include "llvm/Transforms/Scalar/GVN.h"
3104#include "llvm/Transforms/Scalar/LoopDeletion.h"
3105#include "llvm/Transforms/Scalar/LoopRotation.h"
3106#include "llvm/Transforms/Scalar/LoopUnrollPass.h"
3107#include "llvm/Transforms/Scalar/SROA.h"
3108// #include "llvm/Transforms/IPO/MemProfContextDisambiguation.h"
3109#include "llvm/Transforms/IPO/ArgumentPromotion.h"
3110#include "llvm/Transforms/Scalar/ConstraintElimination.h"
3111#include "llvm/Transforms/Scalar/DeadStoreElimination.h"
3112#include "llvm/Transforms/Scalar/JumpThreading.h"
3113#include "llvm/Transforms/Scalar/MemCpyOptimizer.h"
3114#include "llvm/Transforms/Scalar/NewGVN.h"
3115#include "llvm/Transforms/Scalar/TailRecursionElimination.h"
3116#if LLVM_VERSION_MAJOR >= 17
3117#include "llvm/Transforms/Utils/MoveAutoInit.h"
3118#endif
3119#include "llvm/Transforms/Scalar/IndVarSimplify.h"
3120#include "llvm/Transforms/Scalar/LICM.h"
3121#include "llvm/Transforms/Scalar/LoopFlatten.h"
3122#include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h"
3123
3124static InlineParams getInlineParamsFromOptLevel(OptimizationLevel Level) {
3125#if LLVM_VERSION_MAJOR >= 23
3126 return getInlineParams(Level.getSpeedupLevel());
3127#else
3128 return getInlineParams(Level.getSpeedupLevel(), Level.getSizeLevel());
3129#endif
3130}
3131
3132#include "llvm/Transforms/Scalar/LowerConstantIntrinsics.h"
3133#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
3134namespace llvm {
3135extern cl::opt<unsigned> SetLicmMssaNoAccForPromotionCap;
3136extern cl::opt<unsigned> SetLicmMssaOptCap;
3137#define EnableLoopFlatten false
3138#define EagerlyInvalidateAnalyses false
3139#define RunNewGVN false
3140#define EnableConstraintElimination true
3141#define UseInlineAdvisor InliningAdvisorMode::Default
3142#define EnableMemProfContextDisambiguation false
3143// extern cl::opt<bool> EnableMatrix;
3144#define EnableMatrix false
3145#define EnableModuleInliner false
3146} // namespace llvm
3147
3148void augmentPassBuilder(llvm::PassBuilder &PB) {
3149
3150 auto prePass = [](ModulePassManager &MPM, OptimizationLevel Level) {
3151 FunctionPassManager OptimizePM;
3152 OptimizePM.addPass(Float2IntPass());
3153 OptimizePM.addPass(LowerConstantIntrinsicsPass());
3154
3155 if (EnableMatrix) {
3156 OptimizePM.addPass(LowerMatrixIntrinsicsPass());
3157 OptimizePM.addPass(EarlyCSEPass());
3158 }
3159
3160 LoopPassManager LPM;
3161 bool LTOPreLink = false;
3162 // First rotate loops that may have been un-rotated by prior passes.
3163 // Disable header duplication at -Oz.
3164#if LLVM_VERSION_MAJOR >= 23
3165 LPM.addPass(LoopRotatePass(/*EnableLoopHeaderDuplication=*/true, LTOPreLink,
3166 /*CheckExitCount=*/true));
3167#else
3168 LPM.addPass(LoopRotatePass(Level != OptimizationLevel::Oz, LTOPreLink));
3169#endif
3170 // Some loops may have become dead by now. Try to delete them.
3171 // FIXME: see discussion in https://reviews.llvm.org/D112851,
3172 // this may need to be revisited once we run GVN before
3173 // loop deletion in the simplification pipeline.
3174 LPM.addPass(LoopDeletionPass());
3175
3176 LPM.addPass(llvm::LoopFullUnrollPass());
3177 OptimizePM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM)));
3178
3179 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizePM)));
3180 };
3181
3182#if LLVM_VERSION_MAJOR >= 20
3183 auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level,
3184 ThinOrFullLTOPhase)
3185#else
3186 auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level)
3187#endif
3188 {
3189 MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true));
3190
3191 if (!EnzymeEnable)
3192 return;
3193
3194 if (Level != OptimizationLevel::O0)
3195 prePass(MPM, Level);
3196 MPM.addPass(llvm::AlwaysInlinerPass());
3197 FunctionPassManager OptimizerPM;
3198 FunctionPassManager OptimizerPM2;
3199#if LLVM_VERSION_MAJOR >= 16
3200 OptimizerPM.addPass(llvm::GVNPass());
3201 OptimizerPM.addPass(llvm::SROAPass(llvm::SROAOptions::PreserveCFG));
3202#else
3203 OptimizerPM.addPass(llvm::GVNPass());
3204 OptimizerPM.addPass(llvm::SROAPass());
3205#endif
3206 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM)));
3207 MPM.addPass(EnzymeNewPM(/*PostOpt=*/true));
3208 MPM.addPass(PreserveNVVMNewPM(/*Begin*/ false));
3209#if LLVM_VERSION_MAJOR >= 16
3210 OptimizerPM2.addPass(llvm::GVNPass());
3211 OptimizerPM2.addPass(llvm::SROAPass(llvm::SROAOptions::PreserveCFG));
3212#else
3213 OptimizerPM2.addPass(llvm::GVNPass());
3214 OptimizerPM2.addPass(llvm::SROAPass());
3215#endif
3216
3217 LoopPassManager LPM1;
3218 LPM1.addPass(LoopDeletionPass());
3219 OptimizerPM2.addPass(createFunctionToLoopPassAdaptor(std::move(LPM1)));
3220
3221 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM2)));
3222 MPM.addPass(GlobalOptPass());
3223 };
3224 // TODO need for perf reasons to move Enzyme pass to the pre vectorization.
3225 PB.registerOptimizerEarlyEPCallback(loadPass);
3226
3227 auto loadNVVM = [](ModulePassManager &MPM, OptimizationLevel) {
3228 MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true));
3229 };
3230
3231 // We should register at vectorizer start for consistency, however,
3232 // that requires a functionpass, and we have a modulepass.
3233 // PB.registerVectorizerStartEPCallback(loadPass);
3234 PB.registerPipelineStartEPCallback(loadNVVM);
3235 PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadNVVM);
3236
3237 auto preLTOPass = [](ModulePassManager &MPM, OptimizationLevel Level) {
3238 // Create a function that performs CFI checks for cross-DSO calls with
3239 // targets in the current module.
3240 MPM.addPass(CrossDSOCFIPass());
3241
3242 if (Level == OptimizationLevel::O0) {
3243 return;
3244 }
3245
3246 // Try to run OpenMP optimizations, quick no-op if no OpenMP metadata
3247 // present.
3248#if LLVM_VERSION_MAJOR >= 16
3249 MPM.addPass(OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink));
3250#else
3251 MPM.addPass(OpenMPOptPass());
3252#endif
3253
3254 // Remove unused virtual tables to improve the quality of code
3255 // generated by whole-program devirtualization and bitset lowering.
3256 MPM.addPass(GlobalDCEPass());
3257
3258 // Do basic inference of function attributes from known properties of
3259 // system libraries and other oracles.
3260 MPM.addPass(InferFunctionAttrsPass());
3261
3262 if (Level.getSpeedupLevel() > 1) {
3263 MPM.addPass(createModuleToFunctionPassAdaptor(CallSiteSplittingPass(),
3265
3266 // Indirect call promotion. This should promote all the targets that
3267 // are left by the earlier promotion pass that promotes intra-module
3268 // targets. This two-step promotion is to save the compile time. For
3269 // LTO, it should produce the same result as if we only do promotion
3270 // here.
3271 // MPM.addPass(PGOIndirectCallPromotion(
3272 // true /* InLTO */, PGOOpt && PGOOpt->Action ==
3273 // PGOOptions::SampleUse));
3274
3275 // Propagate constants at call sites into the functions they call.
3276 // This opens opportunities for globalopt (and inlining) by
3277 // substituting function pointers passed as arguments to direct uses
3278 // of functions.
3279#if LLVM_VERSION_MAJOR >= 23
3280 MPM.addPass(IPSCCPPass(IPSCCPOptions(/*AllowFuncSpec=*/true)));
3281#elif LLVM_VERSION_MAJOR >= 16
3282 MPM.addPass(IPSCCPPass(IPSCCPOptions(/*AllowFuncSpec=*/
3283 Level != OptimizationLevel::Os &&
3284 Level != OptimizationLevel::Oz)));
3285#else
3286 MPM.addPass(IPSCCPPass());
3287#endif
3288
3289 // Attach metadata to indirect call sites indicating the set of
3290 // functions they may target at run-time. This should follow IPSCCP.
3291 MPM.addPass(CalledValuePropagationPass());
3292 }
3293
3294 // Now deduce any function attributes based in the current code.
3295 MPM.addPass(
3296 createModuleToPostOrderCGSCCPassAdaptor(PostOrderFunctionAttrsPass()));
3297
3298 // Do RPO function attribute inference across the module to
3299 // forward-propagate attributes where applicable.
3300 // FIXME: Is this really an optimization rather than a
3301 // canonicalization?
3302 MPM.addPass(ReversePostOrderFunctionAttrsPass());
3303
3304 // Use in-range annotations on GEP indices to split globals where
3305 // beneficial.
3306 MPM.addPass(GlobalSplitPass());
3307
3308 // Run whole program optimization of virtual call when the list of
3309 // callees is fixed. MPM.addPass(WholeProgramDevirtPass(ExportSummary,
3310 // nullptr));
3311
3312 // Stop here at -O1.
3313 if (Level == OptimizationLevel::O1) {
3314 return;
3315 }
3316
3317 // Optimize globals to try and fold them into constants.
3318 MPM.addPass(GlobalOptPass());
3319
3320 // Promote any localized globals to SSA registers.
3321 MPM.addPass(createModuleToFunctionPassAdaptor(PromotePass()));
3322
3323 // Linking modules together can lead to duplicate global constant,
3324 // only keep one copy of each constant.
3325 MPM.addPass(ConstantMergePass());
3326
3327 // Remove unused arguments from functions.
3328 MPM.addPass(DeadArgumentEliminationPass());
3329
3330 // Reduce the code after globalopt and ipsccp. Both can open up
3331 // significant simplification opportunities, and both can propagate
3332 // functions through function pointers. When this happens, we often
3333 // have to resolve varargs calls, etc, so let instcombine do this.
3334 FunctionPassManager PeepholeFPM;
3335 PeepholeFPM.addPass(InstCombinePass());
3336 if (Level.getSpeedupLevel() > 1)
3337 PeepholeFPM.addPass(AggressiveInstCombinePass());
3338
3339 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(PeepholeFPM),
3341
3342 // Note: historically, the PruneEH pass was run first to deduce
3343 // nounwind and generally clean up exception handling overhead. It
3344 // isn't clear this is valuable as the inliner doesn't currently care
3345 // whether it is inlining an invoke or a call. Run the inliner now.
3346 if (EnableModuleInliner) {
3347 MPM.addPass(ModuleInlinerPass(getInlineParamsFromOptLevel(Level),
3349 ThinOrFullLTOPhase::FullLTOPostLink));
3350 } else {
3351 MPM.addPass(ModuleInlinerWrapperPass(
3353 /* MandatoryFirst */ true,
3354 InlineContext{ThinOrFullLTOPhase::FullLTOPostLink,
3355 InlinePass::CGSCCInliner}));
3356 }
3357
3358 // Perform context disambiguation after inlining, since that would
3359 // reduce the amount of additional cloning required to distinguish the
3360 // allocation contexts. if (EnableMemProfContextDisambiguation)
3361 // MPM.addPass(MemProfContextDisambiguation());
3362
3363 // Optimize globals again after we ran the inliner.
3364 MPM.addPass(GlobalOptPass());
3365
3366 // Run the OpenMPOpt pass again after global optimizations.
3367#if LLVM_VERSION_MAJOR >= 16
3368 MPM.addPass(OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink));
3369#else
3370 MPM.addPass(OpenMPOptPass());
3371#endif
3372
3373 // Garbage collect dead functions.
3374 MPM.addPass(GlobalDCEPass());
3375
3376 // If we didn't decide to inline a function, check to see if we can
3377 // transform it to pass arguments by value instead of by reference.
3378 MPM.addPass(
3379 createModuleToPostOrderCGSCCPassAdaptor(ArgumentPromotionPass()));
3380
3381 FunctionPassManager FPM;
3382 // The IPO Passes may leave cruft around. Clean up after them.
3383 FPM.addPass(InstCombinePass());
3384
3386 FPM.addPass(ConstraintEliminationPass());
3387
3388 FPM.addPass(JumpThreadingPass());
3389
3390 // Do a post inline PGO instrumentation and use pass. This is a context
3391 // sensitive PGO pass.
3392#if 0
3393 if (PGOOpt) {
3394 if (PGOOpt->CSAction == PGOOptions::CSIRInstr)
3395 addPGOInstrPasses(MPM, Level, /* RunProfileGen */ true,
3396 /* IsCS */ true, PGOOpt->CSProfileGenFile,
3397 PGOOpt->ProfileRemappingFile,
3398 ThinOrFullLTOPhase::FullLTOPostLink, PGOOpt->FS);
3399 else if (PGOOpt->CSAction == PGOOptions::CSIRUse)
3400 addPGOInstrPasses(MPM, Level, /* RunProfileGen */ false,
3401 /* IsCS */ true, PGOOpt->ProfileFile,
3402 PGOOpt->ProfileRemappingFile,
3403 ThinOrFullLTOPhase::FullLTOPostLink, PGOOpt->FS);
3404 }
3405#endif
3406
3407 // Break up allocas
3408#if LLVM_VERSION_MAJOR >= 16
3409 FPM.addPass(SROAPass(SROAOptions::ModifyCFG));
3410#else
3411 FPM.addPass(SROAPass());
3412#endif
3413
3414 // LTO provides additional opportunities for tailcall elimination due
3415 // to link-time inlining, and visibility of nocapture attribute.
3416 FPM.addPass(TailCallElimPass());
3417
3418 // Run a few AA driver optimizations here and now to cleanup the code.
3419 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM),
3421
3422 MPM.addPass(
3423 createModuleToPostOrderCGSCCPassAdaptor(PostOrderFunctionAttrsPass()));
3424
3425 // Require the GlobalsAA analysis for the module so we can query it
3426 // within MainFPM.
3427 MPM.addPass(RequireAnalysisPass<GlobalsAA, Module>());
3428 };
3429
3430 auto loadLTO = [preLTOPass, loadPass](ModulePassManager &MPM,
3431 OptimizationLevel Level) {
3432 preLTOPass(MPM, Level);
3433 MPM.addPass(
3434 createModuleToPostOrderCGSCCPassAdaptor(PostOrderFunctionAttrsPass()));
3435
3436 // Require the GlobalsAA analysis for the module so we can query it
3437 // within MainFPM.
3438 MPM.addPass(RequireAnalysisPass<GlobalsAA, Module>());
3439
3440 // Invalidate AAManager so it can be recreated and pick up the newly
3441 // available GlobalsAA.
3442 MPM.addPass(
3443 createModuleToFunctionPassAdaptor(InvalidateAnalysisPass<AAManager>()));
3444
3445 FunctionPassManager MainFPM;
3446#if LLVM_VERSION_MAJOR >= 22
3447 MainFPM.addPass(createFunctionToLoopPassAdaptor(
3449 /*AllowSpeculation=*/true),
3450 /*USeMemorySSA=*/true));
3451#else
3452 MainFPM.addPass(createFunctionToLoopPassAdaptor(
3454 /*AllowSpeculation=*/true),
3455 /*USeMemorySSA=*/true, /*UseBlockFrequencyInfo=*/false));
3456#endif
3457
3458 if (RunNewGVN)
3459 MainFPM.addPass(NewGVNPass());
3460 else
3461 MainFPM.addPass(GVNPass());
3462
3463 // Remove dead memcpy()'s.
3464 MainFPM.addPass(MemCpyOptPass());
3465
3466 // Nuke dead stores.
3467 MainFPM.addPass(DSEPass());
3468#if LLVM_VERSION_MAJOR >= 17
3469 MainFPM.addPass(MoveAutoInitPass());
3470#endif
3471 MainFPM.addPass(MergedLoadStoreMotionPass());
3472
3473 LoopPassManager LPM;
3474 if (EnableLoopFlatten && Level.getSpeedupLevel() > 1)
3475 LPM.addPass(LoopFlattenPass());
3476 LPM.addPass(IndVarSimplifyPass());
3477 LPM.addPass(LoopDeletionPass());
3478 // FIXME: Add loop interchange.
3479
3480#if LLVM_VERSION_MAJOR >= 20
3481 loadPass(MPM, Level, ThinOrFullLTOPhase::None);
3482#else
3483 loadPass(MPM, Level);
3484#endif
3485 };
3486 PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadLTO);
3487}
3488
3489bool registerFixupJuliaPass(llvm::StringRef Name, llvm::ModulePassManager &MPM);
3490
3491extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB,
3492 bool augment = false) {
3493 if (augment) {
3495 }
3496 PB.registerPipelineParsingCallback(
3497 [](llvm::StringRef Name, llvm::ModulePassManager &MPM,
3498 llvm::ArrayRef<llvm::PassBuilder::PipelineElement>) {
3499 if (Name == "enzyme") {
3500 MPM.addPass(EnzymeNewPM());
3501 return true;
3502 }
3503 if (registerFixupJuliaPass(Name, MPM)) {
3504 return true;
3505 }
3506 if (Name == "preserve-nvvm") {
3507 MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true));
3508 return true;
3509 }
3510 if (Name == "preserve-nvvm-end") {
3511 MPM.addPass(PreserveNVVMNewPM(/*Begin*/ false));
3512 return true;
3513 }
3514 if (Name == "print-type-analysis") {
3515 MPM.addPass(TypeAnalysisPrinterNewPM());
3516 return true;
3517 }
3518 if (Name == "print-activity-analysis") {
3519 MPM.addPass(ActivityAnalysisPrinterNewPM());
3520 return true;
3521 }
3522 return false;
3523 });
3524 PB.registerPipelineParsingCallback(
3525 [](llvm::StringRef Name, llvm::FunctionPassManager &FPM,
3526 llvm::ArrayRef<llvm::PassBuilder::PipelineElement>) {
3527 if (Name == "jl-inst-simplify") {
3528 FPM.addPass(JLInstSimplifyNewPM());
3529 return true;
3530 }
3531 if (Name == "simple-gvn") {
3532 FPM.addPass(SimpleGVNNewPM());
3533 return true;
3534 }
3535 return false;
3536 });
3537}
3538
3539extern "C" void registerEnzyme(llvm::PassBuilder &PB) {
3540#ifdef ENZYME_RUNPASS
3541 registerEnzymeAndPassPipeline(PB, /*augment*/ true);
3542#else
3543 registerEnzymeAndPassPipeline(PB, /*augment*/ false);
3544#endif
3545}
3546
3547extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
3549 return {LLVM_PLUGIN_API_VERSION, "EnzymeNewPM", "v0.1", registerEnzyme};
3550}
llvm::cl::opt< bool > EnzymeEnable
static void loadPass(const PassManagerBuilder &Builder, legacy::PassManagerBase &PM)
#define EagerlyInvalidateAnalyses
Definition Enzyme.cpp:3138
#define EnableLoopFlatten
Definition Enzyme.cpp:3137
void augmentPassBuilder(llvm::PassBuilder &PB)
Definition Enzyme.cpp:3148
::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK llvmGetPassPluginInfo()
Definition Enzyme.cpp:3548
ModulePass * createEnzymePass(bool PostOpt)
Definition Enzyme.cpp:3043
llvm::cl::opt< bool > EnzymeOMPOpt("enzyme-omp-opt", cl::init(false), cl::Hidden, cl::desc("Whether to enable openmp opt"))
static InlineParams getInlineParamsFromOptLevel(OptimizationLevel Level)
Definition Enzyme.cpp:3124
void AddEnzymePass(LLVMPassManagerRef PM)
Definition Enzyme.cpp:3050
llvm::cl::opt< bool > EnzymeAttributor("enzyme-attributor", cl::init(false), cl::Hidden, cl::desc("Run attributor post Enzyme"))
#define EnableModuleInliner
Definition Enzyme.cpp:3145
#define EnableConstraintElimination
Definition Enzyme.cpp:3140
llvm::cl::opt< bool > EnzymeDetectReadThrow("enzyme-detect-readthrow", cl::init(true), cl::Hidden, cl::desc("Run preprocessing detect readonly or throw optimization"))
void registerEnzyme(llvm::PassBuilder &PB)
Definition Enzyme.cpp:3539
#define UseInlineAdvisor
Definition Enzyme.cpp:3141
void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB, bool augment=false)
Definition Enzyme.cpp:3491
llvm::cl::opt< std::string > EnzymeTruncateAll("enzyme-truncate-all", cl::init(""), cl::Hidden, cl::desc("Truncate all floating point operations. " "E.g. \"64to32\" or \"64to<exponent_width>-<significand_width>\"."))
static RegisterPass< EnzymeOldPM > X("enzyme", "Enzyme Pass")
llvm::cl::opt< bool > EnzymePostOpt("enzyme-postopt", cl::init(false), cl::Hidden, cl::desc("Run enzymepostprocessing optimizations"))
#define EnableMatrix
Definition Enzyme.cpp:3144
bool registerFixupJuliaPass(llvm::StringRef Name, llvm::ModulePassManager &MPM)
#define RunNewGVN
Definition Enzyme.cpp:3139
bool DetectReadonlyOrThrow(Module &M)
bool LowerSparsification(llvm::Function *F, bool replaceAll)
Lower __enzyme_todense, returning if changed.
constexpr const char * to_string(ActivityAnalyzer::UseActivity UA)
llvm::cl::opt< bool > EnzymePrint
constexpr char EnzymeFPRTPrefix[]
Definition EnzymeLogic.h:57
TruncateMode
@ TruncOpFullModuleMode
@ TruncOpMode
@ TruncMemMode
Value * simplifyLoad(Value *V, size_t valSz, size_t preOffset)
Definition Utils.cpp:3386
bool attributeKnownFunctions(llvm::Function &F)
Definition Utils.cpp:114
Function * GetFunctionFromValue(Value *fn)
Definition Utils.cpp:3547
llvm::FastMathFlags getFast()
Get LLVM fast math flags.
Definition Utils.cpp:3731
static llvm::StringRef getFuncName(llvm::Function *called)
Definition Utils.h:1260
BATCH_TYPE
Definition Utils.h:385
@ Args
Return is a struct of all args.
static bool startsWith(llvm::StringRef string, llvm::StringRef prefix)
Definition Utils.h:713
DIFFE_TYPE
Potential differentiable argument classifications.
Definition Utils.h:374
static llvm::PointerType * getUnqual(llvm::Type *T)
Definition Utils.h:1179
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
Definition Utils.h:203
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
Definition Utils.h:1174
ProbProgMode
Definition Utils.h:399
static void addCallSiteNoCapture(llvm::CallBase *call, size_t idx)
Definition Utils.h:2289
static bool isWriteOnly(const llvm::Function *F, ssize_t arg=-1)
Definition Utils.h:1788
DerivativeMode
Definition Utils.h:390
static DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode, bool integersAreConstant, std::set< llvm::Type * > &seen)
Attempt to automatically detect the differentiable classification based off of a given type.
Definition Utils.h:519
return structtype if recursive function
llvm::Function * fn
llvm::Type * tapeType
return structtype if recursive function
std::map< AugmentedStruct, int > returns
Map from information desired from a augmented return to its index in the returned struct.
Concrete SubType of a given value.
llvm::Function * CreateTruncateFunc(RequestContext context, llvm::Function *tobatch, FloatTruncation truncation, TruncateMode mode)
bool CreateTruncateValue(RequestContext context, llvm::Value *addr, FloatRepresentation from, FloatRepresentation to, bool isTruncate)
const AugmentedReturn & CreateAugmentedPrimal(RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, llvm::ArrayRef< DIFFE_TYPE > constant_args, TypeAnalysis &TA, bool returnUsed, bool shadowReturnUsed, const FnTypeInfo &typeInfo, bool subsequent_calls_may_write, const std::vector< bool > _overwritten_args, bool forceAnonymousTape, bool runtimeActivity, bool strongZero, unsigned width, bool AtomicAdd, bool omp=false)
Create an augmented forward pass.
llvm::Function * CreateForwardDiff(RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, llvm::ArrayRef< DIFFE_TYPE > constant_args, TypeAnalysis &TA, bool returnValue, DerivativeMode mode, bool freeMemory, bool runtimeActivity, bool strongZero, unsigned width, llvm::Type *additionalArg, const FnTypeInfo &typeInfo, bool subsequent_calls_may_write, const std::vector< bool > _overwritten_args, const AugmentedReturn *augmented, bool omp=false)
Create the forward (or forward split) mode derivative function.
llvm::Function * CreateTrace(RequestContext context, llvm::Function *totrace, const llvm::SmallPtrSetImpl< llvm::Function * > &sampleFunctions, const llvm::SmallPtrSetImpl< llvm::Function * > &observeFunctions, const llvm::StringSet<> &ActiveRandomVariables, ProbProgMode mode, bool autodiff, TraceInterface *interface)
Create a traced version of a function context the instruction which requested this trace (or null).
PreProcessCache PPC
llvm::Function * CreatePrimalAndGradient(RequestContext context, const ReverseCacheKey &&key, TypeAnalysis &TA, const AugmentedReturn *augmented, bool omp=false)
Create the reverse pass, or combined forward+reverse derivative function.
llvm::Function * CreateBatch(RequestContext context, llvm::Function *tobatch, unsigned width, llvm::ArrayRef< BATCH_TYPE > arg_types, BATCH_TYPE ret_type)
Create a function batched in its inputs.
bool PostOpt
PostOpt is whether to perform basic optimization of the function after synthesis
static bool isRequired()
Definition Enzyme.cpp:3075
EnzymeNewPM(bool PostOpt=false)
Definition Enzyme.cpp:3068
llvm::PreservedAnalyses Result
Definition Enzyme.cpp:3067
Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM)
Definition Enzyme.cpp:3070
static llvm::Constant * GetOrCreateShadowConstant(RequestContext context, EnzymeLogic &Logic, llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, llvm::Constant *F, DerivativeMode mode, bool runtimeActivity, bool strongZero, unsigned width, bool AtomicAdd)
static llvm::Type * getShadowType(llvm::Type *ty, unsigned width)
llvm::ModuleAnalysisManager MAM
std::map< std::pair< llvm::Function *, DerivativeMode >, llvm::Function * > cache
llvm::FunctionAnalysisManager FAM
Full interprocedural TypeAnalysis.
TypeResults analyzeFunction(const FnTypeInfo &fn)
Analyze a particular function, returning the results.
FnTypeInfo getAnalyzedTypeInfo() const
The TypeInfo calling convention.
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
Definition TypeTree.h:72
TypeTree Only(int Off, llvm::Instruction *orig) const
Prepend an offset to all mappings.
Definition TypeTree.h:471
bool insert(const std::vector< int > Seq, ConcreteType CT, bool PointerIntSame=false)
Return if changed.
Definition TypeTree.h:234
cl::opt< unsigned > SetLicmMssaOptCap
cl::opt< unsigned > SetLicmMssaNoAccForPromotionCap
Struct containing all contextual type information for a particular function call.
std::map< llvm::Argument *, TypeTree > Arguments
Types of arguments.
llvm::Function * Function
Function being analyzed.
TypeTree Return
Type of return.
std::map< llvm::Argument *, std::set< int64_t > > KnownValues
The specific constant(s) known to represented by an argument, if constant.
todiff is the function to differentiate retType is the activity info of the return.