Enzyme main
Loading...
Searching...
No Matches
FunctionUtils.h
Go to the documentation of this file.
1//===- FunctionUtils.h - Declaration of function utilities ---------------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file declares utilities on LLVM Functions that are used as part of the
22// AD process.
23//
24//===----------------------------------------------------------------------===//
25#ifndef ENZYME_FUNCTION_UTILS_H
26#define ENZYME_FUNCTION_UTILS_H
27
28#include <deque>
29#include <set>
30
31#include <llvm/Config/llvm-config.h>
32
33#if LLVM_VERSION_MAJOR >= 16
34#define private public
35#include "llvm/Analysis/ScalarEvolution.h"
36#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
37#undef private
38#else
39#include "SCEV/ScalarEvolution.h"
40#include "SCEV/ScalarEvolutionExpander.h"
41#endif
42
43#include "Utils.h"
44
45#include "llvm/Analysis/AliasAnalysis.h"
46#include "llvm/Analysis/LoopAnalysisManager.h"
47#include "llvm/Analysis/TargetLibraryInfo.h"
48
49#include "llvm/IR/Function.h"
50#include "llvm/IR/Module.h"
51#include "llvm/IR/Type.h"
52
53#include "llvm/IR/Instructions.h"
54#include "llvm/Transforms/Utils/ValueMapper.h"
55
56#include "llvm/ADT/STLExtras.h"
57
58//;
59
60extern "C" {
61extern llvm::cl::opt<bool> EnzymeAlwaysInlineDiff;
62}
63
64// Perform an analysis to detect functions which only write to visible memory
65// outside the function if an error is not throw. Such a function can touch
66// inaccessible memory [e.g. the insides of malloc/etc], and the only violation
67// is whether existing memory before the call is written to.
68// If non-local, returning memory written to is a violation (since it writes to
69// externally visible memory).
70// If local, returning memory written to is fine (since existing memory before
71// the call remains unchanged).
72// In other words, malloc [local and non-local], calloc [local and non-local],
73// copy_array [local only], and friends, are all considered
74// readonly_or_throw, as they only either read externally visible state, throw
75// an error, or write to inaccesible memory.
76bool DetectReadonlyOrThrow(llvm::Module &M);
77
79public:
82 // Using the default move constructor will botch the FAM/MAM proxy passes
83 // since now the new location of FAM/MAM will not be used. Therefore, use a
84 // custom move constructor and default initialize these, and move the
85 // cache/origin maps.
87 cache = std::move(prev.cache);
88 CloneOrigin = std::move(prev.CloneOrigin);
89 };
90
91 llvm::LoopAnalysisManager LAM;
92 llvm::FunctionAnalysisManager FAM;
93 llvm::ModuleAnalysisManager MAM;
94
95 std::map<std::pair<llvm::Function *, DerivativeMode>, llvm::Function *> cache;
96 std::map<llvm::Function *, llvm::Function *> CloneOrigin;
97
98 llvm::Function *preprocessForClone(llvm::Function *F, DerivativeMode mode);
99
100 llvm::AAResults &getAAResultsFromFunction(llvm::Function *NewF);
101
102 llvm::Function *CloneFunctionWithReturns(
103 DerivativeMode mode, unsigned width, llvm::Function *&F,
104 llvm::ValueToValueMapTy &ptrInputs,
105 llvm::ArrayRef<DIFFE_TYPE> constant_args,
106 llvm::SmallPtrSetImpl<llvm::Value *> &constants,
107 llvm::SmallPtrSetImpl<llvm::Value *> &nonconstant,
108 llvm::SmallPtrSetImpl<llvm::Value *> &returnvals, bool returnTape,
109 bool returnPrimal, bool returnShadow, const llvm::Twine &name,
110 llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> *VMapO,
111 bool diffeReturnArg, llvm::Type *additionalArg = nullptr);
112
113 void ReplaceReallocs(llvm::Function *NewF, bool mem2reg = false);
114 void LowerAllocAddr(llvm::Function *NewF);
115 void AlwaysInline(llvm::Function *NewF);
116 void optimizeIntermediate(llvm::Function *F);
117
118 void clear();
119};
120
121class GradientUtils;
122
123static inline void
124getExitBlocks(const llvm::Loop *L,
125 llvm::SmallPtrSetImpl<llvm::BasicBlock *> &ExitBlocks) {
126 llvm::SmallVector<llvm::BasicBlock *, 8> PotentialExitBlocks;
127 L->getExitBlocks(PotentialExitBlocks);
128 for (auto a : PotentialExitBlocks) {
129
130 llvm::SmallVector<llvm::BasicBlock *, 4> tocheck;
131 llvm::SmallPtrSet<llvm::BasicBlock *, 4> checked;
132 tocheck.push_back(a);
133
134 bool isExit = false;
135
136 while (tocheck.size()) {
137 auto foo = tocheck.back();
138 tocheck.pop_back();
139 if (checked.count(foo)) {
140 isExit = true;
141 goto exitblockcheck;
142 }
143 checked.insert(foo);
144 if (auto bi = llvm::dyn_cast<llvm::BranchInst>(foo->getTerminator())) {
145 for (auto nb : bi->successors()) {
146 if (L->contains(nb))
147 continue;
148 tocheck.push_back(nb);
149 }
150 } else if (llvm::isa<llvm::UnreachableInst>(foo->getTerminator())) {
151 continue;
152 } else {
153 isExit = true;
154 goto exitblockcheck;
155 }
156 }
157
158 exitblockcheck:
159 if (isExit) {
160 ExitBlocks.insert(a);
161 }
162 }
163}
164
165static inline llvm::SmallVector<llvm::BasicBlock *, 3>
166getLatches(const llvm::Loop *L,
167 const llvm::SmallPtrSetImpl<llvm::BasicBlock *> &ExitBlocks) {
168 llvm::BasicBlock *Preheader = L->getLoopPreheader();
169 if (!Preheader) {
170 llvm::errs() << *L->getHeader()->getParent() << "\n";
171 llvm::errs() << *L->getHeader() << "\n";
172 llvm::errs() << *L << "\n";
173 }
174 assert(Preheader && "requires preheader");
175
176 // Find latch, defined as a (perhaps unique) block in loop that branches to
177 // exit block
178 llvm::SmallVector<llvm::BasicBlock *, 3> Latches;
179 for (llvm::BasicBlock *ExitBlock : ExitBlocks) {
180 for (llvm::BasicBlock *pred : llvm::predecessors(ExitBlock)) {
181 if (L->contains(pred)) {
182 if (std::find(Latches.begin(), Latches.end(), pred) != Latches.end())
183 continue;
184 Latches.push_back(pred);
185 }
186 }
187 }
188 return Latches;
189}
190
191// TODO note this doesn't go through [loop, unreachable], and we could get more
192// performance by doing this can consider doing some domtree magic potentially
193static inline llvm::SmallPtrSet<llvm::BasicBlock *, 4>
194getGuaranteedUnreachable(llvm::Function *F) {
195 llvm::SmallPtrSet<llvm::BasicBlock *, 4> knownUnreachables;
196 if (F->empty())
197 return knownUnreachables;
198 std::deque<llvm::BasicBlock *> todo;
199 for (auto &BB : *F) {
200 todo.push_back(&BB);
201 }
202
203 while (!todo.empty()) {
204 llvm::BasicBlock *next = todo.front();
205 todo.pop_front();
206
207 if (knownUnreachables.find(next) != knownUnreachables.end())
208 continue;
209
210 if (llvm::isa<llvm::ReturnInst>(next->getTerminator()))
211 continue;
212
213 if (llvm::isa<llvm::UnreachableInst>(next->getTerminator())) {
214 knownUnreachables.insert(next);
215 for (llvm::BasicBlock *Pred : predecessors(next)) {
216 todo.push_back(Pred);
217 }
218 continue;
219 }
220
221 // Assume resumes don't happen
222 // TODO consider EH
223 if (llvm::isa<llvm::ResumeInst>(next->getTerminator())) {
224 knownUnreachables.insert(next);
225 for (llvm::BasicBlock *Pred : predecessors(next)) {
226 todo.push_back(Pred);
227 }
228 continue;
229 }
230
231 bool unreachable = true;
232 for (llvm::BasicBlock *Succ : llvm::successors(next)) {
233 if (knownUnreachables.find(Succ) == knownUnreachables.end()) {
234 unreachable = false;
235 break;
236 }
237 }
238
239 if (!unreachable)
240 continue;
241 knownUnreachables.insert(next);
242 for (llvm::BasicBlock *Pred : llvm::predecessors(next)) {
243 todo.push_back(Pred);
244 }
245 continue;
246 }
247
248 return knownUnreachables;
249}
250
251enum class UseReq {
252 Need,
253 Recur,
254 Cached,
255};
256static inline void calculateUnusedValues(
257 const llvm::Function &oldFunc,
258 llvm::SmallPtrSetImpl<const llvm::Value *> &unnecessaryValues,
259 llvm::SmallPtrSetImpl<const llvm::Instruction *> &unnecessaryInstructions,
260 bool returnValue, llvm::function_ref<bool(const llvm::Value *)> valneeded,
261 llvm::function_ref<UseReq(const llvm::Instruction *)> instneeded,
262 llvm::function_ref<bool(const llvm::Instruction *, const llvm::Value *)>
263 useneeded) {
264
265 std::deque<const llvm::Instruction *> todo;
266
267 for (const llvm::BasicBlock &BB : oldFunc) {
268 if (auto ri = llvm::dyn_cast<llvm::ReturnInst>(BB.getTerminator())) {
269 if (!returnValue) {
270 unnecessaryInstructions.insert(ri);
271 }
272 unnecessaryValues.insert(ri);
273 }
274 for (auto &inst : BB) {
275 if (&inst == BB.getTerminator())
276 continue;
277 todo.push_back(&inst);
278 }
279 }
280
281 while (!todo.empty()) {
282 auto inst = todo.front();
283 todo.pop_front();
284
285 if (unnecessaryInstructions.count(inst)) {
286 assert(unnecessaryValues.count(inst));
287 continue;
288 }
289
290 if (!unnecessaryValues.count(inst)) {
291
292 if (valneeded(inst)) {
293 continue;
294 }
295
296 bool necessaryUse = false;
297
298 llvm::SmallPtrSet<const llvm::Instruction *, 4> seen;
299 std::deque<const llvm::Instruction *> users;
300
301 for (auto user_dtx : inst->users()) {
302 if (auto cst = llvm::dyn_cast<llvm::Instruction>(user_dtx)) {
303 if (useneeded(cst, inst))
304 users.push_back(cst);
305 }
306 }
307
308 while (users.size()) {
309 auto val = users.front();
310 users.pop_front();
311
312 if (seen.count(val))
313 continue;
314 seen.insert(val);
315
316 if (unnecessaryInstructions.count(val))
317 continue;
318
319 switch (instneeded(val)) {
320 case UseReq::Need:
321 necessaryUse = true;
322 break;
323 case UseReq::Recur:
324 for (auto user_dtx : val->users()) {
325 if (auto cst = llvm::dyn_cast<llvm::Instruction>(user_dtx)) {
326 if (useneeded(cst, val))
327 users.push_back(cst);
328 }
329 }
330 break;
331 case UseReq::Cached:
332 break;
333 }
334 if (necessaryUse)
335 break;
336 }
337
338 if (necessaryUse)
339 continue;
340
341 unnecessaryValues.insert(inst);
342
343 for (auto user : inst->users()) {
344 if (auto usedinst = llvm::dyn_cast<llvm::Instruction>(user))
345 todo.push_back(usedinst);
346 }
347 }
348
349 if (instneeded(inst) == UseReq::Need)
350 continue;
351
352 unnecessaryInstructions.insert(inst);
353
354 for (auto &operand : inst->operands()) {
355 if (auto usedinst = llvm::dyn_cast<llvm::Instruction>(operand.get())) {
356 todo.push_back(usedinst);
357 }
358 }
359 }
360
361 if (false && endsWith(oldFunc.getName(), "subfn")) {
362 llvm::errs() << "Prepping values for: " << oldFunc.getName()
363 << " returnValue: " << returnValue << "\n";
364 for (auto v : unnecessaryInstructions) {
365 llvm::errs() << "+ unnecessaryInstructions: " << *v << "\n";
366 }
367 for (auto v : unnecessaryValues) {
368 llvm::errs() << "+ unnecessaryValues: " << *v << "\n";
369 }
370 llvm::errs() << "</end>\n";
371 }
372}
373
374static inline void calculateUnusedStores(
375 const llvm::Function &oldFunc,
376 llvm::SmallPtrSetImpl<const llvm::Instruction *> &unnecessaryStores,
377 llvm::function_ref<bool(const llvm::Instruction *)> needStore) {
378
379 std::deque<const llvm::Instruction *> todo;
380
381 for (const llvm::BasicBlock &BB : oldFunc) {
382 for (auto &inst : BB) {
383 if (&inst == BB.getTerminator())
384 continue;
385 todo.push_back(&inst);
386 }
387 }
388
389 while (!todo.empty()) {
390 auto inst = todo.front();
391 todo.pop_front();
392
393 if (unnecessaryStores.count(inst)) {
394 continue;
395 }
396
397 if (needStore(inst))
398 continue;
399
400 unnecessaryStores.insert(inst);
401 }
402}
403
404void RecursivelyReplaceAddressSpace(llvm::Value *AI, llvm::Value *rep,
405 bool legal);
406
407void ReplaceFunctionImplementation(llvm::Module &M);
408
409/// Is the use of value val as an argument of call CI potentially captured
410bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val);
411
412llvm::FunctionType *getFunctionTypeForClone(
413 llvm::FunctionType *FTy, DerivativeMode mode, unsigned width,
414 llvm::Type *additionalArg, llvm::ArrayRef<DIFFE_TYPE> constant_args,
415 bool diffeReturnArg, bool returnTape, bool returnPrimal, bool returnShadow);
416
417/// Lower __enzyme_todense, returning if changed.
418bool LowerSparsification(llvm::Function *F, bool replaceAll = true);
419
420#endif
bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val)
Is the use of value val as an argument of call CI potentially captured.
UseReq
llvm::cl::opt< bool > EnzymeAlwaysInlineDiff
void RecursivelyReplaceAddressSpace(llvm::Value *AI, llvm::Value *rep, bool legal)
llvm::FunctionType * getFunctionTypeForClone(llvm::FunctionType *FTy, DerivativeMode mode, unsigned width, llvm::Type *additionalArg, llvm::ArrayRef< DIFFE_TYPE > constant_args, bool diffeReturnArg, bool returnTape, bool returnPrimal, bool returnShadow)
void ReplaceFunctionImplementation(llvm::Module &M)
bool DetectReadonlyOrThrow(llvm::Module &M)
static llvm::SmallPtrSet< llvm::BasicBlock *, 4 > getGuaranteedUnreachable(llvm::Function *F)
static void getExitBlocks(const llvm::Loop *L, llvm::SmallPtrSetImpl< llvm::BasicBlock * > &ExitBlocks)
static void calculateUnusedValues(const llvm::Function &oldFunc, llvm::SmallPtrSetImpl< const llvm::Value * > &unnecessaryValues, llvm::SmallPtrSetImpl< const llvm::Instruction * > &unnecessaryInstructions, bool returnValue, llvm::function_ref< bool(const llvm::Value *)> valneeded, llvm::function_ref< UseReq(const llvm::Instruction *)> instneeded, llvm::function_ref< bool(const llvm::Instruction *, const llvm::Value *)> useneeded)
static void calculateUnusedStores(const llvm::Function &oldFunc, llvm::SmallPtrSetImpl< const llvm::Instruction * > &unnecessaryStores, llvm::function_ref< bool(const llvm::Instruction *)> needStore)
bool LowerSparsification(llvm::Function *F, bool replaceAll=true)
Lower __enzyme_todense, returning if changed.
static llvm::SmallVector< llvm::BasicBlock *, 3 > getLatches(const llvm::Loop *L, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &ExitBlocks)
static bool endsWith(llvm::StringRef string, llvm::StringRef suffix)
Definition Utils.h:721
DerivativeMode
Definition Utils.h:390
DerivativeMode mode
llvm::Function * oldFunc
llvm::Function * CloneFunctionWithReturns(DerivativeMode mode, unsigned width, llvm::Function *&F, llvm::ValueToValueMapTy &ptrInputs, llvm::ArrayRef< DIFFE_TYPE > constant_args, llvm::SmallPtrSetImpl< llvm::Value * > &constants, llvm::SmallPtrSetImpl< llvm::Value * > &nonconstant, llvm::SmallPtrSetImpl< llvm::Value * > &returnvals, bool returnTape, bool returnPrimal, bool returnShadow, const llvm::Twine &name, llvm::ValueMap< const llvm::Value *, AssertingReplacingVH > *VMapO, bool diffeReturnArg, llvm::Type *additionalArg=nullptr)
std::map< llvm::Function *, llvm::Function * > CloneOrigin
llvm::Function * preprocessForClone(llvm::Function *F, DerivativeMode mode)
llvm::ModuleAnalysisManager MAM
void optimizeIntermediate(llvm::Function *F)
llvm::LoopAnalysisManager LAM
void AlwaysInline(llvm::Function *NewF)
llvm::AAResults & getAAResultsFromFunction(llvm::Function *NewF)
std::map< std::pair< llvm::Function *, DerivativeMode >, llvm::Function * > cache
llvm::FunctionAnalysisManager FAM
void LowerAllocAddr(llvm::Function *NewF)
PreProcessCache(PreProcessCache &)=delete
PreProcessCache(PreProcessCache &&prev)
void ReplaceReallocs(llvm::Function *NewF, bool mem2reg=false)
Calls to realloc with an appropriate implementation.