Reactant.Ops
API
Reactant.Ops
module provides a high-level API to construct MLIR operations without having to directly interact with the different dialects.
Currently we haven't documented all the functions in Reactant.Ops
.
Reactant.Ops.gather_getindex Method
gather_getindex(src, gather_indices)
Uses MLIR.Dialects.stablehlo.gather
to get the values of src
at the indices specified by gather_indices
. If the indices are contiguous it is recommended to directly use MLIR.Dialects.stablehlo.dynamic_slice
instead.
Reactant.Ops.hlo_call Method
hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
Given a MLIR module given as a string, calls the function identified by the func_name
keyword parameter (default "main") with the provided arguments and return a tuple for each result of the call.
julia> Reactant.@jit(
hlo_call(
"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
Reactant.to_rarray(Float32[1, 2, 3]),
Reactant.to_rarray(Float32[1, 2, 3]),
)
)
(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
Reactant.Ops.lu Method
lu(
x::TracedRArray{T},
::Type{pT}=Int32;
location=mlir_stacktrace("lu", @__FILE__, @__LINE__)
) where {T,pT}
Compute the row maximum pivoted LU factorization of x
and return the factors LU
, ipiv
, permutation
tensor, and info
.
Reactant.Ops.mesh Method
mesh(
mesh::Reactant.Sharding.Mesh; mod::MLIR.IR.Module=MLIR.IR.mmodule(),
sym_name::String="mesh",
location=mlir_stacktrace("mesh", @__FILE__, @__LINE__)
)
mesh(
mesh_axes::Vector{<:Pair{<:Union{String,Symbol},Int64}},
logical_device_ids::Vector{Int64};
sym_name::String="mesh",
mod::MLIR.IR.Module=MLIR.IR.mmodule(),
location=mlir_stacktrace("mesh", @__FILE__, @__LINE__)
)
Produces a Reactant.MLIR.Dialects.sdy.mesh
operation with the given mesh
and logical_device_ids
.
Based on the provided sym_name``, we generate a unique name for the mesh in the module's
SymbolTable. Note that users shouldn't use this sym_name directly, instead they should use the returned
sym_name` to refer to the mesh in the module.
Warning
The logical_device_ids
argument are the logical device ids, not the physical device ids. For example, if the physical device ids are [2, 4, 123, 293]
, the corresponding logical device ids are [0, 1, 2, 3]
.
Returned Value
We return a NamedTuple with the following fields:
sym_name
: The unique name of the mesh in the module'sSymbolTable
.mesh_attr
:sdy::mlir::MeshAttr
representing the mesh.mesh_op
: Thesdy.mesh
operation.
Reactant.Ops.randexp Method
randexp(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type T
with the given shape and seed from an exponential distribution with rate 1. Returns a NamedTuple with the following fields:
output_state
: The state of the random number generator after the operation.output
: The generated array.
Arguments
T
: The type of the generated array.seed
: The seed for the random number generator.shape
: The shape of the generated array.algorithm
: The algorithm to use for generating the random numbers. Defaults to "DEFAULT". Other options include "PHILOX" and "THREE_FRY".
Reactant.Ops.randn Method
randn(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type T
with the given shape and seed from a standard normal distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following fields:
output_state
: The state of the random number generator after the operation.output
: The generated array.
Arguments
T
: The type of the generated array.seed
: The seed for the random number generator.shape
: The shape of the generated array.algorithm
: The algorithm to use for generating the random numbers. Defaults to "DEFAULT". Other options include "PHILOX" and "THREE_FRY".
Reactant.Ops.reduce Method
reduce(
x::TracedRArray{T},
init_values::TracedRNumber{T},
dimensions::Vector{Int},
fn::Function,
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Applies a reduction function fn
along the specified dimensions
of input x
, starting from init_values
.
Arguments
x
: The input array.init_values
: The initial value.dimensions
: The dimensions to reduce along.fn
: A binary operator.
Warning
This reduction operation follows StableHLO semantics. The key difference between this operation and Julia's built-in reduce
is explained below:
The function
fn
and the initial valueinit_values
must form a monoid, meaning:fn
must be an associative binary operation.init_values
must be the identity element associated withfn
.
This constraint ensures consistent results across all implementations.
If init_values
is not the identity element of fn
, the results may vary between CPU and GPU executions. For example:
A = [1 3; 2 4;;; 5 7; 6 8;;; 9 11; 10 12]
init_values = 2
dimensions = [1, 3]
CPU version & Julia's
reduce
:Reduce along dimension 1 →
[(15) (21); (18) (24)]
Reduce along dimension 3 →
[(33 + 2) (45 + 2)]
→[35 47]
GPU version:
Reduce along dimension 1 →
[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]
Reduce along dimension 3 →
[37 49]
Reactant.Ops.rng_bit_generator Method
rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type T
with the given shape and seed from a uniform random distribution between 0 and 1 (for floating point types). Returns a NamedTuple with the following fields:
output_state
: The state of the random number generator after the operation.output
: The generated array.
Arguments
T
: The type of the generated array.seed
: The seed for the random number generator.shape
: The shape of the generated array.algorithm
: The algorithm to use for generating the random numbers. Defaults to "DEFAULT". Other options include "PHILOX" and "THREE_FRY".
Reactant.Ops.scatter_setindex Method
scatter_setindex(dest, scatter_indices, updates)
Uses MLIR.Dialects.stablehlo.scatter
to set the values of dest
at the indices specified by scatter_indices
to the values in updates
. If the indices are contiguous it is recommended to directly use MLIR.Dialects.stablehlo.dynamic_update_slice
instead.
Reactant.Ops.sharding_constraint Method
sharding_constraint(
input::Union{TracedRArray,TracedRNumber},
sharding::Reactant.Sharding.AbstractSharding;
location=mlir_stacktrace("sharding_constraint", @__FILE__, @__LINE__)
)
Produces a Reactant.MLIR.Dialects.sdy.sharding_constraint
operation with the given input
and sharding
.
Reactant.Ops.@opcall Macro
@opcall fn(args...; kwargs...)
This call is expanded to Reactant.Ops.fn(args...; kwargs..., location)
with the location of the callsite. This enables better debug info to be propagated about the source location of different It is recommended to use this macro for calling into any function in Reactant.Ops.<function name>
.