Enzyme main
Loading...
Searching...
No Matches
LowerAffineAtomicRmwPass.cpp
Go to the documentation of this file.
1//===------------------------------------------------------------------------ //
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to lower custom ops generated by the Enzyme AD
10// procedure to the MemRef dialect.
11//===----------------------------------------------------------------------===//
12
13#include "Dialect/Ops.h"
14#include "Passes/Passes.h"
15#include "Passes/Utils.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17
18namespace mlir {
19namespace enzyme {
20#define GEN_PASS_DEF_LOWERAFFINEATOMICRMWPASS
21#include "Passes/Passes.h.inc"
22} // namespace enzyme
23} // namespace mlir
24
25using namespace mlir;
26
27namespace {
28struct LowerAffineAtomicRmwPass
29 : public enzyme::impl::LowerAffineAtomicRmwPassBase<
30 LowerAffineAtomicRmwPass> {
31 void runOnOperation() override {
32 getOperation()->walk([&](enzyme::AffineAtomicRMWOp rmw) {
33 OpBuilder builder(rmw);
34 SmallVector<Value> indices;
35 enzyme::computeAffineIndices(builder, rmw.getLoc(), rmw.getMap(),
36 rmw.getIndices(), indices);
37 rmw.getResult().replaceAllUsesWith(
38 memref::AtomicRMWOp::create(builder, rmw.getLoc(),
39 arith::AtomicRMWKind::addf,
40 rmw.getValue(), rmw.getMemref(), indices)
41 .getResult());
42 rmw->erase();
43 });
44 };
45};
46} // end anonymous namespace
void computeAffineIndices(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands, SmallVectorImpl< Value > &indices)
Definition Utils.cpp:171