20#include "mlir/Dialect/Affine/IR/AffineOps.h"
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Linalg/IR/Linalg.h"
23#include "mlir/Dialect/MemRef/IR/MemRef.h"
25#include "mlir/IR/DialectRegistry.h"
26#include "mlir/Support/LogicalResult.h"
29#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
30#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
31#include "mlir/IR/PatternMatch.h"
32#include "mlir/Transforms/DialectConversion.h"
39Value invertMemref(Value inp, OpBuilder &builder, Location loc) {
40 MemRefType iType = cast<MemRefType>(inp.getType());
41 SmallVector<Value> dims;
42 SmallVector<Value> dimSubOnes;
43 SmallVector<Value> strides;
44 Value negOne = arith::ConstantIndexOp::create(builder, loc, -1);
45 int shapeDim = iType.getShape().size();
46 for (
int i = 0; i < shapeDim; i++) {
47 Value dim = memref::DimOp::create(builder, loc, inp, i);
49 auto dimSubOne = arith::AddIOp::create(builder, loc, dim, negOne);
50 dimSubOnes.push_back(dimSubOne);
51 strides.push_back(negOne);
54 memref::SubViewOp::create(builder, loc, inp, ValueRange(dimSubOnes),
55 ValueRange(dims), ValueRange(strides));
59SmallVector<AffineMap> getIndexingMapsArray(enzyme::GenericAdjointOp &op) {
60 auto attr = op.getIndexingMapsAttr();
61 SmallVector<AffineMap> indexingMaps;
62 for (
auto map : attr.getValue()) {
63 indexingMaps.push_back(cast<AffineMapAttr>(map).getValue());
69struct GenericOpInterfaceReverse
70 :
public ReverseAutoDiffOpInterface::ExternalModel<
71 GenericOpInterfaceReverse<T_>, T_> {
72 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
74 SmallVector<Value> caches)
const {
75 auto linalgOp = cast<linalg::LinalgOp>(op);
76 if (!linalgOp.hasPureBufferSemantics()) {
77 llvm::errs() <<
"Linalg op with tensor semantics not yet supported\n";
81 linalg::LinalgOp newOp =
85 IRRewriter rewriter(builder.getContext(), builder.getListener());
86 auto failiureOrLinalgOp = generalizeNamedOp(rewriter, newOp);
87 if (!failed(failiureOrLinalgOp)) {
88 linalg::GenericOp replacement = failiureOrLinalgOp.value();
89 auto scope = OpBuilder::InsertionGuard(builder);
90 builder.setInsertionPointAfter(newOp);
91 builder.insert(replacement);
96 auto cacheBuilder = OpBuilder(newOp, builder.getListener());
99 AffineMap aMap = newOp.getShapesToLoopsMap();
100 SmallVector<Value> dims;
101 for (OpOperand *input : newOp.getDpsInputOperands()) {
102 auto shape = cast<MemRefType>(input->get().getType()).getShape();
103 for (
unsigned i = 0; i < shape.size(); i++) {
105 arith::ConstantIndexOp::create(cacheBuilder, op->getLoc(), i);
106 auto dim = memref::DimOp::create(cacheBuilder, op->getLoc(),
111 for (Value output : newOp.getDpsInits()) {
112 auto shape = cast<MemRefType>(output.getType()).getShape();
113 for (
unsigned i = 0; i < shape.size(); i++) {
115 arith::ConstantIndexOp::create(cacheBuilder, op->getLoc(), i);
117 memref::DimOp::create(cacheBuilder, op->getLoc(), output, dimI);
122 SmallVector<Value> iterationDomains;
123 SmallVector<int64_t> shapes;
124 for (
unsigned int i = 0; i < aMap.getNumResults(); i++) {
125 AffineMap subMap = aMap.getSubMap({i});
126 Value domain = affine::AffineApplyOp::create(cacheBuilder, op->getLoc(),
127 subMap, ValueRange(dims));
128 iterationDomains.push_back(domain);
129 shapes.push_back(ShapedType::kDynamic);
133 SmallVector<Value> inputs, outputs;
134 SmallVector<AffineMap> indexingMaps;
135 SmallVector<utils::IteratorType> iteratorTypes{
136 linalgOp.getNumLoops(), utils::IteratorType::parallel};
138 for (OpOperand &output : linalgOp.getDpsInitsMutable()) {
142 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&output));
144 Value view = invertMemref(out, builder, op->getLoc());
145 outputs.push_back(view);
148 for (OpOperand *input : linalgOp.getDpsInputOperands()) {
152 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(input));
154 Value view = invertMemref(inp, builder, op->getLoc());
155 inputs.push_back(view);
158 ArrayAttr indexingMapsArrayAttr =
159 builder.getAffineMapArrayAttr(indexingMaps);
160 ArrayAttr iteratorTypesArrayAttr =
161 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
162 iteratorTypes, [&](utils::IteratorType iter) -> mlir::Attribute {
163 return linalg::IteratorTypeAttr::get(builder.getContext(), iter);
165 auto adjoint = enzyme::GenericAdjointOp::create(
166 builder, op->getLoc(), TypeRange(), ValueRange(outputs),
167 ValueRange(inputs), indexingMapsArrayAttr, iteratorTypesArrayAttr,
168 StringAttr(), StringAttr());
170 int numInputs = inputs.size();
171 auto buildFuncReturnOp = [&gutils, numInputs](OpBuilder &builder,
173 auto loc = oBB->rbegin()->getLoc();
174 SmallVector<Value> retargs;
175 for (
auto arg : oBB->getArguments()) {
178 enzyme::AddToOp::create(builder, loc,
179 ValueRange{retargs}.take_front(numInputs));
183 Region *newOpRegion = newOp.getBlock()->getParent();
184 Region *adjointRegion = &adjoint.getRegion();
185 int numInputsAdjoint = adjoint.getInputs().size();
186 Location loc = op->getLoc();
188 SmallVector<Value> pushCaches;
190 auto hook = [newOpRegion, adjointRegion, loc, &numCaches = numCaches,
191 numInputsAdjoint, &pushCaches = pushCaches](Type t) {
192 OpBuilder builder(newOpRegion);
193 Value pushCache = enzyme::InitOp::create(builder, loc, t);
194 pushCaches.push_back(pushCache);
195 newOpRegion->addArgument(t, loc);
198 adjointRegion->insertArgument(numInputsAdjoint + numCaches, t, loc);
200 return std::make_pair(pushCache, popCache);
204 gutils, *linalgOp.getBlock()->getParent(), adjoint.getRegion(),
205 buildFuncReturnOp, hook);
206 if (!sub.succeeded())
209 auto newOpYield = cast<linalg::YieldOp>(
210 cast<linalg::GenericOp>(newOp).getBodyRegion().front().getTerminator());
211 for (Value pc : pushCaches) {
212 newOpYield.getValuesMutable().append(pc);
215 Block *body = &(adjoint.getRegion().front());
216 auto yieldOp = cast<enzyme::AddToOp>(body->getTerminator());
217 for (
auto opOperand : yieldOp.getOperands()) {
218 body->addArgument(opOperand.getType(), opOperand.getLoc());
221 OpBuilder builderAdd(yieldOp);
223 auto newIndexingMaps = newOp.getIndexingMapsArray();
224 auto indexingMapsAdjoint = getIndexingMapsArray(adjoint);
225 for (
int i = 0; i < numCaches; i++) {
226 Value cacheArg = body->getArgument(outputs.size() + i);
228 Type ct = cacheArg.getType();
229 Type type = MemRefType::get(shapes, ct);
230 auto alloc = memref::AllocOp::create(cacheBuilder, op->getLoc(), type,
231 ValueRange(iterationDomains));
234 alloc->setAttr(alloc.getOperandSegmentSizesAttrName(),
235 cacheBuilder.getDenseI32ArrayAttr(
236 {static_cast<int32_t>(iterationDomains.size()), 0}));
238 cast<linalg::GenericOp>(newOp).getOutputsMutable().append(
239 ValueRange({alloc}));
240 newIndexingMaps.push_back(AffineMap::getMultiDimIdentityMap(
241 iterationDomains.size(), cacheBuilder.getContext()));
243 builderAdd.setInsertionPoint(adjoint);
244 Value retrievedValue = gutils->
popCache(cache, builderAdd);
245 retrievedValue = invertMemref(retrievedValue, builderAdd, op->getLoc());
246 adjoint.getInputsMutable().append(ValueRange({retrievedValue}));
247 indexingMapsAdjoint.insert(
248 indexingMapsAdjoint.begin() + numInputsAdjoint + i,
249 AffineMap::getMultiDimIdentityMap(iterationDomains.size(),
250 builderAdd.getContext()));
252 SmallVector<Attribute> indexingMapsAttr;
253 SmallVector<Attribute> indexingMapsAttrAdjoint;
254 for (
auto &map : newIndexingMaps) {
255 indexingMapsAttr.push_back(AffineMapAttr::get(map));
257 for (
auto &map : indexingMapsAdjoint) {
258 indexingMapsAttrAdjoint.push_back(AffineMapAttr::get(map));
260 cast<linalg::GenericOp>(newOp).setIndexingMapsAttr(
261 cacheBuilder.getArrayAttr(indexingMapsAttr));
262 adjoint->setAttr(adjoint.getIndexingMapsAttrName(),
263 builder.getArrayAttr(indexingMapsAttrAdjoint));
267 SmallVector<Value> cacheValues(Operation *op,
269 return SmallVector<Value>();
272 void createShadowValues(Operation *op, OpBuilder &builder,
277 :
public AutoDiffOpInterface::ExternalModel<GenericFwd, linalg::GenericOp> {
279 LogicalResult createForwardModeTangent(Operation *orig, OpBuilder &builder,
282 auto op = cast<linalg::GenericOp>(orig);
287 SmallVector<Type> newOpResultTypes;
288 newOpResultTypes.reserve(op->getNumResults() * 2);
289 for (
auto &&[result, init] :
290 llvm::zip_equal(op->getResults(), op.getOutputs())) {
291 newOpResultTypes.push_back(result.getType());
295 auto typeIface = dyn_cast<AutoDiffTypeInterface>(result.getType());
297 op->emitError() <<
" AutoDiffTypeInterface not implemented for "
298 << result.getType() <<
"\n";
301 newOpResultTypes.push_back(typeIface.getShadowType(gutils->
width));
304 SmallVector<Value> newInputs;
305 SmallVector<Value> newOutputs;
306 SmallVector<AffineMap> indexingMaps;
309 for (Value operand : op.getInputs()) {
311 indexingMaps.push_back(op.getIndexingMapsArray()[idx]);
314 indexingMaps.push_back(op.getIndexingMapsArray()[idx]);
318 for (
auto &&[operand, res, oarg] :
319 llvm::zip_equal(op.getOutputs(), op->getResults(),
320 op.getRegion().front().getArguments().slice(
321 op.getInputs().size()))) {
323 indexingMaps.push_back(op.getIndexingMapsArray()[idx]);
328 indexingMaps.push_back(op.getIndexingMapsArray()[idx]);
330 auto typeIface = dyn_cast<AutoDiffTypeInterface>(operand.getType());
332 newOutputs.push_back(
333 typeIface.createNullValue(builder, operand.getLoc()));
334 indexingMaps.push_back(op.getIndexingMapsArray()[idx]);
338 auto typeIface = dyn_cast<AutoDiffTypeInterface>(oarg.getType());
340 newBA.getOwner()->insertArgument(newBA.getArgNumber() + 1,
341 typeIface.getShadowType(),
355 auto replacement = linalg::GenericOp::create(
356 builder, op.getLoc(), newOpResultTypes, newInputs, newOutputs,
357 indexingMaps, op.getIteratorTypesArray(),
361 assert(replacement->getNumResults() == newOpResultTypes.size());
362 for (
auto &&[region, replacementRegion] :
363 llvm::zip(newOp->getRegions(), replacement->getRegions())) {
364 replacementRegion.takeBody(region);
369 SmallVector<Value> reps;
371 for (OpResult r : op->getResults()) {
373 reps.push_back(replacement->getResult(idx));
379 inverted.replaceAllUsesWith(replacement->getResult(idx));
380 gutils->
erase(inverted.getDefiningOp());
386 for (
auto &origRegion : op->getRegions()) {
387 for (
auto &origBlock : origRegion) {
388 for (Operation &o : origBlock) {
398 gutils->
erase(newOp);
405#include "Implementations/LinalgDerivatives.inc"
409 (Ts::template attachInterface<GenericOpInterfaceReverse<Ts>>(*context), ...);
413 DialectRegistry ®istry) {
414 registry.addExtension(+[](MLIRContext *context, linalg::LinalgDialect *) {
415 registerInterfaces(context);
416 linalg::GenericOp::attachInterface<GenericFwd>(*context);
419#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
void attachAllInterfaces(MLIRContext *context)
LogicalResult differentiate(MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion, llvm::function_ref< buildReturnFunction > buildFuncRetrunOp, std::function< std::pair< Value, Value >(Type)> cacheCreator)
Value popCache(Value cache, OpBuilder &builder)
Value initAndPushCache(Value v, OpBuilder &builder)
IRMapping invertedPointers
void replaceOrigOpWith(Operation *op, ValueRange vals)
std::map< Operation *, Operation * > originalToNewFnOps
LogicalResult visitChild(Operation *op)
void erase(Operation *op)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
bool isConstantValue(mlir::Value v) const
void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry)