Enzyme main
Loading...
Searching...
No Matches
SimpleGVN.cpp
Go to the documentation of this file.
1//=- SimpleGVN.cpp - GVN-like load forwarding optimization ============//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains a GVN-like optimization pass that forwards loads from
22// noalias/nocapture arguments to their corresponding stores, with support
23// for offsets and type conversions.
24//
25// This pass addresses the limitation of LLVM's built-in GVN pass which has
26// a small limit on the number of instructions/memory offsets it analyzes
27// via its use of the memdep analysis.
28//
29// Algorithm:
30// 1. Identify function arguments with noalias and nocapture attributes
31// 2. Verify all uses are exclusively loads, stores, or GEP instructions
32// 3. For each load from such an argument:
33// a. Find all stores to the argument with constant offsets
34// b. Find a dominating store that covers the load's memory range
35// c. Check that no aliasing store exists between the store and load
36// d. If safe, replace the load with the stored value, performing
37// type conversion or extraction as needed
38//
39// Example transformation:
40// define i32 @foo(i32* noalias nocapture %ptr) {
41// store i32 42, i32* %ptr
42// %v = load i32, i32* %ptr
43// ret i32 %v
44// }
45// becomes:
46// define i32 @foo(i32* noalias nocapture %ptr) {
47// store i32 42, i32* %ptr
48// ret i32 42
49// }
50//
51//===----------------------------------------------------------------------===//
52#include <llvm/Config/llvm-config.h>
53
54#include "llvm/ADT/APInt.h"
55#include "llvm/ADT/DenseMap.h"
56#include "llvm/ADT/MapVector.h"
57#include "llvm/ADT/SmallPtrSet.h"
58#include "llvm/ADT/SmallVector.h"
59
60#include "llvm/IR/BasicBlock.h"
61#include "llvm/IR/Constants.h"
62#include "llvm/IR/DataLayout.h"
63#include "llvm/IR/Dominators.h"
64
65#include "llvm/IR/CFG.h"
66#include "llvm/IR/Function.h"
67#include "llvm/IR/GetElementPtrTypeIterator.h"
68#include "llvm/IR/IRBuilder.h"
69#include "llvm/IR/InstrTypes.h"
70#include "llvm/IR/Instruction.h"
71#include "llvm/IR/Instructions.h"
72#include "llvm/IR/Value.h"
73
74#include "llvm/Analysis/AliasAnalysis.h"
75#include "llvm/Analysis/MemoryLocation.h"
76#include "llvm/Analysis/TargetLibraryInfo.h"
77#include "llvm/Analysis/ValueTracking.h"
78
79#include "llvm/IR/LegacyPassManager.h"
80
81#include "llvm/Support/Debug.h"
82#include "llvm/Support/raw_ostream.h"
83
84#include "llvm/Transforms/Utils/Local.h"
85
86#include "SimpleGVN.h"
87#include "Utils.h"
88
89using namespace llvm;
90
91#ifdef DEBUG_TYPE
92#undef DEBUG_TYPE
93#endif
94#define DEBUG_TYPE "simple-gvn"
95
96namespace {
97
98// Extract a value with potential type conversion
99Value *extractValue(IRBuilder<> &Builder, Value *StoredVal, Type *LoadType,
100 const DataLayout &DL, APInt LoadOffset, APInt StoreOffset,
101 uint64_t LoadSize) {
102 Type *StoreType = StoredVal->getType();
103 uint64_t StoreSize = DL.getTypeStoreSize(StoreType);
104
105 // Calculate relative offset
106 int64_t RelativeOffset = (LoadOffset - StoreOffset).getSExtValue();
107
108 // Check if the load is completely within the stored value
109 if (RelativeOffset < 0 || (uint64_t)RelativeOffset + LoadSize > StoreSize) {
110 return nullptr;
111 }
112
113 // If types match and offsets are the same, return directly
114 if (RelativeOffset == 0 && LoadType == StoreType) {
115 return StoredVal;
116 }
117
118 if (RelativeOffset == 0 && isa<PointerType>(LoadType) &&
119 isa<PointerType>(StoreType)) {
120 return Builder.CreatePointerCast(StoredVal, LoadType);
121 }
122
123 if (RelativeOffset == 0 && StoreSize >= LoadSize &&
124 StoreType->isAggregateType()) {
125 auto first = Builder.CreateExtractValue(StoredVal, 0);
126 auto res = extractValue(Builder, first, LoadType, DL, LoadOffset,
127 StoreOffset, LoadSize);
128 if (res) {
129 return res;
130 } else {
131 if (auto I = dyn_cast<Instruction>(first))
132 I->eraseFromParent();
133 }
134 }
135
136 // Handle extraction with offset or type mismatch
137 // First, bitcast to an integer type if needed
138 if (!StoreType->isIntegerTy()) {
139 IntegerType *IntTy = Builder.getIntNTy(StoreSize * 8);
140 if (!CastInst::castIsValid(Instruction::BitCast, StoredVal->getType(),
141 IntTy)) {
142 return nullptr;
143 }
144 StoredVal = Builder.CreateBitCast(StoredVal, IntTy);
145 }
146
147 // Extract the relevant bits if there's an offset
148 if (RelativeOffset > 0) {
149 uint64_t ShiftBits = RelativeOffset * 8;
150 StoredVal = Builder.CreateLShr(StoredVal, ShiftBits);
151 }
152
153 // Truncate to the load size if needed
154 IntegerType *LoadIntTy = Builder.getIntNTy(LoadSize * 8);
155 if (StoredVal->getType() != LoadIntTy) {
156 StoredVal = Builder.CreateTrunc(StoredVal, LoadIntTy);
157 }
158
159 // Bitcast to the final type if needed
160 if (LoadIntTy != LoadType) {
161 if (LoadType->isPointerTy()) {
162 StoredVal = Builder.CreateIntToPtr(StoredVal, LoadType);
163 } else {
164 if (!CastInst::castIsValid(Instruction::BitCast, StoredVal->getType(),
165 LoadType)) {
166 return nullptr;
167 }
168 StoredVal = Builder.CreateBitCast(StoredVal, LoadType);
169 }
170 }
171
172 return StoredVal;
173}
174
175// Helper to check if a source instruction dominates and completely covers a
176// target instruction's memory access
177// For stores: checks if store covers a load
178// For loads: checks if load covers another load
179static bool dominatesAndCovers(Instruction *Source, Instruction *Target,
180 const APInt &SourceOffset,
181 const APInt &TargetOffset, uint64_t TargetSize,
182 const DataLayout &DL, DominatorTree &DT) {
183 if (!DT.dominates(Source, Target))
184 return false;
185
186 // Get the size of the source memory access
187 uint64_t SourceSize;
188 if (auto *SI = dyn_cast<StoreInst>(Source)) {
189 SourceSize = DL.getTypeStoreSize(SI->getValueOperand()->getType());
190 } else if (auto *LI = dyn_cast<LoadInst>(Source)) {
191 SourceSize = DL.getTypeStoreSize(LI->getType());
192 } else {
193 return false;
194 }
195
196 int64_t RelOffset = (TargetOffset - SourceOffset).getSExtValue();
197 return RelOffset >= 0 && (uint64_t)RelOffset + TargetSize <= SourceSize;
198}
199
200// Helper to check if two memory ranges alias
201// Range1: [Offset1, Offset1 + Size1)
202// Range2: [Offset2, Offset2 + Size2)
203static bool memoryRangesAlias(const APInt &Offset1, uint64_t Size1,
204 const APInt &Offset2, uint64_t Size2) {
205 // Check if range2 ends before range1 begins
206 if ((Offset2 + Size2).sle(Offset1))
207 return false;
208
209 // Check if range1 ends before range2 begins
210 if ((Offset1 + Size1).sle(Offset2))
211 return false;
212
213 // Otherwise, they may alias
214 return true;
215}
216
217// Collect memory operations (loads, stores) and calls for a given pointer value
218// Returns false if the value has uses that prevent optimization
219// Nocapture calls are only rejected (causing failure) if Calls is empty on
220// entry If Calls is non-empty on entry, nocapture calls are collected
221static bool
222collectMemoryOps(Value *Arg, const DataLayout &DL,
223 SmallVectorImpl<std::pair<StoreInst *, APInt>> &Stores,
224 SmallVectorImpl<std::pair<LoadInst *, APInt>> &Loads,
225 SmallVectorImpl<std::pair<CallInst *, APInt>> &Calls) {
226 // WorkList tracks (Value*, Offset from Arg)
227 SmallVector<std::pair<Value *, APInt>, 16> ToProcess;
228 SmallPtrSet<Value *, 16> Visited;
229
230 APInt ZeroOffset(DL.getIndexTypeSizeInBits(Arg->getType()), 0);
231 ToProcess.push_back({Arg, ZeroOffset});
232
233 while (!ToProcess.empty()) {
234 auto [V, CurrentOffset] = ToProcess.pop_back_val();
235
236 // Skip if already visited
237 if (!Visited.insert(V).second)
238 continue;
239
240 for (Use &U : V->uses()) {
241 User *Usr = U.getUser();
242 if (auto *LI = dyn_cast<LoadInst>(Usr)) {
243 Loads.push_back({LI, CurrentOffset});
244 } else if (auto *SI = dyn_cast<StoreInst>(Usr)) {
245 // Check if this is a store TO the pointer (not storing the pointer
246 // value)
247 if (SI->getPointerOperand() == V) {
248 Stores.push_back({SI, CurrentOffset});
249 } else {
250 // Pointer value is being stored somewhere - reject this argument
251 return false;
252 }
253 } else if (auto *GEP = dyn_cast<GetElementPtrInst>(Usr)) {
254 // Compute the offset for this GEP
255 APInt GEPOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
256 if (!GEP->accumulateConstantOffset(DL, GEPOffset)) {
257 // Cannot compute constant offset - reject this argument
258 return false;
259 }
260
261 APInt NewOffset = CurrentOffset + GEPOffset;
262 ToProcess.push_back({GEP, NewOffset});
263 } else if (auto *CI = dyn_cast<CastInst>(Usr)) {
264 // Casts don't change offset
265 ToProcess.push_back({CI, CurrentOffset});
266 } else if (auto *Call = dyn_cast<CallInst>(Usr)) {
267 // Get the argument index from the Use
268 unsigned ArgIdx = U.getOperandNo();
269 if (isNoCapture(Call, ArgIdx)) {
270 Calls.push_back({Call, CurrentOffset});
271 } else {
272 // Call that may capture - reject this argument
273 return false;
274 }
275 } else {
276 // Unknown use - reject this argument
277 return false;
278 }
279 }
280 }
281
282 return true;
283}
284
285// Main optimization function
286bool simplifyGVN(Function &F, DominatorTree &DT, const DataLayout &DL) {
287 bool Changed = false;
288
289 // Find noalias arguments
290 SmallVector<Value *, 4> CandidateArgs;
291 for (Argument &Arg : F.args()) {
292 if (Arg.getType()->isPointerTy() && Arg.hasNoAliasAttr()) {
293 CandidateArgs.push_back(&Arg);
294 }
295 }
296
297 for (BasicBlock &BB : F) {
298 for (Instruction &I : BB) {
299 if (isa<AllocaInst>(&I)) {
300 CandidateArgs.push_back(&I);
301 }
302 }
303 }
304
305 if (CandidateArgs.empty())
306 return false;
307
308 // For each candidate argument, collect stores and loads with their offsets
309 for (Value *Arg : CandidateArgs) {
310 // Collect all stores and loads to this argument with offsets
311 SmallVector<std::pair<StoreInst *, APInt>, 8> Stores;
312 SmallVector<std::pair<LoadInst *, APInt>, 8> Loads;
313 SmallVector<std::pair<CallInst *, APInt>, 8> Calls;
314
315 // First pass: strict collection (no nocapture calls) for store-load
316 // forwarding (pass empty Calls to reject nocapture calls)
317 if (!collectMemoryOps(Arg, DL, Stores, Loads, Calls)) {
318 // Argument has uses that prevent optimization
319 continue;
320 }
321
322 APInt ZeroOffset(DL.getIndexTypeSizeInBits(Arg->getType()), 0);
323
324 // Try to forward {stores, previous loads} to loads using simplified
325 // algorithm
326 for (auto &[LI, LoadOffset] : Loads) {
327 uint64_t LoadSize = DL.getTypeStoreSize(LI->getType());
328
329 // Step 1: Find all stores that may alias with this load
330 SmallVector<std::tuple<Instruction *, APInt, uint64_t>, 8> AliasingStores;
331 for (auto &[SI, StoreOffset] : Stores) {
332 uint64_t StoreSize =
333 DL.getTypeStoreSize(SI->getValueOperand()->getType());
334 if (memoryRangesAlias(LoadOffset, LoadSize, StoreOffset, StoreSize)) {
335 AliasingStores.push_back({SI, StoreOffset, StoreSize});
336 }
337 }
338
339 // Assume the call can touch any memory, so just set it to directly
340 // overlap.
341 for (auto &[CI, CallOffset] : Calls) {
342 AliasingStores.push_back({CI, LoadOffset, LoadSize});
343 }
344
345 // Step 2: Filter to dominating + covering stores
346 // Tuple of instruction storing, offset in the instruction, and the
347 // equivalent value.
348 SmallVector<std::tuple<Instruction *, APInt, Value *>, 8>
349 DominatingCoveringStores;
350 for (auto &[I, StoreOffset, StoreSize] : AliasingStores) {
351 if (auto SI = dyn_cast<StoreInst>(I))
352 if (dominatesAndCovers(SI, LI, StoreOffset, LoadOffset, LoadSize, DL,
353 DT)) {
354 DominatingCoveringStores.push_back(
355 {SI, StoreOffset, SI->getValueOperand()});
356 }
357 }
358
359 // Step 3: If only one aliasing store and it's dominating+covering,
360 // forward
361 if (AliasingStores.size() == 1 && DominatingCoveringStores.size() == 1) {
362 Instruction *SI = std::get<0>(DominatingCoveringStores[0]);
363 APInt StoreOffset = std::get<1>(DominatingCoveringStores[0]);
364
365 IRBuilder<> Builder(LI);
366 Value *StoredVal = std::get<2>(DominatingCoveringStores[0]);
367 Value *ExtractedVal =
368 extractValue(Builder, StoredVal, LI->getType(), DL, LoadOffset,
369 StoreOffset, LoadSize);
370
371 if (ExtractedVal) {
372 LLVM_DEBUG(dbgs() << "SimpleGVN: Forwarding (single alias)\n"
373 << " Store: " << *SI << "\n"
374 << " Load: " << *LI << "\n");
375 LI->replaceAllUsesWith(ExtractedVal);
376 LI->eraseFromParent();
377 LI = nullptr;
378 Changed = true;
379 }
380 continue;
381 }
382
383 for (auto &[LI2, LoadOffset2] : Loads) {
384 if (!LI2 || LI2 == LI)
385 continue;
386 if (dominatesAndCovers(LI2, LI, LoadOffset2, LoadOffset, LoadSize, DL,
387 DT)) {
388 DominatingCoveringStores.emplace_back(LI2, LoadOffset2, LI2);
389 }
390 }
391
392 // Step 4: If no dominating+covering stores, bail
393 if (DominatingCoveringStores.empty()) {
394 continue;
395 }
396
397 // Step 5: Build map of last store in each block before LI
398 DenseMap<BasicBlock *, std::tuple<Instruction *, APInt, uint64_t>>
399 LastStoreInBlockBeforeLI;
400 for (auto &[SI, StoreOffset, Size] : AliasingStores) {
401 BasicBlock *BB = SI->getParent();
402 if (BB == LI->getParent()) {
403 // Only consider stores before LI in the same block
404 if (SI->comesBefore(LI)) {
405 auto &Entry = LastStoreInBlockBeforeLI[BB];
406 if (!std::get<0>(Entry) || std::get<0>(Entry)->comesBefore(SI)) {
407 Entry = {SI, StoreOffset, Size};
408 }
409 }
410 } else {
411 // For other blocks, take the last store in the block
412 auto &Entry = LastStoreInBlockBeforeLI[BB];
413 if (!std::get<0>(Entry) || std::get<0>(Entry)->comesBefore(SI)) {
414 Entry = {SI, StoreOffset, Size};
415 }
416 }
417 }
418
419 // Step 6: Check if LI's parent block has a dominating+covering store
420 BasicBlock *LIBlock = LI->getParent();
421 auto It = LastStoreInBlockBeforeLI.find(LIBlock);
422 if (It != LastStoreInBlockBeforeLI.end()) {
423 Instruction *SI = std::get<0>(It->second);
424
425 for (auto &&[DCS, StoreOffset, StoredVal] : DominatingCoveringStores) {
426 if (SI == DCS ||
427 (DCS->getParent() == LI->getParent() && SI->comesBefore(DCS))) {
428
429 IRBuilder<> Builder(LI);
430 Value *ExtractedVal =
431 extractValue(Builder, StoredVal, LI->getType(), DL, LoadOffset,
432 StoreOffset, LoadSize);
433
434 if (ExtractedVal) {
435 LLVM_DEBUG(dbgs() << "SimpleGVN: Forwarding (same block)\n"
436 << " Store: " << *DCS << "\n"
437 << " Load: " << *LI << "\n");
438 LI->replaceAllUsesWith(ExtractedVal);
439 LI->eraseFromParent();
440 LI = nullptr;
441 Changed = true;
442 break;
443 }
444 }
445 }
446 continue;
447 } else {
448 for (auto &&[DCS, StoreOffset, StoredVal] : DominatingCoveringStores) {
449 if (DCS->getParent() == LI->getParent()) {
450
451 IRBuilder<> Builder(LI);
452 Value *ExtractedVal =
453 extractValue(Builder, StoredVal, LI->getType(), DL, LoadOffset,
454 StoreOffset, LoadSize);
455
456 if (ExtractedVal) {
457 LLVM_DEBUG(dbgs() << "SimpleGVN: Forwarding (same block)\n"
458 << " Store: " << *DCS << "\n"
459 << " Load: " << *LI << "\n");
460 LI->replaceAllUsesWith(ExtractedVal);
461 LI->eraseFromParent();
462 LI = nullptr;
463 Changed = true;
464 break;
465 }
466 }
467 }
468 if (LI == nullptr) {
469 continue;
470 }
471 }
472
473 // Step 7: BFS backwards from LI's parent block
474 SmallPtrSet<BasicBlock *, 32> Visited;
475 SmallVector<BasicBlock *, 16> Worklist;
476 StoreInst *Candidate = nullptr;
477 APInt CandidateOffset = ZeroOffset;
478
479 // Start with predecessors of LI's block
480 for (BasicBlock *Pred : predecessors(LIBlock)) {
481 if (Visited.insert(Pred).second)
482 Worklist.push_back(Pred);
483 }
484
485 while (!Worklist.empty()) {
486 BasicBlock *BB = Worklist.pop_back_val();
487
488 auto It = LastStoreInBlockBeforeLI.find(BB);
489 if (It != LastStoreInBlockBeforeLI.end()) {
490 StoreInst *SI = dyn_cast<StoreInst>(std::get<0>(It->second));
491 APInt StoreOffset = std::get<1>(It->second);
492
493 if (!SI || !dominatesAndCovers(SI, LI, StoreOffset, LoadOffset,
494 LoadSize, DL, DT)) {
495 // Non-dominating+covering store on path, bail
496 Candidate = nullptr;
497 break;
498 }
499
500 // Found dominating+covering store
501 if (!Candidate) {
502 Candidate = SI;
503 CandidateOffset = StoreOffset;
504 } else if (Candidate != SI) {
505 // Multiple different candidates, bail
506 Candidate = nullptr;
507 break;
508 }
509 }
510
511 // Continue BFS
512 for (BasicBlock *Pred : predecessors(BB)) {
513 if (Visited.insert(Pred).second)
514 Worklist.push_back(Pred);
515 }
516 }
517
518 // Step 8: If unique candidate found, forward
519 if (Candidate) {
520 IRBuilder<> Builder(LI);
521 Value *StoredVal = Candidate->getValueOperand();
522 Value *ExtractedVal =
523 extractValue(Builder, StoredVal, LI->getType(), DL, LoadOffset,
524 CandidateOffset, LoadSize);
525
526 if (ExtractedVal) {
527 LLVM_DEBUG(dbgs() << "SimpleGVN: Forwarding (BFS candidate)\n"
528 << " Store: " << *Candidate << "\n"
529 << " Load: " << *LI << "\n");
530 LI->replaceAllUsesWith(ExtractedVal);
531 LI->eraseFromParent();
532 LI = nullptr;
533 Changed = true;
534 }
535 }
536 }
537 }
538 return Changed;
539}
540
541class SimpleGVN final : public FunctionPass {
542public:
543 static char ID;
544 SimpleGVN() : FunctionPass(ID) {}
545
546 void getAnalysisUsage(AnalysisUsage &AU) const override {
547 AU.addRequired<DominatorTreeWrapperPass>();
548 }
549
550 bool runOnFunction(Function &F) override {
551 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
552 const DataLayout &DL = F.getParent()->getDataLayout();
553 return simplifyGVN(F, DT, DL);
554 }
555};
556
557} // namespace
558
559FunctionPass *createSimpleGVNPass() { return new SimpleGVN(); }
560
561extern "C" void LLVMAddSimpleGVNPass(LLVMPassManagerRef PM) {
562 unwrap(PM)->add(createSimpleGVNPass());
563}
564
565char SimpleGVN::ID = 0;
566
567static RegisterPass<SimpleGVN> X("simple-gvn",
568 "GVN-like load forwarding optimization");
569
571 FunctionAnalysisManager &FAM) {
572 bool Changed = false;
573 const DataLayout &DL = F.getParent()->getDataLayout();
574 Changed = simplifyGVN(F, FAM.getResult<DominatorTreeAnalysis>(F), DL);
575 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
576}
577
578llvm::AnalysisKey SimpleGVNNewPM::Key;
static RegisterPass< SimpleGVN > X("simple-gvn", "GVN-like load forwarding optimization")
FunctionPass * createSimpleGVNPass()
void LLVMAddSimpleGVNPass(LLVMPassManagerRef PM)
static bool isNoCapture(const llvm::CallBase *call, size_t idx)
Definition Utils.h:1840
llvm::PreservedAnalyses Result
Definition SimpleGVN.h:49
Result run(llvm::Function &F, llvm::FunctionAnalysisManager &FAM)