Enzyme main
Loading...
Searching...
No Matches
TraceInterface.cpp
Go to the documentation of this file.
1//===- TraceInterface.h - Interact with probabilistic programming traces
2//---===//
3//
4// Enzyme Project
5//
6// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
7// Exceptions. See https://llvm.org/LICENSE.txt for license information.
8// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9//
10// If using this code in an academic setting, please cite the following:
11// @incollection{enzymeNeurips,
12// title = {Instead of Rewriting Foreign Code for Machine Learning,
13// Automatically Synthesize Fast Gradients},
14// author = {Moses, William S. and Churavy, Valentin},
15// booktitle = {Advances in Neural Information Processing Systems 33},
16// year = {2020},
17// note = {To appear in},
18// }
19//
20//===----------------------------------------------------------------------===//
21//
22// This file contains an abstraction for static and dynamic implementations of
23// the probabilistic programming interface.
24//
25//===----------------------------------------------------------------------===//
26
27#include "TraceInterface.h"
28
29#include "Utils.h"
30
31#include "llvm/IR/Function.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/Instructions.h"
34#include "llvm/IR/Type.h"
35#include "llvm/IR/Value.h"
36
37using namespace llvm;
38
39TraceInterface::TraceInterface(LLVMContext &C) : C(C){};
40
41PointerType *traceType(LLVMContext &C) {
43}
44
45Type *addressType(LLVMContext &C) { return getInt8PtrTy(C); }
46
47IntegerType *TraceInterface::sizeType(LLVMContext &C) {
48 return IntegerType::getInt64Ty(C);
49}
50
51Type *TraceInterface::stringType(LLVMContext &C) { return getInt8PtrTy(C); }
52
53FunctionType *TraceInterface::getTraceTy() { return getTraceTy(C); }
54FunctionType *TraceInterface::getChoiceTy() { return getChoiceTy(C); }
55FunctionType *TraceInterface::insertCallTy() { return insertCallTy(C); }
56FunctionType *TraceInterface::insertChoiceTy() { return insertChoiceTy(C); }
58FunctionType *TraceInterface::insertReturnTy() { return insertReturnTy(C); }
66FunctionType *TraceInterface::newTraceTy() { return newTraceTy(C); }
67FunctionType *TraceInterface::freeTraceTy() { return freeTraceTy(C); }
68FunctionType *TraceInterface::hasCallTy() { return hasCallTy(C); }
69FunctionType *TraceInterface::hasChoiceTy() { return hasChoiceTy(C); }
70
71FunctionType *TraceInterface::getTraceTy(LLVMContext &C) {
72 return FunctionType::get(traceType(C), {traceType(C), stringType(C)}, false);
73}
74
75FunctionType *TraceInterface::getChoiceTy(LLVMContext &C) {
76 return FunctionType::get(
78 false);
79}
80
81FunctionType *TraceInterface::insertCallTy(LLVMContext &C) {
82 return FunctionType::get(Type::getVoidTy(C),
84 false);
85}
86
87FunctionType *TraceInterface::insertChoiceTy(LLVMContext &C) {
88 return FunctionType::get(Type::getVoidTy(C),
90 Type::getDoubleTy(C), getInt8PtrTy(C), sizeType(C)},
91 false);
92}
93
94FunctionType *TraceInterface::insertArgumentTy(LLVMContext &C) {
95 return FunctionType::get(
96 Type::getVoidTy(C),
97 {getInt8PtrTy(C), stringType(C), getInt8PtrTy(C), sizeType(C)}, false);
98}
99
100FunctionType *TraceInterface::insertReturnTy(LLVMContext &C) {
101 return FunctionType::get(Type::getVoidTy(C),
103 false);
104}
105
106FunctionType *TraceInterface::insertFunctionTy(LLVMContext &C) {
107 return FunctionType::get(Type::getVoidTy(C),
108 {getInt8PtrTy(C), getInt8PtrTy(C)}, false);
109}
110
111FunctionType *TraceInterface::insertChoiceGradientTy(LLVMContext &C) {
112 return FunctionType::get(
113 Type::getVoidTy(C),
114 {getInt8PtrTy(C), stringType(C), getInt8PtrTy(C), sizeType(C)}, false);
115}
116
117FunctionType *TraceInterface::insertArgumentGradientTy(LLVMContext &C) {
118 return FunctionType::get(
119 Type::getVoidTy(C),
120 {getInt8PtrTy(C), stringType(C), getInt8PtrTy(C), sizeType(C)}, false);
121}
122
123FunctionType *TraceInterface::newTraceTy(LLVMContext &C) {
124 return FunctionType::get(getInt8PtrTy(C), {}, false);
125}
126
127FunctionType *TraceInterface::freeTraceTy(LLVMContext &C) {
128 return FunctionType::get(Type::getVoidTy(C), {getInt8PtrTy(C)}, false);
129}
130
131FunctionType *TraceInterface::hasCallTy(LLVMContext &C) {
132 return FunctionType::get(Type::getInt1Ty(C), {getInt8PtrTy(C), stringType(C)},
133 false);
134}
135
136FunctionType *TraceInterface::hasChoiceTy(LLVMContext &C) {
137 return FunctionType::get(Type::getInt1Ty(C), {getInt8PtrTy(C), stringType(C)},
138 false);
139}
140
142 : TraceInterface(M->getContext()) {
143 for (auto &&F : M->functions()) {
144 if (F.isIntrinsic())
145 continue;
146 if (F.getName().contains("__enzyme_newtrace")) {
147 assert(F.getFunctionType() == newTraceTy());
148 newTraceFunction = &F;
149 } else if (F.getName().contains("__enzyme_freetrace")) {
150 assert(F.getFunctionType() == freeTraceTy());
151 freeTraceFunction = &F;
152 } else if (F.getName().contains("__enzyme_get_trace")) {
153 assert(F.getFunctionType() == getTraceTy());
154 getTraceFunction = &F;
155 } else if (F.getName().contains("__enzyme_get_choice")) {
156 assert(F.getFunctionType() == getChoiceTy());
157 getChoiceFunction = &F;
158 } else if (F.getName().contains("__enzyme_insert_call")) {
159 assert(F.getFunctionType() == insertCallTy());
160 insertCallFunction = &F;
161 } else if (F.getName().contains("__enzyme_insert_choice")) {
162 assert(F.getFunctionType() == insertChoiceTy());
163 insertChoiceFunction = &F;
164 } else if (F.getName().contains("__enzyme_insert_argument")) {
165 assert(F.getFunctionType() == insertArgumentTy());
166 insertArgumentFunction = &F;
167 } else if (F.getName().contains("__enzyme_insert_return")) {
168 assert(F.getFunctionType() == insertReturnTy());
169 insertReturnFunction = &F;
170 } else if (F.getName().contains("__enzyme_insert_function")) {
171 assert(F.getFunctionType() == insertFunctionTy());
172 insertFunctionFunction = &F;
173 } else if (F.getName().contains("__enzyme_insert_gradient_choice")) {
174 assert(F.getFunctionType() == insertChoiceGradientTy());
175 insertChoiceGradientFunction = &F;
176 } else if (F.getName().contains("__enzyme_insert_gradient_argument")) {
177 assert(F.getFunctionType() == insertArgumentGradientTy());
178 insertArgumentGradientFunction = &F;
179 } else if (F.getName().contains("__enzyme_has_call")) {
180 assert(F.getFunctionType() == hasCallTy());
181 hasCallFunction = &F;
182 } else if (F.getName().contains("__enzyme_has_choice")) {
183 assert(F.getFunctionType() == hasChoiceTy());
184 hasChoiceFunction = &F;
185 }
186 }
187
188 assert(newTraceFunction);
189 assert(freeTraceFunction);
190 assert(getTraceFunction);
191 assert(getChoiceFunction);
192 assert(insertCallFunction);
193 assert(insertChoiceFunction);
194
195 assert(insertArgumentFunction);
196 assert(insertReturnFunction);
197 assert(insertFunctionFunction);
198
199 assert(insertChoiceGradientFunction);
200 assert(insertArgumentGradientFunction);
201
202 assert(hasCallFunction);
203 assert(hasChoiceFunction);
204
205 newTraceFunction->addFnAttr("enzyme_notypeanalysis");
206 freeTraceFunction->addFnAttr("enzyme_notypeanalysis");
207 getTraceFunction->addFnAttr("enzyme_notypeanalysis");
208 getChoiceFunction->addFnAttr("enzyme_notypeanalysis");
209 insertCallFunction->addFnAttr("enzyme_notypeanalysis");
210 insertChoiceFunction->addFnAttr("enzyme_notypeanalysis");
211 insertArgumentFunction->addFnAttr("enzyme_notypeanalysis");
212 insertReturnFunction->addFnAttr("enzyme_notypeanalysis");
213 insertFunctionFunction->addFnAttr("enzyme_notypeanalysis");
214 insertChoiceGradientFunction->addFnAttr("enzyme_notypeanalysis");
215 insertArgumentGradientFunction->addFnAttr("enzyme_notypeanalysis");
216 hasCallFunction->addFnAttr("enzyme_notypeanalysis");
217 hasChoiceFunction->addFnAttr("enzyme_notypeanalysis");
218
219 newTraceFunction->addFnAttr("enzyme_inactive");
220 freeTraceFunction->addFnAttr("enzyme_inactive");
221 getTraceFunction->addFnAttr("enzyme_inactive");
222 getChoiceFunction->addFnAttr("enzyme_inactive");
223 insertCallFunction->addFnAttr("enzyme_inactive");
224 insertChoiceFunction->addFnAttr("enzyme_inactive");
225 insertArgumentFunction->addFnAttr("enzyme_inactive");
226 insertReturnFunction->addFnAttr("enzyme_inactive");
227 insertFunctionFunction->addFnAttr("enzyme_inactive");
228 insertChoiceGradientFunction->addFnAttr("enzyme_inactive");
229 insertArgumentGradientFunction->addFnAttr("enzyme_inactive");
230 hasCallFunction->addFnAttr("enzyme_inactive");
231 hasChoiceFunction->addFnAttr("enzyme_inactive");
232
233 newTraceFunction->addFnAttr(Attribute::NoFree);
234 getTraceFunction->addFnAttr(Attribute::NoFree);
235 getChoiceFunction->addFnAttr(Attribute::NoFree);
236 insertCallFunction->addFnAttr(Attribute::NoFree);
237 insertChoiceFunction->addFnAttr(Attribute::NoFree);
238 insertArgumentFunction->addFnAttr(Attribute::NoFree);
239 insertReturnFunction->addFnAttr(Attribute::NoFree);
240 insertFunctionFunction->addFnAttr(Attribute::NoFree);
241 insertChoiceGradientFunction->addFnAttr(Attribute::NoFree);
242 insertArgumentGradientFunction->addFnAttr(Attribute::NoFree);
243 hasCallFunction->addFnAttr(Attribute::NoFree);
244 hasChoiceFunction->addFnAttr(Attribute::NoFree);
245}
246
248 LLVMContext &C, Function *getTraceFunction, Function *getChoiceFunction,
249 Function *insertCallFunction, Function *insertChoiceFunction,
250 Function *insertArgumentFunction, Function *insertReturnFunction,
251 Function *insertFunctionFunction, Function *insertChoiceGradientFunction,
252 Function *insertArgumentGradientFunction, Function *newTraceFunction,
253 Function *freeTraceFunction, Function *hasCallFunction,
254 Function *hasChoiceFunction)
255 : TraceInterface(C), getTraceFunction(getTraceFunction),
256 getChoiceFunction(getChoiceFunction),
257 insertCallFunction(insertCallFunction),
258 insertChoiceFunction(insertChoiceFunction),
259 insertArgumentFunction(insertArgumentFunction),
260 insertReturnFunction(insertReturnFunction),
261 insertFunctionFunction(insertFunctionFunction),
262 insertChoiceGradientFunction(insertChoiceGradientFunction),
263 insertArgumentGradientFunction(insertArgumentGradientFunction),
264 newTraceFunction(newTraceFunction), freeTraceFunction(freeTraceFunction),
265 hasCallFunction(hasCallFunction), hasChoiceFunction(hasChoiceFunction){};
266
267// user implemented
268Value *StaticTraceInterface::getTrace(IRBuilder<> &Builder) {
269 return getTraceFunction;
270}
271Value *StaticTraceInterface::getChoice(IRBuilder<> &Builder) {
272 return getChoiceFunction;
273}
274Value *StaticTraceInterface::insertCall(IRBuilder<> &Builder) {
275 return insertCallFunction;
276}
277Value *StaticTraceInterface::insertChoice(IRBuilder<> &Builder) {
278 return insertChoiceFunction;
279}
280Value *StaticTraceInterface::insertArgument(IRBuilder<> &Builder) {
281 return insertArgumentFunction;
282}
283Value *StaticTraceInterface::insertReturn(IRBuilder<> &Builder) {
284 return insertReturnFunction;
285}
286Value *StaticTraceInterface::insertFunction(IRBuilder<> &Builder) {
287 return insertFunctionFunction;
288}
289Value *StaticTraceInterface::insertChoiceGradient(IRBuilder<> &Builder) {
290 return insertChoiceGradientFunction;
291}
293 return insertArgumentGradientFunction;
294}
295Value *StaticTraceInterface::newTrace(IRBuilder<> &Builder) {
296 return newTraceFunction;
297}
298Value *StaticTraceInterface::freeTrace(IRBuilder<> &Builder) {
299 return freeTraceFunction;
300}
301Value *StaticTraceInterface::hasCall(IRBuilder<> &Builder) {
302 return hasCallFunction;
303}
304Value *StaticTraceInterface::hasChoice(IRBuilder<> &Builder) {
305 return hasChoiceFunction;
306}
307
309 Function *F)
310 : TraceInterface(F->getContext()) {
311 assert(dynamicInterface);
312
313 auto &M = *F->getParent();
314 IRBuilder<> Builder(getFirstNonPHIOrDbg(&F->getEntryBlock()));
315
316 getTraceFunction = MaterializeInterfaceFunction(
317 Builder, dynamicInterface, getTraceTy(), 0, M, "get_trace");
318 getChoiceFunction = MaterializeInterfaceFunction(
319 Builder, dynamicInterface, getChoiceTy(), 1, M, "get_choice");
320 insertCallFunction = MaterializeInterfaceFunction(
321 Builder, dynamicInterface, insertCallTy(), 2, M, "insert_call");
322 insertChoiceFunction = MaterializeInterfaceFunction(
323 Builder, dynamicInterface, insertChoiceTy(), 3, M, "insert_choice");
324 insertArgumentFunction = MaterializeInterfaceFunction(
325 Builder, dynamicInterface, insertArgumentTy(), 4, M, "insert_argument");
326 insertReturnFunction = MaterializeInterfaceFunction(
327 Builder, dynamicInterface, insertReturnTy(), 5, M, "insert_return");
328 insertFunctionFunction = MaterializeInterfaceFunction(
329 Builder, dynamicInterface, insertFunctionTy(), 6, M, "insert_function");
330 insertChoiceGradientFunction = MaterializeInterfaceFunction(
331 Builder, dynamicInterface, insertChoiceGradientTy(), 7, M,
332 "insert_choice_gradient");
333 insertArgumentGradientFunction = MaterializeInterfaceFunction(
334 Builder, dynamicInterface, insertArgumentGradientTy(), 8, M,
335 "insert_argument_gradient");
336 newTraceFunction = MaterializeInterfaceFunction(
337 Builder, dynamicInterface, newTraceTy(), 9, M, "new_trace");
338 freeTraceFunction = MaterializeInterfaceFunction(
339 Builder, dynamicInterface, freeTraceTy(), 10, M, "free_trace");
340 hasCallFunction = MaterializeInterfaceFunction(
341 Builder, dynamicInterface, hasCallTy(), 11, M, "has_call");
342 hasChoiceFunction = MaterializeInterfaceFunction(
343 Builder, dynamicInterface, hasChoiceTy(), 12, M, "has_choice");
344
345 assert(newTraceFunction);
346 assert(freeTraceFunction);
347 assert(getTraceFunction);
348 assert(getChoiceFunction);
349 assert(insertCallFunction);
350 assert(insertChoiceFunction);
351
352 assert(insertArgumentFunction);
353 assert(insertReturnFunction);
354 assert(insertFunctionFunction);
355
356 assert(insertChoiceGradientFunction);
357 assert(insertArgumentGradientFunction);
358
359 assert(hasCallFunction);
360 assert(hasChoiceFunction);
361}
362
363Function *DynamicTraceInterface::MaterializeInterfaceFunction(
364 IRBuilder<> &Builder, Value *dynamicInterface, FunctionType *FTy,
365 unsigned index, Module &M, const Twine &Name) {
366 auto ptr =
367 Builder.CreateInBoundsGEP(getInt8PtrTy(dynamicInterface->getContext()),
368 dynamicInterface, Builder.getInt32(index));
369 auto load =
370 Builder.CreateLoad(getInt8PtrTy(dynamicInterface->getContext()), ptr);
371 auto pty = PointerType::get(FTy, load->getPointerAddressSpace());
372 auto cast = Builder.CreatePointerCast(load, pty);
373
374 auto global =
375 new GlobalVariable(M, pty, false, GlobalVariable::PrivateLinkage,
376 ConstantPointerNull::get(pty), Name + "_ptr");
377 Builder.CreateStore(cast, global);
378
379 Function *F = Function::Create(FTy, Function::PrivateLinkage, Name, M);
380 F->addFnAttr(Attribute::AlwaysInline);
381 BasicBlock *Entry = BasicBlock::Create(M.getContext(), "entry", F);
382
383 IRBuilder<> WrapperBuilder(Entry);
384
385 auto ToWrap = WrapperBuilder.CreateLoad(pty, global, Name);
386 auto Args = SmallVector<Value *, 4>(make_pointer_range(F->args()));
387 auto Call = WrapperBuilder.CreateCall(FTy, ToWrap, Args);
388
389 if (!FTy->getReturnType()->isVoidTy()) {
390 WrapperBuilder.CreateRet(Call);
391 } else {
392 WrapperBuilder.CreateRetVoid();
393 }
394
395 return F;
396}
397
398// user implemented
399Value *DynamicTraceInterface::getTrace(IRBuilder<> &Builder) {
400 return getTraceFunction;
401}
402
403Value *DynamicTraceInterface::getChoice(IRBuilder<> &Builder) {
404 return getChoiceFunction;
405}
406
407Value *DynamicTraceInterface::insertCall(IRBuilder<> &Builder) {
408 return insertCallFunction;
409}
410
411Value *DynamicTraceInterface::insertChoice(IRBuilder<> &Builder) {
412 return insertChoiceFunction;
413}
414
415Value *DynamicTraceInterface::insertArgument(IRBuilder<> &Builder) {
416 return insertArgumentFunction;
417}
418
419Value *DynamicTraceInterface::insertReturn(IRBuilder<> &Builder) {
420 return insertReturnFunction;
421}
422
423Value *DynamicTraceInterface::insertFunction(IRBuilder<> &Builder) {
424 return insertFunctionFunction;
425}
426
428 return insertChoiceGradientFunction;
429}
430
432 return insertArgumentGradientFunction;
433}
434
435Value *DynamicTraceInterface::newTrace(IRBuilder<> &Builder) {
436 return newTraceFunction;
437}
438
439Value *DynamicTraceInterface::freeTrace(IRBuilder<> &Builder) {
440 return freeTraceFunction;
441}
442
443Value *DynamicTraceInterface::hasCall(IRBuilder<> &Builder) {
444 return hasCallFunction;
445}
446
447Value *DynamicTraceInterface::hasChoice(IRBuilder<> &Builder) {
448 return hasChoiceFunction;
449}
Type * addressType(LLVMContext &C)
PointerType * traceType(LLVMContext &C)
llvm::PointerType * getDefaultAnonymousTapeType(llvm::LLVMContext &C)
Definition Utils.cpp:437
@ Args
Return is a struct of all args.
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
Definition Utils.h:1174
static llvm::Instruction * getFirstNonPHIOrDbg(llvm::BasicBlock *B)
Definition Utils.h:2272
llvm::Value * insertArgumentGradient(llvm::IRBuilder<> &Builder)
llvm::Value * insertChoiceGradient(llvm::IRBuilder<> &Builder)
DynamicTraceInterface(llvm::Value *dynamicInterface, llvm::Function *F)
llvm::Value * hasChoice(llvm::IRBuilder<> &Builder)
llvm::Value * hasCall(llvm::IRBuilder<> &Builder)
llvm::Value * getChoice(llvm::IRBuilder<> &Builder)
llvm::Value * freeTrace(llvm::IRBuilder<> &Builder)
llvm::Value * insertChoice(llvm::IRBuilder<> &Builder)
llvm::Value * insertFunction(llvm::IRBuilder<> &Builder)
llvm::Value * insertReturn(llvm::IRBuilder<> &Builder)
llvm::Value * getTrace(llvm::IRBuilder<> &Builder)
llvm::Value * insertArgument(llvm::IRBuilder<> &Builder)
llvm::Value * insertCall(llvm::IRBuilder<> &Builder)
llvm::Value * newTrace(llvm::IRBuilder<> &Builder)
llvm::Value * getChoice(llvm::IRBuilder<> &Builder)
llvm::Value * insertCall(llvm::IRBuilder<> &Builder)
llvm::Value * insertArgument(llvm::IRBuilder<> &Builder)
llvm::Value * insertReturn(llvm::IRBuilder<> &Builder)
llvm::Value * insertChoiceGradient(llvm::IRBuilder<> &Builder)
llvm::Value * freeTrace(llvm::IRBuilder<> &Builder)
llvm::Value * insertArgumentGradient(llvm::IRBuilder<> &Builder)
llvm::Value * newTrace(llvm::IRBuilder<> &Builder)
llvm::Value * insertChoice(llvm::IRBuilder<> &Builder)
llvm::Value * getTrace(llvm::IRBuilder<> &Builder)
llvm::Value * hasChoice(llvm::IRBuilder<> &Builder)
StaticTraceInterface(llvm::Module *M)
llvm::Value * insertFunction(llvm::IRBuilder<> &Builder)
llvm::Value * hasCall(llvm::IRBuilder<> &Builder)
llvm::FunctionType * insertArgumentGradientTy()
TraceInterface(llvm::LLVMContext &C)
llvm::FunctionType * hasCallTy()
llvm::FunctionType * insertFunctionTy()
llvm::FunctionType * hasChoiceTy()
llvm::FunctionType * insertChoiceTy()
llvm::FunctionType * insertChoiceGradientTy()
llvm::FunctionType * insertArgumentTy()
static llvm::Type * stringType(llvm::LLVMContext &C)
llvm::FunctionType * insertCallTy()
llvm::FunctionType * freeTraceTy()
static llvm::IntegerType * sizeType(llvm::LLVMContext &C)
llvm::FunctionType * insertReturnTy()
llvm::FunctionType * getChoiceTy()
llvm::FunctionType * newTraceTy()
llvm::FunctionType * getTraceTy()