28#ifndef ENZYME_DIFFEGRADIENTUTILS_H_
29#define ENZYME_DIFFEGRADIENTUTILS_H_
31#include "GradientUtils.h"
33#include <llvm/Config/llvm-config.h>
35#include "llvm/ADT/ArrayRef.h"
36#include "llvm/ADT/SmallPtrSet.h"
37#include "llvm/ADT/SmallVector.h"
39#include "llvm/IR/BasicBlock.h"
40#include "llvm/IR/Dominators.h"
41#include "llvm/IR/IRBuilder.h"
42#include "llvm/IR/Instructions.h"
43#include "llvm/IR/Metadata.h"
44#include "llvm/IR/Type.h"
45#include "llvm/IR/Value.h"
47#include "llvm/Analysis/AliasAnalysis.h"
48#include "llvm/Analysis/LoopInfo.h"
49#include "llvm/Analysis/PostDominators.h"
50#include "llvm/Analysis/ValueTracking.h"
52#include "ActivityAnalysis.h"
53#include "EnzymeLogic.h"
56#include "llvm-c/Core.h"
58#if LLVM_VERSION_MAJOR <= 16
59#include "llvm/ADT/Triple.h"
66 llvm::ValueToValueMapTy &invertedPointers_,
67 const llvm::SmallPtrSetImpl<llvm::Value *> &constantvalues_,
68 const llvm::SmallPtrSetImpl<llvm::Value *> &returnvals_,
70 llvm::ArrayRef<DIFFE_TYPE> constant_values,
71 llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> &origToNew_,
73 unsigned width,
bool omp);
78 llvm::ValueMap<const llvm::Value *, llvm::TrackingVH<llvm::AllocaInst>>
82 bool strongZero,
unsigned width, llvm::Function *todiff,
85 bool shadowReturnArg,
bool diffeReturnArg,
86 llvm::ArrayRef<DIFFE_TYPE> constant_args,
bool returnTape,
87 bool returnPrimal, llvm::Type *additionalArg,
bool omp);
92 llvm::Value *
diffe(llvm::Value *val, llvm::IRBuilder<> &BuilderM);
95 llvm::SmallVector<llvm::SelectInst *, 4>
96 addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM,
97 llvm::Type *addingType, llvm::ArrayRef<llvm::Value *> idxs = {},
98 llvm::Value *mask =
nullptr);
101 llvm::SmallVector<llvm::SelectInst *, 4>
102 addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM,
103 llvm::Type *addingType,
unsigned start,
unsigned size,
104 llvm::ArrayRef<llvm::Value *> idxs = {},
105 llvm::Value *mask =
nullptr,
size_t ignoreFirstSlicesToDiff = 0);
107 void setDiffe(llvm::Value *val, llvm::Value *toset,
108 llvm::IRBuilder<> &BuilderM);
110 llvm::CallInst *
freeCache(llvm::BasicBlock *forwardPreheader,
112 llvm::AllocaInst *alloc, llvm::Type *myType,
113 llvm::ConstantInt *byteSizeOfType,
114 llvm::Value *storeInto,
115 llvm::MDNode *InvariantMD)
override;
119 llvm::Type *addingType,
unsigned start,
120 unsigned size, llvm::Value *origptr,
121 llvm::Value *dif, llvm::IRBuilder<> &BuilderM,
122 llvm::MaybeAlign align = llvm::MaybeAlign(),
123 llvm::Value *mask =
nullptr);
126 TypeTree vd,
unsigned size, llvm::Value *origptr,
127 llvm::Value *prediff, llvm::IRBuilder<> &Builder2,
128 llvm::MaybeAlign align = llvm::MaybeAlign(),
129 llvm::Value *premask =
nullptr);
DIFFE_TYPE
Potential differentiable argument classifications.
llvm::TargetLibraryInfo & TLI
Various analysis results of newFunc.
llvm::SmallVector< std::pair< llvm::Value *, llvm::SmallVector< std::pair< LoopContext, llvm::Value * >, 4 > >, 0 > SubLimitType
Given a LimitContext ctx, representing a location inside a loop nest, break each of the loops up into...
llvm::ValueMap< const llvm::Value *, llvm::TrackingVH< llvm::AllocaInst > > differentials
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal, llvm::Type *addingType, unsigned start, unsigned size, llvm::Value *origptr, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, llvm::MaybeAlign align=llvm::MaybeAlign(), llvm::Value *mask=nullptr)
align is the alignment that should be specified for load/store to pointer
llvm::Value * diffe(llvm::Value *val, llvm::IRBuilder<> &BuilderM)
llvm::AllocaInst * getDifferential(llvm::Value *val)
llvm::SmallVector< llvm::SelectInst *, 4 > addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, llvm::Type *addingType, unsigned start, unsigned size, llvm::ArrayRef< llvm::Value * > idxs={}, llvm::Value *mask=nullptr, size_t ignoreFirstSlicesToDiff=0)
Returns created select instructions, if any.
bool FreeMemory
Whether to free memory in reverse pass or split forward.
static DiffeGradientUtils * CreateFromClone(EnzymeLogic &Logic, DerivativeMode mode, bool runtimeActivity, bool strongZero, unsigned width, llvm::Function *todiff, llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool shadowReturnArg, bool diffeReturnArg, llvm::ArrayRef< DIFFE_TYPE > constant_args, bool returnTape, bool returnPrimal, llvm::Type *additionalArg, bool omp)
llvm::CallInst * freeCache(llvm::BasicBlock *forwardPreheader, const SubLimitType &sublimits, int i, llvm::AllocaInst *alloc, llvm::Type *myType, llvm::ConstantInt *byteSizeOfType, llvm::Value *storeInto, llvm::MDNode *InvariantMD) override
If an allocation is requested to be freed, this subclass will be called to chose how and where to fre...
void setDiffe(llvm::Value *val, llvm::Value *toset, llvm::IRBuilder<> &BuilderM)
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal, TypeTree vd, unsigned size, llvm::Value *origptr, llvm::Value *prediff, llvm::IRBuilder<> &Builder2, llvm::MaybeAlign align=llvm::MaybeAlign(), llvm::Value *premask=nullptr)
llvm::SmallVector< llvm::SelectInst *, 4 > addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, llvm::Type *addingType, llvm::ArrayRef< llvm::Value * > idxs={}, llvm::Value *mask=nullptr)
Returns created select instructions, if any.
Full interprocedural TypeAnalysis.
A holder class representing the results of running TypeAnalysis on a given function.
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
Struct containing all contextual type information for a particular function call.