Sharding API
Reactant.Sharding
module provides a high-level API to construct MLIR operations with support for sharding.
Currently we haven't documented all the functions in Reactant.Sharding
.
Reactant.Sharding.DimsSharding Type
DimsSharding(
mesh::Mesh{M},
dims::NTuple{D,Int},
partition_spec;
is_closed::NTuple{D,Bool}=ntuple(Returns(true), D),
priority::NTuple{D,Int}=ntuple(i -> -1, D),
)
Similar to NamedSharding
but works for a arbitrary dimensional array. Dimensions not specified in dims
are replicated. If any dimension in dims
is greater than the total number of dimensions in the array, the corresponding partition_spec
, is_closed
and priority
are ignored. Additionally for any negative dimensions in dims
, the true dims are calculated as ndims(x) - dim + 1
. A dims value of 0
will throw an error.
Reactant.Sharding.Mesh Type
Mesh(devices::AbstractArray{XLA.AbstractDevice}, axis_names)
Construct a Mesh
from an array of devices and a tuple of axis names. The size of the i-th axis is given by size(devices, i)
. All the axis names must be unique, and cannot be nothing.
Examples
Assuming that we have a total of 8 devices, we can construct a mesh with the following:
julia> devices = Reactant.devices();
julia> mesh = Mesh(reshape(devices, 2, 2, 2), (:x, :y, :z));
julia> mesh = Mesh(reshape(devices, 4, 2), (:x, :y));
Reactant.Sharding.NamedSharding Type
NamedSharding(
mesh::Mesh, partition_spec::Tuple;
is_closed::NTuple{N,Bool}=ntuple(Returns(true), length(partition_spec)),
priority::NTuple{N,Int}=ntuple(i -> -1, length(partition_spec)),
)
Sharding annotation that indicates that the array is sharded along the given partition_spec
. For details on the sharding representation see the Shardy documentation.
Arguments
mesh
:Sharding.Mesh
that describes the mesh of the devices.partition_spec
: Must be equal to the ndims of the array being sharded. Each element can be:nothing
: indicating the corresponding dimension is replicated along the axis.A tuple of axis names indicating the axis names that the corresponding dimension is sharded along.
A single axis name indicating the axis name that the corresponding dimension is sharded along.
Keyword Arguments
is_closed
: A tuple of booleans indicating whether the corresponding dimension is closed along the axis. Defaults totrue
for all dimensions.priority
: A tuple of integers indicating the priority of the corresponding dimension. Defaults to-1
for all dimensions. A negative priority means that the priority is not considered by shardy.
Examples
julia> devices = Reactant.devices();
julia> mesh = Mesh(reshape(devices, 2, 2, 2), (:x, :y, :z));
julia> sharding = NamedSharding(mesh, (:x, :y, nothing)); # 3D Array sharded along x and y on dim 1 and 2 respectively, while dim 3 is replicated
julia> sharding = NamedSharding(mesh, ((:x, :y), nothing, nothing)); # 3D Array sharded along x and y on dim 1, 2 and 3 are replicated
julia> sharding = NamedSharding(mesh, (nothing, nothing)); # fully replicated Matrix
See also: Sharding.NoSharding
Reactant.Sharding.NoSharding Type
NoSharding()
Sharding annotation that indicates that the array is not sharded.
See also: Sharding.NamedSharding
Reactant.Sharding.is_sharded Method
is_sharded(sharding)
is_sharded(x::AbstractArray)
Checks whether the given sharding refers to no sharding.
sourceReactant.Sharding.unwrap_shardinfo Method
unwrap_shardinfo(x)
Unwraps a sharding info object, returning the sharding object itself.
source