16#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#define GEN_PASS_DEF_LOWERAFFINEATOMICRMWPASS
21#include "Passes/Passes.h.inc"
28struct LowerAffineAtomicRmwPass
29 :
public enzyme::impl::LowerAffineAtomicRmwPassBase<
30 LowerAffineAtomicRmwPass> {
31 void runOnOperation()
override {
32 getOperation()->walk([&](enzyme::AffineAtomicRMWOp rmw) {
33 OpBuilder builder(rmw);
34 SmallVector<Value> indices;
36 rmw.getIndices(), indices);
37 rmw.getResult().replaceAllUsesWith(
38 memref::AtomicRMWOp::create(builder, rmw.getLoc(),
39 arith::AtomicRMWKind::addf,
40 rmw.getValue(), rmw.getMemref(), indices)
void computeAffineIndices(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands, SmallVectorImpl< Value > &indices)