Enzyme main
Loading...
Searching...
No Matches
RaiseLLVMExtPass.cpp
Go to the documentation of this file.
1//===- RaiseLLVMExtPass.cpp - Raise LLVM Ext operations ------------------ //
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to raise LLVM ops to the LLVM Ext
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
15#include "Passes/Passes.h"
16
17#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18
19#include "mlir/IR/Matchers.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/IR/SymbolTable.h"
22
23namespace mlir {
24namespace enzyme {
25using namespace mlir::enzyme;
26#define GEN_PASS_DEF_RAISELLVMEXTPASS
27#include "Passes/Passes.h.inc"
28} // namespace enzyme
29} // namespace mlir
30
31namespace {
32using namespace mlir;
33using namespace enzyme;
34
35struct RaiseLLVMExtPass
36 : public enzyme::impl::RaiseLLVMExtPassBase<RaiseLLVMExtPass> {
37 using RaiseLLVMExtPassBase::RaiseLLVMExtPassBase;
38
39 void runOnOperation() override {
40 bool failed = false;
41
42 SymbolTable::walkSymbolTables(
43 getOperation(),
44 /*allUsesVisible*/ true, [&](Operation *st, bool allUsesVisible) {
45 SymbolTable symtable(st);
46
47 auto name = StringAttr::get(&getContext(), "__enzyme_ptr_size_hint");
48 auto uses = SymbolTable::getSymbolUses(name, st);
49
50 if (!uses)
51 return;
52
53 auto fn = cast<FunctionOpInterface>(symtable.lookup(name));
54 if (!fn.isExternal()) {
55 failed = true;
56 fn.emitError() << "__enzyme_ptr_size_hint is not declared external";
57 return;
58 }
59
60 for (auto use : *uses) {
61 auto call = dyn_cast<LLVM::CallOp>(use.getUser());
62 if (!call) {
63 failed = true;
64 use.getUser()->emitError()
65 << "user of __enzyme_ptr_size_hint is not a llvm.call";
66 return;
67 }
68
69 OpBuilder builder(call);
70 llvm_ext::PtrSizeHintOp::create(
71 builder, call.getLoc(), call.getOperand(0), call.getOperand(1));
72
73 call.erase();
74 }
75
76 symtable.erase(fn);
77 });
78
79 if (failed)
80 signalPassFailure();
81 }
82};
83
84} // end anonymous namespace