25#ifndef ENZYME_FUNCTION_UTILS_H
26#define ENZYME_FUNCTION_UTILS_H
31#include <llvm/Config/llvm-config.h>
33#if LLVM_VERSION_MAJOR >= 16
35#include "llvm/Analysis/ScalarEvolution.h"
36#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
39#include "SCEV/ScalarEvolution.h"
40#include "SCEV/ScalarEvolutionExpander.h"
45#include "llvm/Analysis/AliasAnalysis.h"
46#include "llvm/Analysis/LoopAnalysisManager.h"
47#include "llvm/Analysis/TargetLibraryInfo.h"
49#include "llvm/IR/Function.h"
50#include "llvm/IR/Module.h"
51#include "llvm/IR/Type.h"
53#include "llvm/IR/Instructions.h"
54#include "llvm/Transforms/Utils/ValueMapper.h"
56#include "llvm/ADT/STLExtras.h"
87 cache = std::move(prev.cache);
91 llvm::LoopAnalysisManager
LAM;
92 llvm::FunctionAnalysisManager
FAM;
93 llvm::ModuleAnalysisManager
MAM;
95 std::map<std::pair<llvm::Function *, DerivativeMode>, llvm::Function *>
cache;
104 llvm::ValueToValueMapTy &ptrInputs,
105 llvm::ArrayRef<DIFFE_TYPE> constant_args,
106 llvm::SmallPtrSetImpl<llvm::Value *> &constants,
107 llvm::SmallPtrSetImpl<llvm::Value *> &nonconstant,
108 llvm::SmallPtrSetImpl<llvm::Value *> &returnvals,
bool returnTape,
109 bool returnPrimal,
bool returnShadow,
const llvm::Twine &name,
110 llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> *VMapO,
111 bool diffeReturnArg, llvm::Type *additionalArg =
nullptr);
125 llvm::SmallPtrSetImpl<llvm::BasicBlock *> &ExitBlocks) {
126 llvm::SmallVector<llvm::BasicBlock *, 8> PotentialExitBlocks;
127 L->getExitBlocks(PotentialExitBlocks);
128 for (
auto a : PotentialExitBlocks) {
130 llvm::SmallVector<llvm::BasicBlock *, 4> tocheck;
131 llvm::SmallPtrSet<llvm::BasicBlock *, 4> checked;
132 tocheck.push_back(a);
136 while (tocheck.size()) {
137 auto foo = tocheck.back();
139 if (checked.count(foo)) {
144 if (
auto bi = llvm::dyn_cast<llvm::BranchInst>(foo->getTerminator())) {
145 for (
auto nb : bi->successors()) {
148 tocheck.push_back(nb);
150 }
else if (llvm::isa<llvm::UnreachableInst>(foo->getTerminator())) {
160 ExitBlocks.insert(a);
165static inline llvm::SmallVector<llvm::BasicBlock *, 3>
167 const llvm::SmallPtrSetImpl<llvm::BasicBlock *> &ExitBlocks) {
168 llvm::BasicBlock *Preheader = L->getLoopPreheader();
170 llvm::errs() << *L->getHeader()->getParent() <<
"\n";
171 llvm::errs() << *L->getHeader() <<
"\n";
172 llvm::errs() << *L <<
"\n";
174 assert(Preheader &&
"requires preheader");
178 llvm::SmallVector<llvm::BasicBlock *, 3> Latches;
179 for (llvm::BasicBlock *ExitBlock : ExitBlocks) {
180 for (llvm::BasicBlock *pred : llvm::predecessors(ExitBlock)) {
181 if (L->contains(pred)) {
182 if (std::find(Latches.begin(), Latches.end(), pred) != Latches.end())
184 Latches.push_back(pred);
193static inline llvm::SmallPtrSet<llvm::BasicBlock *, 4>
195 llvm::SmallPtrSet<llvm::BasicBlock *, 4> knownUnreachables;
197 return knownUnreachables;
198 std::deque<llvm::BasicBlock *> todo;
199 for (
auto &BB : *F) {
203 while (!todo.empty()) {
204 llvm::BasicBlock *next = todo.front();
207 if (knownUnreachables.find(next) != knownUnreachables.end())
210 if (llvm::isa<llvm::ReturnInst>(next->getTerminator()))
213 if (llvm::isa<llvm::UnreachableInst>(next->getTerminator())) {
214 knownUnreachables.insert(next);
215 for (llvm::BasicBlock *Pred : predecessors(next)) {
216 todo.push_back(Pred);
223 if (llvm::isa<llvm::ResumeInst>(next->getTerminator())) {
224 knownUnreachables.insert(next);
225 for (llvm::BasicBlock *Pred : predecessors(next)) {
226 todo.push_back(Pred);
231 bool unreachable =
true;
232 for (llvm::BasicBlock *Succ : llvm::successors(next)) {
233 if (knownUnreachables.find(Succ) == knownUnreachables.end()) {
241 knownUnreachables.insert(next);
242 for (llvm::BasicBlock *Pred : llvm::predecessors(next)) {
243 todo.push_back(Pred);
248 return knownUnreachables;
258 llvm::SmallPtrSetImpl<const llvm::Value *> &unnecessaryValues,
259 llvm::SmallPtrSetImpl<const llvm::Instruction *> &unnecessaryInstructions,
260 bool returnValue, llvm::function_ref<
bool(
const llvm::Value *)> valneeded,
261 llvm::function_ref<
UseReq(
const llvm::Instruction *)> instneeded,
262 llvm::function_ref<
bool(
const llvm::Instruction *,
const llvm::Value *)>
265 std::deque<const llvm::Instruction *> todo;
267 for (
const llvm::BasicBlock &BB :
oldFunc) {
268 if (
auto ri = llvm::dyn_cast<llvm::ReturnInst>(BB.getTerminator())) {
270 unnecessaryInstructions.insert(ri);
272 unnecessaryValues.insert(ri);
274 for (
auto &inst : BB) {
275 if (&inst == BB.getTerminator())
277 todo.push_back(&inst);
281 while (!todo.empty()) {
282 auto inst = todo.front();
285 if (unnecessaryInstructions.count(inst)) {
286 assert(unnecessaryValues.count(inst));
290 if (!unnecessaryValues.count(inst)) {
292 if (valneeded(inst)) {
296 bool necessaryUse =
false;
298 llvm::SmallPtrSet<const llvm::Instruction *, 4> seen;
299 std::deque<const llvm::Instruction *> users;
301 for (
auto user_dtx : inst->users()) {
302 if (
auto cst = llvm::dyn_cast<llvm::Instruction>(user_dtx)) {
303 if (useneeded(cst, inst))
304 users.push_back(cst);
308 while (users.size()) {
309 auto val = users.front();
316 if (unnecessaryInstructions.count(val))
319 switch (instneeded(val)) {
324 for (
auto user_dtx : val->users()) {
325 if (
auto cst = llvm::dyn_cast<llvm::Instruction>(user_dtx)) {
326 if (useneeded(cst, val))
327 users.push_back(cst);
341 unnecessaryValues.insert(inst);
343 for (
auto user : inst->users()) {
344 if (
auto usedinst = llvm::dyn_cast<llvm::Instruction>(user))
345 todo.push_back(usedinst);
352 unnecessaryInstructions.insert(inst);
354 for (
auto &operand : inst->operands()) {
355 if (
auto usedinst = llvm::dyn_cast<llvm::Instruction>(operand.get())) {
356 todo.push_back(usedinst);
362 llvm::errs() <<
"Prepping values for: " <<
oldFunc.getName()
363 <<
" returnValue: " << returnValue <<
"\n";
364 for (
auto v : unnecessaryInstructions) {
365 llvm::errs() <<
"+ unnecessaryInstructions: " << *v <<
"\n";
367 for (
auto v : unnecessaryValues) {
368 llvm::errs() <<
"+ unnecessaryValues: " << *v <<
"\n";
370 llvm::errs() <<
"</end>\n";
376 llvm::SmallPtrSetImpl<const llvm::Instruction *> &unnecessaryStores,
377 llvm::function_ref<
bool(
const llvm::Instruction *)> needStore) {
379 std::deque<const llvm::Instruction *> todo;
381 for (
const llvm::BasicBlock &BB :
oldFunc) {
382 for (
auto &inst : BB) {
383 if (&inst == BB.getTerminator())
385 todo.push_back(&inst);
389 while (!todo.empty()) {
390 auto inst = todo.front();
393 if (unnecessaryStores.count(inst)) {
400 unnecessaryStores.insert(inst);
414 llvm::Type *additionalArg, llvm::ArrayRef<DIFFE_TYPE> constant_args,
415 bool diffeReturnArg,
bool returnTape,
bool returnPrimal,
bool returnShadow);
bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val)
Is the use of value val as an argument of call CI potentially captured.
llvm::cl::opt< bool > EnzymeAlwaysInlineDiff
void RecursivelyReplaceAddressSpace(llvm::Value *AI, llvm::Value *rep, bool legal)
llvm::FunctionType * getFunctionTypeForClone(llvm::FunctionType *FTy, DerivativeMode mode, unsigned width, llvm::Type *additionalArg, llvm::ArrayRef< DIFFE_TYPE > constant_args, bool diffeReturnArg, bool returnTape, bool returnPrimal, bool returnShadow)
void ReplaceFunctionImplementation(llvm::Module &M)
bool DetectReadonlyOrThrow(llvm::Module &M)
static llvm::SmallPtrSet< llvm::BasicBlock *, 4 > getGuaranteedUnreachable(llvm::Function *F)
static void getExitBlocks(const llvm::Loop *L, llvm::SmallPtrSetImpl< llvm::BasicBlock * > &ExitBlocks)
static void calculateUnusedValues(const llvm::Function &oldFunc, llvm::SmallPtrSetImpl< const llvm::Value * > &unnecessaryValues, llvm::SmallPtrSetImpl< const llvm::Instruction * > &unnecessaryInstructions, bool returnValue, llvm::function_ref< bool(const llvm::Value *)> valneeded, llvm::function_ref< UseReq(const llvm::Instruction *)> instneeded, llvm::function_ref< bool(const llvm::Instruction *, const llvm::Value *)> useneeded)
static void calculateUnusedStores(const llvm::Function &oldFunc, llvm::SmallPtrSetImpl< const llvm::Instruction * > &unnecessaryStores, llvm::function_ref< bool(const llvm::Instruction *)> needStore)
bool LowerSparsification(llvm::Function *F, bool replaceAll=true)
Lower __enzyme_todense, returning if changed.
static llvm::SmallVector< llvm::BasicBlock *, 3 > getLatches(const llvm::Loop *L, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &ExitBlocks)
static bool endsWith(llvm::StringRef string, llvm::StringRef suffix)
llvm::Function * CloneFunctionWithReturns(DerivativeMode mode, unsigned width, llvm::Function *&F, llvm::ValueToValueMapTy &ptrInputs, llvm::ArrayRef< DIFFE_TYPE > constant_args, llvm::SmallPtrSetImpl< llvm::Value * > &constants, llvm::SmallPtrSetImpl< llvm::Value * > &nonconstant, llvm::SmallPtrSetImpl< llvm::Value * > &returnvals, bool returnTape, bool returnPrimal, bool returnShadow, const llvm::Twine &name, llvm::ValueMap< const llvm::Value *, AssertingReplacingVH > *VMapO, bool diffeReturnArg, llvm::Type *additionalArg=nullptr)
std::map< llvm::Function *, llvm::Function * > CloneOrigin
llvm::Function * preprocessForClone(llvm::Function *F, DerivativeMode mode)
llvm::ModuleAnalysisManager MAM
void optimizeIntermediate(llvm::Function *F)
llvm::LoopAnalysisManager LAM
void AlwaysInline(llvm::Function *NewF)
llvm::AAResults & getAAResultsFromFunction(llvm::Function *NewF)
std::map< std::pair< llvm::Function *, DerivativeMode >, llvm::Function * > cache
llvm::FunctionAnalysisManager FAM
void LowerAllocAddr(llvm::Function *NewF)
PreProcessCache(PreProcessCache &)=delete
PreProcessCache(PreProcessCache &&prev)
void ReplaceReallocs(llvm::Function *NewF, bool mem2reg=false)
Calls to realloc with an appropriate implementation.