Enzyme main
Loading...
Searching...
No Matches
CApi.cpp
Go to the documentation of this file.
1//===- CApi.cpp - Enzyme API exported to C for external use -----------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file defines various utility functions of Enzyme for access via C
22//
23//===----------------------------------------------------------------------===//
24#include "CApi.h"
25#include "Utils.h"
26#if LLVM_VERSION_MAJOR >= 16
27#define private public
28#include "llvm/Analysis/ScalarEvolution.h"
29#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
30#undef private
31#else
32#include "SCEV/ScalarEvolution.h"
33#include "SCEV/ScalarEvolutionExpander.h"
34#endif
35
36#include "DiffeGradientUtils.h"
38#include "EnzymeLogic.h"
39#include "GradientUtils.h"
40#include "LibraryFuncs.h"
41#if LLVM_VERSION_MAJOR >= 16
42#include "llvm/Analysis/TargetLibraryInfo.h"
43#else
44#include "SCEV/TargetLibraryInfo.h"
45#endif
46#include "TraceInterface.h"
47
48// #include "llvm/ADT/Triple.h"
49#include "llvm/Analysis/CallGraph.h"
50#include "llvm/Analysis/GlobalsModRef.h"
51#include "llvm/IR/DIBuilder.h"
52#include "llvm/IR/MDBuilder.h"
53#include "llvm/Transforms/Utils/Cloning.h"
54
55#include "llvm/IR/LegacyPassManager.h"
56#include "llvm/Transforms/IPO/Attributor.h"
57
58#define addAttribute addAttributeAtIndex
59#define removeAttribute removeAttributeAtIndex
60#define getAttribute getAttributeAtIndex
61#define hasAttribute hasAttributeAtIndex
62
63using namespace llvm;
64
65TargetLibraryInfo eunwrap(LLVMTargetLibraryInfoRef P) {
66 return TargetLibraryInfo(*reinterpret_cast<TargetLibraryInfoImpl *>(P));
67}
68
70
74
84
85ConcreteType eunwrap(CConcreteType CDT, llvm::LLVMContext &ctx) {
86 switch (CDT) {
87 case DT_Anything:
88 return BaseType::Anything;
89 case DT_Integer:
90 return BaseType::Integer;
91 case DT_Pointer:
92 return BaseType::Pointer;
93 case DT_Half:
94 return ConcreteType(llvm::Type::getHalfTy(ctx));
95 case DT_Float:
96 return ConcreteType(llvm::Type::getFloatTy(ctx));
97 case DT_Double:
98 return ConcreteType(llvm::Type::getDoubleTy(ctx));
99 case DT_X86_FP80:
100 return ConcreteType(llvm::Type::getX86_FP80Ty(ctx));
101 case DT_BFloat16:
102 return ConcreteType(llvm::Type::getBFloatTy(ctx));
103 case DT_FP128:
104 return ConcreteType(llvm::Type::getFP128Ty(ctx));
105 case DT_Unknown:
106 return BaseType::Unknown;
107 }
108 llvm_unreachable("Unknown concrete type to unwrap");
109}
110
111std::vector<int> eunwrap(IntList IL) {
112 std::vector<int> v;
113 for (size_t i = 0; i < IL.size; i++) {
114 v.push_back((int)IL.data[i]);
115 }
116 return v;
117}
118std::set<int64_t> eunwrap64(IntList IL) {
119 std::set<int64_t> v;
120 for (size_t i = 0; i < IL.size; i++) {
121 v.insert((int64_t)IL.data[i]);
122 }
123 return v;
124}
125TypeTree eunwrap(CTypeTreeRef CTT) { return *(TypeTree *)CTT; }
126
128 if (auto flt = CT.isFloat()) {
129 if (flt->isHalfTy())
130 return DT_Half;
131 if (flt->isFloatTy())
132 return DT_Float;
133 if (flt->isDoubleTy())
134 return DT_Double;
135 if (flt->isX86_FP80Ty())
136 return DT_X86_FP80;
137 if (flt->isBFloatTy())
138 return DT_BFloat16;
139 if (flt->isFP128Ty())
140 return DT_FP128;
141 } else {
142 switch (CT.SubTypeEnum) {
144 return DT_Integer;
146 return DT_Pointer;
148 return DT_Anything;
150 return DT_Unknown;
151 case BaseType::Float:
152 llvm_unreachable("Illegal conversion of concretetype");
153 }
154 }
155 llvm_unreachable("Illegal conversion of concretetype");
156}
157
158IntList ewrap(const std::vector<int> &offsets) {
159 IntList IL;
160 IL.size = offsets.size();
161 IL.data = new int64_t[IL.size];
162 for (size_t i = 0; i < offsets.size(); i++) {
163 IL.data[i] = offsets[i];
164 }
165 return IL;
166}
167
169 return (CTypeTreeRef)(new TypeTree(TT));
170}
171
172FnTypeInfo eunwrap(CFnTypeInfo CTI, llvm::Function *F) {
173 FnTypeInfo FTI(F);
174 // auto &ctx = F->getContext();
175 FTI.Return = eunwrap(CTI.Return);
176
177 size_t argnum = 0;
178 for (auto &arg : F->args()) {
179 FTI.Arguments[&arg] = eunwrap(CTI.Arguments[argnum]);
180 FTI.KnownValues[&arg] = eunwrap64(CTI.KnownValues[argnum]);
181 argnum++;
182 }
183 return FTI;
184}
185
186extern "C" {
187
188void EnzymeSetCLBool(void *ptr, uint8_t val) {
189 auto cl = (llvm::cl::opt<bool> *)ptr;
190 cl->setValue((bool)val);
191}
192
193uint8_t EnzymeGetCLBool(void *ptr) {
194 auto cl = (llvm::cl::opt<bool> *)ptr;
195 return (uint8_t)(bool)cl->getValue();
196}
197
198void EnzymeSetCLInteger(void *ptr, int64_t val) {
199 auto cl = (llvm::cl::opt<int> *)ptr;
200 cl->setValue((int)val);
201}
202
203int64_t EnzymeGetCLInteger(void *ptr) {
204 auto cl = (llvm::cl::opt<int> *)ptr;
205 return (int64_t)cl->getValue();
206}
207
208void EnzymeSetCLString(void *ptr, const char *val) {
209 if (auto *clopt = static_cast<cl::opt<std::string> *>(ptr))
210 clopt->setValue(val);
211}
212
214 return (EnzymeLogicRef)(new EnzymeLogic((bool)PostOpt));
215}
216
217void EnzymeLogicSetExternalContext(EnzymeLogicRef Ref, void *ExternalContext) {
218 eunwrap(Ref).ExternalContext = ExternalContext;
219}
220
222 return eunwrap(Ref).ExternalContext;
223}
224
228
230 LLVMContextRef C, LLVMValueRef getTraceFunction,
231 LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction,
232 LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction,
233 LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction,
234 LLVMValueRef insertChoiceGradientFunction,
235 LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction,
236 LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction,
237 LLVMValueRef hasChoiceFunction) {
239 *unwrap(C), cast<Function>(unwrap(getTraceFunction)),
240 cast<Function>(unwrap(getChoiceFunction)),
241 cast<Function>(unwrap(insertCallFunction)),
242 cast<Function>(unwrap(insertChoiceFunction)),
243 cast<Function>(unwrap(insertArgumentFunction)),
244 cast<Function>(unwrap(insertReturnFunction)),
245 cast<Function>(unwrap(insertFunctionFunction)),
246 cast<Function>(unwrap(insertChoiceGradientFunction)),
247 cast<Function>(unwrap(insertArgumentGradientFunction)),
248 cast<Function>(unwrap(newTraceFunction)),
249 cast<Function>(unwrap(freeTraceFunction)),
250 cast<Function>(unwrap(hasCallFunction)),
251 cast<Function>(unwrap(hasChoiceFunction))));
252};
253
255CreateEnzymeDynamicTraceInterface(LLVMValueRef interface, LLVMValueRef F) {
257 unwrap(interface), cast<Function>(unwrap(F))));
258}
259
260void ClearEnzymeLogic(EnzymeLogicRef Ref) { eunwrap(Ref).clear(); }
261
263 auto &Logic = eunwrap(Ref);
264 for (const auto &pair : Logic.PPC.cache)
265 pair.second->eraseFromParent();
266}
267
268void FreeEnzymeLogic(EnzymeLogicRef Ref) { delete (EnzymeLogic *)Ref; }
269
273
275 char **customRuleNames,
276 CustomRuleType *customRules,
277 size_t numRules) {
278 EnzymeLogic &Logic = eunwrap(Log);
279 TypeAnalysis *TA = new TypeAnalysis(Logic);
280 for (size_t i = 0; i < numRules; i++) {
281 CustomRuleType rule = customRules[i];
282 TA->CustomRules[customRuleNames[i]] =
283 [=](int direction, TypeTree &returnTree, ArrayRef<TypeTree> argTrees,
284 ArrayRef<std::set<int64_t>> knownValues, CallBase *call,
285 TypeAnalyzer *TA) -> uint8_t {
286 CTypeTreeRef creturnTree = (CTypeTreeRef)(&returnTree);
287 CTypeTreeRef *cargs = new CTypeTreeRef[argTrees.size()];
288 IntList *kvs = new IntList[argTrees.size()];
289 for (size_t i = 0; i < argTrees.size(); ++i) {
290 cargs[i] = (CTypeTreeRef)(&(argTrees[i]));
291 kvs[i].size = knownValues[i].size();
292 kvs[i].data = new int64_t[kvs[i].size];
293 size_t j = 0;
294 for (auto val : knownValues[i]) {
295 kvs[i].data[j] = val;
296 j++;
297 }
298 }
299 uint8_t result = rule(direction, creturnTree, cargs, kvs, argTrees.size(),
300 wrap(call), TA);
301 delete[] cargs;
302 for (size_t i = 0; i < argTrees.size(); ++i) {
303 delete[] kvs[i].data;
304 }
305 delete[] kvs;
306 return result;
307 };
308 }
309 return (EnzymeTypeAnalysisRef)TA;
310}
311
313
315 TypeAnalysis *TA = (TypeAnalysis *)TAR;
316 delete TA;
317}
318
322
324 LLVMValueRef F) {
325 FnTypeInfo FTI(eunwrap(CTI, cast<Function>(unwrap(F))));
326 return (void *)((TypeAnalysis *)TAR)->analyzeFunction(FTI).analyzer;
327}
328
330 return (void *)&G->TR.analyzer;
331}
332
336
337void EnzymeGradientUtilsErase(GradientUtils *G, LLVMValueRef I) {
338 return G->erase(cast<Instruction>(unwrap(I)));
339}
341 LLVMValueRef orig, uint8_t erase) {
342 return G->eraseWithPlaceholder(cast<Instruction>(unwrap(I)),
343 cast<Instruction>(unwrap(orig)),
344 "_replacementABI", erase != 0);
345}
346
348 LLVMValueRef B) {
349 return G->replaceAWithB(unwrap(A), unwrap(B));
350}
351
353 CustomShadowFree FHandle) {
354 shadowHandlers[Name] = [=](IRBuilder<> &B, CallInst *CI,
355 ArrayRef<Value *> Args,
356 GradientUtils *gutils) -> llvm::Value * {
357 SmallVector<LLVMValueRef, 3> refs;
358 for (auto a : Args)
359 refs.push_back(wrap(a));
360 return unwrap(
361 AHandle(wrap(&B), wrap(CI), Args.size(), refs.data(), gutils));
362 };
363 if (FHandle)
364 shadowErasers[Name] = [=](IRBuilder<> &B,
365 Value *ToFree) -> llvm::CallInst * {
366 return cast_or_null<CallInst>(unwrap(FHandle(wrap(&B), wrap(ToFree))));
367 };
368}
369
370void EnzymeRegisterCallHandler(const char *Name,
372 CustomFunctionReverse RevHandle) {
373 auto &pair = customCallHandlers[Name];
374 pair.first = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils,
375 Value *&normalReturn, Value *&shadowReturn,
376 Value *&tape) -> bool {
377 LLVMValueRef normalR = wrap(normalReturn);
378 LLVMValueRef shadowR = wrap(shadowReturn);
379 LLVMValueRef tapeR = wrap(tape);
380 uint8_t noMod =
381 FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR, &tapeR);
382 normalReturn = unwrap(normalR);
383 shadowReturn = unwrap(shadowR);
384 tape = unwrap(tapeR);
385 return noMod != 0;
386 };
387 pair.second = [=](IRBuilder<> &B, CallInst *CI, DiffeGradientUtils &gutils,
388 Value *tape) {
389 RevHandle(wrap(&B), wrap(CI), &gutils, wrap(tape));
390 };
391}
392
394 auto &pair = customFwdCallHandlers[Name];
395 pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils,
396 Value *&normalReturn, Value *&shadowReturn) -> bool {
397 LLVMValueRef normalR = wrap(normalReturn);
398 LLVMValueRef shadowR = wrap(shadowReturn);
399 uint8_t noMod = FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR);
400 normalReturn = unwrap(normalR);
401 shadowReturn = unwrap(shadowR);
402 return noMod != 0;
403 };
404}
405
407 CustomFunctionDiffUse Handle) {
408 auto &pair = customDiffUseHandlers[Name];
409 pair = [=](const CallInst *CI, const GradientUtils *gutils, const Value *arg,
410 bool isshadow, DerivativeMode mode, bool &useDefault) -> bool {
411 uint8_t useDefaultC = 0;
412 uint8_t noMod = Handle(wrap(CI), gutils, wrap(arg), isshadow,
413 (CDerivativeMode)(mode), &useDefaultC);
414 useDefault = useDefaultC != 0;
415 return noMod != 0;
416 };
417}
418
420 return gutils->runtimeActivity;
421}
422
426
428 return gutils->strongZero;
429}
430
432 return gutils->AtomicAdd;
433}
434
436 return gutils->getWidth();
437}
438
440 LLVMTypeRef T) {
441 return wrap(gutils->getShadowType(unwrap(T)));
442}
443
444LLVMTypeRef EnzymeGetShadowType(uint64_t width, LLVMTypeRef T) {
445 return wrap(GradientUtils::getShadowType(unwrap(T), width));
446}
447
449 LLVMValueRef val) {
450 return wrap(gutils->getNewFromOriginal(unwrap(val)));
451}
452
453void EnzymeReplaceOriginalToNew(GradientUtils *gutils, LLVMValueRef origC,
454 LLVMValueRef repC) {
455 auto orig = cast<Instruction>(unwrap(origC));
456 auto rep = cast<Instruction>(unwrap(repC));
457 auto found = gutils->originalToNewFn.find(orig);
458 assert(found != gutils->originalToNewFn.end());
459 auto newCall = found->second;
460 gutils->originalToNewFn[orig] = rep;
461 gutils->newToOriginalFn.erase(newCall);
462 gutils->newToOriginalFn[rep] = orig;
463}
464
468
471 uint8_t foreignFunction) {
472 return (CDIFFE_TYPE)(G->getDiffeType(unwrap(oval), foreignFunction != 0));
473}
474
477 uint8_t *needsPrimal,
478 uint8_t *needsShadow,
479 CDerivativeMode mode) {
480 bool needsPrimalB;
481 bool needsShadowB;
482 auto res = (CDIFFE_TYPE)(G->getReturnDiffeType(
483 unwrap(oval), &needsPrimalB, &needsShadowB, (DerivativeMode)mode));
484 if (needsPrimal)
485 *needsPrimal = needsPrimalB;
486 if (needsShadow)
487 *needsShadow = needsShadowB;
488 return res;
489}
490
492 LLVMValueRef val,
493 LLVMValueRef orig) {
494 return cast<Instruction>(unwrap(val))
495 ->setDebugLoc(gutils->getNewFromOriginal(
496 cast<Instruction>(unwrap(orig))->getDebugLoc()));
497}
498
499LLVMValueRef EnzymeInsertValue(LLVMBuilderRef B, LLVMValueRef val,
500 LLVMValueRef val2, unsigned *sz, int64_t length,
501 const char *name) {
502 return wrap(unwrap(B)->CreateInsertValue(
503 unwrap(val), unwrap(val2), ArrayRef<unsigned>(sz, sz + length), name));
504}
505
506LLVMValueRef EnzymeGradientUtilsLookup(GradientUtils *gutils, LLVMValueRef val,
507 LLVMBuilderRef B) {
508 return wrap(gutils->lookupM(unwrap(val), *unwrap(B)));
509}
510
512 LLVMValueRef val,
513 LLVMBuilderRef B) {
514 return wrap(gutils->invertPointerM(unwrap(val), *unwrap(B)));
515}
516
518 LLVMValueRef val, LLVMBuilderRef B) {
519 return wrap(gutils->diffe(unwrap(val), *unwrap(B)));
520}
521
523 LLVMValueRef diffe, LLVMBuilderRef B,
524 LLVMTypeRef T) {
525 gutils->addToDiffe(unwrap(val), unwrap(diffe), *unwrap(B), unwrap(T));
526}
527
529 DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMValueRef origVal,
530 LLVMTypeRef addingType, unsigned start, unsigned size, LLVMValueRef origptr,
531 LLVMValueRef dif, LLVMBuilderRef BuilderM, unsigned align,
532 LLVMValueRef mask) {
533 MaybeAlign align2;
534 if (align)
535 align2 = MaybeAlign(align);
536 auto inst = cast_or_null<Instruction>(unwrap(orig));
537 gutils->addToInvertedPtrDiffe(inst, unwrap(origVal), unwrap(addingType),
538 start, size, unwrap(origptr), unwrap(dif),
539 *unwrap(BuilderM), align2, unwrap(mask));
540}
541
543 DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMValueRef origVal,
544 CTypeTreeRef vd, unsigned LoadSize, LLVMValueRef origptr,
545 LLVMValueRef prediff, LLVMBuilderRef BuilderM, unsigned align,
546 LLVMValueRef premask) {
547 MaybeAlign align2;
548 if (align)
549 align2 = MaybeAlign(align);
550 auto inst = cast_or_null<Instruction>(unwrap(orig));
551 gutils->addToInvertedPtrDiffe(inst, unwrap(origVal), *(TypeTree *)vd,
552 LoadSize, unwrap(origptr), unwrap(prediff),
553 *unwrap(BuilderM), align2, unwrap(premask));
554}
555
556void EnzymeGradientUtilsSetDiffe(DiffeGradientUtils *gutils, LLVMValueRef val,
557 LLVMValueRef diffe, LLVMBuilderRef B) {
558 gutils->setDiffe(unwrap(val), unwrap(diffe), *unwrap(B));
559}
560
562 LLVMValueRef val) {
563 return gutils->isConstantValue(unwrap(val));
564}
565
567 LLVMValueRef val) {
568 return gutils->isConstantInstruction(cast<Instruction>(unwrap(val)));
569}
570
572 return wrap(gutils->inversionAllocs);
573}
574
576 LLVMValueRef orig, uint8_t *data,
577 uint64_t size) {
578 if (gutils->mode == DerivativeMode::ForwardMode ||
580 return 0;
581
582 if (!gutils->overwritten_args_map_ptr)
583 return 0;
584
585 CallInst *call = cast<CallInst>(unwrap(orig));
586
587 assert(gutils->overwritten_args_map_ptr);
588 auto found = gutils->overwritten_args_map_ptr->find(call);
589 if (found == gutils->overwritten_args_map_ptr->end()) {
590 llvm::errs() << " oldFunc " << *gutils->oldFunc << "\n";
591 for (auto &pair : *gutils->overwritten_args_map_ptr) {
592 llvm::errs() << " + " << *pair.first << "\n";
593 }
594 llvm::errs() << " could not find call orig in overwritten_args_map_ptr "
595 << *call << "\n";
596 }
597 assert(found != gutils->overwritten_args_map_ptr->end());
598
599 const std::vector<bool> &overwritten_args = found->second.second;
600
601 if (size != overwritten_args.size()) {
602 llvm::errs() << " orig: " << *call << "\n";
603 llvm::errs() << " size: " << size
604 << " overwritten_args.size(): " << overwritten_args.size()
605 << "\n";
606 }
607 assert(size == overwritten_args.size());
608 for (uint64_t i = 0; i < size; i++) {
609 data[i] = overwritten_args[i];
610 }
611 return 1;
612}
613
615 LLVMValueRef val) {
616 auto v = unwrap(val);
617 TypeTree TT = gutils->TR.query(v);
618 TypeTree *pTT = new TypeTree(TT);
619 return (CTypeTreeRef)pTT;
620}
621
623 gutils->TR.dump();
624}
625
627 GradientUtils *gutils, CDerivativeMode mode, LLVMTypeRef secretty,
628 uint64_t intrinsic, uint64_t dstAlign, uint64_t srcAlign, uint64_t offset,
629 uint8_t dstConstant, LLVMValueRef shadow_dst, uint8_t srcConstant,
630 LLVMValueRef shadow_src, LLVMValueRef length, LLVMValueRef isVolatile,
631 LLVMValueRef MTI, uint8_t allowForward, uint8_t shadowsLookedUp) {
632 auto orig = unwrap(MTI);
633 assert(orig);
634 SubTransferHelper(gutils, (DerivativeMode)mode, unwrap(secretty),
635 (Intrinsic::ID)intrinsic, (unsigned)dstAlign,
636 (unsigned)srcAlign, (unsigned)offset, (bool)dstConstant,
637 unwrap(shadow_dst), (bool)srcConstant, unwrap(shadow_src),
638 unwrap(length), unwrap(isVolatile), cast<CallInst>(orig),
639 (bool)allowForward, (bool)shadowsLookedUp);
640}
641
643 LLVMBasicBlockRef block,
644 const char *name,
645 uint8_t forkCache,
646 uint8_t push) {
647 return wrap(gutils->addReverseBlock(cast<BasicBlock>(unwrap(block)), name,
648 forkCache, push));
649}
650
652 LLVMBasicBlockRef block) {
653 auto endBlock = cast<BasicBlock>(unwrap(block));
654 auto found = gutils->reverseBlockToPrimal.find(endBlock);
655 assert(found != gutils->reverseBlockToPrimal.end());
656 auto &vec = gutils->reverseBlocks[found->second];
657 assert(vec.size());
658 vec.push_back(endBlock);
659}
660
662 EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
663 LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
664 size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue,
665 CDerivativeMode mode, uint8_t freeMemory, uint8_t runtimeActivity,
666 uint8_t strongZero, unsigned width, LLVMTypeRef additionalArg,
667 CFnTypeInfo typeInfo, uint8_t subsequent_calls_may_write,
668 uint8_t *_overwritten_args, size_t overwritten_args_size,
669 EnzymeAugmentedReturnPtr augmented) {
670 SmallVector<DIFFE_TYPE, 4> nconstant_args((DIFFE_TYPE *)constant_args,
671 (DIFFE_TYPE *)constant_args +
672 constant_args_size);
673 std::vector<bool> overwritten_args;
674 assert(overwritten_args_size == cast<Function>(unwrap(todiff))->arg_size());
675 for (uint64_t i = 0; i < overwritten_args_size; i++) {
676 overwritten_args.push_back(_overwritten_args[i]);
677 }
678 return wrap(eunwrap(Logic).CreateForwardDiff(
679 RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
680 unwrap(request_ip)),
681 cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
682 eunwrap(TA), returnValue, (DerivativeMode)mode, freeMemory,
683 runtimeActivity, strongZero, width, unwrap(additionalArg),
684 eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
685 subsequent_calls_may_write, overwritten_args, eunwrap(augmented)));
686}
688 EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
689 LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
690 size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue,
691 uint8_t dretUsed, CDerivativeMode mode, uint8_t runtimeActivity,
692 uint8_t strongZero, unsigned width, uint8_t freeMemory,
693 LLVMTypeRef additionalArg, uint8_t forceAnonymousTape, CFnTypeInfo typeInfo,
694 uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args,
695 size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented,
696 uint8_t AtomicAdd) {
697 std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
698 (DIFFE_TYPE *)constant_args +
699 constant_args_size);
700 std::vector<bool> overwritten_args;
701 assert(overwritten_args_size == cast<Function>(unwrap(todiff))->arg_size());
702 for (uint64_t i = 0; i < overwritten_args_size; i++) {
703 overwritten_args.push_back(_overwritten_args[i]);
704 }
705 return wrap(eunwrap(Logic).CreatePrimalAndGradient(
706 RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
707 unwrap(request_ip)),
709 .todiff = cast<Function>(unwrap(todiff)),
710 .retType = (DIFFE_TYPE)retType,
711 .constant_args = nconstant_args,
712 .subsequent_calls_may_write = (bool)subsequent_calls_may_write,
713 .overwritten_args = overwritten_args,
714 .returnUsed = (bool)returnValue,
715 .shadowReturnUsed = (bool)dretUsed,
716 .mode = (DerivativeMode)mode,
717 .width = width,
718 .freeMemory = (bool)freeMemory,
719 .AtomicAdd = (bool)AtomicAdd,
720 .additionalType = unwrap(additionalArg),
721 .forceAnonymousTape = (bool)forceAnonymousTape,
722 .typeInfo = eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
723 .runtimeActivity = (bool)runtimeActivity,
724 .strongZero = (bool)strongZero},
725 eunwrap(TA), eunwrap(augmented)));
726}
728 EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
729 LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
730 size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnUsed,
731 uint8_t shadowReturnUsed, CFnTypeInfo typeInfo,
732 uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args,
733 size_t overwritten_args_size, uint8_t forceAnonymousTape,
734 uint8_t runtimeActivity, uint8_t strongZero, unsigned width,
735 uint8_t AtomicAdd) {
736
737 SmallVector<DIFFE_TYPE, 4> nconstant_args((DIFFE_TYPE *)constant_args,
738 (DIFFE_TYPE *)constant_args +
739 constant_args_size);
740 std::vector<bool> overwritten_args;
741 assert(overwritten_args_size == cast<Function>(unwrap(todiff))->arg_size());
742 for (uint64_t i = 0; i < overwritten_args_size; i++) {
743 overwritten_args.push_back(_overwritten_args[i]);
744 }
745 return ewrap(eunwrap(Logic).CreateAugmentedPrimal(
746 RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
747 unwrap(request_ip)),
748 cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
749 eunwrap(TA), returnUsed, shadowReturnUsed,
750 eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
751 subsequent_calls_may_write, overwritten_args, forceAnonymousTape,
752 runtimeActivity, strongZero, width, AtomicAdd));
753}
754
755LLVMValueRef EnzymeCreateBatch(EnzymeLogicRef Logic, LLVMValueRef request_req,
756 LLVMBuilderRef request_ip, LLVMValueRef tobatch,
757 unsigned width, CBATCH_TYPE *arg_types,
758 size_t arg_types_size, CBATCH_TYPE retType) {
759
760 return wrap(eunwrap(Logic).CreateBatch(
761 RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
762 unwrap(request_ip)),
763 cast<Function>(unwrap(tobatch)), width,
764 ArrayRef<BATCH_TYPE>((BATCH_TYPE *)arg_types,
765 (BATCH_TYPE *)arg_types + arg_types_size),
766 (BATCH_TYPE)retType));
767}
768
769LLVMValueRef EnzymeCreateTrace(
770 EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
771 LLVMValueRef totrace, LLVMValueRef *sample_functions,
772 size_t sample_functions_size, LLVMValueRef *observe_functions,
773 size_t observe_functions_size, const char *active_random_variables[],
774 size_t active_random_variables_size, CProbProgMode mode, uint8_t autodiff,
775 EnzymeTraceInterfaceRef interface) {
776
777 SmallPtrSet<Function *, 4> SampleFunctions;
778 for (size_t i = 0; i < sample_functions_size; i++) {
779 SampleFunctions.insert(cast<Function>(unwrap(sample_functions[i])));
780 }
781
782 SmallPtrSet<Function *, 4> ObserveFunctions;
783 for (size_t i = 0; i < observe_functions_size; i++) {
784 ObserveFunctions.insert(cast<Function>(unwrap(observe_functions[i])));
785 }
786
787 StringSet<> ActiveRandomVariables;
788 for (size_t i = 0; i < active_random_variables_size; i++) {
789 ActiveRandomVariables.insert(active_random_variables[i]);
790 }
791
792 return wrap(eunwrap(Logic).CreateTrace(
793 RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
794 unwrap(request_ip)),
795 cast<Function>(unwrap(totrace)), SampleFunctions, ObserveFunctions,
796 ActiveRandomVariables, (ProbProgMode)mode, (bool)autodiff,
797 eunwrap(interface)));
798}
799
800LLVMValueRef
802 auto AR = (AugmentedReturn *)ret;
803 return wrap(AR->fn);
804}
805
806LLVMTypeRef
808 auto AR = (AugmentedReturn *)ret;
809 return wrap(AR->tapeType);
810}
811
812LLVMTypeRef
814 auto AR = (AugmentedReturn *)ret;
815 auto found = AR->returns.find(AugmentedStruct::Tape);
816 if (found == AR->returns.end()) {
817 return wrap((Type *)nullptr);
818 }
819 if (found->second == -1) {
820 return wrap(AR->fn->getReturnType());
821 }
822 return wrap(
823 cast<StructType>(AR->fn->getReturnType())->getTypeAtIndex(found->second));
824}
826 uint8_t *existed, size_t len) {
827 assert(len == 3);
828 auto AR = (AugmentedReturn *)ret;
831 for (size_t i = 0; i < len; i++) {
832 auto found = AR->returns.find(todo[i]);
833 if (found != AR->returns.end()) {
834 existed[i] = true;
835 data[i] = (int64_t)found->second;
836 } else {
837 existed[i] = false;
838 }
839 }
840}
841
842static MDNode *extractMDNode(MetadataAsValue *MAV) {
843 Metadata *MD = MAV->getMetadata();
844 assert((isa<MDNode>(MD) || isa<ConstantAsMetadata>(MD)) &&
845 "Expected a metadata node or a canonicalized constant");
846
847 if (MDNode *N = dyn_cast<MDNode>(MD))
848 return N;
849
850 return MDNode::get(MAV->getContext(), MD);
851}
852
854 TypeTree *Ret = new TypeTree();
855 MDNode *N = Val ? extractMDNode(unwrap<MetadataAsValue>(Val)) : nullptr;
856 Ret->insertFromMD(N);
857 return (CTypeTreeRef)N;
858}
859
860LLVMValueRef EnzymeTypeTreeToMD(CTypeTreeRef CTR, LLVMContextRef ctx) {
861 auto MD = ((TypeTree *)CTR)->toMD(*unwrap(ctx));
862 return wrap(MetadataAsValue::get(MD->getContext(), MD));
863}
864
867 return (CTypeTreeRef)(new TypeTree(eunwrap(CT, *unwrap(ctx))));
868}
872void EnzymeFreeTypeTree(CTypeTreeRef CTT) { delete (TypeTree *)CTT; }
874 return *(TypeTree *)dst = *(TypeTree *)src;
875}
877 return ((TypeTree *)dst)->orIn(*(TypeTree *)src, /*PointerIntSame*/ false);
878}
880 uint8_t *legalP) {
881 bool legal = true;
882 bool res =
883 ((TypeTree *)dst)
884 ->checkedOrIn(*(TypeTree *)src, /*PointerIntSame*/ false, legal);
885 *legalP = legal;
886 return res;
887}
888
889void EnzymeTypeTreeOnlyEq(CTypeTreeRef CTT, int64_t x) {
890 // TODO only inst
891 *(TypeTree *)CTT = ((TypeTree *)CTT)->Only(x, nullptr);
892}
894 *(TypeTree *)CTT = ((TypeTree *)CTT)->Data0();
895}
896
897void EnzymeTypeTreeLookupEq(CTypeTreeRef CTT, int64_t size, const char *dl) {
898 *(TypeTree *)CTT = ((TypeTree *)CTT)->Lookup(size, DataLayout(dl));
899}
901 const char *dl) {
902 ((TypeTree *)CTT)->CanonicalizeInPlace(size, DataLayout(dl));
903}
904
906 return ewrap(((TypeTree *)CTT)->Inner0());
907}
908
909void EnzymeTypeTreeShiftIndiciesEq(CTypeTreeRef CTT, const char *datalayout,
910 int64_t offset, int64_t maxSize,
911 uint64_t addOffset) {
912 DataLayout DL(datalayout);
913 *(TypeTree *)CTT =
914 ((TypeTree *)CTT)->ShiftIndices(DL, offset, maxSize, addOffset);
915}
916void EnzymeTypeTreeInsertEq(CTypeTreeRef CTT, const int64_t *indices,
917 size_t len, CConcreteType ct, LLVMContextRef ctx) {
918 std::vector<int> seq;
919 for (size_t i = 0; i < len; i++) {
920 seq.push_back(indices[i]);
921 }
922 ((TypeTree *)CTT)->insert(seq, eunwrap(ct, *unwrap(ctx)));
923}
925 std::string tmp = ((TypeTree *)src)->str();
926 char *cstr = new char[tmp.length() + 1];
927 std::strcpy(cstr, tmp.c_str());
928
929 return cstr;
930}
931
932// TODO deprecated
933void EnzymeTypeTreeToStringFree(const char *cstr) { delete[] cstr; }
934
935const char *EnzymeTypeAnalyzerToString(void *src) {
936 auto TA = (TypeAnalyzer *)src;
937 std::string str;
938 raw_string_ostream ss(str);
939 TA->dump(ss);
940 ss.str();
941 char *cstr = new char[str.length() + 1];
942 std::strcpy(cstr, str.c_str());
943 return cstr;
944}
945
947 void *src) {
948 std::string str;
949 raw_string_ostream ss(str);
950 for (auto z : gutils->invertedPointers) {
951 ss << "available inversion for " << *z.first << " of " << *z.second << "\n";
952 }
953 ss.str();
954 char *cstr = new char[str.length() + 1];
955 std::strcpy(cstr, str.c_str());
956 return cstr;
957}
958
960 GradientUtils *gutils, LLVMValueRef func, LLVMTypeRef funcTy,
961 LLVMValueRef *args_vr, uint64_t args_size, LLVMValueRef orig_vr,
962 CValueType *valTys, uint64_t valTys_size, LLVMBuilderRef B,
963 uint8_t lookup) {
964 auto orig = cast<CallInst>(unwrap(orig_vr));
965
966 ArrayRef<ValueType> ar((ValueType *)valTys, valTys_size);
967
968 IRBuilder<> &BR = *unwrap(B);
969
970 auto Defs = gutils->getInvertedBundles(orig, ar, BR, lookup != 0);
971
972 SmallVector<Value *, 1> args;
973 for (size_t i = 0; i < args_size; i++) {
974 args.push_back(unwrap(args_vr[i]));
975 }
976
977 auto callval = unwrap(func);
978
979 auto res =
980 BR.CreateCall(cast<FunctionType>(unwrap(funcTy)), callval, args, Defs);
981 return wrap(res);
982}
983
984void EnzymeStringFree(const char *cstr) { delete[] cstr; }
985
986void EnzymeMoveBefore(LLVMValueRef inst1, LLVMValueRef inst2,
987 LLVMBuilderRef B) {
988 Instruction *I1 = cast<Instruction>(unwrap(inst1));
989 Instruction *I2 = cast<Instruction>(unwrap(inst2));
990 if (I1 != I2) {
991 if (B != nullptr) {
992 IRBuilder<> &BR = *unwrap(B);
993 if (I1->getIterator() == BR.GetInsertPoint()) {
994 if (I2->getNextNode() == nullptr)
995 BR.SetInsertPoint(I1->getParent());
996 else
997 BR.SetInsertPoint(I1->getNextNode());
998 }
999 }
1000 I1->moveBefore(I2);
1001 }
1002}
1003
1004void EnzymeSetStringMD(LLVMValueRef Inst, const char *Kind, LLVMValueRef Val) {
1005 MDNode *N = Val ? extractMDNode(unwrap<MetadataAsValue>(Val)) : nullptr;
1006 Value *V = unwrap(Inst);
1007 if (auto I = dyn_cast<Instruction>(V))
1008 I->setMetadata(Kind, N);
1009 else
1010 cast<GlobalVariable>(V)->setMetadata(Kind, N);
1011}
1012
1013LLVMValueRef EnzymeGetStringMD(LLVMValueRef Inst, const char *Kind) {
1014 auto *I = unwrap<Instruction>(Inst);
1015 assert(I && "Expected instruction");
1016 if (auto *MD = I->getMetadata(Kind))
1017 return wrap(MetadataAsValue::get(I->getContext(), MD));
1018 return nullptr;
1019}
1020
1021void EnzymeSetMustCache(LLVMValueRef inst1) {
1022 Instruction *I1 = cast<Instruction>(unwrap(inst1));
1023 I1->setMetadata("enzyme_mustcache", MDNode::get(I1->getContext(), {}));
1024}
1025
1026uint8_t EnzymeHasFromStack(LLVMValueRef inst1) {
1027 Instruction *I1 = cast<Instruction>(unwrap(inst1));
1028 return hasMetadata(I1, "enzyme_fromstack") != 0;
1029}
1030
1031void EnzymeCloneFunctionDISubprogramInto(LLVMValueRef NF, LLVMValueRef F) {
1032 auto &OldFunc = *cast<Function>(unwrap(F));
1033 auto &NewFunc = *cast<Function>(unwrap(NF));
1034 auto OldSP = OldFunc.getSubprogram();
1035 if (!OldSP)
1036 return;
1037 DIBuilder DIB(*OldFunc.getParent(), /*AllowUnresolved=*/false,
1038 OldSP->getUnit());
1039 auto SPType = DIB.createSubroutineType(DIB.getOrCreateTypeArray({}));
1040 DISubprogram::DISPFlags SPFlags = DISubprogram::SPFlagDefinition |
1041 DISubprogram::SPFlagOptimized |
1042 DISubprogram::SPFlagLocalToUnit;
1043 auto NewSP = DIB.createFunction(
1044 OldSP->getUnit(), NewFunc.getName(), NewFunc.getName(), OldSP->getFile(),
1045 /*LineNo=*/0, SPType, /*ScopeLine=*/0, DINode::FlagZero, SPFlags);
1046 NewFunc.setSubprogram(NewSP);
1047 DIB.finalizeSubprogram(NewSP);
1048 return;
1049}
1050
1053}
1054
1055void EnzymeDetectReadonlyOrThrow(LLVMModuleRef M) {
1056 DetectReadonlyOrThrow(*unwrap(M));
1057}
1058
1059void EnzymeDumpModuleRef(LLVMModuleRef M) {
1060 llvm::errs() << *unwrap(M) << "\n";
1061}
1062
1063void EnzymeDumpValueRef(LLVMValueRef M) { llvm::errs() << *unwrap(M) << "\n"; }
1064
1065void EnzymeDumpTypeRef(LLVMTypeRef M) { llvm::errs() << *unwrap(M) << "\n"; }
1066
1067static bool runAttributorOnFunctions(InformationCache &InfoCache,
1068 SetVector<Function *> &Functions,
1069 AnalysisGetter &AG,
1070 CallGraphUpdater &CGUpdater,
1071 bool DeleteFns, bool IsModulePass) {
1072 if (Functions.empty())
1073 return false;
1074
1075 // Create an Attributor and initially empty information cache that is filled
1076 // while we identify default attribute opportunities.
1077 AttributorConfig AC(CGUpdater);
1078 AC.RewriteSignatures = false;
1079 AC.IsModulePass = IsModulePass;
1080 AC.DeleteFns = DeleteFns;
1081 Attributor A(Functions, InfoCache, AC);
1082
1083 for (Function *F : Functions) {
1084 // Populate the Attributor with abstract attribute opportunities in the
1085 // function and the information cache with IR information.
1086 A.identifyDefaultAbstractAttributes(*F);
1087 }
1088
1089 ChangeStatus Changed = A.run();
1090
1091 return Changed == ChangeStatus::CHANGED;
1092}
1093
1094extern "C" void RunAttributorOnModule(LLVMModuleRef M0) {
1095 auto &M = *unwrap(M0);
1096 AnalysisGetter AG;
1097 SetVector<Function *> Functions;
1098 for (Function &F : M)
1099 Functions.insert(&F);
1100
1101 CallGraphUpdater CGUpdater;
1102 BumpPtrAllocator Allocator;
1103 InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr);
1104 runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater,
1105 /* DeleteFns*/ true,
1106 /* IsModulePass */ true);
1107}
1108
1110 static char ID;
1111
1113
1114 bool runOnModule(Module &M) override {
1115 if (skipModule(M))
1116 return false;
1117
1118 AnalysisGetter AG;
1119 SetVector<Function *> Functions;
1120 for (Function &F : M)
1121 Functions.insert(&F);
1122
1123 CallGraphUpdater CGUpdater;
1124 BumpPtrAllocator Allocator;
1125 InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr);
1126 return runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater,
1127 /* DeleteFns*/ true,
1128 /* IsModulePass */ true);
1129 }
1130
1131 void getAnalysisUsage(AnalysisUsage &AU) const override {
1132 // FIXME: Think about passes we will preserve and add them here.
1133 AU.addRequired<TargetLibraryInfoWrapperPass>();
1134 }
1135};
1136extern "C++" char MyAttributorLegacyPass::ID = 0;
1137void EnzymeAddAttributorLegacyPass(LLVMPassManagerRef PM) {
1138 unwrap(PM)->add(new MyAttributorLegacyPass());
1139}
1140
1141LLVMMetadataRef EnzymeMakeNonConstTBAA(LLVMMetadataRef MD) {
1142 auto M = cast<MDNode>(unwrap(MD));
1143 if (M->getNumOperands() != 4)
1144 return MD;
1145 auto CAM = dyn_cast<ConstantAsMetadata>(M->getOperand(3));
1146 if (!CAM)
1147 return MD;
1148 if (!CAM->getValue()->isOneValue())
1149 return MD;
1150 SmallVector<Metadata *, 4> MDs;
1151 for (auto &M : M->operands())
1152 MDs.push_back(M);
1153 MDs[3] =
1154 ConstantAsMetadata::get(ConstantInt::get(CAM->getValue()->getType(), 0));
1155 return wrap(MDNode::get(M->getContext(), MDs));
1156}
1157void EnzymeCopyMetadata(LLVMValueRef inst1, LLVMValueRef inst2) {
1158 cast<Instruction>(unwrap(inst1))
1159 ->copyMetadata(*cast<Instruction>(unwrap(inst2)));
1160}
1161void EnzymeCopyAlignment(LLVMValueRef inst1, LLVMValueRef inst2) {
1162 cast<AllocaInst>(unwrap(inst1))
1163 ->setAlignment(cast<AllocaInst>(unwrap(inst2))->getAlign());
1164}
1165void EnzymeTakeName(LLVMValueRef inst1, LLVMValueRef inst2) {
1166 unwrap(inst1)->takeName(unwrap(inst2));
1167}
1168
1169LLVMMetadataRef EnzymeAnonymousAliasScopeDomain(const char *str,
1170 LLVMContextRef ctx) {
1171 MDBuilder MDB(*unwrap(ctx));
1172 MDNode *scope = MDB.createAnonymousAliasScopeDomain(str);
1173 return wrap(scope);
1174}
1175LLVMMetadataRef EnzymeAnonymousAliasScope(LLVMMetadataRef domain,
1176 const char *str) {
1177 auto dom = cast<MDNode>(unwrap(domain));
1178 MDBuilder MDB(dom->getContext());
1179 MDNode *scope = MDB.createAnonymousAliasScope(dom, str);
1180 return wrap(scope);
1181}
1182uint8_t EnzymeLowerSparsification(LLVMValueRef F, uint8_t replaceAll) {
1183 return LowerSparsification(cast<Function>(unwrap(F)), replaceAll != 0);
1184}
1185
1186void EnzymeAttributeKnownFunctions(LLVMValueRef FC) {
1187 attributeKnownFunctions(*cast<Function>(unwrap(FC)));
1188}
1189
1190void EnzymeSetCalledFunction(LLVMValueRef C_CI, LLVMValueRef C_F,
1191 uint64_t *argrem, uint64_t num_argrem) {
1192 auto CI = cast<CallInst>(unwrap(C_CI));
1193 auto F = cast<Function>(unwrap(C_F));
1194 auto Attrs = CI->getAttributes();
1195 AttributeList NewAttrs;
1196
1197 if (CI->getType() == F->getReturnType()) {
1198 for (auto attr : Attrs.getAttributes(AttributeList::ReturnIndex))
1199 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1200 AttributeList::ReturnIndex, attr);
1201 }
1202 for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
1203 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1204 AttributeList::FunctionIndex, attr);
1205
1206 size_t argremsz = 0;
1207 size_t nexti = 0;
1208 SmallVector<Value *, 1> vals;
1209 for (size_t i = 0, end = CI->arg_size(); i < end; i++) {
1210 if (argremsz < num_argrem) {
1211 if (i == argrem[argremsz]) {
1212 argremsz++;
1213 continue;
1214 }
1215 }
1216 for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i))
1217 NewAttrs = NewAttrs.addAttribute(
1218 F->getContext(), AttributeList::FirstArgIndex + nexti, attr);
1219 vals.push_back(CI->getArgOperand(i));
1220 nexti++;
1221 }
1222 assert(argremsz == num_argrem);
1223
1224 IRBuilder<> B(CI);
1225 SmallVector<OperandBundleDef, 1> Bundles;
1226 for (unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I)
1227 Bundles.emplace_back(CI->getOperandBundleAt(I));
1228 auto NC = B.CreateCall(F, vals, Bundles);
1229 NC->setAttributes(NewAttrs);
1230 NC->copyMetadata(*CI);
1231
1232 if (CI->getType() == F->getReturnType())
1233 CI->replaceAllUsesWith(NC);
1234
1235 if (!NC->getType()->isVoidTy())
1236 NC->takeName(CI);
1237 NC->setCallingConv(CI->getCallingConv());
1238 CI->eraseFromParent();
1239}
1240
1241// clones a function to now miss the return or args
1242LLVMValueRef EnzymeCloneFunctionWithoutReturnOrArgs(LLVMValueRef FC,
1243 uint8_t keepReturnU,
1244 uint64_t *argrem,
1245 uint64_t num_argrem) {
1246 auto F = cast<Function>(unwrap(FC));
1247 auto FT = F->getFunctionType();
1248 bool keepReturn = keepReturnU != 0;
1249
1250 size_t argremsz = 0;
1251 size_t nexti = 0;
1252 SmallVector<Type *, 1> types;
1253 auto Attrs = F->getAttributes();
1254 AttributeList NewAttrs;
1255 for (size_t i = 0, end = FT->getNumParams(); i < end; i++) {
1256 if (argremsz < num_argrem) {
1257 if (i == argrem[argremsz]) {
1258 argremsz++;
1259 continue;
1260 }
1261 }
1262 for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i))
1263 NewAttrs = NewAttrs.addAttribute(
1264 F->getContext(), AttributeList::FirstArgIndex + nexti, attr);
1265 types.push_back(F->getFunctionType()->getParamType(i));
1266 nexti++;
1267 }
1268 if (keepReturn) {
1269 for (auto attr : Attrs.getAttributes(AttributeList::ReturnIndex))
1270 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1271 AttributeList::ReturnIndex, attr);
1272 }
1273 for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
1274 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1275 AttributeList::FunctionIndex, attr);
1276
1277 FunctionType *FTy = FunctionType::get(
1278 keepReturn ? F->getReturnType() : Type::getVoidTy(F->getContext()), types,
1279 FT->isVarArg());
1280
1281 // Create the new function
1282 Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(),
1283 F->getName(), F->getParent());
1284
1285 ValueToValueMapTy VMap;
1286 // Loop over the arguments, copying the names of the mapped arguments over...
1287 nexti = 0;
1288 argremsz = 0;
1289 Function::arg_iterator DestI = NewF->arg_begin();
1290 for (const Argument &I : F->args()) {
1291 if (argremsz < num_argrem) {
1292 if (I.getArgNo() == argrem[argremsz]) {
1293 VMap[&I] = UndefValue::get(I.getType());
1294 argremsz++;
1295 continue;
1296 }
1297 }
1298 DestI->setName(I.getName()); // Copy the name over...
1299 VMap[&I] = &*DestI++; // Add mapping to VMap
1300 }
1301
1302 SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned.
1303 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
1304 Returns, "", nullptr);
1305
1306 if (!keepReturn) {
1307 for (auto &B : *NewF) {
1308 if (auto RI = dyn_cast<ReturnInst>(B.getTerminator())) {
1309 IRBuilder<> B(RI);
1310 auto NRI = B.CreateRetVoid();
1311 NRI->copyMetadata(*RI);
1312 RI->eraseFromParent();
1313 }
1314 }
1315 }
1316 NewF->setAttributes(NewAttrs);
1317 if (!keepReturn)
1318 for (auto &Arg : NewF->args())
1319 Arg.removeAttr(Attribute::Returned);
1320 SmallVector<std::pair<unsigned, MDNode *>, 1> MD;
1321 F->getAllMetadata(MD);
1322 for (auto pair : MD)
1323 if (pair.first != LLVMContext::MD_dbg)
1324 NewF->addMetadata(pair.first, *pair.second);
1325 NewF->takeName(F);
1326 NewF->setCallingConv(F->getCallingConv());
1327 if (!keepReturn)
1328 NewF->addFnAttr("enzyme_retremove", "");
1329
1330 if (num_argrem) {
1331 SmallVector<uint64_t, 1> previdx;
1332 if (Attrs.hasAttribute(AttributeList::FunctionIndex, "enzyme_parmremove")) {
1333 auto attr =
1334 Attrs.getAttribute(AttributeList::FunctionIndex, "enzyme_parmremove");
1335 auto prevstr = attr.getValueAsString();
1336 SmallVector<StringRef, 1> sub;
1337 prevstr.split(sub, ",");
1338 for (auto s : sub) {
1339 uint64_t ival;
1340 bool b = s.getAsInteger(10, ival);
1341 (void)b;
1342 assert(!b);
1343 previdx.push_back(ival);
1344 }
1345 }
1346 SmallVector<uint64_t, 1> nextidx;
1347 for (size_t i = 0; i < num_argrem; i++) {
1348 auto val = argrem[i];
1349 nextidx.push_back(val);
1350 }
1351
1352 size_t prevcnt = 0;
1353 size_t nextcnt = 0;
1354 SmallVector<uint64_t, 1> out;
1355 while (prevcnt < previdx.size() && nextcnt < nextidx.size()) {
1356 if (previdx[prevcnt] <= nextidx[nextcnt] + prevcnt) {
1357 out.push_back(previdx[prevcnt]);
1358 prevcnt++;
1359 } else {
1360 out.push_back(nextidx[nextcnt] + prevcnt);
1361 nextcnt++;
1362 }
1363 }
1364 while (prevcnt < previdx.size()) {
1365 out.push_back(previdx[prevcnt]);
1366 prevcnt++;
1367 }
1368 while (nextcnt < nextidx.size()) {
1369 out.push_back(nextidx[nextcnt] + prevcnt);
1370 nextcnt++;
1371 }
1372
1373 std::string remstr;
1374 for (auto arg : out) {
1375 if (remstr.size())
1376 remstr += ",";
1377 remstr += std::to_string(arg);
1378 }
1379
1380 NewF->addFnAttr("enzyme_parmremove", remstr);
1381 }
1382 return wrap(NewF);
1383}
1384LLVMTypeRef EnzymeAllocaType(LLVMValueRef V) {
1385 return wrap(cast<AllocaInst>(unwrap(V))->getAllocatedType());
1386}
1387LLVMValueRef EnzymeComputeByteOffsetOfGEP(LLVMBuilderRef B_r, LLVMValueRef V_r,
1388 LLVMTypeRef T_r) {
1389 IRBuilder<> &B = *unwrap(B_r);
1390 auto T = cast<IntegerType>(unwrap(T_r));
1391 auto width = T->getBitWidth();
1392 auto uw = unwrap(V_r);
1393 GEPOperator *gep = isa<GetElementPtrInst>(uw)
1394 ? cast<GEPOperator>(cast<GetElementPtrInst>(uw))
1395 : cast<GEPOperator>(cast<ConstantExpr>(uw));
1396 auto &DL = B.GetInsertBlock()->getParent()->getParent()->getDataLayout();
1397
1398#if LLVM_VERSION_MAJOR >= 20
1399 SmallMapVector<Value *, APInt, 4> VariableOffsets;
1400#else
1401 MapVector<Value *, APInt> VariableOffsets;
1402#endif
1403 APInt Offset(width, 0);
1404 bool success = collectOffset(gep, DL, width, VariableOffsets, Offset);
1405 (void)success;
1406 assert(success);
1407 Value *start = ConstantInt::get(T, Offset);
1408 for (auto &pair : VariableOffsets)
1409 start = B.CreateAdd(
1410 start, B.CreateMul(pair.first, ConstantInt::get(T, pair.second)));
1411 return wrap(start);
1412}
1413}
1414
1415extern "C" {
1416
1417LLVMValueRef EnzymeBuildExtractValue(LLVMBuilderRef B, LLVMValueRef AggVal,
1418 unsigned *Index, unsigned Size,
1419 const char *Name) {
1420 return wrap(unwrap(B)->CreateExtractValue(
1421 unwrap(AggVal), ArrayRef<unsigned>(Index, Size), Name));
1422}
1423
1424LLVMValueRef EnzymeBuildInsertValue(LLVMBuilderRef B, LLVMValueRef AggVal,
1425 LLVMValueRef EltVal, unsigned *Index,
1426 unsigned Size, const char *Name) {
1427 return wrap(unwrap(B)->CreateInsertValue(
1428 unwrap(AggVal), unwrap(EltVal), ArrayRef<unsigned>(Index, Size), Name));
1429}
1430}
int64_t EnzymeGetCLInteger(void *ptr)
Definition CApi.cpp:203
void EnzymeGradientUtilsAddToInvertedPointerDiffeTT(DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMValueRef origVal, CTypeTreeRef vd, unsigned LoadSize, LLVMValueRef origptr, LLVMValueRef prediff, LLVMBuilderRef BuilderM, unsigned align, LLVMValueRef premask)
Definition CApi.cpp:542
void EnzymeGradientUtilsSubTransferHelper(GradientUtils *gutils, CDerivativeMode mode, LLVMTypeRef secretty, uint64_t intrinsic, uint64_t dstAlign, uint64_t srcAlign, uint64_t offset, uint8_t dstConstant, LLVMValueRef shadow_dst, uint8_t srcConstant, LLVMValueRef shadow_src, LLVMValueRef length, LLVMValueRef isVolatile, LLVMValueRef MTI, uint8_t allowForward, uint8_t shadowsLookedUp)
Definition CApi.cpp:626
EnzymeTypeAnalysisRef EnzymeGetTypeAnalysisFromTypeAnalyzer(void *TAR)
Definition CApi.cpp:333
void EnzymeAttributeKnownFunctions(LLVMValueRef FC)
Definition CApi.cpp:1186
void EnzymeTypeTreeLookupEq(CTypeTreeRef CTT, int64_t size, const char *dl)
Definition CApi.cpp:897
LLVMMetadataRef EnzymeAnonymousAliasScopeDomain(const char *str, LLVMContextRef ctx)
Definition CApi.cpp:1169
void * EnzymeGradientUtilsGetExternalContext(GradientUtils *gutils)
Definition CApi.cpp:423
uint8_t EnzymeGradientUtilsIsConstantValue(GradientUtils *gutils, LLVMValueRef val)
Definition CApi.cpp:561
LLVMValueRef EnzymeInsertValue(LLVMBuilderRef B, LLVMValueRef val, LLVMValueRef val2, unsigned *sz, int64_t length, const char *name)
Definition CApi.cpp:499
void EnzymeSetCalledFunction(LLVMValueRef C_CI, LLVMValueRef C_F, uint64_t *argrem, uint64_t num_argrem)
Definition CApi.cpp:1190
CDIFFE_TYPE EnzymeGradientUtilsGetReturnDiffeType(GradientUtils *G, LLVMValueRef oval, uint8_t *needsPrimal, uint8_t *needsShadow, CDerivativeMode mode)
Definition CApi.cpp:476
void EnzymeStringFree(const char *cstr)
Definition CApi.cpp:984
void EnzymeCopyAlignment(LLVMValueRef inst1, LLVMValueRef inst2)
Definition CApi.cpp:1161
void EnzymeAddAttributorLegacyPass(LLVMPassManagerRef PM)
Definition CApi.cpp:1137
CTypeTreeRef EnzymeNewTypeTreeTR(CTypeTreeRef CTR)
Definition CApi.cpp:869
void FreeTypeAnalysis(EnzymeTypeAnalysisRef TAR)
Definition CApi.cpp:314
const char * EnzymeGradientUtilsInvertedPointersToString(GradientUtils *gutils, void *src)
Definition CApi.cpp:946
LLVMBasicBlockRef EnzymeGradientUtilsAddReverseBlock(GradientUtils *gutils, LLVMBasicBlockRef block, const char *name, uint8_t forkCache, uint8_t push)
Definition CApi.cpp:642
void EnzymeTakeName(LLVMValueRef inst1, LLVMValueRef inst2)
Definition CApi.cpp:1165
void EnzymeTypeTreeData0Eq(CTypeTreeRef CTT)
Definition CApi.cpp:893
EnzymeTraceInterfaceRef CreateEnzymeDynamicTraceInterface(LLVMValueRef interface, LLVMValueRef F)
Definition CApi.cpp:255
static bool runAttributorOnFunctions(InformationCache &InfoCache, SetVector< Function * > &Functions, AnalysisGetter &AG, CallGraphUpdater &CGUpdater, bool DeleteFns, bool IsModulePass)
Definition CApi.cpp:1067
LLVMValueRef EnzymeGetStringMD(LLVMValueRef Inst, const char *Kind)
Definition CApi.cpp:1013
void EnzymeFreeTypeTree(CTypeTreeRef CTT)
Definition CApi.cpp:872
EnzymeTraceInterfaceRef FindEnzymeStaticTraceInterface(LLVMModuleRef M)
Definition CApi.cpp:225
static MDNode * extractMDNode(MetadataAsValue *MAV)
Definition CApi.cpp:842
LLVMTypeRef EnzymeExtractUnderlyingTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret)
Definition CApi.cpp:807
void EnzymeSetCLInteger(void *ptr, int64_t val)
Definition CApi.cpp:198
void ClearTypeAnalysis(EnzymeTypeAnalysisRef TAR)
Definition CApi.cpp:312
CConcreteType EnzymeTypeTreeInner0(CTypeTreeRef CTT)
Definition CApi.cpp:905
void EnzymeTypeTreeToStringFree(const char *cstr)
Definition CApi.cpp:933
EnzymeLogicRef CreateEnzymeLogic(uint8_t PostOpt)
Definition CApi.cpp:213
void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle)
Definition CApi.cpp:393
uint8_t EnzymeGetCLBool(void *ptr)
Definition CApi.cpp:193
void EnzymeDumpValueRef(LLVMValueRef M)
Definition CApi.cpp:1063
void FreeEnzymeLogic(EnzymeLogicRef Ref)
Definition CApi.cpp:268
void EnzymeGradientUtilsSetDebugLocFromOriginal(GradientUtils *gutils, LLVMValueRef val, LLVMValueRef orig)
Definition CApi.cpp:491
LLVMValueRef EnzymeGradientUtilsLookup(GradientUtils *gutils, LLVMValueRef val, LLVMBuilderRef B)
Definition CApi.cpp:506
const char * EnzymeTypeAnalyzerToString(void *src)
Definition CApi.cpp:935
CDIFFE_TYPE EnzymeGradientUtilsGetDiffeType(GradientUtils *G, LLVMValueRef oval, uint8_t foreignFunction)
Definition CApi.cpp:470
uint8_t EnzymeLowerSparsification(LLVMValueRef F, uint8_t replaceAll)
Definition CApi.cpp:1182
LLVMMetadataRef EnzymeAnonymousAliasScope(LLVMMetadataRef domain, const char *str)
Definition CApi.cpp:1175
void EnzymeReplaceOriginalToNew(GradientUtils *gutils, LLVMValueRef origC, LLVMValueRef repC)
Definition CApi.cpp:453
void EnzymeSetStringMD(LLVMValueRef Inst, const char *Kind, LLVMValueRef Val)
Definition CApi.cpp:1004
uint8_t EnzymeGradientUtilsGetAtomicAdd(GradientUtils *gutils)
Definition CApi.cpp:431
LLVMValueRef EnzymeGradientUtilsCallWithInvertedBundles(GradientUtils *gutils, LLVMValueRef func, LLVMTypeRef funcTy, LLVMValueRef *args_vr, uint64_t args_size, LLVMValueRef orig_vr, CValueType *valTys, uint64_t valTys_size, LLVMBuilderRef B, uint8_t lookup)
Definition CApi.cpp:959
void EnzymeDetectReadonlyOrThrow(LLVMModuleRef M)
Definition CApi.cpp:1055
LLVMValueRef EnzymeBuildExtractValue(LLVMBuilderRef B, LLVMValueRef AggVal, unsigned *Index, unsigned Size, const char *Name)
Definition CApi.cpp:1417
void EnzymeLogicSetExternalContext(EnzymeLogicRef Ref, void *ExternalContext)
Definition CApi.cpp:217
LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils *gutils, LLVMValueRef val)
Definition CApi.cpp:448
void EnzymeTypeTreeInsertEq(CTypeTreeRef CTT, const int64_t *indices, size_t len, CConcreteType ct, LLVMContextRef ctx)
Definition CApi.cpp:916
uint8_t EnzymeGradientUtilsGetRuntimeActivity(GradientUtils *gutils)
Definition CApi.cpp:419
void EnzymeDumpTypeRef(LLVMTypeRef M)
Definition CApi.cpp:1065
void EnzymeCloneFunctionDISubprogramInto(LLVMValueRef NF, LLVMValueRef F)
Definition CApi.cpp:1031
void FreeTraceInterface(EnzymeTraceInterfaceRef Ref)
Definition CApi.cpp:270
void EnzymeLogicErasePreprocessedFunctions(EnzymeLogicRef Ref)
Definition CApi.cpp:262
CTypeTreeRef EnzymeNewTypeTree()
Definition CApi.cpp:865
void EnzymeGradientUtilsSetReverseBlock(GradientUtils *gutils, LLVMBasicBlockRef block)
Definition CApi.cpp:651
LLVMValueRef EnzymeCreateBatch(EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, LLVMValueRef tobatch, unsigned width, CBATCH_TYPE *arg_types, size_t arg_types_size, CBATCH_TYPE retType)
Definition CApi.cpp:755
void * EnzymeGradientUtilsTypeAnalyzer(GradientUtils *G)
Definition CApi.cpp:329
uint8_t EnzymeGradientUtilsIsConstantInstruction(GradientUtils *gutils, LLVMValueRef val)
Definition CApi.cpp:566
LLVMTypeRef EnzymeGradientUtilsGetShadowType(GradientUtils *gutils, LLVMTypeRef T)
Definition CApi.cpp:439
void RunAttributorOnModule(LLVMModuleRef M0)
Definition CApi.cpp:1094
void EnzymeRegisterCallHandler(const char *Name, CustomAugmentedFunctionForward FwdHandle, CustomFunctionReverse RevHandle)
Definition CApi.cpp:370
LLVMValueRef EnzymeCreateForwardDiff(EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode, uint8_t freeMemory, uint8_t runtimeActivity, uint8_t strongZero, unsigned width, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented)
Definition CApi.cpp:661
void EnzymeGradientUtilsDumpTypeResults(GradientUtils *gutils)
Definition CApi.cpp:622
uint8_t EnzymeSetTypeTree(CTypeTreeRef dst, CTypeTreeRef src)
Definition CApi.cpp:873
uint8_t EnzymeGradientUtilsGetStrongZero(GradientUtils *gutils)
Definition CApi.cpp:427
EnzymeLogicRef EnzymeTypeAnalysisGetLogic(EnzymeTypeAnalysisRef TAR)
Definition CApi.cpp:319
LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils *gutils)
Definition CApi.cpp:571
void EnzymeGradientUtilsAddToDiffe(DiffeGradientUtils *gutils, LLVMValueRef val, LLVMValueRef diffe, LLVMBuilderRef B, LLVMTypeRef T)
Definition CApi.cpp:522
LLVMValueRef EnzymeCreatePrimalAndGradient(EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed, CDerivativeMode mode, uint8_t runtimeActivity, uint8_t strongZero, unsigned width, uint8_t freeMemory, LLVMTypeRef additionalArg, uint8_t forceAnonymousTape, CFnTypeInfo typeInfo, uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd)
Definition CApi.cpp:687
uint8_t EnzymeGradientUtilsGetUncacheableArgs(GradientUtils *gutils, LLVMValueRef orig, uint8_t *data, uint64_t size)
Definition CApi.cpp:575
LLVMValueRef EnzymeCreateTrace(EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, LLVMValueRef totrace, LLVMValueRef *sample_functions, size_t sample_functions_size, LLVMValueRef *observe_functions, size_t observe_functions_size, const char *active_random_variables[], size_t active_random_variables_size, CProbProgMode mode, uint8_t autodiff, EnzymeTraceInterfaceRef interface)
Definition CApi.cpp:769
void EnzymeTypeTreeOnlyEq(CTypeTreeRef CTT, int64_t x)
Definition CApi.cpp:889
TargetLibraryInfo eunwrap(LLVMTargetLibraryInfoRef P)
Definition CApi.cpp:65
LLVMValueRef EnzymeComputeByteOffsetOfGEP(LLVMBuilderRef B_r, LLVMValueRef V_r, LLVMTypeRef T_r)
Definition CApi.cpp:1387
void EnzymeSetMustCache(LLVMValueRef inst1)
Definition CApi.cpp:1021
void EnzymeDumpModuleRef(LLVMModuleRef M)
Definition CApi.cpp:1059
void EnzymeGradientUtilsEraseWithPlaceholder(GradientUtils *G, LLVMValueRef I, LLVMValueRef orig, uint8_t erase)
Definition CApi.cpp:340
LLVMTypeRef EnzymeExtractTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret)
Definition CApi.cpp:813
CTypeTreeRef EnzymeTypeTreeFromMD(LLVMValueRef Val)
Definition CApi.cpp:853
CDerivativeMode EnzymeGradientUtilsGetMode(GradientUtils *gutils)
Definition CApi.cpp:465
LLVMValueRef EnzymeGradientUtilsDiffe(DiffeGradientUtils *gutils, LLVMValueRef val, LLVMBuilderRef B)
Definition CApi.cpp:517
EnzymeAugmentedReturnPtr ewrap(const AugmentedReturn &AR)
Definition CApi.cpp:81
CTypeTreeRef EnzymeNewTypeTreeCT(CConcreteType CT, LLVMContextRef ctx)
Definition CApi.cpp:866
EnzymeTraceInterfaceRef CreateEnzymeStaticTraceInterface(LLVMContextRef C, LLVMValueRef getTraceFunction, LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction, LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction, LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction, LLVMValueRef insertChoiceGradientFunction, LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction, LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction, LLVMValueRef hasChoiceFunction)
Definition CApi.cpp:229
void EnzymeGradientUtilsAddToInvertedPointerDiffe(DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMValueRef origVal, LLVMTypeRef addingType, unsigned start, unsigned size, LLVMValueRef origptr, LLVMValueRef dif, LLVMBuilderRef BuilderM, unsigned align, LLVMValueRef mask)
Definition CApi.cpp:528
uint8_t EnzymeMergeTypeTree(CTypeTreeRef dst, CTypeTreeRef src)
Definition CApi.cpp:876
LLVMValueRef EnzymeBuildInsertValue(LLVMBuilderRef B, LLVMValueRef AggVal, LLVMValueRef EltVal, unsigned *Index, unsigned Size, const char *Name)
Definition CApi.cpp:1424
void EnzymeGradientUtilsErase(GradientUtils *G, LLVMValueRef I)
Definition CApi.cpp:337
uint8_t EnzymeHasFromStack(LLVMValueRef inst1)
Definition CApi.cpp:1026
EnzymeTypeAnalysisRef CreateTypeAnalysis(EnzymeLogicRef Log, char **customRuleNames, CustomRuleType *customRules, size_t numRules)
Definition CApi.cpp:274
LLVMValueRef EnzymeTypeTreeToMD(CTypeTreeRef CTR, LLVMContextRef ctx)
Definition CApi.cpp:860
void * EnzymeAnalyzeTypes(EnzymeTypeAnalysisRef TAR, CFnTypeInfo CTI, LLVMValueRef F)
Definition CApi.cpp:323
LLVMMetadataRef EnzymeMakeNonConstTBAA(LLVMMetadataRef MD)
Definition CApi.cpp:1141
void EnzymeGradientUtilsSetDiffe(DiffeGradientUtils *gutils, LLVMValueRef val, LLVMValueRef diffe, LLVMBuilderRef B)
Definition CApi.cpp:556
void EnzymeSetCLBool(void *ptr, uint8_t val)
Definition CApi.cpp:188
LLVMValueRef EnzymeExtractFunctionFromAugmentation(EnzymeAugmentedReturnPtr ret)
Definition CApi.cpp:801
LLVMValueRef EnzymeGradientUtilsInvertPointer(GradientUtils *gutils, LLVMValueRef val, LLVMBuilderRef B)
Definition CApi.cpp:511
void EnzymeSetCLString(void *ptr, const char *val)
Definition CApi.cpp:208
std::set< int64_t > eunwrap64(IntList IL)
Definition CApi.cpp:118
void EnzymeExtractReturnInfo(EnzymeAugmentedReturnPtr ret, int64_t *data, uint8_t *existed, size_t len)
Definition CApi.cpp:825
CTypeTreeRef EnzymeGradientUtilsAllocAndGetTypeTree(GradientUtils *gutils, LLVMValueRef val)
Definition CApi.cpp:614
void EnzymeGradientUtilsReplaceAWithB(GradientUtils *G, LLVMValueRef A, LLVMValueRef B)
Definition CApi.cpp:347
void EnzymeTypeTreeShiftIndiciesEq(CTypeTreeRef CTT, const char *datalayout, int64_t offset, int64_t maxSize, uint64_t addOffset)
Definition CApi.cpp:909
void EnzymeCopyMetadata(LLVMValueRef inst1, LLVMValueRef inst2)
Definition CApi.cpp:1157
uint8_t EnzymeCheckedMergeTypeTree(CTypeTreeRef dst, CTypeTreeRef src, uint8_t *legalP)
Definition CApi.cpp:879
void * EnzymeLogicGetExternalContext(EnzymeLogicRef Ref)
Definition CApi.cpp:221
void EnzymeReplaceFunctionImplementation(LLVMModuleRef M)
Definition CApi.cpp:1051
void EnzymeTypeTreeCanonicalizeInPlace(CTypeTreeRef CTT, int64_t size, const char *dl)
Definition CApi.cpp:900
void EnzymeRegisterDiffUseCallHandler(char *Name, CustomFunctionDiffUse Handle)
Definition CApi.cpp:406
const char * EnzymeTypeTreeToString(CTypeTreeRef src)
Definition CApi.cpp:924
void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, CustomShadowFree FHandle)
Definition CApi.cpp:352
uint64_t EnzymeGradientUtilsGetWidth(GradientUtils *gutils)
Definition CApi.cpp:435
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnUsed, uint8_t shadowReturnUsed, CFnTypeInfo typeInfo, uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, size_t overwritten_args_size, uint8_t forceAnonymousTape, uint8_t runtimeActivity, uint8_t strongZero, unsigned width, uint8_t AtomicAdd)
Definition CApi.cpp:727
LLVMValueRef EnzymeCloneFunctionWithoutReturnOrArgs(LLVMValueRef FC, uint8_t keepReturnU, uint64_t *argrem, uint64_t num_argrem)
Definition CApi.cpp:1242
LLVMTypeRef EnzymeGetShadowType(uint64_t width, LLVMTypeRef T)
Definition CApi.cpp:444
LLVMTypeRef EnzymeAllocaType(LLVMValueRef V)
Definition CApi.cpp:1384
void ClearEnzymeLogic(EnzymeLogicRef Ref)
Definition CApi.cpp:260
void EnzymeMoveBefore(LLVMValueRef inst1, LLVMValueRef inst2, LLVMBuilderRef B)
Definition CApi.cpp:986
uint8_t(* CustomFunctionForward)(LLVMBuilderRef, LLVMValueRef, GradientUtils *, LLVMValueRef *, LLVMValueRef *)
Definition CApi.h:198
struct EnzymeOpaqueTypeAnalysis * EnzymeTypeAnalysisRef
Definition CApi.h:38
struct EnzymeOpaqueTraceInterface * EnzymeTraceInterfaceRef
Definition CApi.h:47
struct EnzymeTypeTree * CTypeTreeRef
Definition CApi.h:87
CValueType
Definition CApi.h:79
uint8_t(* CustomAugmentedFunctionForward)(LLVMBuilderRef, LLVMValueRef, GradientUtils *, LLVMValueRef *, LLVMValueRef *, LLVMValueRef *)
Definition CApi.h:206
CConcreteType
Definition CApi.h:54
@ DT_Integer
Definition CApi.h:56
@ DT_Double
Definition CApi.h:60
@ DT_Anything
Definition CApi.h:55
@ DT_Unknown
Definition CApi.h:61
@ DT_FP128
Definition CApi.h:64
@ DT_BFloat16
Definition CApi.h:63
@ DT_X86_FP80
Definition CApi.h:62
@ DT_Float
Definition CApi.h:59
@ DT_Half
Definition CApi.h:58
@ DT_Pointer
Definition CApi.h:57
uint8_t(* CustomRuleType)(int, CTypeTreeRef, CTypeTreeRef *, struct IntList *, size_t, LLVMValueRef, void *)
Definition CApi.h:147
uint8_t(* CustomFunctionDiffUse)(LLVMValueRef, const GradientUtils *, LLVMValueRef, uint8_t, CDerivativeMode, uint8_t *)
Definition CApi.h:202
CProbProgMode
Definition CApi.h:142
void(* CustomFunctionReverse)(LLVMBuilderRef, LLVMValueRef, DiffeGradientUtils *, LLVMValueRef)
Definition CApi.h:212
CDIFFE_TYPE
Definition CApi.h:120
CDerivativeMode
Definition CApi.h:133
LLVMValueRef(* CustomShadowFree)(LLVMBuilderRef, LLVMValueRef)
Definition CApi.h:193
struct EnzymeOpaqueLogic * EnzymeLogicRef
Definition CApi.h:41
CBATCH_TYPE
Definition CApi.h:131
LLVMValueRef(* CustomShadowAlloc)(LLVMBuilderRef, LLVMValueRef, size_t, LLVMValueRef *, GradientUtils *)
Definition CApi.h:190
struct EnzymeOpaqueAugmentedReturn * EnzymeAugmentedReturnPtr
Definition CApi.h:44
StringMap< std::function< bool(const CallInst *, const GradientUtils *, const Value *, bool, DerivativeMode, bool &)> > customDiffUseHandlers
void ReplaceFunctionImplementation(Module &M)
bool DetectReadonlyOrThrow(Module &M)
bool LowerSparsification(llvm::Function *F, bool replaceAll)
Lower __enzyme_todense, returning if changed.
static std::string str(AugmentedStruct c)
Definition EnzymeLogic.h:62
AugmentedStruct
Definition EnzymeLogic.h:60
StringMap< std::function< CallInst *(IRBuilder<> &, Value *)> > shadowErasers
StringMap< std::function< Value *(IRBuilder<> &, CallInst *, ArrayRef< Value * >, GradientUtils *)> > shadowHandlers
StringMap< std::pair< std::function< bool(IRBuilder<> &, CallInst *, GradientUtils &, Value *&, Value *&, Value *&)>, std::function< void(IRBuilder<> &, CallInst *, DiffeGradientUtils &, Value *)> > > customCallHandlers
StringMap< std::function< bool(IRBuilder<> &, CallInst *, GradientUtils &, Value *&, Value *&)> > customFwdCallHandlers
void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode, Type *secretty, Intrinsic::ID intrinsic, unsigned dstalign, unsigned srcalign, unsigned offset, bool dstConstant, Value *shadow_dst, bool srcConstant, Value *shadow_src, Value *length, Value *isVolatile, llvm::CallInst *MTI, bool allowForward, bool shadowsLookedUp, bool backwardsShadow)
bool attributeKnownFunctions(llvm::Function &F)
Definition Utils.cpp:114
bool collectOffset(GEPOperator *gep, const DataLayout &DL, unsigned BitWidth, MapVector< Value *, APInt > &VariableOffsets, APInt &ConstantOffset)
Definition Utils.cpp:4169
BATCH_TYPE
Definition Utils.h:385
@ Args
Return is a struct of all args.
DIFFE_TYPE
Potential differentiable argument classifications.
Definition Utils.h:374
static llvm::MDNode * hasMetadata(const llvm::GlobalObject *O, llvm::StringRef kind)
Check if a global has metadata.
Definition Utils.h:339
ProbProgMode
Definition Utils.h:399
ValueType
Classification of value as an original program variable, a derivative variable, neither,...
Definition Utils.h:409
DerivativeMode
Definition Utils.h:390
return structtype if recursive function
llvm::BasicBlock * inversionAllocs
Concrete SubType of a given value.
BaseType SubTypeEnum
Category of underlying type.
llvm::Type * isFloat() const
Return the floating point type, if this is a float.
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal, llvm::Type *addingType, unsigned start, unsigned size, llvm::Value *origptr, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, llvm::MaybeAlign align=llvm::MaybeAlign(), llvm::Value *mask=nullptr)
align is the alignment that should be specified for load/store to pointer
llvm::Value * diffe(llvm::Value *val, llvm::IRBuilder<> &BuilderM)
void setDiffe(llvm::Value *val, llvm::Value *toset, llvm::IRBuilder<> &BuilderM)
llvm::SmallVector< llvm::SelectInst *, 4 > addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, llvm::Type *addingType, llvm::ArrayRef< llvm::Value * > idxs={}, llvm::Value *mask=nullptr)
Returns created select instructions, if any.
void * ExternalContext
Provided through the frontend and only used from it.
DerivativeMode mode
llvm::ValueMap< const llvm::Value *, InvertedPointerVH > invertedPointers
llvm::DebugLoc getNewFromOriginal(const llvm::DebugLoc L) const
void eraseWithPlaceholder(llvm::Instruction *I, llvm::Instruction *orig, const llvm::Twine &suffix="_replacementA", bool erase=true)
TypeResults TR
llvm::BasicBlock * addReverseBlock(llvm::BasicBlock *currentBlock, llvm::Twine const &name, bool forkCache=true, bool push=true)
const std::map< llvm::CallInst *, std::pair< bool, const std::vector< bool > > > * overwritten_args_map_ptr
unsigned getWidth()
std::map< llvm::BasicBlock *, llvm::SmallVector< llvm::BasicBlock *, 4 > > reverseBlocks
Map of primal block to corresponding block(s) in reverse.
llvm::Function * oldFunc
llvm::SmallVector< llvm::OperandBundleDef, 2 > getInvertedBundles(llvm::CallInst *orig, llvm::ArrayRef< ValueType > types, llvm::IRBuilder<> &Builder2, bool lookup, const llvm::ValueToValueMapTy &available=llvm::ValueToValueMapTy())
void replaceAWithB(llvm::Value *A, llvm::Value *B, bool storeInCache=false) override
Replace this instruction both in LLVM modules and any local data-structures.
EnzymeLogic & Logic
DIFFE_TYPE getDiffeType(llvm::Value *v, bool foreignFunction) const
llvm::Value * lookupM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &incoming_availalble=llvm::ValueToValueMapTy(), bool tryLegalRecomputeCheck=true, llvm::BasicBlock *scope=nullptr) override
High-level utility to get the value an instruction at a new location specified by BuilderM.
llvm::ValueMap< const llvm::Value *, AssertingReplacingVH > originalToNewFn
DIFFE_TYPE getReturnDiffeType(llvm::Value *orig, bool *primalReturnUsedP, bool *shadowReturnUsedP, DerivativeMode cmode) const
std::map< llvm::BasicBlock *, llvm::BasicBlock * > reverseBlockToPrimal
Map of block in reverse to corresponding primal block.
bool isConstantInstruction(const llvm::Instruction *inst) const
llvm::ValueMap< const llvm::Value *, AssertingReplacingVH > newToOriginalFn
static llvm::Type * getShadowType(llvm::Type *ty, unsigned width)
bool isConstantValue(llvm::Value *val) const
llvm::Value * invertPointerM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, bool nullShadow=false)
void erase(llvm::Instruction *I) override
Erase this instruction both from LLVM modules and any local data-structures.
Full interprocedural TypeAnalysis.
llvm::StringMap< std::function< bool(int, TypeTree &, llvm::ArrayRef< TypeTree >, llvm::ArrayRef< std::set< int64_t > >, llvm::CallBase *, TypeAnalyzer *)> > CustomRules
Map of custom function call handlers.
void clear()
Clear existing analyses.
EnzymeLogic & Logic
Helper class that computes the fixed-point type results of a given function.
TypeAnalysis & interprocedural
Calling TypeAnalysis to be used in the case of calls to other functions.
void dump(llvm::raw_ostream &ss=llvm::errs()) const
Prints all known information.
TypeAnalyzer * analyzer
TypeTree query(llvm::Value *val) const
The TypeTree of a particular Value.
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
Definition TypeTree.h:72
void insertFromMD(llvm::MDNode *md, const std::vector< int > &prev={})
Definition TypeTree.h:1425
struct IntList * KnownValues
The specific constant(s) known to represented by an argument, if constant.
Definition CApi.h:117
CTypeTreeRef * Arguments
Types of arguments, assumed of size len(Arguments)
Definition CApi.h:110
CTypeTreeRef Return
Type of return.
Definition CApi.h:113
Struct containing all contextual type information for a particular function call.
std::map< llvm::Argument *, TypeTree > Arguments
Types of arguments.
TypeTree Return
Type of return.
std::map< llvm::Argument *, std::set< int64_t > > KnownValues
The specific constant(s) known to represented by an argument, if constant.
Definition CApi.h:49
int64_t * data
Definition CApi.h:50
size_t size
Definition CApi.h:51
void getAnalysisUsage(AnalysisUsage &AU) const override
Definition CApi.cpp:1131
bool runOnModule(Module &M) override
Definition CApi.cpp:1114
todiff is the function to differentiate retType is the activity info of the return.