StableHLO Dialect
Refer to the official documentation for more details.
Reactant.MLIR.Dialects.stablehlo.abs Method
abs
Performs element-wise abs operation on operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#abs
Example
%result = stablehlo.abs %operand : tensor<3xi32>Reactant.MLIR.Dialects.stablehlo.add Method
add
Performs element-wise addition of two tensors lhs and rhs and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#add
Example
%result = stablehlo.add %lhs, %rhs : tensor<2x2xi32>Reactant.MLIR.Dialects.stablehlo.after_all Method
after_all
Ensures that the operations producing the inputs are executed before any operations that depend on result.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#after_all
Example
%result = stablehlo.after_all %input0, %input1 : !stablehlo.tokenReactant.MLIR.Dialects.stablehlo.all_gather Method
all_gather
Within each process group in the process grid, concatenates the values of the operand tensor from each process along all_gather_dim and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_gather
Example
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)Reactant.MLIR.Dialects.stablehlo.all_reduce Method
all_reduce
Within each process group in the process grid, applies a reduction function computation to the values of the operand tensor from each process and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_reduce
Example
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)Reactant.MLIR.Dialects.stablehlo.all_to_all Method
all_to_all
Within each process group in the process grid, splits the values of the operand tensor along split_dimension into parts, scatters the split parts between the processes, concatenates the scattered parts along concat_dimension and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_to_all
Example
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)Reactant.MLIR.Dialects.stablehlo.and Method
and
Performs element-wise AND of two tensors lhs and rhs and produces a result tensor
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#and
Example
%result = stablehlo.and %lhs, %rhs : tensor<2x2xi32>Reactant.MLIR.Dialects.stablehlo.atan2 Method
atan2
Performs element-wise atan2 operation on lhs and rhs tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#atan2
Example
%result = stablehlo.atan2 %lhs, %rhs : tensor<3xf64>Reactant.MLIR.Dialects.stablehlo.batch_norm_grad Method
batch_norm_grad
Computes gradients of several inputs of BatchNormTrainingOp backpropagating from grad_output, and produces grad_operand, grad_scale and grad_offset tensors.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#batch_norm_grad
Example
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)Reactant.MLIR.Dialects.stablehlo.batch_norm_inference Method
batch_norm_inference
Normalizes the operand tensor across all dimensions except for the feature_index dimension and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#batch_norm_inference
Example
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>Reactant.MLIR.Dialects.stablehlo.batch_norm_training Method
batch_norm_training
Computes mean and variance across batch and spatial dimensions and normalizes the operand tensor, for each feature in the feature_index dimension and produces output, batch_mean and batch_var tensors.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#batch_norm_training
Example
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)Reactant.MLIR.Dialects.stablehlo.bitcast_convert Method
bitcast_convert
Performs a bitcast operation on operand tensor and produces a result tensor where the bits of the entire operand tensor are reinterpreted using the type of the result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#bitcast_convert
Example
%result = stablehlo.bitcast_convert %operand : (tensor<f64>) -> tensor<4xf16>Reactant.MLIR.Dialects.stablehlo.broadcast Method
broadcast
This operation is on its way out of StableHLO, so it is not included in the StableHLO specification: https://github.com/openxla/stablehlo/issues/3.
Informally, this operation does the same thing as XLA's Broadcast: https://www.tensorflow.org/xla/operation_semantics#broadcast
Example
%result = stablehlo.broadcast %operand, sizes = [1, 2] : (tensor<3xi32>) -> tensor<1x2x3xi32>Reactant.MLIR.Dialects.stablehlo.broadcast_in_dim Method
broadcast_in_dim
Expands the dimensions and/or rank of an input tensor by duplicating the data in the operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim
Example
%result = stablehlo.broadcast_in_dim %operand, dims = [2, 1] : (tensor<1x3xi32>) -> tensor<2x3x2xi32>Reactant.MLIR.Dialects.stablehlo.case Method
case
Produces the output from executing exactly one function from branches depending on the value of index.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case
Example
%result0, %result1 = "stablehlo.case"(%index) ({
stablehlo.return %result_branch0, %result_branch0 : tensor<2xi64>, tensor<2xi64>
}, {
stablehlo.return %result_branch1, %result_branch1 : tensor<2xi64>, tensor<2xi64>
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)Reactant.MLIR.Dialects.stablehlo.cbrt Method
cbrt
Performs element-wise cubic root operation on operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cbrt
Example
%result = stablehlo.cbrt %operand : tensor<4xf64>Reactant.MLIR.Dialects.stablehlo.ceil Method
ceil
Performs element-wise ceil of operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#ceil
Example
%result = stablehlo.ceil %operand : tensor<5xf32>Reactant.MLIR.Dialects.stablehlo.cholesky Method
cholesky
Computes the Cholesky decomposition of a batch of matrices.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cholesky
Example
%result = stablehlo.cholesky %a, lower = true : tensor<3x3xf64>Reactant.MLIR.Dialects.stablehlo.clamp Method
clamp
Clamps every element of the operand tensor between a minimum and maximum value and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#clamp
Example
%result = stablehlo.clamp %min, %operand, %max : tensor<3xi32>Reactant.MLIR.Dialects.stablehlo.collective_broadcast Method
collective_broadcast
Within each process group in the process grid, send the value of the operand tensor from the source process to the target processes and produce a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_broadcast
Example
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<1x2xi64>) -> tensor<1x2xi64>Reactant.MLIR.Dialects.stablehlo.collective_permute Method
collective_permute
Within each process group in the process grid, sends the value of the operand tensor from the source process to the target process and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_permute
Example
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>Reactant.MLIR.Dialects.stablehlo.compare Method
compare
Performs element-wise comparison of lhs and rhs tensors according to comparison_direction and compare_type, and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#compare
Example
%result = stablehlo.compare LT, %lhs, %rhs, FLOAT : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>Reactant.MLIR.Dialects.stablehlo.complex Method
complex
Performs element-wise conversion to a complex value from a pair of real and imaginary values, lhs and rhs, and produces a result tensor. See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#complex
Example
%result = stablehlo.complex %lhs, %rhs : tensor<2xcomplex<f64>>Reactant.MLIR.Dialects.stablehlo.composite Method
composite
Encapsulates an operation made up (composed) of other StableHLO operations, taking inputs and composite_attributes and producing results. The semantics of the op are implemented by the decomposition attribute. The composite op can be replaced with its decomposition without changing program semantics. In cases where inlining the decomposition does not provide the same op semantics, prefer using custom_call.
The version field (defaults to 0) is used to denote when a composite's semantics change.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#composite
Example
%results = stablehlo.composite "my.op" %input0, %input1 {
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>Reactant.MLIR.Dialects.stablehlo.concatenate Method
concatenate
Concatenates a variadic number of tensors in inputs along dimension dimension in the same order as the given arguments and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#concatenate
Example
%result = stablehlo.concatenate %input0, %input1, dim = 0 : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>Reactant.MLIR.Dialects.stablehlo.constant Method
constant
Produces an output tensor from a constant value.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
Example
%output = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>Reactant.MLIR.Dialects.stablehlo.convert Method
convert
Performs an element-wise conversion from one element type to another on operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convert
Example
%result = stablehlo.convert %operand : (tensor<3xi64>) -> tensor<3xcomplex<f64>>Reactant.MLIR.Dialects.stablehlo.convolution Method
convolution
Computes dot products between windows of lhs and slices of rhs and produces result.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution
Example
%result = stablehlo.convolution(%lhs, %rhs)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {
stride = [4, 4],
pad = [[0, 0], [0, 0]],
lhs_dilate = [2, 2],
rhs_dilate = [1, 1],
reverse = [0, 0]
} {
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} :
(tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>Reactant.MLIR.Dialects.stablehlo.cosine Method
cosine
Performs element-wise cosine operation on operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cosine
Example
%result = stablehlo.cosine %operand : tensor<2xf32>Reactant.MLIR.Dialects.stablehlo.count_leading_zeros Method
count_leading_zeros
Performs element-wise count of the number of leading zero bits in the operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#count_leading_zeros
Example
%result = stablehlo.count_leading_zeros %operand : tensor<2x2xi64>Reactant.MLIR.Dialects.stablehlo.create_token Method
create_token
This operation is on its way out of StableHLO, so it is not included in the StableHLO specification: https://github.com/openxla/stablehlo/issues/3.
Informally, this operation does the same thing as AfterAllOp with 0 inputs: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#after_all
Example
%output = stablehlo.create_token : !stablehlo.tokenReactant.MLIR.Dialects.stablehlo.cross_replica_sum Method
cross_replica_sum
This operation is on its way out of StableHLO, so it is not included in the StableHLO specification: https://github.com/openxla/stablehlo/issues/3.
Informally, this operation does the same thing as AllReduceOp with channel_id = 0, use_global_device_ids = false and computation implementing addition: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_reduce
Example
%result = "stablehlo.cross-replica-sum"(%operand) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<4xf32>) -> tensor<4xf32>Reactant.MLIR.Dialects.stablehlo.custom_call Method
custom_call
Encapsulates an implementation-defined operation call_target_name that takes inputs and called_computations and produces results.
Depending on the API version there are two ways to pass extra bits of static information to the external function:
Use
API_VERSION_TYPED_FFIwhich allows passing a dictionary attribute.Use a previous API version with a StringAttr to encode backend config.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#custom_call
Example
%results = stablehlo.custom_call @foo(%input0) {
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>Reactant.MLIR.Dialects.stablehlo.divide Method
divide
Performs element-wise division of dividend lhs and divisor rhs tensors and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#divide
Example
%result = stablehlo.divide %lhs, %rhs : tensor<4xf32>Reactant.MLIR.Dialects.stablehlo.dot Method
dot
This operation is on its way out of StableHLO, so it is not included in the StableHLO specification: https://github.com/openxla/stablehlo/issues/3.
Informally, this operation does the same thing as XLA's Dot: https://www.tensorflow.org/xla/operation_semantics#dot
Example
%0 = stablehlo.dot %arg0, %arg1 : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<1x1xi32>Reactant.MLIR.Dialects.stablehlo.dot_general Method
dot_general
Computes dot products between slices of lhs and slices of rhs and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general
Example
%result = stablehlo.dot_general %lhs, %rhs,
batching_dims = [0] x [0],
contracting_dims = [2] x [1],
precision = [DEFAULT, DEFAULT],
algorithm = <lhs_precision_type = tf32, rhs_precision_type = tf32, accumulation_type = f32, lhs_component_count = 1, rhs_component_count = 1, num_primitive_operations = 1, allow_imprecise_accumulation = false>
: (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>Reactant.MLIR.Dialects.stablehlo.dynamic_broadcast_in_dim Method
dynamic_broadcast_in_dim
This operation is functionally identical to broadcast_in_dim op, but the result shape is specified dynamically via output_dimensions.
It also accepts optional attributes to express static knowledge about the expanding behavior of dimensions. If not specified, all dimensions are assumed to be possibly expanding. The sets of dimensions that are known to be expanding and the set of dimensions that are known to be non-expanding must be disjoint and they must be a subset of the operand's dimensions.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_broadcast_in_dim
Example
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>Reactant.MLIR.Dialects.stablehlo.dynamic_conv Method
dynamic_conv
This operation is functionally identical to convolution op, but the padding is specified dynamically via padding.
Example
%padding = stablehlo.constant dense<2> : tensor<2x2xi64>
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>Reactant.MLIR.Dialects.stablehlo.dynamic_gather Method
dynamic_gather
This operation is functionally identical to gather op, with the slice_sizes specified dynamically as an operand.
Example
%slice_sizes = stablehlo.constant dense<[1, 2, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slice_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [0, 2],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>Reactant.MLIR.Dialects.stablehlo.dynamic_iota Method
dynamic_iota
This operation is functionally identical to iota op, but the result shape is specified dynamically via output_shape.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_iota
Example
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%0 = stablehlo.dynamic_iota %output_shape, dim = 0 : (tensor<2xi64>) -> tensor<4x5xi64>Reactant.MLIR.Dialects.stablehlo.dynamic_pad Method
dynamic_pad
This operation is functionally identical to pad https://github.com/openxla/stablehlo/pull/2306#discussion_r1595669709 op, but with edge_padding_low,edge_padding_highandinterior_padding specified dynamically as values.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_pad
Example
%edge_padding_low = stablehlo.constant dense<[0, 1]> : tensor<2xi32>
%edge_padding_high = stablehlo.constant dense<[2, 1]> : tensor<2xi32>
%interior_padding = stablehlo.constant dense<[1, 2]> : tensor<2xi32>
%result = stablehlo.dynamic_pad %operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
: (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>Reactant.MLIR.Dialects.stablehlo.dynamic_reshape Method
dynamic_reshape
This operation is functionally identical to reshape op, but the result shape is specified dynamically via output_shape.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_reshape
Example
%output_shape = stablehlo.constant dense<[3, 2]> : tensor<2xi64>
%result = stablehlo.dynamic_reshape %operand, %output_shape : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>Reactant.MLIR.Dialects.stablehlo.dynamic_slice Method
dynamic_slice
Extracts a slice from the operand using dynamically-computed starting indices and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_slice
Example
%result = stablehlo.dynamic_slice %operand, %start_indices0, %start_indices1, sizes = [2, 2]
: (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>Reactant.MLIR.Dialects.stablehlo.dynamic_update_slice Method
dynamic_update_slice
Produces a result tensor which is equal to the operand tensor except that the slice starting at start_indices is updated with the values in update.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_update_slice
Example
%result = stablehlo.dynamic_update_slice %operand, %update, %start_indices0, %start_indices1
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>Reactant.MLIR.Dialects.stablehlo.einsum Method
einsum
This operation is on its way out of StableHLO, so it is not included in the StableHLO specification: https://github.com/openxla/stablehlo/issues/3.
Informally, this operation does the same thing as TF's einsum: https://www.tensorflow.org/api_docs/python/tf/einsum
Example
%result = "stablehlo.einsum"(%lhs, %rhs) {
einsum_config = "ab,bc->ac"
} : (tensor<4x16xf32>, tensor<16x4xf32>) -> tensor<4x4xf32>Reactant.MLIR.Dialects.stablehlo.exponential Method
exponential
Performs element-wise exponential operation on operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential
Example
%result = stablehlo.exponential %operand : tensor<2x2xf64>Reactant.MLIR.Dialects.stablehlo.exponential_minus_one Method
exponential_minus_one
Performs element-wise exponential minus one operation on operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential_minus_one
Example
%result = stablehlo.exponential_minus_one %operand : tensor<2xf64>Reactant.MLIR.Dialects.stablehlo.fft Method
fft
Performs the forward and inverse Fourier transforms for real and complex inputs/outputs.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#fft
Example
%result = stablehlo.fft %operand, type = FFT, length = [4] : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>Reactant.MLIR.Dialects.stablehlo.floor Method
floor
Performs element-wise floor of operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#floor
Example
%result = stablehlo.floor %operand : tensor<2xf32>Reactant.MLIR.Dialects.stablehlo.gather Method
gather
Gathers slices from operand tensor from offsets specified in start_indices and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather
Example
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi64>Reactant.MLIR.Dialects.stablehlo.get_dimension_size Method
get_dimension_size
Produces the size of the given dimension of the operand.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#get_dimension_size
Example
%result = stablehlo.get_dimension_size %operand, dim = 1 : (tensor<2x3xi64>) -> tensor<i32>Reactant.MLIR.Dialects.stablehlo.get_tuple_element Method
get_tuple_element
Extracts element at index position of the operand tuple and produces a result.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#get_tuple_element
Example
%result = stablehlo.get_tuple_element %operand[0] : (tuple<tensor<2xf64>, tuple<tensor<i64>>>) -> tensor<2xf64>Reactant.MLIR.Dialects.stablehlo.if_ Method
if_
Produces the output from executing exactly one branch from true_branch or false_branch depending on the value of pred.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#if
Example
%result = "stablehlo.if"(%pred) ({ "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> () }, { "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> () }) : (tensor<i1>) -> tensor<i32>
sourceReactant.MLIR.Dialects.stablehlo.imag Method
imag
Extracts the imaginary part, element-wise, from the operand and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#imag
Example
%result = stablehlo.imag %operand : (tensor<2xcomplex<f32>>) -> tensor<2xf32>Reactant.MLIR.Dialects.stablehlo.infeed Method
infeed
Reads data from the infeed and produces results.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#infeed
Example
%results0:2 = "stablehlo.infeed"(%token) :
(!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)Reactant.MLIR.Dialects.stablehlo.iota Method
iota
Fills an output tensor with values in increasing order starting from zero along the iota_dimension dimension.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota
Example
%output = stablehlo.iota dim = 0 : tensor<4x5xi32>Reactant.MLIR.Dialects.stablehlo.is_finite Method
is_finite
Performs element-wise check whether the value in x is finite (i.e. is neither +Inf, -Inf, nor NaN) and produces a y tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#is_finite
Example
%y = stablehlo.is_finite %x : (tensor<7xf64>) -> tensor<7xi1>Reactant.MLIR.Dialects.stablehlo.log Method
log
Performs element-wise logarithm operation on operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log
Example
%result = stablehlo.log %operand : tensor<2x2xf64>Reactant.MLIR.Dialects.stablehlo.log_plus_one Method
log_plus_one
Performs element-wise logarithm plus one operation on operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log_plus_one
Example
%result = stablehlo.log_plus_one %operand : tensor<5xf64>Reactant.MLIR.Dialects.stablehlo.logistic Method
logistic
Performs element-wise logistic operation on operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#logistic
Example
%result = stablehlo.logistic %operand : tensor<2x2xf64>Reactant.MLIR.Dialects.stablehlo.map Method
map
Applies a map function computation to inputs along the dimensions and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#map
Example
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>Reactant.MLIR.Dialects.stablehlo.maximum Method
maximum
Performs element-wise max operation on tensors lhs and rhs and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#maximum
Example
%result = stablehlo.maximum %lhs, %rhs : tensor<4xf32>Reactant.MLIR.Dialects.stablehlo.minimum Method
minimum
Performs element-wise min operation on tensors lhs and rhs and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#minimum
Example
%result = stablehlo.minimum %lhs, %rhs : tensor<4xf32>Reactant.MLIR.Dialects.stablehlo.multiply Method
multiply
Performs element-wise product of two tensors lhs and rhs and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#multiply
Example
%result = stablehlo.multiply %lhs, %rhs : tensor<2xi32>Reactant.MLIR.Dialects.stablehlo.negate Method
negate
Performs element-wise negation of operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#negate
Example
%result = stablehlo.negate %operand : tensor<2x3xi32>Reactant.MLIR.Dialects.stablehlo.not Method
not
Performs element-wise NOT of tensor operand of type integer and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not
Example
%result = stablehlo.not %operand : tensor<5x3x1xi1>Reactant.MLIR.Dialects.stablehlo.optimization_barrier Method
optimization_barrier
Ensures that the operations that produce the operand are executed before any operations that depend on the result and prevents compiler transformations from moving operations across the barrier. Other than that, the operation is an identity, i.e. result = operand.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier
Example
%result0, %result1 = stablehlo.optimization_barrier %operand0, %operand1 : tensor<f32>, tensor<f32>Reactant.MLIR.Dialects.stablehlo.or Method
or
Performs element-wise OR of two tensors lhs and rhs and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#or
Example
%result = stablehlo.or %lhs, %rhs : tensor<2xi1>Reactant.MLIR.Dialects.stablehlo.outfeed Method
outfeed
Writes inputs to the outfeed and produces a result token.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#outfeed
Example
%result = "stablehlo.outfeed"(%input0, %token) :
(tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.tokenReactant.MLIR.Dialects.stablehlo.pad Method
pad
Expands operand by padding around the tensor as well as between the elements of the tensor with the given padding_value.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad
Example
%0 = stablehlo.pad %arg0, %arg1, low = [0, 1], high = [2, 1], interior = [1, 2]
: (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>Reactant.MLIR.Dialects.stablehlo.partition_id Method
partition_id
Produces partition_id of the current process.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#partition_id
Example
%result = stablehlo.partition_id : tensor<ui32>Reactant.MLIR.Dialects.stablehlo.popcnt Method
popcnt
Performs element-wise count of the number of bits set in the operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#popcnt
Example
%result = stablehlo.popcnt %operand : tensor<4xi64>Reactant.MLIR.Dialects.stablehlo.power Method
power
Performs element-wise exponentiation of lhs tensor by rhs tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power
Example
%result = stablehlo.power %lhs, %rhs : tensor<6xf64>Reactant.MLIR.Dialects.stablehlo.real Method
real
Extracts the real part, element-wise, from the operand and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#real
Example
%result = stablehlo.real %operand : (tensor<2xcomplex<f32>>) -> tensor<2xf32>Reactant.MLIR.Dialects.stablehlo.real_dynamic_slice Method
real_dynamic_slice
This operation is a work in progress, so it is not yet included in the StableHLO specification: https://github.com/openxla/stablehlo/issues/8.
Informally, this operation does the same thing as SliceOp except that start_indices, limit_indices and strides are specified dynamically: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice
Example
%result = stablehlo.real_dynamic_slice %operand,
%start_indices, %limit_indices, %strides
: (tensor<256x?xf32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<256x?xf32>Reactant.MLIR.Dialects.stablehlo.recv Method
recv
Receives data from a channel with channel_id and produces results.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#recv
Example
%results:2 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 0, type = 1>,
is_host_transfer = false,
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)Reactant.MLIR.Dialects.stablehlo.reduce Method
reduce
Applies a reduction function body to inputs and init_values along the dimensions and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce
Example
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>Reactant.MLIR.Dialects.stablehlo.reduce_precision Method
reduce_precision
Performs element-wise conversion of operand to another floating-point type that uses exponent_bits and mantissa_bits and back to the original floating-point type and produces an output tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_precision
Example
%output = stablehlo.reduce_precision %operand, format = e5m10 : tensor<6xf64>Reactant.MLIR.Dialects.stablehlo.reduce_scatter Method
reduce_scatter
Within each process group in the process grid, performs reduction, using computations, over the values of the operand tensor from each process, splits the reduction result along scatter_dimension into parts, and scatters the split parts between the processes to produce the result.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec#reduce_scatter
Example:
```mlir
%result = "stablehlo.reduce_scatter"(%operand) ({^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>): %0 = stablehlo.add %arg0, %arg1 : tensor<i64> stablehlo.return %0 : tensor<i64> }) { scatter_dimension = 1 : i64, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, channel_handle = #stablehlo.channel_handle<handle = 0, type = 0> } : (tensor<2x4xi64>) -> tensor<2x2xi64> ```
sourceReactant.MLIR.Dialects.stablehlo.reduce_window Method
reduce_window
Applies a reduction function body to windows of inputs and init_values and produces results.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window
Example
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>Reactant.MLIR.Dialects.stablehlo.remainder Method
remainder
Performs element-wise remainder of dividend lhs and divisor rhs tensors and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#remainder
Example
%result = stablehlo.remainder %lhs, %rhs : tensor<4xi64>Reactant.MLIR.Dialects.stablehlo.replica_id Method
replica_id
Produces replica_id of the current process.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#replica_id
Example
%result = stablehlo.replica_id : tensor<ui32>Reactant.MLIR.Dialects.stablehlo.reshape Method
reshape
Performs reshape of operand tensor to a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape
Example
%result = stablehlo.reshape %operand : (tensor<2xf32>) -> tensor<1x2xf32>Reactant.MLIR.Dialects.stablehlo.reverse Method
reverse
Reverses the order of elements in the operand along the specified dimensions and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reverse
Example
%result = stablehlo.reverse %operand, dims = [1] : tensor<3x2xi32>Reactant.MLIR.Dialects.stablehlo.rng Method
rng
Generates random numbers using the rng_distribution algorithm and produces a result tensor of a given shape shape.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rng
Example
%result = stablehlo.rng %a, %b, %shape, distribution = NORMAL : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>Reactant.MLIR.Dialects.stablehlo.rng_bit_generator Method
rng_bit_generator
Returns an output filled with uniform random data and an updated output state output_state given an initial state initial_state using the pseudorandom number generator algorithm rng_algorithm.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rng_bit_generator
Example
%output_state, %output = stablehlo.rng_bit_generator %initial_state, algorithm = THREE_FRY : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)Reactant.MLIR.Dialects.stablehlo.round_nearest_afz Method
round_nearest_afz
Performs element-wise rounding towards the nearest integer, breaking ties away from zero, on the operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_afz
Example
%result = stablehlo.round_nearest_afz %operand : tensor<5xf64>Reactant.MLIR.Dialects.stablehlo.round_nearest_even Method
round_nearest_even
Performs element-wise rounding towards the nearest integer, breaking ties towards the even integer, on the operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even
Example
%result = stablehlo.round_nearest_even %operand : tensor<5xf64>Reactant.MLIR.Dialects.stablehlo.rsqrt Method
rsqrt
Performs element-wise reciprocal square root operation on operand tensor and produces a result tensor, implementing the rSqrt operation from the IEEE-754 specification.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rsqrt
Example
%result = stablehlo.rsqrt %operand : tensor<2x2xf32>Reactant.MLIR.Dialects.stablehlo.scatter Method
scatter
Produces results tensors which are equal to inputs tensors except that several slices specified by scatter_indices are updated with the values updates using update_computation.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter
Example: mlir %result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({ ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>): %0 = stablehlo.add %arg0, %arg1 : tensor<i64> stablehlo.return %0 : tensor<i64> }) { scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [3, 4], inserted_window_dims = [1], input_batching_dims = [0], scatter_indices_batching_dims = [1], scatter_dims_to_operand_dims = [2, 1], index_vector_dim = 3>, indices_are_sorted = false, unique_indices = false } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
Reactant.MLIR.Dialects.stablehlo.select Method
select
Produces a result tensor where each element is selected from on_true or on_false tensor based on the value of the corresponding element of pred.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#select
Example
%result = stablehlo.select %pred, %on_true, %on_false : tensor<2x2xi1>, tensor<2x2xi32>Reactant.MLIR.Dialects.stablehlo.select_and_scatter Method
select_and_scatter
Scatters the values from the source tensor using scatter based on the outcome of reduce_window of the input tensor using select and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#select_and_scatter
Example
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.compare GE, %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %0 : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
window_dimensions = array<i64: [3, 1]>,
window_strides = array<i64: [2, 1]>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>Reactant.MLIR.Dialects.stablehlo.send Method
send
Sends inputs to a channel channel_id and produces a result token.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#send
Example
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 0, type = 1>,
is_host_transfer = false,
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.tokenReactant.MLIR.Dialects.stablehlo.set_dimension_size Method
set_dimension_size
This operation is a work in progress, so it is not yet included in the StableHLO specification: https://github.com/openxla/stablehlo/issues/8.
Informally, this operation does the same thing as XLA's SetDimensionSize: https://www.tensorflow.org/xla/operation_semantics#setdimensionsize
Example
%0 = stablehlo.set_dimension_size %arg0, %arg1, dim = 1 : (tensor<4x2xf32>, tensor<i32>) -> tensor<4x2xf32>Reactant.MLIR.Dialects.stablehlo.shift_left Method
shift_left
Performs element-wise left-shift operation on the lhs tensor by rhs number of bits and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_left
Example
%result = stablehlo.shift_left %lhs, %rhs : tensor<3xi64>Reactant.MLIR.Dialects.stablehlo.shift_right_arithmetic Method
shift_right_arithmetic
Performs element-wise arithmetic right-shift operation on the lhs tensor by rhs number of bits and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_arithmetic
Example
%result = stablehlo.shift_right_arithmetic %lhs, %rhs : tensor<3xi64>Reactant.MLIR.Dialects.stablehlo.shift_right_logical Method
shift_right_logical
Performs element-wise logical right-shift operation on the lhs tensor by rhs number of bits and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_logical
Example
%result = stablehlo.shift_right_logical %lhs, %rhs : tensor<3xi64>Reactant.MLIR.Dialects.stablehlo.sign Method
sign
Returns the sign of the operand element-wise and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign
Example
%result = stablehlo.sign %operand : tensor<5xf64>Reactant.MLIR.Dialects.stablehlo.sine Method
sine
Performs element-wise sine operation on operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sine
Example
%result = stablehlo.sine %operand : tensor<2xf32>Reactant.MLIR.Dialects.stablehlo.slice Method
slice
Extracts a slice from the operand using statically-computed starting indices and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice
Example
%result = stablehlo.slice %operand [1:3, 4:8:2]
: (tensor<3x8xi64>) -> tensor<2x2xi64>
// Same in generic form: the `1:3` above is mapped to the first entry in
// `start_indices` and `limit_indices`, while `strides` is implicitly 1.
// The `4:8:2` above is parsed into the second entry of `start_indices`,
// `limit_indices` and `strides` respectively.
%result = "stablehlo.slice" (%operand) {
start_indices = array<i64: 1, 4>,
limit_indices = array<i64: 3, 8>,
strides = array<i64: 1, 2>
} : (tensor<3x8xi64>) -> tensor<2x2xi64>Reactant.MLIR.Dialects.stablehlo.sort Method
sort
Sorts a variadic number of tensors in inputs together, according to a custom comparator, along the given dimension and produces a variadic number of tensors as results.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sort
Example
<Badge type="info" class="source-link" text="source"><a href="https://github.com/EnzymeAD/Reactant.jl/blob/1e62b5988e90deb790532c54823dbf41893dbc25/src/mlir/Dialects/StableHLO.jl#L4258-L4278" target="_blank" rel="noreferrer">source</a></Badge>
</details>
<details class='jldocstring custom-block' >
<summary><a id='Reactant.MLIR.Dialects.stablehlo.sqrt-Tuple{Reactant.MLIR.IR.Value}' href='#Reactant.MLIR.Dialects.stablehlo.sqrt-Tuple{Reactant.MLIR.IR.Value}'><span class="jlbinding">Reactant.MLIR.Dialects.stablehlo.sqrt</span></a> <Badge type="info" class="jlObjectType jlMethod" text="Method" /></summary>
`sqrt`
Performs element-wise square root operation on `operand` tensor and produces a `result` tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sqrt
**Example**
```mlir
%result = stablehlo.sqrt %operand : tensor<2x2xf32>Reactant.MLIR.Dialects.stablehlo.subtract Method
subtract
Performs element-wise subtraction of two tensors lhs and rhs and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#subtract
Example
%result = stablehlo.subtract %lhs, %rhs : tensor<2xi32>Reactant.MLIR.Dialects.stablehlo.tan Method
tan
Performs element-wise tangent operation on operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tan
Example
%result = stablehlo.tan %operand : tensor<2x2xf64>Reactant.MLIR.Dialects.stablehlo.tanh Method
tanh
Performs element-wise hyperbolic tangent operation on operand tensor and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tanh
Example
%result = stablehlo.tanh %operand : tensor<2xf32>Reactant.MLIR.Dialects.stablehlo.torch_index_select Method
torch_index_select
This operation is on its way out of StableHLO, so it is not included in the StableHLO specification: https://github.com/openxla/stablehlo/issues/3.
Informally, this operation does the same thing as PyTorch's index_select, augmented with support for batch dimensions: https://pytorch.org/docs/stable/generated/torch.index_select.html.
The batch_dims attribute specifies the number of major batch dimensions (0 or more) that act like a multidimensional loop over both the operand and the index.
Example
%result = "stablehlo.torch_index_select"(%operand, %index) {
dim = 2 : i64,
batch_dims = 1 : i64
} : (tensor<8x128x3072x64xf32>, tensor<8x16x1024xi32>) -> tensor<8x128x16x1024x64xf32>Reactant.MLIR.Dialects.stablehlo.transpose Method
transpose
Permutes the dimensions of operand tensor using permutation and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose
Example
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<1x2x3xi32>) -> tensor<3x2x1xi32>Reactant.MLIR.Dialects.stablehlo.triangular_solve Method
triangular_solve
Solves batches of systems of linear equations with lower or upper triangular coefficient matrices.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#triangular_solve
Example
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>Reactant.MLIR.Dialects.stablehlo.tuple Method
tuple
Produces a result tuple from values val.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tuple
Example
%result = stablehlo.tuple %val0, %val1 : tuple<tensor<2xf64>, tuple<tensor<i64>>>Reactant.MLIR.Dialects.stablehlo.unary_einsum Method
unary_einsum
This operation is on its way out of StableHLO, so it is not included in the StableHLO specification: https://github.com/openxla/stablehlo/issues/3.
Informally, this operation does the same thing as TF's einsum: https://www.tensorflow.org/api_docs/python/tf/einsum
Example
%result = "stablehlo.unary_einsum"(%operand) {
einsum_config = "ab->a"
} : (tensor<4x16xf32>) -> tensor<4xf32>Reactant.MLIR.Dialects.stablehlo.uniform_dequantize Method
uniform_dequantize
Performs element-wise conversion of quantized tensor operand to a floating-point tensor result according to the quantization parameters defined by the operand type.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_dequantize
Example
%result = stablehlo.uniform_dequantize %operand : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>Reactant.MLIR.Dialects.stablehlo.uniform_quantize Method
uniform_quantize
Performs element-wise conversion of floating-point tensor or quantized tensor operand to a quantized tensor result according to the quantization parameters defined by the result type.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_quantize
Example
%result = stablehlo.uniform_quantize %operand : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>Reactant.MLIR.Dialects.stablehlo.while_ Method
while_
Produces the output from executing body function 0 or more times while the cond function outputs true.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#while
Example
%results0, %results1 = stablehlo.while(%arg0 = %init_i, %arg1 = %init_sum) : tensor<i64>, tensor<i64>
cond {
%cond = stablehlo.compare LT, %arg0, %ten : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
} do {
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}Reactant.MLIR.Dialects.stablehlo.xor Method
xor
Performs element-wise XOR of two tensors lhs and rhs and produces a result tensor.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#xor
Example
%result = stablehlo.xor %lhs, %rhs : tensor<2xi32>