Enzyme main
Loading...
Searching...
No Matches
LowerLLVMExtPass.cpp
Go to the documentation of this file.
1//===- LowerLLVMExtPass.cpp - Lower LLVM Ext operations ------------------ //
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to lower custom ops generated by the LLVM Ext
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
15#include "Passes/Passes.h"
16
17#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18
19#include "mlir/IR/Matchers.h"
20#include "mlir/IR/PatternMatch.h"
21
22namespace mlir {
23namespace enzyme {
24using namespace mlir::enzyme;
25#define GEN_PASS_DEF_LOWERLLVMEXTPASS
26#include "Passes/Passes.h.inc"
27} // namespace enzyme
28} // namespace mlir
29
30namespace {
31using namespace mlir;
32using namespace enzyme;
33
34LogicalResult tryLoweringToAlloca(llvm_ext::AllocOp alloc,
35 uint64_t staticThreshold) {
36 llvm::APInt size;
37 if (!matchPattern(alloc.getSize(), m_ConstantInt(&size)))
38 return failure();
39
40 if (size.getZExtValue() > staticThreshold)
41 return failure();
42
43 Operation *free = nullptr;
44 for (auto user : alloc.getResult().getUsers()) {
45 auto oFree = dyn_cast<llvm_ext::FreeOp>(user);
46 if (!oFree)
47 continue;
48
49 if (free)
50 return failure(); // multiple frees
51
52 if (oFree->getBlock() != alloc->getBlock()) // free not in same block
53 return failure();
54
55 free = user;
56
57 if (!alloc->isBeforeInBlock(free))
58 return failure();
59 }
60
61 OpBuilder builder(alloc);
62 auto alloca = LLVM::AllocaOp::create(builder, alloc.getLoc(),
63 alloc.getResult().getType(),
64 builder.getI8Type(), alloc.getSize());
65 alloc.getResult().replaceAllUsesWith(alloca.getResult());
66
67 LLVM::LifetimeStartOp::create(builder, alloc.getLoc(), alloca.getResult());
68
69 builder.setInsertionPoint(free);
70 LLVM::LifetimeEndOp::create(builder, alloc.getLoc(), alloca.getResult());
71
72 free->erase();
73 alloc->erase();
74
75 return success();
76}
77
78void lowerAlloc(llvm_ext::AllocOp alloc, uint64_t staticThreshold) {
79 if (tryLoweringToAlloca(alloc, staticThreshold).succeeded())
80 return;
81
82 SymbolTable symtable(SymbolTable::getNearestSymbolTable(alloc));
83
84 auto mallocFn = symtable.lookup<LLVM::LLVMFuncOp>("malloc");
85 if (!mallocFn) {
86 Block *b = &symtable.getOp()->getRegion(0).front();
87 OpBuilder builder(b, b->begin());
88
89 auto fnType = mlir::LLVM::LLVMFunctionType::get(
90 LLVM::LLVMPointerType::get(alloc.getContext()), builder.getI64Type(),
91 /*isVarArg=*/false);
92
93 mallocFn =
94 LLVM::LLVMFuncOp::create(builder, alloc.getLoc(), "malloc", fnType);
95 }
96
97 OpBuilder builder(alloc);
98 auto mallocCall = LLVM::CallOp::create(builder, alloc.getLoc(), mallocFn,
99 alloc->getOperands());
100 alloc.getResult().replaceAllUsesWith(mallocCall.getResult());
101 alloc.erase();
102}
103
104void lowerFree(llvm_ext::FreeOp free) {
105 SymbolTable symtable(SymbolTable::getNearestSymbolTable(free));
106
107 auto freeFn = symtable.lookup<LLVM::LLVMFuncOp>("free");
108 if (!freeFn) {
109 Block *b = &symtable.getOp()->getRegion(0).front();
110 OpBuilder builder(b, b->begin());
111
112 auto fnType = mlir::LLVM::LLVMFunctionType::get(
113 LLVM::LLVMVoidType::get(free.getContext()),
114 LLVM::LLVMPointerType::get(free.getContext()),
115 /*isVarArg=*/false);
116
117 freeFn = LLVM::LLVMFuncOp::create(builder, free.getLoc(), "free", fnType);
118 }
119
120 OpBuilder builder(free);
121 LLVM::CallOp::create(builder, free.getLoc(), freeFn, free->getOperands());
122
123 free.erase();
124}
125
126struct LowerLLVMExtPass
127 : public enzyme::impl::LowerLLVMExtPassBase<LowerLLVMExtPass> {
128 using LowerLLVMExtPassBase::LowerLLVMExtPassBase;
129
130 void runOnOperation() override {
131 Operation *op = getOperation();
132
133 op->walk([](llvm_ext::PtrSizeHintOp psh) { psh.erase(); });
134
135 SmallVector<llvm_ext::AllocOp> allocs;
136 op->walk([&](llvm_ext::AllocOp alloc) { allocs.push_back(alloc); });
137
138 for (auto alloc : allocs)
139 lowerAlloc(alloc, lowerToAllocaThreshold);
140
141 SmallVector<llvm_ext::FreeOp> frees;
142 op->walk([&](llvm_ext::FreeOp free) { frees.push_back(free); });
143
144 for (auto free : frees)
145 lowerFree(free);
146 }
147};
148
149} // end anonymous namespace