Enzyme main
Loading...
Searching...
No Matches
TraceUtils Class Reference

#include "TraceUtils.h"

Collaboration diagram for TraceUtils:

Public Member Functions

 TraceUtils (ProbProgMode mode, const llvm::SmallPtrSetImpl< llvm::Function * > &sampleFunctions, const llvm::SmallPtrSetImpl< llvm::Function * > &observeFunctions, llvm::Function *newFunc, llvm::Argument *trace, llvm::Argument *observations, llvm::Argument *likelihood, TraceInterface *interface)
 
 ~TraceUtils ()
 
TraceInterfacegetTraceInterface ()
 
llvm::Value * getTrace ()
 
llvm::Value * getObservations ()
 
llvm::Value * getLikelihood ()
 
llvm::CallInst * CreateTrace (llvm::IRBuilder<> &Builder, const llvm::Twine &Name="trace")
 
llvm::CallInst * FreeTrace (llvm::IRBuilder<> &Builder)
 
llvm::CallInst * InsertChoice (llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Value *score, llvm::Value *choice)
 
llvm::CallInst * InsertCall (llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Value *subtrace)
 
llvm::CallInst * InsertArgument (llvm::IRBuilder<> &Builder, llvm::Value *name, llvm::Value *argument)
 
llvm::CallInst * InsertReturn (llvm::IRBuilder<> &Builder, llvm::Value *ret)
 
llvm::CallInst * InsertFunction (llvm::IRBuilder<> &Builder, llvm::Function *function)
 
llvm::CallInst * GetTrace (llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
 
llvm::Instruction * GetChoice (llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Type *choiceType, const llvm::Twine &Name="")
 
llvm::Instruction * HasChoice (llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
 
llvm::Instruction * HasCall (llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
 
llvm::Instruction * SampleOrCondition (llvm::IRBuilder<> &Builder, llvm::Function *sample_fn, llvm::ArrayRef< llvm::Value * > sample_args, llvm::Value *address, const llvm::Twine &Name="")
 
llvm::CallInst * CreateOutlinedFunction (llvm::IRBuilder<> &Builder, llvm::function_ref< void(llvm::IRBuilder<> &, TraceUtils *, llvm::ArrayRef< llvm::Value * >)> Outlined, llvm::Type *RetTy, llvm::ArrayRef< llvm::Value * > Arguments, bool needsLikelihood=true, const llvm::Twine &Name="")
 
bool isSampleCall (llvm::CallInst *call)
 
bool isObserveCall (llvm::CallInst *call)
 

Static Public Member Functions

static TraceUtilsFromClone (ProbProgMode mode, const llvm::SmallPtrSetImpl< llvm::Function * > &sampleFunctions, const llvm::SmallPtrSetImpl< llvm::Function * > &observeFunctions, TraceInterface *interface, llvm::Function *oldFunc, llvm::ValueMap< const llvm::Value *, llvm::WeakTrackingVH > &originalToNewFn)
 
static llvm::CallInst * InsertChoiceGradient (llvm::IRBuilder<> &Builder, llvm::FunctionType *interface_type, llvm::Value *interface_function, llvm::Value *address, llvm::Value *choice, llvm::Value *trace)
 
static llvm::CallInst * InsertArgumentGradient (llvm::IRBuilder<> &Builder, llvm::FunctionType *interface_type, llvm::Value *interface_function, llvm::Value *name, llvm::Value *argument, llvm::Value *trace)
 

Public Attributes

TraceInterfaceinterface
 
ProbProgMode mode
 
llvm::Function * newFunc
 
llvm::SmallPtrSet< llvm::Function *, 4 > sampleFunctions
 
llvm::SmallPtrSet< llvm::Function *, 4 > observeFunctions
 

Static Public Attributes

static constexpr const char TraceParameterAttribute [] = "enzyme_trace"
 
static constexpr const char ObservationsParameterAttribute []
 
static constexpr const char LikelihoodParameterAttribute []
 

Detailed Description

Definition at line 43 of file TraceUtils.h.

Constructor & Destructor Documentation

◆ TraceUtils()

TraceUtils::TraceUtils ( ProbProgMode mode,
const llvm::SmallPtrSetImpl< llvm::Function * > & sampleFunctions,
const llvm::SmallPtrSetImpl< llvm::Function * > & observeFunctions,
llvm::Function * newFunc,
llvm::Argument * trace,
llvm::Argument * observations,
llvm::Argument * likelihood,
TraceInterface * interface )

Definition at line 47 of file TraceUtils.cpp.

Referenced by CreateOutlinedFunction(), and FromClone().

◆ ~TraceUtils()

TraceUtils::~TraceUtils ( )
default

Member Function Documentation

◆ CreateOutlinedFunction()

CallInst * TraceUtils::CreateOutlinedFunction ( llvm::IRBuilder<> & Builder,
llvm::function_ref< void(llvm::IRBuilder<> &, TraceUtils *, llvm::ArrayRef< llvm::Value * >)> Outlined,
llvm::Type * RetTy,
llvm::ArrayRef< llvm::Value * > Arguments,
bool needsLikelihood = true,
const llvm::Twine & Name = "" )

◆ CreateTrace()

CallInst * TraceUtils::CreateTrace ( llvm::IRBuilder<> & Builder,
const llvm::Twine & Name = "trace" )

◆ FreeTrace()

CallInst * TraceUtils::FreeTrace ( llvm::IRBuilder<> & Builder)

◆ FromClone()

TraceUtils * TraceUtils::FromClone ( ProbProgMode mode,
const llvm::SmallPtrSetImpl< llvm::Function * > & sampleFunctions,
const llvm::SmallPtrSetImpl< llvm::Function * > & observeFunctions,
TraceInterface * interface,
llvm::Function * oldFunc,
llvm::ValueMap< const llvm::Value *, llvm::WeakTrackingVH > & originalToNewFn )
static

◆ GetChoice()

Instruction * TraceUtils::GetChoice ( llvm::IRBuilder<> & Builder,
llvm::Value * address,
llvm::Type * choiceType,
const llvm::Twine & Name = "" )

◆ getLikelihood()

Value * TraceUtils::getLikelihood ( )

◆ getObservations()

Value * TraceUtils::getObservations ( )

Definition at line 169 of file TraceUtils.cpp.

◆ GetTrace()

CallInst * TraceUtils::GetTrace ( llvm::IRBuilder<> & Builder,
llvm::Value * address,
const llvm::Twine & Name = "" )

◆ getTrace()

Value * TraceUtils::getTrace ( )

Definition at line 167 of file TraceUtils.cpp.

◆ getTraceInterface()

TraceInterface * TraceUtils::getTraceInterface ( )

Definition at line 165 of file TraceUtils.cpp.

References interface.

Referenced by TraceGenerator::handleArbitraryCall().

◆ HasCall()

Instruction * TraceUtils::HasCall ( llvm::IRBuilder<> & Builder,
llvm::Value * address,
const llvm::Twine & Name = "" )

◆ HasChoice()

Instruction * TraceUtils::HasChoice ( llvm::IRBuilder<> & Builder,
llvm::Value * address,
const llvm::Twine & Name = "" )

◆ InsertArgument()

CallInst * TraceUtils::InsertArgument ( llvm::IRBuilder<> & Builder,
llvm::Value * name,
llvm::Value * argument )

◆ InsertArgumentGradient()

CallInst * TraceUtils::InsertArgumentGradient ( llvm::IRBuilder<> & Builder,
llvm::FunctionType * interface_type,
llvm::Value * interface_function,
llvm::Value * name,
llvm::Value * argument,
llvm::Value * trace )
static

Definition at line 326 of file TraceUtils.cpp.

References addCallSiteNoCapture().

◆ InsertCall()

CallInst * TraceUtils::InsertCall ( llvm::IRBuilder<> & Builder,
llvm::Value * address,
llvm::Value * subtrace )

◆ InsertChoice()

CallInst * TraceUtils::InsertChoice ( llvm::IRBuilder<> & Builder,
llvm::Value * address,
llvm::Value * score,
llvm::Value * choice )

◆ InsertChoiceGradient()

CallInst * TraceUtils::InsertChoiceGradient ( llvm::IRBuilder<> & Builder,
llvm::FunctionType * interface_type,
llvm::Value * interface_function,
llvm::Value * address,
llvm::Value * choice,
llvm::Value * trace )
static

Definition at line 310 of file TraceUtils.cpp.

References addCallSiteNoCapture().

◆ InsertFunction()

CallInst * TraceUtils::InsertFunction ( llvm::IRBuilder<> & Builder,
llvm::Function * function )

◆ InsertReturn()

CallInst * TraceUtils::InsertReturn ( llvm::IRBuilder<> & Builder,
llvm::Value * ret )

◆ isObserveCall()

bool TraceUtils::isObserveCall ( llvm::CallInst * call)

Definition at line 522 of file TraceUtils.cpp.

References getFunctionFromCall(), and observeFunctions.

Referenced by TraceGenerator::visitCallInst().

◆ isSampleCall()

bool TraceUtils::isSampleCall ( llvm::CallInst * call)

Definition at line 517 of file TraceUtils.cpp.

References getFunctionFromCall(), and sampleFunctions.

Referenced by TraceGenerator::visitCallInst().

◆ SampleOrCondition()

Instruction * TraceUtils::SampleOrCondition ( llvm::IRBuilder<> & Builder,
llvm::Function * sample_fn,
llvm::ArrayRef< llvm::Value * > sample_args,
llvm::Value * address,
const llvm::Twine & Name = "" )

Definition at line 409 of file TraceUtils.cpp.

References Condition, GetChoice(), HasChoice(), Likelihood, mode, and Trace.

Member Data Documentation

◆ interface

◆ LikelihoodParameterAttribute

const char TraceUtils::LikelihoodParameterAttribute[]
staticconstexpr
Initial value:
=
"enzyme_likelihood"

Definition at line 60 of file TraceUtils.h.

Referenced by FromClone(), and TraceGenerator::visitFunction().

◆ mode

ProbProgMode TraceUtils::mode

Definition at line 52 of file TraceUtils.h.

Referenced by CreateOutlinedFunction(), FromClone(), and SampleOrCondition().

◆ newFunc

llvm::Function* TraceUtils::newFunc

Definition at line 53 of file TraceUtils.h.

Referenced by FromClone(), and TraceGenerator::visitFunction().

◆ ObservationsParameterAttribute

const char TraceUtils::ObservationsParameterAttribute[]
staticconstexpr
Initial value:
=
"enzyme_observations"

Definition at line 58 of file TraceUtils.h.

Referenced by FromClone(), and TraceGenerator::visitFunction().

◆ observeFunctions

llvm::SmallPtrSet<llvm::Function *, 4> TraceUtils::observeFunctions

◆ sampleFunctions

llvm::SmallPtrSet<llvm::Function *, 4> TraceUtils::sampleFunctions

◆ TraceParameterAttribute

const char TraceUtils::TraceParameterAttribute[] = "enzyme_trace"
staticconstexpr

Definition at line 57 of file TraceUtils.h.

Referenced by FromClone(), and TraceGenerator::visitFunction().


The documentation for this class was generated from the following files: