17#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/IR/Matchers.h"
20#include "mlir/IR/PatternMatch.h"
25#define GEN_PASS_DEF_LOWERLLVMEXTPASS
26#include "Passes/Passes.h.inc"
32using namespace enzyme;
34LogicalResult tryLoweringToAlloca(llvm_ext::AllocOp alloc,
35 uint64_t staticThreshold) {
37 if (!matchPattern(alloc.getSize(), m_ConstantInt(&size)))
40 if (size.getZExtValue() > staticThreshold)
43 Operation *free =
nullptr;
44 for (
auto user : alloc.getResult().getUsers()) {
45 auto oFree = dyn_cast<llvm_ext::FreeOp>(user);
52 if (oFree->getBlock() != alloc->getBlock())
57 if (!alloc->isBeforeInBlock(free))
61 OpBuilder builder(alloc);
62 auto alloca = LLVM::AllocaOp::create(builder, alloc.getLoc(),
63 alloc.getResult().getType(),
64 builder.getI8Type(), alloc.getSize());
65 alloc.getResult().replaceAllUsesWith(alloca.getResult());
67 LLVM::LifetimeStartOp::create(builder, alloc.getLoc(), alloca.getResult());
69 builder.setInsertionPoint(free);
70 LLVM::LifetimeEndOp::create(builder, alloc.getLoc(), alloca.getResult());
78void lowerAlloc(llvm_ext::AllocOp alloc, uint64_t staticThreshold) {
79 if (tryLoweringToAlloca(alloc, staticThreshold).succeeded())
82 SymbolTable symtable(SymbolTable::getNearestSymbolTable(alloc));
84 auto mallocFn = symtable.lookup<LLVM::LLVMFuncOp>(
"malloc");
86 Block *b = &symtable.getOp()->getRegion(0).front();
87 OpBuilder builder(b, b->begin());
89 auto fnType = mlir::LLVM::LLVMFunctionType::get(
90 LLVM::LLVMPointerType::get(alloc.getContext()), builder.getI64Type(),
94 LLVM::LLVMFuncOp::create(builder, alloc.getLoc(),
"malloc", fnType);
97 OpBuilder builder(alloc);
98 auto mallocCall = LLVM::CallOp::create(builder, alloc.getLoc(), mallocFn,
99 alloc->getOperands());
100 alloc.getResult().replaceAllUsesWith(mallocCall.getResult());
104void lowerFree(llvm_ext::FreeOp free) {
105 SymbolTable symtable(SymbolTable::getNearestSymbolTable(free));
107 auto freeFn = symtable.lookup<LLVM::LLVMFuncOp>(
"free");
109 Block *b = &symtable.getOp()->getRegion(0).front();
110 OpBuilder builder(b, b->begin());
112 auto fnType = mlir::LLVM::LLVMFunctionType::get(
113 LLVM::LLVMVoidType::get(free.getContext()),
114 LLVM::LLVMPointerType::get(free.getContext()),
117 freeFn = LLVM::LLVMFuncOp::create(builder, free.getLoc(),
"free", fnType);
120 OpBuilder builder(free);
121 LLVM::CallOp::create(builder, free.getLoc(), freeFn, free->getOperands());
126struct LowerLLVMExtPass
127 :
public enzyme::impl::LowerLLVMExtPassBase<LowerLLVMExtPass> {
128 using LowerLLVMExtPassBase::LowerLLVMExtPassBase;
130 void runOnOperation()
override {
131 Operation *op = getOperation();
133 op->walk([](llvm_ext::PtrSizeHintOp psh) { psh.erase(); });
135 SmallVector<llvm_ext::AllocOp> allocs;
136 op->walk([&](llvm_ext::AllocOp alloc) { allocs.push_back(alloc); });
138 for (
auto alloc : allocs)
139 lowerAlloc(alloc, lowerToAllocaThreshold);
141 SmallVector<llvm_ext::FreeOp> frees;
142 op->walk([&](llvm_ext::FreeOp free) { frees.push_back(free); });
144 for (
auto free : frees)