Enzyme main
Loading...
Searching...
No Matches
TraceUtils.cpp
Go to the documentation of this file.
1//===- TraceUtils.cpp - Utilites for interacting with traces ------------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains utilities for interacting with probabilistic programming
22// traces using the probabilistic programming
23// trace interface
24//
25//===----------------------------------------------------------------------===//
26
27#include "TraceUtils.h"
28
29#include "llvm/ADT/SmallPtrSet.h"
30#include "llvm/ADT/SmallVector.h"
31
32#include "llvm/IR/BasicBlock.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/Instructions.h"
35#include "llvm/IR/Type.h"
36#include "llvm/IR/User.h"
37#include "llvm/IR/Value.h"
38#include "llvm/IR/ValueMap.h"
39
40#include "llvm/Transforms/Utils/BasicBlockUtils.h"
41#include "llvm/Transforms/Utils/Cloning.h"
42
43#include "TraceInterface.h"
44
45using namespace llvm;
46
48 const SmallPtrSetImpl<Function *> &sampleFunctions,
49 const SmallPtrSetImpl<Function *> &observeFunctions,
50 Function *newFunc, Argument *trace,
51 Argument *observations, Argument *likelihood,
52 TraceInterface *interface)
53 : trace(trace), observations(observations), likelihood(likelihood),
54 interface(interface), mode(mode), newFunc(newFunc),
55 sampleFunctions(sampleFunctions.begin(), sampleFunctions.end()),
56 observeFunctions(observeFunctions.begin(), observeFunctions.end()){};
57
60 const SmallPtrSetImpl<Function *> &sampleFunctions,
61 const SmallPtrSetImpl<Function *> &observeFunctions,
62 TraceInterface *interface, Function *oldFunc,
63 ValueToValueMapTy &originalToNewFn) {
64 auto &Context = oldFunc->getContext();
65 FunctionType *orig_FTy = oldFunc->getFunctionType();
66 SmallVector<Type *, 4> params;
67
68 for (unsigned i = 0; i < orig_FTy->getNumParams(); ++i) {
69 params.push_back(orig_FTy->getParamType(i));
70 }
71
72 Type *likelihood_acc_type = getUnqual(Type::getDoubleTy(Context));
73 params.push_back(likelihood_acc_type);
74
76 Type *traceType = interface->getTraceTy()->getReturnType();
77
79 params.push_back(traceType);
80
81 params.push_back(traceType);
82 }
83
84 Type *RetTy = oldFunc->getReturnType();
85 FunctionType *FTy = FunctionType::get(RetTy, params, oldFunc->isVarArg());
86
87 const char *mode_str;
88 switch (mode) {
90 mode_str = "likelihood";
91 break;
93 mode_str = "trace";
94 break;
96 mode_str = "condition";
97 break;
98 }
99
100 Function *newFunc = Function::Create(
101 FTy, Function::LinkageTypes::InternalLinkage,
102 Twine(mode_str) + "_" + oldFunc->getName(), oldFunc->getParent());
103
104 auto DestArg = newFunc->arg_begin();
105 auto SrcArg = oldFunc->arg_begin();
106
107 for (unsigned i = 0; i < orig_FTy->getNumParams(); ++i) {
108 Argument *arg = SrcArg;
109 originalToNewFn[arg] = DestArg;
110 DestArg->setName(arg->getName());
111 DestArg++;
112 SrcArg++;
113 }
114
115 SmallVector<ReturnInst *, 4> Returns;
116 if (!oldFunc->empty()) {
117#if LLVM_VERSION_MAJOR >= 13
118 CloneFunctionInto(newFunc, oldFunc, originalToNewFn,
119 CloneFunctionChangeType::LocalChangesOnly, Returns, "",
120 nullptr);
121#else
122 CloneFunctionInto(newFunc, oldFunc, originalToNewFn, true, Returns, "",
123 nullptr);
124#endif
125 }
126 if (newFunc->empty()) {
127 auto entry = BasicBlock::Create(newFunc->getContext(), "entry", newFunc);
128 IRBuilder<> B(entry);
129 B.CreateUnreachable();
130 }
131
132 newFunc->setLinkage(Function::LinkageTypes::InternalLinkage);
133
134 Argument *trace = nullptr;
135 Argument *observations = nullptr;
136 Argument *likelihood = nullptr;
137
138 auto arg = newFunc->arg_end();
139
141 arg -= 1;
142 trace = arg;
143 arg->setName("trace");
144 arg->addAttr(Attribute::get(Context, TraceParameterAttribute));
145 }
146
148 arg -= 1;
149 observations = arg;
150 arg->setName("observations");
151 arg->addAttr(Attribute::get(Context, ObservationsParameterAttribute));
152 }
153
154 arg -= 1;
155 likelihood = arg;
156 arg->setName("likelihood");
157 arg->addAttr(Attribute::get(Context, LikelihoodParameterAttribute));
158
160 observations, likelihood, interface);
161};
162
163TraceUtils::~TraceUtils() = default;
164
166
167Value *TraceUtils::getTrace() { return trace; }
168
169Value *TraceUtils::getObservations() { return observations; }
170
171Value *TraceUtils::getLikelihood() { return likelihood; }
172
173std::pair<Value *, Constant *>
174TraceUtils::ValueToVoidPtrAndSize(IRBuilder<> &Builder, Value *val,
175 Type *size_type) {
176 auto valsize = val->getType()->getPrimitiveSizeInBits();
177
178 if (val->getType()->isPointerTy()) {
179 Value *retval =
180 Builder.CreatePointerCast(val, getInt8PtrTy(val->getContext()));
181 return {retval, ConstantInt::get(size_type, valsize / 8)};
182 }
183
184 auto M = Builder.GetInsertBlock()->getModule();
185 auto &DL = M->getDataLayout();
186 auto pointersize = DL.getPointerSizeInBits();
187
188 if (valsize <= pointersize) {
189 auto cast =
190 Builder.CreateBitCast(val, IntegerType::get(M->getContext(), valsize));
191 if (valsize != pointersize)
192 cast = Builder.CreateZExt(cast, Builder.getIntPtrTy(DL));
193
194 Value *retval =
195 Builder.CreateIntToPtr(cast, getInt8PtrTy(cast->getContext()));
196 return {retval, ConstantInt::get(size_type, valsize / 8)};
197 } else {
198 IRBuilder<> AllocaBuilder(getFirstNonPHIOrDbgOrLifetime(
199 &Builder.GetInsertBlock()->getParent()->getEntryBlock()));
200 auto alloca = AllocaBuilder.CreateAlloca(val->getType(), nullptr,
201 val->getName() + ".ptr");
202 Builder.CreateStore(val, alloca);
203 return {alloca, ConstantInt::get(size_type, valsize / 8)};
204 }
205}
206
207CallInst *TraceUtils::CreateTrace(IRBuilder<> &Builder, const Twine &Name) {
208 auto call = Builder.CreateCall(interface->newTraceTy(),
209 interface->newTrace(Builder), {}, Name);
210#if LLVM_VERSION_MAJOR >= 14
211 call->addAttributeAtIndex(
212 AttributeList::FunctionIndex,
213 Attribute::get(call->getContext(), "enzyme_newtrace"));
214#else
215 call->addAttribute(AttributeList::FunctionIndex,
216 Attribute::get(call->getContext(), "enzyme_newtrace"));
217
218#endif
219 return call;
220}
221
222CallInst *TraceUtils::FreeTrace(IRBuilder<> &Builder) {
223 auto call = Builder.CreateCall(interface->freeTraceTy(),
224 interface->freeTrace(Builder), {trace});
225#if LLVM_VERSION_MAJOR >= 14
226 call->addAttributeAtIndex(
227 AttributeList::FunctionIndex,
228 Attribute::get(call->getContext(), "enzyme_freetrace"));
229#else
230 call->addAttribute(AttributeList::FunctionIndex,
231 Attribute::get(call->getContext(), "enzyme_freetrace"));
232
233#endif
234 return call;
235}
236
237CallInst *TraceUtils::InsertChoice(IRBuilder<> &Builder, Value *address,
238 Value *score, Value *choice) {
239 Type *size_type = interface->insertChoiceTy()->getParamType(4);
240 auto &&[retval, sizeval] = ValueToVoidPtrAndSize(Builder, choice, size_type);
241
242 Value *args[] = {trace, address, score, retval, sizeval};
243
244 auto call = Builder.CreateCall(interface->insertChoiceTy(),
245 interface->insertChoice(Builder), args);
246 call->addParamAttr(1, Attribute::ReadOnly);
247
248 addCallSiteNoCapture(call, 1);
249 return call;
250}
251
252CallInst *TraceUtils::InsertCall(IRBuilder<> &Builder, Value *address,
253 Value *subtrace) {
254 Value *args[] = {trace, address, subtrace};
255
256 auto call = Builder.CreateCall(interface->insertCallTy(),
257 interface->insertCall(Builder), args);
258 call->addParamAttr(1, Attribute::ReadOnly);
259 addCallSiteNoCapture(call, 1);
260#if LLVM_VERSION_MAJOR >= 14
261 call->addAttributeAtIndex(
262 AttributeList::FunctionIndex,
263 Attribute::get(call->getContext(), "enzyme_insert_call"));
264#else
265 call->addAttribute(AttributeList::FunctionIndex,
266 Attribute::get(call->getContext(), "enzyme_insert_call"));
267
268#endif
269 return call;
270}
271
272CallInst *TraceUtils::InsertArgument(IRBuilder<> &Builder, Value *name,
273 Value *argument) {
274 Type *size_type = interface->insertArgumentTy()->getParamType(3);
275 auto &&[retval, sizeval] =
276 ValueToVoidPtrAndSize(Builder, argument, size_type);
277
278 Value *args[] = {trace, name, retval, sizeval};
279
280 auto call = Builder.CreateCall(interface->insertArgumentTy(),
281 interface->insertArgument(Builder), args);
282 call->addParamAttr(1, Attribute::ReadOnly);
283 addCallSiteNoCapture(call, 1);
284 return call;
285}
286
287CallInst *TraceUtils::InsertReturn(IRBuilder<> &Builder, Value *val) {
288 Type *size_type = interface->insertReturnTy()->getParamType(2);
289 auto &&[retval, sizeval] = ValueToVoidPtrAndSize(Builder, val, size_type);
290
291 Value *args[] = {trace, retval, sizeval};
292
293 auto call = Builder.CreateCall(interface->insertReturnTy(),
294 interface->insertReturn(Builder), args);
295 return call;
296}
297
298CallInst *TraceUtils::InsertFunction(IRBuilder<> &Builder, Function *function) {
299 assert(!function->isIntrinsic());
300 auto FunctionPtr =
301 Builder.CreateBitCast(function, getInt8PtrTy(function->getContext()));
302
303 Value *args[] = {trace, FunctionPtr};
304
305 auto call = Builder.CreateCall(interface->insertFunctionTy(),
306 interface->insertFunction(Builder), args);
307 return call;
308}
309
310CallInst *TraceUtils::InsertChoiceGradient(IRBuilder<> &Builder,
311 FunctionType *interface_type,
312 Value *interface_function,
313 Value *address, Value *choice,
314 Value *trace) {
315 Type *size_type = interface_type->getParamType(3);
316 auto &&[retval, sizeval] = ValueToVoidPtrAndSize(Builder, choice, size_type);
317
318 Value *args[] = {trace, address, retval, sizeval};
319
320 auto call = Builder.CreateCall(interface_type, interface_function, args);
321 call->addParamAttr(1, Attribute::ReadOnly);
322 addCallSiteNoCapture(call, 1);
323 return call;
324}
325
326CallInst *TraceUtils::InsertArgumentGradient(IRBuilder<> &Builder,
327 FunctionType *interface_type,
328 Value *interface_function,
329 Value *name, Value *argument,
330 Value *trace) {
331 Type *size_type = interface_type->getParamType(3);
332 auto &&[retval, sizeval] =
333 ValueToVoidPtrAndSize(Builder, argument, size_type);
334
335 Value *args[] = {trace, name, retval, sizeval};
336
337 auto call = Builder.CreateCall(interface_type, interface_function, args);
338 call->addParamAttr(1, Attribute::ReadOnly);
339 addCallSiteNoCapture(call, 1);
340 return call;
341}
342
343CallInst *TraceUtils::GetTrace(IRBuilder<> &Builder, Value *address,
344 const Twine &Name) {
345 assert(address->getType()->isPointerTy());
346
347 Value *args[] = {observations, address};
348
349 auto call = Builder.CreateCall(interface->getTraceTy(),
350 interface->getTrace(Builder), args, Name);
351 call->addParamAttr(1, Attribute::ReadOnly);
352 addCallSiteNoCapture(call, 1);
353 return call;
354}
355
356Instruction *TraceUtils::GetChoice(IRBuilder<> &Builder, Value *address,
357 Type *choiceType, const Twine &Name) {
358 IRBuilder<> AllocaBuilder(getFirstNonPHIOrDbgOrLifetime(
359 &Builder.GetInsertBlock()->getParent()->getEntryBlock()));
360 AllocaInst *store_dest =
361 AllocaBuilder.CreateAlloca(choiceType, nullptr, Name + ".ptr");
362 auto preallocated_size = choiceType->getPrimitiveSizeInBits() / 8;
363 Type *size_type = interface->getChoiceTy()->getParamType(3);
364
365 Value *args[] = {observations, address,
366 Builder.CreatePointerCast(
367 store_dest, getInt8PtrTy(store_dest->getContext())),
368 ConstantInt::get(size_type, preallocated_size)};
369
370 auto call =
371 Builder.CreateCall(interface->getChoiceTy(),
372 interface->getChoice(Builder), args, Name + ".size");
373
374#if LLVM_VERSION_MAJOR >= 14
375 call->addAttributeAtIndex(
376 AttributeList::FunctionIndex,
377 Attribute::get(call->getContext(), "enzyme_inactive"));
378#else
379 call->addAttribute(AttributeList::FunctionIndex,
380 Attribute::get(call->getContext(), "enzyme_inactive"));
381#endif
382 call->addParamAttr(1, Attribute::ReadOnly);
383 addCallSiteNoCapture(call, 1);
384 return Builder.CreateLoad(choiceType, store_dest, "from.trace." + Name);
385}
386
387Instruction *TraceUtils::HasChoice(IRBuilder<> &Builder, Value *address,
388 const Twine &Name) {
389 Value *args[]{observations, address};
390
391 auto call = Builder.CreateCall(interface->hasChoiceTy(),
392 interface->hasChoice(Builder), args, Name);
393 call->addParamAttr(1, Attribute::ReadOnly);
394 addCallSiteNoCapture(call, 1);
395 return call;
396}
397
398Instruction *TraceUtils::HasCall(IRBuilder<> &Builder, Value *address,
399 const Twine &Name) {
400 Value *args[]{observations, address};
401
402 auto call = Builder.CreateCall(interface->hasCallTy(),
403 interface->hasCall(Builder), args, Name);
404 call->addParamAttr(1, Attribute::ReadOnly);
405 addCallSiteNoCapture(call, 1);
406 return call;
407}
408
409Instruction *TraceUtils::SampleOrCondition(IRBuilder<> &Builder,
410 Function *sample_fn,
411 ArrayRef<Value *> sample_args,
412 Value *address, const Twine &Name) {
413 auto &Context = Builder.getContext();
414 auto parent_fn = Builder.GetInsertBlock()->getParent();
415
416 switch (mode) {
418 case ProbProgMode::Trace: {
419 auto sample_call = Builder.CreateCall(sample_fn->getFunctionType(),
420 sample_fn, sample_args);
421 return sample_call;
422 }
424 Instruction *hasChoice = HasChoice(Builder, address, "has.choice." + Name);
425
426 Value *ThenChoice, *ElseChoice;
427 BasicBlock *ThenBlock = BasicBlock::Create(
428 Context, "condition." + Name + ".with.trace", parent_fn);
429 BasicBlock *ElseBlock = BasicBlock::Create(
430 Context, "condition." + Name + ".without.trace", parent_fn);
431 BasicBlock *EndBlock = BasicBlock::Create(Context, "end", parent_fn);
432
433 Builder.CreateCondBr(hasChoice, ThenBlock, ElseBlock);
434 Builder.SetInsertPoint(ThenBlock);
435 ThenChoice = GetChoice(Builder, address, sample_fn->getReturnType(), Name);
436 Builder.CreateBr(EndBlock);
437
438 Builder.SetInsertPoint(ElseBlock);
439 ElseChoice = Builder.CreateCall(sample_fn->getFunctionType(), sample_fn,
440 sample_args, "sample." + Name);
441 Builder.CreateBr(EndBlock);
442
443 Builder.SetInsertPoint(EndBlock);
444 auto phi = Builder.CreatePHI(sample_fn->getReturnType(), 2);
445 phi->addIncoming(ThenChoice, ThenBlock);
446 phi->addIncoming(ElseChoice, ElseBlock);
447
448 return phi;
449 }
450 }
451 llvm_unreachable("Invalid sample_or_condition");
452}
453
455 IRBuilder<> &Builder,
456 function_ref<void(IRBuilder<> &, TraceUtils *, ArrayRef<Value *>)> Outlined,
457 Type *RetTy, ArrayRef<Value *> Arguments, bool needsLikelihood,
458 const Twine &Name) {
459 SmallVector<Type *, 4> Tys;
460 SmallVector<Value *, 4> Vals;
461 Module *M = Builder.GetInsertBlock()->getModule();
462
463 for (auto Arg : Arguments) {
464 Vals.push_back(Arg);
465 Tys.push_back(Arg->getType());
466 }
467
468 if (needsLikelihood) {
469 Vals.push_back(likelihood);
470 Tys.push_back(likelihood->getType());
471 }
472
474 Vals.push_back(observations);
475 Tys.push_back(observations->getType());
476 }
477
479 Vals.push_back(trace);
480 Tys.push_back(trace->getType());
481 }
482
483 FunctionType *FTy = FunctionType::get(RetTy, Tys, false);
484 Function *F =
485 Function::Create(FTy, Function::LinkageTypes::InternalLinkage, Name, M);
486 F->addFnAttr(Attribute::AlwaysInline);
487
488 auto Entry = BasicBlock::Create(M->getContext(), "entry", F);
489
490 auto ArgRange = make_pointer_range(
491 make_range(F->arg_begin(), F->arg_begin() + Arguments.size()));
492 SmallVector<Value *, 4> Rets(ArgRange);
493
494 auto idx = F->arg_begin() + Arguments.size();
495
496 Argument *likelihood_arg = nullptr;
497 if (needsLikelihood)
498 likelihood_arg = idx++;
499
500 Argument *observations_arg = nullptr;
502 observations_arg = idx++;
503
504 Argument *trace_arg = nullptr;
506 trace_arg = idx++;
507
508 TraceUtils OutlineTutils =
510 observations_arg, likelihood_arg, interface);
511 IRBuilder<> OutlineBuilder(Entry);
512 Outlined(OutlineBuilder, &OutlineTutils, Rets);
513
514 return Builder.CreateCall(FTy, F, Vals);
515}
516
517bool TraceUtils::isSampleCall(CallInst *call) {
518 auto F = getFunctionFromCall(call);
519 return sampleFunctions.count(F);
520}
521
522bool TraceUtils::isObserveCall(CallInst *call) {
523 auto F = getFunctionFromCall(call);
524 return observeFunctions.count(F);
525}
static Operation * getFunctionFromCall(CallOpInterface iface)
PointerType * traceType(LLVMContext &C)
static llvm::PointerType * getUnqual(llvm::Type *T)
Definition Utils.h:1179
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
Definition Utils.h:1174
ProbProgMode
Definition Utils.h:399
static void addCallSiteNoCapture(llvm::CallBase *call, size_t idx)
Definition Utils.h:2289
static llvm::Instruction * getFirstNonPHIOrDbgOrLifetime(llvm::BasicBlock *B)
Definition Utils.h:2281
virtual llvm::Value * insertFunction(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * insertChoice(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * newTrace(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * freeTrace(llvm::IRBuilder<> &Builder)=0
llvm::FunctionType * hasCallTy()
llvm::FunctionType * insertFunctionTy()
llvm::FunctionType * hasChoiceTy()
llvm::FunctionType * insertChoiceTy()
virtual llvm::Value * insertArgument(llvm::IRBuilder<> &Builder)=0
llvm::FunctionType * insertArgumentTy()
llvm::FunctionType * insertCallTy()
virtual llvm::Value * hasCall(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * insertReturn(llvm::IRBuilder<> &Builder)=0
llvm::FunctionType * freeTraceTy()
virtual llvm::Value * getChoice(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * hasChoice(llvm::IRBuilder<> &Builder)=0
llvm::FunctionType * insertReturnTy()
virtual llvm::Value * insertCall(llvm::IRBuilder<> &Builder)=0
llvm::FunctionType * getChoiceTy()
llvm::FunctionType * newTraceTy()
llvm::FunctionType * getTraceTy()
virtual llvm::Value * getTrace(llvm::IRBuilder<> &Builder)=0
llvm::CallInst * InsertFunction(llvm::IRBuilder<> &Builder, llvm::Function *function)
static TraceUtils * FromClone(ProbProgMode mode, const llvm::SmallPtrSetImpl< llvm::Function * > &sampleFunctions, const llvm::SmallPtrSetImpl< llvm::Function * > &observeFunctions, TraceInterface *interface, llvm::Function *oldFunc, llvm::ValueMap< const llvm::Value *, llvm::WeakTrackingVH > &originalToNewFn)
static constexpr const char TraceParameterAttribute[]
Definition TraceUtils.h:57
llvm::CallInst * CreateTrace(llvm::IRBuilder<> &Builder, const llvm::Twine &Name="trace")
llvm::Instruction * SampleOrCondition(llvm::IRBuilder<> &Builder, llvm::Function *sample_fn, llvm::ArrayRef< llvm::Value * > sample_args, llvm::Value *address, const llvm::Twine &Name="")
TraceUtils(ProbProgMode mode, const llvm::SmallPtrSetImpl< llvm::Function * > &sampleFunctions, const llvm::SmallPtrSetImpl< llvm::Function * > &observeFunctions, llvm::Function *newFunc, llvm::Argument *trace, llvm::Argument *observations, llvm::Argument *likelihood, TraceInterface *interface)
llvm::CallInst * InsertChoice(llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Value *score, llvm::Value *choice)
llvm::Instruction * GetChoice(llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Type *choiceType, const llvm::Twine &Name="")
llvm::CallInst * InsertReturn(llvm::IRBuilder<> &Builder, llvm::Value *ret)
TraceInterface * getTraceInterface()
static llvm::CallInst * InsertChoiceGradient(llvm::IRBuilder<> &Builder, llvm::FunctionType *interface_type, llvm::Value *interface_function, llvm::Value *address, llvm::Value *choice, llvm::Value *trace)
TraceInterface * interface
Definition TraceUtils.h:51
llvm::CallInst * FreeTrace(llvm::IRBuilder<> &Builder)
llvm::Instruction * HasCall(llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
static constexpr const char ObservationsParameterAttribute[]
Definition TraceUtils.h:58
llvm::Instruction * HasChoice(llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
llvm::SmallPtrSet< llvm::Function *, 4 > observeFunctions
Definition TraceUtils.h:55
llvm::CallInst * CreateOutlinedFunction(llvm::IRBuilder<> &Builder, llvm::function_ref< void(llvm::IRBuilder<> &, TraceUtils *, llvm::ArrayRef< llvm::Value * >)> Outlined, llvm::Type *RetTy, llvm::ArrayRef< llvm::Value * > Arguments, bool needsLikelihood=true, const llvm::Twine &Name="")
static constexpr const char LikelihoodParameterAttribute[]
Definition TraceUtils.h:60
static llvm::CallInst * InsertArgumentGradient(llvm::IRBuilder<> &Builder, llvm::FunctionType *interface_type, llvm::Value *interface_function, llvm::Value *name, llvm::Value *argument, llvm::Value *trace)
llvm::Function * newFunc
Definition TraceUtils.h:53
llvm::SmallPtrSet< llvm::Function *, 4 > sampleFunctions
Definition TraceUtils.h:54
llvm::Value * getLikelihood()
bool isObserveCall(llvm::CallInst *call)
ProbProgMode mode
Definition TraceUtils.h:52
llvm::CallInst * InsertArgument(llvm::IRBuilder<> &Builder, llvm::Value *name, llvm::Value *argument)
llvm::CallInst * GetTrace(llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
bool isSampleCall(llvm::CallInst *call)
llvm::Value * getObservations()
llvm::CallInst * InsertCall(llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Value *subtrace)
llvm::Value * getTrace()