Enzyme main
Loading...
Searching...
No Matches
TraceUtils.h
Go to the documentation of this file.
1//===- TraceUtils.h - Utilites for interacting with traces ---------------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains utilities for interacting with probabilistic programming
22// traces using the probabilistic programming
23// trace interface
24//
25//===----------------------------------------------------------------------===//
26
27#ifndef TraceUtils_h
28#define TraceUtils_h
29
30#include "llvm/ADT/SmallPtrSet.h"
31#include "llvm/ADT/StringSet.h"
32#include "llvm/IR/BasicBlock.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/Instructions.h"
35#include "llvm/IR/Type.h"
36#include "llvm/IR/User.h"
37#include "llvm/IR/Value.h"
38#include "llvm/IR/ValueMap.h"
39
40#include "TraceInterface.h"
41#include "Utils.h"
42
44
45private:
46 llvm::Value *trace;
47 llvm::Value *observations;
48 llvm::Value *likelihood;
49
50public:
53 llvm::Function *newFunc;
54 llvm::SmallPtrSet<llvm::Function *, 4> sampleFunctions;
55 llvm::SmallPtrSet<llvm::Function *, 4> observeFunctions;
56
57 constexpr static const char TraceParameterAttribute[] = "enzyme_trace";
58 constexpr static const char ObservationsParameterAttribute[] =
59 "enzyme_observations";
60 constexpr static const char LikelihoodParameterAttribute[] =
61 "enzyme_likelihood";
62
63public:
65 const llvm::SmallPtrSetImpl<llvm::Function *> &sampleFunctions,
66 const llvm::SmallPtrSetImpl<llvm::Function *> &observeFunctions,
67 llvm::Function *newFunc, llvm::Argument *trace,
68 llvm::Argument *observations, llvm::Argument *likelihood,
70
71 static TraceUtils *
73 const llvm::SmallPtrSetImpl<llvm::Function *> &sampleFunctions,
74 const llvm::SmallPtrSetImpl<llvm::Function *> &observeFunctions,
75 TraceInterface *interface, llvm::Function *oldFunc,
76 llvm::ValueMap<const llvm::Value *, llvm::WeakTrackingVH>
77 &originalToNewFn);
78
80
81private:
82 static std::pair<llvm::Value *, llvm::Constant *>
83 ValueToVoidPtrAndSize(llvm::IRBuilder<> &Builder, llvm::Value *val,
84 llvm::Type *size_type);
85
86public:
88
89 llvm::Value *getTrace();
90
91 llvm::Value *getObservations();
92
93 llvm::Value *getLikelihood();
94
95 llvm::CallInst *CreateTrace(llvm::IRBuilder<> &Builder,
96 const llvm::Twine &Name = "trace");
97
98 llvm::CallInst *FreeTrace(llvm::IRBuilder<> &Builder);
99
100 llvm::CallInst *InsertChoice(llvm::IRBuilder<> &Builder, llvm::Value *address,
101 llvm::Value *score, llvm::Value *choice);
102
103 llvm::CallInst *InsertCall(llvm::IRBuilder<> &Builder, llvm::Value *address,
104 llvm::Value *subtrace);
105
106 llvm::CallInst *InsertArgument(llvm::IRBuilder<> &Builder, llvm::Value *name,
107 llvm::Value *argument);
108
109 llvm::CallInst *InsertReturn(llvm::IRBuilder<> &Builder, llvm::Value *ret);
110
111 llvm::CallInst *InsertFunction(llvm::IRBuilder<> &Builder,
112 llvm::Function *function);
113
114 static llvm::CallInst *
115 InsertChoiceGradient(llvm::IRBuilder<> &Builder,
116 llvm::FunctionType *interface_type,
117 llvm::Value *interface_function, llvm::Value *address,
118 llvm::Value *choice, llvm::Value *trace);
119
120 static llvm::CallInst *
121 InsertArgumentGradient(llvm::IRBuilder<> &Builder,
122 llvm::FunctionType *interface_type,
123 llvm::Value *interface_function, llvm::Value *name,
124 llvm::Value *argument, llvm::Value *trace);
125
126 llvm::CallInst *GetTrace(llvm::IRBuilder<> &Builder, llvm::Value *address,
127 const llvm::Twine &Name = "");
128
129 llvm::Instruction *GetChoice(llvm::IRBuilder<> &Builder, llvm::Value *address,
130 llvm::Type *choiceType,
131 const llvm::Twine &Name = "");
132
133 llvm::Instruction *HasChoice(llvm::IRBuilder<> &Builder, llvm::Value *address,
134 const llvm::Twine &Name = "");
135
136 llvm::Instruction *HasCall(llvm::IRBuilder<> &Builder, llvm::Value *address,
137 const llvm::Twine &Name = "");
138
139 llvm::Instruction *
140 SampleOrCondition(llvm::IRBuilder<> &Builder, llvm::Function *sample_fn,
141 llvm::ArrayRef<llvm::Value *> sample_args,
142 llvm::Value *address, const llvm::Twine &Name = "");
143
144 llvm::CallInst *CreateOutlinedFunction(
145 llvm::IRBuilder<> &Builder,
146 llvm::function_ref<void(llvm::IRBuilder<> &, TraceUtils *,
147 llvm::ArrayRef<llvm::Value *>)>
148 Outlined,
149 llvm::Type *RetTy, llvm::ArrayRef<llvm::Value *> Arguments,
150 bool needsLikelihood = true, const llvm::Twine &Name = "");
151
152 bool isSampleCall(llvm::CallInst *call);
153
154 bool isObserveCall(llvm::CallInst *call);
155};
156
157#endif /* TraceUtils_h */
ProbProgMode
Definition Utils.h:399
llvm::CallInst * InsertFunction(llvm::IRBuilder<> &Builder, llvm::Function *function)
static TraceUtils * FromClone(ProbProgMode mode, const llvm::SmallPtrSetImpl< llvm::Function * > &sampleFunctions, const llvm::SmallPtrSetImpl< llvm::Function * > &observeFunctions, TraceInterface *interface, llvm::Function *oldFunc, llvm::ValueMap< const llvm::Value *, llvm::WeakTrackingVH > &originalToNewFn)
static constexpr const char TraceParameterAttribute[]
Definition TraceUtils.h:57
llvm::CallInst * CreateTrace(llvm::IRBuilder<> &Builder, const llvm::Twine &Name="trace")
llvm::Instruction * SampleOrCondition(llvm::IRBuilder<> &Builder, llvm::Function *sample_fn, llvm::ArrayRef< llvm::Value * > sample_args, llvm::Value *address, const llvm::Twine &Name="")
TraceUtils(ProbProgMode mode, const llvm::SmallPtrSetImpl< llvm::Function * > &sampleFunctions, const llvm::SmallPtrSetImpl< llvm::Function * > &observeFunctions, llvm::Function *newFunc, llvm::Argument *trace, llvm::Argument *observations, llvm::Argument *likelihood, TraceInterface *interface)
llvm::CallInst * InsertChoice(llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Value *score, llvm::Value *choice)
llvm::Instruction * GetChoice(llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Type *choiceType, const llvm::Twine &Name="")
llvm::CallInst * InsertReturn(llvm::IRBuilder<> &Builder, llvm::Value *ret)
TraceInterface * getTraceInterface()
static llvm::CallInst * InsertChoiceGradient(llvm::IRBuilder<> &Builder, llvm::FunctionType *interface_type, llvm::Value *interface_function, llvm::Value *address, llvm::Value *choice, llvm::Value *trace)
TraceInterface * interface
Definition TraceUtils.h:51
llvm::CallInst * FreeTrace(llvm::IRBuilder<> &Builder)
llvm::Instruction * HasCall(llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
static constexpr const char ObservationsParameterAttribute[]
Definition TraceUtils.h:58
llvm::Instruction * HasChoice(llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
llvm::SmallPtrSet< llvm::Function *, 4 > observeFunctions
Definition TraceUtils.h:55
llvm::CallInst * CreateOutlinedFunction(llvm::IRBuilder<> &Builder, llvm::function_ref< void(llvm::IRBuilder<> &, TraceUtils *, llvm::ArrayRef< llvm::Value * >)> Outlined, llvm::Type *RetTy, llvm::ArrayRef< llvm::Value * > Arguments, bool needsLikelihood=true, const llvm::Twine &Name="")
static constexpr const char LikelihoodParameterAttribute[]
Definition TraceUtils.h:60
static llvm::CallInst * InsertArgumentGradient(llvm::IRBuilder<> &Builder, llvm::FunctionType *interface_type, llvm::Value *interface_function, llvm::Value *name, llvm::Value *argument, llvm::Value *trace)
llvm::Function * newFunc
Definition TraceUtils.h:53
llvm::SmallPtrSet< llvm::Function *, 4 > sampleFunctions
Definition TraceUtils.h:54
llvm::Value * getLikelihood()
bool isObserveCall(llvm::CallInst *call)
ProbProgMode mode
Definition TraceUtils.h:52
llvm::CallInst * InsertArgument(llvm::IRBuilder<> &Builder, llvm::Value *name, llvm::Value *argument)
llvm::CallInst * GetTrace(llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
bool isSampleCall(llvm::CallInst *call)
llvm::Value * getObservations()
llvm::CallInst * InsertCall(llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Value *subtrace)
llvm::Value * getTrace()