Enzyme main
Loading...
Searching...
No Matches
CApi.h
Go to the documentation of this file.
1//===- CApi.h - Enzyme API exported to C for external use -----------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file declares various utility functions of Enzyme for access via C
22//
23//===----------------------------------------------------------------------===//
24#ifndef ENZYME_CAPI_H
25#define ENZYME_CAPI_H
26
27#include "llvm-c/Core.h"
28#include "llvm-c/DataTypes.h"
29// #include "llvm-c/Initialization.h"
30#include "llvm-c/Target.h"
31#include <stddef.h>
32
33#ifdef __cplusplus
34extern "C" {
35#endif
36
37struct EnzymeOpaqueTypeAnalysis;
38typedef struct EnzymeOpaqueTypeAnalysis *EnzymeTypeAnalysisRef;
39
40struct EnzymeOpaqueLogic;
41typedef struct EnzymeOpaqueLogic *EnzymeLogicRef;
42
43struct EnzymeOpaqueAugmentedReturn;
44typedef struct EnzymeOpaqueAugmentedReturn *EnzymeAugmentedReturnPtr;
45
46struct EnzymeOpaqueTraceInterface;
47typedef struct EnzymeOpaqueTraceInterface *EnzymeTraceInterfaceRef;
48
49struct IntList {
50 int64_t *data;
51 size_t size;
52};
53
66
71
72/*
73struct CTypeTree {
74 struct CDataPair *data;
75 size_t size;
76};
77*/
78
85
86struct EnzymeTypeTree;
87typedef struct EnzymeTypeTree *CTypeTreeRef;
94void EnzymeTypeTreeOnlyEq(CTypeTreeRef dst, int64_t x);
96void EnzymeTypeTreeShiftIndiciesEq(CTypeTreeRef dst, const char *datalayout,
97 int64_t offset, int64_t maxSize,
98 uint64_t addOffset);
99void EnzymeTypeTreeInsertEq(CTypeTreeRef dst, const int64_t *indices,
100 size_t len, CConcreteType ct, LLVMContextRef ctx);
101const char *EnzymeTypeTreeToString(CTypeTreeRef src);
102void EnzymeTypeTreeToStringFree(const char *cstr);
103
104void EnzymeSetCLBool(void *, uint8_t);
105void EnzymeSetCLInteger(void *, int64_t);
106void EnzymeSetCLString(void *, const char *);
107
109 /// Types of arguments, assumed of size len(Arguments)
111
112 /// Type of return
114
115 /// The specific constant(s) known to represented by an argument, if constant
116 // map is [arg number] => list
118};
119
120typedef enum {
121 DFT_OUT_DIFF = 0, // add differential to an output struct. Only for scalar
122 // values in ReverseMode variants.
123 DFT_DUP_ARG = 1, // duplicate the argument and store differential inside.
124 // For references, pointers, or integers in ReverseMode
125 // variants. For all types in ForwardMode variants.
126 DFT_CONSTANT = 2, // no differential. Usable everywhere.
127 DFT_DUP_NONEED = 3 // duplicate this argument and store differential inside,
128 // but don't need the forward. Same as DUP_ARG otherwise.
130
131typedef enum { BT_SCALAR = 0, BT_VECTOR = 1 } CBATCH_TYPE;
132
141
142typedef enum {
146
147typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/,
148 CTypeTreeRef * /*args*/,
149 struct IntList * /*knownValues*/,
150 size_t /*numArgs*/, LLVMValueRef,
151 void * /*TA*/);
153 char **customRuleNames,
154 CustomRuleType *customRules,
155 size_t numRules);
158
161
164 LLVMContextRef C, LLVMValueRef getTraceFunction,
165 LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction,
166 LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction,
167 LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction,
168 LLVMValueRef insertChoiceGradientFunction,
169 LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction,
170 LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction,
171 LLVMValueRef hasChoiceFunction);
173CreateEnzymeDynamicTraceInterface(LLVMValueRef interface, LLVMValueRef F);
174EnzymeLogicRef CreateEnzymeLogic(uint8_t PostOpt);
177void EnzymeLogicSetExternalContext(EnzymeLogicRef, void *ExternalContext);
179
181 uint8_t *existed, size_t len);
182
183LLVMValueRef
186
187class GradientUtils;
189
190typedef LLVMValueRef (*CustomShadowAlloc)(LLVMBuilderRef, LLVMValueRef,
191 size_t /*numArgs*/, LLVMValueRef *,
192 GradientUtils *);
193typedef LLVMValueRef (*CustomShadowFree)(LLVMBuilderRef, LLVMValueRef);
194
196 CustomShadowFree FHandle);
197
198typedef uint8_t (*CustomFunctionForward)(LLVMBuilderRef, LLVMValueRef,
199 GradientUtils *, LLVMValueRef *,
200 LLVMValueRef *);
201
202typedef uint8_t (*CustomFunctionDiffUse)(LLVMValueRef, const GradientUtils *,
203 LLVMValueRef, uint8_t, CDerivativeMode,
204 uint8_t *);
205
206typedef uint8_t (*CustomAugmentedFunctionForward)(LLVMBuilderRef, LLVMValueRef,
208 LLVMValueRef *,
209 LLVMValueRef *,
210 LLVMValueRef *);
211
212typedef void (*CustomFunctionReverse)(LLVMBuilderRef, LLVMValueRef,
213 DiffeGradientUtils *, LLVMValueRef);
214
215LLVMValueRef EnzymeCreateForwardDiff(
216 EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
217 LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
218 size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue,
219 CDerivativeMode mode, uint8_t freeMemory, uint8_t runtimeActivity,
220 uint8_t strongZero, unsigned width, LLVMTypeRef additionalArg,
221 CFnTypeInfo typeInfo, uint8_t subsequent_calls_may_write,
222 uint8_t *_overwritten_args, size_t overwritten_args_size,
223 EnzymeAugmentedReturnPtr augmented);
224
226 EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
227 LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
228 size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue,
229 uint8_t dretUsed, CDerivativeMode mode, uint8_t runtimeActivity,
230 uint8_t strongZero, unsigned width, uint8_t freeMemory,
231 LLVMTypeRef additionalArg, uint8_t forceAnonymousTape, CFnTypeInfo typeInfo,
232 uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args,
233 size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented,
234 uint8_t AtomicAdd);
235
236void EnzymeRegisterCallHandler(const char *Name,
238 CustomFunctionReverse RevHandle);
239
241 LLVMValueRef val);
242
243// TODO: Other API functions that are defined in CApi.cpp for GradientUtils
246
247#ifdef __cplusplus
248}
249#endif
250
251#endif
uint8_t(* CustomFunctionForward)(LLVMBuilderRef, LLVMValueRef, GradientUtils *, LLVMValueRef *, LLVMValueRef *)
Definition CApi.h:198
void * EnzymeLogicGetExternalContext(EnzymeLogicRef)
Definition CApi.cpp:221
EnzymeTypeAnalysisRef EnzymeGetTypeAnalysisFromTypeAnalyzer(void *TAR)
Definition CApi.cpp:333
void EnzymeSetCLBool(void *, uint8_t)
Definition CApi.cpp:188
void * EnzymeGradientUtilsGetExternalContext(GradientUtils *gutils)
Definition CApi.cpp:423
struct EnzymeOpaqueTypeAnalysis * EnzymeTypeAnalysisRef
Definition CApi.h:38
struct EnzymeOpaqueTraceInterface * EnzymeTraceInterfaceRef
Definition CApi.h:47
struct EnzymeTypeTree * CTypeTreeRef
Definition CApi.h:87
EnzymeTraceInterfaceRef CreateEnzymeDynamicTraceInterface(LLVMValueRef interface, LLVMValueRef F)
Definition CApi.cpp:255
void EnzymeTypeTreeData0Eq(CTypeTreeRef dst)
Definition CApi.cpp:893
void EnzymeTypeTreeInsertEq(CTypeTreeRef dst, const int64_t *indices, size_t len, CConcreteType ct, LLVMContextRef ctx)
Definition CApi.cpp:916
CValueType
Definition CApi.h:79
@ VT_Shadow
Definition CApi.h:82
@ VT_None
Definition CApi.h:80
@ VT_Primal
Definition CApi.h:81
@ VT_Both
Definition CApi.h:83
void EnzymeFreeTypeTree(CTypeTreeRef CTT)
Definition CApi.cpp:872
EnzymeTraceInterfaceRef FindEnzymeStaticTraceInterface(LLVMModuleRef M)
Definition CApi.cpp:225
void EnzymeSetCLString(void *, const char *)
Definition CApi.cpp:208
void EnzymeTypeTreeToStringFree(const char *cstr)
Definition CApi.cpp:933
EnzymeLogicRef CreateEnzymeLogic(uint8_t PostOpt)
Definition CApi.cpp:213
CTypeTreeRef EnzymeNewTypeTreeTR(CTypeTreeRef)
Definition CApi.cpp:869
void EnzymeLogicSetExternalContext(EnzymeLogicRef, void *ExternalContext)
Definition CApi.cpp:217
void ClearTypeAnalysis(EnzymeTypeAnalysisRef)
Definition CApi.cpp:312
uint8_t EnzymeGradientUtilsGetAtomicAdd(GradientUtils *gutils)
Definition CApi.cpp:431
uint8_t(* CustomAugmentedFunctionForward)(LLVMBuilderRef, LLVMValueRef, GradientUtils *, LLVMValueRef *, LLVMValueRef *, LLVMValueRef *)
Definition CApi.h:206
LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils *gutils, LLVMValueRef val)
Definition CApi.cpp:448
void EnzymeSetCLInteger(void *, int64_t)
Definition CApi.cpp:198
CTypeTreeRef EnzymeNewTypeTree()
Definition CApi.cpp:865
CConcreteType
Definition CApi.h:54
@ DT_Integer
Definition CApi.h:56
@ DT_Double
Definition CApi.h:60
@ DT_Anything
Definition CApi.h:55
@ DT_Unknown
Definition CApi.h:61
@ DT_FP128
Definition CApi.h:64
@ DT_BFloat16
Definition CApi.h:63
@ DT_X86_FP80
Definition CApi.h:62
@ DT_Float
Definition CApi.h:59
@ DT_Half
Definition CApi.h:58
@ DT_Pointer
Definition CApi.h:57
uint8_t(* CustomRuleType)(int, CTypeTreeRef, CTypeTreeRef *, struct IntList *, size_t, LLVMValueRef, void *)
Definition CApi.h:147
void EnzymeRegisterCallHandler(const char *Name, CustomAugmentedFunctionForward FwdHandle, CustomFunctionReverse RevHandle)
Definition CApi.cpp:370
uint8_t(* CustomFunctionDiffUse)(LLVMValueRef, const GradientUtils *, LLVMValueRef, uint8_t, CDerivativeMode, uint8_t *)
Definition CApi.h:202
LLVMValueRef EnzymeCreateForwardDiff(EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode, uint8_t freeMemory, uint8_t runtimeActivity, uint8_t strongZero, unsigned width, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented)
Definition CApi.cpp:661
uint8_t EnzymeSetTypeTree(CTypeTreeRef dst, CTypeTreeRef src)
Definition CApi.cpp:873
EnzymeLogicRef EnzymeTypeAnalysisGetLogic(EnzymeTypeAnalysisRef TAR)
Definition CApi.cpp:319
LLVMValueRef EnzymeCreatePrimalAndGradient(EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed, CDerivativeMode mode, uint8_t runtimeActivity, uint8_t strongZero, unsigned width, uint8_t freeMemory, LLVMTypeRef additionalArg, uint8_t forceAnonymousTape, CFnTypeInfo typeInfo, uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd)
Definition CApi.cpp:687
CProbProgMode
Definition CApi.h:142
@ DEM_Condition
Definition CApi.h:144
@ DEM_Trace
Definition CApi.h:143
void EnzymeTypeTreeOnlyEq(CTypeTreeRef dst, int64_t x)
Definition CApi.cpp:889
void(* CustomFunctionReverse)(LLVMBuilderRef, LLVMValueRef, DiffeGradientUtils *, LLVMValueRef)
Definition CApi.h:212
void FreeTypeAnalysis(EnzymeTypeAnalysisRef)
Definition CApi.cpp:314
LLVMTypeRef EnzymeExtractTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret)
Definition CApi.cpp:813
CDIFFE_TYPE
Definition CApi.h:120
@ DFT_DUP_NONEED
Definition CApi.h:127
@ DFT_OUT_DIFF
Definition CApi.h:121
@ DFT_CONSTANT
Definition CApi.h:126
@ DFT_DUP_ARG
Definition CApi.h:123
EnzymeTraceInterfaceRef CreateEnzymeStaticTraceInterface(LLVMContextRef C, LLVMValueRef getTraceFunction, LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction, LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction, LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction, LLVMValueRef insertChoiceGradientFunction, LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction, LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction, LLVMValueRef hasChoiceFunction)
Definition CApi.cpp:229
void EnzymeTypeTreeShiftIndiciesEq(CTypeTreeRef dst, const char *datalayout, int64_t offset, int64_t maxSize, uint64_t addOffset)
Definition CApi.cpp:909
uint8_t EnzymeMergeTypeTree(CTypeTreeRef dst, CTypeTreeRef src)
Definition CApi.cpp:876
CDerivativeMode
Definition CApi.h:133
@ DEM_ReverseModeGradient
Definition CApi.h:136
@ DEM_ReverseModePrimal
Definition CApi.h:135
@ DEM_ForwardModeError
Definition CApi.h:139
@ DEM_ForwardMode
Definition CApi.h:134
@ DEM_ReverseModeCombined
Definition CApi.h:137
@ DEM_ForwardModeSplit
Definition CApi.h:138
void ClearEnzymeLogic(EnzymeLogicRef)
Definition CApi.cpp:260
EnzymeTypeAnalysisRef CreateTypeAnalysis(EnzymeLogicRef Log, char **customRuleNames, CustomRuleType *customRules, size_t numRules)
Definition CApi.cpp:274
LLVMValueRef EnzymeExtractFunctionFromAugmentation(EnzymeAugmentedReturnPtr ret)
Definition CApi.cpp:801
void EnzymeExtractReturnInfo(EnzymeAugmentedReturnPtr ret, int64_t *data, uint8_t *existed, size_t len)
Definition CApi.cpp:825
LLVMValueRef(* CustomShadowFree)(LLVMBuilderRef, LLVMValueRef)
Definition CApi.h:193
const char * EnzymeTypeTreeToString(CTypeTreeRef src)
Definition CApi.cpp:924
CTypeTreeRef EnzymeNewTypeTreeCT(CConcreteType, LLVMContextRef ctx)
Definition CApi.cpp:866
void FreeEnzymeLogic(EnzymeLogicRef)
Definition CApi.cpp:268
struct EnzymeOpaqueLogic * EnzymeLogicRef
Definition CApi.h:41
void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, CustomShadowFree FHandle)
Definition CApi.cpp:352
CBATCH_TYPE
Definition CApi.h:131
@ BT_SCALAR
Definition CApi.h:131
@ BT_VECTOR
Definition CApi.h:131
LLVMValueRef(* CustomShadowAlloc)(LLVMBuilderRef, LLVMValueRef, size_t, LLVMValueRef *, GradientUtils *)
Definition CApi.h:190
struct EnzymeOpaqueAugmentedReturn * EnzymeAugmentedReturnPtr
Definition CApi.h:44
DerivativeMode mode
TypeAnalysis & TA
EnzymeLogic & Logic
CConcreteType datatype
Definition CApi.h:69
struct IntList offsets
Definition CApi.h:68
struct IntList * KnownValues
The specific constant(s) known to represented by an argument, if constant.
Definition CApi.h:117
CTypeTreeRef * Arguments
Types of arguments, assumed of size len(Arguments)
Definition CApi.h:110
CTypeTreeRef Return
Type of return.
Definition CApi.h:113
Definition CApi.h:49
int64_t * data
Definition CApi.h:50
size_t size
Definition CApi.h:51