Skip to content

Reactant.Ops API

Reactant.Ops module provides a high-level API to construct MLIR operations without having to directly interact with the different dialects.

Currently we haven't documented all the functions in Reactant.Ops.

Reactant.Ops.gather_getindex Method
julia
gather_getindex(src, gather_indices)

Uses MLIR.Dialects.stablehlo.gather to get the values of src at the indices specified by gather_indices. If the indices are contiguous it is recommended to directly use MLIR.Dialects.stablehlo.dynamic_slice instead.

source
Reactant.Ops.hlo_call Method
julia
hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}

Given a MLIR module given as a string, calls the function identified by the func_name keyword parameter (default "main") with the provided arguments and return a tuple for each result of the call.

julia
julia> Reactant.@jit(
          hlo_call(
              """
              module {
                func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
                  %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
                  return %0 : tensor<3xf32>
                }
              }
              """,
              Reactant.to_rarray(Float32[1, 2, 3]),
              Reactant.to_rarray(Float32[1, 2, 3]),
          )
       )
(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
source
Reactant.Ops.lu Method
julia
lu(
    x::TracedRArray{T},
    ::Type{pT}=Int32;
    location=mlir_stacktrace("lu", @__FILE__, @__LINE__)
) where {T,pT}

Compute the row maximum pivoted LU factorization of x and return the factors LU, ipiv, permutation tensor, and info.

source
Reactant.Ops.mesh Method
julia
mesh(
    mesh::Reactant.Sharding.Mesh; mod::MLIR.IR.Module=MLIR.IR.mmodule(),
    sym_name::String="mesh",
    location=mlir_stacktrace("mesh", @__FILE__, @__LINE__)
)
mesh(
    mesh_axes::Vector{<:Pair{<:Union{String,Symbol},Int64}},
    logical_device_ids::Vector{Int64};
    sym_name::String="mesh",
    mod::MLIR.IR.Module=MLIR.IR.mmodule(),
    location=mlir_stacktrace("mesh", @__FILE__, @__LINE__)
)

Produces a Reactant.MLIR.Dialects.sdy.mesh operation with the given mesh and logical_device_ids.

Based on the provided sym_name``, we generate a unique name for the mesh in the module'sSymbolTable. Note that users shouldn't use this sym_name directly, instead they should use the returnedsym_name` to refer to the mesh in the module.

Warning

The logical_device_ids argument are the logical device ids, not the physical device ids. For example, if the physical device ids are [2, 4, 123, 293], the corresponding logical device ids are [0, 1, 2, 3].

Returned Value

We return a NamedTuple with the following fields:

  • sym_name: The unique name of the mesh in the module's SymbolTable.

  • mesh_attr: sdy::mlir::MeshAttr representing the mesh.

  • mesh_op: The sdy.mesh operation.

source
Reactant.Ops.randexp Method
julia
randexp(
    ::Type{T},
    seed::TracedRArray{UInt64,1},
    shape;
    algorithm::String="DEFAULT",
    location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Generate a random array of type T with the given shape and seed from an exponential distribution with rate 1. Returns a NamedTuple with the following fields:

  • output_state: The state of the random number generator after the operation.

  • output: The generated array.

Arguments

  • T: The type of the generated array.

  • seed: The seed for the random number generator.

  • shape: The shape of the generated array.

  • algorithm: The algorithm to use for generating the random numbers. Defaults to "DEFAULT". Other options include "PHILOX" and "THREE_FRY".

source
Reactant.Ops.randn Method
julia
randn(
    ::Type{T},
    seed::TracedRArray{UInt64,1},
    shape;
    algorithm::String="DEFAULT",
    location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Generate a random array of type T with the given shape and seed from a standard normal distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following fields:

  • output_state: The state of the random number generator after the operation.

  • output: The generated array.

Arguments

  • T: The type of the generated array.

  • seed: The seed for the random number generator.

  • shape: The shape of the generated array.

  • algorithm: The algorithm to use for generating the random numbers. Defaults to "DEFAULT". Other options include "PHILOX" and "THREE_FRY".

source
Reactant.Ops.reduce Method
julia
reduce(
    x::TracedRArray{T},
    init_values::TracedRNumber{T},
    dimensions::Vector{Int},
    fn::Function,
    location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Applies a reduction function fn along the specified dimensions of input x, starting from init_values.

Arguments

  • x: The input array.

  • init_values: The initial value.

  • dimensions: The dimensions to reduce along.

  • fn: A binary operator.

Warning

This reduction operation follows StableHLO semantics. The key difference between this operation and Julia's built-in reduce is explained below:

  • The function fn and the initial value init_values must form a monoid, meaning:

    • fn must be an associative binary operation.

    • init_values must be the identity element associated with fn.

  • This constraint ensures consistent results across all implementations.

If init_values is not the identity element of fn, the results may vary between CPU and GPU executions. For example:

julia
A = [1 3; 2 4;;; 5 7; 6 8;;; 9 11; 10 12]
init_values = 2
dimensions = [1, 3]
  • CPU version & Julia's reduce:

    • Reduce along dimension 1 → [(15) (21); (18) (24)]

    • Reduce along dimension 3 → [(33 + 2) (45 + 2)][35 47]

  • GPU version:

    • Reduce along dimension 1 → [(15 + 2) (21 + 2); (18 + 2) (24 + 2)]

    • Reduce along dimension 3 → [37 49]

source
Reactant.Ops.rng_bit_generator Method
julia
rng_bit_generator(
    ::Type{T},
    seed::TracedRArray{UInt64,1},
    shape;
    algorithm::String="DEFAULT",
    location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Generate a random array of type T with the given shape and seed from a uniform random distribution between 0 and 1 (for floating point types). Returns a NamedTuple with the following fields:

  • output_state: The state of the random number generator after the operation.

  • output: The generated array.

Arguments

  • T: The type of the generated array.

  • seed: The seed for the random number generator.

  • shape: The shape of the generated array.

  • algorithm: The algorithm to use for generating the random numbers. Defaults to "DEFAULT". Other options include "PHILOX" and "THREE_FRY".

source
Reactant.Ops.scatter_setindex Method
julia
scatter_setindex(dest, scatter_indices, updates)

Uses MLIR.Dialects.stablehlo.scatter to set the values of dest at the indices specified by scatter_indices to the values in updates. If the indices are contiguous it is recommended to directly use MLIR.Dialects.stablehlo.dynamic_update_slice instead.

source
Reactant.Ops.sharding_constraint Method
julia
sharding_constraint(
    input::Union{TracedRArray,TracedRNumber},
    sharding::Reactant.Sharding.AbstractSharding;
    location=mlir_stacktrace("sharding_constraint", @__FILE__, @__LINE__)
)

Produces a Reactant.MLIR.Dialects.sdy.sharding_constraint operation with the given input and sharding.

source
Reactant.Ops.@opcall Macro
julia
@opcall fn(args...; kwargs...)

This call is expanded to Reactant.Ops.fn(args...; kwargs..., location) with the location of the callsite. This enables better debug info to be propagated about the source location of different It is recommended to use this macro for calling into any function in Reactant.Ops.<function name>.

source