27#ifndef ENZYME_CACHE_UTILITY_H
28#define ENZYME_CACHE_UTILITY_H
30#include <llvm/Config/llvm-config.h>
31#if LLVM_VERSION_MAJOR >= 16
33#include "llvm/Analysis/ScalarEvolution.h"
34#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
37#include "SCEV/ScalarEvolution.h"
38#include "SCEV/ScalarEvolutionExpander.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/SmallPtrSet.h"
43#include "llvm/Analysis/LoopInfo.h"
44#include "llvm/IR/Instructions.h"
46#include "llvm/Analysis/AssumptionCache.h"
47#include "llvm/Analysis/TargetLibraryInfo.h"
48#include "llvm/IR/Dominators.h"
49#include "llvm/Support/CommandLine.h"
50#include "llvm/Transforms/Utils/ValueMapper.h"
65 llvm::AssertingVH<llvm::PHINode>
var;
68 llvm::AssertingVH<llvm::Instruction>
incvar;
122static inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
126 os <<
"LegalFullUnwrap";
129 os <<
"LegalFullUnwrapNoTapeReplace";
132 os <<
"AttemptFullUnwrapWithLookup";
135 os <<
"AttemptFullUnwrap";
138 os <<
"AttemptSingleUnwrap";
150 llvm::TargetLibraryInfo &
TLI;
151 llvm::DominatorTree
DT;
155 llvm::AssumptionCache
AC;
168 "allocsForInversion",
newFunc);
189 if (context.second.var == &I || context.second.incvar == &I ||
190 context.second.maxLimit == &I || context.second.trueLimit == &I) {
198 bool ReverseLimit =
true);
202 llvm::errs() <<
"scope:\n";
204 llvm::errs() <<
" scopeMap[" << *a.first <<
"] = " << *a.second.first
205 <<
" ctx:" << a.second.second.Block->getName() <<
"\n";
207 llvm::errs() <<
"end scope\n";
211 if ((bsize & (bsize - 1)) == 0) {
216 }
else if (bsize > 0 && bsize % 8 == 0) {
218 }
else if (bsize > 0 && bsize % 4 == 0) {
220 }
else if (bsize > 0 && bsize % 2 == 0) {
228 virtual void erase(llvm::Instruction *I);
232 bool storeInCache =
false);
259 typedef llvm::SmallVector<std::pair<
262 std::pair<LoopContext, llvm::Value *>, 4>>,
273 llvm::ValueMap<llvm::Value *,
274 std::map<llvm::BasicBlock *, llvm::WeakTrackingVH>>
281 std::map<std::tuple<llvm::Value *, llvm::Value *, llvm::BasicBlock *>,
287 llvm::Value *computeIndexOfChunk(
288 bool inForwardPass, llvm::IRBuilder<> &v,
289 llvm::ArrayRef<std::pair<LoopContext, llvm::Value *>> containedloops,
290 const llvm::ValueToValueMapTy &available);
298 std::map<std::pair<llvm::Value *, int>, llvm::MDNode *>
299 CachePointerInvariantGroups;
302 std::map<llvm::Value *, llvm::MDNode *> ValueInvariantGroups;
306 std::map<llvm::Value *,
307 std::pair<llvm::AssertingVH<llvm::AllocaInst>,
LimitContext>>
314 std::map<llvm::AllocaInst *,
315 llvm::SmallVector<llvm::AssertingVH<llvm::Instruction>, 4>>
320 std::map<llvm::AllocaInst *, std::set<llvm::AssertingVH<llvm::CallInst>>>
325 std::map<llvm::AllocaInst *,
326 llvm::SmallVector<llvm::AssertingVH<llvm::CallInst>, 4>>
332 llvm::Value *cptr, llvm::Value *cache);
339 llvm::StringRef name,
bool shouldFree,
340 bool allocateInternal =
true,
341 llvm::Value *extraSize =
nullptr);
350 virtual llvm::Value *
351 unwrapM(llvm::Value *
const val, llvm::IRBuilder<> &BuilderM,
352 const llvm::ValueToValueMapTy &available,
UnwrapMode mode,
353 llvm::BasicBlock *scope =
nullptr,
bool permitCache =
true) = 0;
363 virtual llvm::Value *
364 lookupM(llvm::Value *val, llvm::IRBuilder<> &BuilderM,
365 const llvm::ValueToValueMapTy &incoming_availalble =
366 llvm::ValueToValueMapTy(),
367 bool tryLegalityCheck =
true, llvm::BasicBlock *scope =
nullptr) = 0;
375 virtual llvm::CallInst *
freeCache(llvm::BasicBlock *forwardPreheader,
377 llvm::AllocaInst *alloc, llvm::Type *myType,
378 llvm::ConstantInt *byteSizeOfType,
379 llvm::Value *storeInto,
380 llvm::MDNode *InvariantMD) {
381 assert(0 &&
"freeing cache not handled in this scenario");
382 llvm_unreachable(
"freeing cache not handled in this scenario");
388 llvm::Value *val, llvm::AllocaInst *cache,
389 llvm::MDNode *TBAA =
nullptr);
394 llvm::AllocaInst *cache,
395 llvm::MDNode *TBAA =
nullptr);
403 llvm::Value *cache,
bool storeInInstructionsMap,
404 const llvm::ValueToValueMapTy &available,
405 llvm::Value *extraSize);
410 llvm::IRBuilder<> &BuilderM,
413 const llvm::ValueToValueMapTy &available,
414 llvm::Value *extraSize =
nullptr,
415 llvm::Value *extraOffset =
nullptr);
424std::pair<llvm::PHINode *, llvm::Instruction *>
426 const llvm::Twine &Name =
"iv");
431 llvm::BasicBlock *Header, llvm::PHINode *CanonicalIV,
433 llvm::function_ref<
void(llvm::Instruction *, llvm::Value *)> replacer,
434 llvm::function_ref<
void(llvm::Instruction *)> eraser);
std::pair< llvm::PHINode *, llvm::Instruction * > InsertNewCanonicalIV(llvm::Loop *L, llvm::Type *Ty, const llvm::Twine &Name="iv")
llvm::cl::opt< bool > EnzymeZeroCache
llvm::cl::opt< bool > EfficientBoolCache
Pack 8 bools together in a single byte.
UnwrapMode
Modes of potential unwraps.
@ LegalFullUnwrapNoTapeReplace
@ AttemptFullUnwrapWithLookup
void RemoveRedundantIVs(llvm::BasicBlock *Header, llvm::PHINode *CanonicalIV, llvm::Instruction *Increment, MustExitScalarEvolution &SE, llvm::function_ref< void(llvm::Instruction *, llvm::Value *)> replacer, llvm::function_ref< void(llvm::Instruction *)> eraser)
static llvm::raw_ostream & operator<<(llvm::raw_ostream &os, UnwrapMode mode)
static bool operator==(const LoopContext &lhs, const LoopContext &rhs)
llvm::Function *const newFunc
The function whose instructions we are caching.
MustExitScalarEvolution SE
virtual llvm::Value * unwrapM(llvm::Value *const val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &available, UnwrapMode mode, llvm::BasicBlock *scope=nullptr, bool permitCache=true)=0
High-level utility to "unwrap" an instruction at a new location specified by BuilderM.
virtual llvm::CallInst * freeCache(llvm::BasicBlock *forwardPreheader, const SubLimitType &antimap, int i, llvm::AllocaInst *alloc, llvm::Type *myType, llvm::ConstantInt *byteSizeOfType, llvm::Value *storeInto, llvm::MDNode *InvariantMD)
If an allocation is requested to be freed, this subclass will be called to chose how and where to fre...
std::map< llvm::AllocaInst *, llvm::SmallVector< llvm::AssertingVH< llvm::CallInst >, 4 > > scopeAllocs
A map of allocations to a set of instructions which allocate memory as part of the cache.
llvm::AllocaInst * createCacheForScope(LimitContext ctx, llvm::Type *T, llvm::StringRef name, bool shouldFree, bool allocateInternal=true, llvm::Value *extraSize=nullptr)
Create a cache of Type T at the given LimitContext.
virtual void erase(llvm::Instruction *I)
Erase this instruction both from LLVM modules and any local data-structures.
CacheUtility(llvm::TargetLibraryInfo &TLI, llvm::Function *newFunc)
llvm::Value * getCachePointer(llvm::Type *T, bool inForwardPass, llvm::IRBuilder<> &BuilderM, LimitContext ctx, llvm::Value *cache, bool storeInInstructionsMap, const llvm::ValueToValueMapTy &available, llvm::Value *extraSize)
Given an allocation specified by the LimitContext ctx and cache, compute a pointer that can hold the ...
llvm::Value * lookupValueFromCache(llvm::Type *T, bool inForwardPass, llvm::IRBuilder<> &BuilderM, LimitContext ctx, llvm::Value *cache, bool isi1, const llvm::ValueToValueMapTy &available, llvm::Value *extraSize=nullptr, llvm::Value *extraOffset=nullptr)
Given an allocation specified by the LimitContext ctx and cache, lookup the underlying cached value.
bool getContext(llvm::BasicBlock *BB, LoopContext &loopContext, bool ReverseLimit)
Given a BasicBlock BB in newFunc, set loopContext to the relevant contained loop and return true.
llvm::Value * loadFromCachePointer(llvm::Type *T, llvm::IRBuilder<> &BuilderM, llvm::Value *cptr, llvm::Value *cache)
Perform the final load from the cache, applying requisite invariant group and alignment.
std::map< llvm::AllocaInst *, std::set< llvm::AssertingVH< llvm::CallInst > > > scopeFrees
A map of allocations to a set of instructions which free memory as part of the cache.
void dumpScope()
Print out all currently cached values.
llvm::TargetLibraryInfo & TLI
Various analysis results of newFunc.
std::map< llvm::Loop *, LoopContext > loopContexts
Map of Loop to requisite loop information needed for AD (forward/reverse induction/etc)
std::map< llvm::AllocaInst *, llvm::SmallVector< llvm::AssertingVH< llvm::Instruction >, 4 > > scopeInstructions
A map of allocations to a vector of instruction used to create by the allocation Keeping track of the...
llvm::BasicBlock * inversionAllocs
std::map< llvm::Value *, std::pair< llvm::AssertingVH< llvm::AllocaInst >, LimitContext > > scopeMap
A map of values being cached to their underlying allocation/limit context.
llvm::SmallVector< std::pair< llvm::Value *, llvm::SmallVector< std::pair< LoopContext, llvm::Value * >, 4 > >, 0 > SubLimitType
Given a LimitContext ctx, representing a location inside a loop nest, break each of the loops up into...
virtual void replaceAWithB(llvm::Value *A, llvm::Value *B, bool storeInCache=false)
Replace this instruction both in LLVM modules and any local data-structures.
unsigned getCacheAlignment(unsigned bsize) const
llvm::AllocaInst * getDynamicLoopLimit(llvm::Loop *L, bool ReverseLimit=true)
llvm::SmallPtrSet< llvm::LoadInst *, 10 > CacheLookups
SubLimitType getSubLimits(bool inForwardPass, llvm::IRBuilder<> *RB, LimitContext ctx, llvm::Value *extraSize=nullptr)
Given a LimitContext ctx, representing a location inside a loop nest, break each of the loops up into...
void storeInstructionInCache(LimitContext ctx, llvm::IRBuilder<> &BuilderM, llvm::Value *val, llvm::AllocaInst *cache, llvm::MDNode *TBAA=nullptr)
Given an allocation defined at a particular ctx, store the value val in the cache at the location def...
bool isInstructionUsedInLoopInduction(llvm::Instruction &I)
Return whether the given instruction is used as necessary as part of a loop context This includes as ...
virtual llvm::Value * lookupM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &incoming_availalble=llvm::ValueToValueMapTy(), bool tryLegalityCheck=true, llvm::BasicBlock *scope=nullptr)=0
High-level utility to get the value an instruction at a new location specified by BuilderM.
virtual bool assumeDynamicLoopOfSizeOne(llvm::Loop *L) const =0
LimitContext(bool ReverseLimit, llvm::BasicBlock *Block, bool ForceSingleIteration=false)
bool ForceSingleIteration
Container for all loop information to synthesize gradients.
llvm::Loop * parent
Parent loop of this loop.
llvm::BasicBlock * header
Header of this loop.
llvm::AssertingVH< llvm::Instruction > incvar
Increment of the induction.
bool dynamic
Whether this loop has a statically analyzable number of iterations.
llvm::SmallPtrSet< llvm::BasicBlock *, 8 > exitBlocks
All blocks this loop exits too.
llvm::AssertingVH< llvm::AllocaInst > antivaralloc
Allocation of induction variable of reverse pass.
llvm::AssertingVH< llvm::PHINode > var
Canonical induction variable of the loop.
AssertingReplacingVH offset
An offset to add to the index when getting the cache pointer.
llvm::BasicBlock * preheader
Preheader of this loop.
AssertingReplacingVH maxLimit
limit is last value of a canonical induction variable iters is number of times loop is run (thus iter...
AssertingReplacingVH allocLimit
An overriding allocation limit size.
AssertingReplacingVH trueLimit