16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Complex/IR/Complex.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinDialect.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/DialectRegistry.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/Support/LLVM.h"
30static mlir::Type batchType(mlir::Type type, int64_t width) {
34 if (
auto TT = dyn_cast<mlir::TensorType>(type)) {
35 SmallVector<int64_t> shape;
36 shape.reserve(TT.getShape().size() + 1);
37 shape.push_back(width);
38 shape.append(TT.getShape().begin(), TT.getShape().end());
39 return TT.clone(shape);
42 return RankedTensorType::get({width}, type);
45template <
typename ConcreteType>
46class FloatTypeInterface :
public AutoDiffTypeInterface::ExternalModel<
47 FloatTypeInterface<ConcreteType>, ConcreteType> {
49 Attribute createNullAttr(Type self)
const {
50 auto fltType = cast<ConcreteType>(self);
51 return FloatAttr::get(fltType, APFloat(fltType.getFloatSemantics(), 0));
54 Value createNullValue(Type self, OpBuilder &builder, Location loc)
const {
55 auto fltType = cast<ConcreteType>(self);
56 return arith::ConstantOp::create(builder, loc, fltType,
57 cast<FloatAttr>(createNullAttr(self)));
60 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
62 return arith::AddFOp::create(builder, loc, a, b);
64 Value createConjOp(Type self, OpBuilder &builder, Location loc,
70 return batchType(self, width);
73 bool isMutable(Type self)
const {
return false; }
75 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
80 bool isZero(Type self, Value val)
const {
81 return matchPattern(val, m_AnyZeroFloat());
84 bool isZeroAttr(Type self, Attribute attr)
const {
85 return matchPattern(attr, m_AnyZeroFloat());
88 int64_t getApproxSize(Type self)
const {
89 return self.getIntOrFloatBitWidth();
93class TensorTypeInterface
94 :
public AutoDiffTypeInterface::ExternalModel<TensorTypeInterface,
97 Attribute createNullAttr(Type self)
const {
98 auto tenType = cast<TensorType>(self);
99 auto ET = tenType.getElementType();
101 if (
auto F = dyn_cast<FloatType>(ET)) {
102 APFloat apvalue(F.getFloatSemantics(), 0);
103 return DenseElementsAttr::get(tenType, apvalue);
105 if (
auto G = dyn_cast<ComplexType>(ET)) {
106 if (
auto F = dyn_cast<FloatType>(G.getElementType())) {
107 APFloat apvalue(F.getFloatSemantics(), 0);
108 mlir::Complex<APFloat> c(apvalue, apvalue);
109 return DenseElementsAttr::get(tenType, c);
112 if (
auto IT = dyn_cast<IntegerType>(ET)) {
113 APInt apvalue(IT.getWidth(), 0);
114 return DenseElementsAttr::get(tenType, apvalue);
116 llvm::errs() <<
" cannot create null value of tensor type: " << tenType
121 Value createNullValue(Type self, OpBuilder &builder, Location loc)
const {
122 auto attr = createNullAttr(self);
124 auto tenType = cast<TensorType>(self);
125 return arith::ConstantOp::create(builder, loc, tenType,
126 cast<TypedAttr>(attr));
129 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
131 auto tenType = cast<TensorType>(self);
132 auto ET = tenType.getElementType();
133 auto iface = cast<AutoDiffTypeInterface>(ET);
134 return iface.createAddOp(builder, loc, a, b);
137 Value createConjOp(Type self, OpBuilder &builder, Location loc,
139 auto tenType = cast<TensorType>(self);
140 auto ET = tenType.getElementType();
141 auto iface = cast<AutoDiffTypeInterface>(ET);
142 auto added = iface.createConjOp(builder, loc, a);
147 return batchType(self, width);
150 bool isMutable(Type self)
const {
return false; }
151 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
156 bool isZero(Type self, Value val)
const {
157 auto tenType = cast<TensorType>(self);
158 auto ET = tenType.getElementType();
159 DenseElementsAttr eAttr;
161 if (!matchPattern(val, m_Constant(&eAttr)))
164 if (!eAttr.isSplat())
167 auto splatVal = eAttr.getSplatValue<Attribute>();
168 auto ADET = dyn_cast<AutoDiffTypeInterface>(ET);
169 return ADET && ADET.isZeroAttr(splatVal);
172 bool isZeroAttr(Type self, Attribute attr)
const {
173 auto eAttr = dyn_cast<DenseElementsAttr>(attr);
177 if (!eAttr.isSplat())
180 auto ET = eAttr.getType().getElementType();
181 auto ADET = dyn_cast<AutoDiffTypeInterface>(ET);
186 return ADET.isZeroAttr(eAttr.getSplatValue<Attribute>());
189 int64_t getApproxSize(Type self)
const {
190 auto tenType = cast<TensorType>(self);
191 auto elType = cast<AutoDiffTypeInterface>(tenType.getElementType());
194 int64_t sz = elType.getApproxSize();
197 for (
auto n : tenType.getShape())
204class IntegerTypeInterface
205 :
public AutoDiffTypeInterface::ExternalModel<IntegerTypeInterface<T>, T> {
207 Attribute createNullAttr(Type self)
const {
208 if (isa<IndexType>(self)) {
209 return IntegerAttr::get(self, APInt(64, 0));
211 return IntegerAttr::get(self, APInt(self.getIntOrFloatBitWidth(), 0));
215 Value createNullValue(Type self, OpBuilder &builder, Location loc)
const {
216 if (isa<IndexType>(self)) {
217 return arith::ConstantIndexOp::create(builder, loc, 0);
219 return arith::ConstantIntOp::create(builder, loc, self, 0);
222 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
224 return arith::AddIOp::create(builder, loc, a, b);
227 Value createConjOp(Type self, OpBuilder &builder, Location loc,
233 return batchType(self, width);
236 bool isMutable(Type self)
const {
return false; }
237 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
242 bool isZero(Type self, Value val)
const {
243 return matchPattern(val, m_Zero());
246 bool isZeroAttr(Type self, Attribute attr)
const {
247 return matchPattern(attr, m_Zero());
250 int64_t getApproxSize(Type self)
const {
255 return self.getIntOrFloatBitWidth();
259class ComplexTypeInterface
260 :
public AutoDiffTypeInterface::ExternalModel<ComplexTypeInterface,
263 Attribute createNullAttr(Type self)
const {
264 auto fltType = cast<FloatType>(cast<ComplexType>(self).getElementType());
265 auto zattr = cast<AutoDiffTypeInterface>(fltType).createNullAttr();
266 mlir::Attribute attrs[2] = {zattr, zattr};
267 return ArrayAttr::get(self.getContext(), attrs);
269 Value createNullValue(Type self, OpBuilder &builder, Location loc)
const {
270 return complex::ConstantOp::create(builder, loc, self,
271 cast<ArrayAttr>(createNullAttr(self)));
274 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
276 return complex::AddOp::create(builder, loc, a, b)->getResult(0);
278 Value createConjOp(Type self, OpBuilder &builder, Location loc,
280 return complex::ConjOp::create(builder, loc, a)->getResult(0);
284 return batchType(self, width);
287 bool isMutable(Type self)
const {
return false; }
288 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
293 bool isZero(Type self, Value val)
const {
296 if (!matchPattern(val, m_Constant(&arrayAttr))) {
300 return this->isZeroAttr(self, arrayAttr);
303 bool isZeroAttr(Type self, Attribute attr)
const {
304 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
305 if (!arrayAttr || arrayAttr.size() != 2)
309 auto compType = cast<ComplexType>(self);
310 auto elType = compType.getElementType();
311 auto eltIntf = dyn_cast<AutoDiffTypeInterface>(elType);
317 for (
auto eltAttr : arrayAttr) {
318 if (!eltIntf.isZeroAttr(eltAttr))
325 int64_t getApproxSize(Type self)
const {
327 cast<AutoDiffTypeInterface>(cast<ComplexType>(self).getElementType());
328 auto elSize = elType.getApproxSize();
329 if (elSize == INT64_MAX)
337 DialectRegistry ®istry) {
338 registry.addExtension(+[](MLIRContext *context, BuiltinDialect *) {
339 BFloat16Type::attachInterface<FloatTypeInterface<BFloat16Type>>(*context);
340 Float16Type::attachInterface<FloatTypeInterface<Float16Type>>(*context);
341 Float32Type::attachInterface<FloatTypeInterface<Float32Type>>(*context);
342 Float64Type::attachInterface<FloatTypeInterface<Float64Type>>(*context);
343 IntegerType::attachInterface<IntegerTypeInterface<IntegerType>>(*context);
344 IndexType::attachInterface<IntegerTypeInterface<IndexType>>(*context);
345 UnrankedTensorType::attachInterface<TensorTypeInterface>(*context);
346 RankedTensorType::attachInterface<TensorTypeInterface>(*context);
347 ComplexType::attachInterface<ComplexTypeInterface>(*context);
Type getShadowType(Type type, unsigned width)
static bool isZero(llvm::Constant *cst)
static bool isMutable(Type type)
void registerBuiltinDialectAutoDiffInterface(DialectRegistry ®istry)