Enzyme main
Loading...
Searching...
No Matches
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utilities for operation interfaces
2//--------------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#include "Interfaces/Utils.h"
11#include "Dialect/Ops.h"
13#include "mlir/Analysis/AliasAnalysis.h"
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Interfaces/FunctionInterfaces.h"
19#include <optional>
20
21using namespace mlir;
22using namespace mlir::enzyme;
23namespace mlir {
24namespace enzyme {
25namespace oputils {
26
27const std::set<std::string> &getNonCapturingFunctions() {
28 static std::set<std::string> NonCapturingFunctions = {
29 "free", "printf", "fprintf", "scanf",
30 "fscanf", "gettimeofday", "clock_gettime", "getenv",
31 "strrchr", "strlen", "sprintf", "sscanf",
32 "mkdir", "fwrite", "fread", "memcpy",
33 "cudaMemcpy", "memset", "cudaMemset", "__isoc99_scanf",
34 "__isoc99_fscanf"};
35 return NonCapturingFunctions;
36}
37
38static bool isCaptured(Value v, Operation *potentialUser = nullptr,
39 bool *seenuse = nullptr) {
40 SmallVector<Value> todo = {v};
41 while (todo.size()) {
42 Value v = todo.pop_back_val();
43 for (auto u : v.getUsers()) {
44 if (seenuse && u == potentialUser)
45 *seenuse = true;
46 if (isa<memref::LoadOp, LLVM::LoadOp, affine::AffineLoadOp>(u))
47 continue;
48 if (auto s = dyn_cast<memref::StoreOp>(u)) {
49 if (s.getValue() == v)
50 return true;
51 continue;
52 }
53 if (auto s = dyn_cast<affine::AffineStoreOp>(u)) {
54 if (s.getValue() == v)
55 return true;
56 continue;
57 }
58 if (auto s = dyn_cast<LLVM::StoreOp>(u)) {
59 if (s.getValue() == v)
60 return true;
61 continue;
62 }
63 if (auto sub = dyn_cast<LLVM::GEPOp>(u)) {
64 todo.push_back(sub);
65 }
66 if (auto sub = dyn_cast<LLVM::BitcastOp>(u)) {
67 todo.push_back(sub);
68 }
69 if (auto sub = dyn_cast<LLVM::AddrSpaceCastOp>(u)) {
70 todo.push_back(sub);
71 }
72 if (auto sub = dyn_cast<func::ReturnOp>(u)) {
73 continue;
74 }
75 if (auto sub = dyn_cast<LLVM::MemsetOp>(u)) {
76 continue;
77 }
78 if (auto sub = dyn_cast<LLVM::MemcpyOp>(u)) {
79 continue;
80 }
81 if (auto sub = dyn_cast<LLVM::MemmoveOp>(u)) {
82 continue;
83 }
84 if (auto sub = dyn_cast<memref::CastOp>(u)) {
85 todo.push_back(sub);
86 }
87 if (auto sub = dyn_cast<memref::DeallocOp>(u)) {
88 continue;
89 }
90 if (auto cop = dyn_cast<LLVM::CallOp>(u)) {
91 if (auto callee = cop.getCallee()) {
92 if (getNonCapturingFunctions().count(callee->str()))
93 continue;
94 }
95 }
96 if (auto cop = dyn_cast<func::CallOp>(u)) {
97 if (getNonCapturingFunctions().count(cop.getCallee().str()))
98 continue;
99 }
100 return true;
101 }
102 }
103
104 return false;
105}
106
107static Value getBase(Value v) {
108 while (true) {
109 if (auto s = v.getDefiningOp<LLVM::GEPOp>()) {
110 v = s.getBase();
111 continue;
112 }
113 if (auto s = v.getDefiningOp<LLVM::BitcastOp>()) {
114 v = s.getArg();
115 continue;
116 }
117 if (auto s = v.getDefiningOp<LLVM::AddrSpaceCastOp>()) {
118 v = s.getArg();
119 continue;
120 }
121 if (auto s = v.getDefiningOp<memref::CastOp>()) {
122 v = s.getSource();
123 continue;
124 }
125 break;
126 }
127 return v;
128}
129
130static bool isStackAlloca(Value v) {
131 return v.getDefiningOp<memref::AllocaOp>() ||
132 v.getDefiningOp<memref::AllocOp>() ||
133 v.getDefiningOp<LLVM::AllocaOp>();
134}
135
136bool mayAlias(Value v1, Value v2) {
137 v1 = getBase(v1);
138 v2 = getBase(v2);
139 if (v1 == v2)
140 return true;
141
142 // We may now assume neither v1 nor v2 are subindices
143
144 if (auto glob = v1.getDefiningOp<memref::GetGlobalOp>()) {
145 if (auto Aglob = v2.getDefiningOp<memref::GetGlobalOp>()) {
146 return glob.getName() == Aglob.getName();
147 }
148 }
149
150 if (auto glob = v1.getDefiningOp<LLVM::AddressOfOp>()) {
151 if (auto Aglob = v2.getDefiningOp<LLVM::AddressOfOp>()) {
152 return glob.getGlobalName() == Aglob.getGlobalName();
153 }
154 }
155
156 bool isAlloca[2];
157 bool isGlobal[2];
158
159 isAlloca[0] = isStackAlloca(v1);
160 isGlobal[0] = v1.getDefiningOp<memref::GetGlobalOp>() ||
161 v1.getDefiningOp<LLVM::AddressOfOp>();
162
163 isAlloca[1] = isStackAlloca(v2);
164
165 isGlobal[1] = v2.getDefiningOp<memref::GetGlobalOp>() ||
166 v2.getDefiningOp<LLVM::AddressOfOp>();
167
168 // Non-equivalent allocas/global's cannot conflict with each other
169 if ((isAlloca[0] || isGlobal[0]) && (isAlloca[1] || isGlobal[1]))
170 return false;
171
172 BlockArgument barg1 = dyn_cast<BlockArgument>(v1);
173 BlockArgument barg2 = dyn_cast<BlockArgument>(v2);
174
175 FunctionOpInterface f1 =
176 barg1 ? dyn_cast<FunctionOpInterface>(barg1.getOwner()->getParentOp())
177 : nullptr;
178 FunctionOpInterface f2 =
179 barg2 ? dyn_cast<FunctionOpInterface>(barg2.getOwner()->getParentOp())
180 : nullptr;
181
182 bool isNoAlias1 =
183 f1 ? !!f1.getArgAttr(barg1.getArgNumber(),
184 LLVM::LLVMDialect::getNoAliasAttrName())
185 : false;
186 bool isNoAlias2 =
187 f2 ? !!f2.getArgAttr(barg2.getArgNumber(),
188 LLVM::LLVMDialect::getNoAliasAttrName())
189 : false;
190
191 if (!isCaptured(v1) && isNoAlias1)
192 return false;
193 if (!isCaptured(v2) && isNoAlias2)
194 return false;
195
196 bool isArg[2];
197 isArg[0] = f1;
198 isArg[1] = f2;
199
200 // Stack allocations cannot have been passed as an argument.
201 if ((isAlloca[0] && isArg[1]) || (isAlloca[1] && isArg[0]))
202 return false;
203
204 // Non captured base allocas cannot conflict with another base value.
205 if (isAlloca[0] && !isCaptured(v1))
206 return false;
207
208 if (isAlloca[1] && !isCaptured(v2))
209 return false;
210
211 return true;
212}
213
214bool mayAlias(MemoryEffects::EffectInstance a, Value v2) {
215 if (Value v = a.getValue()) {
216 return mayAlias(v, v2);
217 }
218 return true;
219}
220
221bool mayAlias(MemoryEffects::EffectInstance &a,
222 MemoryEffects::EffectInstance &b) {
223 if (a.getResource()->getResourceID() != b.getResource()->getResourceID())
224 return false;
225 Value valA = a.getValue();
226 Value valB = b.getValue();
227
228 // unknown effects may always alias
229 if (!valA || !valB) {
230 return true;
231 }
232
233 auto valResult = oputils::mayAlias(valA, valB);
234 return valResult;
235}
236
237bool isReadOnly(Operation *op) {
238 // If the op has memory effects, try to characterize them to see if the op
239 // is trivially dead here.
240 if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
241 // Check to see if this op either has no effects, or only reads from memory.
242 SmallVector<MemoryEffects::EffectInstance, 1> effects;
243 effectInterface.getEffects(effects);
244 if (!llvm::all_of(effects, [op](const MemoryEffects::EffectInstance &it) {
245 return isa<MemoryEffects::Read>(it.getEffect());
246 })) {
247 return false;
248 }
249 }
250
251 bool isRecursiveContainer =
252 op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
253 if (isRecursiveContainer) {
254 for (Region &region : op->getRegions()) {
255 for (auto &block : region) {
256 for (auto &nestedOp : block)
257 if (!isReadOnly(&nestedOp))
258 return false;
259 }
260 }
261 }
262
263 return true;
264}
265
266bool isReadNone(Operation *op) {
267 bool hasRecursiveEffects = op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
268 if (hasRecursiveEffects) {
269 for (Region &region : op->getRegions()) {
270 for (auto &block : region) {
271 for (auto &nestedOp : block)
272 if (!isReadNone(&nestedOp))
273 return false;
274 }
275 }
276 return true;
277 }
278
279 // If the op has memory effects, try to characterize them to see if the op
280 // is trivially dead here.
281 if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
282 // Check to see if this op either has no effects, or only allocates/reads
283 // memory.
284 SmallVector<MemoryEffects::EffectInstance, 1> effects;
285 effectInterface.getEffects(effects);
286 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &it) {
287 return isa<MemoryEffects::Read>(it.getEffect()) ||
288 isa<MemoryEffects::Write>(it.getEffect());
289 })) {
290 return false;
291 }
292 return true;
293 }
294 return false;
295}
296
297bool collectOpEffects(Operation *rootOp,
298 SmallVector<MemoryEffects::EffectInstance> &effects) {
299 SmallVector<Operation *> effectingOps(1, rootOp);
300 bool couldCollectEffects = true;
301
302 while (!effectingOps.empty()) {
303 Operation *op = effectingOps.pop_back_val();
304 bool isRecursiveContainer =
305 op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
306
307 if (isRecursiveContainer) {
308 for (Region &region : op->getRegions()) {
309 for (Block &block : region) {
310 for (Operation &nestedOp : block) {
311 effectingOps.push_back(&nestedOp);
312 }
313 }
314 }
315 }
316
317 if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
318 SmallVector<MemoryEffects::EffectInstance> localEffects;
319 effectInterface.getEffects(localEffects);
320 llvm::append_range(effects, localEffects);
321 } else if (!isRecursiveContainer) {
322 // Handle specific operations which are not recursive containers, but
323 // still may have memory effects(eg. autodiff calls, llvm calls to libc
324 // functions). If it's none of these, then the operation may not have any
325 // memory effects
326 if (auto cop = dyn_cast<LLVM::CallOp>(op)) {
327 if (auto callee = cop.getCallee()) {
328 if (*callee == "scanf" || *callee == "__isoc99_scanf") {
329 // Global read
330 effects.emplace_back(
331 MemoryEffects::Effect::get<MemoryEffects::Read>());
332
333 bool first = true;
334 for (auto &arg : cop.getArgOperandsMutable()) {
335 if (first)
336 effects.emplace_back(MemoryEffects::Read::get(), &arg);
337 else
338 effects.emplace_back(MemoryEffects::Write::get(), &arg,
339 SideEffects::DefaultResource::get());
340 first = false;
341 }
342 }
343 if (*callee == "fscanf" || *callee == "__isoc99_fscanf") {
344 // Global read
345 effects.emplace_back(
346 MemoryEffects::Effect::get<MemoryEffects::Read>());
347
348 for (auto &&[idx, arg] :
349 llvm::enumerate(cop.getArgOperandsMutable())) {
350 if (idx == 0) {
351 effects.emplace_back(MemoryEffects::Read::get(), &arg,
352 SideEffects::DefaultResource::get());
353 effects.emplace_back(MemoryEffects::Write::get(), &arg,
354 SideEffects::DefaultResource::get());
355 } else if (idx == 1) {
356 effects.emplace_back(MemoryEffects::Read::get(), &arg,
357 SideEffects::DefaultResource::get());
358 } else
359 effects.emplace_back(MemoryEffects::Write::get(), &arg,
360 SideEffects::DefaultResource::get());
361 }
362 }
363 if (*callee == "printf") {
364 // Global read
365 effects.emplace_back(
366 MemoryEffects::Effect::get<MemoryEffects::Write>());
367 for (auto &arg : cop.getArgOperandsMutable()) {
368 effects.emplace_back(MemoryEffects::Read::get(), &arg,
369 SideEffects::DefaultResource::get());
370 }
371 }
372 if (*callee == "free") {
373 for (auto &arg : cop.getArgOperandsMutable()) {
374 effects.emplace_back(MemoryEffects::Free::get(), &arg,
375 SideEffects::DefaultResource::get());
376 }
377 }
378 if (*callee == "strlen") {
379 for (auto &arg : cop.getArgOperandsMutable()) {
380 effects.emplace_back(MemoryEffects::Read::get(), &arg,
381 SideEffects::DefaultResource::get());
382 }
383 }
384 }
385 } else {
386 // TODO: handle AutoDiffOp, ForwardDiffOp and AutoDiffRegionOp effects.
387 // Just conservatively add all effects for now
388
389 // We need to be conservative here in case the op doesn't have the
390 // interface and assume it can have any possible effect.
391
392 effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Read>());
393 effects.emplace_back(
394 MemoryEffects::Effect::get<MemoryEffects::Write>());
395 effects.emplace_back(
396 MemoryEffects::Effect::get<MemoryEffects::Allocate>());
397 effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Free>());
398 couldCollectEffects = false;
399
400 // no use in exploring other ops so break
401 break;
402 }
403 }
404 }
405 return couldCollectEffects;
406}
407
408SmallVector<MemoryEffects::EffectInstance>
409collectFnEffects(FunctionOpInterface fnOp) {
410 SmallVector<MemoryEffects::EffectInstance> innerEffects;
411 for (auto &blk : fnOp.getBlocks()) {
412 for (auto &op : blk) {
413 SmallVector<MemoryEffects::EffectInstance> opEffects;
414 (void)collectOpEffects(&op, opEffects);
415 innerEffects.append(opEffects.begin(), opEffects.end());
416 }
417 }
418
419 return innerEffects;
420}
421
422MemoryEffects::EffectInstance getEffectOfVal(Value val,
423 MemoryEffects::Effect *effect,
424 SideEffects::Resource *resource) {
425
426 if (auto valOR = dyn_cast<OpResult>(val))
427 return MemoryEffects::EffectInstance(effect, valOR, resource);
428 else if (auto valBA = dyn_cast<BlockArgument>(val)) {
429 return MemoryEffects::EffectInstance(effect, valBA, resource);
430 } else {
431 llvm_unreachable("Provided Value is neither an argument nor a result of an "
432 "op. This is not allowed by SSA");
433 return nullptr;
434 }
435}
436
437} // namespace oputils
438} // namespace enzyme
439} // namespace mlir
bool mayAlias(Value v1, Value v2)
Definition Utils.cpp:136
static bool isStackAlloca(Value v)
Definition Utils.cpp:130
static Value getBase(Value v)
Definition Utils.cpp:107
const std::set< std::string > & getNonCapturingFunctions()
Definition Utils.cpp:27
bool collectOpEffects(Operation *rootOp, SmallVector< MemoryEffects::EffectInstance > &effects)
Returns the side effects of an operation(similar to mlir::getEffectsRecursively).
Definition Utils.cpp:297
SmallVector< MemoryEffects::EffectInstance > collectFnEffects(FunctionOpInterface fnOp)
Definition Utils.cpp:409
MemoryEffects::EffectInstance getEffectOfVal(Value val, MemoryEffects::Effect *effect, SideEffects::Resource *resource)
Definition Utils.cpp:422
static bool isCaptured(Value v, Operation *potentialUser=nullptr, bool *seenuse=nullptr)
Definition Utils.cpp:38
bool isReadNone(Operation *op)
Definition Utils.cpp:266
bool isReadOnly(Operation *op)
Definition Utils.cpp:237