Enzyme main
Loading...
Searching...
No Matches
AdjointGenerator Class Reference

#include "AdjointGenerator.h"

Inheritance diagram for AdjointGenerator:
Collaboration diagram for AdjointGenerator:

Public Member Functions

 AdjointGenerator (DerivativeMode Mode, GradientUtils *gutils, llvm::ArrayRef< DIFFE_TYPE > constant_args, DIFFE_TYPE retType, std::function< unsigned(llvm::Instruction *, CacheType, llvm::IRBuilder<> &)> getIndex, const std::map< llvm::CallInst *, std::pair< bool, const std::vector< bool > > > overwritten_args_map, const AugmentedReturn *augmentedReturn, const std::map< llvm::ReturnInst *, llvm::StoreInst * > *replacedReturns, const llvm::SmallPtrSetImpl< const llvm::Value * > &unnecessaryValues, const llvm::SmallPtrSetImpl< const llvm::Instruction * > &unnecessaryInstructions, const llvm::SmallPtrSetImpl< const llvm::Instruction * > &unnecessaryStores, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable)
 
void eraseIfUnused (llvm::Instruction &I, bool erase=true, bool check=true)
 
llvm::Value * MPI_TYPE_SIZE (llvm::Value *DT, llvm::IRBuilder<> &B, llvm::Type *intType, llvm::Function *caller)
 
llvm::Value * MPI_COMM_RANK (llvm::Value *comm, llvm::IRBuilder<> &B, llvm::Type *rankTy, llvm::Function *caller)
 
llvm::Value * MPI_COMM_SIZE (llvm::Value *comm, llvm::IRBuilder<> &B, llvm::Type *rankTy, llvm::Function *caller)
 
void visitInstruction (llvm::Instruction &inst)
 
void forwardModeInvertedPointerFallback (llvm::Instruction &I)
 
void visitAllocaInst (llvm::AllocaInst &I)
 
void visitICmpInst (llvm::ICmpInst &I)
 
void visitFCmpInst (llvm::FCmpInst &I)
 
void visitLoadLike (llvm::Instruction &I, llvm::MaybeAlign alignment, bool constantval, llvm::Value *mask=nullptr, llvm::Value *orig_maskInit=nullptr)
 
void visitLoadInst (llvm::LoadInst &LI)
 
void visitAtomicRMWInst (llvm::AtomicRMWInst &I)
 
void visitStoreInst (llvm::StoreInst &SI)
 
void visitCommonStore (llvm::Instruction &I, llvm::Value *orig_ptr, llvm::Value *orig_val, llvm::MaybeAlign prevalign, bool isVolatile, llvm::AtomicOrdering ordering, llvm::SyncScope::ID syncScope, llvm::Value *mask)
 
void visitGetElementPtrInst (llvm::GetElementPtrInst &gep)
 
void visitPHINode (llvm::PHINode &phi)
 
void visitCastInst (llvm::CastInst &I)
 
void visitSelectInst (llvm::SelectInst &SI)
 
void createSelectInstAdjoint (llvm::SelectInst &SI)
 
void visitExtractElementInst (llvm::ExtractElementInst &EEI)
 
void visitInsertElementInst (llvm::InsertElementInst &IEI)
 
void visitShuffleVectorInst (llvm::ShuffleVectorInst &SVI)
 
void visitExtractValueInst (llvm::ExtractValueInst &EVI)
 
void visitInsertValueInst (llvm::InsertValueInst &IVI)
 
void getReverseBuilder (llvm::IRBuilder<> &Builder2, bool original=true)
 
void getForwardBuilder (llvm::IRBuilder<> &Builder2)
 
llvm::Value * diffe (llvm::Value *val, llvm::IRBuilder<> &Builder)
 
void setDiffe (llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &Builder)
 
template<typename Func , typename... Args>
llvm::Value * applyChainRule (llvm::Type *diffType, llvm::IRBuilder<> &Builder, Func rule, Args... args)
 Unwraps a vector derivative from its internal representation and applies a function f to each element.
 
template<typename Func , typename... Args>
void applyChainRule (llvm::IRBuilder<> &Builder, Func rule, Args... args)
 Unwraps a vector derivative from its internal representation and applies a function f to each element.
 
template<typename Func >
void applyChainRule (llvm::ArrayRef< llvm::Value * > diffs, llvm::IRBuilder<> &Builder, Func rule)
 Unwraps an collection of constant vector derivatives from their internal representations and applies a function f to each element.
 
bool shouldFree ()
 
llvm::SmallVector< llvm::SelectInst *, 4 > addToDiffe (llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *mask=nullptr)
 
llvm::Value * lookup (llvm::Value *val, llvm::IRBuilder<> &Builder)
 
void visitBinaryOperator (llvm::BinaryOperator &BO)
 
void createBinaryOperatorAdjoint (llvm::BinaryOperator &BO)
 
void createBinaryOperatorDual (llvm::BinaryOperator &BO)
 
void visitMemSetInst (llvm::MemSetInst &MS)
 
void visitMemSetCommon (llvm::CallInst &MS)
 
void visitMemTransferInst (llvm::MemTransferInst &MTI)
 
void visitMemTransferCommon (llvm::Intrinsic::ID ID, llvm::MaybeAlign srcAlign, llvm::MaybeAlign dstAlign, llvm::CallInst &MTI, llvm::Value *orig_dst, llvm::Value *orig_src, llvm::Value *new_size, llvm::Value *isVolatile)
 
void visitFenceInst (llvm::FenceInst &FI)
 
void visitIntrinsicInst (llvm::IntrinsicInst &II)
 
bool handleAdjointForIntrinsic (llvm::Intrinsic::ID ID, llvm::Instruction &I, llvm::SmallVectorImpl< llvm::Value * > &orig_ops)
 
void visitOMPCall (llvm::CallInst &call)
 
void DifferentiableMemCopyFloats (llvm::CallInst &call, llvm::Value *origArg, llvm::Value *dsto, llvm::Value *srco, llvm::Value *len_arg, llvm::IRBuilder<> &Builder2, llvm::ArrayRef< llvm::OperandBundleDef > ReverseDefs)
 
void recursivelyHandleSubfunction (llvm::CallInst &call, llvm::Function *called, bool subsequent_calls_may_write, const std::vector< bool > &overwritten_args, bool shadowReturnUsed, DIFFE_TYPE subretType, bool subretused)
 
void handleMPI (llvm::CallInst &call, llvm::Function *called, llvm::StringRef funcName)
 
bool handleKnownCallDerivatives (llvm::CallInst &call, llvm::Function *called, llvm::StringRef funcName, bool subsequent_calls_may_write, const std::vector< bool > &overwritten_args, llvm::CallInst *const newCall)
 
void visitCallInst (llvm::CallInst &call)
 

Detailed Description

Definition at line 54 of file AdjointGenerator.h.

Constructor & Destructor Documentation

◆ AdjointGenerator()

AdjointGenerator::AdjointGenerator ( DerivativeMode Mode,
GradientUtils * gutils,
llvm::ArrayRef< DIFFE_TYPE > constant_args,
DIFFE_TYPE retType,
std::function< unsigned(llvm::Instruction *, CacheType, llvm::IRBuilder<> &)> getIndex,
const std::map< llvm::CallInst *, std::pair< bool, const std::vector< bool > > > overwritten_args_map,
const AugmentedReturn * augmentedReturn,
const std::map< llvm::ReturnInst *, llvm::StoreInst * > * replacedReturns,
const llvm::SmallPtrSetImpl< const llvm::Value * > & unnecessaryValues,
const llvm::SmallPtrSetImpl< const llvm::Instruction * > & unnecessaryInstructions,
const llvm::SmallPtrSetImpl< const llvm::Instruction * > & unnecessaryStores,
const llvm::SmallPtrSetImpl< llvm::BasicBlock * > & oldUnreachable )
inline

Member Function Documentation

◆ addToDiffe()

llvm::SmallVector< llvm::SelectInst *, 4 > AdjointGenerator::addToDiffe ( llvm::Value * val,
llvm::Value * dif,
llvm::IRBuilder<> & Builder,
llvm::Type * T,
llvm::Value * mask = nullptr )
inline

◆ applyChainRule() [1/3]

template<typename Func >
void AdjointGenerator::applyChainRule ( llvm::ArrayRef< llvm::Value * > diffs,
llvm::IRBuilder<> & Builder,
Func rule )
inline

Unwraps an collection of constant vector derivatives from their internal representations and applies a function f to each element.

Definition at line 2218 of file AdjointGenerator.h.

◆ applyChainRule() [2/3]

template<typename Func , typename... Args>
void AdjointGenerator::applyChainRule ( llvm::IRBuilder<> & Builder,
Func rule,
Args... args )
inline

Unwraps a vector derivative from its internal representation and applies a function f to each element.

Definition at line 2211 of file AdjointGenerator.h.

◆ applyChainRule() [3/3]

template<typename Func , typename... Args>
llvm::Value * AdjointGenerator::applyChainRule ( llvm::Type * diffType,
llvm::IRBuilder<> & Builder,
Func rule,
Args... args )
inline

Unwraps a vector derivative from its internal representation and applies a function f to each element.

Return values of f are collected and wrapped.

Definition at line 2202 of file AdjointGenerator.h.

Referenced by createBinaryOperatorAdjoint(), createBinaryOperatorDual(), handleAdjointForIntrinsic(), handleKnownCallDerivatives(), visitAtomicRMWInst(), visitBinaryOperator(), visitCastInst(), visitCommonStore(), visitInsertValueInst(), visitLoadLike(), visitMemSetCommon(), and visitMemTransferCommon().

◆ createBinaryOperatorAdjoint()

◆ createBinaryOperatorDual()

◆ createSelectInstAdjoint()

◆ diffe()

◆ DifferentiableMemCopyFloats()

void AdjointGenerator::DifferentiableMemCopyFloats ( llvm::CallInst & call,
llvm::Value * origArg,
llvm::Value * dsto,
llvm::Value * srco,
llvm::Value * len_arg,
llvm::IRBuilder<> & Builder2,
llvm::ArrayRef< llvm::OperandBundleDef > ReverseDefs )
inline

◆ eraseIfUnused()

◆ forwardModeInvertedPointerFallback()

◆ getForwardBuilder()

◆ getReverseBuilder()

◆ handleAdjointForIntrinsic()

◆ handleKnownCallDerivatives()

bool AdjointGenerator::handleKnownCallDerivatives ( llvm::CallInst & call,
llvm::Function * called,
llvm::StringRef funcName,
bool subsequent_calls_may_write,
const std::vector< bool > & overwritten_args,
llvm::CallInst *const newCall )

Definition at line 2212 of file CallDerivatives.cpp.

References GradientUtils::addReverseBlock(), applyChainRule(), GradientUtils::backwardsOnlyShadows, GradientUtils::cacheForReverse(), CreateAllocation(), CreateDealloc(), DUP_ARG, EmitNoDerivativeError(), EnzymeFreeInternalAllocations, EnzymeJuliaAddrLoad, EnzymeShadowAllocRewrite, GradientUtils::erase(), eraseIfUnused(), extractBLAS(), GradientUtils::extractMeta(), ForwardMode, ForwardModeError, ForwardModeSplit, freeKnownAllocation(), getFast(), getForwardBuilder(), GradientUtils::getInvertedBundles(), GradientUtils::getNewFromOriginal(), GradientUtils::getReturnDiffeType(), getReverseBuilder(), getUndefinedValueForType(), getUnqual(), GradientUtils::getWidth(), handleAdjointForIntrinsic(), handleMPI(), hasMetadata(), GradientUtils::invertedPointers, GradientUtils::invertPointerM(), DifferentialUseAnalysis::is_value_needed_in_reverse(), isAllocationFunction(), GradientUtils::isConstantInstruction(), GradientUtils::isConstantValue(), isMemFreeLibMFunction(), GradientUtils::knownRecomputeHeuristic, GradientUtils::legalRecompute(), lookup(), GradientUtils::lookupM(), MPIInactiveCommAllocators, CacheUtility::newFunc, None, GradientUtils::oldFunc, Primal, GradientUtils::rematerializableAllocations, GradientUtils::rematerializedPrimalOrShadowAllocations, GradientUtils::replaceAWithB(), GradientUtils::reverseBlocks, GradientUtils::reverseBlockToPrimal, ReverseModeCombined, ReverseModeGradient, ReverseModePrimal, Self, Shadow, shadowHandlers, shouldFree(), startsWith(), CacheUtility::TLI, GradientUtils::unnecessaryIntermediates, and zeroKnownAllocation().

Referenced by visitCallInst().

◆ handleMPI()

◆ lookup()

◆ MPI_COMM_RANK()

llvm::Value * AdjointGenerator::MPI_COMM_RANK ( llvm::Value * comm,
llvm::IRBuilder<> & B,
llvm::Type * rankTy,
llvm::Function * caller )
inline

◆ MPI_COMM_SIZE()

llvm::Value * AdjointGenerator::MPI_COMM_SIZE ( llvm::Value * comm,
llvm::IRBuilder<> & B,
llvm::Type * rankTy,
llvm::Function * caller )
inline

◆ MPI_TYPE_SIZE()

llvm::Value * AdjointGenerator::MPI_TYPE_SIZE ( llvm::Value * DT,
llvm::IRBuilder<> & B,
llvm::Type * intType,
llvm::Function * caller )
inline

◆ recursivelyHandleSubfunction()

void AdjointGenerator::recursivelyHandleSubfunction ( llvm::CallInst & call,
llvm::Function * called,
bool subsequent_calls_may_write,
const std::vector< bool > & overwritten_args,
bool shadowReturnUsed,
DIFFE_TYPE subretType,
bool subretused )
inline

We only need the shadow pointer for non-forward Mode if it is used in a non return setting

topLevel

Definition at line 4917 of file AdjointGenerator.h.

References TypeResults::addingType(), addToDiffe(), TypeResults::analyzer, TypeResults::anyPointer(), GradientUtils::AtomicAdd, Both, GradientUtils::cacheForReverse(), CONSTANT, AugmentedReturn::constant_args, convertSRetTypeToString(), EnzymeLogic::CreateAugmentedPrimal(), CreateDealloc(), EnzymeLogic::CreateForwardDiff(), EnzymeLogic::CreatePrimalAndGradient(), CustomErrorHandler, diffe(), DifferentialReturn, dumpMap(), DUP_ARG, DUP_NONEED, EmitFailure(), EmitNoDerivativeError(), GradientUtils::erase(), eraseIfUnused(), ErrorIfRuntimeInactive(), GradientUtils::fictiousPHIs, AugmentedReturn::fn, ForwardMode, ForwardModeError, ForwardModeSplit, getBaseObject(), TypeResults::getCallInfo(), getDefaultFunctionTypeForAugmentation(), getDefaultFunctionTypeForGradient(), GradientUtils::getDiffeType(), getFast(), getForwardBuilder(), getFuncNameFromCall(), getFunctionTypeForClone(), getInt8PtrTy(), GradientUtils::getInvertedBundles(), GradientUtils::getNewFromOriginal(), getPointerType(), getReverseBuilder(), GradientUtils::getShadowType(), getUndefinedValueForType(), getUnqual(), GradientUtils::getWidth(), IndexMappingError, insert_or_assign2(), InternalError, TypeAnalyzer::interprocedural, GradientUtils::invertedPointers, GradientUtils::invertPointerM(), DifferentialUseAnalysis::is_value_needed_in_reverse(), GradientUtils::isConstantInstruction(), GradientUtils::isConstantValue(), isMemFreeLibMFunction(), isNoCapture(), GradientUtils::isOriginal(), isReadOnly(), isWriteOnly(), GradientUtils::knownRecomputeHeuristic, legalCombinedForwardReverse(), GradientUtils::Logic, lookup(), GradientUtils::lookupM(), CacheUtility::newFunc, GradientUtils::newToOriginalFn, None, GradientUtils::oldFunc, GradientUtils::originalToNewFn, OUT_DIFF, Pointer, Primal, PrimalParamAttrsToPreserve, TypeResults::query(), GradientUtils::replaceAndRemoveUnwrapCacheFor(), GradientUtils::replaceAWithB(), Return, AugmentedReturn::returns, ReverseModeCombined, ReverseModeGradient, ReverseModePrimal, GradientUtils::runtimeActivity, Self, setDiffe(), Shadow, ShadowParamAttrsToPreserve, shouldAugmentCall(), shouldDisableNoWrite(), shouldFree(), str(), GradientUtils::strongZero, AugmentedReturn::subaugmentations, Tape, GradientUtils::TapesToPreventRecomputation, AugmentedReturn::tapeType, to_string(), GradientUtils::unnecessaryIntermediates, and whatType().

Referenced by visitCallInst().

◆ setDiffe()

◆ shouldFree()

bool AdjointGenerator::shouldFree ( )
inline

◆ visitAllocaInst()

void AdjointGenerator::visitAllocaInst ( llvm::AllocaInst & I)
inline

◆ visitAtomicRMWInst()

◆ visitBinaryOperator()

◆ visitCallInst()

◆ visitCastInst()

◆ visitCommonStore()

◆ visitExtractElementInst()

◆ visitExtractValueInst()

◆ visitFCmpInst()

void AdjointGenerator::visitFCmpInst ( llvm::FCmpInst & I)
inline

Definition at line 363 of file AdjointGenerator.h.

References eraseIfUnused().

◆ visitFenceInst()

void AdjointGenerator::visitFenceInst ( llvm::FenceInst & FI)
inline

◆ visitGetElementPtrInst()

void AdjointGenerator::visitGetElementPtrInst ( llvm::GetElementPtrInst & gep)
inline

◆ visitICmpInst()

void AdjointGenerator::visitICmpInst ( llvm::ICmpInst & I)
inline

Definition at line 361 of file AdjointGenerator.h.

References eraseIfUnused().

◆ visitInsertElementInst()

◆ visitInsertValueInst()

◆ visitInstruction()

◆ visitIntrinsicInst()

◆ visitLoadInst()

void AdjointGenerator::visitLoadInst ( llvm::LoadInst & LI)
inline

◆ visitLoadLike()

◆ visitMemSetCommon()

◆ visitMemSetInst()

void AdjointGenerator::visitMemSetInst ( llvm::MemSetInst & MS)
inline

Definition at line 2933 of file AdjointGenerator.h.

References visitMemSetCommon().

◆ visitMemTransferCommon()

◆ visitMemTransferInst()

void AdjointGenerator::visitMemTransferInst ( llvm::MemTransferInst & MTI)
inline

◆ visitOMPCall()

◆ visitPHINode()

void AdjointGenerator::visitPHINode ( llvm::PHINode & phi)
inline

◆ visitSelectInst()

◆ visitShuffleVectorInst()

◆ visitStoreInst()

void AdjointGenerator::visitStoreInst ( llvm::StoreInst & SI)
inline

The documentation for this class was generated from the following files: