24 auto xType = cast<RankedTensorType>(x.getType());
25 auto elemType = xType.getElementType();
27 auto oneConst = arith::ConstantOp::create(
29 DenseElementsAttr::get(xType, builder.getFloatAttr(elemType, 1.0)));
30 auto oneMinusX = arith::SubFOp::create(builder, loc, oneConst, x);
31 auto logX = math::LogOp::create(builder, loc, x);
32 auto logOneMinusX = math::LogOp::create(builder, loc, oneMinusX);
33 return arith::SubFOp::create(builder, loc, logX, logOneMinusX);
79 Value constrained, impulse::SupportAttr support) {
80 auto kind = support.getKind();
81 auto xType = cast<RankedTensorType>(constrained.getType());
82 auto elemType = xType.getElementType();
85 case impulse::SupportKind::REAL:
88 case impulse::SupportKind::POSITIVE:
89 return math::LogOp::create(builder, loc, constrained);
90 case impulse::SupportKind::UNIT_INTERVAL:
92 case impulse::SupportKind::INTERVAL: {
94 auto lowerAttr = support.getLowerBound();
95 auto upperAttr = support.getUpperBound();
96 if (!lowerAttr || !upperAttr) {
97 llvm_unreachable(
"INTERVAL support requires lower and upper bounds");
99 double lower = lowerAttr.getValueAsDouble();
100 double upper = upperAttr.getValueAsDouble();
102 auto lowerConst = arith::ConstantOp::create(
104 DenseElementsAttr::get(xType, builder.getFloatAttr(elemType, lower)));
105 auto scaleConst = arith::ConstantOp::create(
107 DenseElementsAttr::get(xType,
108 builder.getFloatAttr(elemType, upper - lower)));
110 auto shifted = arith::SubFOp::create(builder, loc, constrained, lowerConst);
111 auto normalized = arith::DivFOp::create(builder, loc, shifted, scaleConst);
114 case impulse::SupportKind::GREATER_THAN: {
116 auto lowerAttr = support.getLowerBound();
118 llvm_unreachable(
"GREATER_THAN support requires lower bound");
120 double lower = lowerAttr.getValueAsDouble();
122 auto lowerConst = arith::ConstantOp::create(
124 DenseElementsAttr::get(xType, builder.getFloatAttr(elemType, lower)));
125 auto shifted = arith::SubFOp::create(builder, loc, constrained, lowerConst);
126 return math::LogOp::create(builder, loc, shifted);
128 case impulse::SupportKind::LESS_THAN: {
130 auto upperAttr = support.getUpperBound();
132 llvm_unreachable(
"LESS_THAN support requires upper bound");
134 double upper = upperAttr.getValueAsDouble();
136 auto upperConst = arith::ConstantOp::create(
138 DenseElementsAttr::get(xType, builder.getFloatAttr(elemType, upper)));
139 auto shifted = arith::SubFOp::create(builder, loc, upperConst, constrained);
140 return math::LogOp::create(builder, loc, shifted);
143 llvm_unreachable(
"Unknown SupportKind");
147 Value unconstrained, impulse::SupportAttr support) {
148 auto kind = support.getKind();
149 auto zType = cast<RankedTensorType>(unconstrained.getType());
150 auto elemType = zType.getElementType();
153 case impulse::SupportKind::REAL:
155 return unconstrained;
156 case impulse::SupportKind::POSITIVE:
157 return math::ExpOp::create(builder, loc, unconstrained);
158 case impulse::SupportKind::UNIT_INTERVAL:
160 return impulse::LogisticOp::create(builder, loc, unconstrained.getType(),
162 case impulse::SupportKind::INTERVAL: {
164 auto lowerAttr = support.getLowerBound();
165 auto upperAttr = support.getUpperBound();
166 if (!lowerAttr || !upperAttr) {
167 llvm_unreachable(
"INTERVAL support requires lower and upper bounds");
169 double lower = lowerAttr.getValueAsDouble();
170 double upper = upperAttr.getValueAsDouble();
172 auto lowerConst = arith::ConstantOp::create(
174 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, lower)));
175 auto scaleConst = arith::ConstantOp::create(
177 DenseElementsAttr::get(zType,
178 builder.getFloatAttr(elemType, upper - lower)));
180 auto sigmoid = impulse::LogisticOp::create(
181 builder, loc, unconstrained.getType(), unconstrained);
182 auto scaled = arith::MulFOp::create(builder, loc, scaleConst, sigmoid);
183 return arith::AddFOp::create(builder, loc, lowerConst, scaled);
185 case impulse::SupportKind::GREATER_THAN: {
187 auto lowerAttr = support.getLowerBound();
189 llvm_unreachable(
"GREATER_THAN support requires lower bound");
191 double lower = lowerAttr.getValueAsDouble();
193 auto lowerConst = arith::ConstantOp::create(
195 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, lower)));
196 auto expZ = math::ExpOp::create(builder, loc, unconstrained);
197 return arith::AddFOp::create(builder, loc, lowerConst, expZ);
199 case impulse::SupportKind::LESS_THAN: {
201 auto upperAttr = support.getUpperBound();
203 llvm_unreachable(
"LESS_THAN support requires upper bound");
205 double upper = upperAttr.getValueAsDouble();
207 auto upperConst = arith::ConstantOp::create(
209 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, upper)));
210 auto expZ = math::ExpOp::create(builder, loc, unconstrained);
211 return arith::SubFOp::create(builder, loc, upperConst, expZ);
215 llvm_unreachable(
"Unknown SupportKind");
220 impulse::SupportAttr support) {
221 auto kind = support.getKind();
222 auto zType = cast<RankedTensorType>(unconstrained.getType());
223 auto elemType = zType.getElementType();
224 auto scalarType = RankedTensorType::get({}, elemType);
227 case impulse::SupportKind::REAL: {
229 return arith::ConstantOp::create(
230 builder, loc, scalarType,
231 DenseElementsAttr::get(scalarType,
232 builder.getFloatAttr(elemType, 0.0)));
234 case impulse::SupportKind::POSITIVE: {
237 auto ones = arith::ConstantOp::create(
239 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, 1.0)));
240 return impulse::DotOp::create(
241 builder, loc, scalarType, unconstrained, ones,
242 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
243 builder.getDenseI64ArrayAttr({0}), builder.getDenseI64ArrayAttr({0}));
245 case impulse::SupportKind::UNIT_INTERVAL: {
250 auto negZ = arith::NegFOp::create(builder, loc, unconstrained);
252 auto logProduct = arith::AddFOp::create(builder, loc, logSigZ, logSigNegZ);
253 auto ones = arith::ConstantOp::create(
255 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, 1.0)));
256 return impulse::DotOp::create(
257 builder, loc, scalarType, logProduct, ones,
258 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
259 builder.getDenseI64ArrayAttr({0}), builder.getDenseI64ArrayAttr({0}));
261 case impulse::SupportKind::INTERVAL: {
263 auto lowerAttr = support.getLowerBound();
264 auto upperAttr = support.getUpperBound();
265 if (!lowerAttr || !upperAttr) {
266 llvm_unreachable(
"INTERVAL support requires lower and upper bounds");
268 double scale = upperAttr.getValueAsDouble() - lowerAttr.getValueAsDouble();
271 auto negZ = arith::NegFOp::create(builder, loc, unconstrained);
273 auto logProduct = arith::AddFOp::create(builder, loc, logSigZ, logSigNegZ);
274 auto ones = arith::ConstantOp::create(
276 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, 1.0)));
277 auto sumLogProduct = impulse::DotOp::create(
278 builder, loc, scalarType, logProduct, ones,
279 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
280 builder.getDenseI64ArrayAttr({0}), builder.getDenseI64ArrayAttr({0}));
282 int64_t n = zType.getNumElements();
283 double logScaleTerm = n * std::log(scale);
284 auto logScaleConst = arith::ConstantOp::create(
285 builder, loc, scalarType,
286 DenseElementsAttr::get(scalarType,
287 builder.getFloatAttr(elemType, logScaleTerm)));
288 return arith::AddFOp::create(builder, loc, sumLogProduct, logScaleConst);
290 case impulse::SupportKind::GREATER_THAN:
291 case impulse::SupportKind::LESS_THAN: {
293 auto ones = arith::ConstantOp::create(
295 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, 1.0)));
296 return impulse::DotOp::create(
297 builder, loc, scalarType, unconstrained, ones,
298 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
299 builder.getDenseI64ArrayAttr({0}), builder.getDenseI64ArrayAttr({0}));
303 llvm_unreachable(
"Unknown SupportKind");