Enzyme main
Loading...
Searching...
No Matches
TraceGenerator.cpp
Go to the documentation of this file.
1//===- TraceGenerator.h - Trace sample statements and calls --------------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains an instruction visitor that generates probabilistic
22// programming traces for call sites and sample statements.
23//
24//===----------------------------------------------------------------------===//
25
26#include "TraceGenerator.h"
27
28#include "llvm/ADT/SmallVector.h"
29
30#include "llvm/Analysis/ValueTracking.h"
31#include "llvm/IR/Constants.h"
32#include "llvm/IR/Function.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/InstrTypes.h"
35#include "llvm/IR/Instructions.h"
36#include "llvm/IR/Module.h"
37#include "llvm/IR/Type.h"
38#include "llvm/IR/Value.h"
39
40#include "llvm/Transforms/Utils/BasicBlockUtils.h"
41
42#include "FunctionUtils.h"
43#include "TraceInterface.h"
44#include "TraceUtils.h"
45#include "Utils.h"
46
47using namespace llvm;
48
50 EnzymeLogic &Logic, TraceUtils *tutils, bool autodiff,
51 ValueMap<const Value *, WeakTrackingVH> &originalToNewFn,
52 const SmallPtrSetImpl<Function *> &generativeFunctions,
53 const StringSet<> &activeRandomVariables)
54 : Logic(Logic), tutils(tutils), autodiff(autodiff),
55 originalToNewFn(originalToNewFn),
56 generativeFunctions(generativeFunctions),
57 activeRandomVariables(activeRandomVariables) {
58 assert(tutils);
59};
60
62 if (mode == ProbProgMode::Likelihood)
63 return;
64
65 auto fn = tutils->newFunc;
66 auto entry = getFirstNonPHIOrDbgOrLifetime(&fn->getEntryBlock());
67
68 while (isa<AllocaInst>(entry) && entry->getNextNode()) {
69 entry = entry->getNextNode();
70 }
71
72 IRBuilder<> Builder(entry);
73
74 tutils->InsertFunction(Builder, tutils->newFunc);
75
76 auto attributes = fn->getAttributes();
77 for (size_t i = 0; i < fn->getFunctionType()->getNumParams(); ++i) {
78 bool shouldSkipParam =
79 attributes.hasParamAttr(i, TraceUtils::TraceParameterAttribute) ||
80 attributes.hasParamAttr(i,
82 attributes.hasParamAttr(i, TraceUtils::LikelihoodParameterAttribute);
83 if (shouldSkipParam)
84 continue;
85
86 auto arg = fn->arg_begin() + i;
87#if LLVM_VERSION_MAJOR >= 17
88 auto name = Builder.CreateGlobalString(arg->getName());
89#else
90 auto name = Builder.CreateGlobalStringPtr(arg->getName());
91#endif
92
93 auto Outlined = [](IRBuilder<> &OutlineBuilder, TraceUtils *OutlineTutils,
94 ArrayRef<Value *> Arguments) {
95 OutlineTutils->InsertArgument(OutlineBuilder, Arguments[0], Arguments[1]);
96 OutlineBuilder.CreateRetVoid();
97 };
98
99 auto call = tutils->CreateOutlinedFunction(
100 Builder, Outlined, Builder.getVoidTy(), {name, arg}, false,
101 "outline_insert_argument");
102
103 call->addAttributeAtIndex(
104 AttributeList::FunctionIndex,
105 Attribute::get(F.getContext(), "enzyme_insert_argument"));
106 call->addAttributeAtIndex(AttributeList::FunctionIndex,
107 Attribute::get(F.getContext(), "enzyme_active"));
108 if (autodiff) {
109 auto gradient_setter = ValueAsMetadata::get(
110 tutils->interface->insertArgumentGradient(Builder));
111 auto gradient_setter_node =
112 MDNode::get(F.getContext(), {gradient_setter});
113
114 call->setMetadata("enzyme_gradient_setter", gradient_setter_node);
115 }
116 }
117}
118
119void TraceGenerator::handleObserveCall(CallInst &call, CallInst *new_call) {
120 IRBuilder<> Builder(new_call);
121
122 SmallVector<Value *, 4> Args(
123 make_range(new_call->arg_begin() + 2, new_call->arg_end()));
124
125 Value *observed = new_call->getArgOperand(0);
126 Function *likelihoodfn = GetFunctionFromValue(new_call->getArgOperand(1));
127 Value *address = new_call->getArgOperand(2);
128
129 StringRef const_address;
130 bool is_address_const = getConstantStringInfo(address, const_address);
131 bool is_random_var_active =
132 activeRandomVariables.empty() ||
133 (is_address_const && activeRandomVariables.count(const_address));
134 Attribute activity_attribute = Attribute::get(
135 call.getContext(),
136 is_random_var_active ? "enzyme_active" : "enzyme_inactive_val");
137
138 // calculate and accumulate log likelihood
139 Args.push_back(observed);
140
141 auto score = Builder.CreateCall(likelihoodfn->getFunctionType(), likelihoodfn,
142 ArrayRef<Value *>(Args).slice(1),
143 "likelihood." + call.getName());
144
145 score->addAttributeAtIndex(AttributeList::FunctionIndex, activity_attribute);
146
147 auto log_prob_sum = Builder.CreateLoad(
148 Builder.getDoubleTy(), tutils->getLikelihood(), "log_prob_sum");
149 auto acc = Builder.CreateFAdd(log_prob_sum, score);
150 Builder.CreateStore(acc, tutils->getLikelihood());
151
152 // create outlined trace function
153 if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) {
154 Value *trace_args[] = {address, score, observed};
155
156 auto OutlinedTrace = [](IRBuilder<> &OutlineBuilder,
157 TraceUtils *OutlineTutils,
158 ArrayRef<Value *> Arguments) {
159 OutlineTutils->InsertChoice(OutlineBuilder, Arguments[0], Arguments[1],
160 Arguments[2]);
161 OutlineBuilder.CreateRetVoid();
162 };
163
164 auto trace_call = tutils->CreateOutlinedFunction(
165 Builder, OutlinedTrace, Builder.getVoidTy(), trace_args, false,
166 "outline_insert_choice");
167
168 trace_call->addAttributeAtIndex(
169 AttributeList::FunctionIndex,
170 Attribute::get(call.getContext(), "enzyme_inactive"));
171 trace_call->addAttributeAtIndex(
172 AttributeList::FunctionIndex,
173 Attribute::get(call.getContext(), "enzyme_notypeanalysis"));
174 }
175
176 if (!call.getType()->isVoidTy()) {
177 observed->takeName(new_call);
178 new_call->replaceAllUsesWith(observed);
179 }
180 new_call->eraseFromParent();
181}
182
183void TraceGenerator::handleSampleCall(CallInst &call, CallInst *new_call) {
184 // create outlined sample function
185 SmallVector<Value *, 4> Args(
186 make_range(new_call->arg_begin() + 2, new_call->arg_end()));
187
188 Function *samplefn = GetFunctionFromValue(new_call->getArgOperand(0));
189 Function *likelihoodfn = GetFunctionFromValue(new_call->getArgOperand(1));
190 Value *address = new_call->getArgOperand(2);
191
192 IRBuilder<> Builder(new_call);
193
194 auto OutlinedSample = [samplefn](IRBuilder<> &OutlineBuilder,
195 TraceUtils *OutlineTutils,
196 ArrayRef<Value *> Arguments) {
197 auto choice = OutlineTutils->SampleOrCondition(
198 OutlineBuilder, samplefn, Arguments.slice(1), Arguments[0],
199 samplefn->getName());
200 OutlineBuilder.CreateRet(choice);
201 };
202
203 const char *mode_str;
204 switch (mode) {
207 mode_str = "sample";
208 break;
210 mode_str = "condition";
211 break;
212 }
213
214 auto sample_call = tutils->CreateOutlinedFunction(
215 Builder, OutlinedSample, samplefn->getReturnType(), Args, false,
216 Twine(mode_str) + "_" + samplefn->getName());
217
218 StringRef const_address;
219 bool is_address_const = getConstantStringInfo(address, const_address);
220 bool is_random_var_active =
221 activeRandomVariables.empty() ||
222 (is_address_const && activeRandomVariables.count(const_address));
223 Attribute activity_attribute = Attribute::get(
224 call.getContext(),
225 is_random_var_active ? "enzyme_active" : "enzyme_inactive_val");
226
227 sample_call->addAttributeAtIndex(
228 AttributeList::FunctionIndex,
229 Attribute::get(call.getContext(), "enzyme_sample"));
230 sample_call->addAttributeAtIndex(AttributeList::FunctionIndex,
231 activity_attribute);
232
233 if (autodiff &&
234 (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition)) {
235 auto gradient_setter =
236 ValueAsMetadata::get(tutils->interface->insertChoiceGradient(Builder));
237 auto gradient_setter_node =
238 MDNode::get(call.getContext(), {gradient_setter});
239
240 sample_call->setMetadata("enzyme_gradient_setter", gradient_setter_node);
241 }
242
243 // calculate and accumulate log likelihood
244 Args.push_back(sample_call);
245
246 auto score = Builder.CreateCall(likelihoodfn->getFunctionType(), likelihoodfn,
247 ArrayRef<Value *>(Args).slice(1),
248 "likelihood." + call.getName());
249
250 score->addAttributeAtIndex(AttributeList::FunctionIndex, activity_attribute);
251
252 auto log_prob_sum = Builder.CreateLoad(
253 Builder.getDoubleTy(), tutils->getLikelihood(), "log_prob_sum");
254 auto acc = Builder.CreateFAdd(log_prob_sum, score);
255 Builder.CreateStore(acc, tutils->getLikelihood());
256
257 // create outlined trace function
258
259 if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) {
260 Value *trace_args[] = {address, score, sample_call};
261
262 auto OutlinedTrace = [](IRBuilder<> &OutlineBuilder,
263 TraceUtils *OutlineTutils,
264 ArrayRef<Value *> Arguments) {
265 OutlineTutils->InsertChoice(OutlineBuilder, Arguments[0], Arguments[1],
266 Arguments[2]);
267 OutlineBuilder.CreateRetVoid();
268 };
269
270 auto trace_call = tutils->CreateOutlinedFunction(
271 Builder, OutlinedTrace, Builder.getVoidTy(), trace_args, false,
272 "outline_insert_choice");
273
274 trace_call->addAttributeAtIndex(
275 AttributeList::FunctionIndex,
276 Attribute::get(call.getContext(), "enzyme_inactive"));
277 trace_call->addAttributeAtIndex(
278 AttributeList::FunctionIndex,
279 Attribute::get(call.getContext(), "enzyme_notypeanalysis"));
280 }
281
282 sample_call->takeName(new_call);
283 new_call->replaceAllUsesWith(sample_call);
284 new_call->eraseFromParent();
285}
286
287void TraceGenerator::handleArbitraryCall(CallInst &call, CallInst *new_call) {
288 IRBuilder<> Builder(new_call);
289
290 SmallVector<Value *, 2> args;
291 for (auto it = new_call->arg_begin(); it != new_call->arg_end(); it++) {
292 args.push_back(*it);
293 }
294
295 Function *called = getFunctionFromCall(&call);
296 assert(called);
297
298 Function *samplefn = Logic.CreateTrace(
299 RequestContext(&call, &Builder), called, tutils->sampleFunctions,
300 tutils->observeFunctions, activeRandomVariables, mode, autodiff,
301 tutils->interface);
302
303 Instruction *replacement;
304 switch (mode) {
306 SmallVector<Value *, 2> args_and_likelihood(args);
307 args_and_likelihood.push_back(tutils->getLikelihood());
308 replacement =
309 Builder.CreateCall(samplefn->getFunctionType(), samplefn,
310 args_and_likelihood, "eval." + called->getName());
311 break;
312 }
313 case ProbProgMode::Trace: {
314 auto trace = tutils->CreateTrace(Builder);
315#if LLVM_VERSION_MAJOR >= 17
316 auto address = Builder.CreateGlobalString(
317 (call.getName() + "." + called->getName()).str());
318#else
319 auto address = Builder.CreateGlobalStringPtr(
320 (call.getName() + "." + called->getName()).str());
321#endif
322
323 SmallVector<Value *, 2> args_and_trace(args);
324 args_and_trace.push_back(tutils->getLikelihood());
325 args_and_trace.push_back(trace);
326 replacement =
327 Builder.CreateCall(samplefn->getFunctionType(), samplefn,
328 args_and_trace, "trace." + called->getName());
329
330 tutils->InsertCall(Builder, address, trace);
331 break;
332 }
334 auto trace = tutils->CreateTrace(Builder);
335#if LLVM_VERSION_MAJOR >= 17
336 auto address = Builder.CreateGlobalString(
337 (call.getName() + "." + called->getName()).str());
338#else
339 auto address = Builder.CreateGlobalStringPtr(
340 (call.getName() + "." + called->getName()).str());
341#endif
342
343 Instruction *hasCall =
344 tutils->HasCall(Builder, address, "has.call." + call.getName());
345 Instruction *ThenTerm, *ElseTerm;
346 Value *ElseTracecall, *ThenTracecall;
347 SplitBlockAndInsertIfThenElse(hasCall, new_call, &ThenTerm, &ElseTerm);
348
349 new_call->getParent()->setName(hasCall->getParent()->getName() + ".cntd");
350
351 Builder.SetInsertPoint(ThenTerm);
352 {
353 ThenTerm->getParent()->setName("condition." + call.getName() +
354 ".with.trace");
355 SmallVector<Value *, 2> args_and_cond(args);
356 auto observations =
357 tutils->GetTrace(Builder, address, called->getName() + ".subtrace");
358 args_and_cond.push_back(tutils->getLikelihood());
359 args_and_cond.push_back(observations);
360 args_and_cond.push_back(trace);
361 ThenTracecall =
362 Builder.CreateCall(samplefn->getFunctionType(), samplefn,
363 args_and_cond, "condition." + called->getName());
364 }
365
366 Builder.SetInsertPoint(ElseTerm);
367 {
368 ElseTerm->getParent()->setName("condition." + call.getName() +
369 ".without.trace");
370 SmallVector<Value *, 2> args_and_null(args);
371 auto observations = ConstantPointerNull::get(cast<PointerType>(
372 tutils->getTraceInterface()->newTraceTy()->getReturnType()));
373 args_and_null.push_back(tutils->getLikelihood());
374 args_and_null.push_back(observations);
375 args_and_null.push_back(trace);
376 ElseTracecall =
377 Builder.CreateCall(samplefn->getFunctionType(), samplefn,
378 args_and_null, "trace." + called->getName());
379 }
380
381 Builder.SetInsertPoint(new_call);
382 auto phi = Builder.CreatePHI(samplefn->getFunctionType()->getReturnType(),
383 2, call.getName());
384 phi->addIncoming(ThenTracecall, ThenTerm->getParent());
385 phi->addIncoming(ElseTracecall, ElseTerm->getParent());
386 replacement = phi;
387
388 tutils->InsertCall(Builder, address, trace);
389 break;
390 }
391 }
392
393 replacement->takeName(new_call);
394 new_call->replaceAllUsesWith(replacement);
395 new_call->eraseFromParent();
396}
397
398void TraceGenerator::visitCallInst(CallInst &call) {
399 auto fn = getFunctionFromCall(&call);
400
401 if (!generativeFunctions.count(fn))
402 return;
403
404 CallInst *new_call = dyn_cast<CallInst>(originalToNewFn[&call]);
405
406 if (tutils->isSampleCall(&call)) {
407 handleSampleCall(call, new_call);
408 } else if (tutils->isObserveCall(&call)) {
409 handleObserveCall(call, new_call);
410 } else {
411 handleArbitraryCall(call, new_call);
412 }
413}
414
415void TraceGenerator::visitReturnInst(ReturnInst &ret) {
416
417 if (!ret.getReturnValue())
418 return;
419
420 ReturnInst *new_ret = dyn_cast<ReturnInst>(originalToNewFn[&ret]);
421
422 IRBuilder<> Builder(new_ret);
423 tutils->InsertReturn(Builder, new_ret->getReturnValue());
424}
static Operation * getFunctionFromCall(CallOpInterface iface)
Function * GetFunctionFromValue(Value *fn)
Definition Utils.cpp:3547
@ Args
Return is a struct of all args.
static llvm::Instruction * getFirstNonPHIOrDbgOrLifetime(llvm::BasicBlock *B)
Definition Utils.h:2281
llvm::Function * CreateTrace(RequestContext context, llvm::Function *totrace, const llvm::SmallPtrSetImpl< llvm::Function * > &sampleFunctions, const llvm::SmallPtrSetImpl< llvm::Function * > &observeFunctions, const llvm::StringSet<> &ActiveRandomVariables, ProbProgMode mode, bool autodiff, TraceInterface *interface)
Create a traced version of a function context the instruction which requested this trace (or null).
void visitFunction(llvm::Function &F)
void handleArbitraryCall(llvm::CallInst &call, llvm::CallInst *new_call)
void handleObserveCall(llvm::CallInst &call, llvm::CallInst *new_call)
void visitCallInst(llvm::CallInst &call)
void handleSampleCall(llvm::CallInst &call, llvm::CallInst *new_call)
void visitReturnInst(llvm::ReturnInst &ret)
TraceGenerator(EnzymeLogic &Logic, TraceUtils *tutils, bool autodiff, llvm::ValueMap< const llvm::Value *, llvm::WeakTrackingVH > &originalToNewFn, const llvm::SmallPtrSetImpl< llvm::Function * > &generativeFunctions, const llvm::StringSet<> &activeRandomVariables)
virtual llvm::Value * insertArgumentGradient(llvm::IRBuilder<> &Builder)=0
virtual llvm::Value * insertChoiceGradient(llvm::IRBuilder<> &Builder)=0
llvm::FunctionType * newTraceTy()
llvm::CallInst * InsertFunction(llvm::IRBuilder<> &Builder, llvm::Function *function)
static constexpr const char TraceParameterAttribute[]
Definition TraceUtils.h:57
llvm::CallInst * CreateTrace(llvm::IRBuilder<> &Builder, const llvm::Twine &Name="trace")
llvm::CallInst * InsertReturn(llvm::IRBuilder<> &Builder, llvm::Value *ret)
TraceInterface * getTraceInterface()
TraceInterface * interface
Definition TraceUtils.h:51
llvm::Instruction * HasCall(llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
static constexpr const char ObservationsParameterAttribute[]
Definition TraceUtils.h:58
llvm::SmallPtrSet< llvm::Function *, 4 > observeFunctions
Definition TraceUtils.h:55
llvm::CallInst * CreateOutlinedFunction(llvm::IRBuilder<> &Builder, llvm::function_ref< void(llvm::IRBuilder<> &, TraceUtils *, llvm::ArrayRef< llvm::Value * >)> Outlined, llvm::Type *RetTy, llvm::ArrayRef< llvm::Value * > Arguments, bool needsLikelihood=true, const llvm::Twine &Name="")
static constexpr const char LikelihoodParameterAttribute[]
Definition TraceUtils.h:60
llvm::Function * newFunc
Definition TraceUtils.h:53
llvm::SmallPtrSet< llvm::Function *, 4 > sampleFunctions
Definition TraceUtils.h:54
llvm::Value * getLikelihood()
bool isObserveCall(llvm::CallInst *call)
llvm::CallInst * GetTrace(llvm::IRBuilder<> &Builder, llvm::Value *address, const llvm::Twine &Name="")
bool isSampleCall(llvm::CallInst *call)
llvm::CallInst * InsertCall(llvm::IRBuilder<> &Builder, llvm::Value *address, llvm::Value *subtrace)