29#include "llvm/ADT/SmallPtrSet.h"
30#include "llvm/ADT/SmallVector.h"
32#include "llvm/IR/BasicBlock.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/Instructions.h"
35#include "llvm/IR/Type.h"
36#include "llvm/IR/User.h"
37#include "llvm/IR/Value.h"
38#include "llvm/IR/ValueMap.h"
40#include "llvm/Transforms/Utils/BasicBlockUtils.h"
41#include "llvm/Transforms/Utils/Cloning.h"
48 const SmallPtrSetImpl<Function *> &sampleFunctions,
49 const SmallPtrSetImpl<Function *> &observeFunctions,
50 Function *newFunc, Argument *trace,
51 Argument *observations, Argument *likelihood,
53 : trace(trace), observations(observations), likelihood(likelihood),
54 interface(interface), mode(mode), newFunc(newFunc),
55 sampleFunctions(sampleFunctions.begin(), sampleFunctions.end()),
56 observeFunctions(observeFunctions.begin(), observeFunctions.end()){};
60 const SmallPtrSetImpl<Function *> &sampleFunctions,
61 const SmallPtrSetImpl<Function *> &observeFunctions,
63 ValueToValueMapTy &originalToNewFn) {
64 auto &Context = oldFunc->getContext();
65 FunctionType *orig_FTy = oldFunc->getFunctionType();
66 SmallVector<Type *, 4> params;
68 for (
unsigned i = 0; i < orig_FTy->getNumParams(); ++i) {
69 params.push_back(orig_FTy->getParamType(i));
72 Type *likelihood_acc_type =
getUnqual(Type::getDoubleTy(Context));
73 params.push_back(likelihood_acc_type);
84 Type *RetTy = oldFunc->getReturnType();
85 FunctionType *FTy = FunctionType::get(RetTy, params, oldFunc->isVarArg());
90 mode_str =
"likelihood";
96 mode_str =
"condition";
100 Function *
newFunc = Function::Create(
101 FTy, Function::LinkageTypes::InternalLinkage,
102 Twine(mode_str) +
"_" + oldFunc->getName(), oldFunc->getParent());
104 auto DestArg =
newFunc->arg_begin();
105 auto SrcArg = oldFunc->arg_begin();
107 for (
unsigned i = 0; i < orig_FTy->getNumParams(); ++i) {
108 Argument *arg = SrcArg;
109 originalToNewFn[arg] = DestArg;
110 DestArg->setName(arg->getName());
115 SmallVector<ReturnInst *, 4> Returns;
116 if (!oldFunc->empty()) {
117#if LLVM_VERSION_MAJOR >= 13
118 CloneFunctionInto(
newFunc, oldFunc, originalToNewFn,
119 CloneFunctionChangeType::LocalChangesOnly, Returns,
"",
122 CloneFunctionInto(
newFunc, oldFunc, originalToNewFn,
true, Returns,
"",
127 auto entry = BasicBlock::Create(
newFunc->getContext(),
"entry",
newFunc);
128 IRBuilder<> B(entry);
129 B.CreateUnreachable();
132 newFunc->setLinkage(Function::LinkageTypes::InternalLinkage);
134 Argument *trace =
nullptr;
135 Argument *observations =
nullptr;
136 Argument *likelihood =
nullptr;
143 arg->setName(
"trace");
150 arg->setName(
"observations");
156 arg->setName(
"likelihood");
173std::pair<Value *, Constant *>
174TraceUtils::ValueToVoidPtrAndSize(IRBuilder<> &Builder, Value *val,
176 auto valsize = val->getType()->getPrimitiveSizeInBits();
178 if (val->getType()->isPointerTy()) {
180 Builder.CreatePointerCast(val,
getInt8PtrTy(val->getContext()));
181 return {retval, ConstantInt::get(size_type, valsize / 8)};
184 auto M = Builder.GetInsertBlock()->getModule();
185 auto &DL = M->getDataLayout();
186 auto pointersize = DL.getPointerSizeInBits();
188 if (valsize <= pointersize) {
190 Builder.CreateBitCast(val, IntegerType::get(M->getContext(), valsize));
191 if (valsize != pointersize)
192 cast = Builder.CreateZExt(cast, Builder.getIntPtrTy(DL));
195 Builder.CreateIntToPtr(cast,
getInt8PtrTy(cast->getContext()));
196 return {retval, ConstantInt::get(size_type, valsize / 8)};
199 &Builder.GetInsertBlock()->getParent()->getEntryBlock()));
200 auto alloca = AllocaBuilder.CreateAlloca(val->getType(),
nullptr,
201 val->getName() +
".ptr");
202 Builder.CreateStore(val, alloca);
203 return {alloca, ConstantInt::get(size_type, valsize / 8)};
210#if LLVM_VERSION_MAJOR >= 14
211 call->addAttributeAtIndex(
212 AttributeList::FunctionIndex,
213 Attribute::get(call->getContext(),
"enzyme_newtrace"));
215 call->addAttribute(AttributeList::FunctionIndex,
216 Attribute::get(call->getContext(),
"enzyme_newtrace"));
225#if LLVM_VERSION_MAJOR >= 14
226 call->addAttributeAtIndex(
227 AttributeList::FunctionIndex,
228 Attribute::get(call->getContext(),
"enzyme_freetrace"));
230 call->addAttribute(AttributeList::FunctionIndex,
231 Attribute::get(call->getContext(),
"enzyme_freetrace"));
238 Value *score, Value *choice) {
240 auto &&[retval, sizeval] = ValueToVoidPtrAndSize(Builder, choice, size_type);
242 Value *args[] = {trace, address, score, retval, sizeval};
246 call->addParamAttr(1, Attribute::ReadOnly);
254 Value *args[] = {trace, address, subtrace};
258 call->addParamAttr(1, Attribute::ReadOnly);
260#if LLVM_VERSION_MAJOR >= 14
261 call->addAttributeAtIndex(
262 AttributeList::FunctionIndex,
263 Attribute::get(call->getContext(),
"enzyme_insert_call"));
265 call->addAttribute(AttributeList::FunctionIndex,
266 Attribute::get(call->getContext(),
"enzyme_insert_call"));
275 auto &&[retval, sizeval] =
276 ValueToVoidPtrAndSize(Builder, argument, size_type);
278 Value *args[] = {trace, name, retval, sizeval};
282 call->addParamAttr(1, Attribute::ReadOnly);
289 auto &&[retval, sizeval] = ValueToVoidPtrAndSize(Builder, val, size_type);
291 Value *args[] = {trace, retval, sizeval};
299 assert(!function->isIntrinsic());
301 Builder.CreateBitCast(function,
getInt8PtrTy(function->getContext()));
303 Value *args[] = {trace, FunctionPtr};
311 FunctionType *interface_type,
312 Value *interface_function,
313 Value *address, Value *choice,
315 Type *size_type = interface_type->getParamType(3);
316 auto &&[retval, sizeval] = ValueToVoidPtrAndSize(Builder, choice, size_type);
318 Value *args[] = {trace, address, retval, sizeval};
320 auto call = Builder.CreateCall(interface_type, interface_function, args);
321 call->addParamAttr(1, Attribute::ReadOnly);
327 FunctionType *interface_type,
328 Value *interface_function,
329 Value *name, Value *argument,
331 Type *size_type = interface_type->getParamType(3);
332 auto &&[retval, sizeval] =
333 ValueToVoidPtrAndSize(Builder, argument, size_type);
335 Value *args[] = {trace, name, retval, sizeval};
337 auto call = Builder.CreateCall(interface_type, interface_function, args);
338 call->addParamAttr(1, Attribute::ReadOnly);
345 assert(address->getType()->isPointerTy());
347 Value *args[] = {observations, address};
351 call->addParamAttr(1, Attribute::ReadOnly);
357 Type *choiceType,
const Twine &Name) {
359 &Builder.GetInsertBlock()->getParent()->getEntryBlock()));
360 AllocaInst *store_dest =
361 AllocaBuilder.CreateAlloca(choiceType,
nullptr, Name +
".ptr");
362 auto preallocated_size = choiceType->getPrimitiveSizeInBits() / 8;
365 Value *args[] = {observations, address,
366 Builder.CreatePointerCast(
368 ConstantInt::get(size_type, preallocated_size)};
374#if LLVM_VERSION_MAJOR >= 14
375 call->addAttributeAtIndex(
376 AttributeList::FunctionIndex,
377 Attribute::get(call->getContext(),
"enzyme_inactive"));
379 call->addAttribute(AttributeList::FunctionIndex,
380 Attribute::get(call->getContext(),
"enzyme_inactive"));
382 call->addParamAttr(1, Attribute::ReadOnly);
384 return Builder.CreateLoad(choiceType, store_dest,
"from.trace." + Name);
389 Value *args[]{observations, address};
393 call->addParamAttr(1, Attribute::ReadOnly);
400 Value *args[]{observations, address};
404 call->addParamAttr(1, Attribute::ReadOnly);
411 ArrayRef<Value *> sample_args,
412 Value *address,
const Twine &Name) {
413 auto &Context = Builder.getContext();
414 auto parent_fn = Builder.GetInsertBlock()->getParent();
419 auto sample_call = Builder.CreateCall(sample_fn->getFunctionType(),
420 sample_fn, sample_args);
424 Instruction *hasChoice =
HasChoice(Builder, address,
"has.choice." + Name);
426 Value *ThenChoice, *ElseChoice;
427 BasicBlock *ThenBlock = BasicBlock::Create(
428 Context,
"condition." + Name +
".with.trace", parent_fn);
429 BasicBlock *ElseBlock = BasicBlock::Create(
430 Context,
"condition." + Name +
".without.trace", parent_fn);
431 BasicBlock *EndBlock = BasicBlock::Create(Context,
"end", parent_fn);
433 Builder.CreateCondBr(hasChoice, ThenBlock, ElseBlock);
434 Builder.SetInsertPoint(ThenBlock);
435 ThenChoice =
GetChoice(Builder, address, sample_fn->getReturnType(), Name);
436 Builder.CreateBr(EndBlock);
438 Builder.SetInsertPoint(ElseBlock);
439 ElseChoice = Builder.CreateCall(sample_fn->getFunctionType(), sample_fn,
440 sample_args,
"sample." + Name);
441 Builder.CreateBr(EndBlock);
443 Builder.SetInsertPoint(EndBlock);
444 auto phi = Builder.CreatePHI(sample_fn->getReturnType(), 2);
445 phi->addIncoming(ThenChoice, ThenBlock);
446 phi->addIncoming(ElseChoice, ElseBlock);
451 llvm_unreachable(
"Invalid sample_or_condition");
455 IRBuilder<> &Builder,
456 function_ref<
void(IRBuilder<> &,
TraceUtils *, ArrayRef<Value *>)> Outlined,
457 Type *RetTy, ArrayRef<Value *> Arguments,
bool needsLikelihood,
459 SmallVector<Type *, 4> Tys;
460 SmallVector<Value *, 4> Vals;
461 Module *M = Builder.GetInsertBlock()->getModule();
463 for (
auto Arg : Arguments) {
465 Tys.push_back(Arg->getType());
468 if (needsLikelihood) {
469 Vals.push_back(likelihood);
470 Tys.push_back(likelihood->getType());
474 Vals.push_back(observations);
475 Tys.push_back(observations->getType());
479 Vals.push_back(trace);
480 Tys.push_back(trace->getType());
483 FunctionType *FTy = FunctionType::get(RetTy, Tys,
false);
485 Function::Create(FTy, Function::LinkageTypes::InternalLinkage, Name, M);
486 F->addFnAttr(Attribute::AlwaysInline);
488 auto Entry = BasicBlock::Create(M->getContext(),
"entry", F);
490 auto ArgRange = make_pointer_range(
491 make_range(F->arg_begin(), F->arg_begin() + Arguments.size()));
492 SmallVector<Value *, 4> Rets(ArgRange);
494 auto idx = F->arg_begin() + Arguments.size();
496 Argument *likelihood_arg =
nullptr;
498 likelihood_arg = idx++;
500 Argument *observations_arg =
nullptr;
502 observations_arg = idx++;
504 Argument *trace_arg =
nullptr;
510 observations_arg, likelihood_arg,
interface);
511 IRBuilder<> OutlineBuilder(Entry);
512 Outlined(OutlineBuilder, &OutlineTutils, Rets);
514 return Builder.CreateCall(FTy, F, Vals);
static Operation * getFunctionFromCall(CallOpInterface iface)
PointerType * traceType(LLVMContext &C)
static llvm::PointerType * getUnqual(llvm::Type *T)
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
static void addCallSiteNoCapture(llvm::CallBase *call, size_t idx)
static llvm::Instruction * getFirstNonPHIOrDbgOrLifetime(llvm::BasicBlock *B)
virtual llvm::Value * insertFunction(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * insertChoice(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * newTrace(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * freeTrace(llvm::IRBuilder<> &Builder)=0
llvm::FunctionType * hasCallTy()
llvm::FunctionType * insertFunctionTy()
llvm::FunctionType * hasChoiceTy()
llvm::FunctionType * insertChoiceTy()
virtual llvm::Value * insertArgument(llvm::IRBuilder<> &Builder)=0
llvm::FunctionType * insertArgumentTy()
llvm::FunctionType * insertCallTy()
virtual llvm::Value * hasCall(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * insertReturn(llvm::IRBuilder<> &Builder)=0
llvm::FunctionType * freeTraceTy()
virtual llvm::Value * getChoice(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * hasChoice(llvm::IRBuilder<> &Builder)=0
llvm::FunctionType * insertReturnTy()
virtual llvm::Value * insertCall(llvm::IRBuilder<> &Builder)=0
llvm::FunctionType * getChoiceTy()
llvm::FunctionType * newTraceTy()
llvm::FunctionType * getTraceTy()
virtual llvm::Value * getTrace(llvm::IRBuilder<> &Builder)=0
llvm::CallInst * InsertFunction(llvm::IRBuilder<> &Builder, llvm::Function *function)
static TraceUtils * FromClone(ProbProgMode mode, const llvm::SmallPtrSetImpl< llvm::Function * > &sampleFunctions, const llvm::SmallPtrSetImpl< llvm::Function * > &observeFunctions, TraceInterface *interface, llvm::Function *oldFunc, llvm::ValueMap< const llvm::Value *, llvm::WeakTrackingVH > &originalToNewFn)
static constexpr const char TraceParameterAttribute[]
llvm::CallInst * CreateTrace(llvm::IRBuilder<> &Builder, const llvm::Twine &Name="trace")
llvm::Instruction * SampleOrCondition(llvm::IRBuilder<> &Builder, llvm::Function *sample_fn, llvm::ArrayRef< llvm::Value * > sample_args, llvm::Value *address, const llvm::Twine &Name="")
TraceUtils(ProbProgMode mode, const llvm::SmallPtrSetImpl< llvm::Function * > &sampleFunctions, const llvm::SmallPtrSetImpl< llvm::Function * > &observeFunctions, llvm::Function *newFunc, llvm::Argument *trace, llvm::Argument *observations, llvm::Argument *likelihood, TraceInterface *interface)
llvm::CallInst * InsertChoice(llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Value *score, llvm::Value *choice)
llvm::Instruction * GetChoice(llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Type *choiceType, const llvm::Twine &Name="")
llvm::CallInst * InsertReturn(llvm::IRBuilder<> &Builder, llvm::Value *ret)
TraceInterface * getTraceInterface()
static llvm::CallInst * InsertChoiceGradient(llvm::IRBuilder<> &Builder, llvm::FunctionType *interface_type, llvm::Value *interface_function, llvm::Value *address, llvm::Value *choice, llvm::Value *trace)
TraceInterface * interface
llvm::CallInst * FreeTrace(llvm::IRBuilder<> &Builder)
llvm::Instruction * HasCall(llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
static constexpr const char ObservationsParameterAttribute[]
llvm::Instruction * HasChoice(llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
llvm::SmallPtrSet< llvm::Function *, 4 > observeFunctions
llvm::CallInst * CreateOutlinedFunction(llvm::IRBuilder<> &Builder, llvm::function_ref< void(llvm::IRBuilder<> &, TraceUtils *, llvm::ArrayRef< llvm::Value * >)> Outlined, llvm::Type *RetTy, llvm::ArrayRef< llvm::Value * > Arguments, bool needsLikelihood=true, const llvm::Twine &Name="")
static constexpr const char LikelihoodParameterAttribute[]
static llvm::CallInst * InsertArgumentGradient(llvm::IRBuilder<> &Builder, llvm::FunctionType *interface_type, llvm::Value *interface_function, llvm::Value *name, llvm::Value *argument, llvm::Value *trace)
llvm::SmallPtrSet< llvm::Function *, 4 > sampleFunctions
llvm::Value * getLikelihood()
bool isObserveCall(llvm::CallInst *call)
llvm::CallInst * InsertArgument(llvm::IRBuilder<> &Builder, llvm::Value *name, llvm::Value *argument)
llvm::CallInst * GetTrace(llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
bool isSampleCall(llvm::CallInst *call)
llvm::Value * getObservations()
llvm::CallInst * InsertCall(llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Value *subtrace)