Enzyme main
Loading...
Searching...
No Matches
JLInstSimplify.cpp
Go to the documentation of this file.
1//=- JLInstSimplify.h - Additional instsimplifyrules for julia programs =//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains a utility LLVM pass for printing derived Activity Analysis
22// results of a given function.
23//
24//===----------------------------------------------------------------------===//
25#include <cstdint>
26#include <llvm/Config/llvm-config.h>
27
28#include "llvm/ADT/SmallVector.h"
29
30#include "llvm/IR/BasicBlock.h"
31#include "llvm/IR/Constants.h"
32#include "llvm/IR/DebugInfoMetadata.h"
33#include "llvm/IR/Function.h"
34#include "llvm/IR/IRBuilder.h"
35#include "llvm/IR/InstrTypes.h"
36#include "llvm/IR/Instructions.h"
37#include "llvm/IR/MDBuilder.h"
38#include "llvm/IR/Metadata.h"
39
40#include "llvm/IR/LegacyPassManager.h"
41
42#include "llvm/Support/Debug.h"
43
44#include "llvm/Analysis/TargetLibraryInfo.h"
45
46#include "llvm-c/Core.h"
47#include "llvm-c/DataTypes.h"
48
49#include "llvm-c/ExternC.h"
50#include "llvm-c/Types.h"
51
52#include "JLInstSimplify.h"
53#include "LibraryFuncs.h"
54#include "Utils.h"
55
56using namespace llvm;
57#ifdef DEBUG_TYPE
58#undef DEBUG_TYPE
59#endif
60#define DEBUG_TYPE "jl-inst-simplify"
61namespace {
62
63bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
64 llvm::AAResults &AA, llvm::LoopInfo &LI) {
65 bool changed = false;
66
67 for (auto &BB : F)
68 for (auto &I : BB) {
69 if (auto FI = dyn_cast<FreezeInst>(&I)) {
70 if (FI->hasOneUse()) {
71 bool allBranch = true;
72 for (auto user : FI->users()) {
73 if (!isa<BranchInst>(user)) {
74 allBranch = false;
75 break;
76 }
77 }
78 if (allBranch) {
79 FI->replaceAllUsesWith(FI->getOperand(0));
80 changed = true;
81 continue;
82 }
83 }
84 }
85 if (I.use_empty())
86 continue;
87
88 bool legal = false;
89 ICmpInst::Predicate pred;
90 if (auto cmp = dyn_cast<ICmpInst>(&I)) {
91 pred = cmp->getPredicate();
92 legal = true;
93 } else if (auto CI = dyn_cast<CallBase>(&I)) {
94 if (getFuncNameFromCall(CI) == "jl_mightalias") {
95#if LLVM_VERSION_MAJOR >= 14
96 size_t numargs = CI->arg_size();
97#else
98 size_t numargs = CI->getNumArgOperands();
99#endif
100 if (numargs == 2 && isa<PointerType>(I.getOperand(0)->getType()) &&
101 isa<PointerType>(I.getOperand(0)->getType())) {
102 legal = true;
103 pred = ICmpInst::Predicate::ICMP_EQ;
104 }
105 }
106 }
107
108 if (legal) {
109 if (auto alias = arePointersGuaranteedNoAlias(
110 TLI, AA, LI, I.getOperand(0), I.getOperand(1), false)) {
111
112 bool val = *alias;
113 auto repval = ICmpInst::isTrueWhenEqual(pred)
114 ? ConstantInt::get(I.getType(), 1 - val)
115 : ConstantInt::get(I.getType(), val);
116 I.replaceAllUsesWith(repval);
117 changed = true;
118 continue;
119 }
120 }
121 }
122
123 return changed;
124}
125
126class JLInstSimplify final : public FunctionPass {
127public:
128 static char ID;
129 JLInstSimplify() : FunctionPass(ID) {}
130
131 void getAnalysisUsage(AnalysisUsage &AU) const override {
132 AU.addRequired<TargetLibraryInfoWrapperPass>();
133 AU.addRequired<AAResultsWrapperPass>();
134 AU.addRequired<LoopInfoWrapperPass>();
135 }
136
137 bool runOnFunction(Function &F) override {
138 auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
139 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
140 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
141 return jlInstSimplify(F, TLI, AA, LI);
142 }
143};
144
145} // namespace
146
147FunctionPass *createJLInstSimplifyPass() { return new JLInstSimplify(); }
148
149extern "C" void LLVMAddJLInstSimplifyPass(LLVMPassManagerRef PM) {
150 unwrap(PM)->add(createJLInstSimplifyPass());
151}
152
153char JLInstSimplify::ID = 0;
154
155static RegisterPass<JLInstSimplify> X("jl-inst-simplify",
156 "JL instruction simplification");
157
159JLInstSimplifyNewPM::run(llvm::Function &F,
160 llvm::FunctionAnalysisManager &FAM) {
161 bool changed = false;
162 changed = jlInstSimplify(F, FAM.getResult<TargetLibraryAnalysis>(F),
163 FAM.getResult<AAManager>(F),
164 FAM.getResult<LoopAnalysis>(F));
165 return changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
166}
167llvm::AnalysisKey JLInstSimplifyNewPM::Key;
FunctionPass * createJLInstSimplifyPass()
void LLVMAddJLInstSimplifyPass(LLVMPassManagerRef PM)
static RegisterPass< JLInstSimplify > X("jl-inst-simplify", "JL instruction simplification")
llvm::Optional< bool > arePointersGuaranteedNoAlias(TargetLibraryInfo &TLI, llvm::AAResults &AA, llvm::LoopInfo &LI, llvm::Value *op0, llvm::Value *op1, bool offsetAllowed)
Definition Utils.cpp:4618
static llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op)
Definition Utils.h:1269
Result run(llvm::Function &M, llvm::FunctionAnalysisManager &MAM)
llvm::PreservedAnalyses Result