Reactant.Ops
API
Reactant.Ops
module provides a high-level API to construct MLIR operations without having to directly interact with the different dialects.
Currently we haven't documented all the functions in Reactant.Ops
.
Reactant.Ops.gather_getindex Method
gather_getindex(src, gather_indices)
Uses MLIR.Dialects.stablehlo.gather
to get the values of src
at the indices specified by gather_indices
. If the indices are contiguous it is recommended to directly use MLIR.Dialects.stablehlo.dynamic_slice
instead.
Reactant.Ops.hlo_call Method
Ops.hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
Given a MLIR module given as a string, calls the function identified by the func_name
keyword parameter (default "main") with the provided arguments and return a tuple for each result of the call.
julia> Reactant.@jit(
Ops.hlo_call(
"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
Reactant.to_rarray(Float32[1, 2, 3]),
Reactant.to_rarray(Float32[1, 2, 3]),
)
)
(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
Reactant.Ops.mesh Method
mesh(
mesh::Reactant.Sharding.Mesh; mod::MLIR.IR.Module=MLIR.IR.mmodule(),
sym_name::String="mesh",
location=mlir_stacktrace("mesh", @__FILE__, @__LINE__)
)
mesh(
mesh_axes::Vector{<:Pair{<:Union{String,Symbol},Int64}},
device_ids::Vector{Int64};
sym_name::String="mesh",
mod::MLIR.IR.Module=MLIR.IR.mmodule(),
location=mlir_stacktrace("mesh", @__FILE__, @__LINE__)
)
Produces a Reactant.MLIR.Dialects.sdy.mesh
operation with the given mesh
and device_ids
.
Based on the provided sym_name``, we generate a unique name for the mesh in the module's
SymbolTable. Note that users shouldn't use this sym_name directly, instead they should use the returned
sym_name` to refer to the mesh in the module.
Warning
The device_ids
argument are the logical device ids, not the physical device ids. For example, if the physical device ids are [2, 4, 123, 293]
, the corresponding logical device ids are [0, 1, 2, 3]
.
Returned Value
We return a NamedTuple with the following fields:
sym_name
: The unique name of the mesh in the module'sSymbolTable
.mesh_attr
:sdy::mlir::MeshAttr
representing the mesh.mesh_op
: Thesdy.mesh
operation.
Reactant.Ops.randexp Method
randexp(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type T
with the given shape and seed from an exponential distribution with rate 1. Returns a NamedTuple with the following fields:
output_state
: The state of the random number generator after the operation.output
: The generated array.
Arguments
T
: The type of the generated array.seed
: The seed for the random number generator.shape
: The shape of the generated array.algorithm
: The algorithm to use for generating the random numbers. Defaults to "DEFAULT". Other options include "PHILOX" and "THREE_FRY".
Reactant.Ops.randn Method
randn(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type T
with the given shape and seed from a standard normal distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following fields:
output_state
: The state of the random number generator after the operation.output
: The generated array.
Arguments
T
: The type of the generated array.seed
: The seed for the random number generator.shape
: The shape of the generated array.algorithm
: The algorithm to use for generating the random numbers. Defaults to "DEFAULT". Other options include "PHILOX" and "THREE_FRY".
Reactant.Ops.reduce Method
reduce(
x::TracedRArray{T},
init_values::Union{Nothing,TracedRNumber{T}},
dimensions::Vector{Int},
fn::Function;
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Applies a reduction function fn
along the specified dimensions
of input x
, starting from init_values
.
Arguments
x
: The input array.init_values
: The initial value.dimensions
: The dimensions to reduce along.fn
: A binary operator.
Warning
This reduction operation follows StableHLO semantics. The key difference between this operation and Julia's built-in reduce
is explained below:
The function
fn
and the initial valueinit_values
must form a monoid, meaning:fn
must be an associative binary operation.init_values
must be the identity element associated withfn
.
This constraint ensures consistent results across all implementations.
If init_values
is not the identity element of fn
, the results may vary between CPU and GPU executions. For example:
A = [1 3; 2 4;;; 5 7; 6 8;;; 9 11; 10 12]
init_values = 2
dimensions = [1, 3]
CPU version & Julia's
reduce
:Reduce along dimension 1 →
[(15) (21); (18) (24)]
Reduce along dimension 3 →
[(33 + 2) (45 + 2)]
→[35 47]
GPU version:
Reduce along dimension 1 →
[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]
Reduce along dimension 3 →
[37 49]
Reactant.Ops.rng_bit_generator Method
rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type T
with the given shape and seed from a uniform random distribution between 0 and 1. Returns a NamedTuple with the following fields:
output_state
: The state of the random number generator after the operation.output
: The generated array.
Arguments
T
: The type of the generated array.seed
: The seed for the random number generator.shape
: The shape of the generated array.algorithm
: The algorithm to use for generating the random numbers. Defaults to "DEFAULT". Other options include "PHILOX" and "THREE_FRY".
Reactant.Ops.scatter_setindex Method
scatter_setindex(dest, scatter_indices, updates)
Uses MLIR.Dialects.stablehlo.scatter
to set the values of dest
at the indices specified by scatter_indices
to the values in updates
. If the indices are contiguous it is recommended to directly use MLIR.Dialects.stablehlo.dynamic_update_slice
instead.
Reactant.Ops.sharding_constraint Method
sharding_constraint(
input::Union{TracedRArray,TracedRNumber},
sharding::Reactant.Sharding.AbstractSharding;
location=mlir_stacktrace("sharding_constraint", @__FILE__, @__LINE__)
)
Produces a Reactant.MLIR.Dialects.sdy.sharding_constraint
operation with the given input
and sharding
.