Enzyme main
Loading...
Searching...
No Matches
TransformUtils.h
Go to the documentation of this file.
1//===- TransformUtils.h - Constraint transforms for HMC --------* C++ -*-===//
2//
3// This file declares utility functions for constraint transforms for HMC
4// inference.
5//
6// Reference:
7// https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/transforms.py
8//
9//===----------------------------------------------------------------------===//
10
11#ifndef ENZYME_MLIR_INTERFACES_TRANSFORM_UTILS_H
12#define ENZYME_MLIR_INTERFACES_TRANSFORM_UTILS_H
13
15#include "Dialect/Ops.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Math/IR/Math.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/Value.h"
21
22namespace mlir {
23namespace enzyme {
24namespace transforms {
25
26/// Get the unconstrained size given a constrained size and support kind.
27int64_t getUnconstrainedSize(int64_t constrainedSize,
28 impulse::SupportKind kind);
29
30/// Get the constrained size given an unconstrained size and support kind.
31int64_t getConstrainedSize(int64_t unconstrainedSize,
32 impulse::SupportKind kind);
33
34/// Transform from constrained to unconstrained space.
35Value unconstrain(OpBuilder &builder, Location loc, Value constrained,
36 impulse::SupportAttr support);
37
38/// Transform from unconstrained to constrained space.
39Value constrain(OpBuilder &builder, Location loc, Value unconstrained,
40 impulse::SupportAttr support);
41
42/// Compute log |det J| of the transform from unconstrained to constrained.
43Value logAbsDetJacobian(OpBuilder &builder, Location loc, Value unconstrained,
44 impulse::SupportAttr support);
45
46Value createLogit(OpBuilder &builder, Location loc, Value x);
47Value createLogSigmoid(OpBuilder &builder, Location loc, Value x);
48} // namespace transforms
49} // namespace enzyme
50} // namespace mlir
51
52#endif // ENZYME_MLIR_INTERFACES_TRANSFORM_UTILS_H
Value logAbsDetJacobian(OpBuilder &builder, Location loc, Value unconstrained, impulse::SupportAttr support)
Compute log |det J| of the transform from unconstrained to constrained.
Value unconstrain(OpBuilder &builder, Location loc, Value constrained, impulse::SupportAttr support)
Transform from constrained to unconstrained space.
int64_t getConstrainedSize(int64_t unconstrainedSize, impulse::SupportKind kind)
Get the constrained size given an unconstrained size and support kind.
Value constrain(OpBuilder &builder, Location loc, Value unconstrained, impulse::SupportAttr support)
Transform from unconstrained to constrained space.
int64_t getUnconstrainedSize(int64_t constrainedSize, impulse::SupportKind kind)
Get the unconstrained size given a constrained size and support kind.
Value createLogSigmoid(OpBuilder &builder, Location loc, Value x)
Value createLogit(OpBuilder &builder, Location loc, Value x)