17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Transforms/DialectConversion.h"
20#include "llvm/Support/raw_ostream.h"
23#include "mlir/Dialect/Affine/IR/AffineOps.h"
24#include "mlir/Dialect/Linalg/IR/Linalg.h"
30#define GEN_PASS_DEF_ADDTOOPTOINDEXANDLOADPASS
31#include "Passes/Passes.h.inc"
36using namespace enzyme;
40SmallVector<Value> applyAffineMap(AffineMap aMap, SmallVector<Value> indices,
41 OpBuilder &builder, Location loc) {
42 SmallVector<Value> appliedAffineMap;
43 for (
unsigned int i = 0; i < aMap.getNumResults(); i++) {
44 AffineMap subMap = aMap.getSubMap({i});
45 auto mapApplied = affine::AffineApplyOp::create(builder, loc, subMap,
47 appliedAffineMap.push_back(mapApplied);
49 return appliedAffineMap;
52struct AddToOpToIndexAndLoadPass
53 :
public enzyme::impl::AddToOpToIndexAndLoadPassBase<
54 AddToOpToIndexAndLoadPass> {
55 void runOnOperation()
override {
56 getOperation()->walk([&](Operation *op) {
57 auto loc = op->getLoc();
58 auto enzymeAdjoint = dyn_cast<enzyme::GenericAdjointOp>(op);
62 OpBuilder cacheBuilder(enzymeAdjoint);
66 Operation *addToOp =
nullptr;
67 adjoint.walk([&](Operation *op) {
68 if (isa<enzyme::AddToOp>(op)) {
75 Operation *terminator = adjoint.getBodyRegion().front().getTerminator();
76 SmallVector<Value> indices;
77 SmallVector<Value> retargs;
78 auto outs = adjoint.getOutputs();
79 auto num_ins = adjoint.getInputs().size();
80 for (
auto val : addToOp->getOperands()) {
81 retargs.push_back(val);
83 auto map = adjoint.getIndexingMapsArray();
84 cacheBuilder.setInsertionPoint(terminator);
87 for (
size_t i = 0; i < map[0].getNumDims(); i++) {
88 indices.push_back(linalg::IndexOp::create(cacheBuilder, loc, i));
91 SmallVector<Value> rets;
92 for (
size_t i = 0; i < retargs.size(); i++) {
96 SmallVector<Value> mapAppliedIndices =
97 applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc);
98 auto load = memref::LoadOp::create(cacheBuilder, loc, outs[i],
100 auto added = cast<enzyme::AutoDiffTypeInterface>(load.getType())
101 .createAddOp(cacheBuilder, loc, load, retargs[i]);
102 memref::StoreOp::create(cacheBuilder, loc, added, outs[i],
106 for (
size_t i = 0; i < retargs.size(); i++) {
107 SmallVector<Value> mapAppliedIndices =
108 applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc);
109 auto load = memref::LoadOp::create(cacheBuilder, loc, outs[i],
114 linalg::YieldOp::create(cacheBuilder, loc, ValueRange{retargs});
static mlir::linalg::GenericOp adjointToGeneric(enzyme::GenericAdjointOp &op, OpBuilder &builder, Location loc)