Enzyme main
Loading...
Searching...
No Matches
TraceInterface.h
Go to the documentation of this file.
1//===- TraceInterface.h - Interact with probabilistic programming traces
2//---===//
3//
4// Enzyme Project
5//
6// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
7// Exceptions. See https://llvm.org/LICENSE.txt for license information.
8// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9//
10// If using this code in an academic setting, please cite the following:
11// @incollection{enzymeNeurips,
12// title = {Instead of Rewriting Foreign Code for Machine Learning,
13// Automatically Synthesize Fast Gradients},
14// author = {Moses, William S. and Churavy, Valentin},
15// booktitle = {Advances in Neural Information Processing Systems 33},
16// year = {2020},
17// note = {To appear in},
18// }
19//
20//===----------------------------------------------------------------------===//
21//
22// This file contains an abstraction for static and dynamic implementations of
23// the probabilistic programming interface.
24//
25//===----------------------------------------------------------------------===//---------------------------------------------------------------------===//
26
27#ifndef TraceInterface_h
28#define TraceInterface_h
29
30#include "llvm/IR/IRBuilder.h"
31#include "llvm/IR/Module.h"
32#include "llvm/IR/Type.h"
33#include "llvm/IR/Value.h"
34
36private:
37 llvm::LLVMContext &C;
38
39public:
40 TraceInterface(llvm::LLVMContext &C);
41
42 virtual ~TraceInterface() = default;
43
44public:
45 // user implemented
46 virtual llvm::Value *getTrace(llvm::IRBuilder<> &Builder) = 0;
47 virtual llvm::Value *getChoice(llvm::IRBuilder<> &Builder) = 0;
48 virtual llvm::Value *insertCall(llvm::IRBuilder<> &Builder) = 0;
49 virtual llvm::Value *insertChoice(llvm::IRBuilder<> &Builder) = 0;
50
51 virtual llvm::Value *insertArgument(llvm::IRBuilder<> &Builder) = 0;
52 virtual llvm::Value *insertReturn(llvm::IRBuilder<> &Builder) = 0;
53 virtual llvm::Value *insertFunction(llvm::IRBuilder<> &Builder) = 0;
54 virtual llvm::Value *insertChoiceGradient(llvm::IRBuilder<> &Builder) = 0;
55 virtual llvm::Value *insertArgumentGradient(llvm::IRBuilder<> &Builder) = 0;
56
57 virtual llvm::Value *newTrace(llvm::IRBuilder<> &Builder) = 0;
58 virtual llvm::Value *freeTrace(llvm::IRBuilder<> &Builder) = 0;
59 virtual llvm::Value *hasCall(llvm::IRBuilder<> &Builder) = 0;
60 virtual llvm::Value *hasChoice(llvm::IRBuilder<> &Builder) = 0;
61
62public:
63 static llvm::IntegerType *sizeType(llvm::LLVMContext &C);
64 static llvm::Type *stringType(llvm::LLVMContext &C);
65
66public:
67 llvm::FunctionType *getTraceTy();
68 llvm::FunctionType *getChoiceTy();
69 llvm::FunctionType *insertCallTy();
70 llvm::FunctionType *insertChoiceTy();
71
72 llvm::FunctionType *insertArgumentTy();
73 llvm::FunctionType *insertReturnTy();
74 llvm::FunctionType *insertFunctionTy();
75 llvm::FunctionType *insertChoiceGradientTy();
76 llvm::FunctionType *insertArgumentGradientTy();
77
78 llvm::FunctionType *newTraceTy();
79 llvm::FunctionType *freeTraceTy();
80 llvm::FunctionType *hasCallTy();
81 llvm::FunctionType *hasChoiceTy();
82
83 static llvm::FunctionType *getTraceTy(llvm::LLVMContext &C);
84 static llvm::FunctionType *getChoiceTy(llvm::LLVMContext &C);
85 static llvm::FunctionType *insertCallTy(llvm::LLVMContext &C);
86 static llvm::FunctionType *insertChoiceTy(llvm::LLVMContext &C);
87
88 static llvm::FunctionType *insertArgumentTy(llvm::LLVMContext &C);
89 static llvm::FunctionType *insertReturnTy(llvm::LLVMContext &C);
90 static llvm::FunctionType *insertFunctionTy(llvm::LLVMContext &C);
91 static llvm::FunctionType *insertChoiceGradientTy(llvm::LLVMContext &C);
92 static llvm::FunctionType *insertArgumentGradientTy(llvm::LLVMContext &C);
93
94 static llvm::FunctionType *newTraceTy(llvm::LLVMContext &C);
95 static llvm::FunctionType *freeTraceTy(llvm::LLVMContext &C);
96 static llvm::FunctionType *hasCallTy(llvm::LLVMContext &C);
97 static llvm::FunctionType *hasChoiceTy(llvm::LLVMContext &C);
98};
99
101private:
102 llvm::Function *getTraceFunction = nullptr;
103 llvm::Function *getChoiceFunction = nullptr;
104 llvm::Function *insertCallFunction = nullptr;
105 llvm::Function *insertChoiceFunction = nullptr;
106 llvm::Function *insertArgumentFunction = nullptr;
107 llvm::Function *insertReturnFunction = nullptr;
108 llvm::Function *insertFunctionFunction = nullptr;
109 llvm::Function *insertChoiceGradientFunction = nullptr;
110 llvm::Function *insertArgumentGradientFunction = nullptr;
111 llvm::Function *newTraceFunction = nullptr;
112 llvm::Function *freeTraceFunction = nullptr;
113 llvm::Function *hasCallFunction = nullptr;
114 llvm::Function *hasChoiceFunction = nullptr;
115
116public:
117 StaticTraceInterface(llvm::Module *M);
118
119 StaticTraceInterface(llvm::LLVMContext &C, llvm::Function *getTraceFunction,
120 llvm::Function *getChoiceFunction,
121 llvm::Function *insertCallFunction,
122 llvm::Function *insertChoiceFunction,
123 llvm::Function *insertArgumentFunction,
124 llvm::Function *insertReturnFunction,
125 llvm::Function *insertFunctionFunction,
126 llvm::Function *insertChoiceGradientFunction,
127 llvm::Function *insertArgumentGradientFunction,
128 llvm::Function *newTraceFunction,
129 llvm::Function *freeTraceFunction,
130 llvm::Function *hasCallFunction,
131 llvm::Function *hasChoiceFunction);
132
134
135public:
136 // user implemented
137 llvm::Value *getTrace(llvm::IRBuilder<> &Builder);
138 llvm::Value *getChoice(llvm::IRBuilder<> &Builder);
139 llvm::Value *insertCall(llvm::IRBuilder<> &Builder);
140 llvm::Value *insertChoice(llvm::IRBuilder<> &Builder);
141 llvm::Value *insertArgument(llvm::IRBuilder<> &Builder);
142 llvm::Value *insertReturn(llvm::IRBuilder<> &Builder);
143 llvm::Value *insertFunction(llvm::IRBuilder<> &Builder);
144 llvm::Value *insertChoiceGradient(llvm::IRBuilder<> &Builder);
145 llvm::Value *insertArgumentGradient(llvm::IRBuilder<> &Builder);
146 llvm::Value *newTrace(llvm::IRBuilder<> &Builder);
147 llvm::Value *freeTrace(llvm::IRBuilder<> &Builder);
148 llvm::Value *hasCall(llvm::IRBuilder<> &Builder);
149 llvm::Value *hasChoice(llvm::IRBuilder<> &Builder);
150};
151
153private:
154 llvm::Function *getTraceFunction;
155 llvm::Function *getChoiceFunction;
156 llvm::Function *insertCallFunction;
157 llvm::Function *insertChoiceFunction;
158 llvm::Function *insertArgumentFunction;
159 llvm::Function *insertReturnFunction;
160 llvm::Function *insertFunctionFunction;
161 llvm::Function *insertChoiceGradientFunction;
162 llvm::Function *insertArgumentGradientFunction;
163 llvm::Function *newTraceFunction;
164 llvm::Function *freeTraceFunction;
165 llvm::Function *hasCallFunction;
166 llvm::Function *hasChoiceFunction;
167
168public:
169 DynamicTraceInterface(llvm::Value *dynamicInterface, llvm::Function *F);
170
172
173private:
174 llvm::Function *MaterializeInterfaceFunction(llvm::IRBuilder<> &Builder,
175 llvm::Value *,
176 llvm::FunctionType *,
177 unsigned index, llvm::Module &M,
178 const llvm::Twine &Name = "");
179
180public:
181 // user implemented
182 llvm::Value *getTrace(llvm::IRBuilder<> &Builder);
183 llvm::Value *getChoice(llvm::IRBuilder<> &Builder);
184 llvm::Value *insertCall(llvm::IRBuilder<> &Builder);
185 llvm::Value *insertChoice(llvm::IRBuilder<> &Builder);
186 llvm::Value *insertArgument(llvm::IRBuilder<> &Builder);
187 llvm::Value *insertReturn(llvm::IRBuilder<> &Builder);
188 llvm::Value *insertFunction(llvm::IRBuilder<> &Builder);
189 llvm::Value *insertChoiceGradient(llvm::IRBuilder<> &Builder);
190 llvm::Value *insertArgumentGradient(llvm::IRBuilder<> &Builder);
191 llvm::Value *newTrace(llvm::IRBuilder<> &Builder);
192 llvm::Value *freeTrace(llvm::IRBuilder<> &Builder);
193 llvm::Value *hasCall(llvm::IRBuilder<> &Builder);
194 llvm::Value *hasChoice(llvm::IRBuilder<> &Builder);
195};
196
197#endif
llvm::Value * insertArgumentGradient(llvm::IRBuilder<> &Builder)
llvm::Value * insertChoiceGradient(llvm::IRBuilder<> &Builder)
DynamicTraceInterface(llvm::Value *dynamicInterface, llvm::Function *F)
llvm::Value * hasChoice(llvm::IRBuilder<> &Builder)
llvm::Value * hasCall(llvm::IRBuilder<> &Builder)
llvm::Value * getChoice(llvm::IRBuilder<> &Builder)
llvm::Value * freeTrace(llvm::IRBuilder<> &Builder)
llvm::Value * insertChoice(llvm::IRBuilder<> &Builder)
llvm::Value * insertFunction(llvm::IRBuilder<> &Builder)
llvm::Value * insertReturn(llvm::IRBuilder<> &Builder)
llvm::Value * getTrace(llvm::IRBuilder<> &Builder)
llvm::Value * insertArgument(llvm::IRBuilder<> &Builder)
~DynamicTraceInterface()=default
llvm::Value * insertCall(llvm::IRBuilder<> &Builder)
llvm::Value * newTrace(llvm::IRBuilder<> &Builder)
llvm::Value * getChoice(llvm::IRBuilder<> &Builder)
llvm::Value * insertCall(llvm::IRBuilder<> &Builder)
StaticTraceInterface(llvm::LLVMContext &C, llvm::Function *getTraceFunction, llvm::Function *getChoiceFunction, llvm::Function *insertCallFunction, llvm::Function *insertChoiceFunction, llvm::Function *insertArgumentFunction, llvm::Function *insertReturnFunction, llvm::Function *insertFunctionFunction, llvm::Function *insertChoiceGradientFunction, llvm::Function *insertArgumentGradientFunction, llvm::Function *newTraceFunction, llvm::Function *freeTraceFunction, llvm::Function *hasCallFunction, llvm::Function *hasChoiceFunction)
llvm::Value * insertArgument(llvm::IRBuilder<> &Builder)
llvm::Value * insertReturn(llvm::IRBuilder<> &Builder)
llvm::Value * insertChoiceGradient(llvm::IRBuilder<> &Builder)
llvm::Value * freeTrace(llvm::IRBuilder<> &Builder)
llvm::Value * insertArgumentGradient(llvm::IRBuilder<> &Builder)
llvm::Value * newTrace(llvm::IRBuilder<> &Builder)
llvm::Value * insertChoice(llvm::IRBuilder<> &Builder)
llvm::Value * getTrace(llvm::IRBuilder<> &Builder)
~StaticTraceInterface()=default
llvm::Value * hasChoice(llvm::IRBuilder<> &Builder)
StaticTraceInterface(llvm::Module *M)
llvm::Value * insertFunction(llvm::IRBuilder<> &Builder)
llvm::Value * hasCall(llvm::IRBuilder<> &Builder)
static llvm::FunctionType * getChoiceTy(llvm::LLVMContext &C)
static llvm::FunctionType * insertReturnTy(llvm::LLVMContext &C)
virtual llvm::Value * insertFunction(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * insertChoice(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * newTrace(llvm::IRBuilder<> &Builder)=0
static llvm::FunctionType * insertArgumentTy(llvm::LLVMContext &C)
static llvm::FunctionType * insertChoiceTy(llvm::LLVMContext &C)
llvm::FunctionType * insertArgumentGradientTy()
virtual llvm::Value * freeTrace(llvm::IRBuilder<> &Builder)=0
TraceInterface(llvm::LLVMContext &C)
static llvm::FunctionType * hasCallTy(llvm::LLVMContext &C)
llvm::FunctionType * hasCallTy()
llvm::FunctionType * insertFunctionTy()
llvm::FunctionType * hasChoiceTy()
virtual ~TraceInterface()=default
static llvm::FunctionType * newTraceTy(llvm::LLVMContext &C)
llvm::FunctionType * insertChoiceTy()
llvm::FunctionType * insertChoiceGradientTy()
virtual llvm::Value * insertArgument(llvm::IRBuilder<> &Builder)=0
static llvm::FunctionType * insertFunctionTy(llvm::LLVMContext &C)
llvm::FunctionType * insertArgumentTy()
static llvm::FunctionType * hasChoiceTy(llvm::LLVMContext &C)
static llvm::Type * stringType(llvm::LLVMContext &C)
llvm::FunctionType * insertCallTy()
static llvm::FunctionType * freeTraceTy(llvm::LLVMContext &C)
virtual llvm::Value * hasCall(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * insertReturn(llvm::IRBuilder<> &Builder)=0
llvm::FunctionType * freeTraceTy()
virtual llvm::Value * insertArgumentGradient(llvm::IRBuilder<> &Builder)=0
static llvm::IntegerType * sizeType(llvm::LLVMContext &C)
static llvm::FunctionType * insertCallTy(llvm::LLVMContext &C)
virtual llvm::Value * getChoice(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * hasChoice(llvm::IRBuilder<> &Builder)=0
static llvm::FunctionType * insertArgumentGradientTy(llvm::LLVMContext &C)
llvm::FunctionType * insertReturnTy()
virtual llvm::Value * insertChoiceGradient(llvm::IRBuilder<> &Builder)=0
static llvm::FunctionType * getTraceTy(llvm::LLVMContext &C)
virtual llvm::Value * insertCall(llvm::IRBuilder<> &Builder)=0
static llvm::FunctionType * insertChoiceGradientTy(llvm::LLVMContext &C)
llvm::FunctionType * getChoiceTy()
llvm::FunctionType * newTraceTy()
llvm::FunctionType * getTraceTy()
virtual llvm::Value * getTrace(llvm::IRBuilder<> &Builder)=0