Enzyme main
Loading...
Searching...
No Matches
LibraryFuncs.h
Go to the documentation of this file.
1//===- LibraryFuncs.h - Utilities for handling library functions ---------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients}, author = {Moses, William S. and
13// Churavy, Valentin}, booktitle = {Advances in Neural Information Processing
14// Systems 33}, year = {2020}, note = {To appear in},
15// }
16//
17//===----------------------------------------------------------------------===//
18//
19// This file defines miscelaious utilities for handling library functions.
20//
21//===----------------------------------------------------------------------===//
22
23#ifndef LIBRARYFUNCS_H_
24#define LIBRARYFUNCS_H_
25
26#include <llvm/ADT/StringMap.h>
27#include <llvm/Analysis/AliasAnalysis.h>
28#include <llvm/Analysis/TargetLibraryInfo.h>
29#include <llvm/IR/IRBuilder.h>
30#include <llvm/IR/InlineAsm.h>
31#include <llvm/IR/Instructions.h>
32
33#include "Utils.h"
34
35class GradientUtils;
36extern llvm::StringMap<std::function<llvm::Value *(
37 llvm::IRBuilder<> &, llvm::CallInst *, llvm::ArrayRef<llvm::Value *>,
38 GradientUtils *)>>
40extern llvm::StringMap<
41 std::function<llvm::CallInst *(llvm::IRBuilder<> &, llvm::Value *)>>
43
44/// Return whether a given function is a known C/C++ memory allocation function
45/// For updating below one should read MemoryBuiltins.cpp, TargetLibraryInfo.cpp
46static inline bool isAllocationFunction(const llvm::StringRef name,
47 const llvm::TargetLibraryInfo &TLI) {
48 if (name == "enzyme_allocator")
49 return true;
50 if (name == "calloc" || name == "malloc")
51 return true;
52 if (name == "_mlir_memref_to_llvm_alloc")
53 return true;
54 if (name == "swift_allocObject")
55 return true;
56 if (name == "__size_returning_new_experiment")
57 return true;
58 if (name == "__rust_alloc" || name == "__rust_alloc_zeroed")
59 return true;
60 if (name == "julia.gc_alloc_obj" || name == "jl_gc_alloc_typed" ||
61 name == "ijl_gc_alloc_typed")
62 return true;
63 if (shadowHandlers.find(name) != shadowHandlers.end())
64 return true;
65
66 using namespace llvm;
67 llvm::LibFunc libfunc;
68 if (!TLI.getLibFunc(name, libfunc))
69 return false;
70
71 switch (libfunc) {
72 case LibFunc_malloc: // malloc(unsigned int);
73 case LibFunc_valloc: // valloc(unsigned int);
74
75 case LibFunc_Znwj: // new(unsigned int);
76 case LibFunc_ZnwjRKSt9nothrow_t: // new(unsigned int, nothrow);
77 case LibFunc_ZnwjSt11align_val_t: // new(unsigned int, align_val_t)
78 case LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t: // new(unsigned int,
79 // align_val_t, nothrow)
80
81 case LibFunc_Znwm: // new(unsigned long);
82 case LibFunc_ZnwmRKSt9nothrow_t: // new(unsigned long, nothrow);
83 case LibFunc_ZnwmSt11align_val_t: // new(unsigned long, align_val_t)
84 case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t: // new(unsigned long,
85 // align_val_t, nothrow)
86
87 case LibFunc_Znaj: // new[](unsigned int);
88 case LibFunc_ZnajRKSt9nothrow_t: // new[](unsigned int, nothrow);
89 case LibFunc_ZnajSt11align_val_t: // new[](unsigned int, align_val_t)
90 case LibFunc_ZnajSt11align_val_tRKSt9nothrow_t: // new[](unsigned int,
91 // align_val_t, nothrow)
92
93 case LibFunc_Znam: // new[](unsigned long);
94 case LibFunc_ZnamRKSt9nothrow_t: // new[](unsigned long, nothrow);
95 case LibFunc_ZnamSt11align_val_t: // new[](unsigned long, align_val_t)
96 case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t: // new[](unsigned long,
97 // align_val_t, nothrow)
98
99 case LibFunc_msvc_new_int: // new(unsigned int);
100 case LibFunc_msvc_new_int_nothrow: // new(unsigned int, nothrow);
101 case LibFunc_msvc_new_longlong: // new(unsigned long long);
102 case LibFunc_msvc_new_longlong_nothrow: // new(unsigned long long, nothrow);
103 case LibFunc_msvc_new_array_int: // new[](unsigned int);
104 case LibFunc_msvc_new_array_int_nothrow: // new[](unsigned int, nothrow);
105 case LibFunc_msvc_new_array_longlong: // new[](unsigned long long);
106 case LibFunc_msvc_new_array_longlong_nothrow: // new[](unsigned long long,
107 // nothrow);
108
109 // TODO strdup, strndup
110
111 // TODO call, realloc, reallocf
112
113 // TODO (perhaps) posix_memalign
114 return true;
115 default:
116 return false;
117 }
118}
119
120/// Return whether a given function is a known C/C++ memory deallocation
121/// function For updating below one should read MemoryBuiltins.cpp,
122/// TargetLibraryInfo.cpp
123static inline bool isDeallocationFunction(const llvm::StringRef name,
124 const llvm::TargetLibraryInfo &TLI) {
125 using namespace llvm;
126 llvm::LibFunc libfunc;
127 if (name == "_ZdlPvmSt11align_val_t")
128 return true;
129 if (!TLI.getLibFunc(name, libfunc)) {
130 if (name == "free")
131 return true;
132 if (name == "_mlir_memref_to_llvm_free")
133 return true;
134 if (name == "__rust_dealloc")
135 return true;
136 if (name == "swift_release")
137 return true;
138 return false;
139 }
140
141 switch (libfunc) {
142 // void free(void*);
143 case LibFunc_free:
144
145 // void operator delete[](void*);
146 case LibFunc_ZdaPv:
147 // void operator delete(void*);
148 case LibFunc_ZdlPv:
149 // void operator delete[](void*);
150 case LibFunc_msvc_delete_array_ptr32:
151 // void operator delete[](void*);
152 case LibFunc_msvc_delete_array_ptr64:
153 // void operator delete(void*);
154 case LibFunc_msvc_delete_ptr32:
155 // void operator delete(void*);
156 case LibFunc_msvc_delete_ptr64:
157
158 // void operator delete[](void*, nothrow);
159 case LibFunc_ZdaPvRKSt9nothrow_t:
160 // void operator delete[](void*, unsigned int);
161 case LibFunc_ZdaPvj:
162 // void operator delete[](void*, unsigned long);
163 case LibFunc_ZdaPvm:
164 // void operator delete(void*, nothrow);
165 case LibFunc_ZdlPvRKSt9nothrow_t:
166 // void operator delete(void*, unsigned int);
167 case LibFunc_ZdlPvj:
168 // void operator delete(void*, unsigned long);
169 case LibFunc_ZdlPvm:
170 // void operator delete(void*, align_val_t)
171 case LibFunc_ZdlPvSt11align_val_t:
172 // void operator delete[](void*, align_val_t)
173 case LibFunc_ZdaPvSt11align_val_t:
174 // void operator delete[](void*, unsigned int);
175 case LibFunc_msvc_delete_array_ptr32_int:
176 // void operator delete[](void*, nothrow);
177 case LibFunc_msvc_delete_array_ptr32_nothrow:
178 // void operator delete[](void*, unsigned long long);
179 case LibFunc_msvc_delete_array_ptr64_longlong:
180 // void operator delete[](void*, nothrow);
181 case LibFunc_msvc_delete_array_ptr64_nothrow:
182 // void operator delete(void*, unsigned int);
183 case LibFunc_msvc_delete_ptr32_int:
184 // void operator delete(void*, nothrow);
185 case LibFunc_msvc_delete_ptr32_nothrow:
186 // void operator delete(void*, unsigned long long);
187 case LibFunc_msvc_delete_ptr64_longlong:
188 // void operator delete(void*, nothrow);
189 case LibFunc_msvc_delete_ptr64_nothrow:
190 // void operator delete(void*, align_val_t, nothrow)
191 case LibFunc_ZdlPvSt11align_val_tRKSt9nothrow_t:
192 // void operator delete[](void*, align_val_t, nothrow)
193 case LibFunc_ZdaPvSt11align_val_tRKSt9nothrow_t:
194 return true;
195 default:
196 return false;
197 }
198}
199
200static inline void zeroKnownAllocation(llvm::IRBuilder<> &bb,
201 llvm::Value *toZero,
202 llvm::ArrayRef<llvm::Value *> argValues,
203 const llvm::StringRef funcName,
204 const llvm::TargetLibraryInfo &TLI,
205 llvm::CallInst *orig) {
206 using namespace llvm;
207 assert(isAllocationFunction(funcName, TLI));
208
209 // Don't re-zero an already-zero buffer
210 if (funcName == "calloc" || funcName == "__rust_alloc_zeroed")
211 return;
212
213 Value *allocSize = argValues[0];
214 if (funcName == "julia.gc_alloc_obj" || funcName == "jl_gc_alloc_typed" ||
215 funcName == "ijl_gc_alloc_typed") {
216 allocSize = argValues[1];
217 }
218 if (funcName == "enzyme_allocator") {
219 auto index = getAllocationIndexFromCall(orig);
220 allocSize = argValues[*index];
221 }
222 Value *dst_arg = toZero;
223
224 if (funcName == "__size_returning_new_experiment")
225 dst_arg = bb.CreateExtractValue(dst_arg, 0);
226
227 if (dst_arg->getType()->isIntegerTy())
228 dst_arg = bb.CreateIntToPtr(dst_arg, getInt8PtrTy(toZero->getContext()));
229 else
230 dst_arg = bb.CreateBitCast(
231 dst_arg, getInt8PtrTy(toZero->getContext(),
232 toZero->getType()->getPointerAddressSpace()));
233
234 auto val_arg = ConstantInt::get(Type::getInt8Ty(toZero->getContext()), 0);
235 auto len_arg =
236 bb.CreateZExtOrTrunc(allocSize, Type::getInt64Ty(toZero->getContext()));
237
238 auto memset = bb.CreateMemSet(dst_arg, val_arg, len_arg, MaybeAlign());
239 memset->addParamAttr(0, Attribute::NonNull);
240 if (auto CI = dyn_cast<ConstantInt>(allocSize)) {
241 auto derefBytes = CI->getLimitedValue();
242#if LLVM_VERSION_MAJOR >= 14
243 memset->addDereferenceableParamAttr(0, derefBytes);
244 memset->setAttributes(
245 memset->getAttributes().addDereferenceableOrNullParamAttr(
246 memset->getContext(), 0, derefBytes));
247#else
248 memset->addDereferenceableAttr(llvm::AttributeList::FirstArgIndex,
249 derefBytes);
250 memset->addDereferenceableOrNullAttr(llvm::AttributeList::FirstArgIndex,
251 derefBytes);
252#endif
253 }
254}
255
256/// Perform the corresponding deallocation of tofree, given it was allocated by
257/// allocationfn
258// For updating below one should read MemoryBuiltins.cpp, TargetLibraryInfo.cpp
259llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder,
260 llvm::Value *tofree,
261 llvm::StringRef allocationfn,
262 const llvm::DebugLoc &debuglocation,
263 const llvm::TargetLibraryInfo &TLI,
264 llvm::CallInst *orig,
265 GradientUtils *gutils);
266
267static inline bool isAllocationCall(const llvm::Value *TmpOrig,
268 llvm::TargetLibraryInfo &TLI) {
269 if (auto *CI = llvm::dyn_cast<llvm::CallBase>(TmpOrig)) {
270 auto AttrList =
271 CI->getAttributes().getAttributes(llvm::AttributeList::FunctionIndex);
272 if (AttrList.hasAttribute("enzyme_allocation"))
273 return true;
274 if (auto Fn = getFunctionFromCall(CI))
275 if (Fn->hasFnAttribute("enzyme_allocation"))
276 return true;
278 }
279 return false;
280}
281
282static inline bool isDeallocationCall(const llvm::Value *TmpOrig,
283 llvm::TargetLibraryInfo &TLI) {
284 if (auto *CI = llvm::dyn_cast<llvm::CallBase>(TmpOrig)) {
286 }
287 return false;
288}
289
290#endif
static bool isDeallocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory deallocation function For updating below one ...
static bool isDeallocationCall(const llvm::Value *TmpOrig, llvm::TargetLibraryInfo &TLI)
llvm::StringMap< std::function< llvm::CallInst *(llvm::IRBuilder<> &, llvm::Value *)> > shadowErasers
llvm::StringMap< std::function< llvm::Value *(llvm::IRBuilder<> &, llvm::CallInst *, llvm::ArrayRef< llvm::Value * >, GradientUtils *)> > shadowHandlers
llvm::CallInst * freeKnownAllocation(llvm::IRBuilder<> &builder, llvm::Value *tofree, llvm::StringRef allocationfn, const llvm::DebugLoc &debuglocation, const llvm::TargetLibraryInfo &TLI, llvm::CallInst *orig, GradientUtils *gutils)
Perform the corresponding deallocation of tofree, given it was allocated by allocationfn.
static bool isAllocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory allocation function For updating below one sh...
static void zeroKnownAllocation(llvm::IRBuilder<> &bb, llvm::Value *toZero, llvm::ArrayRef< llvm::Value * > argValues, const llvm::StringRef funcName, const llvm::TargetLibraryInfo &TLI, llvm::CallInst *orig)
static bool isAllocationCall(const llvm::Value *TmpOrig, llvm::TargetLibraryInfo &TLI)
static Operation * getFunctionFromCall(CallOpInterface iface)
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
Definition Utils.h:1174
static llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op)
Definition Utils.h:1269
static llvm::Optional< size_t > getAllocationIndexFromCall(const llvm::CallBase *op)
Definition Utils.h:1318
llvm::TargetLibraryInfo & TLI
Various analysis results of newFunc.