51 ValueMap<const Value *, WeakTrackingVH> &originalToNewFn,
52 const SmallPtrSetImpl<Function *> &generativeFunctions,
53 const StringSet<> &activeRandomVariables)
54 : Logic(Logic), tutils(tutils), autodiff(autodiff),
55 originalToNewFn(originalToNewFn),
56 generativeFunctions(generativeFunctions),
57 activeRandomVariables(activeRandomVariables) {
68 while (isa<AllocaInst>(entry) && entry->getNextNode()) {
69 entry = entry->getNextNode();
72 IRBuilder<> Builder(entry);
76 auto attributes = fn->getAttributes();
77 for (
size_t i = 0; i < fn->getFunctionType()->getNumParams(); ++i) {
78 bool shouldSkipParam =
80 attributes.hasParamAttr(i,
86 auto arg = fn->arg_begin() + i;
87#if LLVM_VERSION_MAJOR >= 17
88 auto name = Builder.CreateGlobalString(arg->getName());
90 auto name = Builder.CreateGlobalStringPtr(arg->getName());
93 auto Outlined = [](IRBuilder<> &OutlineBuilder,
TraceUtils *OutlineTutils,
94 ArrayRef<Value *> Arguments) {
95 OutlineTutils->InsertArgument(OutlineBuilder, Arguments[0], Arguments[1]);
96 OutlineBuilder.CreateRetVoid();
100 Builder, Outlined, Builder.getVoidTy(), {name, arg},
false,
101 "outline_insert_argument");
103 call->addAttributeAtIndex(
104 AttributeList::FunctionIndex,
105 Attribute::get(F.getContext(),
"enzyme_insert_argument"));
106 call->addAttributeAtIndex(AttributeList::FunctionIndex,
107 Attribute::get(F.getContext(),
"enzyme_active"));
109 auto gradient_setter = ValueAsMetadata::get(
111 auto gradient_setter_node =
112 MDNode::get(F.getContext(), {gradient_setter});
114 call->setMetadata(
"enzyme_gradient_setter", gradient_setter_node);
120 IRBuilder<> Builder(new_call);
122 SmallVector<Value *, 4>
Args(
123 make_range(new_call->arg_begin() + 2, new_call->arg_end()));
125 Value *observed = new_call->getArgOperand(0);
127 Value *address = new_call->getArgOperand(2);
129 StringRef const_address;
130 bool is_address_const = getConstantStringInfo(address, const_address);
131 bool is_random_var_active =
132 activeRandomVariables.empty() ||
133 (is_address_const && activeRandomVariables.count(const_address));
134 Attribute activity_attribute = Attribute::get(
136 is_random_var_active ?
"enzyme_active" :
"enzyme_inactive_val");
139 Args.push_back(observed);
141 auto score = Builder.CreateCall(likelihoodfn->getFunctionType(), likelihoodfn,
142 ArrayRef<Value *>(
Args).slice(1),
143 "likelihood." + call.getName());
145 score->addAttributeAtIndex(AttributeList::FunctionIndex, activity_attribute);
147 auto log_prob_sum = Builder.CreateLoad(
148 Builder.getDoubleTy(), tutils->
getLikelihood(),
"log_prob_sum");
149 auto acc = Builder.CreateFAdd(log_prob_sum, score);
154 Value *trace_args[] = {address, score, observed};
156 auto OutlinedTrace = [](IRBuilder<> &OutlineBuilder,
158 ArrayRef<Value *> Arguments) {
159 OutlineTutils->InsertChoice(OutlineBuilder, Arguments[0], Arguments[1],
161 OutlineBuilder.CreateRetVoid();
165 Builder, OutlinedTrace, Builder.getVoidTy(), trace_args,
false,
166 "outline_insert_choice");
168 trace_call->addAttributeAtIndex(
169 AttributeList::FunctionIndex,
170 Attribute::get(call.getContext(),
"enzyme_inactive"));
171 trace_call->addAttributeAtIndex(
172 AttributeList::FunctionIndex,
173 Attribute::get(call.getContext(),
"enzyme_notypeanalysis"));
176 if (!call.getType()->isVoidTy()) {
177 observed->takeName(new_call);
178 new_call->replaceAllUsesWith(observed);
180 new_call->eraseFromParent();
185 SmallVector<Value *, 4>
Args(
186 make_range(new_call->arg_begin() + 2, new_call->arg_end()));
190 Value *address = new_call->getArgOperand(2);
192 IRBuilder<> Builder(new_call);
194 auto OutlinedSample = [samplefn](IRBuilder<> &OutlineBuilder,
196 ArrayRef<Value *> Arguments) {
197 auto choice = OutlineTutils->SampleOrCondition(
198 OutlineBuilder, samplefn, Arguments.slice(1), Arguments[0],
199 samplefn->getName());
200 OutlineBuilder.CreateRet(choice);
203 const char *mode_str;
210 mode_str =
"condition";
215 Builder, OutlinedSample, samplefn->getReturnType(),
Args,
false,
216 Twine(mode_str) +
"_" + samplefn->getName());
218 StringRef const_address;
219 bool is_address_const = getConstantStringInfo(address, const_address);
220 bool is_random_var_active =
221 activeRandomVariables.empty() ||
222 (is_address_const && activeRandomVariables.count(const_address));
223 Attribute activity_attribute = Attribute::get(
225 is_random_var_active ?
"enzyme_active" :
"enzyme_inactive_val");
227 sample_call->addAttributeAtIndex(
228 AttributeList::FunctionIndex,
229 Attribute::get(call.getContext(),
"enzyme_sample"));
230 sample_call->addAttributeAtIndex(AttributeList::FunctionIndex,
235 auto gradient_setter =
237 auto gradient_setter_node =
238 MDNode::get(call.getContext(), {gradient_setter});
240 sample_call->setMetadata(
"enzyme_gradient_setter", gradient_setter_node);
244 Args.push_back(sample_call);
246 auto score = Builder.CreateCall(likelihoodfn->getFunctionType(), likelihoodfn,
247 ArrayRef<Value *>(
Args).slice(1),
248 "likelihood." + call.getName());
250 score->addAttributeAtIndex(AttributeList::FunctionIndex, activity_attribute);
252 auto log_prob_sum = Builder.CreateLoad(
253 Builder.getDoubleTy(), tutils->
getLikelihood(),
"log_prob_sum");
254 auto acc = Builder.CreateFAdd(log_prob_sum, score);
260 Value *trace_args[] = {address, score, sample_call};
262 auto OutlinedTrace = [](IRBuilder<> &OutlineBuilder,
264 ArrayRef<Value *> Arguments) {
265 OutlineTutils->InsertChoice(OutlineBuilder, Arguments[0], Arguments[1],
267 OutlineBuilder.CreateRetVoid();
271 Builder, OutlinedTrace, Builder.getVoidTy(), trace_args,
false,
272 "outline_insert_choice");
274 trace_call->addAttributeAtIndex(
275 AttributeList::FunctionIndex,
276 Attribute::get(call.getContext(),
"enzyme_inactive"));
277 trace_call->addAttributeAtIndex(
278 AttributeList::FunctionIndex,
279 Attribute::get(call.getContext(),
"enzyme_notypeanalysis"));
282 sample_call->takeName(new_call);
283 new_call->replaceAllUsesWith(sample_call);
284 new_call->eraseFromParent();
288 IRBuilder<> Builder(new_call);
290 SmallVector<Value *, 2> args;
291 for (
auto it = new_call->arg_begin(); it != new_call->arg_end(); it++) {
303 Instruction *replacement;
306 SmallVector<Value *, 2> args_and_likelihood(args);
309 Builder.CreateCall(samplefn->getFunctionType(), samplefn,
310 args_and_likelihood,
"eval." + called->getName());
315#if LLVM_VERSION_MAJOR >= 17
316 auto address = Builder.CreateGlobalString(
317 (call.getName() +
"." + called->getName()).str());
319 auto address = Builder.CreateGlobalStringPtr(
320 (call.getName() +
"." + called->getName()).str());
323 SmallVector<Value *, 2> args_and_trace(args);
325 args_and_trace.push_back(trace);
327 Builder.CreateCall(samplefn->getFunctionType(), samplefn,
328 args_and_trace,
"trace." + called->getName());
335#if LLVM_VERSION_MAJOR >= 17
336 auto address = Builder.CreateGlobalString(
337 (call.getName() +
"." + called->getName()).str());
339 auto address = Builder.CreateGlobalStringPtr(
340 (call.getName() +
"." + called->getName()).str());
343 Instruction *hasCall =
344 tutils->
HasCall(Builder, address,
"has.call." + call.getName());
345 Instruction *ThenTerm, *ElseTerm;
346 Value *ElseTracecall, *ThenTracecall;
347 SplitBlockAndInsertIfThenElse(hasCall, new_call, &ThenTerm, &ElseTerm);
349 new_call->getParent()->setName(hasCall->getParent()->getName() +
".cntd");
351 Builder.SetInsertPoint(ThenTerm);
353 ThenTerm->getParent()->setName(
"condition." + call.getName() +
355 SmallVector<Value *, 2> args_and_cond(args);
357 tutils->
GetTrace(Builder, address, called->getName() +
".subtrace");
359 args_and_cond.push_back(observations);
360 args_and_cond.push_back(trace);
362 Builder.CreateCall(samplefn->getFunctionType(), samplefn,
363 args_and_cond,
"condition." + called->getName());
366 Builder.SetInsertPoint(ElseTerm);
368 ElseTerm->getParent()->setName(
"condition." + call.getName() +
370 SmallVector<Value *, 2> args_and_null(args);
371 auto observations = ConstantPointerNull::get(cast<PointerType>(
374 args_and_null.push_back(observations);
375 args_and_null.push_back(trace);
377 Builder.CreateCall(samplefn->getFunctionType(), samplefn,
378 args_and_null,
"trace." + called->getName());
381 Builder.SetInsertPoint(new_call);
382 auto phi = Builder.CreatePHI(samplefn->getFunctionType()->getReturnType(),
384 phi->addIncoming(ThenTracecall, ThenTerm->getParent());
385 phi->addIncoming(ElseTracecall, ElseTerm->getParent());
393 replacement->takeName(new_call);
394 new_call->replaceAllUsesWith(replacement);
395 new_call->eraseFromParent();
llvm::Function * CreateTrace(RequestContext context, llvm::Function *totrace, const llvm::SmallPtrSetImpl< llvm::Function * > &sampleFunctions, const llvm::SmallPtrSetImpl< llvm::Function * > &observeFunctions, const llvm::StringSet<> &ActiveRandomVariables, ProbProgMode mode, bool autodiff, TraceInterface *interface)
Create a traced version of a function context the instruction which requested this trace (or null).