Skip to content

Sharding API

Reactant.Sharding module provides a high-level API to construct MLIR operations with support for sharding.

Currently we haven't documented all the functions in Reactant.Sharding.

Reactant.Sharding.DimsSharding Type
julia
DimsSharding(
    mesh::Mesh{M},
    dims::NTuple{D,Int},
    partition_spec;
    is_closed::NTuple{D,Bool}=ntuple(Returns(true), D),
    priority::NTuple{D,Int}=ntuple(i -> -1, D),
)

Similar to NamedSharding but works for a arbitrary dimensional array. Dimensions not specified in dims are replicated. If any dimension in dims is greater than the total number of dimensions in the array, the corresponding partition_spec, is_closed and priority are ignored. Additionally for any negative dimensions in dims, the true dims are calculated as ndims(x) - dim + 1. A dims value of 0 will throw an error.

source
Reactant.Sharding.Mesh Type
julia
Mesh(devices::AbstractArray{XLA.AbstractDevice}, axis_names)

Construct a Mesh from an array of devices and a tuple of axis names. The size of the i-th axis is given by size(devices, i). All the axis names must be unique, and cannot be nothing.

Examples

Assuming that we have a total of 8 devices, we can construct a mesh with the following:

julia
julia> devices = Reactant.devices();

julia> mesh = Mesh(reshape(devices, 2, 2, 2), (:x, :y, :z));

julia> mesh = Mesh(reshape(devices, 4, 2), (:x, :y));
source
Reactant.Sharding.NamedSharding Type
julia
NamedSharding(
    mesh::Mesh, partition_spec::Tuple;
    is_closed::NTuple{N,Bool}=ntuple(Returns(true), length(partition_spec)),
    priority::NTuple{N,Int}=ntuple(i -> -1, length(partition_spec)),
)

Sharding annotation that indicates that the array is sharded along the given partition_spec. For details on the sharding representation see the Shardy documentation.

Arguments

  • mesh: Sharding.Mesh that describes the mesh of the devices.

  • partition_spec: Must be equal to the ndims of the array being sharded. Each element can be:

    1. nothing: indicating the corresponding dimension is replicated along the axis.

    2. A tuple of axis names indicating the axis names that the corresponding dimension is sharded along.

    3. A single axis name indicating the axis name that the corresponding dimension is sharded along.

Keyword Arguments

  • is_closed: A tuple of booleans indicating whether the corresponding dimension is closed along the axis. Defaults to true for all dimensions.

  • priority: A tuple of integers indicating the priority of the corresponding dimension. Defaults to -1 for all dimensions. A negative priority means that the priority is not considered by shardy.

Examples

julia
julia> devices = Reactant.devices();

julia> mesh = Mesh(reshape(devices, 2, 2, 2), (:x, :y, :z));

julia> sharding = NamedSharding(mesh, (:x, :y, nothing)); # 3D Array sharded along x and y on dim 1 and 2 respectively, while dim 3 is replicated

julia> sharding = NamedSharding(mesh, ((:x, :y), nothing, nothing)); # 3D Array sharded along x and y on dim 1, 2 and 3 are replicated

julia> sharding = NamedSharding(mesh, (nothing, nothing)); # fully replicated Matrix

See also: Sharding.NoSharding

source
Reactant.Sharding.NoSharding Type
julia
NoSharding()

Sharding annotation that indicates that the array is not sharded.

See also: Sharding.NamedSharding

source
Reactant.Sharding.is_sharded Method
julia
is_sharded(sharding)
is_sharded(x::AbstractArray)

Checks whether the given sharding refers to no sharding.

source
Reactant.Sharding.unwrap_shardinfo Method
julia
unwrap_shardinfo(x)

Unwraps a sharding info object, returning the sharding object itself.

source