Sharding API
Reactant.Sharding module provides a high-level API to construct MLIR operations with support for sharding.
Currently we haven't documented all the functions in Reactant.Sharding.
Reactant.Sharding.DimsSharding Type
DimsSharding(
mesh::Mesh,
dims::NTuple{D,Int},
partition_spec;
is_closed::NTuple{D,Bool}=ntuple(Returns(true), D),
priority::NTuple{D,Int}=ntuple(i -> -1, D),
)Similar to NamedSharding but works for a arbitrary dimensional array. Dimensions not specified in dims are replicated. If any dimension in dims is greater than the total number of dimensions in the array, the corresponding partition_spec, is_closed and priority are ignored. Additionally for any negative dimensions in dims, the true dims are calculated as ndims(x) - dim + 1. A dims value of 0 will throw an error.
Reactant.Sharding.Mesh Type
Mesh(devices::AbstractArray{XLA.AbstractDevice}, axis_names)Construct a Mesh from an array of devices and a tuple of axis names. The size of the i-th axis is given by size(devices, i). All the axis names must be unique, and cannot be nothing.
Examples
Assuming that we have a total of 8 devices, we can construct a mesh with the following:
julia> devices = Reactant.devices();
julia> mesh = Mesh(reshape(devices, 2, 2, 2), (:x, :y, :z));
julia> mesh = Mesh(reshape(devices, 4, 2), (:x, :y));Reactant.Sharding.NamedSharding Type
NamedSharding(
mesh::Mesh, partition_spec::Tuple;
is_closed::NTuple{N,Bool}=ntuple(Returns(true), length(partition_spec)),
priority::NTuple{N,Int}=ntuple(i -> -1, length(partition_spec)),
)Sharding annotation that indicates that the array is sharded along the given partition_spec. For details on the sharding representation see the Shardy documentation.
Arguments
mesh:Sharding.Meshthat describes the mesh of the devices.partition_spec: Must be equal to the ndims of the array being sharded. Each element can be:nothing: indicating the corresponding dimension is replicated along the axis.A tuple of axis names indicating the axis names that the corresponding dimension is sharded along.
A single axis name indicating the axis name that the corresponding dimension is sharded along.
Keyword Arguments
is_closed: A tuple of booleans indicating whether the corresponding dimension is closed along the axis. Defaults totruefor all dimensions.priority: A tuple of integers indicating the priority of the corresponding dimension. Defaults to-1for all dimensions. A negative priority means that the priority is not considered by shardy.
Examples
julia> devices = Reactant.devices();
julia> mesh = Mesh(reshape(devices, 2, 2, 2), (:x, :y, :z));
julia> sharding = NamedSharding(mesh, (:x, :y, nothing)); # 3D Array sharded along x and y on dim 1 and 2 respectively, while dim 3 is replicated
julia> sharding = NamedSharding(mesh, ((:x, :y), nothing, nothing)); # 3D Array sharded along x and y on dim 1, 2 and 3 are replicated
julia> sharding = NamedSharding(mesh, (nothing, nothing)); # fully replicated MatrixSee also: Sharding.NoSharding
Reactant.Sharding.NoSharding Type
NoSharding()Sharding annotation that indicates that the array is not sharded.
See also: Sharding.NamedSharding
Reactant.Sharding.Replicated Type
Replicated(mesh::Mesh)Sharding annotation that indicates that the array is fully replicated along all dimensions.
sourceReactant.Sharding.is_sharded Method
is_sharded(sharding)
is_sharded(x::AbstractArray)Checks whether the given sharding refers to no sharding.
sourceReactant.Sharding.sharding_to_array_slices Function
sharding_to_array_slices(
sharding, size_x; client=nothing, return_updated_sharding=Val(false)
)Given a sharding and an array size, returns the device to array slices mapping. If return_updated_sharding is Val(true), the updated sharding is returned as well (for inputs requiring padding).
Reactant.Sharding.unwrap_shardinfo Method
unwrap_shardinfo(x)Unwraps a sharding info object, returning the sharding object itself.
sourceDistributed API
Reactant.Distributed module provides a high-level API to run reactant on multiple hosts.
Currently we haven't documented all the functions in Reactant.Distributed.
Reactant.Distributed.is_initialized Method
is_initialized()Returns true if the distributed environment has been initialized.
Reactant.Distributed.local_rank Method
local_rank()Returns the local rank of the current process.
source