Enzyme main
Loading...
Searching...
No Matches
TraceGenerator.h
Go to the documentation of this file.
1//===- TraceGenerator.h - Trace sample statements and calls --------------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains an instruction visitor that generates probabilistic
22// programming traces for call sites and sample statements.
23//
24//===----------------------------------------------------------------------===//
25
26#ifndef TraceGenerator_h
27#define TraceGenerator_h
28
29#include "llvm/IR/InstVisitor.h"
30#include "llvm/IR/Instructions.h"
31
32#include "EnzymeLogic.h"
33#include "TraceUtils.h"
34#include "Utils.h"
35
36class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
37private:
38 EnzymeLogic &Logic;
39 TraceUtils *const tutils;
40 ProbProgMode mode = tutils->mode;
41 bool autodiff;
42 llvm::ValueMap<const llvm::Value *, llvm::WeakTrackingVH> &originalToNewFn;
43 const llvm::SmallPtrSetImpl<llvm::Function *> &generativeFunctions;
44 const llvm::StringSet<> &activeRandomVariables;
45
46public:
48 EnzymeLogic &Logic, TraceUtils *tutils, bool autodiff,
49 llvm::ValueMap<const llvm::Value *, llvm::WeakTrackingVH>
50 &originalToNewFn,
51 const llvm::SmallPtrSetImpl<llvm::Function *> &generativeFunctions,
52 const llvm::StringSet<> &activeRandomVariables);
53
54 void visitFunction(llvm::Function &F);
55
56 void handleSampleCall(llvm::CallInst &call, llvm::CallInst *new_call);
57
58 void handleObserveCall(llvm::CallInst &call, llvm::CallInst *new_call);
59
60 void handleArbitraryCall(llvm::CallInst &call, llvm::CallInst *new_call);
61
62 void visitCallInst(llvm::CallInst &call);
63
64 void visitReturnInst(llvm::ReturnInst &ret);
65};
66
67#endif /* TraceGenerator_h */
ProbProgMode
Definition Utils.h:399
void visitFunction(llvm::Function &F)
void handleArbitraryCall(llvm::CallInst &call, llvm::CallInst *new_call)
void handleObserveCall(llvm::CallInst &call, llvm::CallInst *new_call)
void visitCallInst(llvm::CallInst &call)
void handleSampleCall(llvm::CallInst &call, llvm::CallInst *new_call)
void visitReturnInst(llvm::ReturnInst &ret)
TraceGenerator(EnzymeLogic &Logic, TraceUtils *tutils, bool autodiff, llvm::ValueMap< const llvm::Value *, llvm::WeakTrackingVH > &originalToNewFn, const llvm::SmallPtrSetImpl< llvm::Function * > &generativeFunctions, const llvm::StringSet<> &activeRandomVariables)
ProbProgMode mode
Definition TraceUtils.h:52