Skip to content

Reactant.Ops API

Reactant.Ops module provides a high-level API to construct MLIR operations without having to directly interact with the different dialects.

Currently we haven't documented all the functions in Reactant.Ops.

Reactant.Ops.gather_getindex Method
julia
gather_getindex(src, gather_indices)

Uses MLIR.Dialects.stablehlo.gather to get the values of src at the indices specified by gather_indices. If the indices are contiguous it is recommended to directly use MLIR.Dialects.stablehlo.dynamic_slice instead.

source
Reactant.Ops.hlo_call Method
julia
Ops.hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}

Given a MLIR module given as a string, calls the function identified by the func_name keyword parameter (default "main") with the provided arguments and return a tuple for each result of the call.

julia
julia> Reactant.@jit(
          Ops.hlo_call(
              """
              module {
                func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
                  %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
                  return %0 : tensor<3xf32>
                }
              }
              """,
              Reactant.to_rarray(Float32[1, 2, 3]),
              Reactant.to_rarray(Float32[1, 2, 3]),
          )
       )
(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
source
Reactant.Ops.mesh Method
julia
mesh(
    mesh::Reactant.Sharding.Mesh; mod::MLIR.IR.Module=MLIR.IR.mmodule(),
    sym_name::String="mesh",
    location=mlir_stacktrace("mesh", @__FILE__, @__LINE__)
)
mesh(
    mesh_axes::Vector{<:Pair{<:Union{String,Symbol},Int64}},
    device_ids::Vector{Int64};
    sym_name::String="mesh",
    mod::MLIR.IR.Module=MLIR.IR.mmodule(),
    location=mlir_stacktrace("mesh", @__FILE__, @__LINE__)
)

Produces a Reactant.MLIR.Dialects.sdy.mesh operation with the given mesh and device_ids.

Based on the provided sym_name``, we generate a unique name for the mesh in the module'sSymbolTable. Note that users shouldn't use this sym_name directly, instead they should use the returnedsym_name` to refer to the mesh in the module.

Warning

The device_ids argument are the logical device ids, not the physical device ids. For example, if the physical device ids are [2, 4, 123, 293], the corresponding logical device ids are [0, 1, 2, 3].

Returned Value

We return a NamedTuple with the following fields:

  • sym_name: The unique name of the mesh in the module's SymbolTable.

  • mesh_attr: sdy::mlir::MeshAttr representing the mesh.

  • mesh_op: The sdy.mesh operation.

source
Reactant.Ops.randexp Method
julia
randexp(
    ::Type{T},
    seed::TracedRArray{UInt64,1},
    shape;
    algorithm::String="DEFAULT",
    location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Generate a random array of type T with the given shape and seed from an exponential distribution with rate 1. Returns a NamedTuple with the following fields:

  • output_state: The state of the random number generator after the operation.

  • output: The generated array.

Arguments

  • T: The type of the generated array.

  • seed: The seed for the random number generator.

  • shape: The shape of the generated array.

  • algorithm: The algorithm to use for generating the random numbers. Defaults to "DEFAULT". Other options include "PHILOX" and "THREE_FRY".

source
Reactant.Ops.randn Method
julia
randn(
    ::Type{T},
    seed::TracedRArray{UInt64,1},
    shape;
    algorithm::String="DEFAULT",
    location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Generate a random array of type T with the given shape and seed from a standard normal distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following fields:

  • output_state: The state of the random number generator after the operation.

  • output: The generated array.

Arguments

  • T: The type of the generated array.

  • seed: The seed for the random number generator.

  • shape: The shape of the generated array.

  • algorithm: The algorithm to use for generating the random numbers. Defaults to "DEFAULT". Other options include "PHILOX" and "THREE_FRY".

source
Reactant.Ops.reduce Method
julia
reduce(
    x::TracedRArray{T},
    init_values::Union{Nothing,TracedRNumber{T}},
    dimensions::Vector{Int},
    fn::Function;
    location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Applies a reduction function fn along the specified dimensions of input x, starting from init_values.

Arguments

  • x: The input array.

  • init_values: The initial value.

  • dimensions: The dimensions to reduce along.

  • fn: A binary operator.

Warning

This reduction operation follows StableHLO semantics. The key difference between this operation and Julia's built-in reduce is explained below:

  • The function fn and the initial value init_values must form a monoid, meaning:

    • fn must be an associative binary operation.

    • init_values must be the identity element associated with fn.

  • This constraint ensures consistent results across all implementations.

If init_values is not the identity element of fn, the results may vary between CPU and GPU executions. For example:

julia
A = [1 3; 2 4;;; 5 7; 6 8;;; 9 11; 10 12]
init_values = 2
dimensions = [1, 3]
  • CPU version & Julia's reduce:

    • Reduce along dimension 1 → [(15) (21); (18) (24)]

    • Reduce along dimension 3 → [(33 + 2) (45 + 2)][35 47]

  • GPU version:

    • Reduce along dimension 1 → [(15 + 2) (21 + 2); (18 + 2) (24 + 2)]

    • Reduce along dimension 3 → [37 49]

source
Reactant.Ops.rng_bit_generator Method
julia
rng_bit_generator(
    ::Type{T},
    seed::TracedRArray{UInt64,1},
    shape;
    algorithm::String="DEFAULT",
    location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Generate a random array of type T with the given shape and seed from a uniform random distribution between 0 and 1. Returns a NamedTuple with the following fields:

  • output_state: The state of the random number generator after the operation.

  • output: The generated array.

Arguments

  • T: The type of the generated array.

  • seed: The seed for the random number generator.

  • shape: The shape of the generated array.

  • algorithm: The algorithm to use for generating the random numbers. Defaults to "DEFAULT". Other options include "PHILOX" and "THREE_FRY".

source
Reactant.Ops.scatter_setindex Method
julia
scatter_setindex(dest, scatter_indices, updates)

Uses MLIR.Dialects.stablehlo.scatter to set the values of dest at the indices specified by scatter_indices to the values in updates. If the indices are contiguous it is recommended to directly use MLIR.Dialects.stablehlo.dynamic_update_slice instead.

source
Reactant.Ops.sharding_constraint Method
julia
sharding_constraint(
    input::Union{TracedRArray,TracedRNumber},
    sharding::Reactant.Sharding.AbstractSharding;
    location=mlir_stacktrace("sharding_constraint", @__FILE__, @__LINE__)
)

Produces a Reactant.MLIR.Dialects.sdy.sharding_constraint operation with the given input and sharding.

source