Enzyme main
Loading...
Searching...
No Matches
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - General Utilities -------* C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Utils.h"
10#include "mlir/Dialect/Affine/IR/AffineOps.h"
11
12using namespace mlir;
13using namespace mlir::enzyme;
14
15linalg::GenericOp Utils::adjointToGeneric(enzyme::GenericAdjointOp &op,
16 OpBuilder &builder, Location loc) {
17 auto inputs = op.getInputs();
18 auto outputs = op.getOutputs();
19 auto resultTensors = op.getResultTensors();
20 auto indexingMaps = op.getIndexingMapsAttr();
21 auto iteratorTypes = op.getIteratorTypesAttr();
22
23 auto genericOp = mlir::linalg::GenericOp::create(
24 builder, loc, TypeRange(resultTensors), ValueRange(inputs),
25 ValueRange(outputs), ArrayAttr(indexingMaps), ArrayAttr(iteratorTypes),
26 StringAttr(), StringAttr());
27
28 auto &body = genericOp.getRegion();
29 body.takeBody(op.getRegion());
30
31 op.erase();
32
33 return genericOp;
34}
35
36bool mlir::enzyme::opCmp(Operation *a, Operation *b) {
37 if (a == b)
38 return false;
39
40 // Ancestors are less than their descendants.
41 if (a->isProperAncestor(b)) {
42 return true;
43 } else if (b->isProperAncestor(a->getParentOp())) {
44 return false;
45 }
46
47 // Move a and b to be direct descendents of the same op
48 while (!a->getParentOp()->isAncestor(b))
49 a = a->getParentOp();
50
51 while (!b->getParentOp()->isAncestor(a))
52 b = b->getParentOp();
53
54 assert(a->getParentOp() == b->getParentOp());
55
56 if (a->getBlock() == b->getBlock()) {
57 return a->isBeforeInBlock(b);
58 } else {
59 return blockCmp(a->getBlock(), b->getBlock());
60 }
61}
62
63bool mlir::enzyme::regionCmp(Region *a, Region *b) {
64 if (a == b)
65 return false;
66
67 // Ancestors are less than their descendants.
68 if (a->getParentOp()->isProperAncestor(b->getParentOp())) {
69 return true;
70 } else if (b->getParentOp()->isProperAncestor(a->getParentOp())) {
71 return false;
72 }
73
74 if (a->getParentOp() == b->getParentOp()) {
75 return a->getRegionNumber() < b->getRegionNumber();
76 }
77 return opCmp(a->getParentOp(), b->getParentOp());
78}
79
80bool mlir::enzyme::blockCmp(Block *a, Block *b) {
81 if (a == b)
82 return false;
83
84 // Ancestors are less than their descendants.
85 if (a->getParent()->isProperAncestor(b->getParent())) {
86 return true;
87 } else if (b->getParent()->isProperAncestor(a->getParent())) {
88 return false;
89 }
90
91 if (a->getParent() == b->getParent()) {
92 // If the blocks are in the same region, then the first one in
93 // the region is less than the second one.
94 for (auto &bb : *b->getParent()) {
95 if (&bb == a)
96 return true;
97 }
98 return false;
99 }
100
101 return regionCmp(a->getParent(), b->getParent());
102}
103
104bool mlir::enzyme::valueCmp(mlir::Value a, mlir::Value b) {
105 // Equal values are not less than each other.
106 if (a == b)
107 return false;
108
109 auto ba = dyn_cast<BlockArgument>(a);
110 auto bb = dyn_cast<BlockArgument>(b);
111 // Define block arguments are less than non-block arguments.
112 if (ba && !bb)
113 return true;
114 if (!ba && bb)
115 return false;
116 if (ba && bb) {
117 if (ba.getOwner() == bb.getOwner()) {
118 return ba.getArgNumber() < bb.getArgNumber();
119 }
120 return blockCmp(ba.getOwner(), bb.getOwner());
121 }
122
123 OpResult ra = cast<OpResult>(a);
124 OpResult rb = cast<OpResult>(b);
125
126 if (ra.getOwner() == rb.getOwner()) {
127 return ra.getResultNumber() < rb.getResultNumber();
128 } else {
129 return opCmp(ra.getOwner(), rb.getOwner());
130 }
131}
132
133Type mlir::enzyme::getConcatType(Value val, int64_t width) {
134 auto valTy = val.getType();
135 if (auto valTensorTy = dyn_cast<TensorType>(valTy)) {
136 // val is a tensor, prepend batch width to shape
137 SmallVector<int64_t> out_shape = {width};
138 out_shape.append(valTensorTy.getShape().begin(),
139 valTensorTy.getShape().end());
140 auto outTy = valTensorTy.clone(out_shape);
141 return outTy;
142 } else if (auto valMemrefTy = dyn_cast<MemRefType>(valTy)) {
143 // val is a memref, prepend batch width
144 SmallVector<int64_t> out_shape = {width};
145 out_shape.append(valMemrefTy.getShape().begin(),
146 valMemrefTy.getShape().end());
147 auto outTy = valMemrefTy.clone(out_shape);
148 return outTy;
149 } else {
150 // val is a scalar
151 return RankedTensorType::get(width, valTy);
152 }
153}
154
155Value mlir::enzyme::getConcatValue(OpBuilder &builder, Location loc,
156 ArrayRef<Value> argList) {
157 int64_t width = argList.size();
158 Type out_type = mlir::enzyme::getConcatType(argList.front(), width);
159 mlir::Value out = enzyme::ConcatOp::create(builder, loc, out_type, argList);
160 return out;
161}
162
163Value mlir::enzyme::getExtractValue(OpBuilder &builder, Location loc,
164 Type argTy, Value val, int64_t index) {
165 // Extract the original output from the tensorized output at the given index.
166 IntegerAttr indexAttr = builder.getI64IntegerAttr(index);
167 Value out = enzyme::ExtractOp::create(builder, loc, argTy, val, indexAttr);
168 return out;
169}
170
171void mlir::enzyme::computeAffineIndices(OpBuilder &builder, Location loc,
172 AffineMap map, ValueRange operands,
173 SmallVectorImpl<Value> &indices) {
174 for (unsigned i = 0; i < map.getNumResults(); i++) {
175 indices.push_back(affine::AffineApplyOp::create(
176 builder, loc, map.getSubMap({i}), operands));
177 }
178}
static mlir::linalg::GenericOp adjointToGeneric(enzyme::GenericAdjointOp &op, OpBuilder &builder, Location loc)
Definition Utils.cpp:15
bool valueCmp(mlir::Value a, mlir::Value b)
Definition Utils.cpp:104
Type getConcatType(Value val, int64_t width)
Definition Utils.cpp:133
bool blockCmp(mlir::Block *a, mlir::Block *b)
Value getConcatValue(OpBuilder &builder, Location loc, ArrayRef< Value > argList)
Definition Utils.cpp:155
bool opCmp(mlir::Operation *a, mlir::Operation *b)
bool regionCmp(mlir::Region *a, mlir::Region *b)
void computeAffineIndices(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands, SmallVectorImpl< Value > &indices)
Definition Utils.cpp:171
Value getExtractValue(OpBuilder &builder, Location loc, Type argTy, Value val, int64_t index)
Definition Utils.cpp:163