Enzyme main
Loading...
Searching...
No Matches
AddToOpToIndexAndLoad.cpp
Go to the documentation of this file.
1//===- AddToOpToIndexAndLoad.cpp - Lower Shadowed Gradient ops
2//------------------ //
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements a pass to lower custom ops generated by the Enzyme AD
11// procedure to the MemRef dialect.
12//===----------------------------------------------------------------------===//
13
14#include "Dialect/Ops.h"
15#include "PassDetails.h"
16#include "Passes/Passes.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Transforms/DialectConversion.h"
19
20#include "llvm/Support/raw_ostream.h"
21
23#include "mlir/Dialect/Affine/IR/AffineOps.h"
24#include "mlir/Dialect/Linalg/IR/Linalg.h"
25
26#include "Utils.h"
27
28namespace mlir {
29namespace enzyme {
30#define GEN_PASS_DEF_ADDTOOPTOINDEXANDLOADPASS
31#include "Passes/Passes.h.inc"
32} // namespace enzyme
33} // namespace mlir
34
35using namespace mlir;
36using namespace enzyme;
37using llvm::errs;
38namespace {
39
40SmallVector<Value> applyAffineMap(AffineMap aMap, SmallVector<Value> indices,
41 OpBuilder &builder, Location loc) {
42 SmallVector<Value> appliedAffineMap;
43 for (unsigned int i = 0; i < aMap.getNumResults(); i++) {
44 AffineMap subMap = aMap.getSubMap({i});
45 auto mapApplied = affine::AffineApplyOp::create(builder, loc, subMap,
46 ValueRange(indices));
47 appliedAffineMap.push_back(mapApplied);
48 }
49 return appliedAffineMap;
50}
51
52struct AddToOpToIndexAndLoadPass
53 : public enzyme::impl::AddToOpToIndexAndLoadPassBase<
54 AddToOpToIndexAndLoadPass> {
55 void runOnOperation() override {
56 getOperation()->walk([&](Operation *op) {
57 auto loc = op->getLoc();
58 auto enzymeAdjoint = dyn_cast<enzyme::GenericAdjointOp>(op);
59 if (!enzymeAdjoint)
60 return;
61
62 OpBuilder cacheBuilder(enzymeAdjoint);
63 auto adjoint = Utils::adjointToGeneric(enzymeAdjoint, cacheBuilder, loc);
64
65 // check if adjoint contains a enzyme.addToOp
66 Operation *addToOp = nullptr;
67 adjoint.walk([&](Operation *op) {
68 if (isa<enzyme::AddToOp>(op)) {
69 addToOp = op;
70 }
71 });
72 if (!addToOp)
73 return;
74
75 Operation *terminator = adjoint.getBodyRegion().front().getTerminator();
76 SmallVector<Value> indices;
77 SmallVector<Value> retargs;
78 auto outs = adjoint.getOutputs();
79 auto num_ins = adjoint.getInputs().size();
80 for (auto val : addToOp->getOperands()) {
81 retargs.push_back(val);
82 }
83 auto map = adjoint.getIndexingMapsArray();
84 cacheBuilder.setInsertionPoint(terminator);
85
86 // Is it a fine assumption that all indexing maps are the same?
87 for (size_t i = 0; i < map[0].getNumDims(); i++) {
88 indices.push_back(linalg::IndexOp::create(cacheBuilder, loc, i));
89 }
90
91 SmallVector<Value> rets;
92 for (size_t i = 0; i < retargs.size(); i++) {
93 // auto load = AffineLoadOp::create(cacheBuilder, loc, inputs[i],
94 // map[i], indices); auto store = AffineStoreOp::create(cacheBuilder,
95 // loc, load, inputs[i], map[i], indices);
96 SmallVector<Value> mapAppliedIndices =
97 applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc);
98 auto load = memref::LoadOp::create(cacheBuilder, loc, outs[i],
99 mapAppliedIndices);
100 auto added = cast<enzyme::AutoDiffTypeInterface>(load.getType())
101 .createAddOp(cacheBuilder, loc, load, retargs[i]);
102 memref::StoreOp::create(cacheBuilder, loc, added, outs[i],
103 mapAppliedIndices);
104 }
105
106 for (size_t i = 0; i < retargs.size(); i++) {
107 SmallVector<Value> mapAppliedIndices =
108 applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc);
109 auto load = memref::LoadOp::create(cacheBuilder, loc, outs[i],
110 mapAppliedIndices);
111 retargs[i] = load;
112 }
113
114 linalg::YieldOp::create(cacheBuilder, loc, ValueRange{retargs});
115 addToOp->erase();
116 });
117 };
118};
119} // end anonymous namespace
static mlir::linalg::GenericOp adjointToGeneric(enzyme::GenericAdjointOp &op, OpBuilder &builder, Location loc)
Definition Utils.cpp:15