31#include "llvm/IR/Function.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/Instructions.h"
34#include "llvm/IR/Type.h"
35#include "llvm/IR/Value.h"
48 return IntegerType::getInt64Ty(C);
76 return FunctionType::get(
82 return FunctionType::get(Type::getVoidTy(C),
88 return FunctionType::get(Type::getVoidTy(C),
95 return FunctionType::get(
101 return FunctionType::get(Type::getVoidTy(C),
107 return FunctionType::get(Type::getVoidTy(C),
112 return FunctionType::get(
118 return FunctionType::get(
128 return FunctionType::get(Type::getVoidTy(C), {
getInt8PtrTy(C)},
false);
143 for (
auto &&F : M->functions()) {
146 if (F.getName().contains(
"__enzyme_newtrace")) {
147 assert(F.getFunctionType() == newTraceTy());
148 newTraceFunction = &F;
149 }
else if (F.getName().contains(
"__enzyme_freetrace")) {
150 assert(F.getFunctionType() == freeTraceTy());
151 freeTraceFunction = &F;
152 }
else if (F.getName().contains(
"__enzyme_get_trace")) {
153 assert(F.getFunctionType() == getTraceTy());
154 getTraceFunction = &F;
155 }
else if (F.getName().contains(
"__enzyme_get_choice")) {
156 assert(F.getFunctionType() == getChoiceTy());
157 getChoiceFunction = &F;
158 }
else if (F.getName().contains(
"__enzyme_insert_call")) {
159 assert(F.getFunctionType() == insertCallTy());
160 insertCallFunction = &F;
161 }
else if (F.getName().contains(
"__enzyme_insert_choice")) {
162 assert(F.getFunctionType() == insertChoiceTy());
163 insertChoiceFunction = &F;
164 }
else if (F.getName().contains(
"__enzyme_insert_argument")) {
165 assert(F.getFunctionType() == insertArgumentTy());
166 insertArgumentFunction = &F;
167 }
else if (F.getName().contains(
"__enzyme_insert_return")) {
168 assert(F.getFunctionType() == insertReturnTy());
169 insertReturnFunction = &F;
170 }
else if (F.getName().contains(
"__enzyme_insert_function")) {
171 assert(F.getFunctionType() == insertFunctionTy());
172 insertFunctionFunction = &F;
173 }
else if (F.getName().contains(
"__enzyme_insert_gradient_choice")) {
174 assert(F.getFunctionType() == insertChoiceGradientTy());
175 insertChoiceGradientFunction = &F;
176 }
else if (F.getName().contains(
"__enzyme_insert_gradient_argument")) {
177 assert(F.getFunctionType() == insertArgumentGradientTy());
178 insertArgumentGradientFunction = &F;
179 }
else if (F.getName().contains(
"__enzyme_has_call")) {
180 assert(F.getFunctionType() == hasCallTy());
181 hasCallFunction = &F;
182 }
else if (F.getName().contains(
"__enzyme_has_choice")) {
183 assert(F.getFunctionType() == hasChoiceTy());
184 hasChoiceFunction = &F;
188 assert(newTraceFunction);
189 assert(freeTraceFunction);
190 assert(getTraceFunction);
191 assert(getChoiceFunction);
192 assert(insertCallFunction);
193 assert(insertChoiceFunction);
195 assert(insertArgumentFunction);
196 assert(insertReturnFunction);
197 assert(insertFunctionFunction);
199 assert(insertChoiceGradientFunction);
200 assert(insertArgumentGradientFunction);
202 assert(hasCallFunction);
203 assert(hasChoiceFunction);
205 newTraceFunction->addFnAttr(
"enzyme_notypeanalysis");
206 freeTraceFunction->addFnAttr(
"enzyme_notypeanalysis");
207 getTraceFunction->addFnAttr(
"enzyme_notypeanalysis");
208 getChoiceFunction->addFnAttr(
"enzyme_notypeanalysis");
209 insertCallFunction->addFnAttr(
"enzyme_notypeanalysis");
210 insertChoiceFunction->addFnAttr(
"enzyme_notypeanalysis");
211 insertArgumentFunction->addFnAttr(
"enzyme_notypeanalysis");
212 insertReturnFunction->addFnAttr(
"enzyme_notypeanalysis");
213 insertFunctionFunction->addFnAttr(
"enzyme_notypeanalysis");
214 insertChoiceGradientFunction->addFnAttr(
"enzyme_notypeanalysis");
215 insertArgumentGradientFunction->addFnAttr(
"enzyme_notypeanalysis");
216 hasCallFunction->addFnAttr(
"enzyme_notypeanalysis");
217 hasChoiceFunction->addFnAttr(
"enzyme_notypeanalysis");
219 newTraceFunction->addFnAttr(
"enzyme_inactive");
220 freeTraceFunction->addFnAttr(
"enzyme_inactive");
221 getTraceFunction->addFnAttr(
"enzyme_inactive");
222 getChoiceFunction->addFnAttr(
"enzyme_inactive");
223 insertCallFunction->addFnAttr(
"enzyme_inactive");
224 insertChoiceFunction->addFnAttr(
"enzyme_inactive");
225 insertArgumentFunction->addFnAttr(
"enzyme_inactive");
226 insertReturnFunction->addFnAttr(
"enzyme_inactive");
227 insertFunctionFunction->addFnAttr(
"enzyme_inactive");
228 insertChoiceGradientFunction->addFnAttr(
"enzyme_inactive");
229 insertArgumentGradientFunction->addFnAttr(
"enzyme_inactive");
230 hasCallFunction->addFnAttr(
"enzyme_inactive");
231 hasChoiceFunction->addFnAttr(
"enzyme_inactive");
233 newTraceFunction->addFnAttr(Attribute::NoFree);
234 getTraceFunction->addFnAttr(Attribute::NoFree);
235 getChoiceFunction->addFnAttr(Attribute::NoFree);
236 insertCallFunction->addFnAttr(Attribute::NoFree);
237 insertChoiceFunction->addFnAttr(Attribute::NoFree);
238 insertArgumentFunction->addFnAttr(Attribute::NoFree);
239 insertReturnFunction->addFnAttr(Attribute::NoFree);
240 insertFunctionFunction->addFnAttr(Attribute::NoFree);
241 insertChoiceGradientFunction->addFnAttr(Attribute::NoFree);
242 insertArgumentGradientFunction->addFnAttr(Attribute::NoFree);
243 hasCallFunction->addFnAttr(Attribute::NoFree);
244 hasChoiceFunction->addFnAttr(Attribute::NoFree);
248 LLVMContext &C, Function *getTraceFunction, Function *getChoiceFunction,
249 Function *insertCallFunction, Function *insertChoiceFunction,
250 Function *insertArgumentFunction, Function *insertReturnFunction,
251 Function *insertFunctionFunction, Function *insertChoiceGradientFunction,
252 Function *insertArgumentGradientFunction, Function *newTraceFunction,
253 Function *freeTraceFunction, Function *hasCallFunction,
254 Function *hasChoiceFunction)
256 getChoiceFunction(getChoiceFunction),
257 insertCallFunction(insertCallFunction),
258 insertChoiceFunction(insertChoiceFunction),
259 insertArgumentFunction(insertArgumentFunction),
260 insertReturnFunction(insertReturnFunction),
261 insertFunctionFunction(insertFunctionFunction),
262 insertChoiceGradientFunction(insertChoiceGradientFunction),
263 insertArgumentGradientFunction(insertArgumentGradientFunction),
264 newTraceFunction(newTraceFunction), freeTraceFunction(freeTraceFunction),
265 hasCallFunction(hasCallFunction), hasChoiceFunction(hasChoiceFunction){};
269 return getTraceFunction;
272 return getChoiceFunction;
275 return insertCallFunction;
278 return insertChoiceFunction;
281 return insertArgumentFunction;
284 return insertReturnFunction;
287 return insertFunctionFunction;
290 return insertChoiceGradientFunction;
293 return insertArgumentGradientFunction;
296 return newTraceFunction;
299 return freeTraceFunction;
302 return hasCallFunction;
305 return hasChoiceFunction;
311 assert(dynamicInterface);
313 auto &M = *F->getParent();
316 getTraceFunction = MaterializeInterfaceFunction(
317 Builder, dynamicInterface,
getTraceTy(), 0, M,
"get_trace");
318 getChoiceFunction = MaterializeInterfaceFunction(
319 Builder, dynamicInterface,
getChoiceTy(), 1, M,
"get_choice");
320 insertCallFunction = MaterializeInterfaceFunction(
321 Builder, dynamicInterface,
insertCallTy(), 2, M,
"insert_call");
322 insertChoiceFunction = MaterializeInterfaceFunction(
323 Builder, dynamicInterface,
insertChoiceTy(), 3, M,
"insert_choice");
324 insertArgumentFunction = MaterializeInterfaceFunction(
326 insertReturnFunction = MaterializeInterfaceFunction(
327 Builder, dynamicInterface,
insertReturnTy(), 5, M,
"insert_return");
328 insertFunctionFunction = MaterializeInterfaceFunction(
330 insertChoiceGradientFunction = MaterializeInterfaceFunction(
332 "insert_choice_gradient");
333 insertArgumentGradientFunction = MaterializeInterfaceFunction(
335 "insert_argument_gradient");
336 newTraceFunction = MaterializeInterfaceFunction(
337 Builder, dynamicInterface,
newTraceTy(), 9, M,
"new_trace");
338 freeTraceFunction = MaterializeInterfaceFunction(
339 Builder, dynamicInterface,
freeTraceTy(), 10, M,
"free_trace");
340 hasCallFunction = MaterializeInterfaceFunction(
341 Builder, dynamicInterface,
hasCallTy(), 11, M,
"has_call");
342 hasChoiceFunction = MaterializeInterfaceFunction(
343 Builder, dynamicInterface,
hasChoiceTy(), 12, M,
"has_choice");
345 assert(newTraceFunction);
346 assert(freeTraceFunction);
347 assert(getTraceFunction);
348 assert(getChoiceFunction);
349 assert(insertCallFunction);
350 assert(insertChoiceFunction);
352 assert(insertArgumentFunction);
353 assert(insertReturnFunction);
354 assert(insertFunctionFunction);
356 assert(insertChoiceGradientFunction);
357 assert(insertArgumentGradientFunction);
359 assert(hasCallFunction);
360 assert(hasChoiceFunction);
363Function *DynamicTraceInterface::MaterializeInterfaceFunction(
364 IRBuilder<> &Builder, Value *dynamicInterface, FunctionType *FTy,
365 unsigned index, Module &M,
const Twine &Name) {
367 Builder.CreateInBoundsGEP(
getInt8PtrTy(dynamicInterface->getContext()),
368 dynamicInterface, Builder.getInt32(index));
370 Builder.CreateLoad(
getInt8PtrTy(dynamicInterface->getContext()), ptr);
371 auto pty = PointerType::get(FTy, load->getPointerAddressSpace());
372 auto cast = Builder.CreatePointerCast(load, pty);
375 new GlobalVariable(M, pty,
false, GlobalVariable::PrivateLinkage,
376 ConstantPointerNull::get(pty), Name +
"_ptr");
377 Builder.CreateStore(cast, global);
379 Function *F = Function::Create(FTy, Function::PrivateLinkage, Name, M);
380 F->addFnAttr(Attribute::AlwaysInline);
381 BasicBlock *Entry = BasicBlock::Create(M.getContext(),
"entry", F);
383 IRBuilder<> WrapperBuilder(Entry);
385 auto ToWrap = WrapperBuilder.CreateLoad(pty, global, Name);
386 auto Args = SmallVector<Value *, 4>(make_pointer_range(F->args()));
387 auto Call = WrapperBuilder.CreateCall(FTy, ToWrap,
Args);
389 if (!FTy->getReturnType()->isVoidTy()) {
390 WrapperBuilder.CreateRet(
Call);
392 WrapperBuilder.CreateRetVoid();
400 return getTraceFunction;
404 return getChoiceFunction;
408 return insertCallFunction;
412 return insertChoiceFunction;
416 return insertArgumentFunction;
420 return insertReturnFunction;
424 return insertFunctionFunction;
428 return insertChoiceGradientFunction;
432 return insertArgumentGradientFunction;
436 return newTraceFunction;
440 return freeTraceFunction;
444 return hasCallFunction;
448 return hasChoiceFunction;
Type * addressType(LLVMContext &C)
PointerType * traceType(LLVMContext &C)
llvm::PointerType * getDefaultAnonymousTapeType(llvm::LLVMContext &C)
@ Args
Return is a struct of all args.
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
static llvm::Instruction * getFirstNonPHIOrDbg(llvm::BasicBlock *B)
llvm::Value * insertArgumentGradient(llvm::IRBuilder<> &Builder)
llvm::Value * insertChoiceGradient(llvm::IRBuilder<> &Builder)
DynamicTraceInterface(llvm::Value *dynamicInterface, llvm::Function *F)
llvm::Value * hasChoice(llvm::IRBuilder<> &Builder)
llvm::Value * hasCall(llvm::IRBuilder<> &Builder)
llvm::Value * getChoice(llvm::IRBuilder<> &Builder)
llvm::Value * freeTrace(llvm::IRBuilder<> &Builder)
llvm::Value * insertChoice(llvm::IRBuilder<> &Builder)
llvm::Value * insertFunction(llvm::IRBuilder<> &Builder)
llvm::Value * insertReturn(llvm::IRBuilder<> &Builder)
llvm::Value * getTrace(llvm::IRBuilder<> &Builder)
llvm::Value * insertArgument(llvm::IRBuilder<> &Builder)
llvm::Value * insertCall(llvm::IRBuilder<> &Builder)
llvm::Value * newTrace(llvm::IRBuilder<> &Builder)
llvm::Value * getChoice(llvm::IRBuilder<> &Builder)
llvm::Value * insertCall(llvm::IRBuilder<> &Builder)
llvm::Value * insertArgument(llvm::IRBuilder<> &Builder)
llvm::Value * insertReturn(llvm::IRBuilder<> &Builder)
llvm::Value * insertChoiceGradient(llvm::IRBuilder<> &Builder)
llvm::Value * freeTrace(llvm::IRBuilder<> &Builder)
llvm::Value * insertArgumentGradient(llvm::IRBuilder<> &Builder)
llvm::Value * newTrace(llvm::IRBuilder<> &Builder)
llvm::Value * insertChoice(llvm::IRBuilder<> &Builder)
llvm::Value * getTrace(llvm::IRBuilder<> &Builder)
llvm::Value * hasChoice(llvm::IRBuilder<> &Builder)
StaticTraceInterface(llvm::Module *M)
llvm::Value * insertFunction(llvm::IRBuilder<> &Builder)
llvm::Value * hasCall(llvm::IRBuilder<> &Builder)
llvm::FunctionType * insertArgumentGradientTy()
TraceInterface(llvm::LLVMContext &C)
llvm::FunctionType * hasCallTy()
llvm::FunctionType * insertFunctionTy()
llvm::FunctionType * hasChoiceTy()
llvm::FunctionType * insertChoiceTy()
llvm::FunctionType * insertChoiceGradientTy()
llvm::FunctionType * insertArgumentTy()
static llvm::Type * stringType(llvm::LLVMContext &C)
llvm::FunctionType * insertCallTy()
llvm::FunctionType * freeTraceTy()
static llvm::IntegerType * sizeType(llvm::LLVMContext &C)
llvm::FunctionType * insertReturnTy()
llvm::FunctionType * getChoiceTy()
llvm::FunctionType * newTraceTy()
llvm::FunctionType * getTraceTy()