Enzyme main
Loading...
Searching...
No Matches
DifferentialUseAnalysis Namespace Reference

Classes

struct  Node
 

Typedefs

using Graph = std::map<Node, std::set<Node>>
 

Functions

bool is_use_directly_needed_in_reverse (const GradientUtils *gutils, const llvm::Value *val, DerivativeMode mode, const llvm::Instruction *user, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable, QueryType shadow, bool *recursiveUse=nullptr)
 Determine if a value is needed directly to compute the adjoint of the given instruction user.
 
template<QueryType VT, bool OneLevel = false>
bool is_value_needed_in_reverse (const GradientUtils *gutils, const llvm::Value *inst, DerivativeMode mode, std::map< UsageKey, bool > &seen, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable)
 
template<QueryType VT>
static bool is_value_needed_in_reverse (const GradientUtils *gutils, const llvm::Value *inst, DerivativeMode mode, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable)
 
void dump (std::map< Node, std::set< Node > > &G)
 
void bfs (const std::map< Node, std::set< Node > > &G, const llvm::SetVector< llvm::Value * > &Recompute, std::map< Node, Node > &parent)
 
int cmpLoopNest (llvm::Loop *prev, llvm::Loop *next)
 
void minCut (const llvm::DataLayout &DL, llvm::LoopInfo &OrigLI, const llvm::SetVector< llvm::Value * > &Recomputes, const llvm::SetVector< llvm::Value * > &Intermediates, llvm::SetVector< llvm::Value * > &Required, llvm::SetVector< llvm::Value * > &MinReq, const GradientUtils *gutils, llvm::TargetLibraryInfo &TLI)
 
 __attribute__ ((always_inline)) static inline void forEachDirectInsertUser(llvm
 
bool callShouldNotUseDerivative (const GradientUtils *gutils, llvm::CallBase &orig, QueryType qtype, const llvm::Value *val)
 Return whether or not this is a constant and should use reverse pass.
 

Typedef Documentation

◆ Graph

using DifferentialUseAnalysis::Graph = std::map<Node, std::set<Node>>

Definition at line 497 of file DifferentialUseAnalysis.h.

Function Documentation

◆ __attribute__()

DifferentialUseAnalysis::__attribute__ ( (always_inline) )

◆ bfs()

void DifferentialUseAnalysis::bfs ( const std::map< Node, std::set< Node > > & G,
const llvm::SetVector< llvm::Value * > & Recompute,
std::map< Node, Node > & parent )

◆ callShouldNotUseDerivative()

bool DifferentialUseAnalysis::callShouldNotUseDerivative ( const GradientUtils * gutils,
llvm::CallBase & orig,
QueryType qtype,
const llvm::Value * val )

Return whether or not this is a constant and should use reverse pass.

Referenced by AdjointGenerator::visitCallInst().

◆ cmpLoopNest()

int DifferentialUseAnalysis::cmpLoopNest ( llvm::Loop * prev,
llvm::Loop * next )

◆ dump()

void DifferentialUseAnalysis::dump ( std::map< Node, std::set< Node > > & G)

◆ is_use_directly_needed_in_reverse()

bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse ( const GradientUtils * gutils,
const llvm::Value * val,
DerivativeMode mode,
const llvm::Instruction * user,
const llvm::SmallPtrSetImpl< llvm::BasicBlock * > & oldUnreachable,
QueryType shadow,
bool * recursiveUse = nullptr )

Determine if a value is needed directly to compute the adjoint of the given instruction user.

shadow denotes whether we are considering the shadow of the value (shadow=true) or the primal of the value (shadow=false). Recursive use is only usable in shadow mode.

Referenced by is_value_needed_in_reverse().

◆ is_value_needed_in_reverse() [1/2]

template<QueryType VT>
static bool DifferentialUseAnalysis::is_value_needed_in_reverse ( const GradientUtils * gutils,
const llvm::Value * inst,
DerivativeMode mode,
const llvm::SmallPtrSetImpl< llvm::BasicBlock * > & oldUnreachable )
inlinestatic

Definition at line 471 of file DifferentialUseAnalysis.h.

References is_value_needed_in_reverse(), Primal, and Shadow.

◆ is_value_needed_in_reverse() [2/2]

◆ minCut()

void DifferentialUseAnalysis::minCut ( const llvm::DataLayout & DL,
llvm::LoopInfo & OrigLI,
const llvm::SetVector< llvm::Value * > & Recomputes,
const llvm::SetVector< llvm::Value * > & Intermediates,
llvm::SetVector< llvm::Value * > & Required,
llvm::SetVector< llvm::Value * > & MinReq,
const GradientUtils * gutils,
llvm::TargetLibraryInfo & TLI )