Enzyme main
Loading...
Searching...
No Matches
CallDerivatives.cpp
Go to the documentation of this file.
1//===- CallDerivatives.cpp - Implementation of known call derivatives --===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains the implementation of functions in instruction visitor
22// AdjointGenerator that generate corresponding augmented forward pass code,
23// and adjoints for certain known functions.
24//
25//===----------------------------------------------------------------------===//
26
27#include "AdjointGenerator.h"
28
29using namespace llvm;
30
31extern "C" {
32void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *, LLVMValueRef, uint64_t,
33 LLVMValueRef, uint8_t) = nullptr;
34}
35
36void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
37 llvm::StringRef funcName) {
38 using namespace llvm;
39
40 assert(called);
41 assert(gutils->getWidth() == 1);
42
43 IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&call));
44 BuilderZ.setFastMathFlags(getFast());
45
46 // MPI send / recv can only send float/integers
47 if (funcName == "PMPI_Isend" || funcName == "MPI_Isend" ||
48 funcName == "PMPI_Irecv" || funcName == "MPI_Irecv") {
49 if (!gutils->isConstantInstruction(&call)) {
52 assert(!gutils->isConstantValue(call.getOperand(0)));
53 assert(!gutils->isConstantValue(call.getOperand(6)));
54 Value *d_req = gutils->invertPointerM(call.getOperand(6), BuilderZ);
55 if (d_req->getType()->isIntegerTy()) {
56 d_req = BuilderZ.CreateIntToPtr(
57 d_req, getUnqual(getInt8PtrTy(call.getContext())));
58 }
59
60 auto i64 = Type::getInt64Ty(call.getContext());
61 auto impi = getMPIHelper(call.getContext());
62
63 Value *impialloc =
64 CreateAllocation(BuilderZ, impi, ConstantInt::get(i64, 1));
65 BuilderZ.SetInsertPoint(gutils->getNewFromOriginal(&call));
66
67 d_req = BuilderZ.CreateBitCast(d_req, getUnqual(impialloc->getType()));
68 Value *d_req_prev = BuilderZ.CreateLoad(impialloc->getType(), d_req);
69 BuilderZ.CreateStore(
70 BuilderZ.CreatePointerCast(d_req_prev,
71 getInt8PtrTy(call.getContext())),
72 getMPIMemberPtr<MPI_Elem::Old>(BuilderZ, impialloc, impi));
73 BuilderZ.CreateStore(impialloc, d_req);
74
75 if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") {
76 Value *tysize =
77 MPI_TYPE_SIZE(gutils->getNewFromOriginal(call.getOperand(2)),
78 BuilderZ, call.getType(), called);
79
80 auto len_arg = BuilderZ.CreateZExtOrTrunc(
81 gutils->getNewFromOriginal(call.getOperand(1)),
82 Type::getInt64Ty(call.getContext()));
83 len_arg = BuilderZ.CreateMul(
84 len_arg,
85 BuilderZ.CreateZExtOrTrunc(tysize,
86 Type::getInt64Ty(call.getContext())),
87 "", true, true);
88
89 Value *firstallocation =
90 CreateAllocation(BuilderZ, Type::getInt8Ty(call.getContext()),
91 len_arg, "mpirecv_malloccache");
92 BuilderZ.CreateStore(firstallocation, getMPIMemberPtr<MPI_Elem::Buf>(
93 BuilderZ, impialloc, impi));
94 BuilderZ.SetInsertPoint(gutils->getNewFromOriginal(&call));
95 } else {
96 Value *ibuf = gutils->invertPointerM(call.getOperand(0), BuilderZ);
97 if (ibuf->getType()->isIntegerTy())
98 ibuf =
99 BuilderZ.CreateIntToPtr(ibuf, getInt8PtrTy(call.getContext()));
100 BuilderZ.CreateStore(
101 ibuf, getMPIMemberPtr<MPI_Elem::Buf>(BuilderZ, impialloc, impi));
102 }
103
104 BuilderZ.CreateStore(
105 BuilderZ.CreateZExtOrTrunc(
106 gutils->getNewFromOriginal(call.getOperand(1)), i64),
107 getMPIMemberPtr<MPI_Elem::Count>(BuilderZ, impialloc, impi));
108
109 Value *dataType = gutils->getNewFromOriginal(call.getOperand(2));
110 if (dataType->getType()->isIntegerTy())
111 dataType = BuilderZ.CreateIntToPtr(
112 dataType, getInt8PtrTy(dataType->getContext()));
113 BuilderZ.CreateStore(
114 BuilderZ.CreatePointerCast(dataType,
115 getInt8PtrTy(call.getContext())),
116 getMPIMemberPtr<MPI_Elem::DataType>(BuilderZ, impialloc, impi));
117
118 BuilderZ.CreateStore(
119 BuilderZ.CreateZExtOrTrunc(
120 gutils->getNewFromOriginal(call.getOperand(3)), i64),
121 getMPIMemberPtr<MPI_Elem::Src>(BuilderZ, impialloc, impi));
122
123 BuilderZ.CreateStore(
124 BuilderZ.CreateZExtOrTrunc(
125 gutils->getNewFromOriginal(call.getOperand(4)), i64),
126 getMPIMemberPtr<MPI_Elem::Tag>(BuilderZ, impialloc, impi));
127
128 Value *comm = gutils->getNewFromOriginal(call.getOperand(5));
129 if (comm->getType()->isIntegerTy())
130 comm = BuilderZ.CreateIntToPtr(comm,
131 getInt8PtrTy(dataType->getContext()));
132 BuilderZ.CreateStore(
133 BuilderZ.CreatePointerCast(comm, getInt8PtrTy(call.getContext())),
134 getMPIMemberPtr<MPI_Elem::Comm>(BuilderZ, impialloc, impi));
135
136 BuilderZ.CreateStore(
137 ConstantInt::get(
138 Type::getInt8Ty(impialloc->getContext()),
139 (funcName == "MPI_Isend" || funcName == "PMPI_Isend")
141 : (int)MPI_CallType::IRECV),
142 getMPIMemberPtr<MPI_Elem::Call>(BuilderZ, impialloc, impi));
143 // TODO old
144 }
147 IRBuilder<> Builder2(&call);
148 getReverseBuilder(Builder2);
149
150 Type *statusType = nullptr;
151#if LLVM_VERSION_MAJOR < 17
152 if (Function *recvfn = called->getParent()->getFunction(
153 getRenamedPerCallingConv(called->getName(), "MPI_Wait"))) {
154 auto statusArg = recvfn->arg_end();
155 statusArg--;
156 if (auto PT = dyn_cast<PointerType>(statusArg->getType()))
157 statusType = PT->getPointerElementType();
158 }
159#endif
160 if (statusType == nullptr) {
161 statusType = ArrayType::get(Type::getInt8Ty(call.getContext()), 24);
162 llvm::errs() << " warning could not automatically determine mpi "
163 "status type, assuming [24 x i8]\n";
164 }
165 Value *req =
166 lookup(gutils->getNewFromOriginal(call.getOperand(6)), Builder2);
167 Value *d_req = lookup(
168 gutils->invertPointerM(call.getOperand(6), Builder2), Builder2);
169 if (d_req->getType()->isIntegerTy()) {
170 d_req =
171 Builder2.CreateIntToPtr(d_req, getInt8PtrTy(call.getContext()));
172 }
173 auto impi = getMPIHelper(call.getContext());
174 Type *helperTy = getUnqual(impi);
175 Value *helper = Builder2.CreatePointerCast(d_req, getUnqual(helperTy));
176 helper = Builder2.CreateLoad(helperTy, helper);
177
178 auto i64 = Type::getInt64Ty(call.getContext());
179
180 Value *firstallocation;
181 firstallocation = Builder2.CreateLoad(
182 getInt8PtrTy(call.getContext()),
183 getMPIMemberPtr<MPI_Elem::Buf>(Builder2, helper, impi));
184 Value *len_arg = nullptr;
185 if (auto C = dyn_cast<Constant>(
186 gutils->getNewFromOriginal(call.getOperand(1)))) {
187 len_arg = Builder2.CreateZExtOrTrunc(C, i64);
188 } else {
189 len_arg = Builder2.CreateLoad(
190 i64, getMPIMemberPtr<MPI_Elem::Count>(Builder2, helper, impi));
191 }
192 Value *tysize = nullptr;
193 if (auto C = dyn_cast<Constant>(
194 gutils->getNewFromOriginal(call.getOperand(2)))) {
195 tysize = C;
196 } else {
197 tysize = Builder2.CreateLoad(
198 getInt8PtrTy(call.getContext()),
199 getMPIMemberPtr<MPI_Elem::DataType>(Builder2, helper, impi));
200 }
201
202 Value *prev;
203 prev = Builder2.CreateLoad(
204 getInt8PtrTy(call.getContext()),
205 getMPIMemberPtr<MPI_Elem::Old>(Builder2, helper, impi));
206
207 Builder2.CreateStore(prev, Builder2.CreatePointerCast(
208 d_req, getUnqual(prev->getType())));
209
210 assert(shouldFree());
211
212 assert(tysize);
213 tysize = MPI_TYPE_SIZE(tysize, Builder2, call.getType(), called);
214
215 Value *args[] = {/*req*/ req,
216 /*status*/ IRBuilder<>(gutils->inversionAllocs)
217 .CreateAlloca(statusType)};
218 FunctionCallee waitFunc = nullptr;
219 for (auto name : {
220 "MPI_Wait",
221 })
222 if (Function *recvfn = called->getParent()->getFunction(
223 getRenamedPerCallingConv(called->getName(), name))) {
224 auto statusArg = recvfn->arg_end();
225 statusArg--;
226 if (statusArg->getType()->isIntegerTy())
227 args[1] = Builder2.CreatePtrToInt(args[1], statusArg->getType());
228 else
229 args[1] = Builder2.CreateBitCast(args[1], statusArg->getType());
230 waitFunc = recvfn;
231 break;
232 }
233 if (!waitFunc) {
234 Type *types[sizeof(args) / sizeof(*args)];
235 for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++)
236 types[i] = args[i]->getType();
237 FunctionType *FT = FunctionType::get(call.getType(), types, false);
238 waitFunc = called->getParent()->getOrInsertFunction(
239 getRenamedPerCallingConv(called->getName(), "MPI_Wait"), FT);
240 }
241 assert(waitFunc);
242
243 // Need to preserve the shadow Request (operand 6 in isend/irecv),
244 // which becomes operand 0 for iwait.
245 auto ReqDefs = gutils->getInvertedBundles(
246 &call,
249 Builder2, /*lookup*/ true);
250
251 auto BufferDefs = gutils->getInvertedBundles(
252 &call,
256 Builder2, /*lookup*/ true);
257
258 auto fcall = Builder2.CreateCall(waitFunc, args, ReqDefs);
259 fcall->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
260 if (auto F = dyn_cast<Function>(waitFunc.getCallee()))
261 fcall->setCallingConv(F->getCallingConv());
262 len_arg = Builder2.CreateMul(
263 len_arg,
264 Builder2.CreateZExtOrTrunc(tysize,
265 Type::getInt64Ty(Builder2.getContext())),
266 "", true, true);
267 if (funcName == "MPI_Irecv" || funcName == "PMPI_Irecv") {
268 auto val_arg =
269 ConstantInt::get(Type::getInt8Ty(Builder2.getContext()), 0);
270 auto volatile_arg = ConstantInt::getFalse(Builder2.getContext());
271 assert(!gutils->isConstantValue(call.getOperand(0)));
272 auto dbuf = firstallocation;
273 Value *nargs[] = {dbuf, val_arg, len_arg, volatile_arg};
274 Type *tys[] = {dbuf->getType(), len_arg->getType()};
275
276 auto memset = cast<CallInst>(Builder2.CreateCall(
277 getIntrinsicDeclaration(called->getParent(), Intrinsic::memset,
278 tys),
279 nargs, BufferDefs));
280 memset->addParamAttr(0, Attribute::NonNull);
281 } else if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") {
282 assert(!gutils->isConstantValue(call.getOperand(0)));
283 Value *shadow = lookup(
284 gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);
285
286 // TODO add operand bundle (unless force inlined?)
287 DifferentiableMemCopyFloats(call, call.getOperand(0), firstallocation,
288 shadow, len_arg, Builder2, BufferDefs);
289
290 if (shouldFree()) {
291 CreateDealloc(Builder2, firstallocation);
292 }
293 } else
294 assert(0 && "illegal mpi");
295
296 CreateDealloc(Builder2, helper);
297 }
298 if (Mode == DerivativeMode::ForwardMode ||
300 IRBuilder<> Builder2(&call);
301 getForwardBuilder(Builder2);
302
303 assert(!gutils->isConstantValue(call.getOperand(0)));
304 assert(!gutils->isConstantValue(call.getOperand(6)));
305
306 Value *buf = gutils->invertPointerM(call.getOperand(0), Builder2);
307 Value *count = gutils->getNewFromOriginal(call.getOperand(1));
308 Value *datatype = gutils->getNewFromOriginal(call.getOperand(2));
309 Value *source = gutils->getNewFromOriginal(call.getOperand(3));
310 Value *tag = gutils->getNewFromOriginal(call.getOperand(4));
311 Value *comm = gutils->getNewFromOriginal(call.getOperand(5));
312 Value *request = gutils->invertPointerM(call.getOperand(6), Builder2);
313
314 Value *args[] = {
315 /*buf*/ buf,
316 /*count*/ count,
317 /*datatype*/ datatype,
318 /*source*/ source,
319 /*tag*/ tag,
320 /*comm*/ comm,
321 /*request*/ request,
322 };
323
324 auto Defs = gutils->getInvertedBundles(
325 &call,
329 Builder2, /*lookup*/ false);
330
331 auto callval = call.getCalledOperand();
332
333 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
334 return;
335 }
336 }
338 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
339 return;
340 }
341
342 if (funcName == "MPI_Wait" || funcName == "PMPI_Wait") {
343 Value *d_reqp = nullptr;
344 auto impi = getMPIHelper(call.getContext());
347 Value *req = gutils->getNewFromOriginal(call.getOperand(0));
348 Value *d_req = gutils->invertPointerM(call.getOperand(0), BuilderZ);
349
350 if (req->getType()->isIntegerTy()) {
351 req = BuilderZ.CreateIntToPtr(
352 req, getUnqual(getInt8PtrTy(call.getContext())));
353 }
354
355 Value *isNull = nullptr;
356 if (auto GV = gutils->newFunc->getParent()->getNamedValue(
357 "ompi_request_null")) {
358 Value *reql = BuilderZ.CreatePointerCast(req, getUnqual(GV->getType()));
359 reql = BuilderZ.CreateLoad(GV->getType(), reql);
360 isNull = BuilderZ.CreateICmpEQ(reql, GV);
361 }
362
363 if (d_req->getType()->isIntegerTy()) {
364 d_req = BuilderZ.CreateIntToPtr(
365 d_req, getUnqual(getInt8PtrTy(call.getContext())));
366 }
367
368 d_reqp = BuilderZ.CreateLoad(
369 getUnqual(impi),
370 BuilderZ.CreatePointerCast(d_req, getUnqual(getUnqual(impi))));
371 if (isNull)
372 d_reqp =
373 CreateSelect(BuilderZ, isNull,
374 Constant::getNullValue(d_reqp->getType()), d_reqp);
375 if (auto I = dyn_cast<Instruction>(d_reqp))
376 gutils->TapesToPreventRecomputation.insert(I);
377 d_reqp = gutils->cacheForReverse(
378 BuilderZ, d_reqp, getIndex(&call, CacheType::Tape, BuilderZ));
379 }
382 IRBuilder<> Builder2(&call);
383 getReverseBuilder(Builder2);
384
385 assert(!gutils->isConstantValue(call.getOperand(0)));
386 Value *req =
387 lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2);
388
390 d_reqp = BuilderZ.CreatePHI(getUnqual(impi), 0);
391 d_reqp = gutils->cacheForReverse(
392 BuilderZ, d_reqp, getIndex(&call, CacheType::Tape, BuilderZ));
393 } else
394 assert(d_reqp);
395 d_reqp = lookup(d_reqp, Builder2);
396
397 Value *isNull = Builder2.CreateICmpEQ(
398 d_reqp, Constant::getNullValue(d_reqp->getType()));
399
400 BasicBlock *currentBlock = Builder2.GetInsertBlock();
401 BasicBlock *nonnullBlock = gutils->addReverseBlock(
402 currentBlock, currentBlock->getName() + "_nonnull");
403 BasicBlock *endBlock = gutils->addReverseBlock(
404 nonnullBlock, currentBlock->getName() + "_end",
405 /*fork*/ true, /*push*/ false);
406
407 Builder2.CreateCondBr(isNull, endBlock, nonnullBlock);
408 Builder2.SetInsertPoint(nonnullBlock);
409
410 Value *cache = Builder2.CreateLoad(impi, d_reqp);
411
412 Value *args[] = {
413 getMPIMemberPtr<MPI_Elem::Buf, false>(Builder2, cache, impi),
414 getMPIMemberPtr<MPI_Elem::Count, false>(Builder2, cache, impi),
415 getMPIMemberPtr<MPI_Elem::DataType, false>(Builder2, cache, impi),
416 getMPIMemberPtr<MPI_Elem::Src, false>(Builder2, cache, impi),
417 getMPIMemberPtr<MPI_Elem::Tag, false>(Builder2, cache, impi),
418 getMPIMemberPtr<MPI_Elem::Comm, false>(Builder2, cache, impi),
419 getMPIMemberPtr<MPI_Elem::Call, false>(Builder2, cache, impi),
420 req};
421 Type *types[sizeof(args) / sizeof(*args) - 1];
422 for (size_t i = 0; i < sizeof(args) / sizeof(*args) - 1; i++)
423 types[i] = args[i]->getType();
424 Function *dwait = getOrInsertDifferentialMPI_Wait(
425 *called->getParent(), types, call.getOperand(0)->getType(),
426 called->getName());
427
428 // Need to preserve the shadow Request (operand 0 in wait).
429 // However, this doesn't end up preserving
430 // the underlying buffers for the adjoint. To rememdy, force inline.
431 auto cal =
432 Builder2.CreateCall(dwait, args,
433 gutils->getInvertedBundles(
434 &call, {ValueType::Shadow, ValueType::None},
435 Builder2, /*lookup*/ true));
436 cal->setCallingConv(dwait->getCallingConv());
437 cal->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
438 cal->addFnAttr(Attribute::AlwaysInline);
439 Builder2.CreateBr(endBlock);
440 {
441 auto found = gutils->reverseBlockToPrimal.find(endBlock);
442 assert(found != gutils->reverseBlockToPrimal.end());
443 SmallVector<BasicBlock *, 4> &vec =
444 gutils->reverseBlocks[found->second];
445 assert(vec.size());
446 vec.push_back(endBlock);
447 }
448 Builder2.SetInsertPoint(endBlock);
449 } else if (Mode == DerivativeMode::ForwardMode ||
451 IRBuilder<> Builder2(&call);
452 getForwardBuilder(Builder2);
453
454 assert(!gutils->isConstantValue(call.getOperand(0)));
455
456 Value *request = gutils->invertPointerM(call.getArgOperand(0), Builder2);
457 Value *status = gutils->invertPointerM(call.getArgOperand(1), Builder2);
458
459 Value *args[] = {/*request*/ request,
460 /*status*/ status};
461
462 auto Defs = gutils->getInvertedBundles(
463 &call, {ValueType::Shadow, ValueType::Shadow}, Builder2,
464 /*lookup*/ false);
465
466 auto callval = call.getCalledOperand();
467
468 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
469 return;
470 }
472 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
473 return;
474 }
475
476 if (funcName == "MPI_Waitall" || funcName == "PMPI_Waitall") {
477 Value *d_reqp = nullptr;
478 auto impi = getMPIHelper(call.getContext());
479 PointerType *reqType = getUnqual(impi);
482 Value *count = gutils->getNewFromOriginal(call.getOperand(0));
483 Value *req = gutils->getNewFromOriginal(call.getOperand(1));
484 Value *d_req = gutils->invertPointerM(call.getOperand(1), BuilderZ);
485
486 if (req->getType()->isIntegerTy()) {
487 req = BuilderZ.CreateIntToPtr(
488 req, getUnqual(getInt8PtrTy(call.getContext())));
489 }
490
491 if (d_req->getType()->isIntegerTy()) {
492 d_req = BuilderZ.CreateIntToPtr(
493 d_req, getUnqual(getInt8PtrTy(call.getContext())));
494 }
495
496 Function *dsave = getOrInsertDifferentialWaitallSave(
497 *gutils->oldFunc->getParent(),
498 {count->getType(), req->getType(), d_req->getType()}, reqType);
499
500 d_reqp = BuilderZ.CreateCall(dsave, {count, req, d_req});
501 cast<CallInst>(d_reqp)->setCallingConv(dsave->getCallingConv());
502 cast<CallInst>(d_reqp)->setDebugLoc(
503 gutils->getNewFromOriginal(call.getDebugLoc()));
504 d_reqp = gutils->cacheForReverse(
505 BuilderZ, d_reqp, getIndex(&call, CacheType::Tape, BuilderZ));
506 }
509 IRBuilder<> Builder2(&call);
510 getReverseBuilder(Builder2);
511
512 assert(!gutils->isConstantValue(call.getOperand(1)));
513 Value *count =
514 lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2);
515 Value *req_orig =
516 lookup(gutils->getNewFromOriginal(call.getOperand(1)), Builder2);
517
519 d_reqp = BuilderZ.CreatePHI(getUnqual(reqType), 0);
520 d_reqp = gutils->cacheForReverse(
521 BuilderZ, d_reqp, getIndex(&call, CacheType::Tape, BuilderZ));
522 }
523
524 d_reqp = lookup(d_reqp, Builder2);
525
526 BasicBlock *currentBlock = Builder2.GetInsertBlock();
527 BasicBlock *loopBlock = gutils->addReverseBlock(
528 currentBlock, currentBlock->getName() + "_loop");
529 BasicBlock *nonnullBlock = gutils->addReverseBlock(
530 loopBlock, currentBlock->getName() + "_nonnull");
531 BasicBlock *eloopBlock = gutils->addReverseBlock(
532 nonnullBlock, currentBlock->getName() + "_eloop");
533 BasicBlock *endBlock =
534 gutils->addReverseBlock(eloopBlock, currentBlock->getName() + "_end",
535 /*fork*/ true, /*push*/ false);
536
537 Builder2.CreateCondBr(
538 Builder2.CreateICmpNE(count,
539 ConstantInt::get(count->getType(), 0, false)),
540 loopBlock, endBlock);
541
542 Builder2.SetInsertPoint(loopBlock);
543 auto idx = Builder2.CreatePHI(count->getType(), 2);
544 idx->addIncoming(ConstantInt::get(count->getType(), 0, false),
545 currentBlock);
546 Value *inc = Builder2.CreateAdd(
547 idx, ConstantInt::get(count->getType(), 1, false), "", true, true);
548 idx->addIncoming(inc, eloopBlock);
549
550 Value *idxs[] = {idx};
551 Value *req = Builder2.CreateInBoundsGEP(reqType, req_orig, idxs);
552 Value *d_req = Builder2.CreateInBoundsGEP(reqType, d_reqp, idxs);
553
554 d_req = Builder2.CreateLoad(
555 getUnqual(impi),
556 Builder2.CreatePointerCast(d_req, getUnqual(getUnqual(impi))));
557
558 Value *isNull = Builder2.CreateICmpEQ(
559 d_req, Constant::getNullValue(d_req->getType()));
560
561 Builder2.CreateCondBr(isNull, eloopBlock, nonnullBlock);
562 Builder2.SetInsertPoint(nonnullBlock);
563
564 Value *cache = Builder2.CreateLoad(impi, d_req);
565
566 Value *args[] = {
567 getMPIMemberPtr<MPI_Elem::Buf, false>(Builder2, cache, impi),
568 getMPIMemberPtr<MPI_Elem::Count, false>(Builder2, cache, impi),
569 getMPIMemberPtr<MPI_Elem::DataType, false>(Builder2, cache, impi),
570 getMPIMemberPtr<MPI_Elem::Src, false>(Builder2, cache, impi),
571 getMPIMemberPtr<MPI_Elem::Tag, false>(Builder2, cache, impi),
572 getMPIMemberPtr<MPI_Elem::Comm, false>(Builder2, cache, impi),
573 getMPIMemberPtr<MPI_Elem::Call, false>(Builder2, cache, impi),
574 req};
575 Type *types[sizeof(args) / sizeof(*args) - 1];
576 for (size_t i = 0; i < sizeof(args) / sizeof(*args) - 1; i++)
577 types[i] = args[i]->getType();
578 Function *dwait = getOrInsertDifferentialMPI_Wait(
579 *called->getParent(), types, req->getType(), called->getName());
580 // Need to preserve the shadow Request (operand 6 in isend/irecv), which
581 // becomes operand 0 for iwait. However, this doesn't end up preserving
582 // the underlying buffers for the adjoint. To remedy, force inline the
583 // function.
584 auto cal = Builder2.CreateCall(
585 dwait, args,
586 gutils->getInvertedBundles(&call,
587 {ValueType::None, ValueType::None,
588 ValueType::None, ValueType::None,
589 ValueType::None, ValueType::None,
590 ValueType::Shadow},
591 Builder2, /*lookup*/ true));
592 cal->setCallingConv(dwait->getCallingConv());
593 cal->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
594 cal->addFnAttr(Attribute::AlwaysInline);
595 Builder2.CreateBr(eloopBlock);
596
597 Builder2.SetInsertPoint(eloopBlock);
598 Builder2.CreateCondBr(Builder2.CreateICmpEQ(inc, count), endBlock,
599 loopBlock);
600 {
601 auto found = gutils->reverseBlockToPrimal.find(endBlock);
602 assert(found != gutils->reverseBlockToPrimal.end());
603 SmallVector<BasicBlock *, 4> &vec =
604 gutils->reverseBlocks[found->second];
605 assert(vec.size());
606 vec.push_back(endBlock);
607 }
608 Builder2.SetInsertPoint(endBlock);
609 if (shouldFree()) {
610 CreateDealloc(Builder2, d_reqp);
611 }
612 } else if (Mode == DerivativeMode::ForwardMode ||
614 IRBuilder<> Builder2(&call);
615
616 assert(!gutils->isConstantValue(call.getOperand(1)));
617
618 Value *count = gutils->getNewFromOriginal(call.getOperand(0));
619 Value *array_of_requests =
620 gutils->invertPointerM(call.getOperand(1), Builder2);
621 if (array_of_requests->getType()->isIntegerTy()) {
622 array_of_requests = Builder2.CreateIntToPtr(
623 array_of_requests, getUnqual(getInt8PtrTy(call.getContext())));
624 }
625
626 Value *args[] = {
627 /*count*/ count,
628 /*array_of_requests*/ array_of_requests,
629 };
630
631 auto Defs = gutils->getInvertedBundles(
632 &call,
635 Builder2, /*lookup*/ false);
636
637 auto callval = call.getCalledOperand();
638
639 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
640 return;
641 }
643 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
644 return;
645 }
646
647 if (funcName == "MPI_Send" || funcName == "MPI_Ssend" ||
648 funcName == "PMPI_Send" || funcName == "PMPI_Ssend") {
653 bool forwardMode = Mode == DerivativeMode::ForwardMode ||
655
656 IRBuilder<> Builder2 =
657 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
658 if (forwardMode) {
659 getForwardBuilder(Builder2);
660 } else {
661 getReverseBuilder(Builder2);
662 }
663
664 Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2);
665 if (!forwardMode)
666 shadow = lookup(shadow, Builder2);
667 Value *shadowOrig = shadow;
668 if (shadow->getType()->isIntegerTy())
669 shadow =
670 Builder2.CreateIntToPtr(shadow, getInt8PtrTy(call.getContext()));
671
672 Type *statusType = nullptr;
673#if LLVM_VERSION_MAJOR < 17
674 if (called->getContext().supportsTypedPointers()) {
675 if (Function *recvfn = called->getParent()->getFunction(
676 getRenamedPerCallingConv(called->getName(), "MPI_Recv"))) {
677 auto statusArg = recvfn->arg_end();
678 statusArg--;
679 if (auto PT = dyn_cast<PointerType>(statusArg->getType()))
680 statusType = PT->getPointerElementType();
681 }
682 }
683#endif
684 if (statusType == nullptr) {
685 statusType = ArrayType::get(Type::getInt8Ty(call.getContext()), 24);
686 llvm::errs() << " warning could not automatically determine mpi "
687 "status type, assuming [24 x i8]\n";
688 }
689
690 Value *count = gutils->getNewFromOriginal(call.getOperand(1));
691 if (!forwardMode)
692 count = lookup(count, Builder2);
693
694 Value *datatype = gutils->getNewFromOriginal(call.getOperand(2));
695 if (!forwardMode)
696 datatype = lookup(datatype, Builder2);
697
698 Value *src = gutils->getNewFromOriginal(call.getOperand(3));
699 if (!forwardMode)
700 src = lookup(src, Builder2);
701
702 Value *tag = gutils->getNewFromOriginal(call.getOperand(4));
703 if (!forwardMode)
704 tag = lookup(tag, Builder2);
705
706 Value *comm = gutils->getNewFromOriginal(call.getOperand(5));
707 if (!forwardMode)
708 comm = lookup(comm, Builder2);
709
710 if (forwardMode) {
711 Value *args[] = {
712 /*buf*/ shadowOrig,
713 /*count*/ count,
714 /*datatype*/ datatype,
715 /*dest*/ src,
716 /*tag*/ tag,
717 /*comm*/ comm,
718 };
719
720 auto Defs = gutils->getInvertedBundles(
721 &call,
724 Builder2, /*lookup*/ false);
725
726 auto callval = call.getCalledOperand();
727 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
728 return;
729 }
730
731 Value *args[] = {
732 /*buf*/ NULL,
733 /*count*/ count,
734 /*datatype*/ datatype,
735 /*src*/ src,
736 /*tag*/ tag,
737 /*comm*/ comm,
738 /*status*/
739 IRBuilder<>(gutils->inversionAllocs).CreateAlloca(statusType)};
740
741 Value *tysize = MPI_TYPE_SIZE(datatype, Builder2, call.getType(), called);
742
743 auto len_arg = Builder2.CreateZExtOrTrunc(
744 args[1], Type::getInt64Ty(call.getContext()));
745 len_arg =
746 Builder2.CreateMul(len_arg,
747 Builder2.CreateZExtOrTrunc(
748 tysize, Type::getInt64Ty(call.getContext())),
749 "", true, true);
750
751 Value *firstallocation =
752 CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()),
753 len_arg, "mpirecv_malloccache");
754 args[0] = firstallocation;
755
756 Type *types[sizeof(args) / sizeof(*args)];
757 for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++)
758 types[i] = args[i]->getType();
759 FunctionType *FT = FunctionType::get(call.getType(), types, false);
760
761 Builder2.SetInsertPoint(Builder2.GetInsertBlock());
762
763 auto BufferDefs = gutils->getInvertedBundles(
764 &call,
767 Builder2, /*lookup*/ true);
768
769 auto fcall = Builder2.CreateCall(
770 called->getParent()->getOrInsertFunction(
771 getRenamedPerCallingConv(called->getName(), "MPI_Recv"), FT),
772 args);
773 fcall->setCallingConv(call.getCallingConv());
774
775 DifferentiableMemCopyFloats(call, call.getOperand(0), firstallocation,
776 shadow, len_arg, Builder2, BufferDefs);
777
778 if (shouldFree()) {
779 CreateDealloc(Builder2, firstallocation);
780 }
781 }
783 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
784 return;
785 }
786
787 if (funcName == "MPI_Recv" || funcName == "PMPI_Recv") {
792 bool forwardMode = Mode == DerivativeMode::ForwardMode ||
794
795 IRBuilder<> Builder2 =
796 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
797 if (forwardMode) {
798 getForwardBuilder(Builder2);
799 } else {
800 getReverseBuilder(Builder2);
801 }
802
803 Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2);
804 if (!forwardMode)
805 shadow = lookup(shadow, Builder2);
806
807 Value *count = gutils->getNewFromOriginal(call.getOperand(1));
808 if (!forwardMode)
809 count = lookup(count, Builder2);
810
811 Value *datatype = gutils->getNewFromOriginal(call.getOperand(2));
812 if (!forwardMode)
813 datatype = lookup(datatype, Builder2);
814
815 Value *source = gutils->getNewFromOriginal(call.getOperand(3));
816 if (!forwardMode)
817 source = lookup(source, Builder2);
818
819 Value *tag = gutils->getNewFromOriginal(call.getOperand(4));
820 if (!forwardMode)
821 tag = lookup(tag, Builder2);
822
823 Value *comm = gutils->getNewFromOriginal(call.getOperand(5));
824 if (!forwardMode)
825 comm = lookup(comm, Builder2);
826
827 if (forwardMode) {
828 Value *status = gutils->getNewFromOriginal(call.getOperand(6));
829 Value *args[] = {shadow, count, datatype, source, tag, comm, status};
830
831 auto Defs = gutils->getInvertedBundles(
832 &call,
836 Builder2, /*lookup*/ !forwardMode);
837
838 auto callval = call.getCalledOperand();
839
840 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
841 return;
842 }
843
844 Value *args[] = {shadow, count, datatype, source, tag, comm};
845
846 auto Defs = gutils->getInvertedBundles(
847 &call,
851 Builder2, /*lookup*/ !forwardMode);
852
853 Type *types[sizeof(args) / sizeof(*args)];
854 for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++)
855 types[i] = args[i]->getType();
856 FunctionType *FT = FunctionType::get(call.getType(), types, false);
857
858 auto fcall = Builder2.CreateCall(
859 called->getParent()->getOrInsertFunction(
860 getRenamedPerCallingConv(called->getName(), "MPI_Send"), FT),
861 args, Defs);
862 fcall->setCallingConv(call.getCallingConv());
863
864 auto dst_arg =
865 Builder2.CreateBitCast(args[0], getInt8PtrTy(call.getContext()));
866 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
867 auto len_arg = Builder2.CreateZExtOrTrunc(
868 args[1], Type::getInt64Ty(call.getContext()));
869 auto tysize = MPI_TYPE_SIZE(datatype, Builder2, call.getType(), called);
870 len_arg =
871 Builder2.CreateMul(len_arg,
872 Builder2.CreateZExtOrTrunc(
873 tysize, Type::getInt64Ty(call.getContext())),
874 "", true, true);
875 auto volatile_arg = ConstantInt::getFalse(call.getContext());
876
877 Value *nargs[] = {dst_arg, val_arg, len_arg, volatile_arg};
878 Type *tys[] = {dst_arg->getType(), len_arg->getType()};
879
880 auto MemsetDefs = gutils->getInvertedBundles(
881 &call,
884 Builder2, /*lookup*/ true);
885 auto memset = cast<CallInst>(Builder2.CreateCall(
886 getIntrinsicDeclaration(gutils->newFunc->getParent(),
887 Intrinsic::memset, tys),
888 nargs));
889 memset->addParamAttr(0, Attribute::NonNull);
890 }
892 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
893 return;
894 }
895
896 // int MPI_Bcast( void *buffer, int count, MPI_Datatype datatype, int root,
897 // MPI_Comm comm )
898 // 1. if root, malloc intermediate buffer
899 // 2. reduce sum diff(buffer) into intermediate
900 // 3. if root, set shadow(buffer) = intermediate [memcpy] then free
901 // 3-e. else, set shadow(buffer) = 0 [memset]
902 if (funcName == "MPI_Bcast" || funcName == "PMPI_Bcast") {
907 bool forwardMode = Mode == DerivativeMode::ForwardMode ||
909
910 IRBuilder<> Builder2 =
911 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
912 if (forwardMode) {
913 getForwardBuilder(Builder2);
914 } else {
915 getReverseBuilder(Builder2);
916 }
917
918 Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2);
919 if (!forwardMode)
920 shadow = lookup(shadow, Builder2);
921 if (shadow->getType()->isIntegerTy())
922 shadow =
923 Builder2.CreateIntToPtr(shadow, getInt8PtrTy(call.getContext()));
924
925 ConcreteType CT = TR.firstPointer(1, call.getOperand(0), &call);
926 auto MPI_OP_type = getInt8PtrTy(call.getContext());
927 Type *MPI_OP_Ptr_type = getUnqual(MPI_OP_type);
928
929 Value *count = gutils->getNewFromOriginal(call.getOperand(1));
930 if (!forwardMode)
931 count = lookup(count, Builder2);
932 Value *datatype = gutils->getNewFromOriginal(call.getOperand(2));
933 if (!forwardMode)
934 datatype = lookup(datatype, Builder2);
935 Value *root = gutils->getNewFromOriginal(call.getOperand(3));
936 if (!forwardMode)
937 root = lookup(root, Builder2);
938
939 Value *comm = gutils->getNewFromOriginal(call.getOperand(4));
940 if (!forwardMode)
941 comm = lookup(comm, Builder2);
942
943 if (forwardMode) {
944 Value *args[] = {
945 /*buffer*/ shadow,
946 /*count*/ count,
947 /*datatype*/ datatype,
948 /*root*/ root,
949 /*comm*/ comm,
950 };
951
952 auto Defs = gutils->getInvertedBundles(
953 &call,
956 Builder2, /*lookup*/ false);
957
958 auto callval = call.getCalledOperand();
959 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
960 return;
961 }
962
963 Value *rank = MPI_COMM_RANK(comm, Builder2, root->getType(), called);
964 Value *tysize = MPI_TYPE_SIZE(datatype, Builder2, call.getType(), called);
965
966 auto len_arg = Builder2.CreateZExtOrTrunc(
967 count, Type::getInt64Ty(call.getContext()));
968 len_arg =
969 Builder2.CreateMul(len_arg,
970 Builder2.CreateZExtOrTrunc(
971 tysize, Type::getInt64Ty(call.getContext())),
972 "", true, true);
973
974 // 1. if root, malloc intermediate buffer, else undef
975 PHINode *buf;
976
977 {
978 BasicBlock *currentBlock = Builder2.GetInsertBlock();
979 BasicBlock *rootBlock = gutils->addReverseBlock(
980 currentBlock, currentBlock->getName() + "_root", gutils->newFunc);
981 BasicBlock *mergeBlock = gutils->addReverseBlock(
982 rootBlock, currentBlock->getName() + "_post", gutils->newFunc);
983
984 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
985 mergeBlock);
986
987 Builder2.SetInsertPoint(rootBlock);
988
989 Value *rootbuf =
990 CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()),
991 len_arg, "mpireduce_malloccache");
992 Builder2.CreateBr(mergeBlock);
993
994 Builder2.SetInsertPoint(mergeBlock);
995
996 buf = Builder2.CreatePHI(rootbuf->getType(), 2);
997 buf->addIncoming(rootbuf, rootBlock);
998 buf->addIncoming(UndefValue::get(buf->getType()), currentBlock);
999 }
1000
1001 // Need to preserve the shadow buffer.
1002 auto BufferDefs = gutils->getInvertedBundles(
1003 &call,
1006 Builder2, /*lookup*/ true);
1007
1008 // 2. reduce sum diff(buffer) into intermediate
1009 {
1010 // int MPI_Reduce(const void *sendbuf, void *recvbuf, int count,
1011 // MPI_Datatype datatype,
1012 // MPI_Op op, int root, MPI_Comm comm)
1013 Value *args[] = {
1014 /*sendbuf*/ shadow,
1015 /*recvbuf*/ buf,
1016 /*count*/ count,
1017 /*datatype*/ datatype,
1018 /*op (MPI_SUM)*/
1019 getOrInsertOpFloatSum(*gutils->newFunc->getParent(),
1020 MPI_OP_Ptr_type, MPI_OP_type, CT,
1021 root->getType(), Builder2),
1022 /*int root*/ root,
1023 /*comm*/ comm,
1024 };
1025 Type *types[sizeof(args) / sizeof(*args)];
1026 for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++)
1027 types[i] = args[i]->getType();
1028
1029 FunctionType *FT = FunctionType::get(call.getType(), types, false);
1030
1031 Builder2.CreateCall(
1032 called->getParent()->getOrInsertFunction(
1033 getRenamedPerCallingConv(called->getName(), "MPI_Reduce"), FT),
1034 args, BufferDefs);
1035 }
1036
1037 // 3. if root, set shadow(buffer) = intermediate [memcpy]
1038 BasicBlock *currentBlock = Builder2.GetInsertBlock();
1039 BasicBlock *rootBlock = gutils->addReverseBlock(
1040 currentBlock, currentBlock->getName() + "_root", gutils->newFunc);
1041 BasicBlock *nonrootBlock = gutils->addReverseBlock(
1042 rootBlock, currentBlock->getName() + "_nonroot", gutils->newFunc);
1043 BasicBlock *mergeBlock = gutils->addReverseBlock(
1044 nonrootBlock, currentBlock->getName() + "_post", gutils->newFunc);
1045
1046 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
1047 nonrootBlock);
1048
1049 Builder2.SetInsertPoint(rootBlock);
1050
1051 {
1052 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1053 Value *nargs[] = {shadow, buf, len_arg, volatile_arg};
1054
1055 Type *tys[] = {shadow->getType(), buf->getType(), len_arg->getType()};
1056
1057 auto memcpyF = getIntrinsicDeclaration(gutils->newFunc->getParent(),
1058 Intrinsic::memcpy, tys);
1059
1060 auto mem =
1061 cast<CallInst>(Builder2.CreateCall(memcpyF, nargs, BufferDefs));
1062 mem->setCallingConv(memcpyF->getCallingConv());
1063
1064 // Free up the memory of the buffer
1065 if (shouldFree()) {
1066 CreateDealloc(Builder2, buf);
1067 }
1068 }
1069
1070 Builder2.CreateBr(mergeBlock);
1071
1072 Builder2.SetInsertPoint(nonrootBlock);
1073
1074 // 3-e. else, set shadow(buffer) = 0 [memset]
1075 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
1076 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1077 Value *args[] = {shadow, val_arg, len_arg, volatile_arg};
1078 Type *tys[] = {args[0]->getType(), args[2]->getType()};
1079 auto memset = cast<CallInst>(Builder2.CreateCall(
1080 getIntrinsicDeclaration(gutils->newFunc->getParent(),
1081 Intrinsic::memset, tys),
1082 args, BufferDefs));
1083 memset->addParamAttr(0, Attribute::NonNull);
1084 Builder2.CreateBr(mergeBlock);
1085
1086 Builder2.SetInsertPoint(mergeBlock);
1087 }
1089 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
1090 return;
1091 }
1092
1093 // Approximate algo (for sum): -> if statement yet to be
1094 // 1. malloc intermediate buffer
1095 // 1.5 if root, set intermediate = diff(recvbuffer)
1096 // 2. MPI_Bcast intermediate to all
1097 // 3. if root, Zero diff(recvbuffer) [memset to 0]
1098 // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy)
1099 // 5. free intermediate buffer
1100
1101 // int MPI_Reduce(const void *sendbuf, void *recvbuf, int count,
1102 // MPI_Datatype datatype,
1103 // MPI_Op op, int root, MPI_Comm comm)
1104
1105 if (funcName == "MPI_Reduce" || funcName == "PMPI_Reduce") {
1110 // TODO insert a check for sum
1111
1112 bool forwardMode = Mode == DerivativeMode::ForwardMode ||
1114
1115 IRBuilder<> Builder2 =
1116 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
1117 if (forwardMode) {
1118 getForwardBuilder(Builder2);
1119 } else {
1120 getReverseBuilder(Builder2);
1121 }
1122
1123 // Get the operations from MPI_Receive
1124 Value *orig_sendbuf = call.getOperand(0);
1125 Value *orig_recvbuf = call.getOperand(1);
1126 Value *orig_count = call.getOperand(2);
1127 Value *orig_datatype = call.getOperand(3);
1128 Value *orig_op = call.getOperand(4);
1129 Value *orig_root = call.getOperand(5);
1130 Value *orig_comm = call.getOperand(6);
1131
1132 bool isSum = false;
1133 if (Constant *C = dyn_cast<Constant>(orig_op)) {
1134 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
1135 C = CE->getOperand(0);
1136 }
1137 if (auto GV = dyn_cast<GlobalVariable>(C)) {
1138 if (GV->getName() == "ompi_mpi_op_sum") {
1139 isSum = true;
1140 }
1141 }
1142 // MPICH
1143 if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) {
1144 if (CI->getValue() == 1476395011) {
1145 isSum = true;
1146 }
1147 }
1148 }
1149 if (!isSum) {
1150 std::string s;
1151 llvm::raw_string_ostream ss(s);
1152 ss << " call: " << call << "\n";
1153 ss << " unhandled mpi_reduce op: " << *orig_op << "\n";
1154 EmitNoDerivativeError(ss.str(), call, gutils, BuilderZ);
1155 return;
1156 }
1157
1158 Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
1159 if (!forwardMode)
1160 shadow_recvbuf = lookup(shadow_recvbuf, Builder2);
1161 if (shadow_recvbuf->getType()->isIntegerTy())
1162 shadow_recvbuf = Builder2.CreateIntToPtr(
1163 shadow_recvbuf, getInt8PtrTy(call.getContext()));
1164
1165 Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
1166 if (!forwardMode)
1167 shadow_sendbuf = lookup(shadow_sendbuf, Builder2);
1168 if (shadow_sendbuf->getType()->isIntegerTy())
1169 shadow_sendbuf = Builder2.CreateIntToPtr(
1170 shadow_sendbuf, getInt8PtrTy(call.getContext()));
1171
1172 // Need to preserve the shadow send/recv buffers.
1173 auto BufferDefs = gutils->getInvertedBundles(
1174 &call,
1178 Builder2, /*lookup*/ !forwardMode);
1179
1180 Value *count = gutils->getNewFromOriginal(orig_count);
1181 if (!forwardMode)
1182 count = lookup(count, Builder2);
1183
1184 Value *datatype = gutils->getNewFromOriginal(orig_datatype);
1185 if (!forwardMode)
1186 datatype = lookup(datatype, Builder2);
1187
1188 Value *op = gutils->getNewFromOriginal(orig_op);
1189 if (!forwardMode)
1190 op = lookup(op, Builder2);
1191
1192 Value *root = gutils->getNewFromOriginal(orig_root);
1193 if (!forwardMode)
1194 root = lookup(root, Builder2);
1195
1196 Value *comm = gutils->getNewFromOriginal(orig_comm);
1197 if (!forwardMode)
1198 comm = lookup(comm, Builder2);
1199
1200 Value *rank = MPI_COMM_RANK(comm, Builder2, root->getType(), called);
1201
1202 if (forwardMode) {
1203 Value *args[] = {
1204 /*sendbuf*/ shadow_sendbuf,
1205 /*recvbuf*/ shadow_recvbuf,
1206 /*count*/ count,
1207 /*datatype*/ datatype,
1208 /*op*/ op,
1209 /*root*/ root,
1210 /*comm*/ comm,
1211 };
1212
1213 auto Defs = gutils->getInvertedBundles(
1214 &call,
1218 Builder2, /*lookup*/ false);
1219
1220 auto callval = call.getCalledOperand();
1221 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
1222 return;
1223 }
1224
1225 Value *tysize = MPI_TYPE_SIZE(datatype, Builder2, call.getType(), called);
1226
1227 // Get the length for the allocation of the intermediate buffer
1228 auto len_arg = Builder2.CreateZExtOrTrunc(
1229 count, Type::getInt64Ty(call.getContext()));
1230 len_arg =
1231 Builder2.CreateMul(len_arg,
1232 Builder2.CreateZExtOrTrunc(
1233 tysize, Type::getInt64Ty(call.getContext())),
1234 "", true, true);
1235
1236 // 1. Alloc intermediate buffer
1237 Value *buf =
1238 CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()),
1239 len_arg, "mpireduce_malloccache");
1240
1241 // 1.5 if root, set intermediate = diff(recvbuffer)
1242 {
1243
1244 BasicBlock *currentBlock = Builder2.GetInsertBlock();
1245 BasicBlock *rootBlock = gutils->addReverseBlock(
1246 currentBlock, currentBlock->getName() + "_root", gutils->newFunc);
1247 BasicBlock *mergeBlock = gutils->addReverseBlock(
1248 rootBlock, currentBlock->getName() + "_post", gutils->newFunc);
1249
1250 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
1251 mergeBlock);
1252
1253 Builder2.SetInsertPoint(rootBlock);
1254
1255 {
1256 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1257 Value *nargs[] = {buf, shadow_recvbuf, len_arg, volatile_arg};
1258
1259 Type *tys[] = {nargs[0]->getType(), nargs[1]->getType(),
1260 len_arg->getType()};
1261
1262 auto memcpyF = getIntrinsicDeclaration(gutils->newFunc->getParent(),
1263 Intrinsic::memcpy, tys);
1264
1265 auto mem =
1266 cast<CallInst>(Builder2.CreateCall(memcpyF, nargs, BufferDefs));
1267 mem->setCallingConv(memcpyF->getCallingConv());
1268 }
1269
1270 Builder2.CreateBr(mergeBlock);
1271 Builder2.SetInsertPoint(mergeBlock);
1272 }
1273
1274 // 2. MPI_Bcast intermediate to all
1275 {
1276 // int MPI_Bcast( void *buffer, int count, MPI_Datatype datatype, int
1277 // root,
1278 // MPI_Comm comm )
1279 Value *args[] = {
1280 /*buf*/ buf,
1281 /*count*/ count,
1282 /*datatype*/ datatype,
1283 /*int root*/ root,
1284 /*comm*/ comm,
1285 };
1286 Type *types[sizeof(args) / sizeof(*args)];
1287 for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++)
1288 types[i] = args[i]->getType();
1289
1290 FunctionType *FT = FunctionType::get(call.getType(), types, false);
1291 Builder2.CreateCall(
1292 called->getParent()->getOrInsertFunction(
1293 getRenamedPerCallingConv(called->getName(), "MPI_Bcast"), FT),
1294 args, BufferDefs);
1295 }
1296
1297 // 3. if root, Zero diff(recvbuffer) [memset to 0]
1298 {
1299 BasicBlock *currentBlock = Builder2.GetInsertBlock();
1300 BasicBlock *rootBlock = gutils->addReverseBlock(
1301 currentBlock, currentBlock->getName() + "_root", gutils->newFunc);
1302 BasicBlock *mergeBlock = gutils->addReverseBlock(
1303 rootBlock, currentBlock->getName() + "_post", gutils->newFunc);
1304
1305 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
1306 mergeBlock);
1307
1308 Builder2.SetInsertPoint(rootBlock);
1309
1310 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
1311 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1312 Value *args[] = {shadow_recvbuf, val_arg, len_arg, volatile_arg};
1313 Type *tys[] = {args[0]->getType(), args[2]->getType()};
1314 auto memset = cast<CallInst>(Builder2.CreateCall(
1315 getIntrinsicDeclaration(gutils->newFunc->getParent(),
1316 Intrinsic::memset, tys),
1317 args, BufferDefs));
1318 memset->addParamAttr(0, Attribute::NonNull);
1319
1320 Builder2.CreateBr(mergeBlock);
1321 Builder2.SetInsertPoint(mergeBlock);
1322 }
1323
1324 // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy)
1325 DifferentiableMemCopyFloats(call, orig_sendbuf, buf, shadow_sendbuf,
1326 len_arg, Builder2, BufferDefs);
1327
1328 // Free up intermediate buffer
1329 if (shouldFree()) {
1330 CreateDealloc(Builder2, buf);
1331 }
1332 }
1334 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
1335 return;
1336 }
1337
1338 // Approximate algo (for sum): -> if statement yet to be
1339 // 1. malloc intermediate buffers
1340 // 2. MPI_Allreduce (sum) of diff(recvbuffer) to intermediate
1341 // 3. Zero diff(recvbuffer) [memset to 0]
1342 // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy)
1343 // 5. free intermediate buffer
1344
1345 // int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
1346 // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
1347
1348 if (funcName == "MPI_Allreduce" || funcName == "PMPI_Allreduce") {
1353 // TODO insert a check for sum
1354
1355 bool forwardMode = Mode == DerivativeMode::ForwardMode ||
1357
1358 IRBuilder<> Builder2 =
1359 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
1360 if (forwardMode) {
1361 getForwardBuilder(Builder2);
1362 } else {
1363 getReverseBuilder(Builder2);
1364 }
1365
1366 // Get the operations from MPI_Receive
1367 Value *orig_sendbuf = call.getOperand(0);
1368 Value *orig_recvbuf = call.getOperand(1);
1369 Value *orig_count = call.getOperand(2);
1370 Value *orig_datatype = call.getOperand(3);
1371 Value *orig_op = call.getOperand(4);
1372 Value *orig_comm = call.getOperand(5);
1373
1374 bool isSum = false;
1375 if (Constant *C = dyn_cast<Constant>(orig_op)) {
1376 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
1377 C = CE->getOperand(0);
1378 }
1379 if (auto GV = dyn_cast<GlobalVariable>(C)) {
1380 if (GV->getName() == "ompi_mpi_op_sum") {
1381 isSum = true;
1382 }
1383 }
1384 // MPICH
1385 if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) {
1386 if (CI->getValue() == 1476395011) {
1387 isSum = true;
1388 }
1389 }
1390 }
1391 if (!isSum) {
1392 std::string s;
1393 llvm::raw_string_ostream ss(s);
1394 ss << " call: " << call << "\n";
1395 ss << " unhandled mpi_allreduce op: " << *orig_op << "\n";
1396 EmitNoDerivativeError(ss.str(), call, gutils, BuilderZ);
1397 return;
1398 }
1399
1400 Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
1401 if (!forwardMode)
1402 shadow_recvbuf = lookup(shadow_recvbuf, Builder2);
1403 if (shadow_recvbuf->getType()->isIntegerTy())
1404 shadow_recvbuf = Builder2.CreateIntToPtr(
1405 shadow_recvbuf, getInt8PtrTy(call.getContext()));
1406
1407 Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
1408 if (!forwardMode)
1409 shadow_sendbuf = lookup(shadow_sendbuf, Builder2);
1410 if (shadow_sendbuf->getType()->isIntegerTy())
1411 shadow_sendbuf = Builder2.CreateIntToPtr(
1412 shadow_sendbuf, getInt8PtrTy(call.getContext()));
1413
1414 // Need to preserve the shadow send/recv buffers.
1415 auto BufferDefs = gutils->getInvertedBundles(
1416 &call,
1419 Builder2, /*lookup*/ !forwardMode);
1420
1421 Value *count = gutils->getNewFromOriginal(orig_count);
1422 if (!forwardMode)
1423 count = lookup(count, Builder2);
1424
1425 Value *datatype = gutils->getNewFromOriginal(orig_datatype);
1426 if (!forwardMode)
1427 datatype = lookup(datatype, Builder2);
1428
1429 Value *comm = gutils->getNewFromOriginal(orig_comm);
1430 if (!forwardMode)
1431 comm = lookup(comm, Builder2);
1432
1433 Value *op = gutils->getNewFromOriginal(orig_op);
1434 if (!forwardMode)
1435 op = lookup(op, Builder2);
1436
1437 if (forwardMode) {
1438 Value *args[] = {
1439 /*sendbuf*/ shadow_sendbuf,
1440 /*recvbuf*/ shadow_recvbuf,
1441 /*count*/ count,
1442 /*datatype*/ datatype,
1443 /*op*/ op,
1444 /*comm*/ comm,
1445 };
1446
1447 auto callval = call.getCalledOperand();
1448 Builder2.CreateCall(call.getFunctionType(), callval, args, BufferDefs);
1449
1450 return;
1451 }
1452
1453 Value *tysize = MPI_TYPE_SIZE(datatype, Builder2, call.getType(), called);
1454
1455 // Get the length for the allocation of the intermediate buffer
1456 auto len_arg = Builder2.CreateZExtOrTrunc(
1457 count, Type::getInt64Ty(call.getContext()));
1458 len_arg =
1459 Builder2.CreateMul(len_arg,
1460 Builder2.CreateZExtOrTrunc(
1461 tysize, Type::getInt64Ty(call.getContext())),
1462 "", true, true);
1463
1464 // 1. Alloc intermediate buffer
1465 Value *buf =
1466 CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()),
1467 len_arg, "mpireduce_malloccache");
1468
1469 // 2. MPI_Allreduce (sum) of diff(recvbuffer) to intermediate
1470 {
1471 // int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
1472 // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
1473 Value *args[] = {
1474 /*sendbuf*/ shadow_recvbuf,
1475 /*recvbuf*/ buf,
1476 /*count*/ count,
1477 /*datatype*/ datatype,
1478 /*op*/ op,
1479 /*comm*/ comm,
1480 };
1481 Type *types[sizeof(args) / sizeof(*args)];
1482 for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++)
1483 types[i] = args[i]->getType();
1484
1485 FunctionType *FT = FunctionType::get(call.getType(), types, false);
1486 Builder2.CreateCall(
1487 called->getParent()->getOrInsertFunction(
1488 getRenamedPerCallingConv(called->getName(), "MPI_Allreduce"),
1489 FT),
1490 args, BufferDefs);
1491 }
1492
1493 // 3. Zero diff(recvbuffer) [memset to 0]
1494 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
1495 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1496 Value *args[] = {shadow_recvbuf, val_arg, len_arg, volatile_arg};
1497 Type *tys[] = {args[0]->getType(), args[2]->getType()};
1498 auto memset = cast<CallInst>(Builder2.CreateCall(
1499 getIntrinsicDeclaration(gutils->newFunc->getParent(),
1500 Intrinsic::memset, tys),
1501 args, BufferDefs));
1502 memset->addParamAttr(0, Attribute::NonNull);
1503
1504 // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy)
1505 DifferentiableMemCopyFloats(call, orig_sendbuf, buf, shadow_sendbuf,
1506 len_arg, Builder2, BufferDefs);
1507
1508 // Free up intermediate buffer
1509 if (shouldFree()) {
1510 CreateDealloc(Builder2, buf);
1511 }
1512 }
1514 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
1515 return;
1516 }
1517
1518 // Approximate algo (for sum): -> if statement yet to be
1519 // 1. malloc intermediate buffer
1520 // 2. Scatter diff(recvbuffer) to intermediate buffer
1521 // 3. if root, Zero diff(recvbuffer) [memset to 0]
1522 // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy)
1523 // 5. free intermediate buffer
1524
1525 // int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
1526 // void *recvbuf, int recvcount, MPI_Datatype recvtype,
1527 // int root, MPI_Comm comm)
1528
1529 if (funcName == "MPI_Gather" || funcName == "PMPI_Gather") {
1534 bool forwardMode = Mode == DerivativeMode::ForwardMode ||
1536
1537 IRBuilder<> Builder2 =
1538 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
1539 if (forwardMode) {
1540 getForwardBuilder(Builder2);
1541 } else {
1542 getReverseBuilder(Builder2);
1543 }
1544
1545 Value *orig_sendbuf = call.getOperand(0);
1546 Value *orig_sendcount = call.getOperand(1);
1547 Value *orig_sendtype = call.getOperand(2);
1548 Value *orig_recvbuf = call.getOperand(3);
1549 Value *orig_recvcount = call.getOperand(4);
1550 Value *orig_recvtype = call.getOperand(5);
1551 Value *orig_root = call.getOperand(6);
1552 Value *orig_comm = call.getOperand(7);
1553
1554 Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
1555 if (!forwardMode)
1556 shadow_recvbuf = lookup(shadow_recvbuf, Builder2);
1557 if (shadow_recvbuf->getType()->isIntegerTy())
1558 shadow_recvbuf = Builder2.CreateIntToPtr(
1559 shadow_recvbuf, getInt8PtrTy(call.getContext()));
1560
1561 Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
1562 if (!forwardMode)
1563 shadow_sendbuf = lookup(shadow_sendbuf, Builder2);
1564 if (shadow_sendbuf->getType()->isIntegerTy())
1565 shadow_sendbuf = Builder2.CreateIntToPtr(
1566 shadow_sendbuf, getInt8PtrTy(call.getContext()));
1567
1568 Value *recvcount = gutils->getNewFromOriginal(orig_recvcount);
1569 if (!forwardMode)
1570 recvcount = lookup(recvcount, Builder2);
1571
1572 Value *recvtype = gutils->getNewFromOriginal(orig_recvtype);
1573 if (!forwardMode)
1574 recvtype = lookup(recvtype, Builder2);
1575
1576 Value *sendcount = gutils->getNewFromOriginal(orig_sendcount);
1577 if (!sendcount)
1578 sendcount = lookup(sendcount, Builder2);
1579
1580 Value *sendtype = gutils->getNewFromOriginal(orig_sendtype);
1581 if (!forwardMode)
1582 sendtype = lookup(sendtype, Builder2);
1583
1584 Value *root = gutils->getNewFromOriginal(orig_root);
1585 if (!forwardMode)
1586 root = lookup(root, Builder2);
1587
1588 Value *comm = gutils->getNewFromOriginal(orig_comm);
1589 if (!forwardMode)
1590 comm = lookup(comm, Builder2);
1591
1592 Value *rank = MPI_COMM_RANK(comm, Builder2, root->getType(), called);
1593 Value *tysize = MPI_TYPE_SIZE(sendtype, Builder2, call.getType(), called);
1594
1595 if (forwardMode) {
1596 Value *args[] = {
1597 /*sendbuf*/ shadow_sendbuf,
1598 /*sendcount*/ sendcount,
1599 /*sendtype*/ sendtype,
1600 /*recvbuf*/ shadow_recvbuf,
1601 /*recvcount*/ recvcount,
1602 /*recvtype*/ recvtype,
1603 /*root*/ root,
1604 /*comm*/ comm,
1605 };
1606
1607 auto Defs = gutils->getInvertedBundles(
1608 &call,
1612 Builder2, /*lookup*/ false);
1613
1614 auto callval = call.getCalledOperand();
1615 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
1616 return;
1617 }
1618
1619 // Get the length for the allocation of the intermediate buffer
1620 auto sendlen_arg = Builder2.CreateZExtOrTrunc(
1621 sendcount, Type::getInt64Ty(call.getContext()));
1622 sendlen_arg =
1623 Builder2.CreateMul(sendlen_arg,
1624 Builder2.CreateZExtOrTrunc(
1625 tysize, Type::getInt64Ty(call.getContext())),
1626 "", true, true);
1627
1628 // Need to preserve the shadow send/recv buffers.
1629 auto BufferDefs = gutils->getInvertedBundles(
1630 &call,
1634 Builder2, /*lookup*/ true);
1635
1636 // 1. Alloc intermediate buffer
1637 Value *buf =
1638 CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()),
1639 sendlen_arg, "mpireduce_malloccache");
1640
1641 // 2. Scatter diff(recvbuffer) to intermediate buffer
1642 {
1643 // int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype
1644 // sendtype,
1645 // void *recvbuf, int recvcount, MPI_Datatype recvtype, int root,
1646 // MPI_Comm comm)
1647 Value *args[] = {
1648 /*sendbuf*/ shadow_recvbuf,
1649 /*sendcount*/ recvcount,
1650 /*sendtype*/ recvtype,
1651 /*recvbuf*/ buf,
1652 /*recvcount*/ sendcount,
1653 /*recvtype*/ sendtype,
1654 /*op*/ root,
1655 /*comm*/ comm,
1656 };
1657 Type *types[sizeof(args) / sizeof(*args)];
1658 for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++)
1659 types[i] = args[i]->getType();
1660
1661 FunctionType *FT = FunctionType::get(call.getType(), types, false);
1662 Builder2.CreateCall(
1663 called->getParent()->getOrInsertFunction(
1664 getRenamedPerCallingConv(called->getName(), "MPI_Scatter"), FT),
1665 args, BufferDefs);
1666 }
1667
1668 // 3. if root, Zero diff(recvbuffer) [memset to 0]
1669 {
1670
1671 BasicBlock *currentBlock = Builder2.GetInsertBlock();
1672 BasicBlock *rootBlock = gutils->addReverseBlock(
1673 currentBlock, currentBlock->getName() + "_root", gutils->newFunc);
1674 BasicBlock *mergeBlock = gutils->addReverseBlock(
1675 rootBlock, currentBlock->getName() + "_post", gutils->newFunc);
1676
1677 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
1678 mergeBlock);
1679
1680 Builder2.SetInsertPoint(rootBlock);
1681 auto recvlen_arg = Builder2.CreateZExtOrTrunc(
1682 recvcount, Type::getInt64Ty(call.getContext()));
1683 recvlen_arg =
1684 Builder2.CreateMul(recvlen_arg,
1685 Builder2.CreateZExtOrTrunc(
1686 tysize, Type::getInt64Ty(call.getContext())),
1687 "", true, true);
1688 recvlen_arg = Builder2.CreateMul(
1689 recvlen_arg,
1690 Builder2.CreateZExtOrTrunc(
1691 MPI_COMM_SIZE(comm, Builder2, root->getType(), called),
1692 Type::getInt64Ty(call.getContext())),
1693 "", true, true);
1694
1695 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
1696 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1697 Value *args[] = {shadow_recvbuf, val_arg, recvlen_arg, volatile_arg};
1698 Type *tys[] = {args[0]->getType(), args[2]->getType()};
1699 auto memset = cast<CallInst>(Builder2.CreateCall(
1700 getIntrinsicDeclaration(gutils->newFunc->getParent(),
1701 Intrinsic::memset, tys),
1702 args, BufferDefs));
1703 memset->addParamAttr(0, Attribute::NonNull);
1704
1705 Builder2.CreateBr(mergeBlock);
1706 Builder2.SetInsertPoint(mergeBlock);
1707 }
1708
1709 // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy)
1710 DifferentiableMemCopyFloats(call, orig_sendbuf, buf, shadow_sendbuf,
1711 sendlen_arg, Builder2, BufferDefs);
1712
1713 // Free up intermediate buffer
1714 if (shouldFree()) {
1715 CreateDealloc(Builder2, buf);
1716 }
1717 }
1719 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
1720 return;
1721 }
1722
1723 // Approximate algo (for sum): -> if statement yet to be
1724 // 1. if root, malloc intermediate buffer, else undef
1725 // 2. Gather diff(recvbuffer) to intermediate buffer
1726 // 3. Zero diff(recvbuffer) [memset to 0]
1727 // 4. if root, diff(sendbuffer) += intermediate buffer (diffmemcopy)
1728 // 5. if root, free intermediate buffer
1729
1730 // int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype
1731 // sendtype,
1732 // void *recvbuf, int recvcount, MPI_Datatype recvtype, int root,
1733 // MPI_Comm comm)
1734 if (funcName == "MPI_Scatter" || funcName == "PMPI_Scatter") {
1739 bool forwardMode = Mode == DerivativeMode::ForwardMode ||
1741
1742 IRBuilder<> Builder2 =
1743 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
1744 if (forwardMode) {
1745 getForwardBuilder(Builder2);
1746 } else {
1747 getReverseBuilder(Builder2);
1748 }
1749
1750 Value *orig_sendbuf = call.getOperand(0);
1751 Value *orig_sendcount = call.getOperand(1);
1752 Value *orig_sendtype = call.getOperand(2);
1753 Value *orig_recvbuf = call.getOperand(3);
1754 Value *orig_recvcount = call.getOperand(4);
1755 Value *orig_recvtype = call.getOperand(5);
1756 Value *orig_root = call.getOperand(6);
1757 Value *orig_comm = call.getOperand(7);
1758
1759 Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
1760 if (!forwardMode)
1761 shadow_recvbuf = lookup(shadow_recvbuf, Builder2);
1762 if (shadow_recvbuf->getType()->isIntegerTy())
1763 shadow_recvbuf = Builder2.CreateIntToPtr(
1764 shadow_recvbuf, getInt8PtrTy(call.getContext()));
1765
1766 Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
1767 if (!forwardMode)
1768 shadow_sendbuf = lookup(shadow_sendbuf, Builder2);
1769 if (shadow_sendbuf->getType()->isIntegerTy())
1770 shadow_sendbuf = Builder2.CreateIntToPtr(
1771 shadow_sendbuf, getInt8PtrTy(call.getContext()));
1772
1773 Value *recvcount = gutils->getNewFromOriginal(orig_recvcount);
1774 if (!forwardMode)
1775 recvcount = lookup(recvcount, Builder2);
1776
1777 Value *recvtype = gutils->getNewFromOriginal(orig_recvtype);
1778 if (!forwardMode)
1779 recvtype = lookup(recvtype, Builder2);
1780
1781 Value *sendcount = gutils->getNewFromOriginal(orig_sendcount);
1782 if (!forwardMode)
1783 sendcount = lookup(sendcount, Builder2);
1784
1785 Value *sendtype = gutils->getNewFromOriginal(orig_sendtype);
1786 if (!forwardMode)
1787 sendtype = lookup(sendtype, Builder2);
1788
1789 Value *root = gutils->getNewFromOriginal(orig_root);
1790 if (!forwardMode)
1791 root = lookup(root, Builder2);
1792
1793 Value *comm = gutils->getNewFromOriginal(orig_comm);
1794 if (!forwardMode)
1795 comm = lookup(comm, Builder2);
1796
1797 Value *rank = MPI_COMM_RANK(comm, Builder2, root->getType(), called);
1798 Value *tysize = MPI_TYPE_SIZE(sendtype, Builder2, call.getType(), called);
1799
1800 if (forwardMode) {
1801 Value *args[] = {
1802 /*sendbuf*/ shadow_sendbuf,
1803 /*sendcount*/ sendcount,
1804 /*sendtype*/ sendtype,
1805 /*recvbuf*/ shadow_recvbuf,
1806 /*recvcount*/ recvcount,
1807 /*recvtype*/ recvtype,
1808 /*root*/ root,
1809 /*comm*/ comm,
1810 };
1811
1812 auto Defs = gutils->getInvertedBundles(
1813 &call,
1817 Builder2, /*lookup*/ false);
1818
1819 auto callval = call.getCalledOperand();
1820 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
1821 return;
1822 }
1823 // Get the length for the allocation of the intermediate buffer
1824 auto recvlen_arg = Builder2.CreateZExtOrTrunc(
1825 recvcount, Type::getInt64Ty(call.getContext()));
1826 recvlen_arg =
1827 Builder2.CreateMul(recvlen_arg,
1828 Builder2.CreateZExtOrTrunc(
1829 tysize, Type::getInt64Ty(call.getContext())),
1830 "", true, true);
1831
1832 // Need to preserve the shadow send/recv buffers.
1833 auto BufferDefs = gutils->getInvertedBundles(
1834 &call,
1838 Builder2, /*lookup*/ true);
1839
1840 // 1. if root, malloc intermediate buffer, else undef
1841 PHINode *buf;
1842 PHINode *sendlen_phi;
1843
1844 {
1845 BasicBlock *currentBlock = Builder2.GetInsertBlock();
1846 BasicBlock *rootBlock = gutils->addReverseBlock(
1847 currentBlock, currentBlock->getName() + "_root", gutils->newFunc);
1848 BasicBlock *mergeBlock = gutils->addReverseBlock(
1849 rootBlock, currentBlock->getName() + "_post", gutils->newFunc);
1850
1851 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
1852 mergeBlock);
1853
1854 Builder2.SetInsertPoint(rootBlock);
1855
1856 auto sendlen_arg = Builder2.CreateZExtOrTrunc(
1857 sendcount, Type::getInt64Ty(call.getContext()));
1858 sendlen_arg =
1859 Builder2.CreateMul(sendlen_arg,
1860 Builder2.CreateZExtOrTrunc(
1861 tysize, Type::getInt64Ty(call.getContext())),
1862 "", true, true);
1863 sendlen_arg = Builder2.CreateMul(
1864 sendlen_arg,
1865 Builder2.CreateZExtOrTrunc(
1866 MPI_COMM_SIZE(comm, Builder2, root->getType(), called),
1867 Type::getInt64Ty(call.getContext())),
1868 "", true, true);
1869
1870 Value *rootbuf =
1871 CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()),
1872 sendlen_arg, "mpireduce_malloccache");
1873
1874 Builder2.CreateBr(mergeBlock);
1875
1876 Builder2.SetInsertPoint(mergeBlock);
1877
1878 buf = Builder2.CreatePHI(rootbuf->getType(), 2);
1879 buf->addIncoming(rootbuf, rootBlock);
1880 buf->addIncoming(UndefValue::get(buf->getType()), currentBlock);
1881
1882 sendlen_phi = Builder2.CreatePHI(sendlen_arg->getType(), 2);
1883 sendlen_phi->addIncoming(sendlen_arg, rootBlock);
1884 sendlen_phi->addIncoming(UndefValue::get(sendlen_arg->getType()),
1885 currentBlock);
1886 }
1887
1888 // 2. Gather diff(recvbuffer) to intermediate buffer
1889 {
1890 // int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype
1891 // sendtype,
1892 // void *recvbuf, int recvcount, MPI_Datatype recvtype,
1893 // int root, MPI_Comm comm)
1894 Value *args[] = {
1895 /*sendbuf*/ shadow_recvbuf,
1896 /*sendcount*/ recvcount,
1897 /*sendtype*/ recvtype,
1898 /*recvbuf*/ buf,
1899 /*recvcount*/ sendcount,
1900 /*recvtype*/ sendtype,
1901 /*root*/ root,
1902 /*comm*/ comm,
1903 };
1904 Type *types[sizeof(args) / sizeof(*args)];
1905 for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++)
1906 types[i] = args[i]->getType();
1907
1908 FunctionType *FT = FunctionType::get(call.getType(), types, false);
1909 Builder2.CreateCall(
1910 called->getParent()->getOrInsertFunction(
1911 getRenamedPerCallingConv(called->getName(), "MPI_Gather"), FT),
1912 args, BufferDefs);
1913 }
1914
1915 // 3. Zero diff(recvbuffer) [memset to 0]
1916 {
1917 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
1918 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1919 Value *args[] = {shadow_recvbuf, val_arg, recvlen_arg, volatile_arg};
1920 Type *tys[] = {args[0]->getType(), args[2]->getType()};
1921 auto memset = cast<CallInst>(Builder2.CreateCall(
1922 getIntrinsicDeclaration(gutils->newFunc->getParent(),
1923 Intrinsic::memset, tys),
1924 args, BufferDefs));
1925 memset->addParamAttr(0, Attribute::NonNull);
1926 }
1927
1928 // 4. if root, diff(sendbuffer) += intermediate buffer (diffmemcopy)
1929 // 5. if root, free intermediate buffer
1930
1931 {
1932 BasicBlock *currentBlock = Builder2.GetInsertBlock();
1933 BasicBlock *rootBlock = gutils->addReverseBlock(
1934 currentBlock, currentBlock->getName() + "_root", gutils->newFunc);
1935 BasicBlock *mergeBlock = gutils->addReverseBlock(
1936 rootBlock, currentBlock->getName() + "_post", gutils->newFunc);
1937
1938 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
1939 mergeBlock);
1940
1941 Builder2.SetInsertPoint(rootBlock);
1942
1943 // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy)
1944 DifferentiableMemCopyFloats(call, orig_sendbuf, buf, shadow_sendbuf,
1945 sendlen_phi, Builder2, BufferDefs);
1946
1947 // Free up intermediate buffer
1948 if (shouldFree()) {
1949 CreateDealloc(Builder2, buf);
1950 }
1951
1952 Builder2.CreateBr(mergeBlock);
1953 Builder2.SetInsertPoint(mergeBlock);
1954 }
1955 }
1957 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
1958 return;
1959 }
1960
1961 // Approximate algo (for sum): -> if statement yet to be
1962 // 1. malloc intermediate buffer
1963 // 2. reduce diff(recvbuffer) then scatter to corresponding input node's
1964 // intermediate buffer
1965 // 3. Zero diff(recvbuffer) [memset to 0]
1966 // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy)
1967 // 5. free intermediate buffer
1968
1969 // int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype
1970 // sendtype,
1971 // void *recvbuf, int recvcount, MPI_Datatype recvtype,
1972 // MPI_Comm comm)
1973
1974 if (funcName == "MPI_Allgather" || funcName == "PMPI_Allgather") {
1979 bool forwardMode = Mode == DerivativeMode::ForwardMode ||
1981
1982 IRBuilder<> Builder2 =
1983 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
1984 if (forwardMode) {
1985 getForwardBuilder(Builder2);
1986 } else {
1987 getReverseBuilder(Builder2);
1988 }
1989
1990 Value *orig_sendbuf = call.getOperand(0);
1991 Value *orig_sendcount = call.getOperand(1);
1992 Value *orig_sendtype = call.getOperand(2);
1993 Value *orig_recvbuf = call.getOperand(3);
1994 Value *orig_recvcount = call.getOperand(4);
1995 Value *orig_recvtype = call.getOperand(5);
1996 Value *orig_comm = call.getOperand(6);
1997
1998 Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
1999 if (!forwardMode)
2000 shadow_recvbuf = lookup(shadow_recvbuf, Builder2);
2001
2002 if (shadow_recvbuf->getType()->isIntegerTy())
2003 shadow_recvbuf = Builder2.CreateIntToPtr(
2004 shadow_recvbuf, getInt8PtrTy(call.getContext()));
2005
2006 Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
2007 if (!forwardMode)
2008 shadow_sendbuf = lookup(shadow_sendbuf, Builder2);
2009
2010 if (shadow_sendbuf->getType()->isIntegerTy())
2011 shadow_sendbuf = Builder2.CreateIntToPtr(
2012 shadow_sendbuf, getInt8PtrTy(call.getContext()));
2013
2014 Value *recvcount = gutils->getNewFromOriginal(orig_recvcount);
2015 if (!forwardMode)
2016 recvcount = lookup(recvcount, Builder2);
2017
2018 Value *recvtype = gutils->getNewFromOriginal(orig_recvtype);
2019 if (!forwardMode)
2020 recvtype = lookup(recvtype, Builder2);
2021
2022 Value *sendcount = gutils->getNewFromOriginal(orig_sendcount);
2023 if (!forwardMode)
2024 sendcount = lookup(sendcount, Builder2);
2025
2026 Value *sendtype = gutils->getNewFromOriginal(orig_sendtype);
2027 if (!forwardMode)
2028 sendtype = lookup(sendtype, Builder2);
2029
2030 Value *comm = gutils->getNewFromOriginal(orig_comm);
2031 if (!forwardMode)
2032 comm = lookup(comm, Builder2);
2033
2034 Value *tysize = MPI_TYPE_SIZE(sendtype, Builder2, call.getType(), called);
2035
2036 if (forwardMode) {
2037 Value *args[] = {
2038 /*sendbuf*/ shadow_sendbuf,
2039 /*sendcount*/ sendcount,
2040 /*sendtype*/ sendtype,
2041 /*recvbuf*/ shadow_recvbuf,
2042 /*recvcount*/ recvcount,
2043 /*recvtype*/ recvtype,
2044 /*comm*/ comm,
2045 };
2046
2047 auto Defs = gutils->getInvertedBundles(
2048 &call,
2052 Builder2, /*lookup*/ false);
2053
2054 auto callval = call.getCalledOperand();
2055 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
2056 return;
2057 }
2058 // Get the length for the allocation of the intermediate buffer
2059 auto sendlen_arg = Builder2.CreateZExtOrTrunc(
2060 sendcount, Type::getInt64Ty(call.getContext()));
2061 sendlen_arg =
2062 Builder2.CreateMul(sendlen_arg,
2063 Builder2.CreateZExtOrTrunc(
2064 tysize, Type::getInt64Ty(call.getContext())),
2065 "", true, true);
2066
2067 // Need to preserve the shadow send/recv buffers.
2068 auto BufferDefs = gutils->getInvertedBundles(
2069 &call,
2073 Builder2, /*lookup*/ true);
2074
2075 // 1. Alloc intermediate buffer
2076 Value *buf =
2077 CreateAllocation(Builder2, Type::getInt8Ty(call.getContext()),
2078 sendlen_arg, "mpireduce_malloccache");
2079
2080 ConcreteType CT = TR.firstPointer(1, orig_sendbuf, &call);
2081 auto MPI_OP_type = getInt8PtrTy(call.getContext());
2082 Type *MPI_OP_Ptr_type = getUnqual(MPI_OP_type);
2083
2084 // 2. reduce diff(recvbuffer) then scatter to corresponding input node's
2085 // intermediate buffer
2086 {
2087 // int MPI_Reduce_scatter_block(const void* send_buffer,
2088 // void* receive_buffer,
2089 // int count,
2090 // MPI_Datatype datatype,
2091 // MPI_Op operation,
2092 // MPI_Comm communicator);
2093 Value *args[] = {
2094 /*sendbuf*/ shadow_recvbuf,
2095 /*recvbuf*/ buf,
2096 /*recvcount*/ sendcount,
2097 /*recvtype*/ sendtype,
2098 /*op (MPI_SUM)*/
2099 getOrInsertOpFloatSum(*gutils->newFunc->getParent(),
2100 MPI_OP_Ptr_type, MPI_OP_type, CT,
2101 call.getType(), Builder2),
2102 /*comm*/ comm,
2103 };
2104 Type *types[sizeof(args) / sizeof(*args)];
2105 for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++)
2106 types[i] = args[i]->getType();
2107
2108 FunctionType *FT = FunctionType::get(call.getType(), types, false);
2109 Builder2.CreateCall(
2110 called->getParent()->getOrInsertFunction(
2111 getRenamedPerCallingConv(called->getName(),
2112 "MPI_Reduce_scatter_block"),
2113 FT),
2114 args, BufferDefs);
2115 }
2116
2117 // 3. zero diff(recvbuffer) [memset to 0]
2118 {
2119 auto recvlen_arg = Builder2.CreateZExtOrTrunc(
2120 recvcount, Type::getInt64Ty(call.getContext()));
2121 recvlen_arg =
2122 Builder2.CreateMul(recvlen_arg,
2123 Builder2.CreateZExtOrTrunc(
2124 tysize, Type::getInt64Ty(call.getContext())),
2125 "", true, true);
2126 recvlen_arg = Builder2.CreateMul(
2127 recvlen_arg,
2128 Builder2.CreateZExtOrTrunc(
2129 MPI_COMM_SIZE(comm, Builder2, call.getType(), called),
2130 Type::getInt64Ty(call.getContext())),
2131 "", true, true);
2132 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
2133 auto volatile_arg = ConstantInt::getFalse(call.getContext());
2134 Value *args[] = {shadow_recvbuf, val_arg, recvlen_arg, volatile_arg};
2135 Type *tys[] = {args[0]->getType(), args[2]->getType()};
2136 auto memset = cast<CallInst>(Builder2.CreateCall(
2137 getIntrinsicDeclaration(gutils->newFunc->getParent(),
2138 Intrinsic::memset, tys),
2139 args, BufferDefs));
2140 memset->addParamAttr(0, Attribute::NonNull);
2141 }
2142
2143 // 4. diff(sendbuffer) += intermediate buffer (diffmemcopy)
2144 DifferentiableMemCopyFloats(call, orig_sendbuf, buf, shadow_sendbuf,
2145 sendlen_arg, Builder2, BufferDefs);
2146
2147 // Free up intermediate buffer
2148 if (shouldFree()) {
2149 CreateDealloc(Builder2, buf);
2150 }
2151 }
2153 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2154 return;
2155 }
2156
2157 // Adjoint of barrier is to place a barrier at the corresponding
2158 // location in the reverse.
2159 if (funcName == "MPI_Barrier" || funcName == "PMPI_Barrier") {
2162 IRBuilder<> Builder2(&call);
2163 getReverseBuilder(Builder2);
2164 auto callval = call.getCalledOperand();
2165 Value *args[] = {
2166 lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2)};
2167 Builder2.CreateCall(call.getFunctionType(), callval, args);
2168 }
2170 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2171 return;
2172 }
2173
2174 // Remove free's in forward pass so the comm can be used in the reverse
2175 // pass
2176 if (funcName == "MPI_Comm_free" || funcName == "MPI_Comm_disconnect") {
2177 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2178 return;
2179 }
2180
2181 // Adjoint of MPI_Comm_split / MPI_Graph_create (which allocates a comm in a
2182 // pointer) is to free the created comm at the corresponding place in the
2183 // reverse pass
2184 auto commFound = MPIInactiveCommAllocators.find(funcName);
2185 if (commFound != MPIInactiveCommAllocators.end()) {
2188 IRBuilder<> Builder2(&call);
2189 getReverseBuilder(Builder2);
2190
2191 Value *args[] = {lookup(call.getOperand(commFound->second), Builder2)};
2192 Type *types[] = {args[0]->getType()};
2193
2194 FunctionType *FT = FunctionType::get(call.getType(), types, false);
2195 Builder2.CreateCall(
2196 called->getParent()->getOrInsertFunction(
2197 getRenamedPerCallingConv(called->getName(), "MPI_Comm_free"), FT),
2198 args);
2199 }
2201 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2202 return;
2203 }
2204
2205 llvm::errs() << *gutils->oldFunc->getParent() << "\n";
2206 llvm::errs() << *gutils->oldFunc << "\n";
2207 llvm::errs() << call << "\n";
2208 llvm::errs() << called << "\n";
2209 llvm_unreachable("Unhandled MPI FUNCTION");
2210}
2211
2213 CallInst &call, Function *called, StringRef funcName,
2214 bool subsequent_calls_may_write, const std::vector<bool> &overwritten_args,
2215 CallInst *const newCall) {
2216 bool subretused = false;
2217 bool shadowReturnUsed = false;
2218 DIFFE_TYPE subretType =
2219 gutils->getReturnDiffeType(&call, &subretused, &shadowReturnUsed);
2220
2221 IRBuilder<> BuilderZ(newCall);
2222 BuilderZ.setFastMathFlags(getFast());
2223
2224 if (Mode != DerivativeMode::ReverseModePrimal && called) {
2225 if (funcName == "__kmpc_for_static_init_4" ||
2226 funcName == "__kmpc_for_static_init_4u" ||
2227 funcName == "__kmpc_for_static_init_8" ||
2228 funcName == "__kmpc_for_static_init_8u") {
2229 IRBuilder<> Builder2(&call);
2230 getReverseBuilder(Builder2);
2231 auto fini = called->getParent()->getFunction("__kmpc_for_static_fini");
2232 assert(fini);
2233 Value *args[] = {
2234 lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2),
2235 lookup(gutils->getNewFromOriginal(call.getArgOperand(1)), Builder2)};
2236 auto fcall = Builder2.CreateCall(fini->getFunctionType(), fini, args);
2237 fcall->setCallingConv(fini->getCallingConv());
2238 return true;
2239 }
2240 }
2241
2242 if ((startsWith(funcName, "MPI_") || startsWith(funcName, "PMPI_")) &&
2243 (!gutils->isConstantInstruction(&call) || funcName == "MPI_Barrier" ||
2244 funcName == "MPI_Comm_free" || funcName == "MPI_Comm_disconnect" ||
2245 MPIInactiveCommAllocators.find(funcName) !=
2247 handleMPI(call, called, funcName);
2248 return true;
2249 }
2250
2251 if (auto blas = extractBLAS(funcName)) {
2252 if (handleBLAS(call, called, *blas, overwritten_args))
2253 return true;
2254 }
2255
2256 if (funcName == "printf" || funcName == "puts" ||
2257 startsWith(funcName, "_ZN3std2io5stdio6_print") ||
2258 startsWith(funcName, "_ZN4core3fmt")) {
2260 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2261 }
2262 return true;
2263 }
2264 if (called && (called->getName().contains("__enzyme_float") ||
2265 called->getName().contains("__enzyme_double") ||
2266 called->getName().contains("__enzyme_integer") ||
2267 called->getName().contains("__enzyme_pointer"))) {
2268 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2269 return true;
2270 }
2271
2272 // Handle lgamma, safe to recompute so no store/change to forward
2273 if (called) {
2274 if (funcName == "__kmpc_for_static_init_4" ||
2275 funcName == "__kmpc_for_static_init_4u" ||
2276 funcName == "__kmpc_for_static_init_8" ||
2277 funcName == "__kmpc_for_static_init_8u") {
2279 IRBuilder<> Builder2(&call);
2280 getReverseBuilder(Builder2);
2281 auto fini = called->getParent()->getFunction("__kmpc_for_static_fini");
2282 assert(fini);
2283 Value *args[] = {
2284 lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2),
2285 lookup(gutils->getNewFromOriginal(call.getArgOperand(1)),
2286 Builder2)};
2287 auto fcall = Builder2.CreateCall(fini->getFunctionType(), fini, args);
2288 fcall->setCallingConv(fini->getCallingConv());
2289 }
2290 return true;
2291 }
2292 if (funcName == "__kmpc_for_static_fini") {
2294 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2295 }
2296 return true;
2297 }
2298 // TODO check
2299 // Adjoint of barrier is to place a barrier at the corresponding
2300 // location in the reverse.
2301 if (funcName == "__kmpc_barrier") {
2304 IRBuilder<> Builder2(&call);
2305 getReverseBuilder(Builder2);
2306 auto callval = call.getCalledOperand();
2307 Value *args[] = {
2308 lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2),
2309 lookup(gutils->getNewFromOriginal(call.getOperand(1)), Builder2)};
2310 Builder2.CreateCall(call.getFunctionType(), callval, args);
2311 }
2312 return true;
2313 }
2314 if (funcName == "__kmpc_critical") {
2316 IRBuilder<> Builder2(&call);
2317 getReverseBuilder(Builder2);
2318 auto crit2 = called->getParent()->getFunction("__kmpc_end_critical");
2319 assert(crit2);
2320 Value *args[] = {
2321 lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2),
2322 lookup(gutils->getNewFromOriginal(call.getArgOperand(1)), Builder2),
2323 lookup(gutils->getNewFromOriginal(call.getArgOperand(2)),
2324 Builder2)};
2325 auto fcall = Builder2.CreateCall(crit2->getFunctionType(), crit2, args);
2326 fcall->setCallingConv(crit2->getCallingConv());
2327 }
2328 return true;
2329 }
2330 if (funcName == "__kmpc_end_critical") {
2332 IRBuilder<> Builder2(&call);
2333 getReverseBuilder(Builder2);
2334 auto crit2 = called->getParent()->getFunction("__kmpc_critical");
2335 assert(crit2);
2336 Value *args[] = {
2337 lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2),
2338 lookup(gutils->getNewFromOriginal(call.getArgOperand(1)), Builder2),
2339 lookup(gutils->getNewFromOriginal(call.getArgOperand(2)),
2340 Builder2)};
2341 auto fcall = Builder2.CreateCall(crit2->getFunctionType(), crit2, args);
2342 fcall->setCallingConv(crit2->getCallingConv());
2343 }
2344 return true;
2345 }
2346
2347 if (startsWith(funcName, "__kmpc") &&
2348 funcName != "__kmpc_global_thread_num") {
2349 std::string s;
2350 llvm::raw_string_ostream ss(s);
2351 ss << " unhandled openmp function: " << call << "\n";
2352 EmitNoDerivativeError(ss.str(), call, gutils, BuilderZ);
2353 return true;
2354 }
2355
2356 auto mod = call.getParent()->getParent()->getParent();
2357#include "CallDerivatives.inc"
2358
2359 if (funcName == "llvm.julia.gc_preserve_end") {
2362
2363 auto begin_call = cast<CallInst>(call.getOperand(0));
2364
2365 IRBuilder<> Builder2(&call);
2366 getReverseBuilder(Builder2);
2367 SmallVector<Value *, 1> args;
2368 for (auto &arg : begin_call->args()) {
2369 bool primalUsed = false;
2370 bool shadowUsed = false;
2371 gutils->getReturnDiffeType(arg, &primalUsed, &shadowUsed);
2372
2373 if (primalUsed)
2374 args.push_back(
2375 gutils->lookupM(gutils->getNewFromOriginal(arg), Builder2));
2376
2377 if (!gutils->isConstantValue(arg) && shadowUsed) {
2378 Value *ptrshadow = gutils->lookupM(
2379 gutils->invertPointerM(arg, BuilderZ), Builder2);
2380 if (gutils->getWidth() == 1)
2381 args.push_back(ptrshadow);
2382 else
2383 for (size_t i = 0; i < gutils->getWidth(); ++i)
2384 args.push_back(gutils->extractMeta(Builder2, ptrshadow, i));
2385 }
2386 }
2387
2388 auto newp = Builder2.CreateCall(
2389 called->getParent()->getOrInsertFunction(
2390 "llvm.julia.gc_preserve_begin",
2391 FunctionType::get(Type::getTokenTy(call.getContext()),
2392 ArrayRef<Type *>(), true)),
2393 args);
2394 auto ifound = gutils->invertedPointers.find(begin_call);
2395 assert(ifound != gutils->invertedPointers.end());
2396 auto placeholder = cast<CallInst>(&*ifound->second);
2397 gutils->invertedPointers.erase(ifound);
2398 gutils->invertedPointers.insert(std::make_pair(
2399 (const Value *)begin_call, InvertedPointerVH(gutils, newp)));
2400
2401 gutils->replaceAWithB(placeholder, newp);
2402 gutils->erase(placeholder);
2403 }
2404 return true;
2405 }
2406 if (funcName == "llvm.julia.gc_preserve_begin") {
2407 SmallVector<Value *, 1> args;
2408 for (auto &arg : call.args()) {
2409 bool primalUsed = false;
2410 bool shadowUsed = false;
2411 gutils->getReturnDiffeType(arg, &primalUsed, &shadowUsed);
2412
2413 if (primalUsed)
2414 args.push_back(gutils->getNewFromOriginal(arg));
2415
2416 if (!gutils->isConstantValue(arg) && shadowUsed) {
2417 Value *ptrshadow = gutils->invertPointerM(arg, BuilderZ);
2418 if (gutils->getWidth() == 1)
2419 args.push_back(ptrshadow);
2420 else
2421 for (size_t i = 0; i < gutils->getWidth(); ++i)
2422 args.push_back(gutils->extractMeta(BuilderZ, ptrshadow, i));
2423 }
2424 }
2425
2426 auto newp = BuilderZ.CreateCall(called, args);
2427 auto oldp = gutils->getNewFromOriginal(&call);
2428 gutils->replaceAWithB(oldp, newp);
2429 gutils->erase(oldp);
2430
2433 IRBuilder<> Builder2(&call);
2434 getReverseBuilder(Builder2);
2435
2436 auto ifound = gutils->invertedPointers.find(&call);
2437 assert(ifound != gutils->invertedPointers.end());
2438 auto placeholder = cast<CallInst>(&*ifound->second);
2439 Builder2.CreateCall(
2440 called->getParent()->getOrInsertFunction(
2441 "llvm.julia.gc_preserve_end",
2442 FunctionType::get(Builder2.getVoidTy(), call.getType(), false)),
2443 placeholder);
2444 }
2445 return true;
2446 }
2447
2448 /*
2449 * int gsl_sf_legendre_array_e(const gsl_sf_legendre_t norm,
2450 const size_t lmax,
2451 const double x,
2452 const double csphase,
2453 double result_array[]);
2454 */
2455 // d L(n, x) / dx = L(n,x) * x * (n-1) + 1
2456 if (funcName == "gsl_sf_legendre_array_e") {
2457 if (gutils->isConstantValue(call.getArgOperand(4))) {
2458 eraseIfUnused(call);
2459 return true;
2460 }
2462 eraseIfUnused(call);
2463 return true;
2464 }
2467 IRBuilder<> Builder2(&call);
2468 getReverseBuilder(Builder2);
2469 ValueType BundleTypes[5] = {ValueType::None, ValueType::None,
2472 auto Defs = gutils->getInvertedBundles(&call, BundleTypes, Builder2,
2473 /*lookup*/ true);
2474
2475 Type *types[6] = {
2476 call.getOperand(0)->getType(), call.getOperand(1)->getType(),
2477 call.getOperand(2)->getType(), call.getOperand(3)->getType(),
2478 call.getOperand(4)->getType(), call.getOperand(4)->getType(),
2479 };
2480 FunctionType *FT = FunctionType::get(call.getType(), types, false);
2481 auto F = called->getParent()->getOrInsertFunction(
2482 "gsl_sf_legendre_deriv_array_e", FT);
2483
2484 llvm::Value *args[6] = {
2485 gutils->lookupM(gutils->getNewFromOriginal(call.getOperand(0)),
2486 Builder2),
2487 gutils->lookupM(gutils->getNewFromOriginal(call.getOperand(1)),
2488 Builder2),
2489 gutils->lookupM(gutils->getNewFromOriginal(call.getOperand(2)),
2490 Builder2),
2491 gutils->lookupM(gutils->getNewFromOriginal(call.getOperand(3)),
2492 Builder2),
2493 nullptr,
2494 nullptr};
2495
2496 Type *typesS[] = {args[1]->getType()};
2497 FunctionType *FTS =
2498 FunctionType::get(args[1]->getType(), typesS, false);
2499 auto FS = called->getParent()->getOrInsertFunction(
2500 "gsl_sf_legendre_array_n", FTS);
2501 Value *alSize = Builder2.CreateCall(FS, args[1]);
2502 Value *tmp = CreateAllocation(Builder2, types[2], alSize);
2503 Value *dtmp = CreateAllocation(Builder2, types[2], alSize);
2504 Builder2.CreateLifetimeStart(tmp);
2505 Builder2.CreateLifetimeStart(dtmp);
2506
2507 args[4] = Builder2.CreateBitCast(tmp, types[4]);
2508 args[5] = Builder2.CreateBitCast(dtmp, types[5]);
2509
2510 Builder2.CreateCall(F, args, Defs);
2511 Builder2.CreateLifetimeEnd(tmp);
2512 CreateDealloc(Builder2, tmp);
2513
2514 BasicBlock *currentBlock = Builder2.GetInsertBlock();
2515
2516 BasicBlock *loopBlock = gutils->addReverseBlock(
2517 currentBlock, currentBlock->getName() + "_loop");
2518 BasicBlock *endBlock =
2519 gutils->addReverseBlock(loopBlock, currentBlock->getName() + "_end",
2520 /*fork*/ true, /*push*/ false);
2521
2522 Builder2.CreateCondBr(
2523 Builder2.CreateICmpEQ(args[1], Constant::getNullValue(types[1])),
2524 endBlock, loopBlock);
2525 Builder2.SetInsertPoint(loopBlock);
2526
2527 auto idx = Builder2.CreatePHI(types[1], 2);
2528 idx->addIncoming(ConstantInt::get(types[1], 0, false), currentBlock);
2529
2530 auto acc_idx = Builder2.CreatePHI(types[2], 2);
2531
2532 Value *inc = Builder2.CreateAdd(
2533 idx, ConstantInt::get(types[1], 1, false), "", true, true);
2534 idx->addIncoming(inc, loopBlock);
2535 acc_idx->addIncoming(Constant::getNullValue(types[2]), currentBlock);
2536
2537 Value *idxs[] = {idx};
2538 Value *dtmp_idx = Builder2.CreateInBoundsGEP(types[2], dtmp, idxs);
2539 Value *d_req = Builder2.CreateInBoundsGEP(
2540 types[2],
2541 Builder2.CreatePointerCast(
2542 gutils->invertPointerM(call.getOperand(4), Builder2),
2543 getUnqual(types[2])),
2544 idxs);
2545
2546 auto l0 = Builder2.CreateLoad(types[2], dtmp_idx);
2547 auto l1 = Builder2.CreateLoad(types[2], d_req);
2548 auto acc = Builder2.CreateFAdd(acc_idx, Builder2.CreateFMul(l0, l1));
2549 Builder2.CreateStore(Constant::getNullValue(types[2]), d_req);
2550
2551 acc_idx->addIncoming(acc, loopBlock);
2552
2553 Builder2.CreateCondBr(Builder2.CreateICmpEQ(inc, args[1]), endBlock,
2554 loopBlock);
2555
2556 Builder2.SetInsertPoint(endBlock);
2557 {
2558 auto found = gutils->reverseBlockToPrimal.find(endBlock);
2559 assert(found != gutils->reverseBlockToPrimal.end());
2560 SmallVector<BasicBlock *, 4> &vec =
2561 gutils->reverseBlocks[found->second];
2562 assert(vec.size());
2563 vec.push_back(endBlock);
2564 }
2565
2566 auto fin_idx = Builder2.CreatePHI(types[2], 2);
2567 fin_idx->addIncoming(Constant::getNullValue(types[2]), currentBlock);
2568 fin_idx->addIncoming(acc, loopBlock);
2569
2570 Builder2.CreateLifetimeEnd(dtmp);
2571 CreateDealloc(Builder2, dtmp);
2572
2573 ((DiffeGradientUtils *)gutils)
2574 ->addToDiffe(call.getOperand(2), fin_idx, Builder2, types[2]);
2575
2576 return true;
2577 }
2578 }
2579
2580 // Functions that only modify pointers and don't allocate memory,
2581 // needs to be run on shadow in primal
2582 if (funcName == "_ZSt29_Rb_tree_insert_and_rebalancebPSt18_Rb_tree_"
2583 "node_baseS0_RS_") {
2585 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2586 return true;
2587 }
2588 if (gutils->isConstantValue(call.getArgOperand(3)))
2589 return true;
2590 SmallVector<Value *, 2> args;
2591 for (auto &arg : call.args()) {
2592 if (gutils->isConstantValue(arg))
2593 args.push_back(gutils->getNewFromOriginal(arg));
2594 else
2595 args.push_back(gutils->invertPointerM(arg, BuilderZ));
2596 }
2597 BuilderZ.CreateCall(called, args);
2598 return true;
2599 }
2600
2601 // Functions that initialize a shadow data structure (with no
2602 // other arguments) needs to be run on shadow in primal.
2603 if (funcName == "_ZNSt8ios_baseC2Ev" || funcName == "_ZNSt8ios_baseD2Ev" ||
2604 funcName == "_ZNSt6localeC1Ev" || funcName == "_ZNSt6localeD1Ev" ||
2605 funcName == "_ZNKSt5ctypeIcE13_M_widen_initEv") {
2608 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2609 return true;
2610 }
2611 if (gutils->isConstantValue(call.getArgOperand(0)))
2612 return true;
2613 Value *args[] = {gutils->invertPointerM(call.getArgOperand(0), BuilderZ)};
2614 BuilderZ.CreateCall(called, args);
2615 return true;
2616 }
2617
2618 if (funcName == "_ZNSt9basic_iosIcSt11char_traitsIcEE4initEPSt15basic_"
2619 "streambufIcS1_E") {
2622 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2623 return true;
2624 }
2625 if (gutils->isConstantValue(call.getArgOperand(0)))
2626 return true;
2627 Value *args[] = {gutils->invertPointerM(call.getArgOperand(0), BuilderZ),
2628 gutils->invertPointerM(call.getArgOperand(1), BuilderZ)};
2629 BuilderZ.CreateCall(called, args);
2630 return true;
2631 }
2632
2633 // if constant instruction and readonly (thus must be pointer return)
2634 // and shadow return recomputable from shadow arguments.
2635 if (funcName == "__dynamic_cast" ||
2636 funcName == "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base" ||
2637 funcName == "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base" ||
2638 funcName == "_ZSt18_Rb_tree_decrementPSt18_Rb_tree_node_base" ||
2639 funcName == "_ZSt18_Rb_tree_incrementPSt18_Rb_tree_node_base" ||
2640 funcName == "jl_ptr_to_array" || funcName == "jl_ptr_to_array_1d") {
2641 bool shouldCache = false;
2642 if (gutils->knownRecomputeHeuristic.find(&call) !=
2643 gutils->knownRecomputeHeuristic.end()) {
2644 if (!gutils->knownRecomputeHeuristic[&call]) {
2645 shouldCache = true;
2646 }
2647 }
2648 ValueToValueMapTy empty;
2649 bool lrc = gutils->legalRecompute(&call, empty, nullptr);
2650
2651 if (!gutils->isConstantValue(&call)) {
2652 auto ifound = gutils->invertedPointers.find(&call);
2653 assert(ifound != gutils->invertedPointers.end());
2654 auto placeholder = cast<PHINode>(&*ifound->second);
2655
2656 if (subretType == DIFFE_TYPE::DUP_ARG) {
2657 Value *shadow = placeholder;
2658 if (lrc || Mode == DerivativeMode::ReverseModePrimal ||
2662 if (gutils->isConstantValue(call.getArgOperand(0)))
2663 shadow = gutils->getNewFromOriginal(&call);
2664 else {
2665 SmallVector<Value *, 2> args;
2666 size_t i = 0;
2667 for (auto &arg : call.args()) {
2668 if (gutils->isConstantValue(arg) ||
2669 (funcName == "__dynamic_cast" && i > 0) ||
2670 (funcName == "jl_ptr_to_array_1d" && i != 1) ||
2671 (funcName == "jl_ptr_to_array" && i != 1))
2672 args.push_back(gutils->getNewFromOriginal(arg));
2673 else
2674 args.push_back(gutils->invertPointerM(arg, BuilderZ));
2675 i++;
2676 }
2677 shadow = BuilderZ.CreateCall(called, args);
2678 }
2679 }
2680
2681 bool needsReplacement = true;
2682 if (!lrc && (Mode == DerivativeMode::ReverseModePrimal ||
2684 shadow = gutils->cacheForReverse(
2685 BuilderZ, shadow, getIndex(&call, CacheType::Shadow, BuilderZ));
2687 needsReplacement = false;
2688 }
2689 gutils->invertedPointers.erase((const Value *)&call);
2690 gutils->invertedPointers.insert(std::make_pair(
2691 (const Value *)&call, InvertedPointerVH(gutils, shadow)));
2692 if (needsReplacement) {
2693 assert(shadow != placeholder);
2694 gutils->replaceAWithB(placeholder, shadow);
2695 gutils->erase(placeholder);
2696 }
2697 } else {
2698 gutils->invertedPointers.erase((const Value *)&call);
2699 gutils->erase(placeholder);
2700 }
2701 }
2702
2703 if (Mode == DerivativeMode::ForwardMode ||
2705 eraseIfUnused(call);
2706 assert(gutils->isConstantInstruction(&call));
2707 return true;
2708 }
2709
2710 if (!shouldCache && !lrc) {
2711 std::map<UsageKey, bool> Seen;
2712 for (auto pair : gutils->knownRecomputeHeuristic)
2713 Seen[UsageKey(pair.first, QueryType::Primal)] = false;
2714 bool primalNeededInReverse =
2716 QueryType::Primal>(gutils, &call, Mode, Seen, oldUnreachable);
2717 shouldCache = primalNeededInReverse;
2718 }
2719
2720 if (shouldCache) {
2721 BuilderZ.SetInsertPoint(newCall->getNextNode());
2722 gutils->cacheForReverse(BuilderZ, newCall,
2723 getIndex(&call, CacheType::Self, BuilderZ));
2724 }
2725 eraseIfUnused(call);
2726 assert(gutils->isConstantInstruction(&call));
2727 return true;
2728 }
2729
2730 if (called) {
2731 if (funcName == "julia.write_barrier" ||
2732 funcName == "julia.write_barrier_binding") {
2733
2734 std::map<UsageKey, bool> Seen;
2735 for (auto pair : gutils->knownRecomputeHeuristic)
2736 if (!pair.second)
2737 Seen[UsageKey(pair.first, QueryType::Primal)] = false;
2738
2739 bool backwardsShadow = false;
2740 bool forwardsShadow = true;
2741 for (auto pair : gutils->backwardsOnlyShadows) {
2742 if (pair.second.stores.count(&call)) {
2743 backwardsShadow = true;
2744 forwardsShadow = pair.second.primalInitialize;
2745 if (auto inst = dyn_cast<Instruction>(pair.first))
2746 if (!forwardsShadow && pair.second.LI &&
2747 pair.second.LI->contains(inst->getParent()))
2748 backwardsShadow = false;
2749 break;
2750 }
2751 }
2752
2753 if (Mode == DerivativeMode::ForwardMode ||
2756 (forwardsShadow || backwardsShadow)) ||
2757 (Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
2758 (Mode == DerivativeMode::ReverseModeGradient && backwardsShadow)) {
2759 IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&call));
2760 for (size_t i = 0; i < gutils->getWidth(); i++) {
2761 SmallVector<Value *, 1> iargs;
2762 bool first = true;
2763 for (auto &arg : call.args()) {
2764 if (!gutils->isConstantValue(arg)) {
2765 Value *ptrshadow = gutils->invertPointerM(arg, BuilderZ);
2766 if (gutils->getWidth() > 1) {
2767 ptrshadow = gutils->extractMeta(BuilderZ, ptrshadow, i);
2768 }
2769 iargs.push_back(ptrshadow);
2770 } else {
2771 if (first)
2772 break;
2773 }
2774 first = false;
2775 }
2776 if (iargs.size()) {
2777 BuilderZ.CreateCall(called, iargs);
2778 }
2779 }
2780 }
2781
2782 bool forceErase = false;
2784
2785 // Since we won't redo the store in the reverse pass, do not
2786 // force the write barrier.
2787 forceErase = true;
2788 for (const auto &pair : gutils->rematerializableAllocations) {
2789 if (!pair.second.stores.count(&call))
2790 continue;
2791 bool primalNeededInReverse =
2794 ? false
2796 QueryType::Primal>(gutils, pair.first, Mode, Seen,
2797 oldUnreachable);
2798
2799 bool cacheWholeAllocation =
2800 gutils->needsCacheWholeAllocation(pair.first);
2801 if (cacheWholeAllocation) {
2802 primalNeededInReverse = true;
2803 }
2804
2805 if (primalNeededInReverse && !cacheWholeAllocation)
2806 // However, if we are rematerailizing the allocation and not
2807 // inside the loop level rematerialization, we do still need the
2808 // reverse passes ``fake primal'' store and therefore write
2809 // barrier
2810 if (!pair.second.LI || !pair.second.LI->contains(&call)) {
2811 forceErase = false;
2812 }
2813 }
2814 }
2815 if (forceErase)
2816 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2817 else
2818 eraseIfUnused(call);
2819
2820 return true;
2821 }
2822 Intrinsic::ID ID = Intrinsic::not_intrinsic;
2823 if (isMemFreeLibMFunction(funcName, &ID)) {
2825 gutils->isConstantInstruction(&call)) {
2826
2827 if (gutils->knownRecomputeHeuristic.find(&call) !=
2828 gutils->knownRecomputeHeuristic.end()) {
2829 if (!gutils->knownRecomputeHeuristic[&call]) {
2830 gutils->cacheForReverse(
2831 BuilderZ, newCall,
2832 getIndex(&call, CacheType::Self, BuilderZ));
2833 }
2834 }
2835 eraseIfUnused(call);
2836 return true;
2837 }
2838
2839 if (ID != Intrinsic::not_intrinsic) {
2840 SmallVector<Value *, 2> orig_ops(call.getNumOperands());
2841 for (unsigned i = 0; i < call.getNumOperands(); ++i) {
2842 orig_ops[i] = call.getOperand(i);
2843 }
2844 bool cached = handleAdjointForIntrinsic(ID, call, orig_ops);
2845 if (!cached) {
2846 if (gutils->knownRecomputeHeuristic.find(&call) !=
2847 gutils->knownRecomputeHeuristic.end()) {
2848 if (!gutils->knownRecomputeHeuristic[&call]) {
2849 gutils->cacheForReverse(
2850 BuilderZ, newCall,
2851 getIndex(&call, CacheType::Self, BuilderZ));
2852 }
2853 }
2854 }
2855 eraseIfUnused(call);
2856 return true;
2857 }
2858 }
2859 }
2860 }
2861 if (auto assembly = dyn_cast<InlineAsm>(call.getCalledOperand())) {
2862 if (assembly->getAsmString() == "maxpd $1, $0") {
2864 gutils->isConstantInstruction(&call)) {
2865
2866 if (gutils->knownRecomputeHeuristic.find(&call) !=
2867 gutils->knownRecomputeHeuristic.end()) {
2868 if (!gutils->knownRecomputeHeuristic[&call]) {
2869 gutils->cacheForReverse(BuilderZ, newCall,
2870 getIndex(&call, CacheType::Self, BuilderZ));
2871 }
2872 }
2873 eraseIfUnused(call);
2874 return true;
2875 }
2876
2877 SmallVector<Value *, 2> orig_ops(call.getNumOperands());
2878 for (unsigned i = 0; i < call.getNumOperands(); ++i) {
2879 orig_ops[i] = call.getOperand(i);
2880 }
2881 handleAdjointForIntrinsic(Intrinsic::maxnum, call, orig_ops);
2882 if (gutils->knownRecomputeHeuristic.find(&call) !=
2883 gutils->knownRecomputeHeuristic.end()) {
2884 if (!gutils->knownRecomputeHeuristic[&call]) {
2885 gutils->cacheForReverse(BuilderZ, newCall,
2886 getIndex(&call, CacheType::Self, BuilderZ));
2887 }
2888 }
2889 eraseIfUnused(call);
2890 return true;
2891 }
2892 }
2893
2894 if (funcName == "realloc") {
2895 if (Mode == DerivativeMode::ForwardMode ||
2897 if (!gutils->isConstantValue(&call)) {
2898 IRBuilder<> Builder2(&call);
2899 getForwardBuilder(Builder2);
2900
2901 auto dbgLoc = gutils->getNewFromOriginal(&call)->getDebugLoc();
2902
2903 auto rule = [&](Value *ip) {
2904 ValueType BundleTypes[2] = {ValueType::Shadow, ValueType::Primal};
2905
2906 auto Defs = gutils->getInvertedBundles(&call, BundleTypes, Builder2,
2907 /*lookup*/ false);
2908
2909 llvm::Value *args[2] = {
2910 ip, gutils->getNewFromOriginal(call.getOperand(1))};
2911 CallInst *CI = Builder2.CreateCall(
2912 call.getFunctionType(), call.getCalledFunction(), args, Defs);
2913 CI->setAttributes(call.getAttributes());
2914 CI->setCallingConv(call.getCallingConv());
2915 CI->setTailCallKind(call.getTailCallKind());
2916 CI->setDebugLoc(dbgLoc);
2917 return CI;
2918 };
2919
2920 Value *CI = applyChainRule(
2921 call.getType(), Builder2, rule,
2922 gutils->invertPointerM(call.getOperand(0), Builder2));
2923
2924 auto found = gutils->invertedPointers.find(&call);
2925 PHINode *placeholder = cast<PHINode>(&*found->second);
2926
2927 gutils->invertedPointers.erase(found);
2928 gutils->replaceAWithB(placeholder, CI);
2929 gutils->erase(placeholder);
2930 gutils->invertedPointers.insert(
2931 std::make_pair(&call, InvertedPointerVH(gutils, CI)));
2932 }
2933 eraseIfUnused(call);
2934 return true;
2935 }
2936 }
2937
2938 if (isAllocationFunction(funcName, gutils->TLI)) {
2939
2940 bool constval = gutils->isConstantValue(&call);
2941
2942 if (!constval) {
2943 auto dbgLoc = gutils->getNewFromOriginal(&call)->getDebugLoc();
2944 auto found = gutils->invertedPointers.find(&call);
2945 PHINode *placeholder = cast<PHINode>(&*found->second);
2946 IRBuilder<> bb(placeholder);
2947
2948 SmallVector<Value *, 8> args;
2949 for (auto &arg : call.args()) {
2950 args.push_back(gutils->getNewFromOriginal(arg));
2951 }
2952
2957
2958 Value *anti = placeholder;
2959 // If rematerializable allocations and split mode, we can
2960 // simply elect to build the entire piece in the reverse
2961 // since it should be possible to perform any shadow stores
2962 // of pointers (from rematerializable property) and it does
2963 // not escape the function scope (lest it not be
2964 // rematerializable) so all input derivatives remain zero.
2965 bool backwardsShadow = false;
2966 bool forwardsShadow = true;
2967 bool inLoop = false;
2968 bool isAlloca = isa<AllocaInst>(&call);
2969 {
2970 auto found = gutils->backwardsOnlyShadows.find(&call);
2971 if (found != gutils->backwardsOnlyShadows.end()) {
2972 backwardsShadow = true;
2973 forwardsShadow = found->second.primalInitialize;
2974 // If in a loop context, maintain the same free behavior.
2975 if (found->second.LI &&
2976 found->second.LI->contains(call.getParent()))
2977 inLoop = true;
2978 }
2979 }
2980 {
2981
2982 if (!forwardsShadow) {
2984 // Needs a stronger replacement check/assertion.
2985 Value *replacement = getUndefinedValueForType(
2986 *gutils->oldFunc->getParent(), placeholder->getType());
2987 gutils->replaceAWithB(placeholder, replacement);
2988 gutils->invertedPointers.erase(found);
2989 gutils->invertedPointers.insert(std::make_pair(
2990 &call, InvertedPointerVH(gutils, replacement)));
2991 gutils->erase(placeholder);
2992 anti = nullptr;
2993 goto endAnti;
2994 } else if (inLoop) {
2996 placeholder);
2997 if (hasMetadata(&call, "enzyme_fromstack"))
2998 isAlloca = true;
2999 goto endAnti;
3000 }
3001 }
3002 placeholder->setName("");
3003 if (shadowHandlers.find(funcName) != shadowHandlers.end()) {
3004 bb.SetInsertPoint(placeholder);
3005
3007 (Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
3009 backwardsShadow)) {
3010 anti = applyChainRule(call.getType(), bb, [&]() {
3011 return shadowHandlers[funcName](bb, &call, args, gutils);
3012 });
3013 if (anti->getType() != placeholder->getType()) {
3014 llvm::errs() << "orig: " << call << "\n";
3015 llvm::errs() << "placeholder: " << *placeholder << "\n";
3016 llvm::errs() << "anti: " << *anti << "\n";
3017 }
3018 gutils->invertedPointers.erase(found);
3019 bb.SetInsertPoint(placeholder);
3020
3021 gutils->replaceAWithB(placeholder, anti);
3022 gutils->erase(placeholder);
3023 }
3024
3025 if (auto inst = dyn_cast<Instruction>(anti))
3026 bb.SetInsertPoint(inst);
3027
3028 if (!backwardsShadow)
3029 anti = gutils->cacheForReverse(
3030 bb, anti, getIndex(&call, CacheType::Shadow, BuilderZ));
3031 } else {
3032 bool zeroed = false;
3033 uint64_t idx = 0;
3034 Value *prev = nullptr;
3035 ;
3036 auto rule = [&]() {
3037 Value *anti =
3038 bb.CreateCall(call.getFunctionType(), call.getCalledOperand(),
3039 args, call.getName() + "'mi");
3040 cast<CallInst>(anti)->setAttributes(call.getAttributes());
3041 cast<CallInst>(anti)->setCallingConv(call.getCallingConv());
3042 cast<CallInst>(anti)->setTailCallKind(call.getTailCallKind());
3043 cast<CallInst>(anti)->setDebugLoc(dbgLoc);
3044
3045 if (anti->getType()->isPointerTy()) {
3046 cast<CallInst>(anti)->addAttributeAtIndex(
3047 AttributeList::ReturnIndex, Attribute::NoAlias);
3048 cast<CallInst>(anti)->addAttributeAtIndex(
3049 AttributeList::ReturnIndex, Attribute::NonNull);
3050
3051 if (funcName == "malloc" || funcName == "_Znwm" ||
3052 funcName == "??2@YAPAXI@Z" ||
3053 funcName == "??2@YAPEAX_K@Z") {
3054 if (auto ci = dyn_cast<ConstantInt>(args[0])) {
3055 unsigned derefBytes = ci->getLimitedValue();
3056 CallInst *cal =
3057 cast<CallInst>(gutils->getNewFromOriginal(&call));
3058 cast<CallInst>(anti)->addDereferenceableRetAttr(derefBytes);
3059 cal->addDereferenceableRetAttr(derefBytes);
3060#if !defined(FLANG) && !defined(ROCM)
3061 AttrBuilder B(ci->getContext());
3062#else
3063 AttrBuilder B;
3064#endif
3065 B.addDereferenceableOrNullAttr(derefBytes);
3066 cast<CallInst>(anti)->setAttributes(
3067 cast<CallInst>(anti)->getAttributes().addRetAttributes(
3068 call.getContext(), B));
3069 cal->setAttributes(cal->getAttributes().addRetAttributes(
3070 call.getContext(), B));
3071 cal->addAttributeAtIndex(AttributeList::ReturnIndex,
3072 Attribute::NoAlias);
3073 cal->addAttributeAtIndex(AttributeList::ReturnIndex,
3074 Attribute::NonNull);
3075 }
3076 }
3077 if (funcName == "julia.gc_alloc_obj" ||
3078 funcName == "jl_gc_alloc_typed" ||
3079 funcName == "ijl_gc_alloc_typed") {
3081 bool used = unnecessaryInstructions.find(&call) ==
3082 unnecessaryInstructions.end();
3083 EnzymeShadowAllocRewrite(wrap(anti), gutils, wrap(&call),
3084 idx, wrap(prev), used);
3085 }
3086 }
3087 }
3090 forwardsShadow) ||
3092 backwardsShadow) ||
3094 backwardsShadow)) {
3095 if (!inLoop) {
3096 zeroKnownAllocation(bb, anti, args, funcName, gutils->TLI,
3097 &call);
3098 zeroed = true;
3099 }
3100 }
3101 idx++;
3102 prev = anti;
3103 return anti;
3104 };
3105
3106 anti = applyChainRule(call.getType(), bb, rule);
3107
3108 gutils->invertedPointers.erase(found);
3109 if (&*bb.GetInsertPoint() == placeholder)
3110 bb.SetInsertPoint(placeholder->getNextNode());
3111 gutils->replaceAWithB(placeholder, anti);
3112 gutils->erase(placeholder);
3113
3114 if (!backwardsShadow)
3115 anti = gutils->cacheForReverse(
3116 bb, anti, getIndex(&call, CacheType::Shadow, BuilderZ));
3117 else {
3118 if (auto MD = hasMetadata(&call, "enzyme_fromstack")) {
3119 isAlloca = true;
3120 bb.SetInsertPoint(cast<Instruction>(anti));
3121 Value *Size;
3122 if (funcName == "malloc")
3123 Size = args[0];
3124 else if (funcName == "julia.gc_alloc_obj" ||
3125 funcName == "jl_gc_alloc_typed" ||
3126 funcName == "ijl_gc_alloc_typed")
3127 Size = args[1];
3128 else
3129 llvm_unreachable("Unknown allocation to upgrade");
3130
3131 Type *elTy = Type::getInt8Ty(call.getContext());
3132 if (MD->getNumOperands() == 2) {
3133 elTy = (Type *)cast<ConstantInt>(
3134 cast<ConstantAsMetadata>(MD->getOperand(1))
3135 ->getValue())
3136 ->getLimitedValue();
3137 Value *tsize = ConstantInt::get(
3138 Size->getType(), (gutils->newFunc->getParent()
3139 ->getDataLayout()
3140 .getTypeAllocSizeInBits(elTy) +
3141 7) /
3142 8);
3143
3144 Size = bb.CreateUDiv(Size, tsize, "", /*exact*/ true);
3145 }
3146 std::string name = "";
3147#if LLVM_VERSION_MAJOR < 17
3148 if (call.getContext().supportsTypedPointers()) {
3149 for (auto U : call.users()) {
3150 if (hasMetadata(cast<Instruction>(U), "enzyme_caststack")) {
3151 if (MD->getNumOperands() == 1) {
3152 elTy = U->getType()->getPointerElementType();
3153 Value *tsize = ConstantInt::get(
3154 Size->getType(),
3155 (gutils->newFunc->getParent()
3156 ->getDataLayout()
3157 .getTypeAllocSizeInBits(elTy) +
3158 7) /
3159 8);
3160
3161 Size = bb.CreateUDiv(Size, tsize, "", /*exact*/ true);
3162 }
3163 name = (U->getName() + "'ai").str();
3164 break;
3165 }
3166 }
3167 }
3168#endif
3169 auto rule = [&](Value *anti) {
3170 bb.SetInsertPoint(cast<Instruction>(anti));
3171 Value *replacement = bb.CreateAlloca(elTy, Size, name);
3172 if (name.size() == 0)
3173 replacement->takeName(anti);
3174 else
3175 anti->setName("");
3176 auto Alignment = cast<ConstantInt>(cast<ConstantAsMetadata>(
3177 MD->getOperand(0))
3178 ->getValue())
3179 ->getLimitedValue();
3180 if (Alignment) {
3181 cast<AllocaInst>(replacement)
3182 ->setAlignment(Align(Alignment));
3183 }
3184#if LLVM_VERSION_MAJOR < 17
3185 if (call.getContext().supportsTypedPointers()) {
3186 if (anti->getType()->getPointerElementType() != elTy)
3187 replacement = bb.CreatePointerCast(
3188 replacement,
3189 getUnqual(anti->getType()->getPointerElementType()));
3190 }
3191#endif
3192 if (int AS = cast<PointerType>(anti->getType())
3193 ->getAddressSpace()) {
3194 llvm::PointerType *PT;
3195#if LLVM_VERSION_MAJOR < 17
3196 if (call.getContext().supportsTypedPointers()) {
3197 PT = PointerType::get(
3198 anti->getType()->getPointerElementType(), AS);
3199#endif
3200#if LLVM_VERSION_MAJOR < 17
3201 } else {
3202#endif
3203 PT = PointerType::get(anti->getContext(), AS);
3204#if LLVM_VERSION_MAJOR < 17
3205 }
3206#endif
3207 replacement = bb.CreateAddrSpaceCast(replacement, PT);
3208 cast<Instruction>(replacement)
3209 ->setMetadata(
3210 "enzyme_backstack",
3211 MDNode::get(replacement->getContext(), {}));
3212 }
3213 gutils->replaceAWithB(cast<Instruction>(anti), replacement);
3214 bb.SetInsertPoint(cast<Instruction>(anti)->getNextNode());
3215 gutils->erase(cast<Instruction>(anti));
3216 return replacement;
3217 };
3218
3219 auto replacement =
3220 applyChainRule(call.getType(), bb, rule, anti);
3221 anti = replacement;
3222 }
3223 }
3224
3226 (Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
3228 backwardsShadow) ||
3229 (Mode == DerivativeMode::ForwardModeSplit && backwardsShadow)) {
3230 if (!inLoop) {
3231 assert(zeroed);
3232 }
3233 }
3234 }
3235 gutils->invertedPointers.insert(
3236 std::make_pair(&call, InvertedPointerVH(gutils, anti)));
3237 }
3238 endAnti:;
3239 if (((Mode == DerivativeMode::ReverseModeCombined && shouldFree()) ||
3242 !isAlloca) {
3243 IRBuilder<> Builder2(&call);
3244 getReverseBuilder(Builder2);
3245 assert(anti);
3246 Value *tofree = lookup(anti, Builder2);
3247 assert(tofree);
3248 assert(tofree->getType());
3249 auto rule = [&](Value *tofree) {
3250 auto CI = freeKnownAllocation(Builder2, tofree, funcName, dbgLoc,
3251 gutils->TLI, &call, gutils);
3252 if (CI)
3253 CI->addAttributeAtIndex(AttributeList::FirstArgIndex,
3254 Attribute::NonNull);
3255 };
3256 applyChainRule(Builder2, rule, tofree);
3257 }
3258 } else if (Mode == DerivativeMode::ForwardMode ||
3260 IRBuilder<> Builder2(&call);
3261 getForwardBuilder(Builder2);
3262
3263 SmallVector<Value *, 2> args;
3264 for (unsigned i = 0; i < call.arg_size(); ++i) {
3265 auto arg = call.getArgOperand(i);
3266 args.push_back(gutils->getNewFromOriginal(arg));
3267 }
3268
3269 uint64_t idx = 0;
3270 Value *prev = gutils->getNewFromOriginal(&call);
3271 auto rule = [&]() {
3272 SmallVector<ValueType, 2> BundleTypes(args.size(), ValueType::Primal);
3273
3274 auto Defs = gutils->getInvertedBundles(&call, BundleTypes, Builder2,
3275 /*lookup*/ false);
3276
3277 CallInst *CI = Builder2.CreateCall(
3278 call.getFunctionType(), call.getCalledFunction(), args, Defs);
3279 CI->setAttributes(call.getAttributes());
3280 CI->setCallingConv(call.getCallingConv());
3281 CI->setTailCallKind(call.getTailCallKind());
3282 CI->setDebugLoc(dbgLoc);
3283
3284 if (funcName == "julia.gc_alloc_obj" ||
3285 funcName == "jl_gc_alloc_typed" ||
3286 funcName == "ijl_gc_alloc_typed") {
3288 bool used = unnecessaryInstructions.find(&call) ==
3289 unnecessaryInstructions.end();
3290 EnzymeShadowAllocRewrite(wrap(CI), gutils, wrap(&call), idx,
3291 wrap(prev), used);
3292 }
3293 }
3294 idx++;
3295 prev = CI;
3296 return CI;
3297 };
3298
3299 Value *CI = applyChainRule(call.getType(), Builder2, rule);
3300
3301 auto found = gutils->invertedPointers.find(&call);
3302 PHINode *placeholder = cast<PHINode>(&*found->second);
3303
3304 gutils->invertedPointers.erase(found);
3305 gutils->replaceAWithB(placeholder, CI);
3306 gutils->erase(placeholder);
3307 gutils->invertedPointers.insert(
3308 std::make_pair(&call, InvertedPointerVH(gutils, CI)));
3309 }
3310 }
3311
3312 // Cache and rematerialization irrelevant for forward mode.
3313 if (Mode == DerivativeMode::ForwardMode ||
3315 eraseIfUnused(call);
3316 return true;
3317 }
3318
3319 std::map<UsageKey, bool> Seen;
3320 for (auto pair : gutils->knownRecomputeHeuristic)
3321 if (!pair.second ||
3322 gutils->unnecessaryIntermediates.count(cast<Instruction>(pair.first)))
3323 Seen[UsageKey(pair.first, QueryType::Primal)] = false;
3324
3325 bool primalNeededInReverse =
3328 ? false
3330 QueryType::Primal>(gutils, &call, Mode, Seen, oldUnreachable);
3331
3332 bool cacheWholeAllocation = gutils->needsCacheWholeAllocation(&call);
3333 if (cacheWholeAllocation) {
3334 primalNeededInReverse = true;
3335 }
3336
3337 auto restoreFromStack = [&](MDNode *MD) {
3338 IRBuilder<> B(newCall);
3339 Value *Size;
3340 if (funcName == "malloc")
3341 Size = call.getArgOperand(0);
3342 else if (funcName == "julia.gc_alloc_obj" ||
3343 funcName == "jl_gc_alloc_typed" ||
3344 funcName == "ijl_gc_alloc_typed")
3345 Size = call.getArgOperand(1);
3346 else
3347 llvm_unreachable("Unknown allocation to upgrade");
3348 Size = gutils->getNewFromOriginal(Size);
3349
3350 if (isa<ConstantInt>(Size)) {
3351 B.SetInsertPoint(gutils->inversionAllocs);
3352 }
3353 Type *elTy = Type::getInt8Ty(call.getContext());
3354 if (MD->getNumOperands() == 2) {
3355 elTy = (Type *)cast<ConstantInt>(
3356 cast<ConstantAsMetadata>(MD->getOperand(1))->getValue())
3357 ->getLimitedValue();
3358 Value *tsize = ConstantInt::get(Size->getType(),
3359 (gutils->newFunc->getParent()
3360 ->getDataLayout()
3361 .getTypeAllocSizeInBits(elTy) +
3362 7) /
3363 8);
3364 Size = B.CreateUDiv(Size, tsize, "", /*exact*/ true);
3365 }
3366 Instruction *I = nullptr;
3367#if LLVM_VERSION_MAJOR < 17
3368 if (call.getContext().supportsTypedPointers()) {
3369 for (auto U : call.users()) {
3370 if (hasMetadata(cast<Instruction>(U), "enzyme_caststack")) {
3371 if (MD->getNumOperands() == 1) {
3372 elTy = U->getType()->getPointerElementType();
3373 Value *tsize = ConstantInt::get(
3374 Size->getType(), (gutils->newFunc->getParent()
3375 ->getDataLayout()
3376 .getTypeAllocSizeInBits(elTy) +
3377 7) /
3378 8);
3379
3380 Size = B.CreateUDiv(Size, tsize, "", /*exact*/ true);
3381 }
3382 I = gutils->getNewFromOriginal(cast<Instruction>(U));
3383 break;
3384 }
3385 }
3386 }
3387#endif
3388 Value *replacement = B.CreateAlloca(elTy, Size);
3389 for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type",
3390 "enzymejl_allocart", "enzymejl_allocart_name",
3391 "enzymejl_gc_alloc_rt"})
3392 if (auto M = call.getMetadata(MD))
3393 cast<AllocaInst>(replacement)->setMetadata(MD, M);
3394 if (I)
3395 replacement->takeName(I);
3396 else
3397 replacement->takeName(newCall);
3398 auto Alignment =
3399 cast<ConstantInt>(
3400 cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
3401 ->getLimitedValue();
3402 // Don't set zero alignment
3403 if (Alignment) {
3404 cast<AllocaInst>(replacement)->setAlignment(Align(Alignment));
3405 }
3406#if LLVM_VERSION_MAJOR < 17
3407 if (call.getContext().supportsTypedPointers()) {
3408 if (call.getType()->getPointerElementType() != elTy)
3409 replacement = B.CreatePointerCast(
3410 replacement, getUnqual(call.getType()->getPointerElementType()));
3411 }
3412#endif
3413 if (int AS = cast<PointerType>(call.getType())->getAddressSpace()) {
3414 llvm::PointerType *PT;
3415#if LLVM_VERSION_MAJOR < 17
3416 if (call.getContext().supportsTypedPointers()) {
3417 PT = PointerType::get(call.getType()->getPointerElementType(), AS);
3418#endif
3419#if LLVM_VERSION_MAJOR < 17
3420 } else {
3421#endif
3422 PT = PointerType::get(call.getContext(), AS);
3423#if LLVM_VERSION_MAJOR < 17
3424 }
3425#endif
3426 replacement = B.CreateAddrSpaceCast(replacement, PT);
3427 cast<Instruction>(replacement)
3428 ->setMetadata("enzyme_backstack",
3429 MDNode::get(replacement->getContext(), {}));
3430 }
3431 gutils->replaceAWithB(newCall, replacement);
3432 gutils->erase(newCall);
3433 };
3434
3435 // Don't erase any allocation that is being rematerialized.
3436 {
3437 auto found = gutils->rematerializableAllocations.find(&call);
3438 if (found != gutils->rematerializableAllocations.end()) {
3439 // If rematerializing (e.g. needed in reverse, but not needing
3440 // the whole allocation):
3441 if (primalNeededInReverse && !cacheWholeAllocation) {
3442 assert(!unnecessaryValues.count(&call));
3443 // if rematerialize, don't ever cache and downgrade to stack
3444 // allocation where possible. Note that for allocations which are
3445 // within a loop, we will create the rematerialized allocation in the
3446 // rematerialied loop. Note that what matters here is whether the
3447 // actual call itself here is inside the loop, not whether the
3448 // rematerialization is loop level. This is because one can have a
3449 // loop level cache, but a function level allocation (e.g. for stack
3450 // allocas). If we deleted it here, we would have no allocation!
3451 auto AllocationLoop = gutils->OrigLI->getLoopFor(call.getParent());
3452 // An allocation within a loop, must definitionally be a loop level
3453 // allocation (but not always the other way around.
3454 if (AllocationLoop)
3455 assert(found->second.LI);
3456 if (auto MD = hasMetadata(&call, "enzyme_fromstack")) {
3457 if (Mode == DerivativeMode::ReverseModeGradient && AllocationLoop) {
3458 gutils->rematerializedPrimalOrShadowAllocations.push_back(
3459 newCall);
3460 } else {
3461 restoreFromStack(MD);
3462 }
3463 return true;
3464 }
3465
3466 // No need to free GC.
3467 if (EnzymeJuliaAddrLoad && isa<PointerType>(call.getType()) &&
3468 cast<PointerType>(call.getType())->getAddressSpace() == 10) {
3469 if (Mode == DerivativeMode::ReverseModeGradient && AllocationLoop)
3470 gutils->rematerializedPrimalOrShadowAllocations.push_back(
3471 newCall);
3472 return true;
3473 }
3474
3475 // Otherwise if in reverse pass, free the newly created allocation.
3479 IRBuilder<> Builder2(&call);
3480 getReverseBuilder(Builder2);
3481 auto dbgLoc = gutils->getNewFromOriginal(call.getDebugLoc());
3482 freeKnownAllocation(Builder2, lookup(newCall, Builder2), funcName,
3483 dbgLoc, gutils->TLI, &call, gutils);
3484 if (Mode == DerivativeMode::ReverseModeGradient && AllocationLoop)
3485 gutils->rematerializedPrimalOrShadowAllocations.push_back(
3486 newCall);
3487 return true;
3488 }
3489 // If in primal, do nothing (keeping the original caching behavior)
3491 return true;
3492 } else if (!cacheWholeAllocation) {
3493 if (unnecessaryValues.count(&call)) {
3494 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
3495 return true;
3496 }
3497 // If not caching allocation and not needed in the reverse, we can
3498 // use the original freeing behavior for the function. If in the
3499 // reverse pass we should not recreate this allocation.
3501 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
3502 else if (auto MD = hasMetadata(&call, "enzyme_fromstack")) {
3503 restoreFromStack(MD);
3504 }
3505 return true;
3506 }
3507 }
3508 }
3509
3510 // If an allocation is not needed in the reverse, maintain the original
3511 // free behavior and do not rematerialize this for the reverse. However,
3512 // this is only safe to perform for allocations with a guaranteed free
3513 // as can we can only guarantee that we don't erase those frees.
3514 bool hasPDFree = gutils->allocationsWithGuaranteedFree.count(&call);
3515 if (!primalNeededInReverse && hasPDFree) {
3516 if (unnecessaryValues.count(&call)) {
3517 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
3518 return true;
3519 }
3522 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
3523 } else {
3524 if (auto MD = hasMetadata(&call, "enzyme_fromstack")) {
3525 restoreFromStack(MD);
3526 }
3527 }
3528 return true;
3529 }
3530
3531 // If an object is managed by the GC do not preserve it for later free,
3532 // Thus it only needs caching if there is a need for it in the reverse.
3533 if (EnzymeJuliaAddrLoad && isa<PointerType>(call.getType()) &&
3534 cast<PointerType>(call.getType())->getAddressSpace() == 10) {
3535 if (!subretused) {
3536 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
3537 return true;
3538 }
3539 if (!primalNeededInReverse) {
3542 auto pn = BuilderZ.CreatePHI(call.getType(), 1,
3543 call.getName() + "_replacementJ");
3544 gutils->fictiousPHIs[pn] = &call;
3545 gutils->replaceAWithB(newCall, pn);
3546 gutils->erase(newCall);
3547 }
3548 } else if (Mode != DerivativeMode::ReverseModeCombined) {
3549 gutils->cacheForReverse(BuilderZ, newCall,
3550 getIndex(&call, CacheType::Self, BuilderZ));
3551 }
3552 return true;
3553 }
3554
3556 hasPDFree = true;
3557
3558 // TODO enable this if we need to free the memory
3559 // NOTE THAT TOPLEVEL IS THERE SIMPLY BECAUSE THAT WAS PREVIOUS ATTITUTE
3560 // TO FREE'ing
3561 if ((primalNeededInReverse &&
3562 !gutils->unnecessaryIntermediates.count(&call)) ||
3563 hasPDFree) {
3564 Value *nop = gutils->cacheForReverse(
3565 BuilderZ, newCall, getIndex(&call, CacheType::Self, BuilderZ));
3566 if (hasPDFree &&
3570 IRBuilder<> Builder2(&call);
3571 getReverseBuilder(Builder2);
3572 auto dbgLoc = gutils->getNewFromOriginal(call.getDebugLoc());
3573 freeKnownAllocation(Builder2, lookup(nop, Builder2), funcName, dbgLoc,
3574 gutils->TLI, &call, gutils);
3575 }
3576 } else if (Mode == DerivativeMode::ReverseModeGradient ||
3579 // Note that here we cannot simply replace with null as users who
3580 // try to find the shadow pointer will use the shadow of null rather
3581 // than the true shadow of this
3582 auto pn = BuilderZ.CreatePHI(call.getType(), 1,
3583 call.getName() + "_replacementB");
3584 gutils->fictiousPHIs[pn] = &call;
3585 gutils->replaceAWithB(newCall, pn);
3586 gutils->erase(newCall);
3587 }
3588
3589 return true;
3590 }
3591
3592 if (funcName == "julia.gc_loaded") {
3593 if (gutils->isConstantValue(&call)) {
3594 eraseIfUnused(call);
3595 return true;
3596 }
3597 auto ifound = gutils->invertedPointers.find(&call);
3598 assert(ifound != gutils->invertedPointers.end());
3599
3600 if (auto placeholder = dyn_cast<PHINode>(&*ifound->second)) {
3601
3603 QueryType::Shadow>(gutils, &call, Mode, oldUnreachable);
3604 if (!needShadow) {
3605 gutils->invertedPointers.erase(ifound);
3606 gutils->erase(placeholder);
3607 eraseIfUnused(call);
3608 return true;
3609 }
3610
3611 gutils->invertedPointers.erase(ifound);
3612 auto res = gutils->invertPointerM(&call, BuilderZ);
3613
3614 gutils->replaceAWithB(placeholder, res);
3615 gutils->erase(placeholder);
3616 }
3617 eraseIfUnused(call);
3618
3619 return true;
3620 }
3621
3622 if (funcName == "julia.pointer_from_objref") {
3623 if (gutils->isConstantValue(&call)) {
3624 eraseIfUnused(call);
3625 return true;
3626 }
3627
3628 auto ifound = gutils->invertedPointers.find(&call);
3629 assert(ifound != gutils->invertedPointers.end());
3630
3631 auto placeholder = cast<PHINode>(&*ifound->second);
3632
3633 bool needShadow =
3635 gutils, &call, Mode, oldUnreachable);
3636 if (!needShadow) {
3637 gutils->invertedPointers.erase(ifound);
3638 gutils->erase(placeholder);
3639 eraseIfUnused(call);
3640 return true;
3641 }
3642
3643 Value *ptrshadow = gutils->invertPointerM(call.getArgOperand(0), BuilderZ);
3644
3645 Value *val = applyChainRule(
3646 call.getType(), BuilderZ,
3647 [&](Value *v) -> Value * { return BuilderZ.CreateCall(called, {v}); },
3648 ptrshadow);
3649
3650 gutils->replaceAWithB(placeholder, val);
3651 gutils->erase(placeholder);
3652 eraseIfUnused(call);
3653 return true;
3654 }
3655 if (funcName.contains("__enzyme_todense")) {
3656 if (gutils->isConstantValue(&call)) {
3657 eraseIfUnused(call);
3658 return true;
3659 }
3660
3661 auto ifound = gutils->invertedPointers.find(&call);
3662 assert(ifound != gutils->invertedPointers.end());
3663
3664 auto placeholder = cast<PHINode>(&*ifound->second);
3665
3666 bool needShadow =
3668 gutils, &call, Mode, oldUnreachable);
3669 if (!needShadow) {
3670 gutils->invertedPointers.erase(ifound);
3671 gutils->erase(placeholder);
3672 eraseIfUnused(call);
3673 return true;
3674 }
3675
3676 SmallVector<Value *, 3> args;
3677 for (size_t i = 0; i < 2; i++)
3678 args.push_back(gutils->getNewFromOriginal(call.getArgOperand(i)));
3679 for (size_t i = 2; i < call.arg_size(); ++i)
3680 args.push_back(gutils->invertPointerM(call.getArgOperand(0), BuilderZ));
3681
3682 Value *res = UndefValue::get(gutils->getShadowType(call.getType()));
3683 if (gutils->getWidth() == 1) {
3684 res = BuilderZ.CreateCall(called, args);
3685 } else {
3686 for (size_t w = 0; w < gutils->getWidth(); ++w) {
3687 SmallVector<Value *, 3> targs = {args[0], args[1]};
3688 for (size_t i = 2; i < call.arg_size(); ++i)
3689 targs.push_back(GradientUtils::extractMeta(BuilderZ, args[i], w));
3690
3691 auto tres = BuilderZ.CreateCall(called, targs);
3692 res = BuilderZ.CreateInsertValue(res, tres, w);
3693 }
3694 }
3695
3696 gutils->replaceAWithB(placeholder, res);
3697 gutils->erase(placeholder);
3698 eraseIfUnused(call);
3699 return true;
3700 }
3701
3702 if (funcName == "memcpy" || funcName == "memmove") {
3703 auto ID = (funcName == "memcpy") ? Intrinsic::memcpy : Intrinsic::memmove;
3704 visitMemTransferCommon(ID, /*srcAlign*/ MaybeAlign(1),
3705 /*dstAlign*/ MaybeAlign(1), call,
3706 call.getArgOperand(0), call.getArgOperand(1),
3707 gutils->getNewFromOriginal(call.getArgOperand(2)),
3708 ConstantInt::getFalse(call.getContext()));
3709 return true;
3710 }
3711 if (funcName == "memset" || funcName == "memset_pattern16" ||
3712 funcName == "__memset_chk") {
3713 visitMemSetCommon(call);
3714 return true;
3715 }
3716 if (funcName == "enzyme_zerotype") {
3717 IRBuilder<> BuilderZ(&call);
3718 getForwardBuilder(BuilderZ);
3719
3720 bool backwardsShadow = false;
3721 bool forwardsShadow = true;
3722 for (auto pair : gutils->backwardsOnlyShadows) {
3723 if (pair.second.stores.count(&call)) {
3724 backwardsShadow = true;
3725 forwardsShadow = pair.second.primalInitialize;
3726 if (auto inst = dyn_cast<Instruction>(pair.first))
3727 if (!forwardsShadow && pair.second.LI &&
3728 pair.second.LI->contains(inst->getParent()))
3729 backwardsShadow = false;
3730 }
3731 }
3732
3733 bool forceErase =
3734 !((Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
3735 (Mode == DerivativeMode::ReverseModeCombined && forwardsShadow) ||
3736 (Mode == DerivativeMode::ReverseModeGradient && backwardsShadow) ||
3737 (Mode == DerivativeMode::ForwardModeSplit && backwardsShadow));
3738
3739 if (forceErase)
3740 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
3741 else
3742 eraseIfUnused(call);
3743
3744 Value *orig_op0 = call.getArgOperand(0);
3745
3746 // If constant destination then no operation needs doing
3747 if (gutils->isConstantValue(orig_op0)) {
3748 return true;
3749 }
3750
3751 if (!forceErase) {
3752 Value *op0 = gutils->invertPointerM(orig_op0, BuilderZ);
3753 Value *op1 = gutils->getNewFromOriginal(call.getArgOperand(1));
3754 Value *op2 = gutils->getNewFromOriginal(call.getArgOperand(2));
3755 auto Defs = gutils->getInvertedBundles(
3757 BuilderZ, /*lookup*/ false);
3758
3759 applyChainRule(
3760 BuilderZ,
3761 [&](Value *op0) {
3762 SmallVector<Value *, 4> args = {op0, op1, op2};
3763 auto cal =
3764 BuilderZ.CreateCall(call.getCalledFunction(), args, Defs);
3765 llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy);
3766 ToCopy2.push_back(LLVMContext::MD_noalias);
3767 cal->copyMetadata(call, ToCopy2);
3768 cal->setAttributes(call.getAttributes());
3769 if (auto m = hasMetadata(&call, "enzyme_zerostack"))
3770 cal->setMetadata("enzyme_zerostack", m);
3771 cal->setCallingConv(call.getCallingConv());
3772 cal->setTailCallKind(call.getTailCallKind());
3773 cal->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
3774 },
3775 op0);
3776 }
3777 return true;
3778 }
3779 if (funcName == "cuStreamCreate") {
3780 Value *val = nullptr;
3781 llvm::Type *PT = getInt8PtrTy(call.getContext());
3782#if LLVM_VERSION_MAJOR < 17
3783 if (call.getContext().supportsTypedPointers()) {
3784 if (isa<PointerType>(call.getArgOperand(0)->getType()))
3785 PT = call.getArgOperand(0)->getType()->getPointerElementType();
3786 }
3787#endif
3790 val = gutils->getNewFromOriginal(call.getOperand(0));
3791 if (!isa<PointerType>(val->getType()))
3792 val = BuilderZ.CreateIntToPtr(val, getUnqual(PT));
3793 val = BuilderZ.CreateLoad(PT, val);
3794 val = gutils->cacheForReverse(BuilderZ, val,
3795 getIndex(&call, CacheType::Tape, BuilderZ));
3796
3797 } else if (Mode == DerivativeMode::ReverseModeGradient) {
3798 PHINode *toReplace =
3799 BuilderZ.CreatePHI(PT, 1, call.getName() + "_psxtmp");
3800 val = gutils->cacheForReverse(BuilderZ, toReplace,
3801 getIndex(&call, CacheType::Tape, BuilderZ));
3802 }
3805 IRBuilder<> Builder2(&call);
3806 getReverseBuilder(Builder2);
3807 val = gutils->lookupM(val, Builder2);
3808 auto FreeFunc = gutils->newFunc->getParent()->getOrInsertFunction(
3809 "cuStreamDestroy", call.getType(), PT);
3810 Value *nargs[] = {val};
3811 Builder2.CreateCall(FreeFunc, nargs);
3812 }
3814 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
3815 return true;
3816 }
3817 if (funcName == "cuStreamDestroy") {
3820 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
3821 return true;
3822 }
3823 if (funcName == "cuStreamSynchronize") {
3826 IRBuilder<> Builder2(&call);
3827 getReverseBuilder(Builder2);
3828 Value *nargs[] = {gutils->lookupM(
3829 gutils->getNewFromOriginal(call.getOperand(0)), Builder2)};
3830 auto callval = call.getCalledOperand();
3831 Builder2.CreateCall(call.getFunctionType(), callval, nargs);
3832 }
3834 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
3835 return true;
3836 }
3837 if (funcName == "posix_memalign" || funcName == "cuMemAllocAsync" ||
3838 funcName == "cuMemAlloc" || funcName == "cuMemAlloc_v2" ||
3839 funcName == "cudaMalloc" || funcName == "cudaMallocAsync" ||
3840 funcName == "cudaMallocHost" || funcName == "cudaMallocFromPoolAsync") {
3841 bool constval = gutils->isConstantInstruction(&call);
3842
3843 Value *val;
3844 llvm::Type *PT = getInt8PtrTy(call.getContext());
3845#if LLVM_VERSION_MAJOR < 17
3846 if (call.getContext().supportsTypedPointers()) {
3847 if (isa<PointerType>(call.getArgOperand(0)->getType()))
3848 PT = call.getArgOperand(0)->getType()->getPointerElementType();
3849 }
3850#endif
3851 if (!constval) {
3852 Value *stream = nullptr;
3853 if (funcName == "cuMemAllocAsync")
3854 stream = gutils->getNewFromOriginal(call.getArgOperand(2));
3855 else if (funcName == "cudaMallocAsync")
3856 stream = gutils->getNewFromOriginal(call.getArgOperand(2));
3857 else if (funcName == "cudaMallocFromPoolAsync")
3858 stream = gutils->getNewFromOriginal(call.getArgOperand(3));
3859
3860 auto M = gutils->newFunc->getParent();
3861
3866 Value *ptrshadow =
3867 gutils->invertPointerM(call.getArgOperand(0), BuilderZ);
3868 SmallVector<Value *, 1> args;
3869 SmallVector<ValueType, 1> valtys;
3870 args.push_back(ptrshadow);
3871 valtys.push_back(ValueType::Shadow);
3872 for (size_t i = 1; i < call.arg_size(); ++i) {
3873 args.push_back(gutils->getNewFromOriginal(call.getArgOperand(i)));
3874 valtys.push_back(ValueType::Primal);
3875 }
3876
3877 auto Defs = gutils->getInvertedBundles(&call, valtys, BuilderZ,
3878 /*lookup*/ false);
3879
3880 val = applyChainRule(
3881 PT, BuilderZ,
3882 [&](Value *ptrshadow) {
3883 args[0] = ptrshadow;
3884
3885 BuilderZ.CreateCall(called, args, Defs);
3886 if (!isa<PointerType>(ptrshadow->getType()))
3887 ptrshadow = BuilderZ.CreateIntToPtr(ptrshadow, getUnqual(PT));
3888 Value *val = BuilderZ.CreateLoad(PT, ptrshadow);
3889
3890 auto dst_arg =
3891 BuilderZ.CreateBitCast(val, getInt8PtrTy(call.getContext()));
3892
3893 auto val_arg =
3894 ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
3895 auto len_arg = gutils->getNewFromOriginal(
3896 call.getArgOperand((funcName == "posix_memalign") ? 2 : 1));
3897
3898 if (funcName == "posix_memalign" ||
3899 funcName == "cudaMallocHost") {
3900 BuilderZ.CreateMemSet(dst_arg, val_arg, len_arg, MaybeAlign());
3901 } else if (funcName == "cudaMalloc") {
3902 Type *tys[] = {PT, val_arg->getType(), len_arg->getType()};
3903 auto F = M->getOrInsertFunction(
3904 "cudaMemset",
3905 FunctionType::get(call.getType(), tys, false));
3906 Value *nargs[] = {dst_arg, val_arg, len_arg};
3907 auto memset = cast<CallInst>(BuilderZ.CreateCall(F, nargs));
3908 memset->addParamAttr(0, Attribute::NonNull);
3909 } else if (funcName == "cudaMallocAsync" ||
3910 funcName == "cudaMallocFromPoolAsync") {
3911 Type *tys[] = {PT, val_arg->getType(), len_arg->getType(),
3912 stream->getType()};
3913 auto F = M->getOrInsertFunction(
3914 "cudaMemsetAsync",
3915 FunctionType::get(call.getType(), tys, false));
3916 Value *nargs[] = {dst_arg, val_arg, len_arg, stream};
3917 auto memset = cast<CallInst>(BuilderZ.CreateCall(F, nargs));
3918 memset->addParamAttr(0, Attribute::NonNull);
3919 } else if (funcName == "cuMemAllocAsync") {
3920 Type *tys[] = {PT, val_arg->getType(), len_arg->getType(),
3921 stream->getType()};
3922 auto F = M->getOrInsertFunction(
3923 "cuMemsetD8Async",
3924 FunctionType::get(call.getType(), tys, false));
3925 Value *nargs[] = {dst_arg, val_arg, len_arg, stream};
3926 auto memset = cast<CallInst>(BuilderZ.CreateCall(F, nargs));
3927 memset->addParamAttr(0, Attribute::NonNull);
3928 } else if (funcName == "cuMemAlloc" ||
3929 funcName == "cuMemAlloc_v2") {
3930 Type *tys[] = {PT, val_arg->getType(), len_arg->getType()};
3931 auto F = M->getOrInsertFunction(
3932 "cuMemsetD8",
3933 FunctionType::get(call.getType(), tys, false));
3934 Value *nargs[] = {dst_arg, val_arg, len_arg};
3935 auto memset = cast<CallInst>(BuilderZ.CreateCall(F, nargs));
3936 memset->addParamAttr(0, Attribute::NonNull);
3937 } else {
3938 llvm_unreachable("unhandled allocation");
3939 }
3940 return val;
3941 },
3942 ptrshadow);
3943
3944 if (Mode != DerivativeMode::ForwardMode &&
3946 val = gutils->cacheForReverse(
3947 BuilderZ, val, getIndex(&call, CacheType::Tape, BuilderZ));
3948 } else if (Mode == DerivativeMode::ReverseModeGradient) {
3949 PHINode *toReplace = BuilderZ.CreatePHI(gutils->getShadowType(PT), 1,
3950 call.getName() + "_psxtmp");
3951 val = gutils->cacheForReverse(
3952 BuilderZ, toReplace, getIndex(&call, CacheType::Tape, BuilderZ));
3953 }
3954
3957 if (shouldFree()) {
3958 IRBuilder<> Builder2(&call);
3959 getReverseBuilder(Builder2);
3960 Value *tofree = gutils->lookupM(val, Builder2, ValueToValueMapTy(),
3961 /*tryLegalRecompute*/ false);
3962
3963 Type *VoidTy = Type::getVoidTy(M->getContext());
3964 Type *IntPtrTy = getInt8PtrTy(M->getContext());
3965
3966 Value *streamL = nullptr;
3967 if (stream)
3968 streamL = gutils->lookupM(stream, Builder2);
3969
3970 applyChainRule(
3971 BuilderZ,
3972 [&](Value *tofree) {
3973 if (funcName == "posix_memalign") {
3974 auto FreeFunc =
3975 M->getOrInsertFunction("free", VoidTy, IntPtrTy);
3976 Builder2.CreateCall(FreeFunc, tofree);
3977 } else if (funcName == "cuMemAllocAsync") {
3978 auto FreeFunc = M->getOrInsertFunction(
3979 "cuMemFreeAsync", VoidTy, IntPtrTy, streamL->getType());
3980 Value *nargs[] = {tofree, streamL};
3981 Builder2.CreateCall(FreeFunc, nargs);
3982 } else if (funcName == "cuMemAlloc" ||
3983 funcName == "cuMemAlloc_v2") {
3984 auto FreeFunc =
3985 M->getOrInsertFunction("cuMemFree", VoidTy, IntPtrTy);
3986 Value *nargs[] = {tofree};
3987 Builder2.CreateCall(FreeFunc, nargs);
3988 } else if (funcName == "cudaMalloc") {
3989 auto FreeFunc =
3990 M->getOrInsertFunction("cudaFree", VoidTy, IntPtrTy);
3991 Value *nargs[] = {tofree};
3992 Builder2.CreateCall(FreeFunc, nargs);
3993 } else if (funcName == "cudaMallocAsync" ||
3994 funcName == "cudaMallocFromPoolAsync") {
3995 auto FreeFunc = M->getOrInsertFunction(
3996 "cudaFreeAsync", VoidTy, IntPtrTy, streamL->getType());
3997 Value *nargs[] = {tofree, streamL};
3998 Builder2.CreateCall(FreeFunc, nargs);
3999 } else if (funcName == "cudaMallocHost") {
4000 auto FreeFunc =
4001 M->getOrInsertFunction("cudaFreeHost", VoidTy, IntPtrTy);
4002 Value *nargs[] = {tofree};
4003 Builder2.CreateCall(FreeFunc, nargs);
4004 } else
4005 llvm_unreachable("unknown function to free");
4006 },
4007 tofree);
4008 }
4009 }
4010 }
4011
4012 // TODO enable this if we need to free the memory
4013 // NOTE THAT TOPLEVEL IS THERE SIMPLY BECAUSE THAT WAS PREVIOUS ATTITUTE
4014 // TO FREE'ing
4016 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
4017 } else if (Mode == DerivativeMode::ReverseModePrimal) {
4018 // if (is_value_needed_in_reverse<Primal>(
4019 // TR, gutils, orig, /*topLevel*/ Mode ==
4020 // DerivativeMode::Both))
4021 // {
4022
4023 // gutils->cacheForReverse(BuilderZ, newCall,
4024 // getIndex(orig, CacheType::Self, BuilderZ));
4025 //} else if (Mode != DerivativeMode::Forward) {
4026 // Note that here we cannot simply replace with null as users who try
4027 // to find the shadow pointer will use the shadow of null rather than
4028 // the true shadow of this
4029 //}
4030 } else if (Mode == DerivativeMode::ReverseModeCombined && shouldFree()) {
4031 IRBuilder<> Builder2(newCall->getNextNode());
4032 auto ptrv = gutils->getNewFromOriginal(call.getOperand(0));
4033 if (!isa<PointerType>(ptrv->getType()))
4034 ptrv = BuilderZ.CreateIntToPtr(ptrv, getUnqual(PT));
4035 auto load = Builder2.CreateLoad(PT, ptrv, "posix_preread");
4036 Builder2.SetInsertPoint(&call);
4037 getReverseBuilder(Builder2);
4038 auto tofree = gutils->lookupM(load, Builder2, ValueToValueMapTy(),
4039 /*tryLegal*/ false);
4040 Value *streamL = nullptr;
4041 if (funcName == "cuMemAllocAsync")
4042 streamL = gutils->getNewFromOriginal(call.getArgOperand(2));
4043 else if (funcName == "cudaMallocAsync")
4044 streamL = gutils->getNewFromOriginal(call.getArgOperand(2));
4045 else if (funcName == "cudaMallocFromPoolAsync")
4046 streamL = gutils->getNewFromOriginal(call.getArgOperand(3));
4047 if (streamL)
4048 streamL = gutils->lookupM(streamL, Builder2);
4049
4050 auto M = gutils->newFunc->getParent();
4051 Type *VoidTy = Type::getVoidTy(M->getContext());
4052 Type *IntPtrTy = getInt8PtrTy(M->getContext());
4053
4054 if (funcName == "posix_memalign") {
4055 auto FreeFunc = M->getOrInsertFunction("free", VoidTy, IntPtrTy);
4056 Builder2.CreateCall(FreeFunc, tofree);
4057 } else if (funcName == "cuMemAllocAsync") {
4058 auto FreeFunc = M->getOrInsertFunction("cuMemFreeAsync", VoidTy,
4059 IntPtrTy, streamL->getType());
4060 Value *nargs[] = {tofree, streamL};
4061 Builder2.CreateCall(FreeFunc, nargs);
4062 } else if (funcName == "cuMemAlloc" || funcName == "cuMemAlloc_v2") {
4063 auto FreeFunc = M->getOrInsertFunction("cuMemFree", VoidTy, IntPtrTy);
4064 Value *nargs[] = {tofree};
4065 Builder2.CreateCall(FreeFunc, nargs);
4066 } else if (funcName == "cudaMalloc") {
4067 auto FreeFunc = M->getOrInsertFunction("cudaFree", VoidTy, IntPtrTy);
4068 Value *nargs[] = {tofree};
4069 Builder2.CreateCall(FreeFunc, nargs);
4070 } else if (funcName == "cudaMallocAsync" ||
4071 funcName == "cudaMallocFromPoolAsync") {
4072 auto FreeFunc = M->getOrInsertFunction("cudaFreeAsync", VoidTy,
4073 IntPtrTy, streamL->getType());
4074 Value *nargs[] = {tofree, streamL};
4075 Builder2.CreateCall(FreeFunc, nargs);
4076 } else if (funcName == "cudaMallocHost") {
4077 auto FreeFunc =
4078 M->getOrInsertFunction("cudaFreeHost", VoidTy, IntPtrTy);
4079 Value *nargs[] = {tofree};
4080 Builder2.CreateCall(FreeFunc, nargs);
4081 } else
4082 llvm_unreachable("unknown function to free");
4083 }
4084
4085 return true;
4086 }
4087
4088 // Remove free's in forward pass so the memory can be used in the reverse
4089 // pass
4090 if (isDeallocationFunction(funcName, gutils->TLI)) {
4091 assert(gutils->invertedPointers.find(&call) ==
4092 gutils->invertedPointers.end());
4093
4094 if (Mode == DerivativeMode::ForwardMode ||
4096 if (!gutils->isConstantValue(call.getArgOperand(0))) {
4097 IRBuilder<> Builder2(&call);
4098 getForwardBuilder(Builder2);
4099 auto origfree = call.getArgOperand(0);
4100 auto newfree = gutils->getNewFromOriginal(call.getArgOperand(0));
4101 auto tofree = gutils->invertPointerM(origfree, Builder2);
4102
4103 Function *free = getOrInsertCheckedFree(
4104 *call.getModule(), &call, newfree->getType(), gutils->getWidth());
4105
4106 bool used = true;
4107 if (auto instArg = dyn_cast<Instruction>(call.getArgOperand(0)))
4108 used = unnecessaryInstructions.find(instArg) ==
4109 unnecessaryInstructions.end();
4110
4111 SmallVector<Value *, 3> args;
4112 if (used)
4113 args.push_back(newfree);
4114 else
4115 args.push_back(
4116 Constant::getNullValue(call.getArgOperand(0)->getType()));
4117
4118 auto rule = [&args](Value *tofree) { args.push_back(tofree); };
4119 applyChainRule(Builder2, rule, tofree);
4120
4121 for (size_t i = 1; i < call.arg_size(); i++) {
4122 args.push_back(gutils->getNewFromOriginal(call.getArgOperand(i)));
4123 }
4124
4125 auto frees = Builder2.CreateCall(free->getFunctionType(), free, args);
4126 frees->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
4127
4128 eraseIfUnused(call);
4129 return true;
4130 }
4131 eraseIfUnused(call);
4132 }
4133 auto callval = call.getCalledOperand();
4134
4135 for (auto rmat : gutils->backwardsOnlyShadows) {
4136 if (rmat.second.frees.count(&call)) {
4137 bool shouldFree = false;
4138 if (rmat.second.primalInitialize) {
4140 shouldFree = true;
4141 }
4142
4143 if (shouldFree) {
4144 IRBuilder<> Builder2(&call);
4145 getForwardBuilder(Builder2);
4146 auto origfree = call.getArgOperand(0);
4147 auto tofree = gutils->invertPointerM(origfree, Builder2);
4148 if (tofree != origfree) {
4149 SmallVector<Value *, 2> args = {tofree};
4150 CallInst *CI =
4151 Builder2.CreateCall(call.getFunctionType(), callval, args);
4152 CI->setAttributes(call.getAttributes());
4153 }
4154 }
4155 break;
4156 }
4157 }
4158
4159 // If a rematerializable allocation.
4160 for (auto rmat : gutils->rematerializableAllocations) {
4161 if (rmat.second.frees.count(&call)) {
4162 // Leave the original free behavior since this won't be used
4163 // in the reverse pass in split mode
4165 eraseIfUnused(call);
4166 return true;
4167 } else if (Mode == DerivativeMode::ReverseModeGradient) {
4168 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
4169 return true;
4170 } else {
4172 std::map<UsageKey, bool> Seen;
4173 for (auto pair : gutils->knownRecomputeHeuristic)
4174 if (!pair.second)
4175 Seen[UsageKey(pair.first, QueryType::Primal)] = false;
4176 bool primalNeededInReverse =
4178 QueryType::Primal>(gutils, rmat.first, Mode, Seen,
4179 oldUnreachable);
4180 bool cacheWholeAllocation =
4181 gutils->needsCacheWholeAllocation(rmat.first);
4182 if (cacheWholeAllocation) {
4183 primalNeededInReverse = true;
4184 }
4185 // If in a loop context, maintain the same free behavior, unless
4186 // caching whole allocation.
4187 if (!cacheWholeAllocation) {
4188 eraseIfUnused(call);
4189 return true;
4190 }
4191 assert(!unnecessaryValues.count(rmat.first));
4192 (void)primalNeededInReverse;
4193 assert(primalNeededInReverse);
4194 }
4195 }
4196 }
4197
4198 if (gutils->forwardDeallocations.count(&call)) {
4200 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
4201 } else
4202 eraseIfUnused(call);
4203 return true;
4204 }
4205
4206 if (gutils->postDominatingFrees.count(&call)) {
4207 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
4208 return true;
4209 }
4210
4211 llvm::Value *val = getBaseObject(call.getArgOperand(0));
4212 if (isa<ConstantPointerNull>(val)) {
4213 llvm::errs() << "removing free of null pointer\n";
4214 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
4215 return true;
4216 }
4217
4218 // TODO HANDLE FREE
4219 llvm::errs() << "freeing without malloc " << *val << "\n";
4220 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
4221 return true;
4222 }
4223
4224 if (call.hasFnAttr("enzyme_sample")) {
4227 return true;
4228
4229 bool constval = gutils->isConstantInstruction(&call);
4230
4231 if (constval)
4232 return true;
4233
4234 IRBuilder<> Builder2(&call);
4235 getReverseBuilder(Builder2);
4236
4237 auto trace = call.getArgOperand(call.arg_size() - 1);
4238 auto address = call.getArgOperand(0);
4239
4240 auto dtrace = lookup(gutils->getNewFromOriginal(trace), Builder2);
4241 auto daddress = lookup(gutils->getNewFromOriginal(address), Builder2);
4242
4243 Value *dchoice;
4244 if (TR.query(&call)[{-1}].isPossiblePointer()) {
4245 dchoice = gutils->invertPointerM(&call, Builder2);
4246 } else {
4247 dchoice = diffe(&call, Builder2);
4248 }
4249
4250 if (call.hasMetadata("enzyme_gradient_setter")) {
4251 auto gradient_setter = cast<Function>(
4252 cast<ValueAsMetadata>(
4253 call.getMetadata("enzyme_gradient_setter")->getOperand(0).get())
4254 ->getValue());
4255
4257 Builder2, gradient_setter->getFunctionType(), gradient_setter,
4258 daddress, dchoice, dtrace);
4259 }
4260
4261 return true;
4262 }
4263
4264 if (call.hasFnAttr("enzyme_insert_argument")) {
4265 IRBuilder<> Builder2(&call);
4266 getReverseBuilder(Builder2);
4267
4268 auto name = call.getArgOperand(0);
4269 auto arg = call.getArgOperand(1);
4270 auto trace = call.getArgOperand(2);
4271
4272 auto gradient_setter = cast<Function>(
4273 cast<ValueAsMetadata>(
4274 call.getMetadata("enzyme_gradient_setter")->getOperand(0).get())
4275 ->getValue());
4276
4277 auto dtrace = lookup(gutils->getNewFromOriginal(trace), Builder2);
4278 auto dname = lookup(gutils->getNewFromOriginal(name), Builder2);
4279 Value *darg;
4280
4281 if (TR.query(arg)[{-1}].isPossiblePointer()) {
4282 darg = gutils->invertPointerM(arg, Builder2);
4283 } else {
4284 darg = diffe(arg, Builder2);
4285 }
4286
4288 gradient_setter->getFunctionType(),
4289 gradient_setter, dname, darg, dtrace);
4290 return true;
4291 }
4292
4293 return false;
4294}
void(* EnzymeShadowAllocRewrite)(LLVMValueRef, void *, LLVMValueRef, uint64_t, LLVMValueRef, uint8_t)
std::pair< const llvm::Value *, QueryType > UsageKey
CallInst * isSum(llvm::Value *v)
static bool isDeallocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory deallocation function For updating below one ...
static bool isAllocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory allocation function For updating below one sh...
static void zeroKnownAllocation(llvm::IRBuilder<> &bb, llvm::Value *toZero, llvm::ArrayRef< llvm::Value * > argValues, const llvm::StringRef funcName, const llvm::TargetLibraryInfo &TLI, llvm::CallInst *orig)
const llvm::StringMap< size_t > MPIInactiveCommAllocators
StringMap< std::function< Value *(IRBuilder<> &, CallInst *, ArrayRef< Value * >, GradientUtils *)> > shadowHandlers
llvm::CallInst * freeKnownAllocation(llvm::IRBuilder<> &builder, llvm::Value *tofree, llvm::StringRef allocationfn, const llvm::DebugLoc &debuglocation, const llvm::TargetLibraryInfo &TLI, llvm::CallInst *orig, GradientUtils *gutils)
Perform the corresponding deallocation of tofree, given it was allocated by allocationfn.
SmallVector< unsigned int, 9 > MD_ToCopy
llvm::cl::opt< bool > EnzymeFreeInternalAllocations
static bool isMemFreeLibMFunction(llvm::StringRef str, llvm::Intrinsic::ID *ID=nullptr)
CallInst * CreateDealloc(llvm::IRBuilder<> &Builder, llvm::Value *ToFree)
Definition Utils.cpp:742
llvm::Function * getOrInsertDifferentialMPI_Wait(llvm::Module &M, ArrayRef< llvm::Type * > T, Type *reqType, StringRef caller)
Definition Utils.cpp:2271
llvm::Value * getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr, llvm::Type *OpType, ConcreteType CT, llvm::Type *intType, IRBuilder<> &B2)
Definition Utils.cpp:2378
llvm::Function * getOrInsertDifferentialWaitallSave(llvm::Module &M, ArrayRef< llvm::Type * > T, PointerType *reqType)
Definition Utils.cpp:2203
llvm::Optional< BlasInfo > extractBLAS(llvm::StringRef in)
Definition Utils.cpp:3563
Value * CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count, const Twine &Name, CallInst **caller, Instruction **ZeroMem, bool isDefault)
Definition Utils.cpp:619
llvm::Value * EmitNoDerivativeError(const std::string &message, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &Builder2, llvm::Value *condition)
Definition Utils.cpp:4295
llvm::FastMathFlags getFast()
Get LLVM fast math flags.
Definition Utils.cpp:3731
llvm::Constant * getUndefinedValueForType(llvm::Module &M, llvm::Type *T, bool forceZero)
Definition Utils.cpp:3623
Function * getOrInsertCheckedFree(Module &M, CallInst *call, Type *Ty, unsigned width)
Definition Utils.cpp:2084
static std::string getRenamedPerCallingConv(llvm::StringRef caller, llvm::StringRef callee)
Definition Utils.h:2413
static bool startsWith(llvm::StringRef string, llvm::StringRef prefix)
Definition Utils.h:713
DIFFE_TYPE
Potential differentiable argument classifications.
Definition Utils.h:374
static llvm::PointerType * getUnqual(llvm::Type *T)
Definition Utils.h:1179
static llvm::Value * getMPIMemberPtr(llvm::IRBuilder<> &B, llvm::Value *V, llvm::Type *T)
Definition Utils.h:1200
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
Definition Utils.h:1174
static llvm::Function * getIntrinsicDeclaration(llvm::Module *M, llvm::Intrinsic::ID id, llvm::ArrayRef< llvm::Type * > Tys={})
Definition Utils.h:2263
static llvm::Value * getBaseObject(llvm::Value *V, bool offsetAllowed=true)
Definition Utils.h:1507
static llvm::MDNode * hasMetadata(const llvm::GlobalObject *O, llvm::StringRef kind)
Check if a global has metadata.
Definition Utils.h:339
static llvm::StructType * getMPIHelper(llvm::LLVMContext &Context)
Definition Utils.h:1183
llvm::cl::opt< bool > EnzymeJuliaAddrLoad
static llvm::Value * CreateSelect(llvm::IRBuilder<> &Builder2, llvm::Value *cmp, llvm::Value *tval, llvm::Value *fval, const llvm::Twine &Name="")
Definition Utils.h:2005
ValueType
Classification of value as an original program variable, a derivative variable, neither,...
Definition Utils.h:409
llvm::Value * lookup(llvm::Value *val, llvm::IRBuilder<> &Builder)
void DifferentiableMemCopyFloats(llvm::CallInst &call, llvm::Value *origArg, llvm::Value *dsto, llvm::Value *srco, llvm::Value *len_arg, llvm::IRBuilder<> &Builder2, llvm::ArrayRef< llvm::OperandBundleDef > ReverseDefs)
bool handleKnownCallDerivatives(llvm::CallInst &call, llvm::Function *called, llvm::StringRef funcName, bool subsequent_calls_may_write, const std::vector< bool > &overwritten_args, llvm::CallInst *const newCall)
llvm::Value * MPI_TYPE_SIZE(llvm::Value *DT, llvm::IRBuilder<> &B, llvm::Type *intType, llvm::Function *caller)
llvm::Value * applyChainRule(llvm::Type *diffType, llvm::IRBuilder<> &Builder, Func rule, Args... args)
Unwraps a vector derivative from its internal representation and applies a function f to each element...
void eraseIfUnused(llvm::Instruction &I, bool erase=true, bool check=true)
void getForwardBuilder(llvm::IRBuilder<> &Builder2)
void getReverseBuilder(llvm::IRBuilder<> &Builder2, bool original=true)
llvm::Value * MPI_COMM_SIZE(llvm::Value *comm, llvm::IRBuilder<> &B, llvm::Type *rankTy, llvm::Function *caller)
llvm::Value * MPI_COMM_RANK(llvm::Value *comm, llvm::IRBuilder<> &B, llvm::Type *rankTy, llvm::Function *caller)
bool handleAdjointForIntrinsic(llvm::Intrinsic::ID ID, llvm::Instruction &I, llvm::SmallVectorImpl< llvm::Value * > &orig_ops)
void handleMPI(llvm::CallInst &call, llvm::Function *called, llvm::StringRef funcName)
llvm::Function *const newFunc
The function whose instructions we are caching.
llvm::TargetLibraryInfo & TLI
Various analysis results of newFunc.
llvm::BasicBlock * inversionAllocs
Concrete SubType of a given value.
llvm::SmallPtrSet< llvm::Instruction *, 4 > unnecessaryIntermediates
llvm::ValueMap< const llvm::Value *, InvertedPointerVH > invertedPointers
llvm::DebugLoc getNewFromOriginal(const llvm::DebugLoc L) const
static llvm::Value * extractMeta(llvm::IRBuilder<> &Builder, llvm::Value *Agg, unsigned off, const llvm::Twine &name="")
Helper routine to extract a nested element from a struct/array. This is.
llvm::SmallPtrSet< llvm::Instruction *, 4 > TapesToPreventRecomputation
A set of tape extractions to enforce a cache of rather than attempting to recompute.
bool legalRecompute(const llvm::Value *val, const llvm::ValueToValueMapTy &available, llvm::IRBuilder<> *BuilderM, bool reverse=false, bool legalRecomputeCache=true) const
std::map< const llvm::Value *, bool > knownRecomputeHeuristic
llvm::BasicBlock * addReverseBlock(llvm::BasicBlock *currentBlock, llvm::Twine const &name, bool forkCache=true, bool push=true)
llvm::ValueMap< llvm::Value *, Rematerializer > rematerializableAllocations
unsigned getWidth()
std::map< llvm::BasicBlock *, llvm::SmallVector< llvm::BasicBlock *, 4 > > reverseBlocks
Map of primal block to corresponding block(s) in reverse.
llvm::Function * oldFunc
llvm::SmallVector< llvm::OperandBundleDef, 2 > getInvertedBundles(llvm::CallInst *orig, llvm::ArrayRef< ValueType > types, llvm::IRBuilder<> &Builder2, bool lookup, const llvm::ValueToValueMapTy &available=llvm::ValueToValueMapTy())
void replaceAWithB(llvm::Value *A, llvm::Value *B, bool storeInCache=false) override
Replace this instruction both in LLVM modules and any local data-structures.
llvm::Value * lookupM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &incoming_availalble=llvm::ValueToValueMapTy(), bool tryLegalRecomputeCheck=true, llvm::BasicBlock *scope=nullptr) override
High-level utility to get the value an instruction at a new location specified by BuilderM.
DIFFE_TYPE getReturnDiffeType(llvm::Value *orig, bool *primalReturnUsedP, bool *shadowReturnUsedP, DerivativeMode cmode) const
llvm::SmallVector< llvm::Instruction *, 1 > rematerializedPrimalOrShadowAllocations
std::map< llvm::BasicBlock *, llvm::BasicBlock * > reverseBlockToPrimal
Map of block in reverse to corresponding primal block.
bool isConstantInstruction(const llvm::Instruction *inst) const
llvm::Value * cacheForReverse(llvm::IRBuilder<> &BuilderQ, llvm::Value *malloc, int idx, bool replace=true)
llvm::ValueMap< llvm::Value *, ShadowRematerializer > backwardsOnlyShadows
Only loaded from and stored to (not captured), mapped to the stores (and memset).
bool isConstantValue(llvm::Value *val) const
llvm::Value * invertPointerM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, bool nullShadow=false)
void erase(llvm::Instruction *I) override
Erase this instruction both from LLVM modules and any local data-structures.
static llvm::CallInst * InsertChoiceGradient(llvm::IRBuilder<> &Builder, llvm::FunctionType *interface_type, llvm::Value *interface_function, llvm::Value *address, llvm::Value *choice, llvm::Value *trace)
static llvm::CallInst * InsertArgumentGradient(llvm::IRBuilder<> &Builder, llvm::FunctionType *interface_type, llvm::Value *interface_function, llvm::Value *name, llvm::Value *argument, llvm::Value *trace)
ConcreteType firstPointer(size_t num, llvm::Value *val, llvm::Instruction *I, bool errIfNotFound=true, bool pointerIntSame=false) const
Returns whether in the first num bytes there is pointer, int, float, or none If pointerIntSame is set...
bool is_value_needed_in_reverse(const GradientUtils *gutils, const llvm::Value *inst, DerivativeMode mode, std::map< UsageKey, bool > &seen, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable)