Enzyme main
Loading...
Searching...
No Matches
DiffeGradientUtils.cpp
Go to the documentation of this file.
1//===- DiffeGradientUtils.cpp - Helper class and utilities for AD ---------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file declares two helper classes GradientUtils and subclass
22// DiffeGradientUtils. These classes contain utilities for managing the cache,
23// recomputing statements, and in the case of DiffeGradientUtils, managing
24// adjoint values and shadow pointers.
25//
26//===----------------------------------------------------------------------===//
27
28#include <string>
29
30#include "DiffeGradientUtils.h"
31
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/SmallPtrSet.h"
34#include "llvm/ADT/SmallVector.h"
35
36#include "llvm/IR/BasicBlock.h"
37#include "llvm/IR/DebugInfoMetadata.h"
38#include "llvm/IR/Dominators.h"
39#include "llvm/IR/IRBuilder.h"
40#include "llvm/IR/Instructions.h"
41#include "llvm/IR/Type.h"
42#include "llvm/IR/Value.h"
43
44#include "llvm/Transforms/Utils/BasicBlockUtils.h"
45
46#include "llvm/Support/Casting.h"
47#include "llvm/Support/ErrorHandling.h"
48
49#include "LibraryFuncs.h"
50#include "Utils.h"
51
52using namespace llvm;
53
54namespace {
55bool elementwiseReadForContext(const Instruction *orig, const Value *origptr) {
56 if (orig) {
57 if (const Function *F = orig->getFunction()) {
58 if (F->hasFnAttribute("enzyme_elementwise_read")) {
59 return true;
60 }
61 }
62 }
63 const Value *base = getBaseObject(origptr);
64 if (auto *arg = dyn_cast<Argument>(base)) {
65 if (const Function *F = arg->getParent()) {
66 return F->getAttributes().hasParamAttr(arg->getArgNo(),
67 "enzyme_elementwise_read");
68 }
69 }
70 return false;
71}
72} // namespace
73
74DiffeGradientUtils::DiffeGradientUtils(
75 EnzymeLogic &Logic, Function *newFunc_, Function *oldFunc_,
76 TargetLibraryInfo &TLI, TypeAnalysis &TA, TypeResults TR,
77 ValueToValueMapTy &invertedPointers_,
78 const SmallPtrSetImpl<Value *> &constantvalues_,
79 const SmallPtrSetImpl<Value *> &returnvals_, DIFFE_TYPE ActiveReturn,
80 bool shadowReturnUsed, ArrayRef<DIFFE_TYPE> constant_values,
81 llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> &origToNew_,
82 DerivativeMode mode, bool runtimeActivity, bool strongZero, unsigned width,
83 bool omp)
84 : GradientUtils(Logic, newFunc_, oldFunc_, TLI, TA, TR, invertedPointers_,
85 constantvalues_, returnvals_, ActiveReturn,
86 shadowReturnUsed, constant_values, origToNew_, mode,
87 runtimeActivity, strongZero, width, omp) {
88 if (oldFunc_->empty())
89 return;
90 assert(reverseBlocks.size() == 0);
91 if (mode == DerivativeMode::ForwardMode ||
94 return;
95 }
96 for (BasicBlock *BB : originalBlocks) {
97 if (BB == inversionAllocs)
98 continue;
99 BasicBlock *RBB =
100 BasicBlock::Create(BB->getContext(), "invert" + BB->getName(), newFunc);
101 reverseBlocks[BB].push_back(RBB);
102 reverseBlockToPrimal[RBB] = BB;
103 }
104 assert(reverseBlocks.size() != 0);
105}
106
108 EnzymeLogic &Logic, DerivativeMode mode, bool runtimeActivity,
109 bool strongZero, unsigned width, Function *todiff, TargetLibraryInfo &TLI,
110 TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType,
111 bool shadowReturn, bool diffeReturnArg, ArrayRef<DIFFE_TYPE> constant_args,
112 bool returnTape, bool returnPrimal, Type *additionalArg, bool omp) {
113 Function *oldFunc = todiff;
119 ValueToValueMapTy invertedPointers;
120 SmallPtrSet<Instruction *, 4> constants;
121 SmallPtrSet<Instruction *, 20> nonconstant;
122 SmallPtrSet<Value *, 2> returnvals;
123 llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> originalToNew;
124
125 SmallPtrSet<Value *, 4> constant_values;
126 SmallPtrSet<Value *, 4> nonconstant_values;
127
128 std::string prefix;
129
130 switch (mode) {
134 prefix = "fwddiffe";
135 break;
138 prefix = "diffe";
139 break;
141 llvm_unreachable("invalid DerivativeMode: ReverseModePrimal\n");
142 }
143
144 if (width > 1)
145 prefix += std::to_string(width);
146
148 mode, width, oldFunc, invertedPointers, constant_args, constant_values,
149 nonconstant_values, returnvals, returnTape, returnPrimal,
150 (mode == DerivativeMode::ReverseModeGradient) ? false : shadowReturn,
151 prefix + oldFunc->getName(), &originalToNew,
152 /*diffeReturnArg*/ diffeReturnArg, additionalArg);
153
154 // Convert overwritten args from the input function to the preprocessed
155 // function
156
157 FnTypeInfo typeInfo(oldFunc);
158 {
159 auto toarg = todiff->arg_begin();
160 auto olarg = oldFunc->arg_begin();
161 for (; toarg != todiff->arg_end(); ++toarg, ++olarg) {
162
163 {
164 auto fd = oldTypeInfo.Arguments.find(toarg);
165 assert(fd != oldTypeInfo.Arguments.end());
166 typeInfo.Arguments.insert(
167 std::pair<Argument *, TypeTree>(olarg, fd->second));
168 }
169
170 {
171 auto cfd = oldTypeInfo.KnownValues.find(toarg);
172 assert(cfd != oldTypeInfo.KnownValues.end());
173 typeInfo.KnownValues.insert(
174 std::pair<Argument *, std::set<int64_t>>(olarg, cfd->second));
175 }
176 }
177 typeInfo.Return = oldTypeInfo.Return;
178 }
179
180 TypeResults TR = TA.analyzeFunction(typeInfo);
181 if (!oldFunc->empty())
182 assert(TR.getFunction() == oldFunc);
183
184 auto res = new DiffeGradientUtils(
185 Logic, newFunc, oldFunc, TLI, TA, TR, invertedPointers, constant_values,
186 nonconstant_values, retType, shadowReturn, constant_args, originalToNew,
188
189 return res;
190}
191
192AllocaInst *DiffeGradientUtils::getDifferential(Value *val) {
196 assert(val);
197#ifndef NDEBUG
198 if (auto arg = dyn_cast<Argument>(val))
199 assert(arg->getParent() == oldFunc);
200 if (auto inst = dyn_cast<Instruction>(val))
201 assert(inst->getParent()->getParent() == oldFunc);
202#endif
203 assert(inversionAllocs);
204
205 Type *type = getShadowType(val->getType());
206 if (differentials.find(val) == differentials.end()) {
207 IRBuilder<> entryBuilder(inversionAllocs);
208 entryBuilder.setFastMathFlags(getFast());
209 differentials[val] =
210 entryBuilder.CreateAlloca(type, nullptr, val->getName() + "'de");
211 auto Alignment =
212 oldFunc->getParent()->getDataLayout().getPrefTypeAlign(type);
213 differentials[val]->setAlignment(Alignment);
214 ZeroMemory(entryBuilder, type, differentials[val],
215 /*isTape*/ false);
216 }
217#if LLVM_VERSION_MAJOR < 17
218 if (val->getContext().supportsTypedPointers()) {
219 assert(differentials[val]->getType()->getPointerElementType() == type);
220 }
221#endif
222 return differentials[val];
223}
224
225Value *DiffeGradientUtils::diffe(Value *val, IRBuilder<> &BuilderM) {
226#ifndef NDEBUG
227 if (auto arg = dyn_cast<Argument>(val))
228 assert(arg->getParent() == oldFunc);
229 if (auto inst = dyn_cast<Instruction>(val))
230 assert(inst->getParent()->getParent() == oldFunc);
231#endif
232
233 if (isConstantValue(val)) {
234 llvm::errs() << *newFunc << "\n";
235 llvm::errs() << *val << "\n";
236 assert(0 && "getting diffe of constant value");
237 }
241 return invertPointerM(val, BuilderM);
242 if (val->getType()->isPointerTy()) {
243 llvm::errs() << *newFunc << "\n";
244 llvm::errs() << *val << "\n";
245 }
246 assert(!val->getType()->isPointerTy());
247 assert(!val->getType()->isVoidTy());
248 Type *ty = getShadowType(val->getType());
249 return BuilderM.CreateLoad(ty, getDifferential(val));
250}
251
252SmallVector<SelectInst *, 4> DiffeGradientUtils::addToDiffe(
253 Value *val, Value *dif, IRBuilder<> &BuilderM, Type *addingType,
254 unsigned start, unsigned size, llvm::ArrayRef<llvm::Value *> idxs,
255 llvm::Value *mask, size_t ignoreFirstSlicesOfDif) {
256 assert(addingType);
257 auto &DL = oldFunc->getParent()->getDataLayout();
258 Type *VT = val->getType();
259 for (auto cv : idxs) {
260 auto i = dyn_cast<ConstantInt>(cv)->getSExtValue();
261 if (auto ST = dyn_cast<StructType>(VT)) {
262 VT = ST->getElementType(i);
263 continue;
264 }
265 if (auto AT = dyn_cast<ArrayType>(VT)) {
266 assert((size_t)i < AT->getNumElements());
267 VT = AT->getElementType();
268 continue;
269 }
270 assert(0 && "illegal indexing type");
271 }
272 auto storeSize = (DL.getTypeSizeInBits(VT) + 7) / 8;
273
274 assert(start < storeSize);
275 assert(start + size <= storeSize);
276
277 // If VT is a struct type the addToDiffe algorithm will lose type information
278 // so we do the recurrence here, with full type information.
279 if (start == 0 && size == storeSize && !isa<StructType>(VT)) {
280 if (getWidth() == 1) {
281 SmallVector<unsigned, 1> eidxs;
282 for (auto idx : idxs.slice(ignoreFirstSlicesOfDif)) {
283 eidxs.push_back((unsigned)cast<ConstantInt>(idx)->getZExtValue());
284 }
285 return addToDiffe(val, extractMeta(BuilderM, dif, eidxs), BuilderM,
286 addingType, idxs, mask);
287 } else {
288 SmallVector<SelectInst *, 4> res;
289 for (unsigned j = 0; j < getWidth(); j++) {
290 SmallVector<Value *, 1> lidxs;
291 SmallVector<unsigned, 1> eidxs = {(unsigned)j};
292 lidxs.push_back(
293 ConstantInt::get(Type::getInt32Ty(val->getContext()), j));
294 for (auto idx : idxs.slice(ignoreFirstSlicesOfDif)) {
295 eidxs.push_back((unsigned)cast<ConstantInt>(idx)->getZExtValue());
296 }
297 for (auto idx : idxs) {
298 lidxs.push_back(idx);
299 }
300 for (auto v : addToDiffe(val, extractMeta(BuilderM, dif, eidxs),
301 BuilderM, addingType, lidxs, mask))
302 res.push_back(v);
303 }
304 return res;
305 }
306 }
307 if (auto ST = dyn_cast<StructType>(VT)) {
308 auto SL = DL.getStructLayout(ST);
309 auto left_idx = SL->getElementContainingOffset(start);
310 auto right_idx = ST->getNumElements();
311 if (storeSize != start + size) {
312 right_idx = SL->getElementContainingOffset(start + size);
313 // If this doesn't cleanly end the window, make sure we do a partial
314 // accumulate for the remaining part in right_idx.
315 if (SL->getElementOffset(right_idx) != start + size)
316 right_idx++;
317 }
318 SmallVector<SelectInst *, 4> res;
319 for (auto i = left_idx; i < right_idx; i++) {
320 auto subType = ST->getElementType(i);
321 SmallVector<Value *, 1> lidxs(idxs.begin(), idxs.end());
322 lidxs.push_back(ConstantInt::get(Type::getInt32Ty(val->getContext()), i));
323 auto sub_start =
324 (i == left_idx) ? (start - (unsigned)SL->getElementOffset(i)) : 0;
325 auto subTypeSize = (DL.getTypeSizeInBits(subType) + 7) / 8;
326 auto sub_end = (i == right_idx - 1)
327 ? min(start + size - (unsigned)SL->getElementOffset(i),
328 (unsigned)subTypeSize)
329 : subTypeSize;
330 for (auto v :
331 addToDiffe(val, dif, BuilderM, addingType, sub_start,
332 sub_end - sub_start, lidxs, mask, ignoreFirstSlicesOfDif))
333 res.push_back(v);
334 }
335 return res;
336 }
337
338 if (auto AT = dyn_cast<ArrayType>(VT)) {
339 auto subType = AT->getElementType();
340 auto subTypeSize = (DL.getTypeSizeInBits(subType) + 7) / 8;
341 auto left_idx = start / subTypeSize;
342 auto right_idx = AT->getNumElements();
343 if (storeSize != start + size) {
344 right_idx = (start + size) / subTypeSize;
345 // If this doesn't cleanly end the window, make sure we do a partial
346 // accumulate for the remaining part in right_idx.
347 if (right_idx * subTypeSize != start + size)
348 right_idx++;
349 }
350 SmallVector<SelectInst *, 4> res;
351 for (auto i = left_idx; i < right_idx; i++) {
352 SmallVector<Value *, 1> lidxs(idxs.begin(), idxs.end());
353 lidxs.push_back(ConstantInt::get(Type::getInt32Ty(val->getContext()), i));
354 auto sub_start = (i == left_idx) ? (start - (i * subTypeSize)) : 0;
355 auto sub_end = (i == right_idx - 1)
356 ? min(start + size - (unsigned)(i * subTypeSize),
357 (unsigned)subTypeSize)
358 : subTypeSize;
359 for (auto v :
360 addToDiffe(val, dif, BuilderM, addingType, sub_start,
361 sub_end - sub_start, lidxs, mask, ignoreFirstSlicesOfDif))
362 res.push_back(v);
363 }
364 return res;
365 }
366
367 if (auto VecT = dyn_cast<VectorType>(VT)) {
368 if (!VecT->getElementCount().isScalable()) {
369 Type *elemTy = VecT->getElementType();
370 auto elemBytes = (DL.getTypeSizeInBits(elemTy) + 7) / 8;
371
372 // Only handle element-aligned windows
373 if (elemBytes != 0 && start % elemBytes == 0 && size % elemBytes == 0) {
374 unsigned left_idx = start / elemBytes;
375 unsigned right_idx = (start + size) / elemBytes; // exclusive
376
377 unsigned numElts = VecT->getElementCount().getFixedValue();
378 if (left_idx > numElts)
379 left_idx = numElts;
380 if (right_idx > numElts)
381 right_idx = numElts;
382
383 auto maskVec = [&](Value *dsub) -> Value * {
384 if (left_idx == 0 && right_idx == numElts)
385 return dsub;
386 Value *masked = Constant::getNullValue(VT);
387 for (unsigned i = left_idx; i < right_idx; i++) {
388 Value *vidx =
389 ConstantInt::get(Type::getInt32Ty(val->getContext()), i);
390 Value *el = BuilderM.CreateExtractElement(dsub, vidx);
391 masked = BuilderM.CreateInsertElement(masked, el, vidx);
392 }
393 return masked;
394 };
395
396 if (getWidth() == 1) {
397 SmallVector<unsigned, 1> eidxs;
398 for (auto idx : idxs.slice(ignoreFirstSlicesOfDif))
399 eidxs.push_back((unsigned)cast<ConstantInt>(idx)->getZExtValue());
400
401 Value *subdif = extractMeta(BuilderM, dif, eidxs);
402 return addToDiffe(val, maskVec(subdif), BuilderM, addingType, idxs,
403 mask);
404 } else {
405 SmallVector<SelectInst *, 4> res;
406 for (unsigned j = 0; j < getWidth(); j++) {
407 SmallVector<Value *, 1> lidxs;
408 SmallVector<unsigned, 1> eidxs = {(unsigned)j};
409
410 lidxs.push_back(
411 ConstantInt::get(Type::getInt32Ty(val->getContext()), j));
412 for (auto idx : idxs.slice(ignoreFirstSlicesOfDif))
413 eidxs.push_back((unsigned)cast<ConstantInt>(idx)->getZExtValue());
414 for (auto idx : idxs)
415 lidxs.push_back(idx);
416
417 Value *subdif = extractMeta(BuilderM, dif, eidxs);
418 for (auto v : addToDiffe(val, maskVec(subdif), BuilderM, addingType,
419 lidxs, mask))
420 res.push_back(v);
421 }
422 return res;
423 }
424 }
425 }
426 }
427
428 llvm::errs() << " VT: " << *VT << " idxs:{";
429 for (auto idx : idxs)
430 llvm::errs() << *idx << ",";
431 llvm::errs() << "} start=" << start << " size=" << size
432 << " storeSize=" << storeSize << " val=" << *val << "\n";
433 assert(0 && "unhandled accumulate with partial sizes");
434 return {};
435}
436
437static bool isZero(llvm::Constant *cst) {
438#if LLVM_VERSION_MAJOR >= 22
439 return cst->isNullValue() || cst->isNegativeZeroValue();
440#else
441 return cst->isZeroValue();
442#endif
443}
444SmallVector<SelectInst *, 4>
445DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM,
446 Type *addingType, ArrayRef<Value *> idxs,
447 Value *mask) {
450
451#ifndef NDEBUG
452 if (auto arg = dyn_cast<Argument>(val))
453 assert(arg->getParent() == oldFunc);
454 if (auto inst = dyn_cast<Instruction>(val))
455 assert(inst->getParent()->getParent() == oldFunc);
456#endif
457
458 SmallVector<SelectInst *, 4> addedSelects;
459
460 auto faddForNeg = [&](Value *old, Value *inc, bool san) {
461 if (auto bi = dyn_cast<BinaryOperator>(inc)) {
462 if (auto ci = dyn_cast<ConstantFP>(bi->getOperand(0))) {
463 if (bi->getOpcode() == BinaryOperator::FSub && ci->isZero()) {
464 Value *res = BuilderM.CreateFSub(old, bi->getOperand(1));
465 if (san)
466 res = SanitizeDerivatives(val, res, BuilderM, mask);
467 return res;
468 }
469 }
470 }
471 Value *res = BuilderM.CreateFAdd(old, inc);
472 if (san)
473 res = SanitizeDerivatives(val, res, BuilderM, mask);
474 return res;
475 };
476
477 auto faddForSelect = [&](Value *old, Value *dif) -> Value * {
478 //! optimize fadd of select to select of fadd
479 if (SelectInst *select = dyn_cast<SelectInst>(dif)) {
480 if (Constant *ci = dyn_cast<Constant>(select->getTrueValue())) {
481 if (isZero(ci)) {
482 SelectInst *res = cast<SelectInst>(BuilderM.CreateSelect(
483 select->getCondition(), old,
484 faddForNeg(old, select->getFalseValue(), false)));
485 addedSelects.push_back(res);
486 return SanitizeDerivatives(val, res, BuilderM, mask);
487 }
488 }
489 if (Constant *ci = dyn_cast<Constant>(select->getFalseValue())) {
490 if (isZero(ci)) {
491 SelectInst *res = cast<SelectInst>(BuilderM.CreateSelect(
492 select->getCondition(),
493 faddForNeg(old, select->getTrueValue(), false), old));
494 addedSelects.push_back(res);
495 return SanitizeDerivatives(val, res, BuilderM, mask);
496 }
497 }
498 }
499
500 //! optimize fadd of bitcast select to select of bitcast fadd
501 if (BitCastInst *bc = dyn_cast<BitCastInst>(dif)) {
502 if (SelectInst *select = dyn_cast<SelectInst>(bc->getOperand(0))) {
503 if (Constant *ci = dyn_cast<Constant>(select->getTrueValue())) {
504 if (isZero(ci)) {
505 SelectInst *res = cast<SelectInst>(BuilderM.CreateSelect(
506 select->getCondition(), old,
507 faddForNeg(old,
508 BuilderM.CreateCast(bc->getOpcode(),
509 select->getFalseValue(),
510 bc->getDestTy()),
511 false)));
512 addedSelects.push_back(res);
513 return SanitizeDerivatives(val, res, BuilderM, mask);
514 }
515 }
516 if (Constant *ci = dyn_cast<Constant>(select->getFalseValue())) {
517 if (isZero(ci)) {
518 SelectInst *res = cast<SelectInst>(BuilderM.CreateSelect(
519 select->getCondition(),
520 faddForNeg(old,
521 BuilderM.CreateCast(bc->getOpcode(),
522 select->getTrueValue(),
523 bc->getDestTy()),
524 false),
525 old));
526 addedSelects.push_back(res);
527 return SanitizeDerivatives(val, res, BuilderM, mask);
528 }
529 }
530 }
531 }
532
533 // fallback
534 return faddForNeg(old, dif, true);
535 };
536
537 if (val->getType()->isPointerTy()) {
538 llvm::errs() << *newFunc << "\n";
539 llvm::errs() << *val << "\n";
540 }
541 if (isConstantValue(val)) {
542 llvm::errs() << *newFunc << "\n";
543 llvm::errs() << *val << "\n";
544 }
545 assert(!val->getType()->isPointerTy());
546 assert(!isConstantValue(val));
547
548 Value *ptr = getDifferential(val);
549
550 Value *old;
551 if (idxs.size() != 0) {
552 SmallVector<Value *, 4> sv = {
553 ConstantInt::get(Type::getInt32Ty(val->getContext()), 0)};
554 for (auto i : idxs)
555 sv.push_back(i);
556 ptr = BuilderM.CreateGEP(getShadowType(val->getType()), ptr, sv);
557 cast<GetElementPtrInst>(ptr)->setIsInBounds(true);
558 old = BuilderM.CreateLoad(
559 GetElementPtrInst::getIndexedType(getShadowType(val->getType()), sv),
560 ptr);
561 } else {
562 old = BuilderM.CreateLoad(getShadowType(val->getType()), ptr);
563 }
564 if (dif->getType() != old->getType()) {
565 if (auto inst = dyn_cast<Instruction>(val)) {
566 EmitFailure("IllegalAddingType", inst->getDebugLoc(), inst, "val ", *val,
567 " dif ", *dif, " old ", *old);
568 return addedSelects;
569 }
570 llvm::errs() << " IllegalAddingType val: " << *val << " dif: " << *dif
571 << " old: " << *old << "\n";
572 llvm_unreachable("IllegalAddingType");
573 }
574
575 assert(dif->getType() == old->getType());
576 Value *res = nullptr;
577 if (old->getType()->isIntOrIntVectorTy() || old->getType()->isPointerTy()) {
578 if (!addingType) {
579 if (looseTypeAnalysis) {
580 if (old->getType()->isIntegerTy(64))
581 addingType = Type::getDoubleTy(old->getContext());
582 else if (old->getType()->isIntegerTy(32))
583 addingType = Type::getFloatTy(old->getContext());
584 }
585 }
586 if (!addingType) {
587 std::string s;
588 llvm::raw_string_ostream ss(s);
589 ss << "oldFunc: " << *oldFunc << "\n";
590 ss << "Cannot deduce adding type of: " << *val << "\n";
591 ss << " + idxs {";
592 for (auto idx : idxs)
593 ss << *idx << ",";
594 ss << "}\n";
595 if (auto inst = dyn_cast<Instruction>(val)) {
596 EmitNoTypeError(ss.str(), *inst, this, BuilderM);
597 return addedSelects;
598 } else if (CustomErrorHandler) {
599 CustomErrorHandler(ss.str().c_str(), wrap(val), ErrorType::NoType,
600 TR.analyzer, nullptr, wrap(&BuilderM));
601 return addedSelects;
602 } else {
603 TR.dump(ss);
604 llvm::errs() << ss.str() << "\n";
605 llvm_unreachable("Cannot deduce adding type");
606 return addedSelects;
607 }
608 }
609 assert(addingType);
610 assert(addingType->isFPOrFPVectorTy());
611
612 auto oldBitSize =
613 oldFunc->getParent()->getDataLayout().getTypeSizeInBits(old->getType());
614 auto newBitSize =
615 oldFunc->getParent()->getDataLayout().getTypeSizeInBits(addingType);
616
617 if (oldBitSize == newBitSize) {
618 } else if (oldBitSize > newBitSize && oldBitSize % newBitSize == 0) {
619 if (!addingType->isVectorTy())
620 addingType =
621 VectorType::get(addingType, oldBitSize / newBitSize, false);
622 } else {
623 std::string s;
624 llvm::raw_string_ostream ss(s);
625 ss << "oldFunc: " << *oldFunc << "\n";
626 ss << "Illegal intermediate when adding to: " << *val
627 << " with addingType: " << *addingType << "\n"
628 << " old: " << *old << " dif: " << *dif << "\n"
629 << " oldBitSize: " << oldBitSize << " newBitSize: " << newBitSize
630 << "\n";
631 if (CustomErrorHandler) {
632 CustomErrorHandler(ss.str().c_str(), wrap(val), ErrorType::NoType,
633 TR.analyzer, nullptr, wrap(&BuilderM));
634 return addedSelects;
635 } else {
636 DebugLoc loc;
637 if (auto inst = dyn_cast<Instruction>(val))
638 EmitFailure("CannotDeduceType", inst->getDebugLoc(), inst, ss.str());
639 else {
640 llvm::errs() << ss.str() << "\n";
641 llvm_unreachable("Cannot deduce adding type");
642 }
643 return addedSelects;
644 }
645 }
646
647 Value *bcold = old;
648 Value *bcdif = dif;
649 Type *intTy = nullptr;
650 if (old->getType()->isPointerTy()) {
651 auto &DL = oldFunc->getParent()->getDataLayout();
652 intTy = Type::getIntNTy(old->getContext(), DL.getPointerSizeInBits());
653 bcold = BuilderM.CreatePtrToInt(bcold, intTy);
654 bcdif = BuilderM.CreatePtrToInt(bcdif, intTy);
655 } else {
656 intTy = old->getType();
657 }
658
659 bcold = BuilderM.CreateBitCast(bcold, addingType);
660 bcdif = BuilderM.CreateBitCast(bcdif, addingType);
661
662 res = faddForSelect(bcold, bcdif);
663 if (SelectInst *select = dyn_cast<SelectInst>(res)) {
664 assert(addedSelects.back() == select);
665 addedSelects.erase(addedSelects.end() - 1);
666
667 Value *tval = BuilderM.CreateBitCast(select->getTrueValue(), intTy);
668 Value *fval = BuilderM.CreateBitCast(select->getFalseValue(), intTy);
669 if (old->getType()->isPointerTy()) {
670 tval = BuilderM.CreateIntToPtr(tval, old->getType());
671 fval = BuilderM.CreateIntToPtr(fval, old->getType());
672 }
673 res = BuilderM.CreateSelect(select->getCondition(), tval, fval);
674 assert(select->getNumUses() == 0);
675 } else {
676 res = BuilderM.CreateBitCast(res, intTy);
677 if (old->getType()->isPointerTy())
678 res = BuilderM.CreateIntToPtr(res, old->getType());
679 }
680 if (!mask) {
681 BuilderM.CreateStore(res, ptr);
682 // store->setAlignment(align);
683 } else {
684 Type *tys[] = {res->getType(), ptr->getType()};
685 auto F = getIntrinsicDeclaration(oldFunc->getParent(),
686 Intrinsic::masked_store, tys);
687 auto align = cast<AllocaInst>(ptr)->getAlign().value();
688 assert(align);
689 Value *alignv =
690 ConstantInt::get(Type::getInt32Ty(mask->getContext()), align);
691 Value *args[] = {res, ptr, alignv, mask};
692 BuilderM.CreateCall(F, args);
693 }
694 return addedSelects;
695 } else if (old->getType()->isFPOrFPVectorTy()) {
696 // TODO consider adding type
697 res = faddForSelect(old, dif);
698
699 if (!mask) {
700 BuilderM.CreateStore(res, ptr);
701 // store->setAlignment(align);
702 } else {
703 Type *tys[] = {res->getType(), ptr->getType()};
704 auto F = getIntrinsicDeclaration(oldFunc->getParent(),
705 Intrinsic::masked_store, tys);
706 auto align = cast<AllocaInst>(ptr)->getAlign().value();
707 assert(align);
708 Value *alignv =
709 ConstantInt::get(Type::getInt32Ty(mask->getContext()), align);
710 Value *args[] = {res, ptr, alignv, mask};
711 BuilderM.CreateCall(F, args);
712 }
713 return addedSelects;
714 } else if (auto st = dyn_cast<StructType>(old->getType())) {
715 assert(!mask);
716 if (mask)
717 llvm_unreachable("cannot handle recursive addToDiffe with mask");
718 for (unsigned i = 0; i < st->getNumElements(); ++i) {
719 // TODO pass in full type tree here and recurse into tree.
720 if (st->getElementType(i)->isPointerTy())
721 continue;
722 if (st->getElementType(i)->isIntegerTy(8) ||
723 st->getElementType(i)->isIntegerTy(1))
724 continue;
725 Value *v = ConstantInt::get(Type::getInt32Ty(st->getContext()), i);
726 SmallVector<Value *, 2> idx2(idxs.begin(), idxs.end());
727 idx2.push_back(v);
728 // FIXME: reconsider if passing a nullptr is correct here.
729 auto selects = addToDiffe(val, extractMeta(BuilderM, dif, i), BuilderM,
730 nullptr, idx2);
731 for (auto select : selects) {
732 addedSelects.push_back(select);
733 }
734 }
735 return addedSelects;
736 } else if (auto at = dyn_cast<ArrayType>(old->getType())) {
737 assert(!mask);
738 if (mask)
739 llvm_unreachable("cannot handle recursive addToDiffe with mask");
740 if (at->getElementType()->isPointerTy())
741 return addedSelects;
742 for (unsigned i = 0; i < at->getNumElements(); ++i) {
743 // TODO pass in full type tree here and recurse into tree.
744 Value *v = ConstantInt::get(Type::getInt32Ty(at->getContext()), i);
745 SmallVector<Value *, 2> idx2(idxs.begin(), idxs.end());
746 idx2.push_back(v);
747 auto selects = addToDiffe(val, extractMeta(BuilderM, dif, i), BuilderM,
748 addingType, idx2);
749 for (auto select : selects) {
750 addedSelects.push_back(select);
751 }
752 }
753 return addedSelects;
754 } else {
755 llvm::errs() << " idx: {";
756 for (auto i : idxs)
757 llvm::errs() << *i << ", ";
758 llvm::errs() << "}\n";
759 if (addingType)
760 llvm::errs() << " addingType: " << *addingType << "\n";
761 else
762 llvm::errs() << " addingType: null\n";
763 llvm::errs() << " oldType:" << *old->getType() << " old:" << *old << "\n";
764 llvm_unreachable("unknown type to add to diffe");
765 exit(1);
766 }
767}
768
769void DiffeGradientUtils::setDiffe(Value *val, Value *toset,
770 IRBuilder<> &BuilderM) {
771#ifndef NDEBUG
772 if (auto arg = dyn_cast<Argument>(val))
773 assert(arg->getParent() == oldFunc);
774 if (auto inst = dyn_cast<Instruction>(val))
775 assert(inst->getParent()->getParent() == oldFunc);
776 if (isConstantValue(val)) {
777 llvm::errs() << *newFunc << "\n";
778 llvm::errs() << *val << "\n";
779 }
780 assert(!isConstantValue(val));
781#endif
782 toset = SanitizeDerivatives(val, toset, BuilderM);
786 assert(getShadowType(val->getType()) == toset->getType());
787 auto found = invertedPointers.find(val);
788 assert(found != invertedPointers.end());
789 auto placeholder0 = &*found->second;
790 auto placeholder = cast<PHINode>(placeholder0);
791 invertedPointers.erase(found);
792 replaceAWithB(placeholder, toset);
793 placeholder->replaceAllUsesWith(toset);
794 erase(placeholder);
795 invertedPointers.insert(
796 std::make_pair((const Value *)val, InvertedPointerVH(this, toset)));
797 return;
798 }
799 Value *tostore = getDifferential(val);
800#if LLVM_VERSION_MAJOR < 17
801 if (toset->getContext().supportsTypedPointers()) {
802 if (toset->getType() != tostore->getType()->getPointerElementType()) {
803 llvm::errs() << "toset:" << *toset << "\n";
804 llvm::errs() << "tostore:" << *tostore << "\n";
805 }
806 assert(toset->getType() == tostore->getType()->getPointerElementType());
807 }
808#endif
809 BuilderM.CreateStore(toset, tostore);
810}
811
812CallInst *DiffeGradientUtils::freeCache(BasicBlock *forwardPreheader,
813 const SubLimitType &sublimits, int i,
814 AllocaInst *alloc, llvm::Type *T,
815 ConstantInt *byteSizeOfType,
816 Value *storeInto, MDNode *InvariantMD) {
817 if (!FreeMemory)
818 return nullptr;
819 assert(reverseBlocks.find(forwardPreheader) != reverseBlocks.end());
820 assert(reverseBlocks[forwardPreheader].size());
821 IRBuilder<> tbuild(reverseBlocks[forwardPreheader].back());
822 tbuild.setFastMathFlags(getFast());
823
824 // ensure we are before the terminator if it exists
825 if (tbuild.GetInsertBlock()->size() &&
826 tbuild.GetInsertBlock()->getTerminator()) {
827 tbuild.SetInsertPoint(tbuild.GetInsertBlock()->getTerminator());
828 }
829
830 ValueToValueMapTy antimap;
831 for (int j = sublimits.size() - 1; j >= i; j--) {
832 auto &innercontainedloops = sublimits[j].second;
833 for (auto riter = innercontainedloops.rbegin(),
834 rend = innercontainedloops.rend();
835 riter != rend; ++riter) {
836 const auto &idx = riter->first;
837 if (idx.var) {
838 antimap[idx.var] =
839 tbuild.CreateLoad(idx.var->getType(), idx.antivaralloc);
840 }
841 }
842 }
843
844 Value *metaforfree = unwrapM(storeInto, tbuild, antimap,
846
847#if LLVM_VERSION_MAJOR < 17
848 if (metaforfree->getContext().supportsTypedPointers()) {
849 assert(T == metaforfree->getType()->getPointerElementType());
850 }
851#endif
852
853 LoadInst *forfree = cast<LoadInst>(tbuild.CreateLoad(T, metaforfree));
854 forfree->setMetadata(LLVMContext::MD_invariant_group, InvariantMD);
855 forfree->setMetadata(LLVMContext::MD_dereferenceable,
856 MDNode::get(forfree->getContext(),
857 ArrayRef<Metadata *>(ConstantAsMetadata::get(
858 byteSizeOfType))));
859 forfree->setName("forfree");
860 unsigned align = getCacheAlignment(
861 (unsigned)newFunc->getParent()->getDataLayout().getPointerSize());
862 forfree->setAlignment(Align(align));
863
864 CallInst *ci = CreateDealloc(tbuild, forfree);
865 if (ci) {
866 if (newFunc->getSubprogram())
867 ci->setDebugLoc(DILocation::get(newFunc->getContext(), 0, 0,
868 newFunc->getSubprogram(), 0));
869 scopeFrees[alloc].insert(ci);
870 }
871 return ci;
872}
873
874void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
875 Value *origVal, Type *addingType,
876 unsigned start, unsigned size,
877 Value *origptr, Value *dif,
878 IRBuilder<> &BuilderM,
879 MaybeAlign align, Value *mask) {
880 auto &DL = oldFunc->getParent()->getDataLayout();
881
882 auto addingSize = (DL.getTypeSizeInBits(addingType) + 1) / 8;
883 if (addingSize != size) {
884 assert(size > addingSize);
885 addingType =
886 VectorType::get(addingType, size / addingSize, /*isScalable*/ false);
887 size = (size / addingSize) * addingSize;
888 }
889
890 Value *ptr;
891
892 switch (mode) {
896 ptr = invertPointerM(origptr, BuilderM);
897 break;
899 assert(false && "Invalid derivative mode (ReverseModePrimal)");
900 break;
903 ptr = lookupM(invertPointerM(origptr, BuilderM), BuilderM);
904 break;
905 }
906
907 bool needsCast = false;
908#if LLVM_VERSION_MAJOR < 17
909 if (isa<PointerType>(origptr->getType()) &&
910 origptr->getContext().supportsTypedPointers()) {
911 needsCast = origptr->getType()->getPointerElementType() != addingType;
912 }
913#endif
914
915 assert(ptr);
916 if (start != 0 || needsCast || !isa<PointerType>(origptr->getType())) {
917 auto rule = [&](Value *ptr) {
918 if (!isa<PointerType>(origptr->getType())) {
919 ptr = BuilderM.CreateIntToPtr(ptr, getUnqual(addingType));
920 }
921 if (start != 0) {
922 auto i8 = Type::getInt8Ty(ptr->getContext());
923 ptr = BuilderM.CreatePointerCast(
924 ptr, PointerType::get(
925 i8, cast<PointerType>(ptr->getType())->getAddressSpace()));
926 auto off = ConstantInt::get(Type::getInt64Ty(ptr->getContext()), start);
927 ptr = BuilderM.CreateInBoundsGEP(i8, ptr, off);
928 }
929 if (needsCast) {
930 ptr = BuilderM.CreatePointerCast(
931 ptr, PointerType::get(
932 addingType,
933 cast<PointerType>(ptr->getType())->getAddressSpace()));
934 }
935 return ptr;
936 };
937 ptr = applyChainRule(
938 PointerType::get(
939 addingType,
940 isa<PointerType>(origptr->getType())
941 ? cast<PointerType>(origptr->getType())->getAddressSpace()
942 : 0),
943 BuilderM, rule, ptr);
944 }
945
946 if (getWidth() == 1)
947 needsCast = dif->getType() != addingType;
948 else if (auto AT = cast<ArrayType>(dif->getType()))
949 needsCast = AT->getElementType() != addingType;
950 else
951 needsCast =
952 cast<VectorType>(dif->getType())->getElementType() != addingType;
953
954 if (start != 0 || needsCast) {
955 auto rule = [&](Value *dif) {
956 if (start != 0) {
957 IRBuilder<> A(inversionAllocs);
958 auto i8 = Type::getInt8Ty(ptr->getContext());
959 auto prevSize = (DL.getTypeSizeInBits(dif->getType()) + 1) / 8;
960 Type *tys[] = {ArrayType::get(i8, start), addingType,
961 ArrayType::get(i8, prevSize - start - size)};
962 auto ST = StructType::get(i8->getContext(), tys, /*isPacked*/ true);
963 auto Al = A.CreateAlloca(ST);
964 BuilderM.CreateStore(
965 dif, BuilderM.CreatePointerCast(Al, getUnqual(dif->getType())));
966 Value *idxs[] = {
967 ConstantInt::get(Type::getInt64Ty(ptr->getContext()), 0),
968 ConstantInt::get(Type::getInt32Ty(ptr->getContext()), 1)};
969
970 auto difp = BuilderM.CreateInBoundsGEP(ST, Al, idxs);
971 dif = BuilderM.CreateLoad(addingType, difp);
972 }
973 if (dif->getType() != addingType) {
974 auto difSize = (DL.getTypeSizeInBits(dif->getType()) + 1) / 8;
975 if (difSize < size) {
976 llvm::errs() << " ds: " << difSize << " as: " << size << "\n";
977 llvm::errs() << " dif: " << *dif << " adding: " << *addingType
978 << "\n";
979 }
980 assert(difSize >= size);
981 if (CastInst::castIsValid(Instruction::CastOps::BitCast, dif,
982 addingType))
983 dif = BuilderM.CreateBitCast(dif, addingType);
984 else {
985 IRBuilder<> A(inversionAllocs);
986 auto Al = A.CreateAlloca(addingType);
987 BuilderM.CreateStore(
988 dif, BuilderM.CreatePointerCast(Al, getUnqual(dif->getType())));
989 dif = BuilderM.CreateLoad(addingType, Al);
990 }
991 }
992 return dif;
993 };
994 dif = applyChainRule(addingType, BuilderM, rule, dif);
995 }
996
997 auto TmpOrig = getBaseObject(origptr);
998
999 // atomics
1000 bool Atomic = AtomicAdd;
1001 auto Arch = llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch();
1002
1003 // No need to do atomic on local memory for CUDA since it can't be raced
1004 // upon
1005 if (isa<AllocaInst>(TmpOrig) &&
1006 (Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
1007 Arch == Triple::amdgcn)) {
1008 Atomic = false;
1009 }
1010 // Moreover no need to do atomic on local shadows regardless since they are
1011 // not captured/escaping and created in this function. This assumes that
1012 // all additional parallelism in this function is outlined.
1013 if (backwardsOnlyShadows.find(TmpOrig) != backwardsOnlyShadows.end())
1014 Atomic = false;
1015 if (Atomic && elementwiseReadForContext(orig, origptr))
1016 Atomic = false;
1017
1018 if (Atomic) {
1019 // For amdgcn constant AS is 4 and if the primal is in it we need to cast
1020 // the derivative value to AS 1
1021 if (Arch == Triple::amdgcn &&
1022 cast<PointerType>(origptr->getType())->getAddressSpace() == 4) {
1023 auto rule = [&](Value *ptr) {
1024 return BuilderM.CreateAddrSpaceCast(ptr,
1025 PointerType::get(addingType, 1));
1026 };
1027 ptr =
1028 applyChainRule(PointerType::get(addingType, 1), BuilderM, rule, ptr);
1029 }
1030
1031 if (mask) {
1032 std::string s;
1033 llvm::raw_string_ostream ss(s);
1034 ss << "Unimplemented masked atomic fadd for ptr:" << *ptr
1035 << " dif:" << *dif << " mask: " << *mask << " orig: " << *orig << "\n";
1036 if (CustomErrorHandler) {
1037 CustomErrorHandler(ss.str().c_str(), wrap(orig),
1038 ErrorType::NoDerivative, this, nullptr,
1039 wrap(&BuilderM));
1040 return;
1041 } else {
1042 EmitFailure("NoDerivative", orig->getDebugLoc(), orig, ss.str());
1043 return;
1044 }
1045 }
1046
1047 /*
1048 while (auto ASC = dyn_cast<AddrSpaceCastInst>(ptr)) {
1049 ptr = ASC->getOperand(0);
1050 }
1051 while (auto ASC = dyn_cast<ConstantExpr>(ptr)) {
1052 if (!ASC->isCast()) break;
1053 if (ASC->getOpcode() != Instruction::AddrSpaceCast) break;
1054 ptr = ASC->getOperand(0);
1055 }
1056 */
1057 AtomicRMWInst::BinOp op = AtomicRMWInst::FAdd;
1058 if (auto vt = dyn_cast<VectorType>(addingType)) {
1059 assert(!vt->getElementCount().isScalable());
1060 size_t numElems = vt->getElementCount().getKnownMinValue();
1061 auto rule = [&](Value *dif, Value *ptr) {
1062 for (size_t i = 0; i < numElems; ++i) {
1063 auto vdif = BuilderM.CreateExtractElement(dif, i);
1064 vdif = SanitizeDerivatives(orig, vdif, BuilderM);
1065 Value *Idxs[] = {
1066 ConstantInt::get(Type::getInt64Ty(vt->getContext()), 0),
1067 ConstantInt::get(Type::getInt32Ty(vt->getContext()), i)};
1068 auto vptr = BuilderM.CreateGEP(addingType, ptr, Idxs);
1069 MaybeAlign alignv = align;
1070 if (alignv) {
1071 if (start != 0) {
1072 // todo make better alignment calculation
1073 assert((*alignv).value() != 0);
1074 if (start % (*alignv).value() != 0) {
1075 alignv = Align(1);
1076 }
1077 }
1078 }
1079 BuilderM.CreateAtomicRMW(op, vptr, vdif, alignv,
1080 AtomicOrdering::Monotonic,
1081 SyncScope::System);
1082 }
1083 };
1084 applyChainRule(BuilderM, rule, dif, ptr);
1085 } else {
1086 auto rule = [&](Value *dif, Value *ptr) {
1087 dif = SanitizeDerivatives(orig, dif, BuilderM);
1088 MaybeAlign alignv = align;
1089 if (alignv) {
1090 if (start != 0) {
1091 // todo make better alignment calculation
1092 assert((*alignv).value() != 0);
1093 if (start % (*alignv).value() != 0) {
1094 alignv = Align(1);
1095 }
1096 }
1097 }
1098 BuilderM.CreateAtomicRMW(op, ptr, dif, alignv,
1099 AtomicOrdering::Monotonic, SyncScope::System);
1100 };
1101 applyChainRule(BuilderM, rule, dif, ptr);
1102 }
1103 return;
1104 }
1105
1106 if (!mask) {
1107
1108 size_t idx = 0;
1109 auto rule = [&](Value *ptr, Value *dif) {
1110 auto LI = BuilderM.CreateLoad(addingType, ptr);
1111
1112 Value *res = BuilderM.CreateFAdd(LI, dif);
1113 res = SanitizeDerivatives(orig, res, BuilderM);
1114 StoreInst *st = BuilderM.CreateStore(res, ptr);
1115
1116 SmallVector<Metadata *, 1> scopeMD = {
1117 getDerivativeAliasScope(origptr, idx)};
1118 if (auto origValI = dyn_cast_or_null<Instruction>(origVal))
1119 if (auto MD = origValI->getMetadata(LLVMContext::MD_alias_scope)) {
1120 auto MDN = cast<MDNode>(MD);
1121 for (auto &o : MDN->operands())
1122 scopeMD.push_back(o);
1123 }
1124 auto scope = MDNode::get(LI->getContext(), scopeMD);
1125 LI->setMetadata(LLVMContext::MD_alias_scope, scope);
1126 st->setMetadata(LLVMContext::MD_alias_scope, scope);
1127
1128 SmallVector<Metadata *, 1> MDs;
1129 for (ssize_t j = -1; j < getWidth(); j++) {
1130 if (j != (ssize_t)idx)
1131 MDs.push_back(getDerivativeAliasScope(origptr, j));
1132 }
1133 if (auto origValI = dyn_cast_or_null<Instruction>(origVal))
1134 if (auto MD = origValI->getMetadata(LLVMContext::MD_noalias)) {
1135 auto MDN = cast<MDNode>(MD);
1136 for (auto &o : MDN->operands())
1137 MDs.push_back(o);
1138 }
1139 idx++;
1140 auto noscope = MDNode::get(ptr->getContext(), MDs);
1141 LI->setMetadata(LLVMContext::MD_noalias, noscope);
1142 st->setMetadata(LLVMContext::MD_noalias, noscope);
1143
1144 if (origVal && isa<Instruction>(origVal) && start == 0 &&
1145 size == (DL.getTypeSizeInBits(origVal->getType()) + 7) / 8) {
1146 auto origValI = cast<Instruction>(origVal);
1147 LI->copyMetadata(*origValI, MD_ToCopy);
1148 unsigned int StoreData[] = {LLVMContext::MD_tbaa,
1149 LLVMContext::MD_tbaa_struct};
1150 for (auto MD : StoreData)
1151 st->setMetadata(MD, origValI->getMetadata(MD));
1152 }
1153
1154 LI->setDebugLoc(getNewFromOriginal(orig->getDebugLoc()));
1155 st->setDebugLoc(getNewFromOriginal(orig->getDebugLoc()));
1156
1157 if (align) {
1158 auto alignv = align ? (*align).value() : 0;
1159 if (alignv != 0) {
1160 if (start != 0) {
1161 // todo make better alignment calculation
1162 if (start % alignv != 0) {
1163 alignv = 1;
1164 }
1165 }
1166
1167 LI->setAlignment(Align(alignv));
1168 st->setAlignment(Align(alignv));
1169 }
1170 }
1171 };
1172 applyChainRule(BuilderM, rule, ptr, dif);
1173 } else {
1174 Type *tys[] = {addingType, origptr->getType()};
1175 auto LF = getIntrinsicDeclaration(oldFunc->getParent(),
1176 Intrinsic::masked_load, tys);
1177 auto SF = getIntrinsicDeclaration(oldFunc->getParent(),
1178 Intrinsic::masked_store, tys);
1179 unsigned aligni = align ? align->value() : 0;
1180
1181 if (aligni != 0)
1182 if (start != 0) {
1183 // todo make better alignment calculation
1184 if (start % aligni != 0) {
1185 aligni = 1;
1186 }
1187 }
1188 Value *alignv =
1189 ConstantInt::get(Type::getInt32Ty(mask->getContext()), aligni);
1190 auto rule = [&](Value *ptr, Value *dif) {
1191 Value *largs[] = {ptr, alignv, mask,
1192 Constant::getNullValue(dif->getType())};
1193 Value *LI = BuilderM.CreateCall(LF, largs);
1194 Value *res = BuilderM.CreateFAdd(LI, dif);
1195 res = SanitizeDerivatives(orig, res, BuilderM, mask);
1196 Value *sargs[] = {res, ptr, alignv, mask};
1197 BuilderM.CreateCall(SF, sargs);
1198 };
1199 applyChainRule(BuilderM, rule, ptr, dif);
1200 }
1201}
1202
1204 llvm::Instruction *orig, llvm::Value *origVal, TypeTree vd,
1205 unsigned LoadSize, llvm::Value *origptr, llvm::Value *prediff,
1206 llvm::IRBuilder<> &Builder2, MaybeAlign alignment, llvm::Value *premask)
1207
1208{
1209
1210 unsigned start = 0;
1211 unsigned size = LoadSize;
1212
1213 assert(prediff);
1214
1215 BasicBlock *merge = nullptr;
1216
1217 while (1) {
1218 unsigned nextStart = size;
1219
1220 auto dt = vd[{-1}];
1221 for (size_t i = start; i < size; ++i) {
1222 bool Legal = true;
1223 dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
1224 if (!Legal) {
1225 nextStart = i;
1226 break;
1227 }
1228 }
1229 if (!dt.isKnown()) {
1230 TR.dump();
1231 llvm::errs() << " vd:" << vd.str() << " start:" << start
1232 << " size: " << size << " dt:" << dt.str() << "\n";
1233 }
1234 assert(dt.isKnown());
1235
1236 if (Type *isfloat = dt.isFloat()) {
1237
1238 if (origVal) {
1239 if (start == 0 && nextStart == LoadSize) {
1240 setDiffe(origVal,
1241 Constant::getNullValue(getShadowType(origVal->getType())),
1242 Builder2);
1243 } else {
1244 Value *tostore = getDifferential(origVal);
1245
1246 auto i8 = Type::getInt8Ty(tostore->getContext());
1247 if (start != 0) {
1248 tostore = Builder2.CreatePointerCast(
1249 tostore,
1250 PointerType::get(
1251 i8,
1252 cast<PointerType>(tostore->getType())->getAddressSpace()));
1253 auto off = ConstantInt::get(Type::getInt64Ty(tostore->getContext()),
1254 start);
1255 tostore = Builder2.CreateInBoundsGEP(i8, tostore, off);
1256 }
1257 auto AT = ArrayType::get(i8, nextStart - start);
1258 tostore = Builder2.CreatePointerCast(
1259 tostore,
1260 PointerType::get(
1261 AT,
1262 cast<PointerType>(tostore->getType())->getAddressSpace()));
1263 Builder2.CreateStore(Constant::getNullValue(AT), tostore);
1264 }
1265 }
1266
1267 if (!isConstantValue(origptr)) {
1268 auto basePtr = getBaseObject(origptr);
1269 assert(!isConstantValue(basePtr));
1270 // If runtime activity, first see if we can prove that the shadow/primal
1271 // are distinct statically as they are allocas/mallocs, if not compare
1272 // the pointers and conditionally execute.
1273 if ((!isa<AllocaInst>(basePtr) && !isAllocationCall(basePtr, TLI)) &&
1274 runtimeActivity && !merge) {
1275 Value *primal_val = lookupM(getNewFromOriginal(origptr), Builder2);
1276 Value *shadow_val =
1277 lookupM(invertPointerM(origptr, Builder2), Builder2);
1278 if (getWidth() != 1) {
1279 shadow_val = extractMeta(Builder2, shadow_val, 0);
1280 }
1281 Value *shadow = Builder2.CreateICmpNE(primal_val, shadow_val);
1282
1283 BasicBlock *current = Builder2.GetInsertBlock();
1284 BasicBlock *conditional =
1285 addReverseBlock(current, current->getName() + "_active");
1286 merge = addReverseBlock(conditional, current->getName() + "_amerge");
1287 Builder2.CreateCondBr(shadow, conditional, merge);
1288 Builder2.SetInsertPoint(conditional);
1289 }
1290 // Masked partial type is unhanled.
1291 if (premask)
1292 assert(start == 0 && nextStart == LoadSize);
1293 addToInvertedPtrDiffe(orig, origVal, isfloat, start, nextStart - start,
1294 origptr, prediff, Builder2, alignment, premask);
1295 }
1296 }
1297
1298 if (nextStart == size)
1299 break;
1300 start = nextStart;
1301 }
1302 if (merge) {
1303 Builder2.CreateBr(merge);
1304 Builder2.SetInsertPoint(merge);
1305 }
1306}
@ AttemptFullUnwrapWithLookup
static bool isZero(llvm::Constant *cst)
static bool isAllocationCall(const llvm::Value *TmpOrig, llvm::TargetLibraryInfo &TLI)
llvm::cl::opt< bool > looseTypeAnalysis
SmallVector< unsigned int, 9 > MD_ToCopy
CallInst * CreateDealloc(llvm::IRBuilder<> &Builder, llvm::Value *ToFree)
Definition Utils.cpp:742
void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj, bool isTape)
Definition Utils.cpp:414
void EmitNoTypeError(const std::string &message, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &Builder2)
Definition Utils.cpp:4377
LLVMValueRef(* CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, const void *, LLVMValueRef, LLVMBuilderRef)
Definition Utils.cpp:62
llvm::FastMathFlags getFast()
Get LLVM fast math flags.
Definition Utils.cpp:3731
llvm::Value * SanitizeDerivatives(llvm::Value *val, llvm::Value *toset, llvm::IRBuilder<> &BuilderM, llvm::Value *mask)
Definition Utils.cpp:3634
DIFFE_TYPE
Potential differentiable argument classifications.
Definition Utils.h:374
static llvm::PointerType * getUnqual(llvm::Type *T)
Definition Utils.h:1179
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
Definition Utils.h:203
static llvm::Function * getIntrinsicDeclaration(llvm::Module *M, llvm::Intrinsic::ID id, llvm::ArrayRef< llvm::Type * > Tys={})
Definition Utils.h:2263
static llvm::Value * getBaseObject(llvm::Value *V, bool offsetAllowed=true)
Definition Utils.h:1507
static T min(T a, T b)
Pick the maximum value.
Definition Utils.h:268
DerivativeMode
Definition Utils.h:390
llvm::Function *const newFunc
The function whose instructions we are caching.
std::map< llvm::AllocaInst *, std::set< llvm::AssertingVH< llvm::CallInst > > > scopeFrees
A map of allocations to a set of instructions which free memory as part of the cache.
llvm::TargetLibraryInfo & TLI
Various analysis results of newFunc.
llvm::LoopInfo LI
llvm::BasicBlock * inversionAllocs
llvm::SmallVector< std::pair< llvm::Value *, llvm::SmallVector< std::pair< LoopContext, llvm::Value * >, 4 > >, 0 > SubLimitType
Given a LimitContext ctx, representing a location inside a loop nest, break each of the loops up into...
unsigned getCacheAlignment(unsigned bsize) const
llvm::ValueMap< const llvm::Value *, llvm::TrackingVH< llvm::AllocaInst > > differentials
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal, llvm::Type *addingType, unsigned start, unsigned size, llvm::Value *origptr, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, llvm::MaybeAlign align=llvm::MaybeAlign(), llvm::Value *mask=nullptr)
align is the alignment that should be specified for load/store to pointer
llvm::Value * diffe(llvm::Value *val, llvm::IRBuilder<> &BuilderM)
llvm::AllocaInst * getDifferential(llvm::Value *val)
bool FreeMemory
Whether to free memory in reverse pass or split forward.
static DiffeGradientUtils * CreateFromClone(EnzymeLogic &Logic, DerivativeMode mode, bool runtimeActivity, bool strongZero, unsigned width, llvm::Function *todiff, llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool shadowReturnArg, bool diffeReturnArg, llvm::ArrayRef< DIFFE_TYPE > constant_args, bool returnTape, bool returnPrimal, llvm::Type *additionalArg, bool omp)
llvm::CallInst * freeCache(llvm::BasicBlock *forwardPreheader, const SubLimitType &sublimits, int i, llvm::AllocaInst *alloc, llvm::Type *myType, llvm::ConstantInt *byteSizeOfType, llvm::Value *storeInto, llvm::MDNode *InvariantMD) override
If an allocation is requested to be freed, this subclass will be called to chose how and where to fre...
void setDiffe(llvm::Value *val, llvm::Value *toset, llvm::IRBuilder<> &BuilderM)
llvm::SmallVector< llvm::SelectInst *, 4 > addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, llvm::Type *addingType, llvm::ArrayRef< llvm::Value * > idxs={}, llvm::Value *mask=nullptr)
Returns created select instructions, if any.
PreProcessCache PPC
DerivativeMode mode
llvm::ValueMap< const llvm::Value *, InvertedPointerVH > invertedPointers
llvm::DebugLoc getNewFromOriginal(const llvm::DebugLoc L) const
static llvm::Value * extractMeta(llvm::IRBuilder<> &Builder, llvm::Value *Agg, unsigned off, const llvm::Twine &name="")
Helper routine to extract a nested element from a struct/array. This is.
llvm::Value * applyChainRule(llvm::Type *diffType, llvm::IRBuilder<> &Builder, Func rule, Args... args)
Unwraps a vector derivative from its internal representation and applies a function f to each element...
TypeAnalysis & TA
TypeResults TR
llvm::BasicBlock * addReverseBlock(llvm::BasicBlock *currentBlock, llvm::Twine const &name, bool forkCache=true, bool push=true)
unsigned getWidth()
std::map< llvm::BasicBlock *, llvm::SmallVector< llvm::BasicBlock *, 4 > > reverseBlocks
Map of primal block to corresponding block(s) in reverse.
llvm::Function * oldFunc
void replaceAWithB(llvm::Value *A, llvm::Value *B, bool storeInCache=false) override
Replace this instruction both in LLVM modules and any local data-structures.
EnzymeLogic & Logic
llvm::Value * unwrapM(llvm::Value *const val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &available, UnwrapMode unwrapMode, llvm::BasicBlock *scope=nullptr, bool permitCache=true) override final
if full unwrap, don't just unwrap this instruction, but also its operands, etc
llvm::Value * lookupM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &incoming_availalble=llvm::ValueToValueMapTy(), bool tryLegalRecomputeCheck=true, llvm::BasicBlock *scope=nullptr) override
High-level utility to get the value an instruction at a new location specified by BuilderM.
llvm::MDNode * getDerivativeAliasScope(const llvm::Value *origptr, ssize_t newptr)
static llvm::Type * getShadowType(llvm::Type *ty, unsigned width)
llvm::ValueMap< llvm::Value *, ShadowRematerializer > backwardsOnlyShadows
Only loaded from and stored to (not captured), mapped to the stores (and memset).
bool isConstantValue(llvm::Value *val) const
llvm::Value * invertPointerM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, bool nullShadow=false)
void erase(llvm::Instruction *I) override
Erase this instruction both from LLVM modules and any local data-structures.
llvm::Function * CloneFunctionWithReturns(DerivativeMode mode, unsigned width, llvm::Function *&F, llvm::ValueToValueMapTy &ptrInputs, llvm::ArrayRef< DIFFE_TYPE > constant_args, llvm::SmallPtrSetImpl< llvm::Value * > &constants, llvm::SmallPtrSetImpl< llvm::Value * > &nonconstant, llvm::SmallPtrSetImpl< llvm::Value * > &returnvals, bool returnTape, bool returnPrimal, bool returnShadow, const llvm::Twine &name, llvm::ValueMap< const llvm::Value *, AssertingReplacingVH > *VMapO, bool diffeReturnArg, llvm::Type *additionalArg=nullptr)
Full interprocedural TypeAnalysis.
TypeResults analyzeFunction(const FnTypeInfo &fn)
Analyze a particular function, returning the results.
A holder class representing the results of running TypeAnalysis on a given function.
void dump(llvm::raw_ostream &ss=llvm::errs()) const
Prints all known information.
TypeAnalyzer * analyzer
llvm::Function * getFunction() const
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
Definition TypeTree.h:72
std::string str() const
Returns a string representation of this TypeTree.
Definition TypeTree.h:1383
Struct containing all contextual type information for a particular function call.
std::map< llvm::Argument *, TypeTree > Arguments
Types of arguments.
TypeTree Return
Type of return.
std::map< llvm::Argument *, std::set< int64_t > > KnownValues
The specific constant(s) known to represented by an argument, if constant.