Enzyme main
Loading...
Searching...
No Matches
BuiltinAutoDiffTypeInterfaceImpl.cpp
Go to the documentation of this file.
1//===- BuiltinAutoDiffOpInterfaceImpl.cpp - Interface external model ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the external model implementation of the automatic
10// differentiation type interfaces for the upstream MLIR builtin dialect.
11//
12//===----------------------------------------------------------------------===//
13
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Complex/IR/Complex.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinDialect.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/DialectRegistry.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/Support/LLVM.h"
24
25using namespace mlir;
26using namespace mlir::enzyme;
27
28namespace {
29
30static mlir::Type batchType(mlir::Type type, int64_t width) {
31 if (width == 1)
32 return type;
33
34 if (auto TT = dyn_cast<mlir::TensorType>(type)) {
35 SmallVector<int64_t> shape;
36 shape.reserve(TT.getShape().size() + 1);
37 shape.push_back(width);
38 shape.append(TT.getShape().begin(), TT.getShape().end());
39 return TT.clone(shape);
40 }
41
42 return RankedTensorType::get({width}, type);
43}
44
45template <typename ConcreteType>
46class FloatTypeInterface : public AutoDiffTypeInterface::ExternalModel<
47 FloatTypeInterface<ConcreteType>, ConcreteType> {
48public:
49 Attribute createNullAttr(Type self) const {
50 auto fltType = cast<ConcreteType>(self);
51 return FloatAttr::get(fltType, APFloat(fltType.getFloatSemantics(), 0));
52 }
53
54 Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
55 auto fltType = cast<ConcreteType>(self);
56 return arith::ConstantOp::create(builder, loc, fltType,
57 cast<FloatAttr>(createNullAttr(self)));
58 }
59
60 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
61 Value b) const {
62 return arith::AddFOp::create(builder, loc, a, b);
63 }
64 Value createConjOp(Type self, OpBuilder &builder, Location loc,
65 Value a) const {
66 return a;
67 }
68
69 Type getShadowType(Type self, int64_t width) const {
70 return batchType(self, width);
71 }
72
73 bool isMutable(Type self) const { return false; }
74
75 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
76 Value val) const {
77 return failure();
78 }
79
80 bool isZero(Type self, Value val) const {
81 return matchPattern(val, m_AnyZeroFloat());
82 }
83
84 bool isZeroAttr(Type self, Attribute attr) const {
85 return matchPattern(attr, m_AnyZeroFloat());
86 }
87
88 int64_t getApproxSize(Type self) const {
89 return self.getIntOrFloatBitWidth();
90 }
91};
92
93class TensorTypeInterface
94 : public AutoDiffTypeInterface::ExternalModel<TensorTypeInterface,
95 TensorType> {
96public:
97 Attribute createNullAttr(Type self) const {
98 auto tenType = cast<TensorType>(self);
99 auto ET = tenType.getElementType();
100
101 if (auto F = dyn_cast<FloatType>(ET)) {
102 APFloat apvalue(F.getFloatSemantics(), 0);
103 return DenseElementsAttr::get(tenType, apvalue);
104 }
105 if (auto G = dyn_cast<ComplexType>(ET)) {
106 if (auto F = dyn_cast<FloatType>(G.getElementType())) {
107 APFloat apvalue(F.getFloatSemantics(), 0);
108 mlir::Complex<APFloat> c(apvalue, apvalue);
109 return DenseElementsAttr::get(tenType, c);
110 }
111 }
112 if (auto IT = dyn_cast<IntegerType>(ET)) {
113 APInt apvalue(IT.getWidth(), 0);
114 return DenseElementsAttr::get(tenType, apvalue);
115 }
116 llvm::errs() << " cannot create null value of tensor type: " << tenType
117 << "\n";
118 assert(0);
119 return nullptr;
120 }
121 Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
122 auto attr = createNullAttr(self);
123 assert(attr);
124 auto tenType = cast<TensorType>(self);
125 return arith::ConstantOp::create(builder, loc, tenType,
126 cast<TypedAttr>(attr));
127 }
128
129 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
130 Value b) const {
131 auto tenType = cast<TensorType>(self);
132 auto ET = tenType.getElementType();
133 auto iface = cast<AutoDiffTypeInterface>(ET);
134 return iface.createAddOp(builder, loc, a, b);
135 }
136
137 Value createConjOp(Type self, OpBuilder &builder, Location loc,
138 Value a) const {
139 auto tenType = cast<TensorType>(self);
140 auto ET = tenType.getElementType();
141 auto iface = cast<AutoDiffTypeInterface>(ET);
142 auto added = iface.createConjOp(builder, loc, a);
143 return added;
144 }
145
146 Type getShadowType(Type self, int64_t width) const {
147 return batchType(self, width);
148 }
149
150 bool isMutable(Type self) const { return false; }
151 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
152 Value val) const {
153 return failure();
154 }
155
156 bool isZero(Type self, Value val) const {
157 auto tenType = cast<TensorType>(self);
158 auto ET = tenType.getElementType();
159 DenseElementsAttr eAttr;
160
161 if (!matchPattern(val, m_Constant(&eAttr)))
162 return false;
163
164 if (!eAttr.isSplat())
165 return false;
166 // recurse on the individual element type
167 auto splatVal = eAttr.getSplatValue<Attribute>();
168 auto ADET = dyn_cast<AutoDiffTypeInterface>(ET);
169 return ADET && ADET.isZeroAttr(splatVal);
170 }
171
172 bool isZeroAttr(Type self, Attribute attr) const {
173 auto eAttr = dyn_cast<DenseElementsAttr>(attr);
174 if (!eAttr)
175 return false;
176
177 if (!eAttr.isSplat())
178 return false;
179
180 auto ET = eAttr.getType().getElementType();
181 auto ADET = dyn_cast<AutoDiffTypeInterface>(ET);
182
183 if (!ADET)
184 return false;
185
186 return ADET.isZeroAttr(eAttr.getSplatValue<Attribute>());
187 }
188
189 int64_t getApproxSize(Type self) const {
190 auto tenType = cast<TensorType>(self);
191 auto elType = cast<AutoDiffTypeInterface>(tenType.getElementType());
192 if (!elType)
193 return INT64_MAX;
194 int64_t sz = elType.getApproxSize();
195 if (sz == INT64_MAX)
196 return sz;
197 for (auto n : tenType.getShape())
198 sz *= n;
199 return sz;
200 }
201};
202
203template <typename T>
204class IntegerTypeInterface
205 : public AutoDiffTypeInterface::ExternalModel<IntegerTypeInterface<T>, T> {
206public:
207 Attribute createNullAttr(Type self) const {
208 if (isa<IndexType>(self)) {
209 return IntegerAttr::get(self, APInt(64, 0));
210 } else {
211 return IntegerAttr::get(self, APInt(self.getIntOrFloatBitWidth(), 0));
212 }
213 }
214
215 Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
216 if (isa<IndexType>(self)) {
217 return arith::ConstantIndexOp::create(builder, loc, 0);
218 }
219 return arith::ConstantIntOp::create(builder, loc, self, 0);
220 }
221
222 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
223 Value b) const {
224 return arith::AddIOp::create(builder, loc, a, b);
225 }
226
227 Value createConjOp(Type self, OpBuilder &builder, Location loc,
228 Value a) const {
229 return a;
230 }
231
232 Type getShadowType(Type self, int64_t width) const {
233 return batchType(self, width);
234 }
235
236 bool isMutable(Type self) const { return false; }
237 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
238 Value val) const {
239 return failure();
240 }
241
242 bool isZero(Type self, Value val) const {
243 return matchPattern(val, m_Zero());
244 }
245
246 bool isZeroAttr(Type self, Attribute attr) const {
247 return matchPattern(attr, m_Zero());
248 }
249
250 int64_t getApproxSize(Type self) const {
251 // Assume index is 64-bit for ease
252 if (self.isIndex())
253 return 64;
254
255 return self.getIntOrFloatBitWidth();
256 }
257};
258
259class ComplexTypeInterface
260 : public AutoDiffTypeInterface::ExternalModel<ComplexTypeInterface,
261 ComplexType> {
262public:
263 Attribute createNullAttr(Type self) const {
264 auto fltType = cast<FloatType>(cast<ComplexType>(self).getElementType());
265 auto zattr = cast<AutoDiffTypeInterface>(fltType).createNullAttr();
266 mlir::Attribute attrs[2] = {zattr, zattr};
267 return ArrayAttr::get(self.getContext(), attrs);
268 }
269 Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
270 return complex::ConstantOp::create(builder, loc, self,
271 cast<ArrayAttr>(createNullAttr(self)));
272 }
273
274 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
275 Value b) const {
276 return complex::AddOp::create(builder, loc, a, b)->getResult(0);
277 }
278 Value createConjOp(Type self, OpBuilder &builder, Location loc,
279 Value a) const {
280 return complex::ConjOp::create(builder, loc, a)->getResult(0);
281 }
282
283 Type getShadowType(Type self, int64_t width) const {
284 return batchType(self, width);
285 }
286
287 bool isMutable(Type self) const { return false; }
288 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
289 Value val) const {
290 return failure();
291 }
292
293 bool isZero(Type self, Value val) const {
294 ArrayAttr arrayAttr;
295
296 if (!matchPattern(val, m_Constant(&arrayAttr))) {
297 return false;
298 }
299 // reuse attributr check
300 return this->isZeroAttr(self, arrayAttr);
301 }
302
303 bool isZeroAttr(Type self, Attribute attr) const {
304 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
305 if (!arrayAttr || arrayAttr.size() != 2)
306 return false;
307
308 // get the element type
309 auto compType = cast<ComplexType>(self);
310 auto elType = compType.getElementType();
311 auto eltIntf = dyn_cast<AutoDiffTypeInterface>(elType);
312
313 if (!eltIntf)
314 return false;
315
316 // recurse and accumulate info per attribute
317 for (auto eltAttr : arrayAttr) {
318 if (!eltIntf.isZeroAttr(eltAttr))
319 return false;
320 }
321
322 return true;
323 }
324
325 int64_t getApproxSize(Type self) const {
326 auto elType =
327 cast<AutoDiffTypeInterface>(cast<ComplexType>(self).getElementType());
328 auto elSize = elType.getApproxSize();
329 if (elSize == INT64_MAX)
330 return elSize;
331 return 2 * elSize;
332 }
333};
334} // namespace
335
337 DialectRegistry &registry) {
338 registry.addExtension(+[](MLIRContext *context, BuiltinDialect *) {
339 BFloat16Type::attachInterface<FloatTypeInterface<BFloat16Type>>(*context);
340 Float16Type::attachInterface<FloatTypeInterface<Float16Type>>(*context);
341 Float32Type::attachInterface<FloatTypeInterface<Float32Type>>(*context);
342 Float64Type::attachInterface<FloatTypeInterface<Float64Type>>(*context);
343 IntegerType::attachInterface<IntegerTypeInterface<IntegerType>>(*context);
344 IndexType::attachInterface<IntegerTypeInterface<IndexType>>(*context);
345 UnrankedTensorType::attachInterface<TensorTypeInterface>(*context);
346 RankedTensorType::attachInterface<TensorTypeInterface>(*context);
347 ComplexType::attachInterface<ComplexTypeInterface>(*context);
348 });
349}
Type getShadowType(Type type, unsigned width)
static bool isZero(llvm::Constant *cst)
static bool isMutable(Type type)
Definition Ops.cpp:235
void registerBuiltinDialectAutoDiffInterface(DialectRegistry &registry)