Enzyme main
Loading...
Searching...
No Matches
FunctionUtils.cpp
Go to the documentation of this file.
1//===- FunctionUtils.cpp - Implementation of function utilities -----------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file defines utilities on LLVM Functions that are used as part of the AD
22// process.
23//
24//===----------------------------------------------------------------------===//
25#include "FunctionUtils.h"
26
27#include "DiffeGradientUtils.h"
28#include "EnzymeLogic.h"
29#include "GradientUtils.h"
30#include "LibraryFuncs.h"
31
32#include "llvm/IR/Attributes.h"
33#include "llvm/IR/BasicBlock.h"
34#include "llvm/IR/DebugInfoMetadata.h"
35#include "llvm/IR/DerivedTypes.h"
36#include "llvm/IR/Function.h"
37#include "llvm/IR/IRBuilder.h"
38#include "llvm/IR/Module.h"
39#include "llvm/IR/Type.h"
40#include "llvm/IR/Verifier.h"
41#include "llvm/Passes/PassBuilder.h"
42
43#include "llvm/ADT/APSInt.h"
44#include "llvm/ADT/DenseMapInfo.h"
45#include "llvm/ADT/SetOperations.h"
46#include "llvm/ADT/SetVector.h"
47#include "llvm/Analysis/AliasAnalysis.h"
48#include "llvm/Analysis/AssumptionCache.h"
49#include "llvm/Analysis/BasicAliasAnalysis.h"
50#include "llvm/Analysis/CallGraph.h"
51#include "llvm/Analysis/GlobalsModRef.h"
52#include "llvm/Analysis/LazyValueInfo.h"
53#include "llvm/Analysis/LoopInfo.h"
54#include "llvm/Analysis/MemoryDependenceAnalysis.h"
55#include "llvm/Analysis/MemorySSA.h"
56#include "llvm/Analysis/OptimizationRemarkEmitter.h"
57#include <set>
58
59#if LLVM_VERSION_MAJOR < 16
60#include "llvm/Analysis/CFLSteensAliasAnalysis.h"
61#endif
62#include "llvm/Analysis/DependenceAnalysis.h"
63#include "llvm/Analysis/TypeBasedAliasAnalysis.h"
64#include "llvm/CodeGen/UnreachableBlockElim.h"
65
66#include "llvm/Analysis/PhiValues.h"
67#include "llvm/Analysis/ProfileSummaryInfo.h"
68#include "llvm/Analysis/ScalarEvolution.h"
69#include "llvm/Analysis/ScopedNoAliasAA.h"
70#include "llvm/Analysis/TargetTransformInfo.h"
71
72#include "llvm/Support/TimeProfiler.h"
73
74#include "llvm/Transforms/IPO/FunctionAttrs.h"
75#include "llvm/Transforms/Utils/Mem2Reg.h"
76
77#include "llvm/Transforms/Utils.h"
78
79#include "llvm/Transforms/InstCombine/InstCombine.h"
80#include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h"
81#include "llvm/Transforms/Scalar/DCE.h"
82#include "llvm/Transforms/Scalar/DeadStoreElimination.h"
83#include "llvm/Transforms/Scalar/EarlyCSE.h"
84#include "llvm/Transforms/Scalar/GVN.h"
85#include "llvm/Transforms/Scalar/IndVarSimplify.h"
86#include "llvm/Transforms/Scalar/InstSimplifyPass.h"
87#include "llvm/Transforms/Scalar/LoopIdiomRecognize.h"
88#include "llvm/Transforms/Scalar/MemCpyOptimizer.h"
89#include "llvm/Transforms/Scalar/SROA.h"
90#include "llvm/Transforms/Scalar/SimplifyCFG.h"
91#include "llvm/Transforms/Utils/Cloning.h"
92#include "llvm/Transforms/Utils/LCSSA.h"
93#include "llvm/Transforms/Utils/LowerInvoke.h"
94
95#include "llvm/Transforms/IPO/FunctionAttrs.h"
96#include "llvm/Transforms/Scalar/DCE.h"
97#include "llvm/Transforms/Scalar/LoopDeletion.h"
98#include "llvm/Transforms/Scalar/LoopRotation.h"
99
100#include "llvm/Transforms/Utils/CodeExtractor.h"
101
102#include "llvm/Transforms/Utils/BasicBlockUtils.h"
103#include "llvm/Transforms/Utils/Local.h"
104
105#include "llvm/IR/LegacyPassManager.h"
106#if LLVM_VERSION_MAJOR <= 16
107#include "llvm/Transforms/IPO/PassManagerBuilder.h"
108#endif
109#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
110
111#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
112
113#include <optional>
114
115#include "CacheUtility.h"
116#include "Utils.h"
117
118#define addAttribute addAttributeAtIndex
119#define removeAttribute removeAttributeAtIndex
120#define getAttribute getAttributeAtIndex
121#define hasAttribute hasAttributeAtIndex
122
123#define DEBUG_TYPE "enzyme"
124using namespace llvm;
125
126extern "C" {
127cl::opt<bool> EnzymePreopt("enzyme-preopt", cl::init(true), cl::Hidden,
128 cl::desc("Run enzyme preprocessing optimizations"));
129
130cl::opt<bool> EnzymeInline("enzyme-inline", cl::init(false), cl::Hidden,
131 cl::desc("Force inlining of autodiff"));
132
133cl::opt<int> EnzymePostInlineOpt("enzyme-post-inline-opt", cl::init(0),
134 cl::Hidden,
135 cl::desc("Force inlining of autodiff"));
136
137cl::opt<bool> EnzymeNoAlias("enzyme-noalias", cl::init(false), cl::Hidden,
138 cl::desc("Force noalias of autodiff"));
139#if LLVM_VERSION_MAJOR < 16
140cl::opt<bool>
141 EnzymeAggressiveAA("enzyme-aggressive-aa", cl::init(false), cl::Hidden,
142 cl::desc("Use more unstable but aggressive LLVM AA"));
143#endif
144cl::opt<bool> EnzymeLowerGlobals(
145 "enzyme-lower-globals", cl::init(false), cl::Hidden,
146 cl::desc("Lower globals to locals assuming the global values are not "
147 "needed outside of this gradient"));
148
149cl::opt<int>
150 EnzymeInlineCount("enzyme-inline-count", cl::init(10000), cl::Hidden,
151 cl::desc("Limit of number of functions to inline"));
152
153cl::opt<bool> EnzymeCoalese("enzyme-coalese", cl::init(false), cl::Hidden,
154 cl::desc("Whether to coalese memory allocations"));
155
156static cl::opt<bool> EnzymePHIRestructure(
157 "enzyme-phi-restructure", cl::init(false), cl::Hidden,
158 cl::desc("Whether to restructure phi's to have better unwrap behavior"));
159
160cl::opt<bool>
161 EnzymeNameInstructions("enzyme-name-instructions", cl::init(false),
162 cl::Hidden,
163 cl::desc("Have enzyme name all instructions"));
164
165cl::opt<bool> EnzymeSelectOpt("enzyme-select-opt", cl::init(true), cl::Hidden,
166 cl::desc("Run Enzyme select optimization"));
167
168cl::opt<bool> EnzymeAutoSparsity("enzyme-auto-sparsity", cl::init(false),
169 cl::Hidden,
170 cl::desc("Run Enzyme auto sparsity"));
171
173 "enzyme-post-opt-level", cl::init(0), cl::Hidden,
174 cl::desc("Post optimization level within Enzyme differentiated function"));
175
177 "enzyme-always-inline", cl::init(false), cl::Hidden,
178 cl::desc("Mark generated functions as always-inline"));
179}
180
181/// Is the use of value val as an argument of call CI potentially captured
182bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val) {
183 Function *F = CI->getCalledFunction();
184
185 if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
186 if (castinst->isCast())
187 if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
188 F = fn;
189 }
190 }
191
192 if (F == nullptr)
193 return true;
194
195 if (F->getIntrinsicID() == Intrinsic::memset)
196 return false;
197 if (F->getIntrinsicID() == Intrinsic::memcpy)
198 return false;
199 if (F->getIntrinsicID() == Intrinsic::memmove)
200 return false;
201
202 auto arg = F->arg_begin();
203 for (size_t i = 0, size = CI->arg_size(); i < size; i++) {
204 if (val == CI->getArgOperand(i)) {
205 // This is a vararg, assume captured
206 if (arg == F->arg_end()) {
207 return true;
208 } else {
209 if (!arg->hasNoCaptureAttr()) {
210 return true;
211 }
212 }
213 }
214 if (arg != F->arg_end())
215 arg++;
216 }
217 // No argument captured
218 return false;
219}
220
226/// Return whether this function eventually calls itself
227static bool
229 std::map<const Function *, RecurType> &Results) {
230
231 // If we haven't seen this function before, look at all callers
232 // and mark this as potentially recursive. If we see this function
233 // still as marked as MaybeRecursive, we will definitionally have
234 // found an eventual caller of the original function. If not,
235 // the function does not eventually call itself (in a static way)
236 if (Results.find(F) == Results.end()) {
237 Results[F] = MaybeRecursive; // staging
238 for (auto &BB : *F) {
239 for (auto &I : BB) {
240 if (auto call = dyn_cast<CallInst>(&I)) {
241 if (call->getCalledFunction() == nullptr)
242 continue;
243 if (call->getCalledFunction()->empty())
244 continue;
245 IsFunctionRecursive(call->getCalledFunction(), Results);
246 }
247 if (auto call = dyn_cast<InvokeInst>(&I)) {
248 if (call->getCalledFunction() == nullptr)
249 continue;
250 if (call->getCalledFunction()->empty())
251 continue;
252 IsFunctionRecursive(call->getCalledFunction(), Results);
253 }
254 }
255 }
256 if (Results[F] == MaybeRecursive) {
257 Results[F] = NotRecursive; // not recursive
258 }
259 } else if (Results[F] == MaybeRecursive) {
260 Results[F] = DefinitelyRecursive; // definitely recursive
261 }
262 assert(Results[F] != MaybeRecursive);
263 return Results[F] == DefinitelyRecursive;
264}
265
266static inline bool OnlyUsedInOMP(AllocaInst *AI) {
267 bool ompUse = false;
268 for (auto U : AI->users()) {
269 if (auto SI = dyn_cast<StoreInst>(U))
270 if (SI->getPointerOperand() == AI)
271 continue;
272 if (auto CI = dyn_cast<CallInst>(U)) {
273 if (auto F = CI->getCalledFunction()) {
274 if (F->getName() == "__kmpc_for_static_init_4" ||
275 F->getName() == "__kmpc_for_static_init_4u" ||
276 F->getName() == "__kmpc_for_static_init_8" ||
277 F->getName() == "__kmpc_for_static_init_8u") {
278 ompUse = true;
279 }
280 }
281 }
282 }
283
284 if (!ompUse)
285 return false;
286 return true;
287}
288
290 SmallVector<std::tuple<Value *, Value *, Instruction *>, 1> &Todo,
291 SmallVector<Instruction *, 1> &toErase, bool legal) {
292 SmallVector<StoreInst *, 1> toPostCache;
293 while (Todo.size()) {
294 auto cur = Todo.back();
295 Todo.pop_back();
296 Value *rep = std::get<0>(cur);
297 Value *prev = std::get<1>(cur);
298 Value *inst = std::get<2>(cur);
299 if (auto ASC = dyn_cast<AddrSpaceCastInst>(inst)) {
300 auto AS = cast<PointerType>(rep->getType())->getAddressSpace();
301 if (AS == ASC->getDestAddressSpace()) {
302 ASC->replaceAllUsesWith(rep);
303 toErase.push_back(ASC);
304 continue;
305 }
307 cast<PointerType>(rep->getType())->getAddressSpace() == 0 &&
308 ASC->getDestAddressSpace() == 11) {
309 for (auto U : ASC->users()) {
310 Todo.push_back(
311 std::make_tuple(rep, (Value *)ASC, cast<Instruction>(U)));
312 }
313 toErase.push_back(ASC);
314 continue;
315 }
316 ASC->setOperand(0, rep);
317 continue;
318 }
319 if (auto CI = dyn_cast<CastInst>(inst)) {
320 if (!CI->getType()->isPointerTy()) {
321 CI->setOperand(0, rep);
322 continue;
323 }
324 IRBuilder<> B(CI);
325 Type *resTy;
326#if LLVM_VERSION_MAJOR < 17
327 if (CI->getContext().supportsTypedPointers()) {
328 resTy = PointerType::get(
329 CI->getType()->getPointerElementType(),
330 cast<PointerType>(rep->getType())->getAddressSpace());
331 } else {
332 resTy = rep->getType();
333 }
334#else
335 resTy = rep->getType();
336#endif
337
338 auto nCI0 = B.CreateCast(CI->getOpcode(), rep, resTy);
339 if (auto nCI = dyn_cast<CastInst>(nCI0))
340 nCI->takeName(CI);
341 for (auto U : CI->users()) {
342 Todo.push_back(
343 std::make_tuple((Value *)nCI0, (Value *)CI, cast<Instruction>(U)));
344 }
345 toErase.push_back(CI);
346 continue;
347 }
348 if (auto GEP = dyn_cast<GetElementPtrInst>(inst)) {
349 IRBuilder<> B(GEP);
351 cast<PointerType>(rep->getType())->getAddressSpace() == 10) {
352
353 Type *resTy;
354#if LLVM_VERSION_MAJOR < 17
355 if (GEP->getContext().supportsTypedPointers()) {
356 resTy = PointerType::get(rep->getType()->getPointerElementType(), 11);
357 } else {
358 resTy = PointerType::get(rep->getContext(), 11);
359 }
360#else
361 resTy = PointerType::get(rep->getContext(), 11);
362#endif
363 rep = B.CreateAddrSpaceCast(rep, resTy);
364 }
365 SmallVector<Value *, 1> ind(GEP->indices());
366 auto nGEP = cast<GetElementPtrInst>(
367 B.CreateGEP(GEP->getSourceElementType(), rep, ind));
368 nGEP->takeName(GEP);
369 for (auto U : GEP->users()) {
370 Todo.push_back(
371 std::make_tuple((Value *)nGEP, (Value *)GEP, cast<Instruction>(U)));
372 }
373 toErase.push_back(GEP);
374 continue;
375 }
376 if (auto P = dyn_cast<PHINode>(inst)) {
377 auto NumOperands = P->getNumIncomingValues();
378 SmallVector<Value *, 1> replacedOperands(NumOperands, nullptr);
379 for (size_t i = 0; i < NumOperands; i++)
380 if (P->getOperand(i) == prev)
381 replacedOperands[i] = rep;
382
383 for (auto tval : Todo) {
384 if (std::get<2>(tval) != P)
385 continue;
386 for (size_t i = 0; i < NumOperands; i++)
387 if (P->getOperand(i) == std::get<1>(tval)) {
388 replacedOperands[i] = std::get<0>(tval);
389 }
390 }
391 bool allReplaced = true;
392 for (size_t i = 0; i < NumOperands; i++) {
393 if (!replacedOperands[i]) {
394 allReplaced = false;
395 }
396 }
397 if (!allReplaced) {
398 bool remainingArePHIs = true;
399 for (auto v : Todo) {
400 if (isa<PHINode>(std::get<2>(v))) {
401 } else {
402 remainingArePHIs = false;
403 }
404 }
405 if (!remainingArePHIs) {
406 Todo.insert(Todo.begin(), cur);
407 continue;
408 }
409 } else {
410 IRBuilder<> B(&(*P->getParent()->getFirstNonPHIOrDbgOrLifetime()));
411 auto nP = B.CreatePHI(rep->getType(), P->getNumOperands());
412 for (size_t i = 0; i < NumOperands; i++) {
413 nP->addIncoming(replacedOperands[i], P->getIncomingBlock(i));
414 }
415 nP->takeName(P);
416 for (auto U : P->users()) {
417 Todo.push_back(
418 std::make_tuple((Value *)nP, (Value *)P, cast<Instruction>(U)));
419 }
420 toErase.push_back(P);
421 for (int i = Todo.size() - 1; i >= 0; i--) {
422 if (std::get<2>(Todo[i]) != P)
423 continue;
424 Todo.erase(Todo.begin() + i);
425 }
426 continue;
427 }
428 }
429 if (auto II = dyn_cast<IntrinsicInst>(inst)) {
430 if (isIntelSubscriptIntrinsic(*II)) {
431
432 const std::array<size_t, 4> idxArgsIndices{{0, 1, 2, 4}};
433 const size_t ptrArgIndex = 3;
434
435 SmallVector<Value *, 5> args(5);
436 for (auto i : idxArgsIndices) {
437 Value *idx = II->getOperand(i);
438 args[i] = idx;
439 }
440 args[ptrArgIndex] = rep;
441
442 IRBuilder<> B(II);
443 auto nII = cast<CallInst>(B.CreateCall(II->getCalledFunction(), args));
444 // Must copy the elementtype attribute as it is needed by the intrinsic
445 nII->addParamAttr(
446 ptrArgIndex,
447 II->getParamAttr(ptrArgIndex, Attribute::AttrKind::ElementType));
448 nII->takeName(II);
449 for (auto U : II->users()) {
450 Todo.push_back(
451 std::make_tuple((Value *)nII, (Value *)II, cast<Instruction>(U)));
452 }
453 toErase.push_back(II);
454 continue;
455 }
456 }
457 if (auto LI = dyn_cast<LoadInst>(inst)) {
459 cast<PointerType>(rep->getType())->getAddressSpace() == 10) {
460 IRBuilder<> B(LI);
461 Type *resTy;
462#if LLVM_VERSION_MAJOR < 17
463 if (LI->getContext().supportsTypedPointers()) {
464 resTy = PointerType::get(rep->getType()->getPointerElementType(), 11);
465 } else {
466 resTy = PointerType::get(rep->getContext(), 11);
467 }
468#else
469 resTy = PointerType::get(rep->getContext(), 11);
470#endif
471 rep = B.CreateAddrSpaceCast(rep, resTy);
472 }
473 LI->setOperand(0, rep);
474 continue;
475 }
476 if (auto SI = dyn_cast<StoreInst>(inst)) {
477 if (SI->getPointerOperand() == prev) {
479 cast<PointerType>(rep->getType())->getAddressSpace() == 10) {
480 IRBuilder<> B(SI);
481 Type *resTy;
482#if LLVM_VERSION_MAJOR < 17
483 if (SI->getContext().supportsTypedPointers()) {
484 resTy =
485 PointerType::get(rep->getType()->getPointerElementType(), 11);
486 } else {
487 resTy = PointerType::get(rep->getContext(), 11);
488 }
489#else
490 resTy = PointerType::get(rep->getContext(), 11);
491#endif
492 rep = B.CreateAddrSpaceCast(rep, resTy);
493 }
494 SI->setOperand(1, rep);
496 cast<PointerType>(rep->getType())->getAddressSpace() == 11 &&
497 cast<PointerType>(SI->getPointerOperand()->getType())
498 ->getAddressSpace() == 0) {
499 IRBuilder<> B(SI);
500 auto subvals = getJuliaObjects(SI->getValueOperand(), B);
501 if (subvals.size()) {
502 auto JLT =
503 PointerType::get(StructType::get(SI->getContext(), {}), 10);
504 auto FT = FunctionType::get(Type::getVoidTy(rep->getContext()),
505 {JLT}, true);
506 auto wb = B.GetInsertBlock()
507 ->getParent()
508 ->getParent()
509 ->getOrInsertFunction("julia.write_barrier", FT);
510 auto obj = getBaseObject(rep);
511 assert(obj->getType() == JLT);
512 subvals.insert(subvals.begin(), obj);
513 B.CreateCall(wb, subvals);
514 }
515 }
516 toPostCache.push_back(SI);
517 continue;
518 }
519 }
520 if (auto MS = dyn_cast<MemSetInst>(inst)) {
521 IRBuilder<> B(MS);
522
523 Value *nargs[] = {MS->getArgOperand(0), MS->getArgOperand(1),
524 MS->getArgOperand(2), MS->getArgOperand(3)};
525
526 if (nargs[0] == prev)
527 nargs[0] = rep;
528
529 if (nargs[1] == prev)
530 nargs[1] = rep;
531
532 Type *tys[] = {nargs[0]->getType(), nargs[2]->getType()};
533 auto nMS = cast<CallInst>(B.CreateCall(
534 getIntrinsicDeclaration(MS->getParent()->getParent()->getParent(),
535 Intrinsic::memset, tys),
536 nargs));
537 nMS->copyMetadata(*MS);
538 nMS->setAttributes(MS->getAttributes());
539 toErase.push_back(MS);
540 continue;
541 }
542 if (auto MTI = dyn_cast<MemTransferInst>(inst)) {
543 IRBuilder<> B(MTI);
544
545 Value *nargs[4] = {MTI->getArgOperand(0), MTI->getArgOperand(1),
546 MTI->getArgOperand(2), MTI->getArgOperand(3)};
547
548 if (nargs[0] == prev)
549 nargs[0] = rep;
550
551 if (nargs[1] == prev)
552 nargs[1] = rep;
553
554 Type *tys[] = {nargs[0]->getType(), nargs[1]->getType(),
555 nargs[2]->getType()};
556
557 auto nMTI = cast<CallInst>(B.CreateCall(
558 getIntrinsicDeclaration(MTI->getParent()->getParent()->getParent(),
559 MTI->getIntrinsicID(), tys),
560 nargs));
561 nMTI->copyMetadata(*MTI);
562 nMTI->setAttributes(MTI->getAttributes());
563 toErase.push_back(MTI);
564 continue;
565 }
566 if (auto CI = dyn_cast<CallInst>(inst)) {
567 if (auto F = CI->getCalledFunction()) {
568 if (F->getName() == "julia.write_barrier" && legal) {
569 toErase.push_back(CI);
570 continue;
571 }
572 if (F->getName() == "julia.write_barrier_binding" && legal) {
573 toErase.push_back(CI);
574 continue;
575 }
576 }
577 IRBuilder<> B(CI);
578 auto Addr = B.CreateAddrSpaceCast(rep, prev->getType());
579 for (size_t i = 0; i < CI->arg_size(); i++) {
580 if (CI->getArgOperand(i) == prev) {
581 CI->setArgOperand(i, Addr);
582 }
583 }
584 continue;
585 }
586 if (auto IVI = dyn_cast<InsertValueInst>(inst)) {
587 if (IVI->getInsertedValueOperand() == prev && EnzymeJuliaAddrLoad &&
588 cast<PointerType>(rep->getType())->getAddressSpace() == 0 &&
589 cast<PointerType>(IVI->getInsertedValueOperand()->getType())
590 ->getAddressSpace() == 11) {
591 IRBuilder<> B(IVI);
592 auto Addr = B.CreateAddrSpaceCast(rep, prev->getType());
593 IVI->setOperand(1, Addr);
594 continue;
595 }
596 }
597
598 std::string s;
599 llvm::raw_string_ostream ss(s);
600 ss << "Illegal address space propagation\n";
601 ss << " + rep: " << *rep << "\n";
602 ss << " + prev: " << *prev << "\n";
603 ss << " + inst: " << *inst << "\n";
604
605 if (CustomErrorHandler) {
606 CustomErrorHandler(s.c_str(), wrap(inst), ErrorType::InternalError,
607 nullptr, nullptr, nullptr);
608 } else {
609 auto instI = cast<Instruction>(inst);
610 ss << *instI->getParent()->getParent() << "\n";
611 EmitFailure("IllegalAddressSpacePropagation", instI->getDebugLoc(), instI,
612 ss.str());
613 }
614 llvm_unreachable("Illegal address space propagation");
615 }
616
617 for (auto I : llvm::reverse(toErase)) {
618 I->eraseFromParent();
619 }
620 for (auto SI : toPostCache) {
621 IRBuilder<> B(SI->getNextNode());
622 PostCacheStore(SI, B);
623 }
624}
625
626void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) {
627 SmallVector<std::tuple<Value *, Value *, Instruction *>, 1> Todo;
628 for (auto U : AI->users()) {
629 Todo.push_back(
630 std::make_tuple((Value *)rep, (Value *)AI, cast<Instruction>(U)));
631 }
632 SmallVector<Instruction *, 1> toErase;
633 if (auto I = dyn_cast<Instruction>(AI))
634 toErase.push_back(I);
635 RecursivelyReplaceAddressSpace(Todo, toErase, legal);
636}
637
638/// Convert necessary stack allocations into mallocs for use in the reverse
639/// pass. Specifically if we're not topLevel all allocations must be upgraded
640/// Even if topLevel any allocations that aren't in the entry block (and
641/// therefore may not be reachable in the reverse pass) must be upgraded.
642static inline void
644 SmallPtrSetImpl<llvm::BasicBlock *> &Unreachable) {
645 SmallVector<AllocaInst *, 4> ToConvert;
646
647 for (auto &BB : *NewF) {
648 if (Unreachable.count(&BB))
649 continue;
650 for (auto &I : BB) {
651 if (auto AI = dyn_cast<AllocaInst>(&I)) {
652 bool UsableEverywhere = AI->getParent() == &NewF->getEntryBlock();
653 // TODO use is_value_needed_in_reverse (requiring GradientUtils)
654 if (OnlyUsedInOMP(AI))
655 continue;
656 if (!UsableEverywhere || mode != DerivativeMode::ReverseModeCombined) {
657 ToConvert.push_back(AI);
658 }
659 }
660 }
661 }
662
663#if LLVM_VERSION_MAJOR >= 22
664 Function *start_lifetime = nullptr;
665 Function *end_lifetime = nullptr;
666#endif
667
668 SmallVector<std::tuple<Value *, Value *, Instruction *>, 1> Todo;
669 SmallVector<Instruction *, 1> toErase;
670 for (auto AI : ToConvert) {
671 std::string nam = AI->getName().str();
672 AI->setName("");
673
674#if LLVM_VERSION_MAJOR >= 22
675 for (auto U : llvm::make_early_inc_range(AI->users())) {
676 if (auto II = dyn_cast<IntrinsicInst>(U)) {
677 if (II->getIntrinsicID() == Intrinsic::lifetime_start) {
678 if (!start_lifetime) {
679 start_lifetime = cast<Function>(
680 NewF->getParent()
681 ->getOrInsertFunction(
682 "llvm.enzyme.lifetime_start",
683 FunctionType::get(Type::getVoidTy(NewF->getContext()),
684 {}, true))
685 .getCallee());
686 }
687 IRBuilder<> B(II);
688 SmallVector<Value *, 2> args(II->arg_size());
689 for (unsigned i = 0; i < II->arg_size(); ++i) {
690 args[i] = II->getArgOperand(i);
691 }
692 auto newI = B.CreateCall(start_lifetime, args);
693 newI->takeName(II);
694 newI->setDebugLoc(II->getDebugLoc());
695 II->eraseFromParent();
696 continue;
697 }
698 if (II->getIntrinsicID() == Intrinsic::lifetime_end) {
699 if (!end_lifetime) {
700 end_lifetime = cast<Function>(
701 NewF->getParent()
702 ->getOrInsertFunction(
703 "llvm.enzyme.lifetime_end",
704 FunctionType::get(Type::getVoidTy(NewF->getContext()),
705 {}, true))
706 .getCallee());
707 }
708 IRBuilder<> B(II);
709 SmallVector<Value *, 2> args(II->arg_size());
710 for (unsigned i = 0; i < II->arg_size(); ++i) {
711 args[i] = II->getArgOperand(i);
712 }
713 auto newI = B.CreateCall(end_lifetime, args);
714 newI->takeName(II);
715 newI->setDebugLoc(II->getDebugLoc());
716 II->eraseFromParent();
717 continue;
718 }
719 }
720 }
721#endif
722
723 // Ensure we insert the malloc after the allocas
724 Instruction *insertBefore = AI;
725 while (isa<AllocaInst>(insertBefore->getNextNode())) {
726 insertBefore = insertBefore->getNextNode();
727 assert(insertBefore);
728 }
729
730 auto i64 = Type::getInt64Ty(NewF->getContext());
731 IRBuilder<> B(insertBefore);
732 CallInst *CI = nullptr;
733 Instruction *ZeroInst = nullptr;
734 auto rep = CreateAllocation(
735 B, AI->getAllocatedType(), B.CreateZExtOrTrunc(AI->getArraySize(), i64),
736 nam, &CI, /*ZeroMem*/ EnzymeZeroCache ? &ZeroInst : nullptr);
737 auto align = AI->getAlign().value();
738 CI->setMetadata(
739 "enzyme_fromstack",
740 MDNode::get(CI->getContext(),
741 {
742 ConstantAsMetadata::get(ConstantInt::get(
743 IntegerType::get(AI->getContext(), 64), align)),
744 ConstantAsMetadata::get(ConstantInt::get(
745 IntegerType::get(AI->getContext(), 64),
746 (size_t)AI->getAllocatedType())),
747 }));
748
749 for (auto MD : {"enzyme_active", "enzyme_inactive", "enzyme_type",
750 "enzymejl_allocart", "enzymejl_allocart_name",
751 "enzymejl_gc_alloc_rt"})
752 if (auto M = AI->getMetadata(MD))
753 CI->setMetadata(MD, M);
754
755 if (rep != CI) {
756 cast<Instruction>(rep)->setMetadata("enzyme_caststack",
757 MDNode::get(CI->getContext(), {}));
758 }
759 if (ZeroInst) {
760 ZeroInst->setMetadata("enzyme_zerostack",
761 MDNode::get(CI->getContext(), {}));
762 }
763
764 auto PT0 = cast<PointerType>(rep->getType());
765 auto PT1 = cast<PointerType>(AI->getType());
766 if (PT0->getAddressSpace() != PT1->getAddressSpace()) {
767 for (auto U : AI->users()) {
768 Todo.push_back(
769 std::make_tuple((Value *)rep, (Value *)AI, cast<Instruction>(U)));
770 }
771 if (auto I = dyn_cast<Instruction>(AI))
772 toErase.push_back(I);
773 } else {
774 assert(rep->getType() == AI->getType());
775 AI->replaceAllUsesWith(rep);
776 AI->eraseFromParent();
777 }
778 }
779 RecursivelyReplaceAddressSpace(Todo, toErase, /*legal*/ false);
780}
781
782// Create a stack variable containing the size of the allocation
783// error if not possible (e.g. not local)
784static inline AllocaInst *
785OldAllocationSize(Value *Ptr, CallInst *Loc, Function *NewF, IntegerType *T,
786 const std::map<CallInst *, Value *> &reallocSizes) {
787 IRBuilder<> B(&*NewF->getEntryBlock().begin());
788 AllocaInst *AI = B.CreateAlloca(T);
789
790 std::set<std::pair<Value *, Instruction *>> seen;
791 std::deque<std::pair<Value *, Instruction *>> todo = {{Ptr, Loc}};
792
793 while (todo.size()) {
794 auto next = todo.front();
795 todo.pop_front();
796 if (seen.count(next))
797 continue;
798 seen.insert(next);
799
800 if (auto CI = dyn_cast<CastInst>(next.first)) {
801 todo.push_back({CI->getOperand(0), CI});
802 continue;
803 }
804
805 // Assume zero size if realloc of undef pointer
806 if (isa<UndefValue>(next.first)) {
807 B.SetInsertPoint(next.second);
808 B.CreateStore(ConstantInt::get(T, 0), AI);
809 continue;
810 }
811
812 if (auto CE = dyn_cast<ConstantExpr>(next.first)) {
813 if (CE->isCast()) {
814 todo.push_back({CE->getOperand(0), next.second});
815 continue;
816 }
817 }
818
819 if (auto C = dyn_cast<Constant>(next.first)) {
820 if (C->isNullValue()) {
821 B.SetInsertPoint(next.second);
822 B.CreateStore(ConstantInt::get(T, 0), AI);
823 continue;
824 }
825 }
826 if (auto CI = dyn_cast<ConstantInt>(next.first)) {
827 // if negative or below 0xFFF this cannot possibly represent
828 // a real pointer, so ignore this case by setting to 0
829 if (CI->isNegative() || CI->getLimitedValue() <= 0xFFF) {
830 B.SetInsertPoint(next.second);
831 B.CreateStore(ConstantInt::get(T, 0), AI);
832 continue;
833 }
834 }
835
836 // Todo consider more general method for selects
837 if (auto SI = dyn_cast<SelectInst>(next.first)) {
838 if (auto C1 = dyn_cast<ConstantInt>(SI->getTrueValue())) {
839 // if negative or below 0xFFF this cannot possibly represent
840 // a real pointer, so ignore this case by setting to 0
841 if (C1->isNegative() || C1->getLimitedValue() <= 0xFFF) {
842 if (auto C2 = dyn_cast<ConstantInt>(SI->getFalseValue())) {
843 if (C2->isNegative() || C2->getLimitedValue() <= 0xFFF) {
844 B.SetInsertPoint(next.second);
845 B.CreateStore(ConstantInt::get(T, 0), AI);
846 continue;
847 }
848 }
849 }
850 }
851 }
852
853 if (auto PN = dyn_cast<PHINode>(next.first)) {
854 for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
855 todo.push_back({PN->getIncomingValue(i),
856 PN->getIncomingBlock(i)->getTerminator()});
857 }
858 continue;
859 }
860
861 if (auto CI = dyn_cast<CallInst>(next.first)) {
862 if (auto F = CI->getCalledFunction()) {
863 if (F->getName() == "malloc") {
864 B.SetInsertPoint(next.second);
865 B.CreateStore(CI->getArgOperand(0), AI);
866 continue;
867 }
868 if (F->getName() == "calloc") {
869 B.SetInsertPoint(next.second);
870 B.CreateStore(B.CreateMul(CI->getArgOperand(0), CI->getArgOperand(1)),
871 AI);
872 continue;
873 }
874 if (F->getName() == "realloc") {
875 assert(reallocSizes.find(CI) != reallocSizes.end());
876 B.SetInsertPoint(next.second);
877 B.CreateStore(reallocSizes.find(CI)->second, AI);
878 continue;
879 }
880 }
881 }
882
883 if (auto LI = dyn_cast<LoadInst>(next.first)) {
884 bool success = false;
885 for (Instruction *prev = LI->getPrevNode(); prev != nullptr;
886 prev = prev->getPrevNode()) {
887 if (auto CI = dyn_cast<CallInst>(prev)) {
888 if (auto F = CI->getCalledFunction()) {
889 if (F->getName() == "posix_memalign" &&
890 CI->getArgOperand(0) == LI->getOperand(0)) {
891 B.SetInsertPoint(next.second);
892 B.CreateStore(CI->getArgOperand(2), AI);
893 success = true;
894 break;
895 }
896 }
897 }
898 if (prev->mayWriteToMemory()) {
899 break;
900 }
901 }
902 if (success)
903 continue;
904
905 auto v2 = simplifyLoad(LI);
906 if (v2) {
907 todo.push_back({v2, next.second});
908 continue;
909 }
910 }
911
912 EmitFailure("DynamicReallocSize", Loc->getDebugLoc(), Loc,
913 "could not statically determine size of realloc ", *Loc,
914 " - because of - ", *next.first);
915 return AI;
916
917 std::string allocName;
918 switch (llvm::Triple(NewF->getParent()->getTargetTriple()).getOS()) {
919 case llvm::Triple::Linux:
920 case llvm::Triple::FreeBSD:
921 case llvm::Triple::NetBSD:
922 case llvm::Triple::OpenBSD:
923 case llvm::Triple::Fuchsia:
924 allocName = "malloc_usable_size";
925 break;
926
927 case llvm::Triple::Darwin:
928 case llvm::Triple::IOS:
929 case llvm::Triple::MacOSX:
930 case llvm::Triple::WatchOS:
931 case llvm::Triple::TvOS:
932 allocName = "malloc_size";
933 break;
934
935 case llvm::Triple::Win32:
936 allocName = "_msize";
937 break;
938
939 default:
940 llvm_unreachable("unknown reallocation for OS");
941 }
942
943 AttributeList list;
944 list = list.addFnAttribute(NewF->getContext(), Attribute::ReadOnly);
945 list = list.addParamAttribute(NewF->getContext(), 0, Attribute::ReadNone);
946 list = addFunctionNoCapture(NewF->getContext(), list, 0);
947 auto allocSize = NewF->getParent()->getOrInsertFunction(
948 allocName,
949 FunctionType::get(
950 IntegerType::get(NewF->getContext(), 8 * sizeof(size_t)),
951 {getInt8PtrTy(NewF->getContext())}, /*isVarArg*/ false),
952 list);
953
954 B.SetInsertPoint(Loc);
955 Value *sz = B.CreateZExtOrTrunc(B.CreateCall(allocSize, {Ptr}), T);
956 B.CreateStore(sz, AI);
957 return AI;
958
959 llvm_unreachable("DynamicReallocSize");
960 }
961 return AI;
962}
963
964void PreProcessCache::AlwaysInline(Function *NewF) {
965
966 PreservedAnalyses PA;
967 PA.preserve<AssumptionAnalysis>();
968 PA.preserve<TargetLibraryAnalysis>();
969 FAM.invalidate(*NewF, PA);
970 SmallVector<CallInst *, 2> ToInline;
971 // TODO this logic should be combined with the dynamic loop emission
972 // to minimize the number of branches if the realloc is used for multiple
973 // values with the same bound.
974 for (auto &BB : *NewF) {
975 for (auto &I : make_early_inc_range(BB)) {
976 if (hasMetadata(&I, "enzyme_zerostack")) {
977 if (isa<AllocaInst>(getBaseObject(I.getOperand(0)))) {
978 I.eraseFromParent();
979 continue;
980 }
981 }
982 if (auto CI = dyn_cast<CallInst>(&I)) {
983 if (!CI->getCalledFunction())
984 continue;
985 if (CI->getCalledFunction()->hasFnAttribute(Attribute::AlwaysInline))
986 ToInline.push_back(CI);
987 }
988 }
989 }
990
991 for (auto CI : ToInline) {
992 InlineFunctionInfo IFI;
993#if LLVM_VERSION_MAJOR >= 18 && LLVM_VERSION_MAJOR < 21
994 auto F = CI->getCalledFunction();
995 if (CI->getParent()->IsNewDbgInfoFormat != F->IsNewDbgInfoFormat) {
996 if (CI->getParent()->IsNewDbgInfoFormat) {
997 F->convertToNewDbgValues();
998 } else {
999 F->convertFromNewDbgValues();
1000 }
1001 }
1002#endif
1003 InlineFunction(*CI, IFI);
1004 }
1005}
1006
1007// Simplify all extractions to use inserted values, if possible.
1008void simplifyExtractions(Function *NewF) {
1009 // First rewrite/remove any extractions
1010 for (auto &BB : *NewF) {
1011 IRBuilder<> B(&BB);
1012 auto first = BB.begin();
1013 auto last = BB.empty() ? BB.end() : std::prev(BB.end());
1014 for (auto it = first; it != last;) {
1015 auto inst = &*it;
1016 // We iterate first here, since we may delete the instruction
1017 // in the body
1018 ++it;
1019 if (auto E = dyn_cast<ExtractValueInst>(inst)) {
1020 auto rep = GradientUtils::extractMeta(B, E->getAggregateOperand(),
1021 E->getIndices(), E->getName(),
1022 /*fallback*/ false);
1023 if (rep) {
1024 E->replaceAllUsesWith(rep);
1025 E->eraseFromParent();
1026 }
1027 }
1028 }
1029 }
1030 // Now that there may be unused insertions, delete them. We keep a list of
1031 // todo's since deleting an insertvalue may cause a different insertvalue to
1032 // have no uses
1033 SmallVector<InsertValueInst *, 1> todo;
1034 for (auto &BB : *NewF) {
1035 for (auto &inst : BB)
1036 if (auto I = dyn_cast<InsertValueInst>(&inst)) {
1037 if (I->getNumUses() == 0)
1038 todo.push_back(I);
1039 }
1040 }
1041 while (todo.size()) {
1042 auto I = todo.pop_back_val();
1043 auto op = I->getAggregateOperand();
1044 I->eraseFromParent();
1045 if (auto I2 = dyn_cast<InsertValueInst>(op))
1046 if (I2->getNumUses() == 0)
1047 todo.push_back(I2);
1048 }
1049}
1050
1052 simplifyExtractions(NewF);
1053 SmallVector<Instruction *, 1> Todo;
1054 for (auto &BB : *NewF) {
1055 for (auto &I : BB) {
1056 if (hasMetadata(&I, "enzyme_backstack")) {
1057 Todo.push_back(&I);
1058 // TODO
1059 // I.eraseMetadata("enzyme_backstack");
1060 }
1061 }
1062 }
1063 for (auto T : Todo) {
1064 auto T0 = T->getOperand(0);
1065 if (auto CI = dyn_cast<BitCastInst>(T0))
1066 T0 = CI->getOperand(0);
1067 auto AI = cast<AllocaInst>(T0);
1068 llvm::Value *AIV = AI;
1069#if LLVM_VERSION_MAJOR < 17
1070 if (AIV->getContext().supportsTypedPointers() &&
1071 AIV->getType()->getPointerElementType() !=
1072 T->getType()->getPointerElementType()) {
1073 IRBuilder<> B(AI->getNextNode());
1074 AIV = B.CreateBitCast(
1075 AIV, PointerType::get(
1076 T->getType()->getPointerElementType(),
1077 cast<PointerType>(AI->getType())->getAddressSpace()));
1078 }
1079#endif
1080 RecursivelyReplaceAddressSpace(T, AIV, /*legal*/ true);
1081 }
1082
1083#if LLVM_VERSION_MAJOR >= 22
1084 {
1085 auto start_lifetime =
1086 NewF->getParent()->getFunction("llvm.enzyme.lifetime_start");
1087 auto end_lifetime =
1088 NewF->getParent()->getFunction("llvm.enzyme.lifetime_end");
1089
1090 SmallVector<CallInst *, 1> Todo;
1091 for (auto &BB : *NewF) {
1092 for (auto &I : BB) {
1093 if (auto CB = dyn_cast<CallInst>(&I)) {
1094 if (!CB->getCalledFunction())
1095 continue;
1096 if (CB->getCalledFunction() == start_lifetime ||
1097 CB->getCalledFunction() == end_lifetime) {
1098 Todo.push_back(CB);
1099 }
1100 }
1101 }
1102 }
1103
1104 for (auto CB : Todo) {
1105 if (!isa<AllocaInst>(CB->getArgOperand(1))) {
1106 CB->eraseFromParent();
1107 continue;
1108 }
1109 IRBuilder<> B(CB);
1110 if (CB->getCalledFunction() == start_lifetime) {
1111 B.CreateLifetimeStart(cast<ConstantInt>(CB->getArgOperand(0)));
1112 } else {
1113 B.CreateLifetimeEnd(cast<ConstantInt>(CB->getArgOperand(0)));
1114 }
1115 CB->eraseFromParent();
1116 }
1117 }
1118#endif
1119}
1120
1121/// Calls to realloc with an appropriate implementation
1122void PreProcessCache::ReplaceReallocs(Function *NewF, bool mem2reg) {
1123 if (mem2reg) {
1124 auto PA = PromotePass().run(*NewF, FAM);
1125 FAM.invalidate(*NewF, PA);
1126#if !defined(FLANG)
1127 PA = GVNPass().run(*NewF, FAM);
1128#else
1129 PA = GVN().run(*NewF, FAM);
1130#endif
1131 FAM.invalidate(*NewF, PA);
1132 }
1133
1134 SmallVector<CallInst *, 4> ToConvert;
1135 std::map<CallInst *, Value *> reallocSizes;
1136 IntegerType *T = nullptr;
1137
1138 for (auto &BB : *NewF) {
1139 for (auto &I : BB) {
1140 if (auto CI = dyn_cast<CallInst>(&I)) {
1141 if (auto F = CI->getCalledFunction()) {
1142 if (F->getName() == "realloc") {
1143 ToConvert.push_back(CI);
1144 IRBuilder<> B(CI->getNextNode());
1145 T = cast<IntegerType>(CI->getArgOperand(1)->getType());
1146 reallocSizes[CI] = B.CreatePHI(T, 0);
1147 }
1148 }
1149 }
1150 }
1151 }
1152
1153 SmallVector<AllocaInst *, 4> memoryLocations;
1154
1155 for (auto CI : ToConvert) {
1156 assert(T);
1157 AllocaInst *AI =
1158 OldAllocationSize(CI->getArgOperand(0), CI, NewF, T, reallocSizes);
1159
1160 BasicBlock *resize =
1161 BasicBlock::Create(CI->getContext(), "resize" + CI->getName(), NewF);
1162 assert(resize->getParent() == NewF);
1163
1164 BasicBlock *splitParent = CI->getParent();
1165 BasicBlock *nextBlock = splitParent->splitBasicBlock(CI);
1166
1167 splitParent->getTerminator()->eraseFromParent();
1168 IRBuilder<> B(splitParent);
1169
1170 Value *p = CI->getArgOperand(0);
1171 Value *req = CI->getArgOperand(1);
1172 Value *old = B.CreateLoad(AI->getAllocatedType(), AI);
1173 Value *cmp = B.CreateICmpULE(req, old);
1174 // if (req < old)
1175 B.CreateCondBr(cmp, nextBlock, resize);
1176
1177 B.SetInsertPoint(resize);
1178 // size_t newsize = nextPowerOfTwo(req);
1179 // void* next = malloc(newsize);
1180 // memcpy(next, p, newsize);
1181 // free(p);
1182 // return { next, newsize };
1183
1184 Value *newsize = nextPowerOfTwo(B, req);
1185
1186 Module *M = NewF->getParent();
1187 Type *BPTy = getInt8PtrTy(NewF->getContext());
1188 auto MallocFunc =
1189 M->getOrInsertFunction("malloc", BPTy, newsize->getType());
1190 auto next = B.CreateCall(MallocFunc, newsize);
1191 B.SetInsertPoint(resize);
1192
1193 auto volatile_arg = ConstantInt::getFalse(CI->getContext());
1194
1195 Value *nargs[] = {next, p, old, volatile_arg};
1196
1197 Type *tys[] = {next->getType(), p->getType(), old->getType()};
1198
1199 auto memcpyF =
1200 getIntrinsicDeclaration(NewF->getParent(), Intrinsic::memcpy, tys);
1201
1202 auto mem = cast<CallInst>(B.CreateCall(memcpyF, nargs));
1203 mem->setCallingConv(memcpyF->getCallingConv());
1204
1205 Type *VoidTy = Type::getVoidTy(M->getContext());
1206 auto FreeFunc = M->getOrInsertFunction("free", VoidTy, BPTy);
1207 B.CreateCall(FreeFunc, p);
1208 B.SetInsertPoint(resize);
1209
1210 B.CreateBr(nextBlock);
1211
1212 // else
1213 // return { p, old }
1214 B.SetInsertPoint(&*nextBlock->begin());
1215
1216 PHINode *retPtr = B.CreatePHI(CI->getType(), 2);
1217 retPtr->addIncoming(p, splitParent);
1218 retPtr->addIncoming(next, resize);
1219 CI->replaceAllUsesWith(retPtr);
1220 std::string nam = CI->getName().str();
1221 CI->setName("");
1222 retPtr->setName(nam);
1223 Value *nextSize = B.CreateSelect(cmp, old, req);
1224 reallocSizes[CI]->replaceAllUsesWith(nextSize);
1225 cast<PHINode>(reallocSizes[CI])->eraseFromParent();
1226 reallocSizes[CI] = nextSize;
1227 }
1228
1229 for (auto CI : ToConvert) {
1230 CI->eraseFromParent();
1231 }
1232
1233 PreservedAnalyses PA;
1234 FAM.invalidate(*NewF, PA);
1235
1236 PA = PromotePass().run(*NewF, FAM);
1237 FAM.invalidate(*NewF, PA);
1238}
1239
1240Function *CreateMPIWrapper(Function *F) {
1241 std::string name = ("enzyme_wrapmpi$$" + F->getName() + "#").str();
1242 if (auto W = F->getParent()->getFunction(name))
1243 return W;
1244
1245 // MPI_Comm_rank(MPI_Comm comm, int *rank)
1246 // MPI_Comm_size(MPI_Comm comm, int *size)
1247 Type *ReturnType = Type::getInt32Ty(F->getContext());
1248 Type *types = {F->getFunctionType()->getParamType(0)}; // MPI_Comm
1249 auto FT = FunctionType::get(ReturnType, types, false);
1250 Function *W = Function::Create(FT, GlobalVariable::InternalLinkage, name,
1251 F->getParent());
1252 llvm::Attribute::AttrKind attrs[] = {
1253 Attribute::WillReturn,
1254 Attribute::MustProgress,
1255#if LLVM_VERSION_MAJOR < 16
1256 Attribute::ReadOnly,
1257#endif
1258 Attribute::Speculatable,
1259 Attribute::NoUnwind,
1260 Attribute::AlwaysInline,
1261 Attribute::NoFree,
1262 Attribute::NoSync,
1263#if LLVM_VERSION_MAJOR < 16
1264 Attribute::InaccessibleMemOnly
1265#endif
1266 };
1267 for (auto attr : attrs) {
1268 W->addFnAttr(attr);
1269 }
1270#if LLVM_VERSION_MAJOR >= 16
1271 W->setOnlyAccessesInaccessibleMemory();
1272 W->setOnlyReadsMemory();
1273#endif
1274 W->addFnAttr(Attribute::get(F->getContext(), "enzyme_inactive"));
1275 BasicBlock *entry = BasicBlock::Create(W->getContext(), "entry", W);
1276 IRBuilder<> B(entry);
1277 auto alloc = B.CreateAlloca(ReturnType);
1278 Value *args[] = {W->arg_begin(), alloc};
1279
1280 auto T = F->getFunctionType()->getParamType(1);
1281 if (!isa<PointerType>(T)) {
1282 assert(isa<IntegerType>(T));
1283 args[1] = B.CreatePtrToInt(args[1], T);
1284 }
1285 B.CreateCall(F, args);
1286 B.CreateRet(B.CreateLoad(ReturnType, alloc));
1287 return W;
1288}
1289
1290static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) {
1291 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(NewF);
1292 SmallVector<CallBase *, 4> Todo;
1293 SmallVector<CallBase *, 4> OMPBounds;
1294 for (auto &BB : NewF) {
1295 for (auto &I : BB) {
1296 if (auto CI = dyn_cast<CallBase>(&I)) {
1297 Function *Fn = CI->getCalledFunction();
1298 if (Fn == nullptr)
1299 continue;
1300 auto name = getFuncName(Fn);
1301 if (name == "MPI_Comm_rank" || name == "PMPI_Comm_rank" ||
1302 name == "MPI_Comm_size" || name == "PMPI_Comm_size") {
1303 Todo.push_back(CI);
1304 }
1305 if (name == "__kmpc_for_static_init_4" ||
1306 name == "__kmpc_for_static_init_4u" ||
1307 name == "__kmpc_for_static_init_8" ||
1308 name == "__kmpc_for_static_init_8u") {
1309 OMPBounds.push_back(CI);
1310 }
1311 }
1312 }
1313 }
1314 if (Todo.size() == 0 && OMPBounds.size() == 0)
1315 return;
1316 for (auto CI : Todo) {
1317 IRBuilder<> B(CI);
1318 Value *arg[] = {CI->getArgOperand(0)};
1319 SmallVector<OperandBundleDef, 2> Defs;
1320 CI->getOperandBundlesAsDefs(Defs);
1321 CallBase *res = nullptr;
1322 if (auto II = dyn_cast<InvokeInst>(CI))
1323 res = B.CreateInvoke(CreateMPIWrapper(CI->getCalledFunction()),
1324 II->getNormalDest(), II->getUnwindDest(), arg, Defs);
1325 else
1326 res = B.CreateCall(CreateMPIWrapper(CI->getCalledFunction()), arg, Defs);
1327 Value *storePointer = CI->getArgOperand(1);
1328
1329 // Comm_rank and Comm_size return Err, assume 0 is success
1330 CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0));
1331 CI->eraseFromParent();
1332
1333 while (auto Cast = dyn_cast<CastInst>(storePointer)) {
1334 storePointer = Cast->getOperand(0);
1335 if (Cast->use_empty())
1336 Cast->eraseFromParent();
1337 }
1338
1339 B.SetInsertPoint(res);
1340
1341 if (auto PT = dyn_cast<PointerType>(storePointer->getType())) {
1342 (void)PT;
1343#if LLVM_VERSION_MAJOR < 17
1344 if (PT->getContext().supportsTypedPointers()) {
1345 if (PT->getPointerElementType() != res->getType())
1346 storePointer = B.CreateBitCast(
1347 storePointer,
1348 PointerType::get(res->getType(), PT->getAddressSpace()));
1349 }
1350#endif
1351 } else {
1352 assert(isa<IntegerType>(storePointer->getType()));
1353 storePointer = B.CreateIntToPtr(storePointer, getUnqual(res->getType()));
1354 }
1355 if (isa<AllocaInst>(storePointer)) {
1356 // If this is only loaded from, immedaitely replace
1357 // Immediately replace all dominated stores.
1358 SmallVector<LoadInst *, 2> LI;
1359 bool nonload = false;
1360 for (auto &U : storePointer->uses()) {
1361 if (auto L = dyn_cast<LoadInst>(U.getUser())) {
1362 LI.push_back(L);
1363 } else
1364 nonload = true;
1365 }
1366 if (!nonload) {
1367 for (auto L : LI) {
1368 if (DT.dominates(res, L)) {
1369 L->replaceAllUsesWith(res);
1370 L->eraseFromParent();
1371 }
1372 }
1373 }
1374 }
1375 if (auto II = dyn_cast<InvokeInst>(res)) {
1376 B.SetInsertPoint(II->getNormalDest()->getFirstNonPHI());
1377 } else {
1378 B.SetInsertPoint(res->getNextNode());
1379 }
1380 B.CreateStore(res, storePointer);
1381 }
1382 for (auto Bound : OMPBounds) {
1383 for (int i = 4; i <= 6; i++) {
1384 auto AI = cast<AllocaInst>(Bound->getArgOperand(i));
1385 IRBuilder<> B(AI);
1386 auto AI2 = B.CreateAlloca(AI->getAllocatedType(), nullptr,
1387 AI->getName() + "_smpl");
1388 B.SetInsertPoint(Bound);
1389 B.CreateStore(B.CreateLoad(AI->getAllocatedType(), AI), AI2);
1390 Bound->setArgOperand(i, AI2);
1391 if (auto II = dyn_cast<InvokeInst>(Bound)) {
1392 B.SetInsertPoint(II->getNormalDest()->getFirstNonPHI());
1393 } else {
1394 B.SetInsertPoint(Bound->getNextNode());
1395 }
1396 B.CreateStore(B.CreateLoad(AI2->getAllocatedType(), AI2), AI);
1397 addCallSiteNoCapture(Bound, i);
1398 }
1399 }
1400 PreservedAnalyses PA;
1401 PA.preserve<AssumptionAnalysis>();
1402 PA.preserve<TargetLibraryAnalysis>();
1403 PA.preserve<LoopAnalysis>();
1404 PA.preserve<DominatorTreeAnalysis>();
1405 PA.preserve<PostDominatorTreeAnalysis>();
1406 FAM.invalidate(NewF, PA);
1407}
1408
1409/// Perform recursive inlinining on NewF up to the given limit
1410static void ForceRecursiveInlining(Function *NewF, size_t Limit) {
1411 std::map<const Function *, RecurType> RecurResults;
1412 for (size_t count = 0; count < Limit; count++) {
1413 for (auto &BB : *NewF) {
1414 for (auto &I : BB) {
1415 if (auto CI = dyn_cast<CallInst>(&I)) {
1416 if (CI->getCalledFunction() == nullptr)
1417 continue;
1418 if (CI->getCalledFunction()->empty())
1419 continue;
1420 if (startsWith(CI->getCalledFunction()->getName(),
1421 "_ZN3std2io5stdio6_print"))
1422 continue;
1423 if (startsWith(CI->getCalledFunction()->getName(), "_ZN4core3fmt"))
1424 continue;
1425 if (startsWith(CI->getCalledFunction()->getName(),
1426 "enzyme_wrapmpi$$"))
1427 continue;
1428 if (CI->getCalledFunction()->hasFnAttribute(
1429 Attribute::ReturnsTwice) ||
1430 CI->getCalledFunction()->hasFnAttribute(Attribute::NoInline))
1431 continue;
1432 if (IsFunctionRecursive(CI->getCalledFunction(), RecurResults)) {
1433 LLVM_DEBUG(llvm::dbgs()
1434 << "not inlining recursive "
1435 << CI->getCalledFunction()->getName() << "\n");
1436 continue;
1437 }
1438 InlineFunctionInfo IFI;
1439 InlineFunction(*CI, IFI);
1440 goto outermostContinue;
1441 }
1442 }
1443 }
1444
1445 // No functions were inlined, break
1446 break;
1447
1448 outermostContinue:;
1449 }
1450}
1451
1452void CanonicalizeLoops(Function *F, FunctionAnalysisManager &FAM) {
1453 LoopSimplifyPass().run(*F, FAM);
1454 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(*F);
1455 LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
1456 AssumptionCache &AC = FAM.getResult<AssumptionAnalysis>(*F);
1457 TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(*F);
1458 MustExitScalarEvolution SE(*F, TLI, AC, DT, LI);
1459 for (Loop *L : LI.getLoopsInPreorder()) {
1460 auto pair =
1461 InsertNewCanonicalIV(L, Type::getInt64Ty(F->getContext()), "iv");
1462 PHINode *CanonicalIV = pair.first;
1463 assert(CanonicalIV);
1465 L->getHeader(), CanonicalIV, pair.second, SE,
1466 [&](Instruction *I, Value *V) { I->replaceAllUsesWith(V); },
1467 [&](Instruction *I) { I->eraseFromParent(); });
1468 }
1469 PreservedAnalyses PA;
1470 PA.preserve<AssumptionAnalysis>();
1471 PA.preserve<TargetLibraryAnalysis>();
1472 PA.preserve<LoopAnalysis>();
1473 PA.preserve<DominatorTreeAnalysis>();
1474 PA.preserve<PostDominatorTreeAnalysis>();
1475 PA.preserve<TypeBasedAA>();
1476 PA.preserve<BasicAA>();
1477 PA.preserve<ScopedNoAliasAA>();
1478 FAM.invalidate(*F, PA);
1479}
1480
1481void RemoveRedundantPHI(Function *F, FunctionAnalysisManager &FAM) {
1482 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(*F);
1483 for (BasicBlock &BB : *F) {
1484 for (BasicBlock::iterator II = BB.begin(); isa<PHINode>(II);) {
1485 PHINode *PN = cast<PHINode>(II);
1486 ++II;
1487 SmallPtrSet<Value *, 2> vals;
1488 SmallPtrSet<PHINode *, 2> done;
1489 SmallVector<PHINode *, 2> todo = {PN};
1490 while (todo.size() > 0) {
1491 PHINode *N = todo.back();
1492 todo.pop_back();
1493 if (done.count(N))
1494 continue;
1495 done.insert(N);
1496 if (vals.size() == 0 && todo.size() == 0 && PN != N &&
1497 DT.dominates(N, PN)) {
1498 vals.insert(N);
1499 break;
1500 }
1501 for (auto &v : N->incoming_values()) {
1502 if (isa<UndefValue>(v))
1503 continue;
1504 if (auto NN = dyn_cast<PHINode>(v)) {
1505 todo.push_back(NN);
1506 continue;
1507 }
1508 vals.insert(v);
1509 if (vals.size() > 1)
1510 break;
1511 }
1512 if (vals.size() > 1)
1513 break;
1514 }
1515 if (vals.size() == 1) {
1516 auto V = *vals.begin();
1517 if (!isa<Instruction>(V) || DT.dominates(cast<Instruction>(V), PN)) {
1518 PN->replaceAllUsesWith(V);
1519 PN->eraseFromParent();
1520 }
1521 }
1522 }
1523 }
1524}
1525
1527 // Explicitly chose AA passes that are stateless
1528 // and will not be invalidated
1529 FAM.registerPass([] { return TypeBasedAA(); });
1530 FAM.registerPass([] { return BasicAA(); });
1531 MAM.registerPass([] { return GlobalsAA(); });
1532 // CallGraphAnalysis required for GlobalsAA
1533 MAM.registerPass([] { return CallGraphAnalysis(); });
1534
1535 FAM.registerPass([] { return ScopedNoAliasAA(); });
1536
1537 // SCEVAA causes some breakage/segfaults
1538 // disable for now, consider enabling in future
1539 // FAM.registerPass([] { return SCEVAA(); });
1540
1541#if LLVM_VERSION_MAJOR < 16
1543 FAM.registerPass([] { return CFLSteensAA(); });
1544#endif
1545
1546 MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
1547 FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
1548
1549 LAM.registerPass([&] { return FunctionAnalysisManagerLoopProxy(FAM); });
1550 FAM.registerPass([&] { return LoopAnalysisManagerFunctionProxy(LAM); });
1551
1552 FAM.registerPass([] {
1553 auto AM = AAManager();
1554 AM.registerFunctionAnalysis<BasicAA>();
1555 AM.registerFunctionAnalysis<TypeBasedAA>();
1556 AM.registerModuleAnalysis<GlobalsAA>();
1557 AM.registerFunctionAnalysis<ScopedNoAliasAA>();
1558
1559 // broken for different reasons
1560 // AM.registerFunctionAnalysis<SCEVAA>();
1561
1562#if LLVM_VERSION_MAJOR < 16
1564 AM.registerFunctionAnalysis<CFLSteensAA>();
1565#endif
1566
1567 return AM;
1568 });
1569
1570 PassBuilder PB;
1571 PB.registerModuleAnalyses(MAM);
1572 PB.registerFunctionAnalyses(FAM);
1573 PB.registerLoopAnalyses(LAM);
1574}
1575
1576llvm::AAResults &
1578 return FAM.getResult<AAManager>(*NewF);
1579}
1580
1581void setFullWillReturn(Function *NewF) {
1582 for (auto &BB : *NewF) {
1583 for (auto &I : BB) {
1584 if (auto CI = dyn_cast<CallInst>(&I)) {
1585 CI->addFnAttr(Attribute::WillReturn);
1586 CI->addFnAttr(Attribute::MustProgress);
1587 }
1588 if (auto CI = dyn_cast<InvokeInst>(&I)) {
1589 CI->addFnAttr(Attribute::WillReturn);
1590 CI->addFnAttr(Attribute::MustProgress);
1591 }
1592 }
1593 }
1594}
1595
1596void SplitPHIs(llvm::Function &F) {
1597 SetVector<Instruction *> todo;
1598 for (auto &BB : F) {
1599 for (auto &I : BB) {
1600 if (isa<PHINode>(&I)) {
1601 todo.insert(&I);
1602 } else if (isa<SelectInst>(&I)) {
1603 todo.insert(&I);
1604 }
1605 }
1606 }
1607 while (todo.size()) {
1608 auto cur = todo.pop_back_val();
1609 IRBuilder<> B(cur);
1610 auto ST = dyn_cast<StructType>(cur->getType());
1611 if (!ST)
1612 continue;
1613 bool justExtract = true;
1614 for (auto U : cur->users()) {
1615 if (!isa<ExtractValueInst>(U)) {
1616 justExtract = false;
1617 break;
1618 }
1619 if (cast<ExtractValueInst>(U)->getIndices().size() == 0) {
1620 justExtract = false;
1621 break;
1622 }
1623 }
1624 if (!justExtract)
1625 continue;
1626
1627 SmallVector<Value *, 1> replacements;
1628 for (size_t i = 0, e = ST->getNumElements(); i < e; i++) {
1629 if (auto cur2 = dyn_cast<PHINode>(cur)) {
1630 auto nPhi =
1631 B.CreatePHI(ST->getElementType(i), cur2->getNumIncomingValues(),
1632 cur->getName() + ".extract." + std::to_string(i));
1633 for (auto &&[blk, val] :
1634 llvm::zip(cur2->blocks(), cur2->incoming_values())) {
1635 IRBuilder B2(blk->getTerminator());
1636 nPhi->addIncoming(GradientUtils::extractMeta(B2, val, i), blk);
1637 }
1638 replacements.push_back(nPhi);
1639 todo.insert(nPhi);
1640 } else {
1641 auto cur3 = cast<SelectInst>(cur);
1642 auto rep = B.CreateSelect(
1643 cur3->getCondition(),
1644 GradientUtils::extractMeta(B, cur3->getTrueValue(), i),
1645 GradientUtils::extractMeta(B, cur3->getFalseValue(), i),
1646 cur->getName() + ".extract." + std::to_string(i));
1647 replacements.push_back(rep);
1648 if (auto sel = dyn_cast<SelectInst>(rep))
1649 todo.insert(sel);
1650 }
1651 }
1652 for (auto &U : make_early_inc_range(cur->uses())) {
1653 auto user = cast<ExtractValueInst>(U.getUser());
1654 Value *rep = replacements[user->getIndices()[0]];
1655 IRBuilder<> B(user);
1656 if (user->getIndices().size() > 1)
1657 rep = B.CreateExtractValue(rep, user->getIndices().slice(1));
1658 assert(rep->getType() == user->getType());
1659 user->replaceAllUsesWith(rep);
1660 user->eraseFromParent();
1661 }
1662 cur->eraseFromParent();
1663 }
1664}
1665
1666// returns if newly changed, subject to the pending calls
1667bool DetectPointerArgOfFn(llvm::Function &F,
1668 SmallPtrSetImpl<Function *> &calls_todo) {
1669 if (F.empty())
1670 return false;
1671 bool changed = false;
1672 for (auto &arg : F.args()) {
1673 if (!arg.getType()->isPointerTy())
1674 continue;
1675 // Store list of values we need to check
1676 std::deque<Value *> todo = {&arg};
1677 SmallPtrSet<Value *, 1> seen;
1678
1679 bool captured = false;
1680 bool read = false;
1681 bool written = false;
1682
1683 AttributeList Attrs = arg.getParent()->getAttributes();
1684
1685 // We have already hit the max state.
1686
1687 if (Attrs.hasParamAttr(arg.getArgNo(), Attribute::ReadNone) &&
1688 arg.hasNoCaptureAttr())
1689 continue;
1690
1691 while (!todo.empty()) {
1692 auto cur = todo.back();
1693 todo.pop_back();
1694 if (seen.contains(cur))
1695 continue;
1696 seen.insert(cur);
1697 for (auto &U : cur->uses()) {
1698 auto I = cast<Instruction>(U.getUser());
1699 if (isPointerArithmeticInst(I)) {
1700 todo.push_back(I);
1701 continue;
1702 }
1703 if (isa<LoadInst>(I)) {
1704 read = true;
1705 continue;
1706 }
1707 if (auto SI = dyn_cast<StoreInst>(I)) {
1708 if (SI->getValueOperand() == cur) {
1709 captured = true;
1710 read = true;
1711 written = true;
1712 break;
1713 }
1714 if (SI->getPointerOperand() == cur) {
1715 written = true;
1716 continue;
1717 }
1718 }
1719 if (auto MSI = dyn_cast<MemSetInst>(I)) {
1720 if (MSI->getRawDest() == cur) {
1721 written = true;
1722 }
1723 continue;
1724 }
1725 if (auto MTI = dyn_cast<MemTransferInst>(I)) {
1726 if (MTI->getRawDest() == cur) {
1727 written = true;
1728 }
1729 if (MTI->getRawSource() == cur) {
1730 read = true;
1731 }
1732 continue;
1733 }
1734 if (auto CB = dyn_cast<CallBase>(I)) {
1735
1736 if (CB->getCalledOperand() == cur) {
1737 captured = true;
1738 read = true;
1739 written = true;
1740 break;
1741 }
1742
1743 auto F2 = dyn_cast<Function>(CB->getCalledOperand());
1744
1745 if (F2 == &F && U.getOperandNo() == arg.getArgNo()) {
1746 continue;
1747 }
1748
1749 auto name = getFuncNameFromCall(CB);
1750
1751 if (name == "julia.write_barrier") {
1752 continue;
1753 }
1754
1755 // Used as operand bundle
1756 if (U.getOperandNo() >= CB->arg_size()) {
1757 captured = true;
1758 read = true;
1759 written = true;
1760 break;
1761 }
1762
1763 if (!isNoCapture(CB, U.getOperandNo()) && !arg.hasNoCaptureAttr()) {
1764 captured = true;
1765 if (F2)
1766 calls_todo.insert(F2);
1767 }
1768
1769 if (!isReadOnly(CB, U.getOperandNo()) &&
1770 !Attrs.hasParamAttr(arg.getArgNo(), Attribute::ReadNone) &&
1771 !Attrs.hasParamAttr(arg.getArgNo(), Attribute::ReadOnly)) {
1772 written = true;
1773 if (F2)
1774 calls_todo.insert(F2);
1775 }
1776
1777 if (!isWriteOnly(CB, U.getOperandNo()) &&
1778 !Attrs.hasParamAttr(arg.getArgNo(), Attribute::ReadNone) &&
1779 !Attrs.hasParamAttr(arg.getArgNo(), Attribute::WriteOnly)) {
1780 read = true;
1781 if (F2)
1782 calls_todo.insert(F2);
1783 }
1784 continue;
1785 }
1786
1787 captured = true;
1788 read = true;
1789 written = true;
1790 break;
1791 }
1792 }
1793
1794 if (!captured && !arg.hasNoCaptureAttr()) {
1795 addFunctionNoCapture(arg.getParent(), arg.getArgNo());
1796 changed = true;
1797 }
1798
1799 if ((!read && !written) ||
1800 Attrs.hasParamAttr(arg.getArgNo(), Attribute::ReadNone)) {
1801 if (!Attrs.hasParamAttr(arg.getArgNo(), Attribute::ReadNone)) {
1802 if (Attrs.hasParamAttr(arg.getArgNo(), Attribute::ReadOnly)) {
1803 arg.removeAttr(Attribute::ReadOnly);
1804 }
1805 if (Attrs.hasParamAttr(arg.getArgNo(), Attribute::WriteOnly)) {
1806 arg.removeAttr(Attribute::WriteOnly);
1807 }
1808 arg.addAttr(Attribute::ReadNone);
1809 changed = true;
1810 }
1811 } else if (!written) {
1812 if (!Attrs.hasParamAttr(arg.getArgNo(), Attribute::ReadOnly)) {
1813 arg.addAttr(Attribute::ReadOnly);
1814 changed = true;
1815 }
1816 } else if (!read) {
1817 if (!Attrs.hasParamAttr(arg.getArgNo(), Attribute::WriteOnly)) {
1818 arg.addAttr(Attribute::WriteOnly);
1819 changed = true;
1820 }
1821 }
1822 }
1823 return changed;
1824}
1825
1826// returns if newly changed, subject to the pending calls
1827bool DetectNoUnwindOfFn(llvm::Function &F,
1828 SmallPtrSetImpl<Function *> &calls_todo) {
1829 if (F.empty())
1830 return false;
1831 if (F.doesNotThrow())
1832 return false;
1833
1834 bool mayThrow = false;
1835
1836 for (auto &BB : F) {
1837 for (auto &I : BB) {
1838#if LLVM_VERSION_MAJOR >= 17
1839 if (!I.mayThrow(/*IncludePhaseOneUnwind*/ true)) {
1840 continue;
1841 }
1842#else
1843 if (!I.mayThrow()) {
1844 continue;
1845 }
1846#endif
1847 if (auto CB = dyn_cast<CallBase>(&I)) {
1848 if (auto F2 = CB->getCalledFunction()) {
1849 if (F2 == &F)
1850 continue;
1851 if (F2->doesNotThrow())
1852 continue;
1853 calls_todo.insert(F2);
1854 }
1855 }
1856 mayThrow = true;
1857 break;
1858 }
1859 }
1860 if (mayThrow)
1861 return false;
1862 F.setDoesNotThrow();
1863 return true;
1864}
1865
1866// returns if newly legal, subject to the pending calls
1867bool DetectReadonlyOrThrowFn(llvm::Function &F,
1868 SmallPtrSetImpl<Function *> &calls_todo,
1869 llvm::TargetLibraryInfo &TLI, bool &local) {
1870 if (isReadOnlyOrThrow(&F))
1871 return false;
1872 if (F.empty())
1873 return false;
1874 const auto unreachable = getGuaranteedUnreachable(&F);
1875 for (auto &BB : F) {
1876 if (unreachable.find(&BB) != unreachable.end()) {
1877 continue;
1878 }
1879 for (auto &I : BB) {
1880 if (!I.mayWriteToMemory())
1881 continue;
1882 if (hasMetadata(&I, "enzyme_ReadOnlyOrThrow"))
1883 continue;
1884 if (hasMetadata(&I, "enzyme_LocalReadOnlyOrThrow"))
1885 continue;
1886
1887 if (auto MTI = dyn_cast<MemTransferInst>(&I)) {
1888 auto Obj = getBaseObject(MTI->getOperand(0));
1889 // Storing into local memory is fine since it definitionally will not be
1890 // seen outside the function. Note, even if one stored into x =
1891 // malloc(..), and stored x into a global/arg pointer, that second store
1892 // would trigger not readonly.
1893 if (isa<AllocaInst>(Obj))
1894 continue;
1895 if (isAllocationCall(Obj, TLI)) {
1896 if (local)
1897 continue;
1898 if (notCaptured(Obj))
1899 continue;
1900 local = true;
1901 continue;
1902 }
1903 if (auto arg = dyn_cast<Argument>(Obj)) {
1904 if (arg->hasStructRetAttr() ||
1905 arg->getParent()
1906 ->getAttribute(arg->getArgNo() + AttributeList::FirstArgIndex,
1907 "enzymejl_returnRoots")
1908 .isValid() ||
1909 arg->getParent()
1910 ->getAttribute(arg->getArgNo() + AttributeList::FirstArgIndex,
1911 "enzymejl_sret_union_bytes")
1912 .isValid()) {
1913 local = true;
1914 continue;
1915 }
1916 }
1917 }
1918 if (auto MSI = dyn_cast<MemSetInst>(&I)) {
1919 auto Obj = getBaseObject(MSI->getOperand(0));
1920 // Storing into local memory is fine since it definitionally will not be
1921 // seen outside the function. Note, even if one stored into x =
1922 // malloc(..), and stored x into a global/arg pointer, that second store
1923 // would trigger not readonly.
1924 if (isa<AllocaInst>(Obj))
1925 continue;
1926 if (isAllocationCall(Obj, TLI)) {
1927 if (local)
1928 continue;
1929 if (notCaptured(Obj))
1930 continue;
1931 local = true;
1932 continue;
1933 }
1934 if (auto arg = dyn_cast<Argument>(Obj)) {
1935 if (arg->hasStructRetAttr() ||
1936 arg->getParent()
1937 ->getAttribute(arg->getArgNo() + AttributeList::FirstArgIndex,
1938 "enzymejl_returnRoots")
1939 .isValid() ||
1940 arg->getParent()
1941 ->getAttribute(arg->getArgNo() + AttributeList::FirstArgIndex,
1942 "enzymejl_sret_union_bytes")
1943 .isValid()) {
1944 local = true;
1945 continue;
1946 }
1947 }
1948 }
1949
1950 if (auto CI = dyn_cast<CallBase>(&I)) {
1951 if (isLocalReadOnlyOrThrow(CI)) {
1952 continue;
1953 }
1954 if (isAllocationCall(CI, TLI)) {
1955 continue;
1956 }
1957 if (getFuncNameFromCall(CI) == "zeroType") {
1958 auto Obj = getBaseObject(CI->getArgOperand(0));
1959 // Storing into local memory is fine since it definitionally will not
1960 // be seen outside the function. Note, even if one stored into x =
1961 // malloc(..), and stored x into a global/arg pointer, that second
1962 // store would trigger not readonly.
1963 if (isa<AllocaInst>(Obj))
1964 continue;
1965 if (isAllocationCall(Obj, TLI)) {
1966 if (local)
1967 continue;
1968 if (notCaptured(Obj))
1969 continue;
1970 local = true;
1971 continue;
1972 }
1973 if (auto arg = dyn_cast<Argument>(Obj)) {
1974 if (arg->hasStructRetAttr() ||
1975 arg->getParent()
1976 ->getAttribute(arg->getArgNo() +
1977 AttributeList::FirstArgIndex,
1978 "enzymejl_returnRoots")
1979 .isValid() ||
1980 arg->getParent()
1981 ->getAttribute(arg->getArgNo() +
1982 AttributeList::FirstArgIndex,
1983 "enzymejl_sret_union_bytes")
1984 .isValid()) {
1985 local = true;
1986 continue;
1987 }
1988 }
1989 }
1990 if (auto F2 = CI->getCalledFunction()) {
1991 if (isDebugFunction(F2))
1992 continue;
1993 if (F2->getName() == "julia.write_barrier")
1994 continue;
1995 if (F2->getCallingConv() == CI->getCallingConv()) {
1996 if (F2 == &F)
1997 continue;
1998 if (isReadOnlyOrThrow(F2))
1999 continue;
2000 if (!F2->empty()) {
2001 if (EnzymePrintPerf) {
2003 "WritingInstruction", I, "Instruction could write forcing ",
2004 F.getName(),
2005 " to not be marked readonly_or_throw per sub-call of ",
2006 F2->getName());
2007 }
2008 calls_todo.insert(F2);
2009 continue;
2010 }
2011 }
2012 }
2013 }
2014 if (auto SI = dyn_cast<StoreInst>(&I)) {
2015 auto Obj = getBaseObject(SI->getPointerOperand());
2016 // Storing into local memory is fine since it definitionally will not be
2017 // seen outside the function. Note, even if one stored into x =
2018 // malloc(..), and stored x into a global/arg pointer, that second store
2019 // would trigger not readonly.
2020 if (isa<AllocaInst>(Obj))
2021 continue;
2022 if (isAllocationCall(Obj, TLI)) {
2023 if (local)
2024 continue;
2025 if (notCaptured(Obj))
2026 continue;
2027 local = true;
2028 continue;
2029 }
2030 if (auto arg = dyn_cast<Argument>(Obj)) {
2031 if (arg->hasStructRetAttr() ||
2032 arg->getParent()
2033 ->getAttribute(arg->getArgNo() + AttributeList::FirstArgIndex,
2034 "enzymejl_returnRoots")
2035 .isValid() ||
2036 arg->getParent()
2037 ->getAttribute(arg->getArgNo() + AttributeList::FirstArgIndex,
2038 "enzymejl_sret_union_bytes")
2039 .isValid()) {
2040 local = true;
2041 continue;
2042 }
2043 }
2044 }
2045 // ignore atomic load impacts
2046 if (isa<LoadInst>(&I))
2047 continue;
2048 if (EnzymeJuliaAddrLoad && isa<FenceInst>(&I)) {
2049 if (auto prev = dyn_cast_or_null<CallBase>(I.getPrevNode())) {
2050 if (auto F = prev->getCalledFunction())
2051 if (F->getName() == "julia.safepoint")
2052 continue;
2053 }
2054 if (auto prev = dyn_cast_or_null<CallBase>(I.getNextNode())) {
2055 if (auto F = prev->getCalledFunction())
2056 if (F->getName() == "julia.safepoint")
2057 continue;
2058 }
2059 }
2060
2061 if (EnzymePrintPerf) {
2062 EmitWarning("WritingInstruction", I, "Instruction could write forcing ",
2063 F.getName(), " to not be marked readonly_or_throw", I);
2064 }
2065 return false;
2066 }
2067 }
2068
2069 if (calls_todo.size() == 0) {
2070 if (local) {
2071 F.addFnAttr("enzyme_LocalReadOnlyOrThrow");
2072 } else {
2073 F.addFnAttr("enzyme_ReadOnlyOrThrow");
2074 }
2075 }
2076 return true;
2077}
2078
2079bool DetectReadonlyOrThrow(Module &M) {
2080
2081 bool changed = false;
2082
2083 {
2084 // Set of functions newly deduced readonly/nocapture/etc by this pass
2085 SmallVector<llvm::Function *> todo;
2086
2087 // Map of functions which could be readonly if all functions in the set are
2088 // marked readonly
2089 DenseMap<llvm::Function *, SmallPtrSet<Function *, 1>> todo_map;
2090
2091 for (Function &F : M) {
2092 SmallPtrSet<Function *, 1> calls_todo;
2093 if (DetectNoUnwindOfFn(F, calls_todo)) {
2094 changed = true;
2095 for (auto F2 : todo_map[&F]) {
2096 todo.push_back(F2);
2097 }
2098 todo_map.erase(&F);
2099 }
2100 for (auto tocheck : calls_todo) {
2101 todo_map[tocheck].insert(&F);
2102 todo.push_back(tocheck);
2103 }
2104 }
2105
2106 while (!todo.empty()) {
2107 auto cur = todo.pop_back_val();
2108
2109 SmallPtrSet<Function *, 1> calls_todo;
2110
2111 if (!DetectNoUnwindOfFn(*cur, calls_todo))
2112 continue;
2113
2114 for (auto F2 : todo_map[cur]) {
2115 todo.push_back(F2);
2116 }
2117
2118 todo_map.erase(cur);
2119
2120 for (auto tocheck : calls_todo) {
2121 todo_map[tocheck].insert(cur);
2122 todo.push_back(tocheck);
2123 }
2124
2125 changed = true;
2126 }
2127 }
2128
2129 {
2130 // Set of functions newly deduced readonly/nocapture/etc by this pass
2131 SmallVector<llvm::Function *> todo;
2132
2133 // Map of functions which could be readonly if all functions in the set are
2134 // marked readonly
2135 DenseMap<llvm::Function *, SmallPtrSet<Function *, 1>> todo_map;
2136
2137 for (Function &F : M) {
2138 SmallPtrSet<Function *, 1> calls_todo;
2139 if (DetectPointerArgOfFn(F, calls_todo)) {
2140 changed = true;
2141 for (auto F2 : todo_map[&F]) {
2142 todo.push_back(F2);
2143 }
2144 todo_map.erase(&F);
2145 }
2146 for (auto tocheck : calls_todo) {
2147 todo_map[tocheck].insert(&F);
2148 todo.push_back(tocheck);
2149 }
2150 }
2151
2152 while (!todo.empty()) {
2153 auto cur = todo.pop_back_val();
2154
2155 SmallPtrSet<Function *, 1> calls_todo;
2156
2157 if (!DetectPointerArgOfFn(*cur, calls_todo))
2158 continue;
2159
2160 for (auto F2 : todo_map[cur]) {
2161 todo.push_back(F2);
2162 }
2163
2164 todo_map.erase(cur);
2165
2166 for (auto tocheck : calls_todo) {
2167 todo_map[tocheck].insert(cur);
2168 todo.push_back(tocheck);
2169 }
2170
2171 changed = true;
2172 }
2173 }
2174
2175 PassBuilder PB;
2176 LoopAnalysisManager LAM;
2177 FunctionAnalysisManager FAM;
2178 CGSCCAnalysisManager CGAM;
2179 ModuleAnalysisManager MAM;
2180 PB.registerModuleAnalyses(MAM);
2181 PB.registerFunctionAnalyses(FAM);
2182 PB.registerLoopAnalyses(LAM);
2183 PB.registerCGSCCAnalyses(CGAM);
2184 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
2185
2186 // Set of functions newly deduced readonlyorthrow by this pass
2187 SmallVector<llvm::Function *> todo;
2188
2189 // Map of functions which could be readonly if all functions in the set are
2190 // marked readonly
2191 DenseMap<llvm::Function *, SmallPtrSet<Function *, 1>> todo_map;
2192
2193 // Map from a function `f` to all the functions that have `f` as a
2194 // prerequisite for being readonly. Inverse of `todo_map`
2195 DenseMap<llvm::Function *, SmallPtrSet<Function *, 1>> inverse_todo_map;
2196
2197 SmallPtrSet<Function *, 1> LocalReadOnlyFunctions;
2198
2199 for (Function &F : M) {
2200 SmallPtrSet<Function *, 1> calls_todo;
2201 auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
2202 bool local = false;
2203 if (DetectReadonlyOrThrowFn(F, calls_todo, TLI, local)) {
2204 if (local)
2205 LocalReadOnlyFunctions.insert(&F);
2206 if (calls_todo.size() == 0) {
2207 changed = true;
2208 todo.push_back(&F);
2209 } else {
2210 for (auto F2 : calls_todo) {
2211 inverse_todo_map[F2].insert(&F);
2212 }
2213 }
2214 todo_map[&F] = std::move(calls_todo);
2215 }
2216 }
2217
2218 while (todo.size()) {
2219 auto cur = todo.pop_back_val();
2220 auto found = inverse_todo_map.find(cur);
2221
2222 // Nobody needs cur as a prerequisite
2223 if (found == inverse_todo_map.end()) {
2224 continue;
2225 }
2226 for (auto F2 : found->second) {
2227 auto found2 = todo_map.find(F2);
2228 assert(found2 != todo_map.end());
2229 auto &fwd_set = found2->second;
2230 fwd_set.erase(cur);
2231 if (fwd_set.size() == 0) {
2232 if (LocalReadOnlyFunctions.contains(F2)) {
2233 F2->addFnAttr("enzyme_LocalReadOnlyOrThrow");
2234 } else {
2235 F2->addFnAttr("enzyme_ReadOnlyOrThrow");
2236 }
2237 todo.push_back(F2);
2238 todo_map.erase(F2);
2239 }
2240 }
2241
2242 inverse_todo_map.erase(found);
2243 }
2244 return changed;
2245}
2246
2248 DerivativeMode mode) {
2249
2250 TimeTraceScope timeScope("preprocessForClone", F->getName());
2251
2256
2257 // If we've already processed this, return the previous version
2258 // and derive aliasing information
2259 if (cache.find(std::make_pair(F, mode)) != cache.end()) {
2260 Function *NewF = cache[std::make_pair(F, mode)];
2261 return NewF;
2262 }
2263
2264 Function *NewF =
2265 Function::Create(F->getFunctionType(), F->getLinkage(),
2266 "preprocess_" + F->getName(), F->getParent());
2267
2268 ValueToValueMapTy VMap;
2269 for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) {
2270 VMap[i] = j;
2271 j->setName(i->getName());
2272 if (EnzymeNoAlias && j->getType()->isPointerTy()) {
2273 j->addAttr(Attribute::NoAlias);
2274 }
2275 ++i;
2276 ++j;
2277 }
2278
2279 SmallVector<ReturnInst *, 4> Returns;
2280
2281 if (!F->empty()) {
2282 CloneFunctionInto(
2283 NewF, F, VMap,
2284 /*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly,
2285 Returns, "", nullptr);
2286 }
2287 CloneOrigin[NewF] = F;
2288 NewF->setAttributes(F->getAttributes());
2289 if (EnzymeNoAlias)
2290 for (auto j = NewF->arg_begin(); j != NewF->arg_end(); j++) {
2291 if (j->getType()->isPointerTy()) {
2292 j->addAttr(Attribute::NoAlias);
2293 }
2294 }
2295 NewF->addFnAttr(Attribute::WillReturn);
2296 NewF->addFnAttr(Attribute::MustProgress);
2297 setFullWillReturn(NewF);
2298
2299 if (EnzymePreopt) {
2300 if (EnzymeInline) {
2302 setFullWillReturn(NewF);
2303 PreservedAnalyses PA;
2304 FAM.invalidate(*NewF, PA);
2305
2306 OptimizationLevel Level = OptimizationLevel::O0;
2307
2308 switch (EnzymePostInlineOpt) {
2309 default:
2310 case 0:
2311 Level = OptimizationLevel::O0;
2312 break;
2313 case 1:
2314 Level = OptimizationLevel::O1;
2315 break;
2316 case 2:
2317 Level = OptimizationLevel::O2;
2318 break;
2319 case 3:
2320 Level = OptimizationLevel::O3;
2321 break;
2322 }
2323 if (Level != OptimizationLevel::O0) {
2324 PassBuilder PB;
2325 FunctionPassManager FPM = PB.buildFunctionSimplificationPipeline(
2326 Level, ThinOrFullLTOPhase::None);
2327 PA = FPM.run(*NewF, FAM);
2328 FAM.invalidate(*NewF, PA);
2329 }
2330 }
2331 }
2332
2333 {
2334 SmallVector<CallInst *, 4> ItersToErase;
2335 for (auto &BB : *NewF) {
2336 for (auto &I : BB) {
2337
2338 if (auto CI = dyn_cast<CallInst>(&I)) {
2339
2340 Function *called = CI->getCalledFunction();
2341 if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
2342 if (castinst->isCast()) {
2343 if (auto fn = dyn_cast<Function>(castinst->getOperand(0)))
2344 called = fn;
2345 }
2346 }
2347
2348 if (called && called->getName().contains("__enzyme_iter")) {
2349 ItersToErase.push_back(CI);
2350 }
2351 }
2352 }
2353 }
2354 for (auto Call : ItersToErase) {
2355 IRBuilder<> B(Call);
2356 Call->setArgOperand(
2357 0, B.CreateAdd(Call->getArgOperand(0), Call->getArgOperand(1)));
2358 }
2359 }
2360
2361 // Assume allocations do not return null
2362 {
2363 TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(*F);
2364 SmallVector<Instruction *, 4> CmpsToErase;
2365 SmallVector<BasicBlock *, 4> BranchesToErase;
2366 for (auto &BB : *NewF) {
2367 for (auto &I : BB) {
2368 if (auto IC = dyn_cast<ICmpInst>(&I)) {
2369 if (!IC->isEquality())
2370 continue;
2371 for (int i = 0; i < 2; i++) {
2372 if (isa<ConstantPointerNull>(IC->getOperand(1 - i)))
2373 if (isAllocationCall(IC->getOperand(i), TLI)) {
2374 for (auto U : IC->users()) {
2375 if (auto BI = dyn_cast<BranchInst>(U))
2376 BranchesToErase.push_back(BI->getParent());
2377 }
2378 IC->replaceAllUsesWith(
2379 IC->getPredicate() == ICmpInst::ICMP_NE
2380 ? ConstantInt::getTrue(I.getContext())
2381 : ConstantInt::getFalse(I.getContext()));
2382 CmpsToErase.push_back(&I);
2383 break;
2384 }
2385 }
2386 }
2387 }
2388 }
2389 for (auto I : CmpsToErase)
2390 I->eraseFromParent();
2391 for (auto BE : BranchesToErase)
2392 ConstantFoldTerminator(BE);
2393 }
2394
2395 SimplifyMPIQueries(*NewF, FAM);
2396 {
2397 auto PA = PromotePass().run(*NewF, FAM);
2398 FAM.invalidate(*NewF, PA);
2399 }
2400
2401 if (EnzymeLowerGlobals) {
2402 SmallVector<CallInst *, 4> Calls;
2403 SmallVector<ReturnInst *, 4> Returns;
2404 for (BasicBlock &BB : *NewF) {
2405 for (Instruction &I : BB) {
2406 if (auto CI = dyn_cast<CallInst>(&I)) {
2407 Calls.push_back(CI);
2408 }
2409 if (auto RI = dyn_cast<ReturnInst>(&I)) {
2410 Returns.push_back(RI);
2411 }
2412 }
2413 }
2414
2415 // TODO consider using TBAA and globals as well
2416 // instead of just BasicAA
2417 AAResults AA2(FAM.getResult<TargetLibraryAnalysis>(*NewF));
2418 AA2.addAAResult(FAM.getResult<BasicAA>(*NewF));
2419 AA2.addAAResult(FAM.getResult<TypeBasedAA>(*NewF));
2420 AA2.addAAResult(FAM.getResult<ScopedNoAliasAA>(*NewF));
2421
2422 for (auto &g : NewF->getParent()->globals()) {
2423 bool inF = false;
2424 {
2425 std::set<Constant *> seen;
2426 std::deque<Constant *> todo = {(Constant *)&g};
2427 while (todo.size()) {
2428 auto GV = todo.front();
2429 todo.pop_front();
2430 if (!seen.insert(GV).second)
2431 continue;
2432 for (auto u : GV->users()) {
2433 if (auto C = dyn_cast<Constant>(u)) {
2434 todo.push_back(C);
2435 } else if (auto I = dyn_cast<Instruction>(u)) {
2436 if (I->getParent()->getParent() == NewF) {
2437 inF = true;
2438 goto doneF;
2439 }
2440 }
2441 }
2442 }
2443 }
2444 doneF:;
2445 if (inF) {
2446 bool activeCall = false;
2447 bool hasWrite = false;
2448 MemoryLocation Loc =
2449 MemoryLocation(&g, LocationSize::beforeOrAfterPointer());
2450
2451 for (CallInst *CI : Calls) {
2452 if (isa<IntrinsicInst>(CI))
2453 continue;
2454 Function *F = CI->getCalledFunction();
2455 if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
2456 if (castinst->isCast())
2457 if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
2458 F = fn;
2459 }
2460 }
2461 if (F && isMemFreeLibMFunction(F->getName())) {
2462 continue;
2463 }
2464 if (F && F->getName().contains("__enzyme_integer")) {
2465 continue;
2466 }
2467 if (F && F->getName().contains("__enzyme_pointer")) {
2468 continue;
2469 }
2470 if (F && F->getName().contains("__enzyme_float")) {
2471 continue;
2472 }
2473 if (F && F->getName().contains("__enzyme_double")) {
2474 continue;
2475 }
2476 if (F && (startsWith(F->getName(), "f90io") ||
2477 F->getName() == "ftnio_fmt_write64" ||
2478 F->getName() == "__mth_i_ipowi" ||
2479 F->getName() == "f90_pausea")) {
2480 continue;
2481 }
2482 if (llvm::isModOrRefSet(AA2.getModRefInfo(CI, Loc))) {
2483 llvm::errs() << " failed to inline global: " << g << " due to "
2484 << *CI << "\n";
2485 activeCall = true;
2486 break;
2487 }
2488 }
2489
2490 if (!activeCall) {
2491 std::set<Value *> seen;
2492 std::deque<Value *> todo = {(Value *)&g};
2493 while (todo.size()) {
2494 auto GV = todo.front();
2495 todo.pop_front();
2496 if (!seen.insert(GV).second)
2497 continue;
2498 for (auto u : GV->users()) {
2499 if (isa<Constant>(u) || isa<GetElementPtrInst>(u) ||
2500 isa<CastInst>(u) || isa<LoadInst>(u)) {
2501 todo.push_back(u);
2502 continue;
2503 }
2504
2505 if (auto II = dyn_cast<IntrinsicInst>(u)) {
2506 if (isIntelSubscriptIntrinsic(*II)) {
2507 todo.push_back(u);
2508 continue;
2509 }
2510 }
2511
2512 if (auto CI = dyn_cast<CallInst>(u)) {
2513 Function *F = CI->getCalledFunction();
2514 if (auto castinst =
2515 dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
2516 if (castinst->isCast())
2517 if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
2518 F = fn;
2519 }
2520 }
2521 if (F && isMemFreeLibMFunction(F->getName())) {
2522 continue;
2523 }
2524 if (F && F->getName().contains("__enzyme_integer")) {
2525 continue;
2526 }
2527 if (F && F->getName().contains("__enzyme_pointer")) {
2528 continue;
2529 }
2530 if (F && F->getName().contains("__enzyme_float")) {
2531 continue;
2532 }
2533 if (F && F->getName().contains("__enzyme_double")) {
2534 continue;
2535 }
2536 if (F && (startsWith(F->getName(), "f90io") ||
2537 F->getName() == "ftnio_fmt_write64" ||
2538 F->getName() == "__mth_i_ipowi" ||
2539 F->getName() == "f90_pausea")) {
2540 continue;
2541 }
2542
2543 if (couldFunctionArgumentCapture(CI, GV)) {
2544 hasWrite = true;
2545 goto endCheck;
2546 }
2547
2548 if (llvm::isModSet(AA2.getModRefInfo(CI, Loc))) {
2549 hasWrite = true;
2550 goto endCheck;
2551 }
2552 }
2553
2554 else if (auto I = dyn_cast<Instruction>(u)) {
2555 if (llvm::isModSet(AA2.getModRefInfo(I, Loc))) {
2556 hasWrite = true;
2557 goto endCheck;
2558 }
2559 }
2560 }
2561 }
2562 }
2563
2564 endCheck:;
2565 if (!activeCall && hasWrite) {
2566 IRBuilder<> bb(&NewF->getEntryBlock(), NewF->getEntryBlock().begin());
2567 AllocaInst *antialloca = bb.CreateAlloca(
2568 g.getValueType(), g.getType()->getPointerAddressSpace(), nullptr,
2569 g.getName() + "_local");
2570
2571 if (g.getAlignment()) {
2572 antialloca->setAlignment(Align(g.getAlignment()));
2573 }
2574
2575 std::map<Constant *, Value *> remap;
2576 remap[&g] = antialloca;
2577
2578 std::deque<Constant *> todo = {&g};
2579 while (todo.size()) {
2580 auto GV = todo.front();
2581 todo.pop_front();
2582 if (&g != GV && remap.find(GV) != remap.end())
2583 continue;
2584 Value *replaced = nullptr;
2585 if (remap.find(GV) != remap.end()) {
2586 replaced = remap[GV];
2587 } else if (auto CE = dyn_cast<ConstantExpr>(GV)) {
2588 auto I = CE->getAsInstruction();
2589 bb.Insert(I);
2590 assert(isa<Constant>(I->getOperand(0)));
2591 assert(remap[cast<Constant>(I->getOperand(0))]);
2592 I->setOperand(0, remap[cast<Constant>(I->getOperand(0))]);
2593 replaced = remap[GV] = I;
2594 }
2595 assert(replaced && "unhandled constantexpr");
2596
2597 SmallVector<std::pair<Instruction *, size_t>, 4> uses;
2598 for (Use &U : GV->uses()) {
2599 if (auto I = dyn_cast<Instruction>(U.getUser())) {
2600 if (I->getParent()->getParent() == NewF) {
2601 uses.emplace_back(I, U.getOperandNo());
2602 }
2603 }
2604 if (auto C = dyn_cast<Constant>(U.getUser())) {
2605 assert(C != &g);
2606 todo.push_back(C);
2607 }
2608 }
2609 for (auto &U : uses) {
2610 U.first->setOperand(U.second, replaced);
2611 }
2612 }
2613
2614 Value *args[] = {
2615 bb.CreateBitCast(antialloca, getInt8PtrTy(g.getContext())),
2616 bb.CreateBitCast(&g, getInt8PtrTy(g.getContext())),
2617 ConstantInt::get(
2618 Type::getInt64Ty(g.getContext()),
2619 g.getParent()->getDataLayout().getTypeAllocSizeInBits(
2620 g.getValueType()) /
2621 8),
2622 ConstantInt::getFalse(g.getContext())};
2623
2624 Type *tys[] = {args[0]->getType(), args[1]->getType(),
2625 args[2]->getType()};
2626 auto intr =
2627 getIntrinsicDeclaration(g.getParent(), Intrinsic::memcpy, tys);
2628 {
2629
2630 auto cal = bb.CreateCall(intr, args);
2631 if (g.getAlignment()) {
2632 cal->addParamAttr(
2633 0, Attribute::getWithAlignment(g.getContext(),
2634 Align(g.getAlignment())));
2635 cal->addParamAttr(
2636 1, Attribute::getWithAlignment(g.getContext(),
2637 Align(g.getAlignment())));
2638 }
2639 }
2640
2641 std::swap(args[0], args[1]);
2642
2643 for (ReturnInst *RI : Returns) {
2644 IRBuilder<> IB(RI);
2645 auto cal = IB.CreateCall(intr, args);
2646 if (g.getAlignment()) {
2647 cal->addParamAttr(
2648 0, Attribute::getWithAlignment(g.getContext(),
2649 Align(g.getAlignment())));
2650 cal->addParamAttr(
2651 1, Attribute::getWithAlignment(g.getContext(),
2652 Align(g.getAlignment())));
2653 }
2654 }
2655 }
2656 }
2657 }
2658
2659 auto Level = OptimizationLevel::O2;
2660
2661 PassBuilder PB;
2662 FunctionPassManager FPM =
2663 PB.buildFunctionSimplificationPipeline(Level, ThinOrFullLTOPhase::None);
2664 auto PA = FPM.run(*F, FAM);
2665 FAM.invalidate(*F, PA);
2666 }
2667
2668 if (EnzymePreopt) {
2669 {
2670 auto PA = LowerInvokePass().run(*NewF, FAM);
2671 FAM.invalidate(*NewF, PA);
2672 }
2673 {
2674 auto PA = UnreachableBlockElimPass().run(*NewF, FAM);
2675 FAM.invalidate(*NewF, PA);
2676 }
2677
2678 {
2679 auto PA = PromotePass().run(*NewF, FAM);
2680 FAM.invalidate(*NewF, PA);
2681 }
2682
2683 {
2684#if LLVM_VERSION_MAJOR >= 16 && !defined(FLANG)
2685 auto PA = SROAPass(llvm::SROAOptions::ModifyCFG).run(*NewF, FAM);
2686#elif !defined(FLANG)
2687 auto PA = SROAPass().run(*NewF, FAM);
2688#else
2689 auto PA = SROA().run(*NewF, FAM);
2690#endif
2691 FAM.invalidate(*NewF, PA);
2692 }
2693
2694 if (mode != DerivativeMode::ForwardMode)
2695 ReplaceReallocs(NewF);
2696
2697 {
2698#if LLVM_VERSION_MAJOR >= 16 && !defined(FLANG)
2699 auto PA = SROAPass(llvm::SROAOptions::PreserveCFG).run(*NewF, FAM);
2700#elif !defined(FLANG)
2701 auto PA = SROAPass().run(*NewF, FAM);
2702#else
2703 auto PA = SROA().run(*NewF, FAM);
2704#endif
2705 FAM.invalidate(*NewF, PA);
2706 }
2707
2708 SimplifyCFGOptions scfgo;
2709 {
2710 auto PA = SimplifyCFGPass(scfgo).run(*NewF, FAM);
2711 FAM.invalidate(*NewF, PA);
2712 }
2713 }
2714
2715 {
2716 SplitPHIs(*NewF);
2717 PreservedAnalyses PA;
2718 PA.preserve<AssumptionAnalysis>();
2719 PA.preserve<TargetLibraryAnalysis>();
2720 PA.preserve<LoopAnalysis>();
2721 PA.preserve<DominatorTreeAnalysis>();
2722 PA.preserve<PostDominatorTreeAnalysis>();
2723 PA.preserve<TypeBasedAA>();
2724 PA.preserve<BasicAA>();
2725 PA.preserve<ScopedNoAliasAA>();
2726 PA.preserve<ScalarEvolutionAnalysis>();
2727 PA.preserve<PhiValuesAnalysis>();
2728 }
2729
2730 if (mode != DerivativeMode::ForwardMode)
2731 ReplaceReallocs(NewF);
2732
2736 // For subfunction calls upgrade stack allocations to mallocs
2737 // to ensure availability in the reverse pass
2738 auto unreachable = getGuaranteedUnreachable(NewF);
2739 UpgradeAllocasToMallocs(NewF, mode, unreachable);
2740 }
2741
2742 CanonicalizeLoops(NewF, FAM);
2743 RemoveRedundantPHI(NewF, FAM);
2744
2745 // Run LoopSimplifyPass to ensure preheaders exist on all loops
2746 {
2747 auto PA = LoopSimplifyPass().run(*NewF, FAM);
2748 FAM.invalidate(*NewF, PA);
2749 }
2750
2751 {
2752 for (auto &BB : *NewF) {
2753 for (auto &I : make_early_inc_range(BB)) {
2754 if (auto MTI = dyn_cast<MemTransferInst>(&I)) {
2755
2756 if (auto CI = dyn_cast<ConstantInt>(MTI->getOperand(2))) {
2757 if (CI->getValue() == 0) {
2758 MTI->eraseFromParent();
2759 }
2760 }
2761 }
2762 }
2763 }
2764
2765 PreservedAnalyses PA;
2766 PA.preserve<AssumptionAnalysis>();
2767 PA.preserve<TargetLibraryAnalysis>();
2768 PA.preserve<LoopAnalysis>();
2769 PA.preserve<DominatorTreeAnalysis>();
2770 PA.preserve<PostDominatorTreeAnalysis>();
2771 PA.preserve<TypeBasedAA>();
2772 PA.preserve<BasicAA>();
2773 PA.preserve<ScopedNoAliasAA>();
2774 PA.preserve<ScalarEvolutionAnalysis>();
2775 PA.preserve<PhiValuesAnalysis>();
2776
2777 FAM.invalidate(*NewF, PA);
2778
2780 for (auto &Arg : NewF->args()) {
2781 if (!Arg.hasName())
2782 Arg.setName("arg");
2783 }
2784 for (BasicBlock &BB : *NewF) {
2785 if (!BB.hasName())
2786 BB.setName("bb");
2787
2788 for (Instruction &I : BB) {
2789 if (!I.hasName() && !I.getType()->isVoidTy())
2790 I.setName("i");
2791 }
2792 }
2793 }
2794 }
2795
2797 if (false) {
2798 reset:;
2799 PreservedAnalyses PA;
2800 FAM.invalidate(*NewF, PA);
2801 }
2802
2803 SmallVector<BasicBlock *, 4> MultiBlocks;
2804 for (auto &B : *NewF) {
2805 if (B.hasNPredecessorsOrMore(3))
2806 MultiBlocks.push_back(&B);
2807 }
2808
2809 LoopInfo &LI = FAM.getResult<LoopAnalysis>(*NewF);
2810 for (BasicBlock *B : MultiBlocks) {
2811
2812 // Map of function edges to list of values possible
2813 std::map<std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>,
2814 std::set<BasicBlock *>>
2815 done;
2816 {
2817 std::deque<std::tuple<
2818 std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>,
2819 BasicBlock *>>
2820 Q; // newblock, target
2821
2822 for (auto P : predecessors(B)) {
2823 Q.emplace_back(std::make_pair(P, B), P);
2824 }
2825
2826 for (std::tuple<
2827 std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>,
2828 BasicBlock *>
2829 trace;
2830 Q.size() > 0;) {
2831 trace = Q.front();
2832 Q.pop_front();
2833 auto edge = std::get<0>(trace);
2834 auto block = edge.first;
2835 auto target = std::get<1>(trace);
2836
2837 if (done[edge].count(target))
2838 continue;
2839 done[edge].insert(target);
2840
2841 Loop *blockLoop = LI.getLoopFor(block);
2842
2843 for (BasicBlock *Pred : predecessors(block)) {
2844 // Don't go up the backedge as we can use the last value if desired
2845 // via lcssa
2846 if (blockLoop && blockLoop->getHeader() == block &&
2847 blockLoop == LI.getLoopFor(Pred))
2848 continue;
2849
2850 Q.push_back(
2851 std::tuple<std::pair<BasicBlock *, BasicBlock *>, BasicBlock *>(
2852 std::make_pair(Pred, block), target));
2853 }
2854 }
2855 }
2856
2857 SmallPtrSet<BasicBlock *, 4> Preds;
2858 for (auto &pair : done) {
2859 Preds.insert(pair.first.first);
2860 }
2861
2862 for (auto BB : Preds) {
2863 bool illegal = false;
2864 SmallPtrSet<BasicBlock *, 2> UnionSet;
2865 size_t numSuc = 0;
2866 for (BasicBlock *sucI : successors(BB)) {
2867 numSuc++;
2868 const auto &SI = done[std::make_pair(BB, sucI)];
2869 if (SI.size() == 0) {
2870 // sucI->getName();
2871 illegal = true;
2872 break;
2873 }
2874 for (auto si : SI) {
2875 UnionSet.insert(si);
2876
2877 for (BasicBlock *sucJ : successors(BB)) {
2878 if (sucI == sucJ)
2879 continue;
2880 if (done[std::make_pair(BB, sucJ)].count(si)) {
2881 illegal = true;
2882 goto endIllegal;
2883 }
2884 }
2885 }
2886 }
2887 endIllegal:;
2888
2889 if (!illegal && numSuc > 1 && !B->hasNPredecessors(UnionSet.size())) {
2890 BasicBlock *Ins =
2891 BasicBlock::Create(BB->getContext(), "tmpblk", BB->getParent());
2892 IRBuilder<> Builder(Ins);
2893 for (auto &phi : B->phis()) {
2894 auto nphi = Builder.CreatePHI(phi.getType(), 2);
2895 SmallVector<BasicBlock *, 4> Blocks;
2896
2897 for (auto blk : UnionSet) {
2898 nphi->addIncoming(phi.getIncomingValueForBlock(blk), blk);
2899 phi.removeIncomingValue(blk, /*deleteifempty*/ false);
2900 }
2901
2902 phi.addIncoming(nphi, Ins);
2903 }
2904 Builder.CreateBr(B);
2905 for (auto blk : UnionSet) {
2906 auto term = blk->getTerminator();
2907 for (unsigned Idx = 0, NumSuccessors = term->getNumSuccessors();
2908 Idx != NumSuccessors; ++Idx)
2909 if (term->getSuccessor(Idx) == B)
2910 term->setSuccessor(Idx, Ins);
2911 }
2912 goto reset;
2913 }
2914 }
2915 }
2916 }
2917
2918 {
2919 SmallPtrSet<Function *, 1> calls_todo;
2920 bool local = false;
2921 DetectReadonlyOrThrowFn(*NewF, calls_todo,
2922 FAM.getResult<TargetLibraryAnalysis>(*NewF), local);
2923 }
2924
2925 if (EnzymePrint)
2926 llvm::errs() << "after simplification :\n" << *NewF << "\n";
2927
2928 if (llvm::verifyFunction(*NewF, &llvm::errs())) {
2929 llvm::errs() << *NewF << "\n";
2930 report_fatal_error("function failed verification (1)");
2931 }
2932 cache[std::make_pair(F, mode)] = NewF;
2933 return NewF;
2934}
2935
2936FunctionType *getFunctionTypeForClone(llvm::FunctionType *FTy,
2937 DerivativeMode mode, unsigned width,
2938 llvm::Type *additionalArg,
2939 llvm::ArrayRef<DIFFE_TYPE> constant_args,
2940 bool diffeReturnArg, bool returnTape,
2941 bool returnPrimal, bool returnShadow) {
2942 SmallVector<Type *, 4> RetTypes;
2943 if (returnPrimal)
2944 RetTypes.push_back(FTy->getReturnType());
2945 if (returnShadow)
2946 RetTypes.push_back(
2947 GradientUtils::getShadowType(FTy->getReturnType(), width));
2948 SmallVector<Type *, 4> ArgTypes;
2949
2950 // The user might be deleting arguments to the function by specifying them in
2951 // the VMap. If so, we need to not add the arguments to the arg ty vector
2952 unsigned argno = 0;
2953
2954 for (auto &I : FTy->params()) {
2955 ArgTypes.push_back(I);
2956 if (constant_args[argno] == DIFFE_TYPE::DUP_ARG ||
2957 constant_args[argno] == DIFFE_TYPE::DUP_NONEED) {
2958 ArgTypes.push_back(GradientUtils::getShadowType(I, width));
2959 } else if (constant_args[argno] == DIFFE_TYPE::OUT_DIFF && !returnTape) {
2960 RetTypes.push_back(GradientUtils::getShadowType(I, width));
2961 }
2962 ++argno;
2963 }
2964
2965 if (diffeReturnArg) {
2966 assert(!FTy->getReturnType()->isVoidTy());
2967 ArgTypes.push_back(
2968 GradientUtils::getShadowType(FTy->getReturnType(), width));
2969 }
2970 if (additionalArg) {
2971 ArgTypes.push_back(additionalArg);
2972 }
2973 Type *RetType = StructType::get(FTy->getContext(), RetTypes);
2974 if (returnTape) {
2975 RetTypes.insert(RetTypes.begin(),
2976 getDefaultAnonymousTapeType(FTy->getContext()));
2977 }
2978
2979 if (RetTypes.size() == 0)
2980 RetType = Type::getVoidTy(RetType->getContext());
2981 else if (RetTypes.size() == 1 && (returnPrimal || returnShadow) &&
2983 RetType = RetTypes[0];
2984 else
2985 RetType = StructType::get(FTy->getContext(), RetTypes);
2986
2987 bool noReturn = RetTypes.size() == 0;
2988 if (noReturn)
2989 RetType = Type::getVoidTy(RetType->getContext());
2990
2991 // Create a new function type...
2992 return FunctionType::get(RetType, ArgTypes, FTy->isVarArg());
2993}
2994
2996 DerivativeMode mode, unsigned width, Function *&F,
2997 ValueToValueMapTy &ptrInputs, ArrayRef<DIFFE_TYPE> constant_args,
2998 SmallPtrSetImpl<Value *> &constants, SmallPtrSetImpl<Value *> &nonconstant,
2999 SmallPtrSetImpl<Value *> &returnvals, bool returnTape, bool returnPrimal,
3000 bool returnShadow, const Twine &name,
3001 llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> *VMapO,
3002 bool diffeReturnArg, llvm::Type *additionalArg) {
3003 if (!F->empty())
3004 F = preprocessForClone(F, mode);
3005 llvm::ValueToValueMapTy VMap;
3006 llvm::FunctionType *FTy = getFunctionTypeForClone(
3007 F->getFunctionType(), mode, width, additionalArg, constant_args,
3008 diffeReturnArg, returnTape, returnPrimal, returnShadow);
3009
3010 for (BasicBlock &BB : *F) {
3011 if (auto ri = dyn_cast<ReturnInst>(BB.getTerminator())) {
3012 if (auto rv = ri->getReturnValue()) {
3013 returnvals.insert(rv);
3014 }
3015 }
3016 }
3017
3018 // Create the new function...
3019 Function *NewF = Function::Create(FTy, F->getLinkage(), name, F->getParent());
3020 if (diffeReturnArg) {
3021 auto I = NewF->arg_end();
3022 I--;
3023 if (additionalArg)
3024 I--;
3025 I->setName("differeturn");
3026 }
3027 if (additionalArg) {
3028 auto I = NewF->arg_end();
3029 I--;
3030 I->setName("tapeArg");
3031 }
3032
3033 {
3034 unsigned ii = 0;
3035 for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) {
3036 VMap[i] = j;
3037 ++j;
3038 ++i;
3039 if (constant_args[ii] == DIFFE_TYPE::DUP_ARG ||
3040 constant_args[ii] == DIFFE_TYPE::DUP_NONEED) {
3041 ++j;
3042 }
3043 ++ii;
3044 }
3045 }
3046
3047 // Loop over the arguments, copying the names of the mapped arguments over...
3048 Function::arg_iterator DestI = NewF->arg_begin();
3049
3050 for (const Argument &I : F->args())
3051 if (VMap.count(&I) == 0) { // Is this argument preserved?
3052 DestI->setName(I.getName()); // Copy the name over...
3053 VMap[&I] = &*DestI++; // Add mapping to VMap
3054 }
3055 SmallVector<ReturnInst *, 4> Returns;
3056 if (!F->empty()) {
3057 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
3058 Returns, "", nullptr);
3059 }
3060 if (NewF->empty()) {
3061 auto entry = BasicBlock::Create(NewF->getContext(), "entry", NewF);
3062 IRBuilder<> B(entry);
3063 B.CreateUnreachable();
3064 }
3065 CloneOrigin[NewF] = F;
3066 if (VMapO) {
3067 for (const auto &data : VMap)
3068 VMapO->insert(std::pair<const llvm::Value *, AssertingReplacingVH>(
3069 data.first, (llvm::Value *)data.second));
3070 VMapO->getMDMap() = VMap.getMDMap();
3071 }
3072
3073 for (auto attr : {"enzyme_ta_norecur", "frame-pointer"})
3074 if (F->getAttributes().hasAttribute(AttributeList::FunctionIndex, attr)) {
3075 NewF->addAttribute(
3076 AttributeList::FunctionIndex,
3077 F->getAttributes().getAttribute(AttributeList::FunctionIndex, attr));
3078 }
3079
3080 for (auto attr :
3081 {"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"})
3082 if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, attr)) {
3083 NewF->addAttribute(
3084 AttributeList::ReturnIndex,
3085 F->getAttributes().getAttribute(AttributeList::ReturnIndex, attr));
3086 }
3087
3088 bool hasPtrInput = false;
3089 unsigned ii = 0, jj = 0;
3090
3091 for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) {
3092 if (F->hasParamAttribute(ii, Attribute::StructRet)) {
3093 NewF->addParamAttr(
3094 jj, Attribute::get(F->getContext(), "enzyme_sret",
3096 F->getParamAttribute(ii, Attribute::StructRet)
3097 .getValueAsType())));
3098 }
3099 for (auto attr : {"enzymejl_parmtype", "enzymejl_parmtype_ref",
3100 "enzyme_type", "enzymejl_rooted_typ",
3101 "enzymejl_returnRoots", "enzymejl_sret_union_bytes"})
3102 if (F->getAttributes().hasParamAttr(ii, attr)) {
3103 NewF->addParamAttr(jj, F->getAttributes().getParamAttr(ii, attr));
3104 }
3105 for (auto ty : PrimalParamAttrsToPreserve)
3106 if (F->getAttributes().hasParamAttr(ii, ty)) {
3107 auto attr = F->getAttributes().getParamAttr(ii, ty);
3108 NewF->addParamAttr(jj, attr);
3109 }
3110 if (constant_args[ii] == DIFFE_TYPE::CONSTANT) {
3111 if (!i->hasByValAttr())
3112 constants.insert(i);
3114 llvm::errs() << "in new function " << NewF->getName()
3115 << " constant arg " << *j << "\n";
3116 } else {
3117 nonconstant.insert(i);
3119 llvm::errs() << "in new function " << NewF->getName()
3120 << " nonconstant arg " << *j << "\n";
3121 }
3122
3123 // Always remove nonnull/noundef since the caller may choose to pass
3124 // undef as an arg if provably it will not be used in the reverse pass
3125 if (constant_args[ii] == DIFFE_TYPE::DUP_NONEED ||
3127 if (F->hasParamAttribute(ii, Attribute::NonNull)) {
3128 NewF->removeParamAttr(jj, Attribute::NonNull);
3129 }
3130 if (F->hasParamAttribute(ii, Attribute::NoUndef)) {
3131 NewF->removeParamAttr(jj, Attribute::NoUndef);
3132 }
3133 if (F->hasParamAttribute(ii, Attribute::Dereferenceable)) {
3134 NewF->removeParamAttr(jj, Attribute::Dereferenceable);
3135 }
3136 if (F->hasParamAttribute(ii, Attribute::DereferenceableOrNull)) {
3137 NewF->removeParamAttr(jj, Attribute::DereferenceableOrNull);
3138 }
3139 }
3140
3141 if (constant_args[ii] == DIFFE_TYPE::DUP_ARG ||
3142 constant_args[ii] == DIFFE_TYPE::DUP_NONEED) {
3143 hasPtrInput = true;
3144 ptrInputs[i] = (j + 1);
3145 // TODO: find a way to keep the attributes in vector mode.
3146 if (width == 1) {
3147 for (auto ty : ShadowParamAttrsToPreserve) {
3148 if (F->getAttributes().hasParamAttr(ii, ty)) {
3149 auto attr = F->getAttributes().getParamAttr(ii, ty);
3150 NewF->addParamAttr(jj + 1, attr);
3151 }
3152 }
3153 }
3154
3155 for (auto attr : {"enzymejl_parmtype", "enzymejl_parmtype_ref",
3156 "enzyme_type", "enzymejl_rooted_typ",
3157 "enzymejl_returnRoots", "enzymejl_sret_union_bytes"})
3158 if (F->getAttributes().hasParamAttr(ii, attr)) {
3159 if (width == 1) {
3160 NewF->addParamAttr(jj + 1,
3161 F->getAttributes().getParamAttr(ii, attr));
3162 } else {
3163 NewF->addParamAttr(jj + 1,
3164 Attribute::get(F->getContext(),
3165 attr + std::string("_v"),
3166 F->getAttributes()
3167 .getParamAttr(ii, attr)
3168 .getValueAsString()));
3169 }
3170 }
3171
3172 if (F->hasParamAttribute(ii, Attribute::StructRet)) {
3173 if (width == 1) {
3174 NewF->addParamAttr(
3175 jj + 1,
3176 Attribute::get(F->getContext(), "enzyme_sret",
3178 F->getParamAttribute(ii, Attribute::StructRet)
3179 .getValueAsType())));
3180 } else {
3181 NewF->addParamAttr(
3182 jj + 1,
3183 Attribute::get(F->getContext(), "enzyme_sret_v",
3185 F->getParamAttribute(ii, Attribute::StructRet)
3186 .getValueAsType())));
3187 }
3188 }
3189
3190 j->setName(i->getName());
3191 ++j;
3192 j->setName(i->getName() + "'");
3193 nonconstant.insert(j);
3194 ++j;
3195 jj += 2;
3196
3197 ++i;
3198
3199 } else {
3200 j->setName(i->getName());
3201 ++j;
3202 ++jj;
3203 ++i;
3204 }
3205 ++ii;
3206 }
3207
3208 if (hasPtrInput && (mode == DerivativeMode::ReverseModeCombined ||
3210 if (NewF->hasFnAttribute(Attribute::ReadOnly)) {
3211 NewF->removeFnAttr(Attribute::ReadOnly);
3212 }
3213#if LLVM_VERSION_MAJOR >= 16
3214 auto eff = NewF->getMemoryEffects();
3215 for (auto loc : MemoryEffects::locations()) {
3216 if (loc == MemoryEffects::Location::InaccessibleMem)
3217 continue;
3218 auto mr = eff.getModRef(loc);
3219 if (isModSet(mr))
3220 eff |= MemoryEffects(loc, ModRefInfo::Ref);
3221 if (isRefSet(mr))
3222 eff |= MemoryEffects(loc, ModRefInfo::Mod);
3223 }
3224 NewF->setMemoryEffects(eff);
3225#endif
3226 }
3227 NewF->setLinkage(Function::LinkageTypes::InternalLinkage);
3229 NewF->addFnAttr(Attribute::AlwaysInline);
3230 assert(NewF->hasLocalLinkage());
3231
3232 return NewF;
3233}
3234
3235void CoaleseTrivialMallocs(Function &F, DominatorTree &DT) {
3236 std::map<BasicBlock *, std::vector<std::pair<CallInst *, CallInst *>>>
3237 LegalMallocs;
3238
3239 std::map<Metadata *, std::vector<CallInst *>> frees;
3240 for (BasicBlock &BB : F) {
3241 for (Instruction &I : BB) {
3242 if (auto CI = dyn_cast<CallInst>(&I)) {
3243 if (auto F2 = CI->getCalledFunction()) {
3244 if (F2->getName() == "free") {
3245 if (auto MD = hasMetadata(CI, "enzyme_cache_free")) {
3246 Metadata *op = MD->getOperand(0);
3247 frees[op].push_back(CI);
3248 }
3249 }
3250 }
3251 }
3252 }
3253 }
3254
3255 for (BasicBlock &BB : F) {
3256 for (Instruction &I : BB) {
3257 if (auto CI = dyn_cast<CallInst>(&I)) {
3258 if (auto F = CI->getCalledFunction()) {
3259 if (F->getName() == "malloc") {
3260 CallInst *freeCall = nullptr;
3261 for (auto U : CI->users()) {
3262 if (auto CI2 = dyn_cast<CallInst>(U)) {
3263 if (auto F2 = CI2->getCalledFunction()) {
3264 if (F2->getName() == "free") {
3265 if (DT.dominates(CI, CI2)) {
3266 freeCall = CI2;
3267 break;
3268 }
3269 }
3270 }
3271 }
3272 }
3273 if (!freeCall) {
3274 if (auto MD = hasMetadata(CI, "enzyme_cache_alloc")) {
3275 Metadata *op = MD->getOperand(0);
3276 if (frees[op].size() == 1)
3277 freeCall = frees[op][0];
3278 }
3279 }
3280 if (freeCall)
3281 LegalMallocs[&BB].emplace_back(CI, freeCall);
3282 }
3283 }
3284 }
3285 }
3286 }
3287 for (auto &pair : LegalMallocs) {
3288 if (pair.second.size() < 2)
3289 continue;
3290 CallInst *First = pair.second[0].first;
3291 for (auto &z : pair.second) {
3292 if (!DT.dominates(First, z.first))
3293 First = z.first;
3294 }
3295 bool legal = true;
3296 for (auto &z : pair.second) {
3297 if (auto inst = dyn_cast<Instruction>(z.first->getArgOperand(0)))
3298 if (!DT.dominates(inst, First))
3299 legal = false;
3300 }
3301 if (!legal)
3302 continue;
3303 IRBuilder<> B(First);
3304 Value *Size = First->getArgOperand(0);
3305 for (auto &z : pair.second) {
3306 if (z.first == First)
3307 continue;
3308 Size = B.CreateAdd(
3309 B.CreateOr(B.CreateSub(Size, ConstantInt::get(Size->getType(), 1)),
3310 ConstantInt::get(Size->getType(), 15)),
3311 ConstantInt::get(Size->getType(), 1));
3312 z.second->eraseFromParent();
3313 IRBuilder<> B2(z.first);
3314 Value *gepPtr = B2.CreateInBoundsGEP(Type::getInt8Ty(First->getContext()),
3315 First, Size);
3316 z.first->replaceAllUsesWith(gepPtr);
3317 Size = B.CreateAdd(Size, z.first->getArgOperand(0));
3318 z.first->eraseFromParent();
3319 }
3320 auto NewMalloc =
3321 cast<CallInst>(B.CreateCall(First->getCalledFunction(), Size));
3322 NewMalloc->copyIRFlags(First);
3323 NewMalloc->setMetadata("enzyme_cache_alloc",
3324 hasMetadata(First, "enzyme_cache_alloc"));
3325 First->replaceAllUsesWith(NewMalloc);
3326 First->eraseFromParent();
3327 }
3328}
3329
3330void SelectOptimization(Function *F) {
3331 DominatorTree DT(*F);
3332 for (auto &BB : *F) {
3333 if (auto BI = dyn_cast<BranchInst>(BB.getTerminator())) {
3334 if (BI->isConditional()) {
3335 for (auto &I : BB) {
3336 if (auto SI = dyn_cast<SelectInst>(&I)) {
3337 if (SI->getCondition() == BI->getCondition()) {
3338 for (Value::use_iterator UI = SI->use_begin(), E = SI->use_end();
3339 UI != E;) {
3340 Use &U = *UI;
3341 ++UI;
3342 if (DT.dominates(BasicBlockEdge(&BB, BI->getSuccessor(0)), U))
3343 U.set(SI->getTrueValue());
3344 else if (DT.dominates(BasicBlockEdge(&BB, BI->getSuccessor(1)),
3345 U))
3346 U.set(SI->getFalseValue());
3347 }
3348 }
3349 }
3350 }
3351 }
3352 }
3353 }
3354}
3355
3357 for (Function &Impl : M) {
3358 for (auto attr : {"implements", "implements2"}) {
3359 if (!Impl.hasFnAttribute(attr))
3360 continue;
3361 const Attribute &A = Impl.getFnAttribute(attr);
3362
3363 const StringRef SpecificationName = A.getValueAsString();
3364 Function *Specification = M.getFunction(SpecificationName);
3365 if (!Specification) {
3366 LLVM_DEBUG(dbgs() << "Found implementation '" << Impl.getName()
3367 << "' but no matching specification with name '"
3368 << SpecificationName
3369 << "', potentially inlined and/or eliminated.\n");
3370 continue;
3371 }
3372 LLVM_DEBUG(dbgs() << "Replace specification '" << Specification->getName()
3373 << "' with implementation '" << Impl.getName()
3374 << "'\n");
3375
3376 for (auto I = Specification->use_begin(), UE = Specification->use_end();
3377 I != UE;) {
3378 auto &use = *I;
3379 ++I;
3380 auto cext = ConstantExpr::getBitCast(&Impl, Specification->getType());
3381 if (cast<Instruction>(use.getUser())->getParent()->getParent() == &Impl)
3382 continue;
3383 use.set(cext);
3384 if (auto CI = dyn_cast<CallInst>(use.getUser())) {
3385 if (CI->getCalledOperand() == cext ||
3386 CI->getCalledFunction() == &Impl) {
3387 CI->setCallingConv(Impl.getCallingConv());
3388 }
3389 }
3390 }
3391 }
3392 }
3393}
3394
3396 PreservedAnalyses PA;
3397 PA = PromotePass().run(*F, FAM);
3398 FAM.invalidate(*F, PA);
3399#if !defined(FLANG)
3400 PA = GVNPass().run(*F, FAM);
3401#else
3402 PA = GVN().run(*F, FAM);
3403#endif
3404 FAM.invalidate(*F, PA);
3405#if LLVM_VERSION_MAJOR >= 16 && !defined(FLANG)
3406 PA = SROAPass(llvm::SROAOptions::PreserveCFG).run(*F, FAM);
3407#elif !defined(FLANG)
3408 PA = SROAPass().run(*F, FAM);
3409#else
3410 PA = SROA().run(*F, FAM);
3411#endif
3412 FAM.invalidate(*F, PA);
3413
3414 if (EnzymeSelectOpt) {
3415 SimplifyCFGOptions scfgo;
3416 PA = SimplifyCFGPass(scfgo).run(*F, FAM);
3417 FAM.invalidate(*F, PA);
3418 PA = CorrelatedValuePropagationPass().run(*F, FAM);
3419 FAM.invalidate(*F, PA);
3421 }
3422 // EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM);
3423
3424 if (EnzymeCoalese)
3425 CoaleseTrivialMallocs(*F, FAM.getResult<DominatorTreeAnalysis>(*F));
3426
3427 ReplaceFunctionImplementation(*F->getParent());
3428
3429 {
3430 PreservedAnalyses PA;
3431 FAM.invalidate(*F, PA);
3432 }
3433
3434 OptimizationLevel Level = OptimizationLevel::O0;
3435
3436 switch (EnzymePostOptLevel) {
3437 default:
3438 case 0:
3439 Level = OptimizationLevel::O0;
3440 break;
3441 case 1:
3442 Level = OptimizationLevel::O1;
3443 break;
3444 case 2:
3445 Level = OptimizationLevel::O2;
3446 break;
3447 case 3:
3448 Level = OptimizationLevel::O3;
3449 break;
3450 }
3451 if (Level != OptimizationLevel::O0) {
3452 PassBuilder PB;
3453 FunctionPassManager FPM =
3454 PB.buildFunctionSimplificationPipeline(Level, ThinOrFullLTOPhase::None);
3455 PA = FPM.run(*F, FAM);
3456 FAM.invalidate(*F, PA);
3457 }
3458
3459 // TODO actually run post optimizations.
3460}
3461
3463 LAM.clear();
3464 FAM.clear();
3465 MAM.clear();
3466 cache.clear();
3467}
3468
3469// Returns if a is guaranteed to be equivalent to not b
3470static bool isNot(Value *a, Value *b) {
3471 // cmp pred, a, b and cmp inverse(pred), a, b
3472 if (auto I1 = dyn_cast<CmpInst>(a))
3473 if (auto I2 = dyn_cast<CmpInst>(b))
3474 if (I1->getOperand(0) == I2->getOperand(0) &&
3475 I1->getOperand(1) == I2->getOperand(1) &&
3476 I1->getPredicate() == I2->getInversePredicate())
3477 return true;
3478 // a := xor true, b
3479 if (auto I = dyn_cast<Instruction>(a))
3480 if (I->getOpcode() == Instruction::Xor)
3481 for (int i = 0; i < 2; i++) {
3482 if (I->getOperand(i) == b)
3483 if (auto CI = dyn_cast<ConstantInt>(I->getOperand(1 - i)))
3484#if LLVM_VERSION_MAJOR > 16
3485 if (CI->getValue().isAllOnes())
3486#else
3487 if (CI->getValue().isAllOnesValue())
3488#endif
3489 return true;
3490 }
3491 // b := xor true, a
3492 if (auto I = dyn_cast<Instruction>(b))
3493 if (I->getOpcode() == Instruction::Xor)
3494 for (int i = 0; i < 2; i++) {
3495 if (I->getOperand(i) == a)
3496 if (auto CI = dyn_cast<ConstantInt>(I->getOperand(1 - i)))
3497#if LLVM_VERSION_MAJOR > 16
3498 if (CI->getValue().isAllOnes())
3499#else
3500 if (CI->getValue().isAllOnesValue())
3501#endif
3502 return true;
3503 }
3504 return false;
3505}
3506
3508public:
3509 DominatorTree &DT;
3510 LoopInfo &LI;
3511 compare_insts(DominatorTree &DT, LoopInfo &LI) : DT(DT), LI(LI) {}
3512
3513 // return true if A appears later than B.
3514 bool operator()(Instruction *A, Instruction *B) const {
3515 if (A == B) {
3516 return false;
3517 }
3518 if (A->getParent() == B->getParent()) {
3519 return !A->comesBefore(B);
3520 }
3521 auto AB = A->getParent();
3522 auto BB = B->getParent();
3523 assert(AB->getParent() == BB->getParent());
3524
3525 for (auto prev = BB->getPrevNode(); prev; prev = prev->getPrevNode()) {
3526 if (prev == AB)
3527 return false;
3528 }
3529 return true;
3530 }
3531};
3532
3533class DominatorOrderSet : public std::set<Instruction *, compare_insts> {
3534public:
3535 DominatorOrderSet(DominatorTree &DT, LoopInfo &LI)
3536 : std::set<Instruction *, compare_insts>(compare_insts(DT, LI)) {}
3537 bool contains(Instruction *I) const {
3538 auto __i = find(I);
3539 return __i != end();
3540 }
3541 void remove(Instruction *I) {
3542 auto __i = find(I);
3543 assert(__i != end());
3544 erase(__i);
3545 }
3546 Instruction *pop_back_val() {
3547 auto back = end();
3548 back--;
3549 auto v = *back;
3550 erase(back);
3551 return v;
3552 }
3553};
3554
3555bool directlySparse(Value *z) {
3556 if (isa<UIToFPInst>(z))
3557 return true;
3558 if (isa<SIToFPInst>(z))
3559 return true;
3560 if (isa<ZExtInst>(z))
3561 return true;
3562 if (isa<SExtInst>(z))
3563 return true;
3564 if (auto SI = dyn_cast<SelectInst>(z)) {
3565 if (auto CI = dyn_cast<ConstantInt>(SI->getTrueValue()))
3566 if (CI->isZero())
3567 return true;
3568 if (auto CI = dyn_cast<ConstantInt>(SI->getFalseValue()))
3569 if (CI->isZero())
3570 return true;
3571 }
3572 return false;
3573}
3574
3576
3577Function *getProductIntrinsic(llvm::Module &M, llvm::Type *T) {
3578 std::string name = "__enzyme_product.";
3579 if (T->isFloatTy())
3580 name += "f32";
3581 else if (T->isDoubleTy())
3582 name += "f64";
3583 else if (T->isIntegerTy())
3584 name += "i" + std::to_string(cast<IntegerType>(T)->getBitWidth());
3585 else
3586 assert(0);
3587 auto FT = llvm::FunctionType::get(T, {}, true);
3588 AttributeList AL;
3589 AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
3590 Attribute::ReadNone);
3591 AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
3592 Attribute::NoUnwind);
3593 AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
3594 Attribute::NoFree);
3595 AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
3596 Attribute::NoSync);
3597 AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
3598 Attribute::WillReturn);
3599 return cast<Function>(M.getOrInsertFunction(name, FT, AL).getCallee());
3600}
3601
3602Function *getSumIntrinsic(llvm::Module &M, llvm::Type *T) {
3603 std::string name = "__enzyme_sum.";
3604 if (T->isFloatTy())
3605 name += "f32";
3606 else if (T->isDoubleTy())
3607 name += "f64";
3608 else if (T->isIntegerTy())
3609 name += "i" + std::to_string(cast<IntegerType>(T)->getBitWidth());
3610 else
3611 assert(0);
3612 auto FT = llvm::FunctionType::get(T, {}, true);
3613 AttributeList AL;
3614 AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
3615 Attribute::ReadNone);
3616 AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
3617 Attribute::NoUnwind);
3618 AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
3619 Attribute::NoFree);
3620 AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
3621 Attribute::NoSync);
3622 AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex,
3623 Attribute::WillReturn);
3624 return cast<Function>(M.getOrInsertFunction(name, FT, AL).getCallee());
3625}
3626
3627CallInst *isProduct(llvm::Value *v) {
3628 if (auto prod = dyn_cast<CallInst>(v))
3629 if (auto F = getFunctionFromCall(prod))
3630 if (startsWith(F->getName(), "__enzyme_product"))
3631 return prod;
3632 return nullptr;
3633}
3634
3635CallInst *isSum(llvm::Value *v) {
3636 if (auto prod = dyn_cast<CallInst>(v))
3637 if (auto F = getFunctionFromCall(prod))
3638 if (startsWith(F->getName(), "__enzyme_sum"))
3639 return prod;
3640 return nullptr;
3641}
3642
3643SmallVector<Value *, 1> callOperands(llvm::CallBase *CB) {
3644 return SmallVector<Value *, 1>(CB->args().begin(), CB->args().end());
3645}
3646
3648 if (isa<LoadInst>(z))
3649 return true;
3650 if (isa<Constant>(z))
3651 return true;
3652 if (auto BO = dyn_cast<BinaryOperator>(z))
3653 return guaranteedDataDependent(BO->getOperand(0)) &&
3654 guaranteedDataDependent(BO->getOperand(1));
3655 if (auto C = dyn_cast<CastInst>(z))
3656 return guaranteedDataDependent(C->getOperand(0));
3657 if (auto S = isSum(z)) {
3658 for (auto op : callOperands(S))
3660 return true;
3661 return false;
3662 }
3663 if (auto S = isProduct(z)) {
3664 for (auto op : callOperands(S))
3665 if (!guaranteedDataDependent(op))
3666 return false;
3667 return true;
3668 }
3669 if (auto II = dyn_cast<IntrinsicInst>(z)) {
3670 switch (II->getIntrinsicID()) {
3671 case Intrinsic::sqrt:
3672 case Intrinsic::sin:
3673 case Intrinsic::cos:
3674#if LLVM_VERSION_MAJOR >= 19
3675 case Intrinsic::sinh:
3676 case Intrinsic::cosh:
3677 case Intrinsic::tanh:
3678#endif
3679 return guaranteedDataDependent(II->getArgOperand(0));
3680 default:
3681 break;
3682 }
3683 }
3684 return false;
3685}
3686
3687std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
3688 QueueType &Q, DominatorTree &DT,
3689 ScalarEvolution &SE, LoopInfo &LI,
3690 const DataLayout &DL) {
3691 auto push = [&](llvm::Value *V) {
3692 if (V == cur)
3693 return V;
3694 assert(V);
3695 if (auto I = dyn_cast<Instruction>(V)) {
3696 Q.insert(I);
3697 for (auto U : I->users()) {
3698 if (auto I2 = dyn_cast<Instruction>(U)) {
3699 if (I2 == cur)
3700 continue;
3701 Q.insert(I2);
3702 }
3703 }
3704 }
3705 return V;
3706 };
3707 auto pushcse = [&](llvm::Value *V) -> llvm::Value * {
3708 if (auto I = dyn_cast<Instruction>(V)) {
3709 for (size_t i = 0; i < I->getNumOperands(); i++) {
3710 if (auto I2 = dyn_cast<Instruction>(I->getOperand(i))) {
3711 Instruction *candidate = nullptr;
3712 for (auto U : I2->users()) {
3713 candidate = dyn_cast<Instruction>(U);
3714 if (!candidate)
3715 continue;
3716 if (candidate == I && candidate->getType() != I->getType()) {
3717 candidate = nullptr;
3718 continue;
3719 }
3720 bool isSame = candidate->isIdenticalTo(I);
3721 if (!isSame) {
3722 if (auto P1 = isProduct(I))
3723 if (auto P2 = isProduct(I2)) {
3724 std::multiset<llvm::Value *> s1;
3725 std::multiset<llvm::Value *> s2;
3726 for (auto &v : callOperands(P1))
3727 s1.insert(v);
3728 for (auto &v : callOperands(P2))
3729 s2.insert(v);
3730 isSame = s1 == s2;
3731 }
3732 if (auto P1 = isSum(I))
3733 if (auto P2 = isSum(I2)) {
3734 std::multiset<llvm::Value *> s1;
3735 std::multiset<llvm::Value *> s2;
3736 for (auto &v : callOperands(P1))
3737 s1.insert(v);
3738 for (auto &v : callOperands(P2))
3739 s2.insert(v);
3740 isSame = s1 == s2;
3741 }
3742 }
3743 if (!isSame) {
3744 candidate = nullptr;
3745 continue;
3746 }
3747
3748 if (DT.dominates(candidate, I)) {
3749 break;
3750 }
3751 candidate = nullptr;
3752 }
3753 if (candidate) {
3754 I->eraseFromParent();
3755 return candidate;
3756 }
3757 break;
3758 }
3759 }
3760 return push(I);
3761 }
3762 return V;
3763 };
3764 auto replaceAndErase = [&](llvm::Instruction *I, llvm::Value *candidate) {
3765 for (auto U : I->users())
3766 push(U);
3767 I->replaceAllUsesWith(candidate);
3768 push(candidate);
3769
3770 SetVector<Instruction *> operands;
3771 for (size_t i = 0; i < I->getNumOperands(); i++) {
3772 if (auto I2 = dyn_cast<Instruction>(I->getOperand(i))) {
3773 if ((!I2->mayWriteToMemory() ||
3774 (isa<CallInst>(I2) && isReadOnly(cast<CallInst>(I2)))))
3775 operands.insert(I2);
3776 }
3777 }
3778 if (Q.contains(I)) {
3779 Q.remove(I);
3780 }
3781 assert(!Q.contains(I));
3782 I->eraseFromParent();
3783 for (auto op : operands)
3784 if (op->getNumUses() == 0) {
3785 if (Q.contains(op))
3786 Q.remove(op);
3787 op->eraseFromParent();
3788 }
3789 };
3790 if (!cur->getType()->isVoidTy() &&
3791 (!cur->mayWriteToMemory() ||
3792 (isa<CallInst>(cur) && isReadOnly(cast<CallInst>(cur))))) {
3793 // DCE
3794 if (cur->getNumUses() == 0) {
3795 for (size_t i = 0; i < cur->getNumOperands(); i++)
3796 push(cur->getOperand(i));
3797 assert(!Q.contains(cur));
3798 cur->eraseFromParent();
3799 return "DCE";
3800 }
3801 // CSE
3802 {
3803 for (size_t i = 0; i < cur->getNumOperands(); i++) {
3804 if (auto I = dyn_cast<Instruction>(cur->getOperand(i))) {
3805 Instruction *candidate = nullptr;
3806 bool reverse = false;
3807 for (auto U : I->users()) {
3808 candidate = dyn_cast<Instruction>(U);
3809 if (!candidate)
3810 continue;
3811 if (candidate == cur && candidate->getType() != cur->getType()) {
3812 candidate = nullptr;
3813 continue;
3814 }
3815 bool isSame = candidate->isIdenticalTo(cur);
3816 if (!isSame) {
3817 if (auto P1 = isProduct(candidate))
3818 if (auto P2 = isProduct(cur)) {
3819 std::multiset<llvm::Value *> s1;
3820 std::multiset<llvm::Value *> s2;
3821 for (auto &v : callOperands(P1))
3822 s1.insert(v);
3823 for (auto &v : callOperands(P2))
3824 s2.insert(v);
3825 isSame = s1 == s2;
3826 }
3827 if (auto P1 = isSum(candidate))
3828 if (auto P2 = isSum(cur)) {
3829 std::multiset<llvm::Value *> s1;
3830 std::multiset<llvm::Value *> s2;
3831 for (auto &v : callOperands(P1))
3832 s1.insert(v);
3833 for (auto &v : callOperands(P2))
3834 s2.insert(v);
3835 isSame = s1 == s2;
3836 }
3837 }
3838
3839 if (!isSame) {
3840 candidate = nullptr;
3841 continue;
3842 }
3843
3844 if (DT.dominates(candidate, cur)) {
3845 break;
3846 } else if (DT.dominates(cur, candidate)) {
3847 reverse = true;
3848 break;
3849 }
3850 candidate = nullptr;
3851 }
3852 if (candidate) {
3853 if (reverse) {
3854 if (Q.contains(candidate))
3855 Q.remove(candidate);
3856 auto tmp = candidate;
3857 candidate = cur;
3858 cur = tmp;
3859 }
3860 replaceAndErase(cur, candidate);
3861 return "CSE";
3862 }
3863 break;
3864 }
3865 }
3866 }
3867 }
3868
3869 if (auto SI = dyn_cast<SelectInst>(cur))
3870 if (auto CI = dyn_cast<ConstantInt>(SI->getCondition())) {
3871 if (CI->isOne()) {
3872 replaceAndErase(cur, SI->getTrueValue());
3873 return "SelectToTrue";
3874 } else {
3875 replaceAndErase(cur, SI->getFalseValue());
3876 return "SelectToFalse";
3877 }
3878 }
3879 if (cur->getOpcode() == Instruction::Or) {
3880 for (int i = 0; i < 2; i++) {
3881 if (auto C = dyn_cast<ConstantInt>(cur->getOperand(i))) {
3882 // or a, 0 -> a
3883 if (C->isZero()) {
3884 replaceAndErase(cur, cur->getOperand(1 - i));
3885 return "OrZero";
3886 }
3887 // or a, 1 -> 1
3888 if (C->isOne() && cur->getType()->isIntegerTy(1)) {
3889 replaceAndErase(cur, C);
3890 return "OrOne";
3891 }
3892 }
3893 }
3894 }
3895 if (cur->getOpcode() == Instruction::And) {
3896 for (int i = 0; i < 2; i++) {
3897 if (auto C = dyn_cast<ConstantInt>(cur->getOperand(i))) {
3898 // and a, 1 -> a
3899 if (C->isOne() && cur->getType()->isIntegerTy(1)) {
3900 replaceAndErase(cur, cur->getOperand(1 - i));
3901 return "AndOne";
3902 }
3903 // and a, 0 -> 0
3904 if (C->isZero()) {
3905 replaceAndErase(cur, C);
3906 return "AndZero";
3907 }
3908 }
3909 }
3910 }
3911
3912 IRBuilder<> B(cur);
3913 if (auto CI = dyn_cast<CastInst>(cur))
3914 if (auto C = dyn_cast<Constant>(CI->getOperand(0))) {
3915 replaceAndErase(
3916 cur, cast<Constant>(B.CreateCast(CI->getOpcode(), C, CI->getType())));
3917 return "CastConstProp";
3918 }
3919 std::function<Value *(Value *, Value *, Value *)> replace = [&](Value *val,
3920 Value *orig,
3921 Value *with) {
3922 if (val == orig) {
3923 return with;
3924 }
3925 if (isNot(val, orig)) {
3926 return pushcse(B.CreateNot(with));
3927 }
3928 if (isa<PHINode>(val))
3929 return val;
3930
3931 if (auto I = dyn_cast<Instruction>(val)) {
3932 if (I->mayWriteToMemory() &&
3933 !(isa<CallInst>(I) && isReadOnly(cast<CallInst>(I))))
3934 return val;
3935
3936 if (I->getOpcode() == Instruction::Add) {
3937 Value *lhs = replace(I->getOperand(0), orig, with);
3938 Value *rhs = replace(I->getOperand(1), orig, with);
3939 if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
3940 return val;
3941 push(I);
3942 return pushcse(B.CreateAdd(lhs, rhs, "sel." + I->getName(),
3943 I->hasNoUnsignedWrap(),
3944 I->hasNoSignedWrap()));
3945 }
3946
3947 if (I->getOpcode() == Instruction::Sub) {
3948 Value *lhs = replace(I->getOperand(0), orig, with);
3949 Value *rhs = replace(I->getOperand(1), orig, with);
3950 if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
3951 return val;
3952 push(I);
3953 return pushcse(B.CreateSub(lhs, rhs, "sel." + I->getName(),
3954 I->hasNoUnsignedWrap(),
3955 I->hasNoSignedWrap()));
3956 }
3957
3958 if (I->getOpcode() == Instruction::Mul) {
3959 Value *lhs = replace(I->getOperand(0), orig, with);
3960 Value *rhs = replace(I->getOperand(1), orig, with);
3961 if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
3962 return val;
3963 push(I);
3964 return pushcse(B.CreateMul(lhs, rhs, "sel." + I->getName(),
3965 I->hasNoUnsignedWrap(),
3966 I->hasNoSignedWrap()));
3967 }
3968
3969 if (I->getOpcode() == Instruction::And) {
3970 Value *lhs = replace(I->getOperand(0), orig, with);
3971 Value *rhs = replace(I->getOperand(1), orig, with);
3972 if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
3973 return val;
3974 push(I);
3975 return pushcse(B.CreateAnd(lhs, rhs, "sel." + I->getName()));
3976 }
3977
3978 if (I->getOpcode() == Instruction::Or) {
3979 Value *lhs = replace(I->getOperand(0), orig, with);
3980 Value *rhs = replace(I->getOperand(1), orig, with);
3981 if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
3982 return val;
3983 push(I);
3984 return pushcse(B.CreateOr(lhs, rhs, "sel." + I->getName()));
3985 }
3986
3987 if (I->getOpcode() == Instruction::Xor) {
3988 Value *lhs = replace(I->getOperand(0), orig, with);
3989 Value *rhs = replace(I->getOperand(1), orig, with);
3990 if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
3991 return val;
3992 push(I);
3993 return pushcse(B.CreateXor(lhs, rhs, "sel." + I->getName()));
3994 }
3995
3996 if (I->getOpcode() == Instruction::FAdd) {
3997 Value *lhs = replace(I->getOperand(0), orig, with);
3998 Value *rhs = replace(I->getOperand(1), orig, with);
3999 if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
4000 return val;
4001 push(I);
4002 return pushcse(B.CreateFAddFMF(lhs, rhs, I, "sel." + I->getName()));
4003 }
4004
4005 if (I->getOpcode() == Instruction::FSub) {
4006 Value *lhs = replace(I->getOperand(0), orig, with);
4007 Value *rhs = replace(I->getOperand(1), orig, with);
4008 if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
4009 return val;
4010 push(I);
4011 return pushcse(B.CreateFSubFMF(lhs, rhs, I, "sel." + I->getName()));
4012 }
4013
4014 if (I->getOpcode() == Instruction::FMul) {
4015 Value *lhs = replace(I->getOperand(0), orig, with);
4016 Value *rhs = replace(I->getOperand(1), orig, with);
4017 if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
4018 return val;
4019 push(I);
4020 return pushcse(B.CreateFMulFMF(lhs, rhs, I, "sel." + I->getName()));
4021 }
4022
4023 if (I->getOpcode() == Instruction::ZExt) {
4024 Value *op = replace(I->getOperand(0), orig, with);
4025 if (op == I->getOperand(0))
4026 return val;
4027 push(I);
4028 return pushcse(B.CreateZExt(op, I->getType(), "sel." + I->getName()));
4029 }
4030
4031 if (I->getOpcode() == Instruction::SExt) {
4032 Value *op = replace(I->getOperand(0), orig, with);
4033 if (op == I->getOperand(0))
4034 return val;
4035 push(I);
4036 return pushcse(B.CreateSExt(op, I->getType(), "sel." + I->getName()));
4037 }
4038
4039 if (I->getOpcode() == Instruction::UIToFP) {
4040 Value *op = replace(I->getOperand(0), orig, with);
4041 if (op == I->getOperand(0))
4042 return val;
4043 push(I);
4044 return pushcse(B.CreateUIToFP(op, I->getType(), "sel." + I->getName()));
4045 }
4046
4047 if (I->getOpcode() == Instruction::SIToFP) {
4048 Value *op = replace(I->getOperand(0), orig, with);
4049 if (op == I->getOperand(0))
4050 return val;
4051 push(I);
4052 return pushcse(B.CreateSIToFP(op, I->getType(), "sel." + I->getName()));
4053 }
4054
4055 if (auto CI = dyn_cast<CmpInst>(I)) {
4056 Value *lhs = replace(I->getOperand(0), orig, with);
4057 Value *rhs = replace(I->getOperand(1), orig, with);
4058 if (lhs == I->getOperand(0) && rhs == I->getOperand(1))
4059 return val;
4060 push(I);
4061 return pushcse(
4062 B.CreateCmp(CI->getPredicate(), lhs, rhs, "sel." + I->getName()));
4063 }
4064
4065 if (auto SI = dyn_cast<SelectInst>(I)) {
4066 Value *cond = replace(SI->getCondition(), orig, with);
4067 Value *tval = replace(SI->getTrueValue(), orig, with);
4068 Value *fval = replace(SI->getFalseValue(), orig, with);
4069 if (cond == SI->getCondition() && tval == SI->getTrueValue() &&
4070 fval == SI->getFalseValue())
4071 return val;
4072 push(I);
4073 if (auto CI = dyn_cast<ConstantInt>(cond)) {
4074 if (CI->isOne())
4075 return tval;
4076 else
4077 return fval;
4078 }
4079 return pushcse(B.CreateSelect(cond, tval, fval, "sel." + I->getName()));
4080 }
4081
4082 if (isProduct(I) || isSum(I)) {
4083 auto C = cast<CallBase>(I);
4084 auto ops = callOperands(C);
4085 bool changed = false;
4086 for (auto &op : ops) {
4087 auto next = replace(op, orig, with);
4088 if (next != op) {
4089 changed = true;
4090 op = next;
4091 }
4092 }
4093 if (!changed)
4094 return (Value *)I;
4095 push(I);
4096 pushcse(
4097 B.CreateCall(getFunctionFromCall(C), ops, "sel." + I->getName()));
4098 }
4099 }
4100 return val;
4101 };
4102
4103 if (auto II = dyn_cast<IntrinsicInst>(cur))
4104 if (II->getIntrinsicID() == Intrinsic::fmuladd ||
4105 II->getIntrinsicID() == Intrinsic::fma) {
4106 B.setFastMathFlags(getFast());
4107 auto mul = pushcse(B.CreateFMul(II->getOperand(0), II->getOperand(1)));
4108 auto add = pushcse(B.CreateFAdd(mul, II->getOperand(2)));
4109 replaceAndErase(cur, add);
4110 return "FMulAddExpand";
4111 }
4112
4113 if (auto BO = dyn_cast<BinaryOperator>(cur)) {
4114 if (BO->getOpcode() == Instruction::FMul && BO->isFast()) {
4115 Value *args[2] = {BO->getOperand(0), BO->getOperand(1)};
4116 auto mul = pushcse(
4117 B.CreateCall(getProductIntrinsic(*F.getParent(), BO->getType()), args,
4118 cur->getName()));
4119 replaceAndErase(cur, mul);
4120 return "FMulToProduct";
4121 }
4122 if (BO->getOpcode() == Instruction::FDiv && BO->isFast()) {
4123 auto c0 = dyn_cast<ConstantFP>(BO->getOperand(0));
4124 if (!c0 || !c0->isExactlyValue(1.0)) {
4125 B.setFastMathFlags(getFast());
4126 auto div = pushcse(B.CreateFDivFMF(ConstantFP::get(BO->getType(), 1.0),
4127 BO->getOperand(1), BO));
4128 auto mul = pushcse(
4129 B.CreateFMulFMF(BO->getOperand(0), div, BO, cur->getName()));
4130 replaceAndErase(cur, mul);
4131 return "FDivToFMul";
4132 }
4133 }
4134 if (BO->getOpcode() == Instruction::FAdd && BO->isFast()) {
4135 Value *args[2] = {BO->getOperand(0), BO->getOperand(1)};
4136 auto mul = pushcse(
4137 B.CreateCall(getSumIntrinsic(*F.getParent(), BO->getType()), args));
4138 replaceAndErase(cur, mul);
4139 return "FAddToSum";
4140 }
4141 if (BO->getOpcode() == Instruction::FSub && BO->isFast()) {
4142 B.setFastMathFlags(getFast());
4143 Value *args[2] = {BO->getOperand(0),
4144 pushcse(B.CreateFNeg(BO->getOperand(1)))};
4145 auto mul =
4146 pushcse(B.CreateCall(getSumIntrinsic(*F.getParent(), BO->getType()),
4147 args, cur->getName()));
4148 replaceAndErase(cur, mul);
4149 return "FAddToSum";
4150 }
4151 }
4152 if (cur->getOpcode() == Instruction::FNeg) {
4153 B.setFastMathFlags(getFast());
4154 auto mul =
4155 pushcse(B.CreateFMulFMF(ConstantFP::get(cur->getType(), -1.0),
4156 cur->getOperand(0), cur, cur->getName()));
4157 replaceAndErase(cur, mul);
4158 return "FNegToMul";
4159 }
4160
4161 if (auto SI = dyn_cast<SelectInst>(cur)) {
4162 if (auto tc = dyn_cast<ConstantFP>(SI->getTrueValue()))
4163 if (auto fc = dyn_cast<ConstantFP>(SI->getFalseValue()))
4164 if (fc->isZero()) {
4165 if (tc->isExactlyValue(1.0)) {
4166 auto res =
4167 pushcse(B.CreateUIToFP(SI->getCondition(), tc->getType()));
4168 replaceAndErase(cur, res);
4169 return "SelToUIFP";
4170 }
4171 if (tc->isExactlyValue(-1.0)) {
4172 auto res =
4173 pushcse(B.CreateSIToFP(SI->getCondition(), tc->getType()));
4174 replaceAndErase(cur, res);
4175 return "SelToSIFP";
4176 }
4177 }
4178 }
4179
4180 if (auto P = isProduct(cur)) {
4181 SmallVector<Value *, 1> operands;
4182 std::optional<APFloat> constval;
4183 bool changed = false;
4184 for (auto &v : callOperands(P))
4185
4186 {
4187 if (auto P2 = isProduct(v)) {
4188 for (auto &v2 : callOperands(P2)) {
4189 push(v2);
4190 operands.push_back(v2);
4191 }
4192 push(P2);
4193 changed = true;
4194 continue;
4195 }
4196 if (auto C = dyn_cast<ConstantFP>(v)) {
4197 if (C->isExactlyValue(1.0)) {
4198 changed = true;
4199 continue;
4200 }
4201 if (C->isZero()) {
4202 replaceAndErase(cur, C);
4203 return "ZeroProduct";
4204 }
4205 if (!constval) {
4206 constval = C->getValue();
4207 continue;
4208 }
4209 constval = (*constval) * C->getValue();
4210 changed = true;
4211 continue;
4212 }
4213 if (auto op = dyn_cast<SelectInst>(v)) {
4214 if (auto tc = dyn_cast<ConstantFP>(op->getTrueValue()))
4215 if (tc->isZero()) {
4216 operands.push_back(pushcse(B.CreateUIToFP(
4217 pushcse(B.CreateNot(op->getCondition())), op->getType())));
4218 operands.push_back(op->getFalseValue());
4219 changed = true;
4220 continue;
4221 }
4222 if (auto tc = dyn_cast<ConstantFP>(op->getFalseValue()))
4223 if (tc->isZero()) {
4224 operands.push_back(
4225 pushcse(B.CreateUIToFP(op->getCondition(), op->getType())));
4226 operands.push_back(op->getTrueValue());
4227 changed = true;
4228 continue;
4229 }
4230 }
4231 operands.push_back(v);
4232 }
4233 if (constval)
4234 operands.push_back(ConstantFP::get(cur->getType(), *constval));
4235
4236 if (operands.size() == 0) {
4237 replaceAndErase(cur, ConstantFP::get(cur->getType(), 1.0));
4238 return "EmptyProduct";
4239 }
4240 if (operands.size() == 1) {
4241 replaceAndErase(cur, operands[0]);
4242 return "SingleProduct";
4243 }
4244 if (changed) {
4245 auto mul = pushcse(
4246 B.CreateCall(getProductIntrinsic(*F.getParent(), cur->getType()),
4247 operands, cur->getName()));
4248 replaceAndErase(cur, mul);
4249 return "ProductSimplification";
4250 }
4251 }
4252
4253 if (auto P = isSum(cur)) {
4254 // map from operand, to number of counts
4255 std::map<Value *, unsigned> operands;
4256 std::optional<APFloat> constval;
4257 bool changed = false;
4258 for (auto &v : callOperands(P)) {
4259 if (auto P2 = isSum(v)) {
4260 for (auto &v2 : callOperands(P2)) {
4261 push(v2);
4262 operands[v2]++;
4263 }
4264 push(P2);
4265 changed = true;
4266 continue;
4267 }
4268 if (auto C = dyn_cast<ConstantFP>(v)) {
4269 if (C->isExactlyValue(0.0)) {
4270 changed = true;
4271 continue;
4272 }
4273 if (!constval) {
4274 constval = C->getValue();
4275 continue;
4276 }
4277 constval = (*constval) + C->getValue();
4278 changed = true;
4279 continue;
4280 }
4281 operands[v]++;
4282 }
4283 if (constval)
4284 operands[ConstantFP::get(cur->getType(), *constval)]++;
4285
4286 if (operands.size() == 0) {
4287 replaceAndErase(cur, ConstantFP::get(cur->getType(), 0.0));
4288 return "EmptySum";
4289 }
4290 SmallVector<Value *, 1> args;
4291 for (auto &pair : operands) {
4292 if (pair.second == 1) {
4293 args.push_back(pair.first);
4294 continue;
4295 }
4296 changed = true;
4297 Value *sargs[] = {pair.first,
4298 ConstantFP::get(cur->getType(), (double)pair.second)};
4299 args.push_back(pushcse(B.CreateCall(
4300 getProductIntrinsic(*F.getParent(), cur->getType()), sargs)));
4301 }
4302 if (args.size() == 1) {
4303 replaceAndErase(cur, args[0]);
4304 return "SingleSum";
4305 }
4306 if (changed) {
4307 auto sum =
4308 pushcse(B.CreateCall(getSumIntrinsic(*F.getParent(), cur->getType()),
4309 args, cur->getName()));
4310 replaceAndErase(cur, sum);
4311 return "SumSimplification";
4312 }
4313 }
4314
4315 if (auto P = isProduct(cur)) {
4316 SmallVector<Value *, 1> operands;
4317 SmallVector<Value *, 1> conditions;
4318 for (auto &v : callOperands(P)) {
4319 // z = uitofp i1 c to float -> select c, (prod withot z), 0
4320 if (auto op = dyn_cast<UIToFPInst>(v)) {
4321 if (op->getOperand(0)->getType()->isIntegerTy(1)) {
4322 conditions.push_back(op->getOperand(0));
4323 continue;
4324 }
4325 }
4326 // z = sitofp i1 c to float -> select c, (-prod withot z), 0
4327 if (auto op = dyn_cast<SIToFPInst>(v)) {
4328 if (op->getOperand(0)->getType()->isIntegerTy(1)) {
4329 conditions.push_back(op->getOperand(0));
4330 operands.push_back(ConstantFP::get(cur->getType(), -1.0));
4331 continue;
4332 }
4333 }
4334 if (auto op = dyn_cast<SelectInst>(v)) {
4335 if (auto tc = dyn_cast<ConstantFP>(op->getTrueValue()))
4336 if (tc->isZero()) {
4337 conditions.push_back(pushcse(B.CreateNot(op->getCondition())));
4338 operands.push_back(op->getFalseValue());
4339 continue;
4340 }
4341 if (auto tc = dyn_cast<ConstantFP>(op->getFalseValue()))
4342 if (tc->isZero()) {
4343 conditions.push_back(op->getCondition());
4344 operands.push_back(op->getTrueValue());
4345 continue;
4346 }
4347 }
4348 operands.push_back(v);
4349 }
4350
4351 if (conditions.size()) {
4352 auto mul = pushcse(B.CreateCall(
4353 getProductIntrinsic(*F.getParent(), cur->getType()), operands));
4354 Value *condition = nullptr;
4355 for (auto v : conditions) {
4356 assert(v->getType()->isIntegerTy(1));
4357 if (condition == nullptr) {
4358 condition = v;
4359 continue;
4360 }
4361 condition = pushcse(B.CreateAnd(condition, v));
4362 }
4363 auto zero = ConstantFP::get(cur->getType(), 0.0);
4364 auto sel = pushcse(B.CreateSelect(condition, mul, zero, cur->getName()));
4365 replaceAndErase(cur, sel);
4366 return "ProductSelect";
4367 }
4368 }
4369
4370 // TODO
4371 if (auto P = isSum(cur)) {
4372 // whether negated
4373 SmallVector<std::pair<Value *, bool>, 1> conditions;
4374 bool legal = true;
4375 for (auto &v : callOperands(P)) {
4376 // z = uitofp i1 c to float -> select c, (prod withot z), 0
4377 if (auto op = dyn_cast<UIToFPInst>(v)) {
4378 if (op->getOperand(0)->getType()->isIntegerTy(1)) {
4379 conditions.emplace_back(op->getOperand(0), false);
4380 continue;
4381 }
4382 }
4383 // z = sitofp i1 c to float -> select c, (-prod withot z), 0
4384 if (auto op = dyn_cast<SIToFPInst>(v)) {
4385 if (op->getOperand(0)->getType()->isIntegerTy(1)) {
4386 conditions.emplace_back(op->getOperand(0), false);
4387 continue;
4388 }
4389 }
4390 if (auto op = dyn_cast<SelectInst>(v)) {
4391 if (auto tc = dyn_cast<ConstantFP>(op->getTrueValue()))
4392 if (tc->isZero()) {
4393 conditions.emplace_back(op->getCondition(), true);
4394 continue;
4395 }
4396 if (auto tc = dyn_cast<ConstantFP>(op->getFalseValue()))
4397 if (tc->isZero()) {
4398 conditions.emplace_back(op->getCondition(), false);
4399 continue;
4400 }
4401 }
4402 legal = false;
4403 break;
4404 }
4405 Value *condition = nullptr;
4406 if (legal)
4407 for (size_t i = 0; i < conditions.size(); i++) {
4408 size_t count = 0;
4409 for (size_t j = 0; j < conditions.size(); j++) {
4410 if (((conditions[i].first == conditions[j].first) &&
4411 (conditions[i].second == conditions[i].second)) ||
4412 ((isNot(conditions[i].first, conditions[j].first) &&
4413 (conditions[i].second != conditions[i].second))))
4414 count++;
4415 }
4416 if (count == conditions.size() && count > 1) {
4417 condition = conditions[i].first;
4418 if (conditions[i].second)
4419 condition = pushcse(B.CreateNot(condition, "sumpnot"));
4420 break;
4421 }
4422 }
4423
4424 if (condition) {
4425
4426 SmallVector<Value *, 1> operands;
4427 for (auto &v : callOperands(P)) {
4428 // z = uitofp i1 c to float -> select c, (prod withot z), 0
4429 if (auto op = dyn_cast<UIToFPInst>(v)) {
4430 if (op->getOperand(0)->getType()->isIntegerTy(1)) {
4431 operands.push_back(ConstantFP::get(cur->getType(), 1.0));
4432 continue;
4433 }
4434 }
4435 // z = sitofp i1 c to float -> select c, (-prod withot z), 0
4436 if (auto op = dyn_cast<SIToFPInst>(v)) {
4437 if (op->getOperand(0)->getType()->isIntegerTy(1)) {
4438 operands.push_back(ConstantFP::get(cur->getType(), -1.0));
4439 continue;
4440 }
4441 }
4442 if (auto op = dyn_cast<SelectInst>(v)) {
4443 if (auto tc = dyn_cast<ConstantFP>(op->getTrueValue()))
4444 if (tc->isZero()) {
4445 operands.push_back(op->getFalseValue());
4446 continue;
4447 }
4448 if (auto tc = dyn_cast<ConstantFP>(op->getFalseValue()))
4449 if (tc->isZero()) {
4450 operands.push_back(op->getTrueValue());
4451 continue;
4452 }
4453 }
4454 llvm::errs() << " unhandled call op sumselect: " << *v << "\n";
4455 assert(0);
4456 }
4457
4458 if (conditions.size()) {
4459 auto sum = pushcse(B.CreateCall(
4460 getSumIntrinsic(*F.getParent(), cur->getType()), operands));
4461 auto zero = ConstantFP::get(cur->getType(), 0.0);
4462 auto sel =
4463 pushcse(B.CreateSelect(condition, sum, zero, cur->getName()));
4464 replaceAndErase(cur, sel);
4465 return "SumSelect";
4466 }
4467 }
4468 }
4469 // (a1*b1) + (a1*c1) + (a1*d1 ) + ... -> a1 * (b1 + c1 + d1 + ...)
4470 if (auto S = isSum(cur)) {
4471 SmallVector<Value *, 1> allOps;
4472 auto combine = [](const SmallVector<Value *, 1> &lhs,
4473 SmallVector<Value *, 1> rhs) {
4474 SmallVector<Value *, 1> out;
4475 for (auto v : lhs) {
4476 bool seen = false;
4477 for (auto &v2 : rhs) {
4478 if (v == v2) {
4479 v2 = nullptr;
4480 seen = true;
4481 break;
4482 }
4483 }
4484 if (seen) {
4485 out.push_back(v);
4486 }
4487 }
4488 return out;
4489 };
4490 auto subtract = [](SmallVector<Value *, 1> lhs,
4491 const SmallVector<Value *, 1> &rhs) {
4492 for (auto v : rhs) {
4493 auto found = find(lhs, v);
4494 assert(found != lhs.end());
4495 lhs.erase(found);
4496 }
4497 return lhs;
4498 };
4499 bool seen = false;
4500 bool legal = true;
4501 for (auto op : callOperands(S)) {
4502 auto P = isProduct(op);
4503 if (!P) {
4504 legal = false;
4505 break;
4506 }
4507 if (!seen) {
4508 allOps = callOperands(P);
4509 seen = true;
4510 continue;
4511 }
4512 allOps = combine(allOps, callOperands(P));
4513 }
4514
4515 if (legal && allOps.size() > 0) {
4516 SmallVector<Value *, 1> operands;
4517 for (auto op : callOperands(S)) {
4518 auto P = isProduct(op);
4519 push(op);
4520 auto sub = subtract(callOperands(P), allOps);
4521 auto newprod = pushcse(B.CreateCall(
4522 getProductIntrinsic(*F.getParent(), S->getType()), sub));
4523 operands.push_back(newprod);
4524 }
4525 auto newsum = pushcse(B.CreateCall(
4526 getSumIntrinsic(*F.getParent(), S->getType()), operands));
4527 allOps.push_back(newsum);
4528 auto fprod = pushcse(B.CreateCall(
4529 getProductIntrinsic(*F.getParent(), S->getType()), allOps));
4530 replaceAndErase(cur, fprod);
4531 return "SumFactor";
4532 }
4533 }
4534
4535 /*
4536 // add (ext (x == expr )), ( ext (x == expr + 1)) -> -expr == c2 ) and c1
4537 != c2 -> false if (cur->getOpcode() == Instruction::Add) for (int j=0; j<2;
4538 j++) if (auto c0 = dyn_cast<ZExtInst>(cur->getOperand(j))) if (auto cmp0 =
4539 dyn_cast<ICmpInst>(c0->getOperand(0))) if (auto c1 =
4540 dyn_cast<CastInst>(cur->getOperand(1-j))) if (auto cmp1 =
4541 dyn_cast<ICmpInst>(c0->getOperand(0))) if (cmp0->getPredicate() ==
4542 ICmpInst::ICMP_EQ && cmp1->getPredicate() == ICmpInst::ICMP_EQ)
4543 {
4544 for (size_t i0 = 0; i0 < 2; i0++)
4545 for (size_t i1 = 0; i1 < 2; i1++)
4546 if (cmp0->getOperand(1 - i0) == cmp1->getOperand(1 - i1))
4547 auto e0 = SE.getSCEV(cmp0->getOperand(i0));
4548 auto e1 = SE.getSCEV(cmp1->getOperand(i1));
4549 auto m = SE.getMinusSCEV(e0, e1, SCEV::NoWrapMask);
4550 if (auto C = dyn_cast<SCEVConstant>(m)) {
4551 // if c1 == c2 don't need the and they are equivalent
4552 if (C->getValue()->isZero()) {
4553 } else {
4554 auto sel0 = pushcse(B.CreateSelect(cmp0,
4555 ConstantInt::get(cur->getType(), isa<ZExtInst>(cmp0) ? 1 : -1),
4556 ConstantInt::get(cur->getType(), 0));
4557 // if non one constant they must be distinct.
4558 replaceAndErase(cur,
4559 ConstantInt::getFalse(cur->getContext()));
4560 return "AndNEExpr";
4561 }
4562 }
4563 }
4564 }
4565 */
4566
4567 if (auto fcmp = dyn_cast<FCmpInst>(cur)) {
4568 auto predicate = fcmp->getPredicate();
4569 if (predicate == FCmpInst::FCMP_OEQ || predicate == FCmpInst::FCMP_UEQ ||
4570 predicate == FCmpInst::FCMP_UNE || predicate == FCmpInst::FCMP_ONE) {
4571 for (int i = 0; i < 2; i++)
4572 if (auto C = dyn_cast<ConstantFP>(fcmp->getOperand(i))) {
4573 if (C->isZero()) {
4574 // (a1*a2*...an) == 0 -> (a1 == 0) || (a2 == 0) || ... (a2 == 0)
4575 // (a1*a2*...an) != 0 -> ![ (a1 == 0) || (a2 == 0) || ... (a2 ==
4576 // 0)
4577 // ]
4578 if (auto P = isProduct(fcmp->getOperand(1 - i))) {
4579 Value *res = nullptr;
4580
4581 auto eq_predicate = predicate;
4582 if (predicate == FCmpInst::FCMP_UNE ||
4583 predicate == FCmpInst::FCMP_ONE)
4584 eq_predicate = fcmp->getInversePredicate();
4585
4586 for (auto &v : callOperands(P)) {
4587 auto ncmp1 = pushcse(B.CreateFCmp(eq_predicate, v, C));
4588 if (!res)
4589 res = ncmp1;
4590 else
4591 res = pushcse(B.CreateOr(res, ncmp1));
4592 }
4593
4594 if (predicate == FCmpInst::FCMP_UNE ||
4595 predicate == FCmpInst::FCMP_ONE) {
4596 res = pushcse(B.CreateNot(res));
4597 }
4598
4599 replaceAndErase(cur, res);
4600 return "CmpProductSplit";
4601 }
4602
4603 // (a1*b1) + (a1*c1) + (a1*d1 ) + ... ?= 0 -> a1 * (b1 + c1 + d1 +
4604 // ...) ?= 0
4605 if (auto S = isSum(fcmp->getOperand(1 - i))) {
4606 SmallVector<Value *, 1> allOps;
4607 auto combine = [](const SmallVector<Value *, 1> &lhs,
4608 SmallVector<Value *, 1> rhs) {
4609 SmallVector<Value *, 1> out;
4610 for (auto v : lhs) {
4611 bool seen = false;
4612 for (auto &v2 : rhs) {
4613 if (v == v2) {
4614 v2 = nullptr;
4615 seen = true;
4616 break;
4617 }
4618 }
4619 if (seen) {
4620 out.push_back(v);
4621 }
4622 }
4623 return out;
4624 };
4625 auto subtract = [](SmallVector<Value *, 1> lhs,
4626 const SmallVector<Value *, 1> &rhs) {
4627 for (auto v : rhs) {
4628 auto found = find(lhs, v);
4629 assert(found != lhs.end());
4630 lhs.erase(found);
4631 }
4632 return lhs;
4633 };
4634 bool seen = false;
4635 bool legal = true;
4636 for (auto op : callOperands(S)) {
4637 auto P = isProduct(op);
4638 if (!P) {
4639 legal = false;
4640 break;
4641 }
4642 if (!seen) {
4643 allOps = callOperands(P);
4644 seen = true;
4645 continue;
4646 }
4647 allOps = combine(allOps, callOperands(P));
4648 }
4649
4650 if (legal && allOps.size() > 0) {
4651 SmallVector<Value *, 1> operands;
4652 for (auto op : callOperands(S)) {
4653 auto P = isProduct(op);
4654 push(op);
4655 auto sub = subtract(callOperands(P), allOps);
4656 auto newprod = pushcse(B.CreateCall(
4657 getProductIntrinsic(*F.getParent(), C->getType()), sub));
4658 operands.push_back(newprod);
4659 }
4660 auto newsum = pushcse(B.CreateCall(
4661 getSumIntrinsic(*F.getParent(), C->getType()), operands));
4662 allOps.push_back(newsum);
4663 auto fprod = pushcse(B.CreateCall(
4664 getProductIntrinsic(*F.getParent(), C->getType()), allOps));
4665 auto fcmp = pushcse(B.CreateCmp(predicate, fprod, C));
4666 replaceAndErase(cur, fcmp);
4667 return "CmpSumFactor";
4668 }
4669 }
4670 }
4671 }
4672 }
4673 }
4674
4675 if (auto fcmp = dyn_cast<FCmpInst>(cur)) {
4676 auto predicate = fcmp->getPredicate();
4677 if (predicate == FCmpInst::FCMP_OEQ || predicate == FCmpInst::FCMP_UEQ ||
4678 predicate == FCmpInst::FCMP_UNE || predicate == FCmpInst::FCMP_ONE) {
4679 for (int i = 0; i < 2; i++)
4680 if (auto C = dyn_cast<ConstantFP>(fcmp->getOperand(i))) {
4681 if (C->isZero()) {
4682 // a + b == 0 -> ( (a == 0 & b == 0) || a == -b)
4683 if (auto S = isSum(fcmp->getOperand(1 - i))) {
4684 auto allOps = callOperands(S);
4685 if (!llvm::any_of(allOps, guaranteedDataDependent)) {
4686 auto eq_predicate = predicate;
4687 if (predicate == FCmpInst::FCMP_UNE ||
4688 predicate == FCmpInst::FCMP_ONE)
4689 eq_predicate = fcmp->getInversePredicate();
4690
4691 Value *op_checks = nullptr;
4692 for (auto a : allOps) {
4693 auto a_e0 = pushcse(B.CreateFCmp(eq_predicate, a, C));
4694 if (op_checks == nullptr)
4695 op_checks = a_e0;
4696 else
4697 op_checks = pushcse(B.CreateAnd(op_checks, a_e0));
4698 }
4699 SmallVector<Value *, 1> slice;
4700 for (size_t i = 1; i < allOps.size(); i++)
4701 slice.push_back(allOps[i]);
4702 auto ane = pushcse(B.CreateFCmp(
4703 eq_predicate, pushcse(B.CreateFNeg(allOps[0])),
4704 pushcse(B.CreateCall(getFunctionFromCall(S), slice))));
4705 auto ori = pushcse(B.CreateOr(op_checks, ane));
4706 if (predicate == FCmpInst::FCMP_UNE ||
4707 predicate == FCmpInst::FCMP_ONE) {
4708 ori = pushcse(B.CreateNot(ori));
4709 }
4710 replaceAndErase(cur, ori);
4711 return "Sum2ZeroSplit";
4712 }
4713 }
4714 }
4715 }
4716 }
4717 }
4718
4719 // (zext a) + (zext b) ?= 0 -> zext a ?= - zext b
4720 if (auto icmp = dyn_cast<CmpInst>(cur)) {
4721 if (icmp->getPredicate() == CmpInst::ICMP_EQ ||
4722 icmp->getPredicate() == CmpInst::ICMP_NE) {
4723 for (int i = 0; i < 2; i++)
4724 if (auto C = dyn_cast<ConstantInt>(icmp->getOperand(i)))
4725 if (C->isZero())
4726 if (auto add = dyn_cast<BinaryOperator>(icmp->getOperand(1 - i)))
4727 if (add->getOpcode() == Instruction::Add)
4728 if (auto a0 = dyn_cast<CastInst>(add->getOperand(0)))
4729 if (auto a1 = dyn_cast<CastInst>(add->getOperand(1)))
4730 if (a0->getOperand(0)->getType() ==
4731 a1->getOperand(0)->getType() &&
4732 (isa<ZExtInst>(a0) || isa<SExtInst>(a0))) {
4733 auto cmp2 = pushcse(B.CreateCmp(
4734 icmp->getPredicate(), a0, pushcse(B.CreateNeg(a1))));
4735 replaceAndErase(cur, cmp2);
4736 return "CmpExt0Shuffle";
4737 }
4738 }
4739 }
4740
4741 // sub 0, (zext i1 to N) -> sext i1 to N
4742 // sub 0, (sext i1 to N) -> zext i1 to N
4743 if (auto sub = dyn_cast<BinaryOperator>(cur))
4744 if (sub->getOpcode() == Instruction::Sub)
4745 if (auto C = dyn_cast<ConstantInt>(sub->getOperand(0)))
4746 if (C->isZero())
4747 if (auto a0 = dyn_cast<CastInst>(sub->getOperand(1)))
4748 if (a0->getOperand(0)->getType()->isIntegerTy(1)) {
4749
4750 Value *tmp = nullptr;
4751 if (isa<ZExtInst>(a0))
4752 tmp = pushcse(B.CreateSExt(a0->getOperand(0), a0->getType()));
4753 else if (isa<SExtInst>(a0))
4754 tmp = pushcse(B.CreateZExt(a0->getOperand(0), a0->getType()));
4755 else
4756 assert(0);
4757 replaceAndErase(cur, tmp);
4758 return "NegSZExtI1";
4759 }
4760
4761 if ((cur->getOpcode() == Instruction::LShr ||
4762 cur->getOpcode() == Instruction::SDiv ||
4763 cur->getOpcode() == Instruction::UDiv) &&
4764 cur->isExact())
4765 if (auto C2 = dyn_cast<ConstantInt>(cur->getOperand(1)))
4766 if (auto mul = dyn_cast<BinaryOperator>(cur->getOperand(0))) {
4767 // (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if
4768 // C2 divides C1
4769 if (mul->getOpcode() == Instruction::Mul)
4770 for (int i0 = 0; i0 < 2; i0++)
4771 if (auto C1 = dyn_cast<ConstantInt>(mul->getOperand(i0))) {
4772 auto lhs = C1->getValue();
4773 APInt rhs = C2->getValue();
4774 if (cur->getOpcode() == Instruction::LShr) {
4775 rhs = APInt(rhs.getBitWidth(), 1) << rhs;
4776 }
4777
4778 APInt div, rem;
4779 if (cur->getOpcode() == Instruction::LShr ||
4780 cur->getOpcode() == Instruction::UDiv)
4781 APInt::udivrem(lhs, rhs, div, rem);
4782 else
4783 APInt::sdivrem(lhs, rhs, div, rem);
4784 if (rem == 0) {
4785 auto res = pushcse(B.CreateMul(
4786 mul->getOperand(1 - i0),
4787 ConstantInt::get(cur->getType(), div),
4788 "mdiv." + cur->getName(), mul->hasNoUnsignedWrap(),
4789 mul->hasNoSignedWrap()));
4790 push(mul);
4791 replaceAndErase(cur, res);
4792 return "IMulDivConst";
4793 }
4794 }
4795 // (lshr exact (add a, C1), C2), C -> add a, (lhsr exact C1, C2) if
4796 // C2
4797 if (mul->getOpcode() == Instruction::Add)
4798 for (int i0 = 0; i0 < 2; i0++)
4799 if (auto C1 = dyn_cast<ConstantInt>(mul->getOperand(i0))) {
4800 auto lhs = C1->getValue();
4801 APInt rhs = C2->getValue();
4802 if (cur->getOpcode() == Instruction::LShr) {
4803 rhs = APInt(rhs.getBitWidth(), 1) << rhs;
4804 }
4805
4806 APInt div, rem;
4807 if (cur->getOpcode() == Instruction::LShr ||
4808 cur->getOpcode() == Instruction::UDiv)
4809 APInt::udivrem(lhs, rhs, div, rem);
4810 else
4811 APInt::sdivrem(lhs, rhs, div, rem);
4812 if (rem == 0 && ((mul->hasNoUnsignedWrap() &&
4813 (cur->getOpcode() == Instruction::LShr ||
4814 cur->getOpcode() == Instruction::UDiv)) ||
4815 (mul->hasNoSignedWrap() &&
4816 (cur->getOpcode() == Instruction::AShr ||
4817 cur->getOpcode() == Instruction::SDiv)))) {
4818 auto res = pushcse(B.CreateAdd(
4819 mul->getOperand(1 - i0),
4820 ConstantInt::get(cur->getType(), div),
4821 "madd." + cur->getName(), mul->hasNoUnsignedWrap(),
4822 mul->hasNoSignedWrap()));
4823 push(mul);
4824 replaceAndErase(cur, res);
4825 return "IAddDivConst";
4826 }
4827 }
4828 }
4829
4830 // mul (mul a, const1), (mul b, const2) -> mul (mul a, b), (const1, const2)
4831 if (cur->getOpcode() == Instruction::FMul)
4832 if (cur->isFast())
4833 if (auto mul1 = dyn_cast<Instruction>(cur->getOperand(0)))
4834 if (mul1->getOpcode() == Instruction::FMul && mul1->isFast())
4835 if (auto mul2 = dyn_cast<Instruction>(cur->getOperand(1)))
4836 if (mul2->getOpcode() == Instruction::FMul && mul2->isFast()) {
4837 for (auto i1 = 0; i1 < 2; i1++)
4838 for (auto i2 = 0; i2 < 2; i2++)
4839 if (isa<Constant>(mul1->getOperand(i1)))
4840 if (isa<Constant>(mul2->getOperand(i2))) {
4841
4842 auto n0 = pushcse(
4843 B.CreateFMulFMF(mul1->getOperand(1 - i1),
4844 mul2->getOperand(1 - i2), cur));
4845 auto n1 = pushcse(B.CreateFMulFMF(
4846 mul1->getOperand(i1), mul2->getOperand(i2), cur));
4847 auto n2 = pushcse(B.CreateFMulFMF(n0, n1, cur));
4848 push(mul1);
4849 push(mul2);
4850 replaceAndErase(cur, n2);
4851 return "MulMulConstConst";
4852 }
4853 }
4854
4855 // mul (mul a, const1), const2 -> mul a, (mul const1, const2)
4856 if ((cur->getOpcode() == Instruction::FMul && cur->isFast()) ||
4857 cur->getOpcode() == Instruction::Mul)
4858 for (auto i1 = 0; i1 < 2; i1++)
4859 if (auto mul1 = dyn_cast<Instruction>(cur->getOperand(i1)))
4860 if (((mul1->getOpcode() == Instruction::FMul && mul1->isFast())) ||
4861 mul1->getOpcode() == Instruction::FMul)
4862 if (auto const2 = dyn_cast<Constant>(cur->getOperand(1 - i1)))
4863 for (auto i2 = 0; i2 < 2; i2++)
4864 if (auto const1 = dyn_cast<Constant>(mul1->getOperand(i2))) {
4865 Value *res = nullptr;
4866 if (cur->getOpcode() == Instruction::FMul) {
4867 auto const3 = pushcse(B.CreateFMulFMF(const1, const2, mul1));
4868 res = pushcse(
4869 B.CreateFMulFMF(mul1->getOperand(1 - i2), const3, cur));
4870 } else {
4871 auto const3 = pushcse(B.CreateMul(const1, const2));
4872 res = pushcse(B.CreateMul(mul1->getOperand(1 - i2), const3));
4873 }
4874 push(mul1);
4875 replaceAndErase(cur, res);
4876 return "MulConstConst";
4877 }
4878
4879 if (auto fcmp = dyn_cast<FCmpInst>(cur)) {
4880 if (fcmp->getPredicate() == FCmpInst::FCMP_OEQ) {
4881 for (int i = 0; i < 2; i++)
4882 if (auto C = dyn_cast<ConstantFP>(fcmp->getOperand(i))) {
4883 if (C->isZero()) {
4884 if (auto fmul = dyn_cast<BinaryOperator>(fcmp->getOperand(1 - i))) {
4885 // (a*b) == 0 -> (a == 0) || (b == 0)
4886 if (fmul->getOpcode() == Instruction::FMul) {
4887 auto ncmp1 = pushcse(
4888 B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(0), C));
4889 auto ncmp2 = pushcse(
4890 B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(1), C));
4891 auto ori = pushcse(B.CreateOr(ncmp1, ncmp2));
4892 replaceAndErase(cur, ori);
4893 return "CmpFMulSplit";
4894 }
4895 // (a/b) == 0 -> (a == 0)
4896 if (fmul->getOpcode() == Instruction::FDiv) {
4897 auto ncmp1 = pushcse(
4898 B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(0), C));
4899 replaceAndErase(cur, ncmp1);
4900 return "CmpFDivSplit";
4901 }
4902 // (a - b) ?= 0 -> a ?= b
4903 if (fmul->getOpcode() == Instruction::FSub) {
4904 auto ncmp1 = pushcse(B.CreateFCmp(fcmp->getPredicate(),
4905 fmul->getOperand(0),
4906 fmul->getOperand(1)));
4907 replaceAndErase(cur, ncmp1);
4908 return "CmpFSubSplit";
4909 }
4910 }
4911 if (auto cast = dyn_cast<SIToFPInst>(fcmp->getOperand(1 - i))) {
4912 auto ncmp1 = pushcse(B.CreateICmp(
4913 ICmpInst::ICMP_EQ, cast->getOperand(0),
4914 ConstantInt::get(cast->getOperand(0)->getType(), 0)));
4915 replaceAndErase(cur, ncmp1);
4916 return "SFCmpToICmp";
4917 }
4918 if (auto cast = dyn_cast<UIToFPInst>(fcmp->getOperand(1 - i))) {
4919 auto ncmp1 = pushcse(B.CreateICmp(
4920 ICmpInst::ICMP_EQ, cast->getOperand(0),
4921 ConstantInt::get(cast->getOperand(0)->getType(), 0)));
4922 replaceAndErase(cur, ncmp1);
4923 return "UFCmpToICmp";
4924 }
4925 if (auto SI = dyn_cast<SelectInst>(fcmp->getOperand(1 - i))) {
4926 auto res = pushcse(
4927 B.CreateSelect(SI->getCondition(),
4928 pushcse(B.CreateCmp(fcmp->getPredicate(), C,
4929 SI->getTrueValue())),
4930 pushcse(B.CreateCmp(fcmp->getPredicate(), C,
4931 SI->getFalseValue()))));
4932 replaceAndErase(cur, res);
4933 return "FCmpSelect";
4934 }
4935 }
4936 }
4937 }
4938 }
4939 if (auto fcmp = dyn_cast<CmpInst>(cur)) {
4940 if (fcmp->getPredicate() == CmpInst::ICMP_EQ ||
4941 fcmp->getPredicate() == CmpInst::ICMP_NE ||
4942 fcmp->getPredicate() == CmpInst::FCMP_OEQ ||
4943 fcmp->getPredicate() == CmpInst::FCMP_ONE) {
4944
4945 // a + c ?= a -> c ?= 0 , if fast
4946 for (int i = 0; i < 2; i++)
4947 if (auto inst = dyn_cast<Instruction>(fcmp->getOperand(i)))
4948 if (inst->getOpcode() == Instruction::FAdd && inst->isFast())
4949 for (int i2 = 0; i2 < 2; i2++)
4950 if (inst->getOperand(i2) == fcmp->getOperand(1 - i)) {
4951 auto res = pushcse(
4952 B.CreateCmp(fcmp->getPredicate(), inst->getOperand(1 - i2),
4953 ConstantFP::get(inst->getType(), 0)));
4954 replaceAndErase(cur, res);
4955 return "CmpFAddSame";
4956 }
4957
4958 // a == b -> a & b | !a & !b
4959 // a != b -> a & !b | !a & b
4960 if (fcmp->getOperand(0)->getType()->isIntegerTy(1)) {
4961 auto a = fcmp->getOperand(0);
4962 auto b = fcmp->getOperand(1);
4963 if (fcmp->getPredicate() == CmpInst::ICMP_EQ) {
4964 auto res = pushcse(
4965 B.CreateOr(pushcse(B.CreateAnd(a, b)),
4966 pushcse(B.CreateAnd(pushcse(B.CreateNot(a)),
4967 pushcse(B.CreateNot(b))))));
4968 replaceAndErase(cur, res);
4969 return "CmpI1EQ";
4970 }
4971 if (fcmp->getPredicate() == CmpInst::ICMP_NE) {
4972 auto res = pushcse(
4973 B.CreateOr(pushcse(B.CreateAnd(pushcse(B.CreateNot(a)), b)),
4974 pushcse(B.CreateAnd(a, pushcse(B.CreateNot(b))))));
4975 replaceAndErase(cur, res);
4976 return "CmpI1NE";
4977 }
4978 }
4979
4980 for (int i = 0; i < 2; i++)
4981 if (auto CI = dyn_cast<ConstantInt>(fcmp->getOperand(i)))
4982 if (CI->isZero()) {
4983 // a + a ?= 0 -> a ?= 0
4984 if (auto addI = dyn_cast<Instruction>(fcmp->getOperand(1 - i))) {
4985 if (addI->getOpcode() == Instruction::Add &&
4986 addI->getOperand(0) == addI->getOperand(1)) {
4987 Value *res = pushcse(
4988 B.CreateCmp(fcmp->getPredicate(), addI->getOperand(0), CI));
4989 replaceAndErase(cur, res);
4990 return "CmpAddAdd";
4991 }
4992 // (a-b) ?= 0 -> a ?= b
4993 if (addI->getOpcode() == Instruction::Sub) {
4994 auto ncmp1 = pushcse(B.CreateICmp(fcmp->getPredicate(),
4995 addI->getOperand(0),
4996 addI->getOperand(1)));
4997 replaceAndErase(cur, ncmp1);
4998 return "CmpISubSplit";
4999 }
5000 }
5001 }
5002
5003 // (a * b) == (c * b) -> (a == c) || b == 0
5004 // (a * b) != (c * b) -> (a != c) && b != 0
5005 // auto S1 = SE.getSCEV(cur->getOperand(0));
5006 // auto S2 = SE.getSCEV(cur->getOperand(1));
5007 // llvm::errs() <<" attempting push: " << *cur << " S1: " << *S1 << "
5008 // S2: " << *S2 << " and " << *cur->getOperand(0) << " " <<
5009 // *cur->getOperand(1) << "\n";
5010 if (auto mul1 = dyn_cast<Instruction>(cur->getOperand(0)))
5011 if (auto mul2 = dyn_cast<Instruction>(cur->getOperand(1))) {
5012 if (mul1->getOpcode() == Instruction::Mul &&
5013 mul2->getOpcode() == Instruction::Mul &&
5014 mul1->hasNoUnsignedWrap() && mul1->hasNoSignedWrap() &&
5015 mul2->hasNoUnsignedWrap() && mul2->hasNoSignedWrap()) {
5016 for (int i = 0; i < 2; i++) {
5017 if (mul1->getOperand(i) == mul2->getOperand(i)) {
5018 Value *res = pushcse(B.CreateICmp(fcmp->getPredicate(),
5019 mul1->getOperand(1 - i),
5020 mul2->getOperand(1 - i)));
5021 auto b = mul1->getOperand(i);
5022 if (fcmp->getPredicate() == CmpInst::ICMP_EQ) {
5023 Value *bZero = pushcse(B.CreateICmp(
5024 CmpInst::ICMP_EQ, b, ConstantInt::get(b->getType(), 0)));
5025 res = pushcse(B.CreateOr(res, bZero));
5026 } else {
5027 Value *bZero = pushcse(B.CreateICmp(
5028 ICmpInst::ICMP_NE, b, ConstantInt::get(b->getType(), 0)));
5029 res = pushcse(B.CreateAnd(res, bZero));
5030 }
5031 replaceAndErase(cur, res);
5032 return "CmpMulCommon";
5033 }
5034 }
5035 }
5036 // same as above but now with floats
5037 if (mul1->getOpcode() == Instruction::FMul &&
5038 mul2->getOpcode() == Instruction::FMul && mul1->isFast() &&
5039 mul2->isFast()) {
5040 for (int i = 0; i < 2; i++) {
5041 if (mul1->getOperand(i) == mul2->getOperand(i)) {
5042 Value *res = pushcse(B.CreateFCmp(fcmp->getPredicate(),
5043 mul1->getOperand(1 - i),
5044 mul2->getOperand(1 - i)));
5045 auto b = mul1->getOperand(i);
5046 if (fcmp->getPredicate() == CmpInst::FCMP_OEQ) {
5047 Value *bZero = pushcse(B.CreateCmp(
5048 CmpInst::FCMP_OEQ, b, ConstantFP::get(b->getType(), 0)));
5049 res = pushcse(B.CreateOr(res, bZero));
5050 } else {
5051 Value *bZero = pushcse(B.CreateCmp(
5052 CmpInst::FCMP_ONE, b, ConstantFP::get(b->getType(), 0)));
5053 res = pushcse(B.CreateAnd(res, bZero));
5054 }
5055 replaceAndErase(cur, res);
5056 return "CmpMulfCommon";
5057 }
5058 }
5059 }
5060
5061 // (uitofp a ) ?= (uitofp b) -> a ?= b
5062 for (auto cond : {Instruction::UIToFP, Instruction::SIToFP})
5063 if (mul1->getOpcode() == cond && mul2->getOpcode() == cond &&
5064 mul1->getOperand(0)->getType() ==
5065 mul2->getOperand(0)->getType()) {
5066 Value *res = pushcse(B.CreateICmp(
5067 fcmp->getPredicate() == CmpInst::FCMP_OEQ ? CmpInst::ICMP_EQ
5068 : CmpInst::ICMP_NE,
5069 mul1->getOperand(0), mul2->getOperand(0)));
5070 replaceAndErase(cur, res);
5071 return "CmpUIToFP";
5072 }
5073
5074 // (zext a ) ?= (zext b) -> a ?= b
5075 if (mul1->getOpcode() == Instruction::ZExt &&
5076 mul2->getOpcode() == Instruction::ZExt &&
5077 mul1->getOperand(0)->getType() ==
5078 mul2->getOperand(0)->getType()) {
5079 Value *res =
5080 pushcse(B.CreateICmp(fcmp->getPredicate(), mul1->getOperand(0),
5081 mul2->getOperand(0)));
5082 replaceAndErase(cur, res);
5083 return "CmpZExt";
5084 }
5085
5086 // (zext i1 a ) == (sext i1 b) -> (!a & !b)
5087 // (zext i1 a ) != (sext i1 b) -> (a | b)
5088 if (auto mul1 = dyn_cast<Instruction>(cur->getOperand(0)))
5089 if (auto mul2 = dyn_cast<Instruction>(cur->getOperand(1)))
5090 if (((mul1->getOpcode() == Instruction::ZExt &&
5091 mul2->getOpcode() == Instruction::SExt) ||
5092 (mul1->getOpcode() == Instruction::SExt &&
5093 mul2->getOpcode() == Instruction::ZExt)) &&
5094 mul1->getOperand(0)->getType() ==
5095 mul2->getOperand(0)->getType() &&
5096 mul1->getOperand(0)->getType()->isIntegerTy(1)) {
5097
5098 Value *na = mul1->getOperand(0);
5099 Value *nb = mul2->getOperand(0);
5100
5101 if (fcmp->getPredicate() == ICmpInst::ICMP_EQ) {
5102 na = pushcse(B.CreateNot(na));
5103 nb = pushcse(B.CreateNot(nb));
5104 }
5105
5106 Value *res = nullptr;
5107 if (fcmp->getPredicate() == ICmpInst::ICMP_EQ)
5108 res = pushcse(B.CreateAnd(na, nb));
5109 else
5110 res = pushcse(B.CreateOr(na, nb));
5111
5112 replaceAndErase(cur, res);
5113 return "CmpZExtSExt";
5114 }
5115 }
5116 }
5117 if (fcmp->getPredicate() == ICmpInst::ICMP_EQ) {
5118 for (int i = 0; i < 2; i++) {
5119 if (auto C = dyn_cast<ConstantInt>(fcmp->getOperand(i))) {
5120 if (C->isZero()) {
5121 if (auto fmul = dyn_cast<BinaryOperator>(fcmp->getOperand(1 - i))) {
5122 // (a*b) == 0 -> (a == 0) || (b == 0)
5123 if (fmul->getOpcode() == Instruction::Mul) {
5124 auto ncmp1 = pushcse(
5125 B.CreateICmp(fcmp->getPredicate(), fmul->getOperand(0), C));
5126 auto ncmp2 = pushcse(
5127 B.CreateICmp(fcmp->getPredicate(), fmul->getOperand(1), C));
5128 auto ori = pushcse(B.CreateOr(ncmp1, ncmp2));
5129 replaceAndErase(cur, ori);
5130 return "CmpIMulSplit";
5131 }
5132 }
5133 }
5134 }
5135 }
5136 }
5137 }
5138
5139 if (cur->getOpcode() == Instruction::FAdd) {
5140 // add x, x -> mul 2.0
5141 if (cur->getOperand(0) == cur->getOperand(1) && cur->isFast()) {
5142 auto res = pushcse(B.CreateFMulFMF(
5143 cur->getOperand(0), ConstantFP::get(cur->getType(), 2.0), cur));
5144 replaceAndErase(cur, res);
5145 return "AddToMul2";
5146 }
5147 }
5148
5149 if (cur->getOpcode() == Instruction::Add) {
5150 // add x, (y * -1) -> sub x, y
5151 for (int i = 0; i < 2; i++) {
5152 if (auto mul1 = dyn_cast<Instruction>(cur->getOperand(i)))
5153 if (mul1->getOpcode() == Instruction::Mul) {
5154 for (int j = 0; j < 2; j++) {
5155 if (auto C = dyn_cast<ConstantInt>(mul1->getOperand(j))) {
5156 if (C->isMinusOne()) {
5157 auto res = pushcse(B.CreateSub(cur->getOperand(1 - i),
5158 mul1->getOperand(1 - j)));
5159 push(mul1);
5160
5161 replaceAndErase(cur, res);
5162 return "AddToSub";
5163 }
5164 }
5165 }
5166 }
5167 }
5168 }
5169
5170 if (auto SI = dyn_cast<SelectInst>(cur)) {
5171 auto shouldMove = [](Value *v) { return isa<Constant>(v); };
5172
5173 /*
5174 // select c, 0, x -> fmul (uitofp (!c)), x
5175 if (auto C1 = dyn_cast<ConstantFP>(SI->getTrueValue())) {
5176 if (C1->isZero()) {
5177 auto n = pushcse(B.CreateNot(SI->getCondition()));
5178 auto val = pushcse(B.CreateUIToFP(n, SI->getType()));
5179 auto res = pushcse(B.CreateFMul(val, SI->getFalseValue()));
5180 if (auto I = dyn_cast<Instruction>(res))
5181 I->setFast(true);
5182 replaceAndErase(cur, res);
5183 return true;
5184 }
5185 }
5186 // select c, x, 0 -> fmul (uitofp c), x
5187 if (auto C1 = dyn_cast<ConstantFP>(SI->getFalseValue())) {
5188 if (C1->isZero()) {
5189 auto val = pushcse(B.CreateUIToFP(SI->getCondition(), SI->getType()));
5190 auto res = pushcse(B.CreateFMul(val, SI->getTrueValue()));
5191 if (auto I = dyn_cast<Instruction>(res))
5192 I->setFast(true);
5193 replaceAndErase(cur, res);
5194 return true;
5195 }
5196 }
5197 */
5198
5199 // select c, (mul x y), 0 -> mul x, (select c, y, 0)
5200 for (int i = 0; i < 2; i++)
5201 if (auto inst = dyn_cast<Instruction>(SI->getOperand(1 + i)))
5202 if (inst->getOpcode() == Instruction::Mul)
5203 // inst->getOpcode() == Instruction::FMul)
5204 if (auto C = dyn_cast<Constant>(SI->getOperand(1 + (1 - i))))
5205 if ((isa<ConstantInt>(C) && cast<ConstantInt>(C)->isZero()) ||
5206 (isa<ConstantFP>(C) && cast<ConstantFP>(C)->isZero()))
5207 for (int j = 0; j < 2; j++)
5208 if (shouldMove(inst->getOperand(j))) {
5209 auto x = inst->getOperand(j);
5210 auto y = inst->getOperand(1 - j);
5211 auto isel = pushcse(B.CreateSelect(
5212 SI->getCondition(), (i == 0) ? y : C, (i == 0) ? C : y,
5213 "smulmove." + SI->getName()));
5214 Value *imul;
5215 if (cur->getType()->isIntegerTy())
5216 imul = pushcse(B.CreateMul(isel, x, "",
5217 inst->hasNoUnsignedWrap(),
5218 inst->hasNoSignedWrap()));
5219 else
5220 imul = pushcse(B.CreateFMulFMF(isel, x, inst, ""));
5221
5222 replaceAndErase(cur, imul);
5223 return "SelMulMove";
5224 }
5225
5226 // select c, (sitofp x), (sitofp y) -> sitofp (select c, x, y)
5227 // select c, c5, (sitofp y) -> sitofp (select c, c5, y)
5228 {
5229 Value *ops[2] = {nullptr, nullptr};
5230 bool legal = true;
5231 for (int i = 0; i < 2; i++) {
5232 if (isa<ConstantFP>(SI->getOperand(1 + i))) {
5233 ops[i] = nullptr;
5234 continue;
5235 }
5236 if (auto CI = dyn_cast<CastInst>(SI->getOperand(1 + i))) {
5237 if (CI->getOpcode() == Instruction::SIToFP) {
5238 ops[i] = CI->getOperand(0);
5239 continue;
5240 }
5241 }
5242 legal = false;
5243 break;
5244 }
5245 for (int i = 0; i < 2; i++) {
5246 if (!ops[i] && ops[1 - i])
5247 ops[i] = ConstantInt::get(ops[1 - i]->getType(), 0);
5248 }
5249 for (int i = 0; i < 2; i++) {
5250 if (ops[i] == nullptr || ops[i]->getType() != ops[0]->getType()) {
5251 legal = false;
5252 break;
5253 }
5254 }
5255 if (legal) {
5256 auto isel = pushcse(B.CreateSelect(SI->getCondition(), ops[0], ops[1],
5257 "seltofp." + SI->getName()));
5258 auto res = pushcse(B.CreateSIToFP(isel, SI->getType()));
5259
5260 replaceAndErase(cur, res);
5261 return "SelSIMerge";
5262 }
5263 }
5264 }
5265
5266 if (cur->getOpcode() == Instruction::Mul) {
5267 for (int i = 0; i < 2; i++) {
5268 // mul (x, 1) -> x
5269 if (auto C = dyn_cast<ConstantInt>(cur->getOperand(i)))
5270 if (C->isOne()) {
5271 replaceAndErase(cur, cur->getOperand(1 - i));
5272 return "MulIdent";
5273 }
5274
5275 // mul (zext i1 x), y -> mul (zext i1 x) y[x->1]
5276 if (auto Z = dyn_cast<ZExtInst>(cur->getOperand(i)))
5277 if (Z->getOperand(0)->getType()->isIntegerTy(1)) {
5278 auto prev = cur->getOperand(1 - i);
5279 auto next = replace(prev, Z->getOperand(0),
5280 ConstantInt::getTrue(cur->getContext()));
5281 if (next != prev) {
5282 auto res = pushcse(B.CreateMul(Z, next, "postmul." + cur->getName(),
5283 cur->hasNoUnsignedWrap(),
5284 cur->hasNoSignedWrap()));
5285 replaceAndErase(cur, res);
5286 return "MulReplaceZExt";
5287 }
5288 }
5289 }
5290
5291 /*
5292 // mul x, (select c, 0, y) -> select c (mul x 0), (mul x y)
5293 for (int i=0; i<2; i++)
5294 if (auto SI = dyn_cast<SelectInst>(cur->getOperand(i)))
5295 for (int j=0; j<2; j++)
5296 if (auto CI = dyn_cast<ConstantInt>(SI->getOperand(1+j)))
5297 if (CI->isZero()) {
5298 auto tval = (j == 0) ? CI :
5299 pushcse(B.CreateMul(SI->getTrueValue(), cur->getOperand(1-i), "tval." +
5300 cur->getName(), cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap())); auto
5301 fval = (j == 1) ? CI : pushcse(B.CreateMul(SI->getFalseValue(),
5302 cur->getOperand(1-i), "fval." + cur->getName(), cur->hasNoUnsignedWrap(),
5303 cur->hasNoSignedWrap()));
5304
5305 auto res = pushcse(B.CreateSelect(SI->getCondition(), tval, fval));
5306
5307 replaceAndErase(cur, res);
5308 return true;
5309 }
5310 */
5311
5312 // mul (sub x, y), -c -> mul (sub, y, x), c
5313 for (int i = 0; i < 2; i++)
5314 if (auto inst = dyn_cast<Instruction>(cur->getOperand(i)))
5315 if (inst->getOpcode() == Instruction::Sub)
5316 if (auto CI = dyn_cast<ConstantInt>(cur->getOperand(1 - i)))
5317 if (CI->isNegative()) {
5318 auto sub2 = pushcse(B.CreateSub(
5319 inst->getOperand(1), inst->getOperand(0), "",
5320 inst->hasNoUnsignedWrap(), inst->hasNoSignedWrap()));
5321 auto mul2 = pushcse(B.CreateMul(
5322 sub2, ConstantInt::get(CI->getType(), -CI->getValue()), "",
5323 cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap()));
5324
5325 replaceAndErase(cur, mul2);
5326 return "MulSubNegConst";
5327 }
5328 }
5329
5330 if (cur->getOpcode() == Instruction::Sub)
5331 if (auto CI = dyn_cast<ConstantInt>(cur->getOperand(0)))
5332 if (CI->isZero())
5333 if (auto zext = dyn_cast<Instruction>(cur->getOperand(1))) {
5334 // sub 0, (zext i1 x) -> sext x
5335 if (zext->getOpcode() == Instruction::ZExt &&
5336 zext->getOperand(0)->getType()->isIntegerTy(1)) {
5337 auto res =
5338 pushcse(B.CreateSExt(zext->getOperand(0), cur->getType()));
5339 replaceAndErase(cur, res);
5340 return "SubZExt";
5341 }
5342 // sub 0, (mul nsw nuw constant, x) -> mul nsw nuw -constant, x
5343 if (zext->getOpcode() == Instruction::Mul &&
5344 zext->hasNoUnsignedWrap() && zext->hasNoSignedWrap()) {
5345 for (int i = 0; i < 2; i++)
5346 if (auto CI = dyn_cast<ConstantInt>(zext->getOperand(i))) {
5347 auto res = pushcse(B.CreateMul(
5348 zext->getOperand(1 - i),
5349 ConstantInt::get(CI->getType(), -CI->getValue()),
5350 "neg." + zext->getName(), true, true));
5351 replaceAndErase(cur, res);
5352 return "SubMulConstant";
5353 }
5354 }
5355 }
5356
5357 // add (zext (and c1, x) ), (zext (and c1, y)) -> select c1, (add (zext x),
5358 // (zext y)), 0
5359 /*
5360 if (cur->getOpcode() == Instruction::Add ||
5361 cur->getOpcode() == Instruction::Sub ||
5362 cur->getOpcode() == Instruction::Mul)
5363 if (auto inst1 = dyn_cast<Instruction>(cur->getOperand(0)))
5364 if (auto inst2 = dyn_cast<Instruction>(cur->getOperand(1)))
5365 if (inst1->getOpcode() == Instruction::ZExt && inst2->getOpcode() ==
5366 Instruction::ZExt) if (auto and1 =
5367 dyn_cast<Instruction>(inst1->getOperand(0))) if (auto and2 =
5368 dyn_cast<Instruction>(inst2->getOperand(0))) if
5369 (and1->getType()->isIntegerTy(1) && and2->getType()->isIntegerTy(1) &&
5370 and1->getOpcode() == Instruction::And && and2->getOpcode() ==
5371 Instruction::And) { bool done = false; for (int i1=0; i1<2; i1++) for (int
5372 i2=0; i2<2; i2++) if (and1->getOperand(i1) == and2->getOperand(i2)) { auto
5373 c1 = and1->getOperand(i1); auto x = and1->getOperand(1-i1); x =
5374 pushcse(B.CreateZExt(x, inst1->getType())); auto y =
5375 and2->getOperand(1-i2);
5376
5377 y = pushcse(B.CreateZExt(y, inst2->getType()));
5378
5379 Value *res = nullptr;
5380 switch (cur->getOpcode()) {
5381 case Instruction::Add:
5382 res = pushcse(B.CreateAdd(x, y, "", cur->hasNoUnsignedWrap(),
5383 cur->hasNoSignedWrap())); break; case Instruction::Sub: res = B.CreateSub(x,
5384 y,
5385 "", cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap()); break; case
5386 Instruction::Mul: res = B.CreateMul(x, y, "", cur->hasNoUnsignedWrap(),
5387 cur->hasNoSignedWrap()); break; default: llvm_unreachable("Illegal opcode");
5388 }
5389 res = pushcse(B.CreateSelect(c1, res,
5390 Constant::getNullValue(cur->getType())));
5391
5392 replaceAndErase(cur, res);
5393 return;
5394 }
5395 }
5396 */
5397
5398 // add (select %c c0, x), (select %c, c1, y) -> select %c, (add c0, c1),
5399 // (add x, y) and for sub/mul/cmp
5400 if (cur->getOpcode() == Instruction::Add ||
5401 cur->getOpcode() == Instruction::Sub ||
5402 cur->getOpcode() == Instruction::Mul ||
5403 cur->getOpcode() == Instruction::FAdd ||
5404 cur->getOpcode() == Instruction::FSub ||
5405 cur->getOpcode() == Instruction::FMul ||
5406 // cur->getOpcode() == Instruction::SIToFP ||
5407 // cur->getOpcode() == Instruction::UIToFP ||
5408 cur->getOpcode() == Instruction::ICmp ||
5409 cur->getOpcode() == Instruction::FCmp) {
5410
5411 Value *SI1cond = nullptr;
5412 Value *SI1tval = nullptr;
5413 Value *SI1fval = nullptr;
5414 if (auto SI1 = dyn_cast<SelectInst>(cur->getOperand(0))) {
5415 SI1cond = SI1->getCondition();
5416 SI1tval = SI1->getTrueValue();
5417 SI1fval = SI1->getFalseValue();
5418 }
5419 if (auto SI1 = dyn_cast<ZExtInst>(cur->getOperand(0)))
5420 if (SI1->getOperand(0)->getType()->isIntegerTy(1)) {
5421 SI1cond = SI1->getOperand(0);
5422 SI1tval = SI1;
5423 SI1fval = ConstantInt::get(SI1->getType(), 0);
5424 }
5425 if (auto SI1 = dyn_cast<SExtInst>(cur->getOperand(0)))
5426 if (SI1->getOperand(0)->getType()->isIntegerTy(1)) {
5427 SI1cond = SI1->getOperand(0);
5428 SI1tval = SI1;
5429 SI1fval = ConstantInt::get(SI1->getType(), 0);
5430 }
5431 Value *SI2cond = nullptr;
5432 Value *SI2tval = nullptr;
5433 Value *SI2fval = nullptr;
5434
5435 auto op2 = cur->getOperand((cur->getOpcode() == Instruction::SIToFP ||
5436 cur->getOpcode() == Instruction::UIToFP)
5437 ? 0
5438 : 1);
5439 if (auto SI2 = dyn_cast<SelectInst>(op2)) {
5440 SI2cond = SI2->getCondition();
5441 SI2tval = SI2->getTrueValue();
5442 SI2fval = SI2->getFalseValue();
5443 }
5444 if (auto SI2 = dyn_cast<ZExtInst>(op2))
5445 if (SI2->getOperand(0)->getType()->isIntegerTy(1)) {
5446 SI2cond = SI2->getOperand(0);
5447 SI2tval = SI2;
5448 SI2fval = ConstantInt::get(SI2->getType(), 0);
5449 }
5450 if (auto SI2 = dyn_cast<SExtInst>(op2))
5451 if (SI2->getOperand(0)->getType()->isIntegerTy(1)) {
5452 SI2cond = SI2->getOperand(0);
5453 SI2tval = SI2;
5454 SI2fval = ConstantInt::get(SI2->getType(), 0);
5455 }
5456
5457 if (SI1cond && SI2cond && (SI1cond == SI2cond || isNot(SI1cond, SI2cond)))
5458 if ((SI1cond == SI2cond &&
5459 ((isa<Constant>(SI1tval) && isa<Constant>(SI2tval)) ||
5460 (isa<Constant>(SI1fval) && isa<Constant>(SI2fval)))) ||
5461 (SI1cond != SI2cond &&
5462 ((isa<Constant>(SI1tval) && isa<Constant>(SI2fval)) ||
5463 (isa<Constant>(SI1fval) && isa<Constant>(SI2tval))))
5464
5465 ) {
5466 Value *tval = nullptr;
5467 Value *fval = nullptr;
5468 bool inverted = SI1cond != SI2cond;
5469 switch (cur->getOpcode()) {
5470 case Instruction::SIToFP:
5471 tval =
5472 B.CreateSIToFP(SI1tval, cur->getType(), "tval." + cur->getName());
5473 fval =
5474 B.CreateSIToFP(SI1fval, cur->getType(), "fval." + cur->getName());
5475 break;
5476 case Instruction::UIToFP:
5477 tval =
5478 B.CreateUIToFP(SI1tval, cur->getType(), "tval." + cur->getName());
5479 fval =
5480 B.CreateUIToFP(SI1fval, cur->getType(), "fval." + cur->getName());
5481 break;
5482 case Instruction::FAdd:
5483 tval = B.CreateFAddFMF(SI1tval, inverted ? SI2fval : SI2tval, cur,
5484 "tval." + cur->getName());
5485 fval = B.CreateFAddFMF(SI1fval, inverted ? SI2tval : SI2fval, cur,
5486 "fval." + cur->getName());
5487 break;
5488 case Instruction::FSub:
5489 tval = B.CreateFSubFMF(SI1tval, inverted ? SI2fval : SI2tval, cur,
5490 "tval." + cur->getName());
5491 fval = B.CreateFSubFMF(SI1fval, inverted ? SI2tval : SI2fval, cur,
5492 "fval." + cur->getName());
5493 break;
5494 case Instruction::FMul:
5495 tval = B.CreateFMulFMF(SI1tval, inverted ? SI2fval : SI2tval, cur,
5496 "tval." + cur->getName());
5497 fval = B.CreateFMulFMF(SI1fval, inverted ? SI2tval : SI2fval, cur,
5498 "fval." + cur->getName());
5499 break;
5500 case Instruction::Add:
5501 tval = B.CreateAdd(SI1tval, inverted ? SI2fval : SI2tval,
5502 "tval." + cur->getName(), cur->hasNoUnsignedWrap(),
5503 cur->hasNoSignedWrap());
5504 fval = B.CreateAdd(SI1fval, inverted ? SI2tval : SI2fval,
5505 "fval." + cur->getName(), cur->hasNoUnsignedWrap(),
5506 cur->hasNoSignedWrap());
5507 break;
5508 case Instruction::Sub:
5509 tval = B.CreateSub(SI1tval, inverted ? SI2fval : SI2tval,
5510 "tval." + cur->getName(), cur->hasNoUnsignedWrap(),
5511 cur->hasNoSignedWrap());
5512 fval = B.CreateSub(SI1fval, inverted ? SI2tval : SI2fval,
5513 "fval." + cur->getName(), cur->hasNoUnsignedWrap(),
5514 cur->hasNoSignedWrap());
5515 break;
5516 case Instruction::Mul:
5517 tval = B.CreateMul(SI1tval, inverted ? SI2fval : SI2tval,
5518 "tval." + cur->getName(), cur->hasNoUnsignedWrap(),
5519 cur->hasNoSignedWrap());
5520 fval = B.CreateMul(SI1fval, inverted ? SI2tval : SI2fval,
5521 "fval." + cur->getName(), cur->hasNoUnsignedWrap(),
5522 cur->hasNoSignedWrap());
5523 break;
5524 case Instruction::ICmp:
5525 case Instruction::FCmp:
5526 tval = B.CreateCmp(cast<CmpInst>(cur)->getPredicate(), SI1tval,
5527 inverted ? SI2fval : SI2tval,
5528 "tval." + cur->getName());
5529 fval = B.CreateCmp(cast<CmpInst>(cur)->getPredicate(), SI1fval,
5530 inverted ? SI2tval : SI2fval,
5531 "fval." + cur->getName());
5532 break;
5533 default:
5534 llvm_unreachable("illegal opcode");
5535 }
5536 tval = pushcse(tval);
5537 fval = pushcse(fval);
5538
5539 auto res = pushcse(
5540 B.CreateSelect(SI1cond, tval, fval, "selmerge." + cur->getName()));
5541
5542 push(cur->getOperand(0));
5543 push(cur->getOperand(1));
5544 replaceAndErase(cur, res);
5545 return "BinopSelFuse";
5546 }
5547 }
5548
5549 /*
5550 // and (i == c), (i != d) -> and (i == c) && (c != d)
5551 if (cur->getOpcode() == Instruction::And) {
5552 auto lhs = replace(cur->getOperand(0), cur->getOperand(1),
5553 ConstantInt::getTrue(cur->getContext()));
5554 auto rhs = replace(cur->getOperand(1), cur->getOperand(0),
5555 ConstantInt::getTrue(cur->getContext()));
5556 if (lhs != cur->getOperand(0) || rhs != cur->getOperand(1)) {
5557 auto res = pushcse(B.CreateAnd(lhs, rhs, "postand." + cur->getName()));
5558 replaceAndErase(cur, res);
5559 return "AndReplace2";
5560 }
5561 }
5562 */
5563
5564 // and a, (or q, (not a)) -> and a q
5565 if (cur->getOpcode() == Instruction::And) {
5566 for (size_t i1 = 0; i1 < 2; i1++)
5567 if (auto inst2 = dyn_cast<Instruction>(cur->getOperand(1 - i1)))
5568 if (inst2->getOpcode() == Instruction::Or)
5569 for (size_t i2 = 0; i2 < 2; i2++)
5570 if (isNot(cur->getOperand(i1), inst2->getOperand(i2))) {
5571 auto q = inst2->getOperand(1 - i2);
5572 cur->setOperand(1 - i1, q);
5573 push(cur);
5574 push(q);
5575 push(inst2);
5576 push(cur->getOperand(i1));
5577 push(inst2->getOperand(i2));
5578 Q.insert(cur);
5579 for (auto U : cur->users())
5580 push(U);
5581 return "AndOrProp";
5582 }
5583 }
5584
5585 // and (and a, b), a) -> and a, b
5586 if (cur->getOpcode() == Instruction::And) {
5587 for (size_t i1 = 0; i1 < 2; i1++)
5588 if (auto inst2 = dyn_cast<Instruction>(cur->getOperand(i1)))
5589 if (inst2->getOpcode() == Instruction::And)
5590 for (size_t i2 = 0; i2 < 2; i2++)
5591 if (inst2->getOperand(i2) == cur->getOperand(1 - i1)) {
5592 replaceAndErase(cur, inst2);
5593 return "AndAndProp";
5594 }
5595 }
5596
5597 // or a, (and q, (not a)) -> and a q
5598 if (cur->getOpcode() == Instruction::And) {
5599 for (size_t i1 = 0; i1 < 2; i1++)
5600 if (auto inst2 = dyn_cast<Instruction>(cur->getOperand(1 - i1)))
5601 if (inst2->getOpcode() == Instruction::Or)
5602 for (size_t i2 = 0; i2 < 2; i2++)
5603 if (isNot(cur->getOperand(i1), inst2->getOperand(i2))) {
5604 auto q = inst2->getOperand(1 - i2);
5605 cur->setOperand(1 - i1, q);
5606 push(cur);
5607 push(q);
5608 push(inst2);
5609 push(cur->getOperand(i1));
5610 push(inst2->getOperand(i2));
5611 Q.insert(cur);
5612 for (auto U : cur->users())
5613 push(U);
5614 return "OrAndProp";
5615 }
5616 }
5617
5618 // and ( (a +/- b) != c ), ( (d +/- b) != c ) -> and ( a != (c -/+ b) ), (
5619 // d != (c -/+ b) )
5620 // also with or
5621 if (cur->getOpcode() == Instruction::And ||
5622 cur->getOpcode() == Instruction::Or) {
5623 for (auto cmpOp : {ICmpInst::ICMP_EQ, ICmpInst::ICMP_NE})
5624 for (auto interOp : {Instruction::Add, Instruction::Sub})
5625 if (auto cmp1 = dyn_cast<ICmpInst>(cur->getOperand(0)))
5626 if (auto cmp2 = dyn_cast<ICmpInst>(cur->getOperand(1)))
5627 for (size_t i1 = 0; i1 < 2; i1++)
5628 for (size_t i2 = 0; i2 < 2; i2++)
5629 if (cmp1->getOperand(1 - i1) == cmp2->getOperand(1 - i2) &&
5630 cmp1->getPredicate() == cmpOp &&
5631 cmp2->getPredicate() == cmpOp)
5632 if (auto add1 = dyn_cast<Instruction>(cmp1->getOperand(i1)))
5633 if (auto add2 = dyn_cast<Instruction>(cmp2->getOperand(i2)))
5634 if (add1->getOpcode() == interOp &&
5635 add2->getOpcode() == interOp)
5636 for (size_t ia = 0; ia < 2; ia++)
5637 if (add1->getOperand(ia) == add2->getOperand(ia)) {
5638
5639 auto b = add1->getOperand(ia);
5640 auto c = cmp1->getOperand(1 - i1);
5641 auto a = add1->getOperand(1 - ia);
5642 auto d = add2->getOperand(1 - ia);
5643
5644 Value *res = nullptr;
5645 if (interOp == Instruction::Add)
5646 res = pushcse(B.CreateSub(ia == 0 ? b : c,
5647 ia == 0 ? c : b));
5648 else
5649 res = pushcse(B.CreateAdd(ia == 0 ? b : c,
5650 ia == 0 ? c : b));
5651
5652 auto lhs = pushcse(B.CreateCmp(cmpOp, a, res));
5653 auto rhs = pushcse(B.CreateCmp(cmpOp, d, res));
5654
5655 Value *fres = nullptr;
5656 if (cur->getOpcode() == Instruction::And)
5657 fres = pushcse(B.CreateAnd(lhs, rhs));
5658 else
5659 fres = pushcse(B.CreateOr(lhs, rhs));
5660
5661 replaceAndErase(cur, fres);
5662 return "AndLinearShift";
5663 }
5664 }
5665
5666 // and ( expr == c1 ), ( expr == c2 ) and c1 != c2 -> false
5667 if (cur->getOpcode() == Instruction::And) {
5668 for (auto cmpOp : {ICmpInst::ICMP_EQ})
5669 if (auto cmp1 = dyn_cast<ICmpInst>(cur->getOperand(0)))
5670 if (auto cmp2 = dyn_cast<ICmpInst>(cur->getOperand(1)))
5671 for (size_t i1 = 0; i1 < 2; i1++)
5672 for (size_t i2 = 0; i2 < 2; i2++)
5673 if (cmp1->getOperand(1 - i1) == cmp2->getOperand(1 - i2) &&
5674 cmp1->getPredicate() == cmpOp &&
5675 cmp2->getPredicate() == cmpOp) {
5676 auto c1 = SE.getSCEV(cmp1->getOperand(i1));
5677 auto c2 = SE.getSCEV(cmp2->getOperand(i2));
5678 auto m = SE.getMinusSCEV(c1, c2, SCEV::NoWrapMask);
5679 if (auto C = dyn_cast<SCEVConstant>(m)) {
5680 // if c1 == c2 don't need the and they are equivalent
5681 if (C->getValue()->isZero()) {
5682 push(cmp1);
5683 push(cmp2);
5684 replaceAndErase(cur, cmp1);
5685 return "AndEQExpr";
5686 } else {
5687 // if non one constant they must be distinct.
5688 replaceAndErase(cur,
5689 ConstantInt::getFalse(cur->getContext()));
5690 return "AndNEExpr";
5691 }
5692 }
5693 }
5694 }
5695
5696 // (a | b) == 0 -> a == 0 & b == 0
5697 if (auto icmp = dyn_cast<ICmpInst>(cur))
5698 if (icmp->getPredicate() == ICmpInst::ICMP_EQ &&
5699 cur->getType()->isIntegerTy(1))
5700 for (int i = 0; i < 2; i++)
5701 if (auto C = dyn_cast<ConstantInt>(icmp->getOperand(i)))
5702 if (C->isZero())
5703 if (auto z = dyn_cast<BinaryOperator>(icmp->getOperand(1 - i)))
5704 if (z->getOpcode() == BinaryOperator::Or) {
5705 auto a0 = pushcse(B.CreateICmpEQ(z->getOperand(0), C));
5706 auto b0 = pushcse(B.CreateICmpEQ(z->getOperand(1), C));
5707 auto res = pushcse(B.CreateAnd(a0, b0));
5708 push(z);
5709 push(icmp);
5710 replaceAndErase(cur, res);
5711 return "OrEQZero";
5712 }
5713
5714 // add (mul a b), (mul c, b) -> mul (add a, c), b
5715 if (cur->getOpcode() == Instruction::Sub ||
5716 cur->getOpcode() == Instruction::Add) {
5717 if (auto mul1 = dyn_cast<Instruction>(cur->getOperand(0)))
5718 if (auto mul2 = dyn_cast<Instruction>(cur->getOperand(1)))
5719 if ((mul1->getOpcode() == Instruction::Mul &&
5720 mul2->getOpcode() == Instruction::Mul) ||
5721 (mul1->getOpcode() == Instruction::FMul &&
5722 mul2->getOpcode() == Instruction::FMul && mul1->isFast() &&
5723 mul2->isFast() && cur->isFast())) {
5724 for (int i1 = 0; i1 < 2; i1++)
5725 for (int i2 = 0; i2 < 2; i2++) {
5726 if (mul1->getOperand(i1) == mul2->getOperand(i2)) {
5727 Value *res = nullptr;
5728 switch (cur->getOpcode()) {
5729 case Instruction::Add:
5730 res = B.CreateAdd(mul1->getOperand(1 - i1),
5731 mul2->getOperand(1 - i2));
5732 break;
5733 case Instruction::Sub:
5734 res = B.CreateSub(mul1->getOperand(1 - i1),
5735 mul2->getOperand(1 - i2));
5736 break;
5737 case Instruction::FAdd:
5738 res = B.CreateFAddFMF(mul1->getOperand(1 - i1),
5739 mul2->getOperand(1 - i2), cur);
5740 break;
5741 case Instruction::FSub:
5742 res = B.CreateFSubFMF(mul1->getOperand(1 - i1),
5743 mul2->getOperand(1 - i2), cur);
5744 break;
5745 default:
5746 llvm_unreachable("Illegal opcode");
5747 }
5748 res = pushcse(res);
5749 Value *res2 = nullptr;
5750 if (cur->getType()->isIntegerTy())
5751 res2 = B.CreateMul(
5752 res, mul1->getOperand(i1), "",
5753 mul1->hasNoUnsignedWrap() && mul1->hasNoUnsignedWrap(),
5754 mul2->hasNoSignedWrap() && mul2->hasNoSignedWrap());
5755 else
5756 res2 = B.CreateFMulFMF(res, mul1->getOperand(i1), cur);
5757
5758 res2 = pushcse(res2);
5759
5760 replaceAndErase(cur, res2);
5761 return "InvDistributive";
5762 }
5763 }
5764 }
5765 }
5766
5767 // fadd (ext a), (ext b) -> ext (a + b)
5768 // fsub (ext a), (ext b) -> ext (a - b)
5769 // fmul (ext a), (ext b) -> ext (a * b)
5770 if (cur->getOpcode() == Instruction::FSub ||
5771 cur->getOpcode() == Instruction::FAdd ||
5772 cur->getOpcode() == Instruction::FMul ||
5773 cur->getOpcode() == Instruction::FNeg ||
5774 (isSum(cur) && callOperands(cast<CallBase>(cur)).size() == 2)) {
5775 auto opcode = cur->getOpcode();
5776 if (isSum(cur))
5777 opcode = Instruction::FAdd;
5778 auto Ty = B.getInt64Ty();
5779 SmallPtrSet<Instruction *, 1> temporaries;
5780 SmallVector<Instruction *, 1> precasts;
5781 Value *lhs = nullptr;
5782
5783 Value *prelhs = (cur->getOpcode() == Instruction::FNeg)
5784 ? ConstantFP::get(cur->getType(), 0.0)
5785 : cur->getOperand(0);
5786 Value *prerhs = (cur->getOpcode() == Instruction::FNeg)
5787 ? cur->getOperand(0)
5788 : cur->getOperand(1);
5789
5790 APInt minval(64, 0);
5791 APInt maxval(64, 0);
5792 if (auto C = dyn_cast<ConstantFP>(prelhs)) {
5793 APSInt Tmp(64);
5794 bool isExact = false;
5795 C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero,
5796 &isExact);
5797 if (isExact || C->isZero()) {
5798 minval = maxval = Tmp;
5799 lhs = ConstantInt::get(Ty, Tmp);
5800 }
5801 }
5802 if (auto ext = dyn_cast<CastInst>(prelhs)) {
5803 if (ext->getOpcode() == Instruction::UIToFP ||
5804 ext->getOpcode() == Instruction::SIToFP) {
5805 precasts.push_back(ext);
5806 auto ity = cast<IntegerType>(ext->getOperand(0)->getType());
5807 bool md = false;
5808 if (auto I = dyn_cast<Instruction>(ext->getOperand(0)))
5809 if (auto MD = hasMetadata(I, LLVMContext::MD_range)) {
5810 md = true;
5811 minval =
5812 cast<ConstantInt>(
5813 cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
5814 ->getValue()
5815 .zextOrTrunc(64);
5816 maxval =
5817 cast<ConstantInt>(
5818 cast<ConstantAsMetadata>(MD->getOperand(1))->getValue())
5819 ->getValue()
5820 .zextOrTrunc(64);
5821 }
5822 if (!md) {
5823 if (ext->getOpcode() == Instruction::UIToFP)
5824 maxval = APInt::getMaxValue(ity->getBitWidth()).zextOrTrunc(64);
5825 else {
5826 maxval =
5827 APInt::getSignedMaxValue(ity->getBitWidth()).zextOrTrunc(64);
5828 minval =
5829 APInt::getSignedMinValue(ity->getBitWidth()).zextOrTrunc(64);
5830 }
5831 }
5832 if (ext->getOperand(0)->getType() == Ty)
5833 lhs = ext->getOperand(0);
5834 else if (ity->getBitWidth() < Ty->getBitWidth()) {
5835 if (ext->getOpcode() == Instruction::UIToFP)
5836 lhs = B.CreateZExt(ext->getOperand(0), Ty);
5837 else
5838 lhs = B.CreateSExt(ext->getOperand(0), Ty);
5839 if (auto I = dyn_cast<Instruction>(lhs))
5840 if (I != ext->getOperand(0))
5841 temporaries.insert(I);
5842 }
5843 }
5844 }
5845
5846 Value *rhs = nullptr;
5847
5848 if (auto C = dyn_cast<ConstantFP>(prerhs)) {
5849 APSInt Tmp(64);
5850 bool isExact = false;
5851 C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero,
5852 &isExact);
5853 if (isExact || C->isZero()) {
5854 rhs = ConstantInt::get(Ty, Tmp);
5855 switch (opcode) {
5856 case Instruction::FAdd:
5857 minval += Tmp;
5858 maxval += Tmp;
5859 break;
5860 case Instruction::FSub:
5861 case Instruction::FNeg:
5862 minval -= Tmp;
5863 maxval -= Tmp;
5864 break;
5865 case Instruction::FMul:
5866 minval *= Tmp;
5867 maxval *= Tmp;
5868 break;
5869 default:
5870 llvm_unreachable("Illegal opcode");
5871 }
5872 }
5873 }
5874 if (auto ext = dyn_cast<CastInst>(prerhs)) {
5875 if (ext->getOpcode() == Instruction::UIToFP ||
5876 ext->getOpcode() == Instruction::SIToFP) {
5877 precasts.push_back(ext);
5878 auto ity = cast<IntegerType>(ext->getOperand(0)->getType());
5879 bool md = false;
5880 APInt rhsMin(64, 0);
5881 APInt rhsMax(64, 0);
5882 if (auto I = dyn_cast<Instruction>(ext->getOperand(0)))
5883 if (auto MD = hasMetadata(I, LLVMContext::MD_range)) {
5884 md = true;
5885 rhsMin =
5886 cast<ConstantInt>(
5887 cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
5888 ->getValue()
5889 .zextOrTrunc(64);
5890 rhsMax =
5891 cast<ConstantInt>(
5892 cast<ConstantAsMetadata>(MD->getOperand(1))->getValue())
5893 ->getValue()
5894 .zextOrTrunc(64);
5895 }
5896 if (!md) {
5897 if (ext->getOpcode() == Instruction::UIToFP) {
5898 rhsMax = APInt::getMaxValue(ity->getBitWidth()).zextOrTrunc(64);
5899 rhsMin = APInt(64, 0);
5900 } else {
5901 rhsMax =
5902 APInt::getSignedMaxValue(ity->getBitWidth()).zextOrTrunc(64);
5903 rhsMin =
5904 APInt::getSignedMinValue(ity->getBitWidth()).zextOrTrunc(64);
5905 }
5906 }
5907 switch (opcode) {
5908 case Instruction::FAdd:
5909 minval += rhsMin;
5910 maxval += rhsMax;
5911 break;
5912 case Instruction::FSub:
5913 case Instruction::FNeg:
5914 minval -= rhsMax;
5915 maxval -= rhsMin;
5916 break;
5917 case Instruction::FMul: {
5918 auto minf = [&](APInt a, APInt b) { return a.sle(b) ? a : b; };
5919 auto maxf = [&](APInt a, APInt b) { return a.sle(b) ? b : b; };
5920 minval = minf(
5921 minval * rhsMin,
5922 minf(minval * rhsMax, minf(maxval * rhsMin, maxval * rhsMax)));
5923 maxval = maxf(
5924 minval * rhsMin,
5925 maxf(minval * rhsMax, maxf(maxval * rhsMin, maxval * rhsMax)));
5926 break;
5927 }
5928 default:
5929 llvm_unreachable("Illegal opcode");
5930 }
5931 if (ext->getOperand(0)->getType() == Ty)
5932 rhs = ext->getOperand(0);
5933 else if (ity->getBitWidth() < Ty->getBitWidth()) {
5934 if (ext->getOpcode() == Instruction::UIToFP)
5935 rhs = B.CreateZExt(ext->getOperand(0), Ty);
5936 else
5937 rhs = B.CreateSExt(ext->getOperand(0), Ty);
5938 if (auto I = dyn_cast<Instruction>(rhs))
5939 if (I != ext->getOperand(0))
5940 temporaries.insert(I);
5941 }
5942 }
5943 }
5944
5945 if (lhs && rhs) {
5946 Value *res = nullptr;
5947 if (temporaries.count(dyn_cast<Instruction>(lhs)))
5948 lhs = pushcse(lhs);
5949 if (temporaries.count(dyn_cast<Instruction>(rhs)))
5950 rhs = pushcse(rhs);
5951 switch (opcode) {
5952 case Instruction::FAdd:
5953 res = B.CreateAdd(lhs, rhs, "", false, true);
5954 break;
5955 case Instruction::FSub:
5956 case Instruction::FNeg:
5957 res = B.CreateSub(lhs, rhs, "", false, true);
5958 break;
5959 case Instruction::FMul:
5960 res = B.CreateMul(lhs, rhs, "", false, true);
5961 break;
5962 default:
5963 llvm_unreachable("Illegal opcode");
5964 }
5965 res = pushcse(res);
5966 for (auto I : precasts)
5967 push(I);
5968 /*
5969 if (auto I = dyn_cast<Instruction>(res)) {
5970 Q.insert(I);
5971 Metadata *vals[] = {(Metadata *)ConstantAsMetadata::get(
5972 ConstantInt::get(Ty, minval)),
5973 (Metadata *)ConstantAsMetadata::get(
5974 ConstantInt::get(Ty, maxval))};
5975 I->setMetadata(LLVMContext::MD_range,
5976 MDNode::get(I->getContext(), vals));
5977 }
5978 */
5979 auto ext = pushcse(B.CreateSIToFP(res, cur->getType()));
5980 replaceAndErase(cur, ext);
5981 return "BinopExtToExtBinop";
5982
5983 } else {
5984 for (auto I : temporaries)
5985 I->eraseFromParent();
5986 }
5987 }
5988
5989 // select(cond, const1, b) ?= const2 -> select(cond, const1 ?= const2, b ?=
5990 // const2)
5991 if (auto fcmp = dyn_cast<FCmpInst>(cur))
5992 for (int i = 0; i < 2; i++)
5993 if (auto const2 = dyn_cast<Constant>(fcmp->getOperand(i)))
5994 if (auto sel = dyn_cast<SelectInst>(fcmp->getOperand(1 - i)))
5995 if (isa<Constant>(sel->getTrueValue()) ||
5996 isa<Constant>(sel->getFalseValue())) {
5997 auto tval = pushcse(B.CreateFCmp(fcmp->getPredicate(),
5998 sel->getTrueValue(), const2));
5999 auto fval = pushcse(B.CreateFCmp(fcmp->getPredicate(),
6000 sel->getFalseValue(), const2));
6001 auto res = pushcse(B.CreateSelect(sel->getCondition(), tval, fval));
6002 replaceAndErase(cur, res);
6003 return "FCmpSelectConst";
6004 }
6005
6006 // mul (mul a, const), b:not_sparse_or_const -> mul (mul a, b), const
6007 // note we avoid the case where b = (mul a, const) since otherwise
6008 // we create an infinite recursion
6009 // and also we make sure b isn't sparse, since sparse is the first
6010 // precedence for pushing, then constant, then others
6011 if (cur->getOpcode() == Instruction::FMul)
6012 if (cur->isFast() && cur->getOperand(0) != cur->getOperand(1))
6013 for (auto ic = 0; ic < 2; ic++)
6014 if (auto mul = dyn_cast<Instruction>(cur->getOperand(ic)))
6015 if (mul->getOpcode() == Instruction::FMul && mul->isFast()) {
6016 auto b = cur->getOperand(1 - ic);
6017 if (!isa<Constant>(b) && !directlySparse(b)) {
6018
6019 for (int i = 0; i < 2; i++)
6020 if (auto C = dyn_cast<Constant>(mul->getOperand(i))) {
6021 auto n0 =
6022 pushcse(B.CreateFMulFMF(mul->getOperand(1 - i), b, mul));
6023 auto n1 = pushcse(B.CreateFMulFMF(n0, C, cur));
6024 push(mul);
6025
6026 replaceAndErase(cur, n1);
6027 return "MulMulConst";
6028 }
6029 }
6030 }
6031
6032 // (mul c, a) +/- (mul c, b) -> mul c, (a +/- b)
6033 if (cur->getOpcode() == Instruction::FAdd ||
6034 cur->getOpcode() == Instruction::FSub) {
6035 if (auto mul1 = dyn_cast<BinaryOperator>(cur->getOperand(0))) {
6036 if (mul1->getOpcode() == Instruction::FMul && mul1->isFast()) {
6037 if (auto mul2 = dyn_cast<BinaryOperator>(cur->getOperand(1))) {
6038 if (mul2->getOpcode() == Instruction::FMul && mul2->isFast()) {
6039 for (int i = 0; i < 2; i++) {
6040 for (int j = 0; j < 2; j++) {
6041 if (mul1->getOperand(i) == mul2->getOperand(j)) {
6042 auto c = mul1->getOperand(i);
6043 auto a = mul1->getOperand(1 - i);
6044 auto b = mul2->getOperand(1 - j);
6045 Value *intermediate = nullptr;
6046
6047 if (cur->getOpcode() == Instruction::FAdd)
6048 intermediate = pushcse(B.CreateFAddFMF(a, b, cur));
6049 else
6050 intermediate = pushcse(B.CreateFSubFMF(a, b, cur));
6051
6052 auto res = pushcse(B.CreateFMulFMF(c, intermediate, cur));
6053 push(mul1);
6054 push(mul2);
6055 replaceAndErase(cur, res);
6056 return "FAddMulConstMulConst";
6057 }
6058 }
6059 }
6060 }
6061 }
6062 }
6063 }
6064 }
6065
6066 // fmul a, (sitofp (imul c:const, b)) -> fmul (fmul (a, (sitofp c))),
6067 // (sitofp b)
6068
6069 if (cur->getOpcode() == Instruction::FMul && cur->isFast()) {
6070 for (int i = 0; i < 2; i++)
6071 if (auto z = dyn_cast<Instruction>(cur->getOperand(i)))
6072 if (isa<SIToFPInst>(z) || isa<UIToFPInst>(z))
6073 if (auto imul = dyn_cast<BinaryOperator>(z->getOperand(0)))
6074 if (imul->getOpcode() == Instruction::Mul)
6075 for (int j = 0; j < 2; j++)
6076 if (auto c = dyn_cast<Constant>(imul->getOperand(j))) {
6077 auto b = imul->getOperand(1 - j);
6078 auto a = cur->getOperand(1 - i);
6079
6080 auto c_fp = pushcse(B.CreateSIToFP(c, cur->getType()));
6081 auto b_fp = pushcse(B.CreateSIToFP(b, cur->getType()));
6082 auto n_mul = pushcse(B.CreateFMulFMF(a, c_fp, cur));
6083 auto res = pushcse(
6084 B.CreateFMulFMF(n_mul, b_fp, cur, cur->getName()));
6085 push(imul);
6086 push(z);
6087 replaceAndErase(cur, res);
6088 return "FMulIMulConstRotate";
6089 }
6090 }
6091
6092 if (cur->getOpcode() == Instruction::FDiv) {
6093 Value *prelhs = cur->getOperand(0);
6094 Value *b = cur->getOperand(1);
6095
6096 // fdiv (sitofp a), b -> select (a == 0), 0 [ (fdiv 1 / b) * sitofp a]
6097 if (auto ext = dyn_cast<CastInst>(prelhs)) {
6098 if (ext->getOpcode() == Instruction::UIToFP ||
6099 ext->getOpcode() == Instruction::SIToFP) {
6100 push(ext);
6101
6102 Value *condition = pushcse(
6103 B.CreateICmpEQ(ext->getOperand(0),
6104 ConstantInt::get(ext->getOperand(0)->getType(), 0),
6105 "sdivcmp." + cur->getName()));
6106
6107 Value *fdiv = pushcse(
6108 B.CreateFMulFMF(pushcse(B.CreateFDivFMF(
6109 ConstantFP::get(cur->getType(), 1.0), b, cur)),
6110 ext, cur));
6111
6112 Value *sel = pushcse(
6113 B.CreateSelect(condition, ConstantFP::get(cur->getType(), 0.0),
6114 fdiv, "sfdiv." + cur->getName()));
6115
6116 replaceAndErase(cur, sel);
6117 return "FDivSIToFPProp";
6118 }
6119 }
6120 // fdiv (select c, 0, a), b -> select c, 0 (fdiv a, b)
6121 if (auto SI = dyn_cast<SelectInst>(prelhs)) {
6122 auto tvalC = dyn_cast<ConstantFP>(SI->getTrueValue());
6123 auto fvalC = dyn_cast<ConstantFP>(SI->getFalseValue());
6124 if ((tvalC && tvalC->isZero()) || (fvalC && fvalC->isZero())) {
6125 push(SI);
6126 auto ntval =
6127 (tvalC && tvalC->isZero())
6128 ? tvalC
6129 : pushcse(B.CreateFDivFMF(SI->getTrueValue(), b, cur,
6130 "sfdiv2_t." + cur->getName()));
6131 auto nfval =
6132 (fvalC && fvalC->isZero())
6133 ? fvalC
6134 : pushcse(B.CreateFDivFMF(SI->getFalseValue(), b, cur,
6135 "sfdiv2_f." + cur->getName()));
6136
6137 // Work around bad fdivfmf, fixed in LLVM 16+
6138 // https://github.com/llvm/llvm-project/commit/4f3b1c6dd6ef6c7b5bb79f058e3b7ba4bcdf4566
6139#if LLVM_VERSION_MAJOR < 16
6140 for (auto v : {ntval, nfval})
6141 if (auto I = dyn_cast<Instruction>(v))
6142 I->setFastMathFlags(cur->getFastMathFlags());
6143#endif
6144
6145 auto res = pushcse(B.CreateSelect(SI->getCondition(), ntval, nfval,
6146 "sfdiv2." + cur->getName()));
6147
6148 replaceAndErase(cur, res);
6149 return "FDivSelectProp";
6150 }
6151 }
6152 }
6153
6154 // div (mul a:not_sparse, b:is_sparse), c -> mul (div, a, c), b:is_sparse
6155 if (cur->getOpcode() == Instruction::FDiv) {
6156 auto c = cur->getOperand(1);
6157 if (auto z = dyn_cast<BinaryOperator>(cur->getOperand(0))) {
6158 if (z->getOpcode() == Instruction::FMul) {
6159 for (int i = 0; i < 2; i++) {
6160
6161 Value *a = z->getOperand(i);
6162 Value *b = z->getOperand(1 - i);
6163 if (directlySparse(a))
6164 continue;
6165 if (!directlySparse(b))
6166 continue;
6167
6168 Value *inner_fdiv = pushcse(B.CreateFDivFMF(a, c, cur));
6169 Value *outer_fmul = pushcse(B.CreateFMulFMF(inner_fdiv, b, z));
6170 push(z);
6171 replaceAndErase(cur, outer_fmul);
6172 return "FDivFMulSparseProp";
6173 }
6174 }
6175 }
6176 }
6177
6178 if (cur->getOpcode() == Instruction::FMul)
6179 for (int i = 0; i < 2; i++) {
6180
6181 Value *prelhs = cur->getOperand(i);
6182 Value *b = cur->getOperand(1 - i);
6183
6184 // fmul (fmul x:constant, y):z, b:constant .
6185 if (isa<Constant>(b))
6186 if (auto z = dyn_cast<BinaryOperator>(prelhs)) {
6187 if (z->getOpcode() == Instruction::FMul) {
6188 for (int j = 0; j < 2; j++) {
6189 auto x = z->getOperand(i);
6190 if (!isa<Constant>(x))
6191 continue;
6192 auto y = z->getOperand(1 - i);
6193 Value *inner_fmul = pushcse(B.CreateFMulFMF(x, b, cur));
6194 Value *outer_fmul = pushcse(B.CreateFMulFMF(inner_fmul, y, z));
6195 push(z);
6196 replaceAndErase(cur, outer_fmul);
6197 return "FMulFMulConstantReorder";
6198 }
6199 }
6200 }
6201
6202 auto integralFloat = [](Value *z) {
6203 if (auto C = dyn_cast<ConstantFP>(z)) {
6204 APSInt Tmp(64);
6205 bool isExact = false;
6206 C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero,
6207 &isExact);
6208 if (isExact || C->isZero()) {
6209 return true;
6210 }
6211 }
6212 return false;
6213 };
6214
6215 // fmul (fmul x:sparse, y):z, b
6216 // 1) If x and y are both sparse, do nothing and let the inner fmul be
6217 // simplified into a single sparse instruction. Thus, we may assume
6218 // y is not sparse.
6219 // 2) if b is sparse, swap it to be fmul (fmul x, b), y so the inner
6220 // sparsity can be simplified.
6221 // 3) otherwise b is not sparse and we should push the sparsity to
6222 // be the outermost value
6223 if (auto z = dyn_cast<BinaryOperator>(prelhs)) {
6224 if (z->getOpcode() == Instruction::FMul) {
6225 for (int j = 0; j < 2; j++) {
6226 auto x = z->getOperand(j);
6227 if (!directlySparse(x))
6228 continue;
6229 auto y = z->getOperand(1 - j);
6230 if (directlySparse(y))
6231 continue;
6232
6233 if (directlySparse(b) || integralFloat(b)) {
6234 push(z);
6235 Value *inner_fmul = pushcse(
6236 B.CreateFMulFMF(x, b, cur, "mulisr." + cur->getName()));
6237 Value *outer_fmul = pushcse(
6238 B.CreateFMulFMF(inner_fmul, y, z, "mulisr." + z->getName()));
6239 replaceAndErase(cur, outer_fmul);
6240 return "FMulFMulSparseReorder";
6241 } else {
6242 push(z);
6243 Value *inner_fmul = pushcse(
6244 B.CreateFMulFMF(y, b, cur, "mulisp." + cur->getName()));
6245 Value *outer_fmul = pushcse(
6246 B.CreateFMulFMF(inner_fmul, x, z, "mulisp." + z->getName()));
6247 replaceAndErase(cur, outer_fmul);
6248 return "FMulFMulSparsePush";
6249 }
6250 }
6251 }
6252 }
6253
6254 /*
6255 auto contains = [](MDNode *MD, Value *V) {
6256 if (!MD)
6257 return false;
6258 for (auto &op : MD->operands()) {
6259 auto V2 = cast<ValueAsMetadata>(op)->getValue();
6260 if (V == V2)
6261 return true;
6262 }
6263 return false;
6264 };
6265
6266 // fmul (sitofp a), b -> select (a == 0), 0 [noprop fmul ( sitofp a), b]
6267 if (true || !contains(hasMetadata(cur, "enzyme_fmulnoprop"), prelhs))
6268 if (auto ext = dyn_cast<CastInst>(prelhs)) {
6269 if (ext->getOpcode() == Instruction::UIToFP ||
6270 ext->getOpcode() == Instruction::SIToFP) {
6271 push(ext);
6272
6273 Value *condition = pushcse(B.CreateICmpEQ(
6274 ext->getOperand(0),
6275 ConstantInt::get(ext->getOperand(0)->getType(), 0),
6276 "mulcsicmp." + cur->getName()));
6277
6278 Value *fmul = pushcse(B.CreateFMulFMF(ext, b, cur));
6279 if (auto I = dyn_cast<Instruction>(fmul)) {
6280 SmallVector<Metadata *, 1> nodes;
6281 if (auto MD = hasMetadata(cur, "enzyme_fmulnoprop")) {
6282 for (auto &M : MD->operands()) {
6283 nodes.push_back(M.get());
6284 }
6285 }
6286 nodes.push_back(ValueAsMetadata::get(ext));
6287 I->setMetadata("enzyme_fmulnoprop",
6288 MDNode::get(I->getContext(), nodes));
6289 }
6290
6291 Value *sel = pushcse(
6292 B.CreateSelect(condition, ConstantFP::get(cur->getType(),
6293 0.0), fmul, "mulcsi." + cur->getName()));
6294
6295 replaceAndErase(cur, sel);
6296 return "FMulSIToFPProp";
6297 }
6298 }
6299 */
6300
6301 // fmul (select c, 0, a), b -> select c, 0 (fmul a, b)
6302 if (auto SI = dyn_cast<SelectInst>(prelhs)) {
6303 auto tvalC = dyn_cast<ConstantFP>(SI->getTrueValue());
6304 auto fvalC = dyn_cast<ConstantFP>(SI->getFalseValue());
6305 if ((tvalC && tvalC->isZero()) || (fvalC && fvalC->isZero())) {
6306 push(SI);
6307 auto ntval =
6308 (tvalC && tvalC->isZero())
6309 ? tvalC
6310 : pushcse(B.CreateFMulFMF(SI->getTrueValue(), b, cur));
6311 auto nfval =
6312 (fvalC && fvalC->isZero())
6313 ? fvalC
6314 : pushcse(B.CreateFMulFMF(SI->getFalseValue(), b, cur));
6315 auto res = pushcse(B.CreateSelect(SI->getCondition(), ntval, nfval,
6316 "mulsi." + cur->getName()));
6317
6318 replaceAndErase(cur, res);
6319 return "FMulSelectProp";
6320 }
6321 }
6322 }
6323
6324 if (auto icmp = dyn_cast<BinaryOperator>(cur)) {
6325 if (icmp->getOpcode() == Instruction::Xor) {
6326 for (int i = 0; i < 2; i++) {
6327 if (auto C = dyn_cast<ConstantInt>(icmp->getOperand(i))) {
6328 // !(cmp a, b) -> inverse(cmp), a, b
6329 if (C->isOne()) {
6330 if (auto scmp = dyn_cast<CmpInst>(icmp->getOperand(1 - i))) {
6331 auto next = pushcse(
6332 B.CreateCmp(scmp->getInversePredicate(), scmp->getOperand(0),
6333 scmp->getOperand(1), "not." + scmp->getName()));
6334 replaceAndErase(cur, next);
6335 return "NotCmp";
6336 }
6337 }
6338 }
6339 }
6340 }
6341 }
6342
6343 // select cmp, (ext tval), (ext fval) -> (cmp & tval) | (!cmp & fval)
6344 if (auto SI = dyn_cast<SelectInst>(cur)) {
6345
6346 Value *trueVal = nullptr;
6347 if (auto C = dyn_cast<ConstantFP>(SI->getTrueValue())) {
6348 if (C->isZero()) {
6349 trueVal = ConstantInt::getFalse(SI->getContext());
6350 }
6351 if (C->isExactlyValue(1.0)) {
6352 trueVal = ConstantInt::getTrue(SI->getContext());
6353 }
6354 }
6355 if (auto ext = dyn_cast<CastInst>(SI->getTrueValue())) {
6356 if (ext->getOperand(0)->getType()->isIntegerTy(1))
6357 trueVal = ext->getOperand(0);
6358 }
6359 Value *falseVal = nullptr;
6360 if (auto C = dyn_cast<ConstantFP>(SI->getFalseValue())) {
6361 if (C->isZero()) {
6362 falseVal = ConstantInt::getFalse(SI->getContext());
6363 }
6364 if (C->isExactlyValue(1.0)) {
6365 falseVal = ConstantInt::getTrue(SI->getContext());
6366 }
6367 }
6368 if (auto ext = dyn_cast<CastInst>(SI->getFalseValue())) {
6369 if (ext->getOperand(0)->getType()->isIntegerTy(1))
6370 falseVal = ext->getOperand(0);
6371 }
6372 if (trueVal && falseVal) {
6373 auto ncmp1 = pushcse(B.CreateAnd(SI->getCondition(), trueVal));
6374 auto notV = pushcse(B.CreateNot(SI->getCondition()));
6375 auto ncmp2 = pushcse(B.CreateAnd(notV, falseVal));
6376 auto ori = pushcse(B.CreateOr(ncmp1, ncmp2));
6377 auto ext = pushcse(B.CreateUIToFP(ori, SI->getType()));
6378 replaceAndErase(cur, ext);
6379 return "SelectI1Ext";
6380 }
6381 }
6382 // select cmp, (i1 tval), (i1 fval) -> (cmp & tval) | (!cmp & fval)
6383 if (cur->getType()->isIntegerTy(1))
6384 if (auto SI = dyn_cast<SelectInst>(cur)) {
6385 auto ncmp1 = pushcse(B.CreateAnd(SI->getCondition(), SI->getTrueValue()));
6386 auto notV = pushcse(B.CreateNot(SI->getCondition()));
6387 auto ncmp2 = pushcse(B.CreateAnd(notV, SI->getFalseValue()));
6388 auto ori = pushcse(B.CreateOr(ncmp1, ncmp2));
6389 replaceAndErase(cur, ori);
6390 return "SelectI1";
6391 }
6392
6393 if (auto PN = dyn_cast<PHINode>(cur)) {
6394 B.SetInsertPoint(PN->getParent()->getFirstNonPHI());
6395 if (SE.isSCEVable(PN->getType())) {
6396 auto S = SE.getSCEV(PN);
6397
6398 bool legal = false;
6399 if (auto SV = dyn_cast<SCEVUnknown>(S)) {
6400 auto val = SV->getValue();
6401 legal |= isa<Constant>(val) || isa<Argument>(val);
6402 if (auto I = dyn_cast<Instruction>(val)) {
6403 auto L = LI.getLoopFor(I->getParent());
6404 if ((!L || L->getCanonicalInductionVariable() != I) && I != PN)
6405 legal = true;
6406 }
6407 }
6408 if (isa<SCEVAddRecExpr>(S)) {
6409 auto L = LI.getLoopFor(PN->getParent());
6410 assert(L);
6411 if (L->getCanonicalInductionVariable() != PN)
6412 legal = true;
6413 }
6414
6415 if (legal) {
6416 for (auto U : cur->users()) {
6417 push(U);
6418 }
6419 auto point = PN->getParent()->getFirstNonPHI();
6420 auto tmp = cast<PHINode>(pushcse(B.CreatePHI(cur->getType(), 1)));
6421 cur->replaceAllUsesWith(tmp);
6422 cur->eraseFromParent();
6423
6424 Value *newIV = nullptr;
6425 {
6426#if LLVM_VERSION_MAJOR >= 22
6427 SCEVExpander Exp(SE, "sparseenzyme");
6428#else
6429 SCEVExpander Exp(SE, DL, "sparseenzyme");
6430#endif
6431 // We place that at first non phi as it may produce a non-phi
6432 // instruction and must thus be expanded after all phi's
6433 newIV = Exp.expandCodeFor(S, tmp->getType(), point);
6434 // sadly this doesn't exist on 11
6435 for (auto I : Exp.getAllInsertedInstructions())
6436 Q.insert(I);
6437 }
6438
6439 tmp->replaceAllUsesWith(newIV);
6440 tmp->eraseFromParent();
6441 return "InductVarSCEV";
6442 }
6443 }
6444 // phi a, a -> a
6445 {
6446 bool legal = true;
6447 for (size_t i = 1; i < PN->getNumIncomingValues(); i++) {
6448 auto v = PN->getIncomingValue(i);
6449 if (v != PN->getIncomingValue(0)) {
6450 legal = false;
6451 break;
6452 }
6453 }
6454 if (legal) {
6455 auto val = PN->getIncomingValue(0);
6456 replaceAndErase(cur, val);
6457 return "PhiMerge";
6458 }
6459 }
6460 // phi (idx=0) ? b, a, a -> select (idx == 0), b, a
6461 if (auto L = LI.getLoopFor(PN->getParent()))
6462 if (L->getHeader() == PN->getParent())
6463 if (auto idx = L->getCanonicalInductionVariable())
6464 if (auto PH = L->getLoopPreheader()) {
6465 bool legal = idx != PN;
6466 auto ph_idx = PN->getBasicBlockIndex(PH);
6467 assert(ph_idx >= 0);
6468 for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
6469 if ((int)i == ph_idx)
6470 continue;
6471 auto v = PN->getIncomingValue(i);
6472 if (v != PN->getIncomingValue(1 - ph_idx)) {
6473 legal = false;
6474 break;
6475 }
6476 // The given var must dominate the loop
6477 if (isa<Constant>(v))
6478 continue;
6479 if (isa<Argument>(v))
6480 continue;
6481 // exception for the induction itself, which we handle specially
6482 if (v == idx)
6483 continue;
6484 auto I = cast<Instruction>(v);
6485 if (!DT.dominates(I, PN)) {
6486 legal = false;
6487 break;
6488 }
6489 }
6490 if (legal) {
6491 auto val = PN->getIncomingValue(1 - ph_idx);
6492 push(val);
6493 if (val == idx) {
6494 val = pushcse(
6495 B.CreateSub(idx, ConstantInt::get(idx->getType(), 1)));
6496 }
6497
6498 auto val2 = PN->getIncomingValue(ph_idx);
6499 push(val2);
6500
6501 auto c0 = ConstantInt::get(idx->getType(), 0);
6502 // if (val2 == c0 && PN->getIncomingValue(1 - ph_idx) == idx) {
6503 // val = B.CreateBinaryIntrinsic(Intrinsic::umax, c0, val);
6504 //} else {
6505 auto eq = pushcse(B.CreateICmpEQ(idx, c0));
6506 val = pushcse(
6507 B.CreateSelect(eq, val2, val, "phisel." + cur->getName()));
6508 //}
6509
6510 replaceAndErase(cur, val);
6511 return "PhiLoop0Sel";
6512 }
6513 }
6514 // phi (sitofp a), (sitofp b) -> sitofp (phi a, b)
6515 {
6516 SmallVector<Value *, 1> negOps;
6517 SmallVector<Instruction *, 1> prevNegOps;
6518 bool legal = true;
6519 for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
6520 auto v = PN->getIncomingValue(i);
6521 if (auto C = dyn_cast<ConstantFP>(v)) {
6522 APSInt Tmp(64);
6523 bool isExact = false;
6524 C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero,
6525 &isExact);
6526 if (isExact || C->isZero()) {
6527 negOps.push_back(ConstantInt::get(B.getInt64Ty(), Tmp));
6528 continue;
6529 }
6530 }
6531 if (auto fneg = dyn_cast<Instruction>(v)) {
6532 if (fneg->getOpcode() == Instruction::SIToFP &&
6533 cast<IntegerType>(fneg->getOperand(0)->getType())
6534 ->getBitWidth() == 64) {
6535 negOps.push_back(fneg->getOperand(0));
6536 prevNegOps.push_back(fneg);
6537 continue;
6538 }
6539 }
6540 legal = false;
6541 }
6542 if (legal) {
6543 auto PN2 = cast<PHINode>(
6544 pushcse(B.CreatePHI(B.getInt64Ty(), PN->getNumIncomingValues())));
6545 PN2->takeName(PN);
6546 for (auto val : llvm::enumerate(negOps))
6547 PN2->addIncoming(val.value(), PN->getIncomingBlock(val.index()));
6548
6549 push(PN2);
6550
6551 auto fneg = pushcse(B.CreateSIToFP(PN2, PN->getType()));
6552
6553 for (auto I : prevNegOps)
6554 push(I);
6555 replaceAndErase(cur, fneg);
6556 return "PhiSIToFP";
6557 }
6558 }
6559 // phi (fneg a), (fneg b) -> fneg (phi a, b)
6560 {
6561 SmallVector<Value *, 1> negOps;
6562 SmallVector<Instruction *, 1> prevNegOps;
6563 bool legal = true;
6564 bool hasNeg = false;
6565 for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
6566 auto v = PN->getIncomingValue(i);
6567 if (auto C = dyn_cast<ConstantFP>(v)) {
6568 negOps.push_back(C->isZero() ? C : pushcse(B.CreateFNeg(C)));
6569 continue;
6570 }
6571 if (auto fneg = dyn_cast<Instruction>(v)) {
6572 if (fneg->getOpcode() == Instruction::FNeg) {
6573 negOps.push_back(fneg->getOperand(0));
6574 prevNegOps.push_back(fneg);
6575 continue;
6576 }
6577 }
6578 legal = false;
6579 }
6580 if (legal && hasNeg) {
6581 for (auto val : llvm::enumerate(negOps))
6582 PN->setIncomingValue(val.index(), val.value());
6583
6584 push(PN);
6585
6586 auto fneg = pushcse(B.CreateFNeg(PN));
6587
6588 for (auto &U : cur->uses()) {
6589 if (U.getUser() == fneg)
6590 continue;
6591 push(U.getUser());
6592 U.set(fneg);
6593 }
6594 for (auto I : prevNegOps)
6595 push(I);
6596 return "PhiFNeg";
6597 }
6598 }
6599 // phi (neg a), (neg b) -> neg (phi a, b)
6600 {
6601 SmallVector<Value *, 1> negOps;
6602 SmallVector<Instruction *, 1> prevNegOps;
6603 bool legal = true;
6604 bool hasNeg = false;
6605 for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
6606 auto v = PN->getIncomingValue(i);
6607 if (auto C = dyn_cast<ConstantInt>(v)) {
6608 negOps.push_back(pushcse(B.CreateNeg(C)));
6609 continue;
6610 }
6611 if (auto fneg = dyn_cast<BinaryOperator>(v)) {
6612 if (auto CI = dyn_cast<ConstantInt>(fneg->getOperand(0)))
6613 if (fneg->getOpcode() == Instruction::Sub && CI->isZero()) {
6614 negOps.push_back(fneg->getOperand(1));
6615 prevNegOps.push_back(fneg);
6616 hasNeg = true;
6617 continue;
6618 }
6619 }
6620 legal = false;
6621 }
6622 if (legal && hasNeg) {
6623 for (auto val : llvm::enumerate(negOps))
6624 PN->setIncomingValue(val.index(), val.value());
6625
6626 push(PN);
6627
6628 auto fneg = pushcse(B.CreateNeg(PN));
6629
6630 for (auto &U : cur->uses()) {
6631 if (U.getUser() == fneg)
6632 continue;
6633 push(U.getUser());
6634 U.set(fneg);
6635 }
6636 for (auto I : prevNegOps)
6637 push(I);
6638 return "PHINeg";
6639 }
6640 }
6641 // p = phi (mul a, c), (mul b, d) -> mul (phi a, b), (phi c, d) if
6642 // a,b,c != p
6643 {
6644 for (auto code :
6645 {(unsigned)Instruction::Mul, (unsigned)Instruction::Sub,
6646 (unsigned)Instruction::Add, (unsigned)Instruction::ZExt,
6647 (unsigned)Instruction::UIToFP, (unsigned)Instruction::ICmp,
6648 (unsigned)Instruction::FMul, (unsigned)Instruction::Or,
6649 (unsigned)Instruction::And}) {
6650 SmallVector<Value *, 1> lhsOps;
6651 SmallVector<Value *, 1> rhsOps;
6652 SmallVector<Instruction *, 1> prevOps;
6653 bool legal = true;
6654 bool fast = false;
6655 bool NUW = false;
6656 bool NSW = false;
6657 size_t numOps = 0;
6658 std::optional<llvm::CmpInst::Predicate> cmpPredicate;
6659 switch (code) {
6660 case Instruction::FMul:
6661 case Instruction::FSub:
6662 case Instruction::FAdd:
6663 fast = true;
6664 numOps = 2;
6665 break;
6666 case Instruction::Mul:
6667 case Instruction::Add:
6668 NUW = NSW = true;
6669 numOps = 2;
6670 break;
6671 case Instruction::Sub:
6672 NSW = true;
6673 numOps = 2;
6674 break;
6675 case Instruction::ICmp:
6676 case Instruction::FCmp:
6677 case Instruction::Or:
6678 case Instruction::And:
6679 numOps = 2;
6680 break;
6681 case Instruction::ZExt:
6682 case Instruction::UIToFP:
6683 numOps = 1;
6684 break;
6685 default:;
6686 llvm_unreachable("unknown opcode");
6687 }
6688 bool changed = false;
6689 for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
6690 auto v = PN->getIncomingValue(i);
6691 if (auto C = dyn_cast<ConstantInt>(v)) {
6692 if (code == Instruction::ZExt) {
6693 lhsOps.push_back(ConstantInt::getFalse(C->getContext()));
6694 continue;
6695 } else if (C->isZero()) {
6696 rhsOps.push_back(C);
6697 lhsOps.push_back(C);
6698 continue;
6699 }
6700 }
6701 if (auto C = dyn_cast<ConstantFP>(v)) {
6702 if (code == Instruction::UIToFP) {
6703 if (C->isZero()) {
6704 lhsOps.push_back(ConstantInt::getFalse(C->getContext()));
6705 }
6706 } else if (code == Instruction::FMul || code == Instruction::FSub ||
6707 code == Instruction::FAdd) {
6708 if (C->isZero()) {
6709 rhsOps.push_back(C);
6710 lhsOps.push_back(C);
6711 continue;
6712 }
6713 }
6714 }
6715 if (auto fneg = dyn_cast<Instruction>(v)) {
6716 if (fneg->getOpcode() == code) {
6717 switch (code) {
6718 case Instruction::FMul:
6719 case Instruction::FSub:
6720 case Instruction::FAdd:
6721 fast &= fneg->isFast();
6722 if (fneg->getOperand(0) == PN)
6723 legal = false;
6724 if (fneg->getOperand(1) == PN)
6725 legal = false;
6726 lhsOps.push_back(fneg->getOperand(0));
6727 rhsOps.push_back(fneg->getOperand(1));
6728 break;
6729 case Instruction::Mul:
6730 case Instruction::Sub:
6731 case Instruction::Add:
6732 NUW &= fneg->hasNoUnsignedWrap();
6733 NSW &= fneg->hasNoSignedWrap();
6734 if (fneg->getOperand(0) == PN)
6735 legal = false;
6736 if (fneg->getOperand(1) == PN)
6737 legal = false;
6738 lhsOps.push_back(fneg->getOperand(0));
6739 rhsOps.push_back(fneg->getOperand(1));
6740 break;
6741 case Instruction::Or:
6742 case Instruction::And:
6743 if (fneg->getOperand(0) == PN)
6744 legal = false;
6745 if (fneg->getOperand(1) == PN)
6746 legal = false;
6747 lhsOps.push_back(fneg->getOperand(0));
6748 rhsOps.push_back(fneg->getOperand(1));
6749 break;
6750 case Instruction::ICmp:
6751 case Instruction::FCmp:
6752 if (fneg->getOperand(0) == PN)
6753 legal = false;
6754 if (fneg->getOperand(1) == PN)
6755 legal = false;
6756 if (cmpPredicate) {
6757 if (*cmpPredicate != cast<CmpInst>(fneg)->getPredicate())
6758 legal = false;
6759 } else {
6760 cmpPredicate = cast<CmpInst>(fneg)->getPredicate();
6761 }
6762 lhsOps.push_back(fneg->getOperand(0));
6763 rhsOps.push_back(fneg->getOperand(1));
6764 break;
6765 case Instruction::ZExt:
6766 case Instruction::UIToFP:
6767 if (cast<IntegerType>(fneg->getOperand(0)->getType())
6768 ->getBitWidth() != 1)
6769 legal = false;
6770 lhsOps.push_back(fneg->getOperand(0));
6771 break;
6772 default:
6773 llvm_unreachable("unhandled opcode");
6774 }
6775 prevOps.push_back(fneg);
6776 changed = true;
6777 continue;
6778 }
6779 }
6780 legal = false;
6781 }
6782
6783 int preheader_fix = -1;
6784
6785 if (code == Instruction::ICmp || code == Instruction::FCmp) {
6786 if (!cmpPredicate)
6787 legal = false;
6788 auto L = LI.getLoopFor(PN->getParent());
6789 if (legal && L && L->getLoopPreheader() &&
6790 L->getCanonicalInductionVariable() &&
6791 L->getHeader() == PN->getParent()) {
6792 auto ph_idx = PN->getBasicBlockIndex(L->getLoopPreheader());
6793 if (isa<ConstantInt>(PN->getIncomingValue(ph_idx))) {
6794 lhsOps[ph_idx] =
6795 Constant::getNullValue(lhsOps[1 - ph_idx]->getType());
6796 rhsOps[ph_idx] =
6797 Constant::getNullValue(rhsOps[1 - ph_idx]->getType());
6798 preheader_fix = ph_idx;
6799 }
6800 }
6801 for (auto v : lhsOps)
6802 if (v->getType() != lhsOps[0]->getType())
6803 legal = false;
6804 for (auto v : rhsOps)
6805 if (v->getType() != rhsOps[0]->getType())
6806 legal = false;
6807 }
6808
6809 if (legal && changed) {
6810 auto lhsPN = cast<PHINode>(pushcse(
6811 B.CreatePHI(lhsOps[0]->getType(), PN->getNumIncomingValues())));
6812 PHINode *rhsPN = nullptr;
6813 if (numOps == 2)
6814 rhsPN = cast<PHINode>(pushcse(
6815 B.CreatePHI(rhsOps[0]->getType(), PN->getNumIncomingValues())));
6816
6817 for (auto val : llvm::enumerate(lhsOps))
6818 lhsPN->addIncoming(val.value(), PN->getIncomingBlock(val.index()));
6819
6820 if (numOps == 2) {
6821 for (auto val : llvm::enumerate(rhsOps))
6822 rhsPN->addIncoming(val.value(),
6823 PN->getIncomingBlock(val.index()));
6824 }
6825
6826 Value *fneg = nullptr;
6827 switch (code) {
6828 case Instruction::FMul:
6829 fneg = B.CreateFMul(lhsPN, rhsPN);
6830 if (auto I = dyn_cast<Instruction>(fneg))
6831 I->setFast(fast);
6832 break;
6833 case Instruction::FAdd:
6834 fneg = B.CreateFAdd(lhsPN, rhsPN);
6835 if (auto I = dyn_cast<Instruction>(fneg))
6836 I->setFast(fast);
6837 break;
6838 case Instruction::FSub:
6839 fneg = B.CreateFSub(lhsPN, rhsPN);
6840 if (auto I = dyn_cast<Instruction>(fneg))
6841 I->setFast(fast);
6842 break;
6843 case Instruction::Mul:
6844 fneg = B.CreateMul(lhsPN, rhsPN, "", NUW, NSW);
6845 break;
6846 case Instruction::Add:
6847 fneg = B.CreateAdd(lhsPN, rhsPN, "", NUW, NSW);
6848 break;
6849 case Instruction::Sub:
6850 fneg = B.CreateSub(lhsPN, rhsPN, "", NUW, NSW);
6851 break;
6852 case Instruction::ZExt:
6853 fneg = B.CreateZExt(lhsPN, PN->getType());
6854 break;
6855 case Instruction::FCmp:
6856 case Instruction::ICmp:
6857 fneg = B.CreateCmp(*cmpPredicate, lhsPN, rhsPN);
6858 break;
6859 case Instruction::UIToFP:
6860 fneg = B.CreateUIToFP(lhsPN, PN->getType());
6861 break;
6862 case Instruction::Or:
6863 fneg = B.CreateOr(lhsPN, rhsPN);
6864 break;
6865 case Instruction::And:
6866 fneg = B.CreateAnd(lhsPN, rhsPN);
6867 break;
6868 default:
6869 llvm_unreachable("unhandled opcode");
6870 }
6871
6872 push(fneg);
6873
6874 if (preheader_fix != -1) {
6875 auto L = LI.getLoopFor(PN->getParent());
6876 auto idx = L->getCanonicalInductionVariable();
6877 auto eq = pushcse(
6878 B.CreateICmpEQ(idx, ConstantInt::get(idx->getType(), 0)));
6879 fneg =
6880 pushcse(B.CreateSelect(eq, PN->getIncomingValue(preheader_fix),
6881 fneg, "phphisel." + cur->getName()));
6882 }
6883
6884 replaceAndErase(cur, fneg);
6885 return "PHIBinop";
6886 }
6887 }
6888 }
6889 // phi -> select
6890 if (PN->getNumIncomingValues() == 2) {
6891 for (int i = 0; i < 2; i++) {
6892 auto prev = PN->getIncomingBlock(i);
6893 if (!DT.dominates(prev, PN->getParent())) {
6894 continue;
6895 }
6896 auto br = dyn_cast<BranchInst>(prev->getTerminator());
6897 if (!br) {
6898 continue;
6899 }
6900 if (!br->isConditional()) {
6901 continue;
6902 }
6903 if (br->getSuccessor(0) != PN->getParent()) {
6904 continue;
6905 }
6906 if (br->getSuccessor(1) != PN->getIncomingBlock(1 - i)) {
6907 continue;
6908 }
6909
6910 Value *specVal = PN->getIncomingValue(1 - i);
6911 SetVector<Value *, std::deque<Value *>> todo;
6912 todo.insert(specVal);
6913 SetVector<Instruction *> toMove;
6914 bool legal = true;
6915 while (!todo.empty()) {
6916 auto cur = *todo.begin();
6917 todo.erase(todo.begin());
6918 auto I = dyn_cast<Instruction>(cur);
6919 if (!I)
6920 continue;
6921 if (I->mayReadOrWriteMemory()) {
6922 legal = false;
6923 break;
6924 }
6925 if (DT.dominates(I, PN))
6926 continue;
6927 for (size_t i = 0; i < I->getNumOperands(); i++)
6928 todo.insert(I->getOperand(i));
6929 toMove.insert(I);
6930 }
6931 if (!legal)
6932 continue;
6933 for (auto iter = toMove.rbegin(), end = toMove.rend(); iter != end;
6934 iter++) {
6935 (*iter)->moveBefore(br);
6936 }
6937 auto sel = pushcse(B.CreateSelect(
6938 br->getCondition(), PN->getIncomingValueForBlock(prev),
6939 PN->getIncomingValueForBlock(br->getSuccessor(1)),
6940 "tphisel." + cur->getName()));
6941
6942 replaceAndErase(cur, sel);
6943 return "TPhiSel";
6944 }
6945 }
6946 }
6947
6948 if (auto SI = dyn_cast<SelectInst>(cur)) {
6949 auto tval = replace(SI->getTrueValue(), SI->getCondition(),
6950 ConstantInt::getTrue(SI->getContext()));
6951 auto fval = replace(SI->getFalseValue(), SI->getCondition(),
6952 ConstantInt::getFalse(SI->getContext()));
6953 if (tval != SI->getTrueValue() || fval != SI->getFalseValue()) {
6954 auto res = pushcse(B.CreateSelect(SI->getCondition(), tval, fval,
6955 "postsel." + SI->getName()));
6956 replaceAndErase(cur, res);
6957 return "SelectReplace";
6958 }
6959 }
6960
6961 // and a, b -> and a b[with a true]
6962 if (cur->getOpcode() == Instruction::And) {
6963 auto lhs = replace(cur->getOperand(0), cur->getOperand(1),
6964 ConstantInt::getTrue(cur->getContext()));
6965 if (lhs != cur->getOperand(0)) {
6966 auto res = pushcse(
6967 B.CreateAnd(lhs, cur->getOperand(1), "postand." + cur->getName()));
6968 replaceAndErase(cur, res);
6969 return "AndReplaceLHS";
6970 }
6971 auto rhs = replace(cur->getOperand(1), cur->getOperand(0),
6972 ConstantInt::getTrue(cur->getContext()));
6973 if (rhs != cur->getOperand(1)) {
6974 auto res = pushcse(
6975 B.CreateAnd(cur->getOperand(0), rhs, "postand." + cur->getName()));
6976 replaceAndErase(cur, res);
6977 return "AndReplaceRHS";
6978 }
6979 }
6980
6981 // or a, b -> or a b[with a false]
6982 if (cur->getOpcode() == Instruction::Or) {
6983 auto lhs = replace(cur->getOperand(0), cur->getOperand(1),
6984 ConstantInt::getFalse(cur->getContext()));
6985 if (lhs != cur->getOperand(0)) {
6986 auto res = pushcse(
6987 B.CreateOr(lhs, cur->getOperand(1), "postor." + cur->getName()));
6988 replaceAndErase(cur, res);
6989 return "OrReplaceLHS";
6990 }
6991 auto rhs = replace(cur->getOperand(1), cur->getOperand(0),
6992 ConstantInt::getFalse(cur->getContext()));
6993 if (rhs != cur->getOperand(1)) {
6994 auto res = pushcse(
6995 B.CreateOr(cur->getOperand(0), rhs, "postor." + cur->getName()));
6996 replaceAndErase(cur, res);
6997 return "OrReplaceRHS";
6998 }
6999 }
7000 return {};
7001}
7002
7003class Constraints;
7004raw_ostream &operator<<(raw_ostream &os, const Constraints &c);
7005
7007 bool operator()(std::shared_ptr<const Constraints> lhs,
7008 std::shared_ptr<const Constraints> rhs) const;
7009};
7010
7012 ScalarEvolution &SE;
7013 const Loop *loopToSolve;
7014 const SmallVectorImpl<Instruction *> &Assumptions;
7015 DominatorTree &DT;
7016 using InnerTy = std::shared_ptr<const Constraints>;
7017 using SetTy = std::set<InnerTy, ConstraintComparator>;
7019 ConstraintContext(ScalarEvolution &SE, const Loop *loopToSolve,
7020 const SmallVectorImpl<Instruction *> &Assumptions,
7021 DominatorTree &DT)
7023 assert(loopToSolve);
7024 }
7028 DT(ctx.DT), seen(ctx.seen) {
7029 seen.insert(lhs);
7030 }
7033 DT(ctx.DT), seen(ctx.seen) {
7034 seen.insert(lhs);
7035 seen.insert(rhs);
7036 }
7037 bool contains(InnerTy x) const { return seen.count(x) != 0; }
7038};
7039
7040bool cannotDependOnLoopIV(const SCEV *S, const Loop *L) {
7041 assert(L);
7042 if (isa<SCEVConstant>(S))
7043 return true;
7044 if (auto M = dyn_cast<SCEVAddExpr>(S)) {
7045 for (auto o : M->operands())
7046 if (!cannotDependOnLoopIV(o, L))
7047 return false;
7048 return true;
7049 }
7050 if (auto M = dyn_cast<SCEVMulExpr>(S)) {
7051 for (auto o : M->operands())
7052 if (!cannotDependOnLoopIV(o, L))
7053 return false;
7054 return true;
7055 }
7056 if (auto M = dyn_cast<SCEVUDivExpr>(S)) {
7057 for (auto o : {M->getLHS(), M->getRHS()})
7058 if (!cannotDependOnLoopIV(o, L))
7059 return false;
7060 return true;
7061 }
7062 if (auto UV = dyn_cast<SCEVUnknown>(S)) {
7063 auto U = UV->getValue();
7064 if (isa<Argument>(U))
7065 return true;
7066 if (isa<Constant>(U))
7067 return true;
7068 auto I = cast<Instruction>(U);
7069 return !L->contains(I->getParent());
7070 }
7071 if (auto addrec = dyn_cast<SCEVAddRecExpr>(S)) {
7072 if (addrec->getLoop() == L)
7073 return false;
7074 for (auto o : addrec->operands())
7075 if (!cannotDependOnLoopIV(o, L))
7076 return false;
7077 return true;
7078 }
7079 if (auto expr = dyn_cast<SCEVSignExtendExpr>(S)) {
7080 return cannotDependOnLoopIV(expr->getOperand(), L);
7081 }
7082 llvm::errs() << " cannot tell if depends on loop iv: " << *S << "\n";
7083 return false;
7084}
7085
7086const SCEV *evaluateAtLoopIter(const SCEV *V, ScalarEvolution &SE,
7087 const Loop *find, const SCEV *replace) {
7088 assert(find);
7089 if (cannotDependOnLoopIV(V, find))
7090 return V;
7091 if (auto addrec = dyn_cast<SCEVAddRecExpr>(V)) {
7092 if (addrec->getLoop() == find) {
7093 auto V2 = addrec->evaluateAtIteration(replace, SE);
7094 return evaluateAtLoopIter(V2, SE, find, replace);
7095 }
7096 }
7097 if (auto div = dyn_cast<SCEVUDivExpr>(V)) {
7098 auto lhs = evaluateAtLoopIter(div->getLHS(), SE, find, replace);
7099 if (!lhs)
7100 return nullptr;
7101 auto rhs = evaluateAtLoopIter(div->getRHS(), SE, find, replace);
7102 if (!rhs)
7103 return nullptr;
7104 return SE.getUDivExpr(lhs, rhs);
7105 }
7106 return nullptr;
7107}
7108
7109class Constraints : public std::enable_shared_from_this<Constraints> {
7110public:
7111 const enum class Type {
7112 Union = 0,
7113 Intersect = 1,
7114 Compare = 2,
7115 All = 3,
7116 None = 4
7118
7119 using InnerTy = std::shared_ptr<const Constraints>;
7120
7121 using SetTy = std::set<InnerTy, ConstraintComparator>;
7122
7124
7125 const SCEV *const node;
7126 // whether equal to the node, or not equal to the node
7128 // the loop of the iv comparing against.
7129 const llvm::Loop *const Loop;
7130 // using SetTy = SmallVector<InnerTy, 0>;
7131 // using SetTy = SetVector<InnerTy, SmallVector<InnerTy, 0>,
7132 // std::set<InnerTy>>;
7133
7135 : ty(Type::Union), values(), node(nullptr), isEqual(false),
7136 Loop(nullptr) {}
7137
7138private:
7139 Constraints(const SCEV *v, bool isEqual, const llvm::Loop *Loop, bool)
7140 : ty(Type::Compare), values(), node(v), isEqual(isEqual), Loop(Loop) {}
7141
7142public:
7143 static InnerTy make_compare(const SCEV *v, bool isEqual,
7144 const llvm::Loop *Loop,
7145 const ConstraintContext &ctx);
7146
7148 : ty(t), values(), node(nullptr), isEqual(false), Loop(nullptr) {
7149 assert(t == Type::All || t == Type::None);
7150 }
7151 Constraints(Type t, const SetTy &c, bool check = true)
7152 : ty(t), values(c), node(nullptr), isEqual(false), Loop(nullptr) {
7153 assert(t != Type::All);
7154 assert(t != Type::None);
7155 assert(c.size() != 0);
7156 assert(c.size() != 1);
7157#ifndef NDEBUG
7158 SmallVector<InnerTy, 1> tmp(c.begin(), c.end());
7159 for (unsigned i = 0; i < tmp.size(); i++)
7160 for (unsigned j = 0; j < i; j++)
7161 assert(*tmp[i] != *tmp[j]);
7162 if (t == Type::Intersect) {
7163 for (auto &v : c) {
7164 assert(v->ty != Type::Intersect);
7165 }
7166 }
7167 if (t == Type::Union) {
7168 for (auto &v : c) {
7169 assert(v->ty != Type::Union);
7170 }
7171 }
7172 if (t == Type::Intersect && check) {
7173 for (unsigned i = 0; i < tmp.size(); i++)
7174 if (tmp[i]->ty == Type::Compare && tmp[i]->isEqual && tmp[i]->Loop)
7175 for (unsigned j = 0; j < tmp.size(); j++)
7176 if (tmp[j]->ty == Type::Compare)
7177 if (auto s = dyn_cast<SCEVAddRecExpr>(tmp[j]->node))
7178 assert(s->getLoop() != tmp[i]->Loop);
7179 }
7180#endif
7181 }
7182
7183 bool operator==(const Constraints &rhs) const {
7184 if (ty != rhs.ty) {
7185 return false;
7186 }
7187 if (node != rhs.node) {
7188 return false;
7189 }
7190 if (isEqual != rhs.isEqual) {
7191 return false;
7192 }
7193 if (Loop != rhs.Loop) {
7194 return false;
7195 }
7196 if (values.size() != rhs.values.size()) {
7197 return false;
7198 }
7199 for (auto pair : llvm::zip(values, rhs.values)) {
7200 if (*std::get<0>(pair) != *std::get<1>(pair))
7201 return false;
7202 }
7203 return true;
7204 //) && !(rhs.values < values)
7205 /*
7206for (size_t i=0; i<values.size(); i++)
7207if (*values[i] != *rhs.values[i]) return false;
7208return true;
7209 */
7210 }
7211 bool operator>(const Constraints &rhs) const { return rhs < *this; }
7212 bool operator<(const Constraints &rhs) const {
7213 if (ty < rhs.ty) {
7214 return true;
7215 }
7216 if (ty > rhs.ty) {
7217 return false;
7218 }
7219 if (node < rhs.node) {
7220 return true;
7221 }
7222 if (node > rhs.node) {
7223 return false;
7224 }
7225 if (isEqual < rhs.isEqual) {
7226 return true;
7227 }
7228 if (isEqual > rhs.isEqual) {
7229 return false;
7230 }
7231 if (Loop < rhs.Loop) {
7232 return true;
7233 }
7234 if (Loop > rhs.Loop) {
7235 return false;
7236 }
7237 if (values.size() < rhs.values.size()) {
7238 return true;
7239 }
7240 if (values.size() > rhs.values.size()) {
7241 return false;
7242 }
7243 for (auto pair : llvm::zip(values, rhs.values)) {
7244 if (*std::get<0>(pair) < *std::get<1>(pair))
7245 return true;
7246 if (*std::get<0>(pair) > *std::get<1>(pair))
7247 return false;
7248 }
7249 return false;
7250 }
7251 unsigned hash() const {
7252 unsigned res = 5 * (unsigned)ty +
7253 DenseMapInfo<const SCEV *>::getHashValue(node) + isEqual;
7254 res = llvm::detail::combineHashValue(res, (unsigned)(size_t)Loop);
7255 for (auto v : values)
7256 res = llvm::detail::combineHashValue(res, v->hash());
7257 return res;
7258 }
7259 bool operator!=(const Constraints &rhs) const { return !(*this == rhs); }
7260 static InnerTy all() {
7261 static auto allv = std::make_shared<Constraints>(Type::All);
7262 return allv;
7263 }
7264 static InnerTy none() {
7265 static auto nonev = std::make_shared<Constraints>(Type::None);
7266 return nonev;
7267 }
7268 bool isNone() const { return ty == Type::None; }
7269 bool isAll() const { return ty == Type::All; }
7270 static void insert(SetTy &set, InnerTy ty) {
7271 set.insert(ty);
7272 int mcount = 0;
7273 for (auto &v : set)
7274 if (*v == *ty)
7275 mcount++;
7276 assert(mcount == 1);
7277 /*
7278 for (auto &v : set)
7279 if (*v == *ty)
7280 return;
7281 set.push_back(ty);
7282 */
7283 }
7284 static SetTy intersect(const SetTy &lhs, const SetTy &rhs) {
7285 SetTy res;
7286 for (auto &v : lhs)
7287 if (rhs.count(v))
7288 res.insert(v);
7289 return res;
7290 }
7291 static void set_subtract(SetTy &set, const SetTy &rhs) {
7292 for (auto &v : rhs)
7293 if (set.count(v))
7294 set.erase(v);
7295 /*
7296 for (const auto &val : rhs)
7297 for (auto I = set.begin(); I != set.end(); I++) {
7298 if (**I == *val) {
7299 set.erase(I);
7300 break;
7301 }
7302 }
7303*/
7304 }
7305 __attribute__((noinline)) void dump() const { llvm::errs() << *this << "\n"; }
7306 InnerTy notB(const ConstraintContext &ctx) const {
7307 switch (ty) {
7308 case Type::None:
7309 return Constraints::all();
7310 case Type::All:
7311 return Constraints::none();
7312 case Type::Compare:
7313 return make_compare(node, !isEqual, Loop, ctx);
7314 case Type::Union: {
7315 // not of or's is and of not's
7316 SetTy next;
7317 for (const auto &v : values)
7318 insert(next, v->notB(ctx));
7319 if (next.size() == 1)
7320 llvm::errs() << " uold : " << *this << "\n";
7321 return std::make_shared<Constraints>(Type::Intersect, next);
7322 }
7323 case Type::Intersect: {
7324 // not of and's is or of not's
7325 SetTy next;
7326 for (const auto &v : values)
7327 insert(next, v->notB(ctx));
7328 if (next.size() == 1)
7329 llvm::errs() << " old : " << *this << "\n";
7330 return std::make_shared<Constraints>(Type::Union, next);
7331 }
7332 }
7333 return Constraints::none();
7334 }
7335 InnerTy orB(InnerTy rhs, const ConstraintContext &ctx) const {
7336 auto notLHS = notB(ctx);
7337 if (!notLHS)
7338 return nullptr;
7339 auto notRHS = rhs->notB(ctx);
7340 if (!notRHS)
7341 return nullptr;
7342 auto andV = notLHS->andB(notRHS, ctx);
7343 if (!andV)
7344 return nullptr;
7345 auto res = andV->notB(ctx);
7346 return res;
7347 }
7348 InnerTy andB(const InnerTy rhs, const ConstraintContext &ctx) const {
7349 assert(rhs);
7350 if (*rhs == *this)
7351 return shared_from_this();
7352 if (rhs->isNone())
7353 return rhs;
7354 if (rhs->isAll())
7355 return shared_from_this();
7356 if (isNone())
7357 return shared_from_this();
7358 if (isAll())
7359 return rhs;
7360
7361 // llvm::errs() << " anding: " << *this << " with " << *rhs << "\n";
7362 if (ctx.contains(shared_from_this()) || ctx.contains(rhs)) {
7363 // llvm::errs() << " %%% stopping recursion\n";
7364 return nullptr;
7365 }
7366 if (ty == Type::Compare && rhs->ty == Type::Compare) {
7367 auto sub = ctx.SE.getMinusSCEV(node, rhs->node);
7368 if (Loop == rhs->Loop) {
7369 // llvm::errs() << " + sameloop, sub=" << *sub << "\n";
7370 if (auto cst = dyn_cast<SCEVConstant>(sub)) {
7371 // the two solves are equivalent to each other
7372 if (cst->getValue()->isZero()) {
7373 // iv = a and iv = a
7374 // also iv != a and iv != a
7375 if (isEqual == rhs->isEqual)
7376 return shared_from_this();
7377 else {
7378 // iv = a and iv != a
7379 return Constraints::none();
7380 }
7381 } else {
7382 // the two solves are guaranteed to be distinct
7383 // iv == 0 and iv == 1
7384 if (isEqual && rhs->isEqual) {
7385 return Constraints::none();
7386
7387 } else if (!isEqual && !rhs->isEqual) {
7388 // iv != 0 and iv != 1
7389 SetTy vals;
7390 insert(vals, shared_from_this());
7391 insert(vals, rhs);
7392 return std::make_shared<Constraints>(Type::Intersect, vals);
7393 } else if (!isEqual) {
7394 assert(rhs->isEqual);
7395 // iv != 0 and iv == 1
7396 return rhs;
7397 ;
7398 } else {
7399 // iv == 0 and iv != 1
7400 assert(isEqual);
7401 assert(!rhs->isEqual);
7402 return shared_from_this();
7403 }
7404 }
7405 } else if (isEqual || rhs->isEqual) {
7406 // llvm::errs() << " + botheq\n";
7407 // eq(i, a) & i ?= b -> eq(i, a) & (a ?= b)
7408 if (auto addrec = dyn_cast<SCEVAddRecExpr>(sub)) {
7409 // we want a ?= b, but we can only represent loopvar ?= something
7410 // so suppose a-b is of the form X + Y * lv then a-b ?= 0 is
7411 // X + Y * lv ?= 0 -> lv ?= - X / Y
7412 if (addrec->isAffine()) {
7413 auto X = addrec->getStart();
7414 auto Y = addrec->getStepRecurrence(ctx.SE);
7415 auto MinusX = X;
7416
7417 if (isa<SCEVConstant>(Y) &&
7418 cast<SCEVConstant>(Y)->getAPInt().isNegative())
7419 Y = ctx.SE.getNegativeSCEV(Y);
7420 else
7421 MinusX = ctx.SE.getNegativeSCEV(X);
7422
7423 auto div = ctx.SE.getUDivExpr(MinusX, Y);
7424 auto div_e = ctx.SE.getUDivExactExpr(MinusX, Y);
7425 // in case of inexact division, check that these exactly equal
7426 // for replacement
7427
7428 if (div == div_e) {
7429 if (isEqual) {
7430 auto res = make_compare(div, /*isEqual*/ rhs->isEqual,
7431 addrec->getLoop(), ctx);
7432 // llvm::errs() << " simplified rhs to: " << *res << "\n";
7433 return andB(res, ctx);
7434 } else {
7435 assert(rhs->isEqual);
7436 auto res = make_compare(div, /*isEqual*/ isEqual,
7437 addrec->getLoop(), ctx);
7438 // llvm::errs() << " simplified lhs to: " << *res << "\n";
7439 return rhs->andB(res, ctx);
7440 }
7441 }
7442 }
7443 }
7444 if (isEqual && rhs->Loop &&
7446 auto res = make_compare(sub, /*isEqual*/ rhs->isEqual,
7447 /*loop*/ nullptr, ctx);
7448 // llvm::errs() << " simplified(noloop) rhs from " << *rhs
7449 // << " to: " << *res << "\n";
7450 return andB(res, ctx);
7451 }
7452 if (rhs->isEqual && Loop &&
7454 auto res =
7455 make_compare(sub, /*isEqual*/ isEqual, /*loop*/ nullptr, ctx);
7456 // llvm::errs() << " simplified(noloop) lhs from " << *rhs
7457 // << " to: " << *res << "\n";
7458 return rhs->andB(res, ctx);
7459 }
7460
7461 llvm::errs() << " warning: potential but unhandled simplification of "
7462 "equalities: "
7463 << *this << " and " << *rhs << " sub: " << *sub << "\n";
7464 }
7465 }
7466
7467 if (isEqual) {
7468 if (Loop)
7469 if (auto rep = evaluateAtLoopIter(rhs->node, ctx.SE, Loop, node))
7470 if (rep != rhs->node) {
7471 auto newrhs = make_compare(rep, rhs->isEqual, rhs->Loop, ctx);
7472 return andB(newrhs, ctx);
7473 }
7474
7475 // not loop -> node == 0
7476 if (!Loop) {
7477 for (auto sub1 : {ctx.SE.getMinusSCEV(node, rhs->node),
7478 ctx.SE.getMinusSCEV(rhs->node, node)}) {
7479 // llvm::errs() << " maybe replace lhs: " << *this << " rhs: " <<
7480 // *rhs
7481 // << " sub1: " << *sub1 << "\n";
7482 auto newrhs = make_compare(sub1, rhs->isEqual, rhs->Loop, ctx);
7483 if (*newrhs == *this)
7484 return shared_from_this();
7485 if (!isa<SCEVConstant>(rhs->node) && isa<SCEVConstant>(sub1)) {
7486 return andB(newrhs, ctx);
7487 }
7488 }
7489 }
7490 }
7491
7492 if (rhs->isEqual) {
7493 if (rhs->Loop)
7494 if (auto rep = evaluateAtLoopIter(node, ctx.SE, rhs->Loop, rhs->node))
7495 if (rep != node) {
7496 auto newlhs = make_compare(rep, isEqual, Loop, ctx);
7497 return newlhs->andB(rhs, ctx);
7498 }
7499
7500 // not loop -> node == 0
7501 if (!rhs->Loop) {
7502 for (auto sub1 : {ctx.SE.getMinusSCEV(node, rhs->node),
7503 ctx.SE.getMinusSCEV(rhs->node, node)}) {
7504 // llvm::errs() << " maybe replace lhs2: " << *this << " rhs: " <<
7505 // *rhs
7506 // << " sub1: " << *sub1 << "\n";
7507 auto newlhs = make_compare(sub1, isEqual, Loop, ctx);
7508 if (*newlhs == *this)
7509 return shared_from_this();
7510 if (!isa<SCEVConstant>(node) && isa<SCEVConstant>(sub1)) {
7511 return newlhs->andB(rhs, ctx);
7512 }
7513 }
7514 }
7515 }
7516
7517 if (!Loop && !rhs->Loop && isEqual == rhs->isEqual) {
7518 if (node == ctx.SE.getNegativeSCEV(rhs->node))
7519 return shared_from_this();
7520 }
7521
7522 SetTy vals;
7523 insert(vals, shared_from_this());
7524 insert(vals, rhs);
7525 if (vals.size() == 1) {
7526 llvm::errs() << "this: " << *this << " rhs: " << *rhs << "\n";
7527 }
7528 auto res = std::make_shared<Constraints>(Type::Intersect, vals);
7529 // llvm::errs() << " naiive comp merge: " << *res << "\n";
7530 return res;
7531 }
7532 if (ty == Type::Intersect && rhs->ty == Type::Intersect) {
7533 auto tmp = shared_from_this();
7534 for (const auto &v : rhs->values) {
7535 auto tmp2 = tmp->andB(v, ctx);
7536 if (!tmp2)
7537 return nullptr;
7538 tmp = std::move(tmp2);
7539 }
7540 return tmp;
7541 }
7542 if (ty == Type::Intersect && rhs->ty == Type::Compare) {
7543 SetTy vals;
7544 // Force internal merging to do individual compares
7545 bool foldedIn = false;
7546 for (auto en : llvm::enumerate(values)) {
7547 auto i = en.index();
7548 auto v = en.value();
7549 assert(v->ty != Type::Intersect);
7550 assert(v->ty != Type::All);
7551 assert(v->ty != Type::None);
7552 assert(v->ty == Type::Compare || v->ty == Type::Union);
7553 if (foldedIn) {
7554 insert(vals, v);
7555 continue;
7556 }
7557 // this is either a compare or a union
7558 auto tmp = rhs->andB(v, ctx);
7559 if (!tmp)
7560 return nullptr;
7561 switch (tmp->ty) {
7562 case Type::Union:
7563 case Type::All:
7564 llvm_unreachable("Impossible");
7565 case Type::None:
7566 return Constraints::none();
7567 case Type::Compare:
7568 insert(vals, tmp);
7569 foldedIn = true;
7570 break;
7571 // if intersected, these two were not foldable, try folding into later
7572 case Type::Intersect: {
7573 SetTy fuse;
7574 insert(fuse, rhs);
7575 insert(fuse, v);
7576
7577 Constraints trivialFuse(Type::Intersect, fuse, false);
7578
7579 // If this is not just making an intersect of the two operands,
7580 // remerge.
7581 if (trivialFuse != *tmp) {
7582 InnerTy newlhs = Constraints::all();
7583 bool legal = true;
7584 for (auto en2 : llvm::enumerate(values)) {
7585 auto i2 = en2.index();
7586 auto v2 = en2.value();
7587 if (i2 == i)
7588 continue;
7589 auto newlhs2 = newlhs->andB(v2, ctx);
7590 if (!newlhs2) {
7591 legal = false;
7592 break;
7593 }
7594 newlhs = std::move(newlhs2);
7595 }
7596 if (legal) {
7597 return newlhs->andB(tmp, ctx);
7598 }
7599 }
7600 insert(vals, v);
7601 }
7602 }
7603 }
7604 if (!foldedIn) {
7605 insert(vals, rhs);
7606 return std::make_shared<Constraints>(Type::Intersect, vals);
7607 } else {
7608 auto cur = Constraints::all();
7609 for (auto &iv : vals) {
7610 auto cur2 = cur->andB(iv, ctx);
7611 if (!cur2)
7612 return nullptr;
7613 cur = std::move(cur2);
7614 }
7615 return cur;
7616 }
7617 }
7618 if ((ty == Type::Intersect || ty == Type::Compare) &&
7619 rhs->ty == Type::Union) {
7620 SetTy unionVals = rhs->values;
7621 bool changed = false;
7622 SetTy ivVals;
7623 if (ty == Type::Intersect)
7624 ivVals = values;
7625 else
7626 insert(ivVals, shared_from_this());
7627
7628 ConstraintContext ctxd(ctx, shared_from_this(), rhs);
7629
7630 for (const auto &iv : ivVals) {
7631 SetTy nextunionVals;
7632 bool midchanged = false;
7633 for (auto &uv : unionVals) {
7634 auto tmp = iv->andB(uv, ctxd);
7635 if (!tmp) {
7636 midchanged = false;
7637 nextunionVals = unionVals;
7638 break;
7639 }
7640 switch (tmp->ty) {
7641 case Type::None:
7642 case Type::Compare:
7643 case Type::Union:
7644 insert(nextunionVals, tmp);
7645 changed |= tmp != uv;
7646 break;
7647 case Type::Intersect: {
7648 SetTy fuse;
7649 if (uv->ty == Type::Intersect)
7650 fuse = uv->values;
7651 else {
7652 assert(uv->ty == Type::Compare);
7653 insert(fuse, uv);
7654 }
7655 insert(fuse, iv);
7656
7657 Constraints trivialFuse(Type::Intersect, fuse, false);
7658 if (trivialFuse != *tmp) {
7659 insert(nextunionVals, tmp);
7660 midchanged = true;
7661 break;
7662 }
7663
7664 insert(nextunionVals, uv);
7665 break;
7666 }
7667 case Type::All:
7668 llvm_unreachable("Impossible");
7669 }
7670 }
7671 if (midchanged) {
7672 unionVals = nextunionVals;
7673 changed = true;
7674 }
7675 }
7676
7677 if (changed) {
7678 auto cur = Constraints::none();
7679 for (auto uv : unionVals) {
7680 cur = cur->orB(uv, ctxd);
7681 if (!cur)
7682 break;
7683 }
7684
7685 if (*cur != *rhs)
7686 return andB(cur, ctx);
7687 }
7688
7689 SetTy vals = ivVals;
7690 insert(vals, rhs);
7691 return std::make_shared<Constraints>(Type::Intersect, vals);
7692 }
7693 // Handled above via symmetry
7694 if (rhs->ty == Type::Intersect || rhs->ty == Type::Compare) {
7695 return rhs->andB(shared_from_this(), ctx);
7696 }
7697 // (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d)
7698 // and (c or e))
7699 if (ty == Type::Union && rhs->ty == Type::Union) {
7700 if (*this == *rhs->notB(ctx)) {
7701 return Constraints::none();
7702 }
7703 SetTy intersection = intersect(values, rhs->values);
7704 if (intersection.size() != 0) {
7705 InnerTy other_lhs = remove(intersection);
7706 InnerTy other_rhs = rhs->remove(intersection);
7707 InnerTy remainder;
7708 if (intersection.size() == 1)
7709 remainder = *intersection.begin();
7710 else {
7711 remainder = std::make_shared<Constraints>(Type::Union, intersection);
7712 }
7713 return remainder->orB(other_lhs->andB(other_rhs, ctx), ctx);
7714 }
7715
7716 bool changed = false;
7717 SetTy lhsVals = values;
7718 SetTy rhsVals = rhs->values;
7719
7720 ConstraintContext ctxd(ctx, shared_from_this(), rhs);
7721
7722 SetTy distributedVals;
7723 for (const auto &l1 : lhsVals) {
7724 bool subchanged = false;
7725 SetTy subDistributedVals;
7726 for (auto &r1 : rhsVals) {
7727 auto tmp = l1->andB(r1, ctxd);
7728 if (!tmp) {
7729 subchanged = false;
7730 break;
7731 }
7732
7733 if (l1->ty == Type::Intersect || r1->ty == Type::Intersect) {
7734 subchanged = true;
7735 insert(subDistributedVals, tmp);
7736 } else {
7737
7738 SetTy fuse;
7739 insert(fuse, l1);
7740 insert(fuse, r1);
7741 assert(fuse.size() == 2);
7742 Constraints trivialFuse(Type::Intersect, fuse);
7743 if ((trivialFuse != *tmp) || distributedVals.count(tmp)) {
7744 subchanged = true;
7745 }
7746 insert(subDistributedVals, tmp);
7747 }
7748 }
7749 if (subchanged) {
7750 for (auto sub : subDistributedVals)
7751 insert(distributedVals, sub);
7752 changed = true;
7753 } else {
7754 auto midand = l1->andB(rhs, ctxd);
7755 if (!midand) {
7756 changed = false;
7757 break;
7758 }
7759 insert(distributedVals, midand);
7760 }
7761 }
7762
7763 if (changed) {
7764 auto cur = Constraints::none();
7765 bool legal = true;
7766 for (auto &uv : distributedVals) {
7767 auto cur2 = cur->orB(uv, ctxd);
7768 if (!cur2) {
7769 legal = false;
7770 break;
7771 }
7772 cur = std::move(cur2);
7773 }
7774 if (legal) {
7775 return cur;
7776 }
7777 }
7778
7779 SetTy vals;
7780 insert(vals, shared_from_this());
7781 insert(vals, rhs);
7782 auto res = std::make_shared<Constraints>(Type::Intersect, vals);
7783 return res;
7784 }
7785 llvm::errs() << " andB this: " << *this << " rhs: " << *rhs << "\n";
7786 llvm_unreachable("Illegal predicate state");
7787 }
7788 // what this would be like when removing the following list of constraints
7789 InnerTy remove(const SetTy &sub) const {
7790 assert(ty == Type::Union || ty == Type::Intersect);
7791 SetTy res = values;
7792 set_subtract(res, sub);
7793 // res.set_subtract(sub);
7794 if (res.size() == 0) {
7795 if (ty == Type::Union)
7796 return Constraints::none();
7797 else
7798 return Constraints::all();
7799 } else if (res.size() == 1) {
7800 return *res.begin();
7801 } else {
7802 return std::make_shared<Constraints>(ty, res);
7803 }
7804 }
7805 SmallVector<std::pair<Value *, Value *>, 1>
7806 allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP,
7807 const ConstraintContext &ctx, IRBuilder<> &B) const;
7808};
7809
7810void dump(const Constraints &c) { c.dump(); }
7811void dump(std::shared_ptr<const Constraints> c) { c->dump(); }
7812
7814 std::shared_ptr<const Constraints> lhs,
7815 std::shared_ptr<const Constraints> rhs) const {
7816 return *lhs < *rhs;
7817}
7818
7819raw_ostream &operator<<(raw_ostream &os, const Constraints &c) {
7820 switch (c.ty) {
7822 return os << "All";
7824 return os << "None";
7826 os << "(Union ";
7827 for (auto v : c.values)
7828 os << *v << ", ";
7829 os << ")";
7830 return os;
7831 }
7833 os << "(Intersect ";
7834 for (auto v : c.values)
7835 os << *v << ", ";
7836 os << ")";
7837 return os;
7838 }
7840 if (c.isEqual)
7841 os << "(eq ";
7842 else
7843 os << "(ne ";
7844 os << *c.node << ", L=";
7845 if (c.Loop)
7846 os << c.Loop->getHeader()->getName();
7847 else
7848 os << "nullptr";
7849 return os << ")";
7850 }
7851 }
7852 return os;
7853}
7854
7855SmallVector<std::pair<Value *, Value *>, 1>
7856Constraints::allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP,
7857 const ConstraintContext &ctx, IRBuilder<> &B) const {
7858 switch (ty) {
7859 case Type::None:
7860 return {};
7861 case Type::All:
7862 llvm::errs() << *this << "\n";
7863 llvm_unreachable("All not handled");
7864 case Type::Compare: {
7865 Value *cond = ConstantInt::getTrue(T->getContext());
7866 if (ctx.loopToSolve != Loop) {
7867 assert(ctx.loopToSolve);
7868 Value *ivVal = Exp.expandCodeFor(node, T, IP);
7869 Value *iv = nullptr;
7870 if (Loop) {
7871 iv = Loop->getCanonicalInductionVariable();
7872 assert(iv);
7873 } else {
7874 iv = ConstantInt::getNullValue(ivVal->getType());
7875 }
7876 if (isEqual)
7877 cond = B.CreateICmpEQ(ivVal, iv);
7878 else
7879 cond = B.CreateICmpNE(ivVal, iv);
7880 return {std::make_pair((Value *)nullptr, cond)};
7881 }
7882 if (isEqual) {
7883 return {std::make_pair(Exp.expandCodeFor(node, T, IP), cond)};
7884 }
7885 EmitFailure("NoSparsification", IP->getDebugLoc(), IP,
7886 "Negated solution not handled: ", *this);
7887 assert(0);
7888 return {};
7889 }
7890 case Type::Union: {
7891 SmallVector<std::pair<Value *, Value *>, 1> vals;
7892 for (auto v : values)
7893 for (auto sol : v->allSolutions(Exp, T, IP, ctx, B))
7894 vals.push_back(sol);
7895 return vals;
7896 }
7897 case Type::Intersect: {
7898 {
7899 SmallVector<InnerTy, 1> vals(values.begin(), values.end());
7900 ssize_t unionidx = -1;
7901 for (unsigned i = 0; i < vals.size(); i++) {
7902 if (vals[i]->ty == Type::Union) {
7903 unionidx = i;
7904 bool allne = true;
7905 for (auto &v : vals[i]->values) {
7906 if (v->ty != Type::Compare || v->isEqual) {
7907 allne = false;
7908 break;
7909 }
7910 }
7911 if (allne)
7912 break;
7913 }
7914 }
7915 if (unionidx != -1) {
7916 auto others = Constraints::all();
7917 for (unsigned j = 0; j < vals.size(); j++)
7918 if (unionidx != j)
7919 others = others->andB(vals[j], ctx);
7920 SmallVector<std::pair<Value *, Value *>, 1> resvals;
7921 for (auto &v : vals[unionidx]->values) {
7922 auto tmp = v->andB(others, ctx);
7923 for (const auto &sol : tmp->allSolutions(Exp, T, IP, ctx, B))
7924 resvals.push_back(sol);
7925 }
7926 return resvals;
7927 }
7928 }
7929 Value *solVal = nullptr;
7930 Value *cond = ConstantInt::getTrue(T->getContext());
7931 for (auto v : values) {
7932 auto sols = v->allSolutions(Exp, T, IP, ctx, B);
7933 if (sols.size() != 1) {
7934 llvm::errs() << *this << "\n";
7935 for (auto s : sols)
7936 if (s.first)
7937 llvm::errs() << " + sol: " << *s.first << " " << *s.second << "\n";
7938 else
7939 llvm::errs() << " + sol: " << s.first << " " << *s.second << "\n";
7940 llvm::errs() << " v: " << *v << " this: " << *this << "\n";
7941 llvm_unreachable("Intersect not handled (solsize>1)");
7942 }
7943 auto sol = sols[0];
7944 if (sol.first) {
7945 if (solVal != nullptr) {
7946 llvm::errs() << *this << "\n";
7947 llvm::errs() << " prevsolVal: " << *solVal << "\n";
7948 llvm_unreachable("Intersect not handled (prevsolval)");
7949 }
7950 assert(solVal == nullptr);
7951 solVal = sol.first;
7952 }
7953 cond = B.CreateAnd(cond, sol.second);
7954 }
7955 return {std::make_pair(solVal, cond)};
7956 }
7957 }
7958 return {};
7959}
7960
7961constexpr bool SparseDebug = false;
7962std::shared_ptr<const Constraints>
7963getSparseConditions(bool &legal, Value *val,
7964 std::shared_ptr<const Constraints> defaultFloat,
7965 Instruction *scope, const ConstraintContext &ctx) {
7966 if (auto I = dyn_cast<Instruction>(val)) {
7967 // Binary `and` is a bit-wise `umin`.
7968 if (I->getOpcode() == Instruction::And) {
7969 auto lhs = getSparseConditions(legal, I->getOperand(0),
7970 Constraints::all(), I, ctx);
7971 auto rhs = getSparseConditions(legal, I->getOperand(1),
7972 Constraints::all(), I, ctx);
7973 auto res = lhs->andB(rhs, ctx);
7974 assert(res);
7975 assert(ctx.seen.size() == 0);
7976 if (SparseDebug) {
7977 llvm::errs() << " getSparse(and, " << *I << "), lhs("
7978 << *I->getOperand(0) << ") = " << *lhs << "\n";
7979 llvm::errs() << " getSparse(and, " << *I << "), rhs("
7980 << *I->getOperand(1) << ") = " << *rhs << "\n";
7981 llvm::errs() << " getSparse(and, " << *I << ") = " << *res << "\n";
7982 }
7983 return res;
7984 }
7985
7986 // Binary `or` is a bit-wise `umax`.
7987 if (I->getOpcode() == Instruction::Or) {
7988 auto lhs = getSparseConditions(legal, I->getOperand(0),
7989 Constraints::none(), I, ctx);
7990 auto rhs = getSparseConditions(legal, I->getOperand(1),
7991 Constraints::none(), I, ctx);
7992 auto res = lhs->orB(rhs, ctx);
7993 if (SparseDebug) {
7994 llvm::errs() << " getSparse(or, " << *I << "), lhs("
7995 << *I->getOperand(0) << ") = " << *lhs << "\n";
7996 llvm::errs() << " getSparse(or, " << *I << "), rhs("
7997 << *I->getOperand(1) << ") = " << *rhs << "\n";
7998 llvm::errs() << " getSparse(or, " << *I << ") = " << *res << "\n";
7999 }
8000 return res;
8001 }
8002
8003 if (I->getOpcode() == Instruction::Xor) {
8004 for (int i = 0; i < 2; i++) {
8005 if (auto C = dyn_cast<ConstantInt>(I->getOperand(i)))
8006 if (C->isOne()) {
8007 auto pres =
8008 getSparseConditions(legal, I->getOperand(1 - i),
8009 defaultFloat->notB(ctx), scope, ctx);
8010 auto res = pres->notB(ctx);
8011 if (SparseDebug) {
8012 llvm::errs() << " getSparse(not, " << *I << "), prev ("
8013 << *I->getOperand(0) << ") = " << *pres << "\n";
8014 llvm::errs() << " getSparse(not, " << *I << ") = " << *res
8015 << "\n";
8016 }
8017 return res;
8018 }
8019 }
8020 }
8021
8022 if (auto icmp = dyn_cast<ICmpInst>(I)) {
8023 auto L = ctx.loopToSolve;
8024 auto lhs = ctx.SE.getSCEVAtScope(icmp->getOperand(0), L);
8025 auto rhs = ctx.SE.getSCEVAtScope(icmp->getOperand(1), L);
8026 if (SparseDebug) {
8027 llvm::errs() << " lhs: " << *lhs << "\n";
8028 llvm::errs() << " rhs: " << *rhs << "\n";
8029 }
8030
8031 auto sub1 = ctx.SE.getMinusSCEV(lhs, rhs);
8032
8033 if (icmp->getPredicate() == ICmpInst::ICMP_EQ ||
8034 icmp->getPredicate() == ICmpInst::ICMP_NE) {
8035 if (auto add = dyn_cast<SCEVAddRecExpr>(sub1)) {
8036 if (add->isAffine()) {
8037 // 0 === A + B * inc -> -A / B = inc
8038 auto A = add->getStart();
8039 if (auto B =
8040 dyn_cast<SCEVConstant>(add->getStepRecurrence(ctx.SE))) {
8041
8042 auto MA = A;
8043 if (B->getAPInt().isNegative())
8044 B = cast<SCEVConstant>(ctx.SE.getNegativeSCEV(B));
8045 else
8046 MA = ctx.SE.getNegativeSCEV(A);
8047 auto div = ctx.SE.getUDivExpr(MA, B);
8048 auto div_e = ctx.SE.getUDivExactExpr(MA, B);
8049 if (div == div_e) {
8050 auto res = Constraints::make_compare(
8051 div, icmp->getPredicate() == ICmpInst::ICMP_EQ,
8052 add->getLoop(), ctx);
8053 if (SparseDebug) {
8054 llvm::errs()
8055 << " getSparse(icmp, " << *I << ") = " << *res << "\n";
8056 }
8057 return res;
8058 }
8059 }
8060 }
8061 }
8062 if (cannotDependOnLoopIV(sub1, ctx.loopToSolve)) {
8063 auto res = Constraints::make_compare(
8064 sub1, icmp->getPredicate() == ICmpInst::ICMP_EQ, nullptr, ctx);
8065 llvm::errs() << " getSparse(icmp_noloop, " << *I << ") = " << *res
8066 << "\n";
8067 return res;
8068 }
8069 }
8070 if (scope)
8071 EmitWarning("NoSparsification", *I,
8072 " No sparsification: not sparse solvable(icmp): ", *I,
8073 " via ", *sub1);
8074 if (SparseDebug) {
8075 llvm::errs() << " getSparse(icmp_dflt, " << *I
8076 << ") = " << *defaultFloat << "\n";
8077 }
8078 return defaultFloat;
8079 }
8080
8081 // cmp x, 1.0 -> false/true
8082 if (auto fcmp = dyn_cast<FCmpInst>(I)) {
8083 auto res = defaultFloat;
8084 if (SparseDebug) {
8085 llvm::errs() << " getSparse(fcmp, " << *I << ") = " << *res << "\n";
8086 }
8087 return res;
8088
8089 if (fcmp->getPredicate() == CmpInst::FCMP_OEQ ||
8090 fcmp->getPredicate() == CmpInst::FCMP_UEQ) {
8091 return Constraints::all();
8092 } else if (fcmp->getPredicate() == CmpInst::FCMP_ONE ||
8093 fcmp->getPredicate() == CmpInst::FCMP_UNE) {
8094 return Constraints::none();
8095 }
8096 }
8097 }
8098
8099 if (scope) {
8100 EmitFailure("NoSparsification", scope->getDebugLoc(), scope,
8101 " No sparsification: not sparse solvable: ", *val);
8102 }
8103 legal = false;
8104 return defaultFloat;
8105}
8106
8108 const llvm::Loop *Loop,
8109 const ConstraintContext &ctx) {
8110 if (!Loop) {
8111 assert(!isa<SCEVAddRecExpr>(v));
8112 SmallVector<Instruction *, 1> noassumption;
8113 ConstraintContext ctx2(ctx.SE, ctx.loopToSolve, noassumption, ctx.DT);
8114 for (auto I : ctx.Assumptions) {
8115 bool legal = true;
8116 if (I->getParent()->getParent() !=
8117 ctx.loopToSolve->getHeader()->getParent())
8118 continue;
8119 auto parsedCond = getSparseConditions(legal, I->getOperand(0),
8120 Constraints::none(), nullptr, ctx2);
8121 bool dominates = ctx.DT.dominates(I, ctx.loopToSolve->getHeader());
8122 if (legal && dominates) {
8123 if (parsedCond->ty == Type::Compare && !parsedCond->Loop) {
8124 if (parsedCond->node == v ||
8125 parsedCond->node == ctx.SE.getNegativeSCEV(v)) {
8126 InnerTy res;
8127 if (parsedCond->isEqual == isEqual)
8128 res = Constraints::all();
8129 else
8130 res = Constraints::none();
8131 return res;
8132 }
8133 }
8134 }
8135 }
8136 }
8137 // cannot have negative loop canonical induction var
8138 if (Loop)
8139 if (auto C = dyn_cast<SCEVConstant>(v))
8140 if (C->getAPInt().isNegative()) {
8141 if (isEqual)
8142 return Constraints::none();
8143 else
8144 return Constraints::all();
8145 }
8146 return InnerTy(new Constraints(v, isEqual, Loop, false));
8147}
8148
8149void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
8150 SetVector<BasicBlock *> &toDenseBlocks) {
8151
8152 auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
8153 auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F);
8154 auto &LI = FAM.getResult<LoopAnalysis>(F);
8155 auto &DL = F.getParent()->getDataLayout();
8156
8157 QueueType Q(DT, LI);
8158 {
8159 llvm::SetVector<BasicBlock *> todoBlocks;
8160 for (auto b : toDenseBlocks) {
8161 auto L = LI.getLoopFor(b);
8162 if (L) {
8163 for (auto B : L->getBlocks())
8164 todoBlocks.insert(B);
8165 }
8166 }
8167 for (auto BB : todoBlocks)
8168 for (auto &I : *BB)
8169 if (!I.getType()->isVoidTy()) {
8170 Q.insert(&I);
8171 assert(Q.contains(&I));
8172 }
8173 }
8174
8175 // llvm::errs() << " pre fix inner: " << F << "\n";
8176
8177 // Full simplification
8178 while (!Q.empty()) {
8179 auto cur = Q.pop_back_val();
8180 /*
8181 std::set<Instruction *> prev;
8182 for (auto v : Q)
8183 prev.insert(v);
8184 // llvm::errs() << "\n\n\n\n" << F << "\n";
8185 llvm::errs() << "cur: " << *cur << "\n";
8186 */
8187 auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL);
8188 (void)changed;
8189 /*
8190 if (changed) {
8191 llvm::errs() << "changed: " << *changed << "\n";
8192
8193 for (auto I : Q)
8194 if (!prev.count(I))
8195 llvm::errs() << " + " << *I << "\n";
8196 // llvm::errs() << F << "\n\n";
8197 }
8198 */
8199 }
8200
8201 // llvm::errs() << " post fix inner " << F << "\n";
8202
8203 SmallVector<std::pair<BasicBlock *, BranchInst *>, 1> sparseBlocks;
8204 bool legalToSparse = true;
8205 for (auto &B : F)
8206 if (auto br = dyn_cast<BranchInst>(B.getTerminator()))
8207 if (br->isConditional())
8208 for (int bidx = 0; bidx < 2; bidx++)
8209 if (auto uncond_br =
8210 dyn_cast<BranchInst>(br->getSuccessor(bidx)->getTerminator()))
8211 if (!uncond_br->isConditional())
8212 if (uncond_br->getSuccessor(0) == br->getSuccessor(1 - bidx)) {
8213 auto blk = br->getSuccessor(bidx);
8214 int countSparse = 0;
8215 for (auto &I : *blk) {
8216 if (auto CI = dyn_cast<CallInst>(&I)) {
8217 if (auto F = CI->getCalledFunction()) {
8218 if (F->hasFnAttribute("enzyme_sparse_accumulate")) {
8219 countSparse++;
8220 }
8221 }
8222 }
8223 }
8224 if (countSparse == 0)
8225 continue;
8226 if (countSparse > 1) {
8227 legalToSparse = false;
8229 "NoSparsification", br->getDebugLoc(), br, "F: ", F,
8230 "\nMultiple distinct sparse stores in same block: ",
8231 *blk);
8232 break;
8233 }
8234
8235 for (auto &I : *blk) {
8236 if (auto CI = dyn_cast<CallInst>(&I)) {
8237 if (auto F = CI->getCalledFunction()) {
8238 if (F->hasFnAttribute("enzyme_sparse_accumulate")) {
8239 continue;
8240 }
8241 }
8242 if (isReadOnly(CI))
8243 continue;
8244 }
8245 if (!I.mayWriteToMemory())
8246 continue;
8247
8248 legalToSparse = false;
8250 "NoSparsification", br->getDebugLoc(), br, "F: ", F,
8251 "\nIllegal writing instruction in sparse block: ", I);
8252 break;
8253 }
8254
8255 if (!legalToSparse) {
8256 break;
8257 }
8258
8259 auto L = LI.getLoopFor(blk);
8260 if (!L) {
8261 legalToSparse = false;
8262 EmitFailure("NoSparsification", br->getDebugLoc(), br,
8263 "F: ", F, "\nCould not find loop for: ", *blk);
8264 break;
8265 }
8266 auto idx = L->getCanonicalInductionVariable();
8267 if (!idx) {
8268 legalToSparse = false;
8269 EmitFailure("NoSparsification", br->getDebugLoc(), br,
8270 "F: ", F, "\nL:", *L,
8271 "\nCould not find loop index: ", *L->getHeader());
8272 break;
8273 }
8274 assert(idx);
8275 auto preheader = L->getLoopPreheader();
8276 if (!preheader) {
8277 legalToSparse = false;
8278 EmitFailure("NoSparsification", br->getDebugLoc(), br,
8279 "F: ", F, "\nL:", *L,
8280 "\nCould not find loop preheader");
8281 break;
8282 }
8283 sparseBlocks.emplace_back(blk, br);
8284 }
8285
8286 if (!legalToSparse) {
8287 return;
8288 }
8289
8290 // block, bound, scev for indexset
8291 std::map<Loop *,
8292 std::pair<std::pair<PHINode *, PHINode *>,
8293 SmallVector<std::pair<BasicBlock *,
8294 std::shared_ptr<const Constraints>>,
8295 1>>>
8296 forSparsification;
8297
8298 SmallVector<Instruction *, 1> Assumptions;
8299 for (auto &BB : F)
8300 for (auto &I : BB)
8301 if (auto II = dyn_cast<IntrinsicInst>(&I))
8302 if (II->getIntrinsicID() == Intrinsic::assume)
8303 Assumptions.push_back(II);
8304
8305 bool sawError = false;
8306
8307 for (auto [blk, br] : sparseBlocks) {
8308 auto L = LI.getLoopFor(blk);
8309 assert(L);
8310 auto idx = L->getCanonicalInductionVariable();
8311 assert(idx);
8312 auto preheader = L->getLoopPreheader();
8313 assert(preheader);
8314
8315 // default is condition avoids sparse, negated is condition goes
8316 // to sparse
8317 auto cond = br->getCondition();
8318 bool negated = br->getSuccessor(0) == blk;
8319
8320 bool legal = true;
8321 // Whether the i1 value does not contain any icmp's
8322 std::function<bool(Value *)> onlyDataDependentValues = [&](Value *val) {
8323 auto I = cast<Instruction>(val);
8324 if (I->getOpcode() == Instruction::Or) {
8325 return onlyDataDependentValues(I->getOperand(0)) &&
8326 onlyDataDependentValues(I->getOperand(1));
8327 }
8328 if (I->getOpcode() == Instruction::And) {
8329 return onlyDataDependentValues(I->getOperand(0)) &&
8330 onlyDataDependentValues(I->getOperand(1));
8331 }
8332 if (isa<FCmpInst>(I))
8333 return true;
8334 if (isa<ICmpInst>(I))
8335 return false;
8336 EmitFailure("NoSparsification", I->getDebugLoc(), I,
8337 " No sparsification: bad datadepedent values check: ", *I);
8338 legal = false;
8339 return true;
8340 };
8341
8342 // Simplify variable val which is known to branch away from the
8343 // actual store (if not negated) or to the store (if negated)
8344 // if! negated the result may become more false if negated the
8345 // result may become more true
8346
8347 //
8348
8349 // default is condition avoids sparse, negated is condition goes
8350 // to sparse
8351 Instruction *context =
8352 isa<Instruction>(cond) ? cast<Instruction>(cond) : idx;
8353 ConstraintContext cctx(SE, L, Assumptions, DT);
8354 auto solutions = getSparseConditions(
8355 legal, cond, negated ? Constraints::all() : Constraints::none(),
8356 context, cctx);
8357 // llvm::errs() << " solutions pre negate: " << *solutions << "\n";
8358 if (!negated) {
8359 solutions = solutions->notB(cctx);
8360 }
8361 // llvm::errs() << " solutions post negate: " << *solutions << "\n";
8362 if (!legal) {
8363 sawError = true;
8364 continue;
8365 }
8366
8367 if (solutions == Constraints::none() || solutions == Constraints::all()) {
8369 "NoSparsification", context->getDebugLoc(), context, "F: ", F,
8370 "\nL: ", *L, "\ncond: ", *cond, " negated:", negated,
8371 "\n No sparsification: not sparse solvable(nosoltn): solutions:",
8372 *solutions);
8373 sawError = true;
8374 }
8375 // llvm::errs() << " found solvable solutions " << *solutions << "\n";
8376
8377 if (forSparsification.count(L) == 0) {
8378 {
8379 IRBuilder<> PB(preheader->getTerminator());
8380 forSparsification[L].first =
8381 std::make_pair(PB.CreatePHI(idx->getType(), 0, "ph.idx"),
8382 PB.CreatePHI(idx->getType(), 0, "loop.idx"));
8383 }
8384
8385 Value *LoopCount = nullptr;
8386
8387 IRBuilder<> B(L->getHeader()->getFirstNonPHI());
8388 {
8389#if LLVM_VERSION_MAJOR >= 22
8390 SCEVExpander Exp(SE, "sparseenzyme");
8391#else
8392 SCEVExpander Exp(SE, DL, "sparseenzyme");
8393#endif
8394 auto LoopCountS = SE.getBackedgeTakenCount(L);
8395 LoopCount = B.CreateAdd(
8396 ConstantInt::get(idx->getType(), 1),
8397 Exp.expandCodeFor(LoopCountS, idx->getType(), &blk->front()));
8398 }
8399 Value *inbounds = B.CreateAnd(
8400 B.CreateICmpSLT(idx, LoopCount),
8401 B.CreateICmpSGE(idx, ConstantInt::get(idx->getType(), 0)));
8402 Value *args[] = {inbounds, forSparsification[L].first.second};
8403 B.CreateCall(F.getParent()->getOrInsertFunction(
8404 "enzyme.sparse.inbounds", B.getVoidTy(),
8405 inbounds->getType(), idx->getType()),
8406 args);
8407 }
8408
8409 IRBuilder<> B(br);
8410 B.SetInsertPoint(br);
8411 auto nidx = B.CreateICmpEQ(
8412 forSparsification[L].first.first,
8413 ConstantInt::get(idx->getType(), forSparsification[L].second.size()));
8414 // TODO check direction
8415 if (!negated)
8416 nidx = B.CreateNot(nidx);
8417
8418 br->setCondition(nidx);
8419 forSparsification[L].second.emplace_back(blk, solutions);
8420 }
8421
8422 if (sawError) {
8423 for (auto &pair : forSparsification) {
8424 for (auto PN : {pair.second.first.first, pair.second.first.second}) {
8425 PN->replaceAllUsesWith(UndefValue::get(PN->getType()));
8426 PN->eraseFromParent();
8427 }
8428 }
8429 if (llvm::verifyFunction(F, &llvm::errs())) {
8430 llvm::errs() << F << "\n";
8431 report_fatal_error("function failed verification (6)");
8432 }
8433 return;
8434 }
8435
8436 if (forSparsification.size() == 0) {
8437 auto context = &F.getEntryBlock().front();
8438 EmitFailure("NoSparsification", context->getDebugLoc(), context, "F: ", F,
8439 "\n Found no stores for sparsification");
8440 return;
8441 }
8442
8443 for (const auto &pair : forSparsification) {
8444 auto L = pair.first;
8445 auto [PN, inductPN] = pair.second.first;
8446
8447 auto ph = L->getLoopPreheader();
8448#if LLVM_VERSION_MAJOR >= 20
8449 CodeExtractor ext(L->getBlocks(), &DT);
8450#else
8451 CodeExtractor ext(DT, *L);
8452#endif
8453 CodeExtractorAnalysisCache cache(F);
8454 SetVector<Value *> Inputs, Outputs;
8455 auto F2 = ext.extractCodeRegion(cache, Inputs, Outputs);
8456 assert(F2);
8457 F2->addFnAttr(Attribute::AlwaysInline);
8458
8459 for (auto U : F2->users())
8460 cast<Instruction>(U)->eraseFromParent();
8461
8462 ssize_t induct_idx = -1;
8463 ssize_t off_idx = -1;
8464 for (auto en : llvm::enumerate(Inputs)) {
8465 if (en.value() == inductPN)
8466 induct_idx = en.index();
8467 if (en.value() == PN)
8468 off_idx = en.index();
8469 }
8470 assert(induct_idx != -1);
8471 assert(off_idx != -1);
8472
8473 auto L2 = LI.getLoopFor(F2->getEntryBlock().getSingleSuccessor());
8474 auto new_idx = F2->getArg(induct_idx);
8475 auto L2Header = L2->getHeader();
8476 auto new_lidx = L2->getCanonicalInductionVariable();
8477
8478 auto idxty = new_idx->getType();
8479
8480 auto new_pn = F2->getArg(off_idx);
8481 // Find all sparse accumulates we weren't meant to handle
8482 {
8483 SmallVector<CallInst *, 1> toErase;
8484 // First delete any accumulates in sub loops
8485 for (auto SL : L2->getSubLoops())
8486 for (auto B : SL->getBlocks())
8487 for (auto &I : *B)
8488 if (auto CI = dyn_cast<CallInst>(&I))
8489 if (auto F = CI->getCalledFunction()) {
8490 if (F->hasFnAttribute("enzyme_sparse_accumulate")) {
8491 toErase.push_back(CI);
8492 continue;
8493 }
8494 }
8495 for (auto C : toErase)
8496 C->eraseFromParent();
8497 toErase.clear();
8498 // Next delete any accumulates not in latchany loops
8499 for (auto B : L2->getBlocks()) {
8500 bool guarded = false;
8501 if (auto P = B->getSinglePredecessor())
8502 if (auto S = B->getSingleSuccessor())
8503 if (auto BI = dyn_cast<BranchInst>(P->getTerminator()))
8504 if (BI->isConditional())
8505 for (size_t i = 0; i < 2; i++)
8506 if (BI->getSuccessor(i) == B &&
8507 BI->getSuccessor(1 - i) == S) {
8508 auto val = BI->getCondition();
8509 if (auto xori = dyn_cast<Instruction>(val))
8510 if (xori->getOpcode() == Instruction::Xor)
8511 val = xori->getOperand(0);
8512 if (auto cmp = dyn_cast<ICmpInst>(val))
8513 if (cmp->getOperand(0) == new_pn ||
8514 cmp->getOperand(1) == new_pn)
8515 guarded = true;
8516 }
8517 if (guarded)
8518 continue;
8519 for (auto &I : *B)
8520 if (auto CI = dyn_cast<CallInst>(&I))
8521 if (auto F = CI->getCalledFunction()) {
8522 if (F->hasFnAttribute("enzyme_sparse_accumulate")) {
8523 toErase.push_back(CI);
8524 continue;
8525 }
8526 }
8527 }
8528 for (auto C : toErase)
8529 C->eraseFromParent();
8530 toErase.clear();
8531 }
8532
8533 auto guard = L2->getLoopLatch()->getTerminator();
8534 assert(guard);
8535 IRBuilder<> G(guard);
8536 G.CreateRetVoid();
8537 guard->eraseFromParent();
8538 new_lidx->replaceAllUsesWith(new_idx);
8539 new_lidx->eraseFromParent();
8540
8541 auto phterm = ph->getTerminator();
8542 IRBuilder<> B(phterm);
8543
8544 // We extracted code, reset analyses.
8545 /*
8546 DT.reset();
8547 SE.forgetAllLoops();
8548 */
8549
8550 for (auto en : llvm::enumerate(pair.second.second)) {
8551 auto off = en.index();
8552 auto &solutions = en.value().second;
8553 ConstraintContext ctx(SE, L, Assumptions, DT);
8554#if LLVM_VERSION_MAJOR >= 22
8555 SCEVExpander Exp(SE, "sparseenzyme", /*preservelcssa*/ false);
8556#else
8557 SCEVExpander Exp(SE, DL, "sparseenzyme", /*preservelcssa*/ false);
8558#endif
8559 auto sols = solutions->allSolutions(Exp, idxty, phterm, ctx, B);
8560 SmallVector<Value *, 1> prevSols;
8561 for (auto [sol, condition] : sols) {
8562 SmallVector<Value *, 1> args(Inputs.begin(), Inputs.end());
8563 args[off_idx] = ConstantInt::get(idxty, off);
8564 args[induct_idx] = sol;
8565 for (auto sol2 : prevSols)
8566 condition = B.CreateAnd(condition, B.CreateICmpNE(sol, sol2));
8567 prevSols.push_back(sol);
8568 auto BB = B.GetInsertBlock();
8569 auto B2 = BB->splitBasicBlock(B.GetInsertPoint(), "poststore");
8570 B2->moveAfter(BB);
8571 BB->getTerminator()->eraseFromParent();
8572 B.SetInsertPoint(BB);
8573 auto callB = BasicBlock::Create(BB->getContext(), "tostore",
8574 BB->getParent(), B2);
8575 B.CreateCondBr(condition, callB, B2);
8576 B.SetInsertPoint(callB);
8577 B.CreateCall(F2, args);
8578 B.CreateBr(B2);
8579 B.SetInsertPoint(B2->getTerminator());
8580 }
8581 auto blk = en.value().first;
8582 auto term = blk->getTerminator();
8583 IRBuilder<> B2(blk);
8584 B2.CreateRetVoid();
8585 term->eraseFromParent();
8586 }
8587
8588 PN->eraseFromParent();
8589
8590 for (auto &I : *L2Header) {
8591 auto boundsCheck = dyn_cast<CallInst>(&I);
8592 if (!boundsCheck)
8593 continue;
8594 auto BF = boundsCheck->getCalledFunction();
8595 if (!BF)
8596 continue;
8597 if (BF->getName() != "enzyme.sparse.inbounds")
8598 continue;
8599
8600 auto boundsCond = boundsCheck->getArgOperand(0);
8601
8602 auto next = L2Header->splitBasicBlock(boundsCheck);
8603
8604 auto exit = BasicBlock::Create(F2->getContext(), "bounds.exit", F2,
8605 L2Header->getNextNode());
8606 {
8607 IRBuilder B(exit);
8608 B.CreateRetVoid();
8609 }
8610 L2Header->getTerminator()->eraseFromParent();
8611
8612 {
8613 IRBuilder B(L2Header);
8614 B.CreateCondBr(boundsCond, next, exit);
8615 }
8616 boundsCheck->eraseFromParent();
8617 inductPN->eraseFromParent();
8618
8619 break;
8620 }
8621 }
8622
8623 for (auto &F2 : F.getParent()->functions()) {
8624 if (startsWith(F2.getName(), "__enzyme_product")) {
8625 SmallVector<Instruction *, 1> toErase;
8626 for (llvm::User *I : F2.users()) {
8627 auto CB = cast<CallBase>(I);
8628 IRBuilder<> B(CB);
8629 B.setFastMathFlags(getFast());
8630 Value *res = nullptr;
8631 for (auto v : callOperands(CB)) {
8632 if (res == nullptr)
8633 res = v;
8634 else {
8635 res = B.CreateFMul(res, v);
8636 }
8637 }
8638 CB->replaceAllUsesWith(res);
8639 toErase.push_back(CB);
8640 }
8641 for (auto CB : toErase)
8642 CB->eraseFromParent();
8643 } else if (startsWith(F2.getName(), "__enzyme_sum")) {
8644 SmallVector<Instruction *, 1> toErase;
8645 for (llvm::User *I : F2.users()) {
8646 auto CB = cast<CallBase>(I);
8647 IRBuilder<> B(CB);
8648 B.setFastMathFlags(getFast());
8649 Value *res = nullptr;
8650 for (auto v : callOperands(CB)) {
8651 if (res == nullptr)
8652 res = v;
8653 else {
8654 res = B.CreateFAdd(res, v);
8655 }
8656 }
8657 CB->replaceAllUsesWith(res);
8658 toErase.push_back(CB);
8659 }
8660 for (auto CB : toErase)
8661 CB->eraseFromParent();
8662 }
8663 }
8664}
8665
8666void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F,
8667 const llvm::DataLayout &DL) {
8668 auto load_fn = cast<Function>(getBaseObject(CI->getArgOperand(0)));
8669 auto store_fn = cast<Function>(getBaseObject(CI->getArgOperand(1)));
8670 size_t argstart = 2;
8671 size_t num_args = CI->arg_size();
8672 SmallVector<std::pair<Instruction *, Value *>, 1> users;
8673
8674 for (auto U : CI->users()) {
8675 users.push_back(std::make_pair(cast<Instruction>(U), CI));
8676 }
8677 IntegerType *intTy = IntegerType::get(CI->getContext(), 64);
8678 auto toInt = [&](IRBuilder<> &B, llvm::Value *V) {
8679 if (auto PT = dyn_cast<PointerType>(V->getType())) {
8680 if (PT->getAddressSpace() != 0) {
8681#if LLVM_VERSION_MAJOR < 17
8682 if (CI->getContext().supportsTypedPointers()) {
8683 V = B.CreateAddrSpaceCast(V, getUnqual(PT->getPointerElementType()));
8684 } else {
8685 V = B.CreateAddrSpaceCast(V,
8686 PointerType::getUnqual(PT->getContext()));
8687 }
8688#else
8689 V = B.CreateAddrSpaceCast(V, PointerType::getUnqual(PT->getContext()));
8690#endif
8691 }
8692 return B.CreatePtrToInt(V, intTy);
8693 }
8694 auto IT = cast<IntegerType>(V->getType());
8695 if (IT == intTy)
8696 return V;
8697 return B.CreateZExtOrTrunc(V, intTy);
8698 };
8699 SmallVector<Instruction *, 1> toErase;
8700
8701 ValueToValueMapTy replacements;
8702 replacements[CI] = Constant::getNullValue(CI->getType());
8703 Instruction *remaining = nullptr;
8704 while (users.size()) {
8705 auto pair = users.back();
8706 users.pop_back();
8707 auto U = pair.first;
8708 auto val = pair.second;
8709 if (replacements.count(U))
8710 continue;
8711
8712 IRBuilder B(U);
8713 if (auto CI = dyn_cast<CastInst>(U)) {
8714 for (auto U : CI->users()) {
8715 users.push_back(std::make_pair(cast<Instruction>(U), CI));
8716 }
8717 auto rep =
8718 B.CreateCast(CI->getOpcode(), replacements[val], CI->getDestTy());
8719 if (auto I = dyn_cast<Instruction>(rep))
8720 I->setDebugLoc(CI->getDebugLoc());
8721 replacements[CI] = rep;
8722 continue;
8723 }
8724 if (auto SI = dyn_cast<SelectInst>(U)) {
8725 for (auto U : SI->users()) {
8726 users.push_back(std::make_pair(cast<Instruction>(U), SI));
8727 }
8728 auto tval = SI->getTrueValue();
8729 auto fval = SI->getFalseValue();
8730 auto rep = B.CreateSelect(
8731 SI->getCondition(),
8732 replacements.count(tval) ? (Value *)replacements[tval] : tval,
8733 replacements.count(fval) ? (Value *)replacements[fval] : fval);
8734 if (auto I = dyn_cast<Instruction>(rep))
8735 I->setDebugLoc(SI->getDebugLoc());
8736 replacements[SI] = rep;
8737 continue;
8738 }
8739 /*
8740 if (auto CI = dyn_cast<PHINode>(U)) {
8741 for (auto U : CI->users()) {
8742 users.push_back(std::make_pair(cast<Instruction>(U), CI));
8743 }
8744 continue;
8745 }
8746 */
8747 if (auto CI = dyn_cast<CallInst>(U)) {
8748 auto funcName = getFuncNameFromCall(CI);
8749 if (funcName == "julia.pointer_from_objref") {
8750 for (auto U : CI->users()) {
8751 users.push_back(std::make_pair(cast<Instruction>(U), CI));
8752 }
8753 auto *F = CI->getCalledOperand();
8754
8755 SmallVector<Value *, 1> args;
8756 for (auto &arg : CI->args())
8757 args.push_back(replacements[arg]);
8758
8759 auto FT = CI->getFunctionType();
8760
8761 auto cal = cast<CallInst>(B.CreateCall(FT, F, args));
8762 cal->setCallingConv(CI->getCallingConv());
8763 cal->setDebugLoc(CI->getDebugLoc());
8764 replacements[CI] = cal;
8765 continue;
8766 }
8767 }
8768 if (auto CI = dyn_cast<GetElementPtrInst>(U)) {
8769 for (auto U : CI->users()) {
8770 users.push_back(std::make_pair(cast<Instruction>(U), CI));
8771 }
8772 SmallVector<Value *, 1> inds;
8773 bool allconst = true;
8774 for (auto &ind : CI->indices()) {
8775 if (!isa<ConstantInt>(ind)) {
8776 allconst = false;
8777 }
8778 inds.push_back(ind);
8779 }
8780 Value *gep;
8781
8782 if (inds.size() == 1) {
8783 gep = ConstantInt::get(
8784 intTy, (DL.getTypeSizeInBits(CI->getSourceElementType()) + 7) / 8);
8785 gep = B.CreateMul(intTy == inds[0]->getType()
8786 ? inds[0]
8787 : B.CreateZExtOrTrunc(inds[0], intTy),
8788 gep, "", true, true);
8789 gep = B.CreateAdd(B.CreatePtrToInt(replacements[val], intTy), gep);
8790 gep = B.CreateIntToPtr(gep, CI->getType());
8791 } else if (!allconst) {
8792 gep = B.CreateGEP(CI->getSourceElementType(), replacements[val], inds);
8793 if (auto ge = cast<GetElementPtrInst>(gep))
8794 ge->setIsInBounds(CI->isInBounds());
8795 } else {
8796 APInt ai(64, 0);
8797 CI->accumulateConstantOffset(DL, ai);
8798 gep = B.CreateIntToPtr(ConstantInt::get(intTy, ai), CI->getType());
8799 }
8800 if (auto I = dyn_cast<Instruction>(gep))
8801 I->setDebugLoc(CI->getDebugLoc());
8802 replacements[CI] = gep;
8803 continue;
8804 }
8805 if (auto LI = dyn_cast<LoadInst>(U)) {
8806 auto diff = toInt(B, replacements[LI->getPointerOperand()]);
8807 SmallVector<Value *, 2> args;
8808 args.push_back(diff);
8809 for (size_t i = argstart; i < num_args; i++)
8810 args.push_back(CI->getArgOperand(i));
8811
8812 if (load_fn->getFunctionType()->getNumParams() != args.size()) {
8813 auto fnName = load_fn->getName();
8814 auto found_numargs = load_fn->getFunctionType()->getNumParams();
8815 auto expected_numargs = args.size();
8816 EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
8817 " incorrect number of arguments to loader function ",
8818 fnName, " expected ", expected_numargs, " found ",
8819 found_numargs, " - ", *load_fn->getFunctionType());
8820 continue;
8821 } else {
8822 bool tocontinue = false;
8823 for (size_t i = 0; i < args.size(); i++) {
8824 if (load_fn->getFunctionType()->getParamType(i) !=
8825 args[i]->getType()) {
8826 auto fnName = load_fn->getName();
8827 EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
8828 " incorrect type of argument ", i,
8829 " to loader function ", fnName, " expected ",
8830 *args[i]->getType(), " found ",
8831 load_fn->getFunctionType()->params()[i]);
8832 tocontinue = true;
8833 args[i] = UndefValue::get(args[i]->getType());
8834 }
8835 }
8836 if (tocontinue)
8837 continue;
8838 }
8839 CallInst *call = B.CreateCall(load_fn, args);
8840 call->setDebugLoc(LI->getDebugLoc());
8841 Value *tmp = call;
8842 if (tmp->getType() != LI->getType()) {
8843 if (CastInst::castIsValid(Instruction::BitCast, tmp, LI->getType()))
8844 tmp = B.CreateBitCast(tmp, LI->getType());
8845 else {
8846 auto fnName = load_fn->getName();
8847 EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
8848 " incorrect return type of loader function ", fnName,
8849 " expected ", *LI->getType(), " found ",
8850 *call->getType());
8851 tmp = UndefValue::get(LI->getType());
8852 }
8853 }
8854 LI->replaceAllUsesWith(tmp);
8855
8856 if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) {
8857 InlineFunctionInfo IFI;
8858 InlineFunction(*call, IFI);
8859 }
8860 toErase.push_back(LI);
8861 continue;
8862 }
8863 if (auto SI = dyn_cast<StoreInst>(U)) {
8864 assert(SI->getValueOperand() != val);
8865 auto diff = toInt(B, replacements[SI->getPointerOperand()]);
8866 SmallVector<Value *, 2> args;
8867 args.push_back(SI->getValueOperand());
8868 auto sty = store_fn->getFunctionType()->getParamType(0);
8869 if (args[0]->getType() != store_fn->getFunctionType()->getParamType(0)) {
8870 if (CastInst::castIsValid(Instruction::BitCast, args[0], sty))
8871 args[0] = B.CreateBitCast(args[0], sty);
8872 else {
8873 auto args0ty = args[0]->getType();
8874 EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
8875 " first argument of store function must be the type of "
8876 "the store found fn arg type ",
8877 *sty, " expected ", *args0ty);
8878 args[0] = UndefValue::get(sty);
8879 }
8880 }
8881 args.push_back(diff);
8882 for (size_t i = argstart; i < num_args; i++)
8883 args.push_back(CI->getArgOperand(i));
8884
8885 if (store_fn->getFunctionType()->getNumParams() != args.size()) {
8886 auto fnName = store_fn->getName();
8887 auto found_numargs = store_fn->getFunctionType()->getNumParams();
8888 auto expected_numargs = args.size();
8889 EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
8890 " incorrect number of arguments to store function ", fnName,
8891 " expected ", expected_numargs, " found ", found_numargs,
8892 " - ", *store_fn->getFunctionType());
8893 continue;
8894 } else {
8895 bool tocontinue = false;
8896 for (size_t i = 0; i < args.size(); i++) {
8897 if (store_fn->getFunctionType()->getParamType(i) !=
8898 args[i]->getType()) {
8899 auto fnName = store_fn->getName();
8900 EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
8901 " incorrect type of argument ", i,
8902 " to storeer function ", fnName, " expected ",
8903 *args[i]->getType(), " found ",
8904 store_fn->getFunctionType()->params()[i]);
8905 tocontinue = true;
8906 args[i] = UndefValue::get(args[i]->getType());
8907 }
8908 }
8909 if (tocontinue)
8910 continue;
8911 }
8912 auto call = B.CreateCall(store_fn, args);
8913 call->setDebugLoc(SI->getDebugLoc());
8914 if (store_fn->hasFnAttribute(Attribute::AlwaysInline)) {
8915 InlineFunctionInfo IFI;
8916 InlineFunction(*call, IFI);
8917 }
8918 toErase.push_back(SI);
8919 continue;
8920 }
8921 remaining = U;
8922 }
8923 for (auto U : toErase)
8924 U->eraseFromParent();
8925
8926 if (!remaining) {
8927 CI->replaceAllUsesWith(Constant::getNullValue(CI->getType()));
8928 CI->eraseFromParent();
8929 } else if (replaceAll) {
8930 EmitFailure("IllegalSparse", remaining->getDebugLoc(), remaining,
8931 " Illegal remaining use (", *remaining, ") of todense (", *CI,
8932 ") in function ", *F);
8933 }
8934}
8935
8936bool LowerSparsification(llvm::Function *F, bool replaceAll) {
8937 auto &DL = F->getParent()->getDataLayout();
8938 bool changed = false;
8939 SmallVector<CallBase *, 1> todo;
8940 SetVector<BasicBlock *> toDenseBlocks;
8941 for (auto &BB : *F) {
8942 for (auto &I : BB) {
8943 if (auto CI = dyn_cast<CallInst>(&I)) {
8944 if (getFuncNameFromCall(CI).contains("__enzyme_todense")) {
8945 todo.push_back(CI);
8946 toDenseBlocks.insert(&BB);
8947 }
8948 }
8949 }
8950 }
8951 for (auto CI : todo) {
8952 changed = true;
8953 replaceToDense(CI, replaceAll, F, DL);
8954 }
8955 todo.clear();
8956
8957 if (changed && EnzymeAutoSparsity) {
8958 PassBuilder PB;
8959 LoopAnalysisManager LAM;
8960 FunctionAnalysisManager FAM;
8961 CGSCCAnalysisManager CGAM;
8962 ModuleAnalysisManager MAM;
8963 PB.registerModuleAnalyses(MAM);
8964 PB.registerFunctionAnalyses(FAM);
8965 PB.registerLoopAnalyses(LAM);
8966 PB.registerCGSCCAnalyses(CGAM);
8967 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
8968
8969 SimplifyCFGPass(SimplifyCFGOptions()).run(*F, FAM);
8970 InstCombinePass().run(*F, FAM);
8971 // required to make preheaders
8972 LoopSimplifyPass().run(*F, FAM);
8973 fixSparseIndices(*F, FAM, toDenseBlocks);
8974 }
8975
8976 for (auto &BB : *F) {
8977 for (auto &I : BB) {
8978 if (auto CI = dyn_cast<CallInst>(&I)) {
8979 if (getFuncNameFromCall(CI).contains("__enzyme_post_sparse_todense")) {
8980 todo.push_back(CI);
8981 }
8982 }
8983 }
8984 }
8985 for (auto CI : todo) {
8986 changed = true;
8987 replaceToDense(CI, replaceAll, F, DL);
8988 }
8989 return changed;
8990}
static RegisterPass< ActivityAnalysisPrinter > X("print-activity-analysis", "Print Activity Analysis Results")
std::pair< PHINode *, Instruction * > InsertNewCanonicalIV(Loop *L, Type *Ty, const llvm::Twine &Name)
void RemoveRedundantIVs(BasicBlock *Header, PHINode *CanonicalIV, Instruction *Increment, MustExitScalarEvolution &SE, llvm::function_ref< void(Instruction *, Value *)> replacer, llvm::function_ref< void(Instruction *)> eraser)
llvm::cl::opt< bool > EnzymeZeroCache
static bool contains(ArrayRef< int > ar, int v)
static bool isZero(llvm::Constant *cst)
bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val)
Is the use of value val as an argument of call CI potentially captured.
cl::opt< bool > EnzymeNoAlias("enzyme-noalias", cl::init(false), cl::Hidden, cl::desc("Force noalias of autodiff"))
cl::opt< bool > EnzymeAggressiveAA("enzyme-aggressive-aa", cl::init(false), cl::Hidden, cl::desc("Use more unstable but aggressive LLVM AA"))
static cl::opt< bool > EnzymePHIRestructure("enzyme-phi-restructure", cl::init(false), cl::Hidden, cl::desc("Whether to restructure phi's to have better unwrap behavior"))
std::shared_ptr< const Constraints > getSparseConditions(bool &legal, Value *val, std::shared_ptr< const Constraints > defaultFloat, Instruction *scope, const ConstraintContext &ctx)
static void ForceRecursiveInlining(Function *NewF, size_t Limit)
Perform recursive inlinining on NewF up to the given limit.
RecurType
@ DefinitelyRecursive
@ MaybeRecursive
@ NotRecursive
FunctionType * getFunctionTypeForClone(llvm::FunctionType *FTy, DerivativeMode mode, unsigned width, llvm::Type *additionalArg, llvm::ArrayRef< DIFFE_TYPE > constant_args, bool diffeReturnArg, bool returnTape, bool returnPrimal, bool returnShadow)
void RemoveRedundantPHI(Function *F, FunctionAnalysisManager &FAM)
void setFullWillReturn(Function *NewF)
cl::opt< bool > EnzymeSelectOpt("enzyme-select-opt", cl::init(true), cl::Hidden, cl::desc("Run Enzyme select optimization"))
void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, SetVector< BasicBlock * > &toDenseBlocks)
void SplitPHIs(llvm::Function &F)
cl::opt< bool > EnzymePreopt("enzyme-preopt", cl::init(true), cl::Hidden, cl::desc("Run enzyme preprocessing optimizations"))
void ReplaceFunctionImplementation(Module &M)
bool guaranteedDataDependent(Value *z)
cl::opt< bool > EnzymeCoalese("enzyme-coalese", cl::init(false), cl::Hidden, cl::desc("Whether to coalese memory allocations"))
cl::opt< bool > EnzymeInline("enzyme-inline", cl::init(false), cl::Hidden, cl::desc("Force inlining of autodiff"))
bool cannotDependOnLoopIV(const SCEV *S, const Loop *L)
constexpr bool SparseDebug
std::optional< std::string > fixSparse_inner(Instruction *cur, llvm::Function &F, QueueType &Q, DominatorTree &DT, ScalarEvolution &SE, LoopInfo &LI, const DataLayout &DL)
bool DetectPointerArgOfFn(llvm::Function &F, SmallPtrSetImpl< Function * > &calls_todo)
DominatorOrderSet QueueType
bool directlySparse(Value *z)
void CoaleseTrivialMallocs(Function &F, DominatorTree &DT)
SmallVector< Value *, 1 > callOperands(llvm::CallBase *CB)
void RecursivelyReplaceAddressSpace(SmallVector< std::tuple< Value *, Value *, Instruction * >, 1 > &Todo, SmallVector< Instruction *, 1 > &toErase, bool legal)
Function * CreateMPIWrapper(Function *F)
void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F, const llvm::DataLayout &DL)
cl::opt< bool > EnzymeNameInstructions("enzyme-name-instructions", cl::init(false), cl::Hidden, cl::desc("Have enzyme name all instructions"))
void simplifyExtractions(Function *NewF)
static AllocaInst * OldAllocationSize(Value *Ptr, CallInst *Loc, Function *NewF, IntegerType *T, const std::map< CallInst *, Value * > &reallocSizes)
void dump(const Constraints &c)
bool DetectNoUnwindOfFn(llvm::Function &F, SmallPtrSetImpl< Function * > &calls_todo)
static bool OnlyUsedInOMP(AllocaInst *AI)
void CanonicalizeLoops(Function *F, FunctionAnalysisManager &FAM)
bool DetectReadonlyOrThrow(Module &M)
static bool IsFunctionRecursive(Function *F, std::map< const Function *, RecurType > &Results)
Return whether this function eventually calls itself.
CallInst * isSum(llvm::Value *v)
static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM)
bool LowerSparsification(llvm::Function *F, bool replaceAll)
Lower __enzyme_todense, returning if changed.
static bool isNot(Value *a, Value *b)
Function * getProductIntrinsic(llvm::Module &M, llvm::Type *T)
cl::opt< int > EnzymePostInlineOpt("enzyme-post-inline-opt", cl::init(0), cl::Hidden, cl::desc("Force inlining of autodiff"))
cl::opt< int > EnzymeInlineCount("enzyme-inline-count", cl::init(10000), cl::Hidden, cl::desc("Limit of number of functions to inline"))
CallInst * isProduct(llvm::Value *v)
raw_ostream & operator<<(raw_ostream &os, const Constraints &c)
void SelectOptimization(Function *F)
cl::opt< bool > EnzymeAutoSparsity("enzyme-auto-sparsity", cl::init(false), cl::Hidden, cl::desc("Run Enzyme auto sparsity"))
const SCEV * evaluateAtLoopIter(const SCEV *V, ScalarEvolution &SE, const Loop *find, const SCEV *replace)
static void UpgradeAllocasToMallocs(Function *NewF, DerivativeMode mode, SmallPtrSetImpl< llvm::BasicBlock * > &Unreachable)
Convert necessary stack allocations into mallocs for use in the reverse pass.
cl::opt< bool > EnzymeLowerGlobals("enzyme-lower-globals", cl::init(false), cl::Hidden, cl::desc("Lower globals to locals assuming the global values are not " "needed outside of this gradient"))
cl::opt< int > EnzymePostOptLevel("enzyme-post-opt-level", cl::init(0), cl::Hidden, cl::desc("Post optimization level within Enzyme differentiated function"))
bool DetectReadonlyOrThrowFn(llvm::Function &F, SmallPtrSetImpl< Function * > &calls_todo, llvm::TargetLibraryInfo &TLI, bool &local)
Function * getSumIntrinsic(llvm::Module &M, llvm::Type *T)
bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val)
Is the use of value val as an argument of call CI potentially captured.
llvm::cl::opt< bool > EnzymeAlwaysInlineDiff
void RecursivelyReplaceAddressSpace(llvm::Value *AI, llvm::Value *rep, bool legal)
llvm::FunctionType * getFunctionTypeForClone(llvm::FunctionType *FTy, DerivativeMode mode, unsigned width, llvm::Type *additionalArg, llvm::ArrayRef< DIFFE_TYPE > constant_args, bool diffeReturnArg, bool returnTape, bool returnPrimal, bool returnShadow)
void ReplaceFunctionImplementation(llvm::Module &M)
static llvm::SmallPtrSet< llvm::BasicBlock *, 4 > getGuaranteedUnreachable(llvm::Function *F)
static bool isAllocationCall(const llvm::Value *TmpOrig, llvm::TargetLibraryInfo &TLI)
static bool isReadOnly(Operation *op)
static Operation * getFunctionFromCall(CallOpInterface iface)
llvm::cl::opt< bool > EnzymePrintActivity
llvm::cl::opt< bool > EnzymePrint
static bool isMemFreeLibMFunction(llvm::StringRef str, llvm::Intrinsic::ID *ID=nullptr)
llvm::SmallVector< llvm::Instruction *, 2 > PostCacheStore(llvm::StoreInst *SI, llvm::IRBuilder<> &B)
Definition Utils.cpp:423
llvm::PointerType * getDefaultAnonymousTapeType(llvm::LLVMContext &C)
Definition Utils.cpp:437
Value * simplifyLoad(Value *V, size_t valSz, size_t preOffset)
Definition Utils.cpp:3386
llvm::Value * nextPowerOfTwo(llvm::IRBuilder<> &B, llvm::Value *V)
Create function to computer nearest power of two.
Definition Utils.cpp:2192
Value * CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count, const Twine &Name, CallInst **caller, Instruction **ZeroMem, bool isDefault)
Definition Utils.cpp:619
LLVMValueRef(* CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, const void *, LLVMValueRef, LLVMBuilderRef)
Definition Utils.cpp:62
bool notCaptured(llvm::Value *V)
Check if value if b captured.
Definition Utils.cpp:4608
llvm::SmallVector< llvm::Value *, 1 > getJuliaObjects(llvm::Value *v, llvm::IRBuilder<> &B)
Definition Utils.cpp:4913
llvm::FastMathFlags getFast()
Get LLVM fast math flags.
Definition Utils.cpp:3731
static llvm::StringRef getFuncName(llvm::Function *called)
Definition Utils.h:1260
static bool isIntelSubscriptIntrinsic(const llvm::IntrinsicInst &II)
Definition Utils.h:1445
ReturnType
Potential return type of generated functions.
Definition Utils.h:355
static bool startsWith(llvm::StringRef string, llvm::StringRef prefix)
Definition Utils.h:713
static llvm::PointerType * getUnqual(llvm::Type *T)
Definition Utils.h:1179
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
Definition Utils.h:203
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
Definition Utils.h:1174
static llvm::Function * getIntrinsicDeclaration(llvm::Module *M, llvm::Intrinsic::ID id, llvm::ArrayRef< llvm::Type * > Tys={})
Definition Utils.h:2263
static bool isNoCapture(const llvm::CallBase *call, size_t idx)
Definition Utils.h:1840
static void addFunctionNoCapture(llvm::Function *call, size_t idx)
Definition Utils.h:2299
static bool isDebugFunction(llvm::Function *called)
Definition Utils.h:690
static bool isPointerArithmeticInst(const llvm::Value *V, bool includephi=true, bool includebin=true)
Definition Utils.h:1456
static llvm::Value * getBaseObject(llvm::Value *V, bool offsetAllowed=true)
Definition Utils.h:1507
static llvm::MDNode * hasMetadata(const llvm::GlobalObject *O, llvm::StringRef kind)
Check if a global has metadata.
Definition Utils.h:339
void EmitWarning(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::BasicBlock *BB, const Args &...args)
Definition Utils.h:133
static bool isReadOnlyOrThrow(const llvm::Function *F)
Definition Utils.h:1759
llvm::cl::opt< bool > EnzymeJuliaAddrLoad
static void addCallSiteNoCapture(llvm::CallBase *call, size_t idx)
Definition Utils.h:2289
llvm::cl::opt< bool > EnzymePrintPerf
Print additional debug info relevant to performance.
static std::string convertSRetTypeToString(llvm::Type *T)
Definition Utils.h:2428
static llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op)
Definition Utils.h:1269
static bool isWriteOnly(const llvm::Function *F, ssize_t arg=-1)
Definition Utils.h:1788
static bool isLocalReadOnlyOrThrow(const llvm::Function *F)
Definition Utils.h:1719
static llvm::Attribute::AttrKind PrimalParamAttrsToPreserve[]
Definition Utils.h:2208
static llvm::Attribute::AttrKind ShadowParamAttrsToPreserve[]
Definition Utils.h:2240
DerivativeMode
Definition Utils.h:390
const SCEV *const node
static SetTy intersect(const SetTy &lhs, const SetTy &rhs)
bool operator==(const Constraints &rhs) const
bool operator!=(const Constraints &rhs) const
std::set< InnerTy, ConstraintComparator > SetTy
unsigned hash() const
Constraints(Type t, const SetTy &c, bool check=true)
static InnerTy make_compare(const SCEV *v, bool isEqual, const llvm::Loop *Loop, const ConstraintContext &ctx)
InnerTy notB(const ConstraintContext &ctx) const
InnerTy orB(InnerTy rhs, const ConstraintContext &ctx) const
static void insert(SetTy &set, InnerTy ty)
static InnerTy none()
const llvm::Loop *const Loop
bool operator>(const Constraints &rhs) const
bool isNone() const
static InnerTy all()
std::shared_ptr< const Constraints > InnerTy
__attribute__((noinline)) void dump() const
InnerTy andB(const InnerTy rhs, const ConstraintContext &ctx) const
enum Constraints::Type ty
bool isAll() const
const SetTy values
static void set_subtract(SetTy &set, const SetTy &rhs)
InnerTy remove(const SetTy &sub) const
SmallVector< std::pair< Value *, Value * >, 1 > allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP, const ConstraintContext &ctx, IRBuilder<> &B) const
bool operator<(const Constraints &rhs) const
void remove(Instruction *I)
bool contains(Instruction *I) const
DominatorOrderSet(DominatorTree &DT, LoopInfo &LI)
Instruction * pop_back_val()
static llvm::Value * extractMeta(llvm::IRBuilder<> &Builder, llvm::Value *Agg, unsigned off, const llvm::Twine &name="")
Helper routine to extract a nested element from a struct/array. This is.
static llvm::Type * getShadowType(llvm::Type *ty, unsigned width)
llvm::Function * CloneFunctionWithReturns(DerivativeMode mode, unsigned width, llvm::Function *&F, llvm::ValueToValueMapTy &ptrInputs, llvm::ArrayRef< DIFFE_TYPE > constant_args, llvm::SmallPtrSetImpl< llvm::Value * > &constants, llvm::SmallPtrSetImpl< llvm::Value * > &nonconstant, llvm::SmallPtrSetImpl< llvm::Value * > &returnvals, bool returnTape, bool returnPrimal, bool returnShadow, const llvm::Twine &name, llvm::ValueMap< const llvm::Value *, AssertingReplacingVH > *VMapO, bool diffeReturnArg, llvm::Type *additionalArg=nullptr)
std::map< llvm::Function *, llvm::Function * > CloneOrigin
llvm::Function * preprocessForClone(llvm::Function *F, DerivativeMode mode)
llvm::ModuleAnalysisManager MAM
void optimizeIntermediate(llvm::Function *F)
llvm::LoopAnalysisManager LAM
void AlwaysInline(llvm::Function *NewF)
llvm::AAResults & getAAResultsFromFunction(llvm::Function *NewF)
std::map< std::pair< llvm::Function *, DerivativeMode >, llvm::Function * > cache
llvm::FunctionAnalysisManager FAM
void LowerAllocAddr(llvm::Function *NewF)
void ReplaceReallocs(llvm::Function *NewF, bool mem2reg=false)
Calls to realloc with an appropriate implementation.
bool operator()(std::shared_ptr< const Constraints > lhs, std::shared_ptr< const Constraints > rhs) const
ScalarEvolution & SE
DominatorTree & DT
ConstraintContext(const ConstraintContext &ctx, InnerTy lhs)
ConstraintContext(const ConstraintContext &ctx, InnerTy lhs, InnerTy rhs)
ConstraintContext(const ConstraintContext &)=delete
std::set< InnerTy, ConstraintComparator > SetTy
const Loop * loopToSolve
ConstraintContext(ScalarEvolution &SE, const Loop *loopToSolve, const SmallVectorImpl< Instruction * > &Assumptions, DominatorTree &DT)
const SmallVectorImpl< Instruction * > & Assumptions
bool contains(InnerTy x) const
std::shared_ptr< const Constraints > InnerTy
DominatorTree & DT
compare_insts(DominatorTree &DT, LoopInfo &LI)
bool operator()(Instruction *A, Instruction *B) const