@@ -96,11 +96,16 @@ See also: [`Sharding.NamedSharding`](@ref)
9696"""
9797struct NoSharding <: AbstractSharding end
9898
99+ @inline ndevices (:: NoSharding ) = 1
100+
101+ @inline shard_type (:: Type{NoSharding} , _) = ShardInfo{NoSharding,Nothing}
102+
99103# This allows us to mark entire branches as NoSharding
100104Base. getproperty (:: NoSharding , x) = NoSharding ()
101105Base. getproperty (:: NoSharding , x:: Symbol ) = NoSharding ()
102106
103107function (:: NoSharding )(client:: XLA.PJRT.Client , device, x:: Union{AbstractArray,Number} )
108+ device === nothing && (device = XLA. default_device (client))
104109 buffer = XLA. PJRT. AsyncBuffer (client, x, device)
105110 return (buffer,), ShardInfo (NoSharding (), nothing )
106111end
@@ -185,6 +190,12 @@ struct NamedSharding{D1,D2,P<:Tuple} <: AbstractSharding
185190 end
186191end
187192
193+ @inline ndevices (sharding:: NamedSharding ) = length (sharding. mesh. device_ids)
194+
195+ @inline function shard_type (:: Type{NamedSharding{D1,D2,P}} , N) where {D1,D2,P}
196+ return shard_type (HloSharding{D1,D2}, N)
197+ end
198+
188199function (sharding:: NamedSharding )(
189200 client:: XLA.PJRT.Client , device:: Nothing , x:: Union{AbstractArray,Number}
190201)
@@ -226,6 +237,84 @@ function get_shardy_tensor_sharding_attribute(
226237 )
227238end
228239
240+ # TODO : Something like NamedDims.jl will allow us to support NamedDimsSharding similar to
241+ # `levanter`
242+
243+ """
244+ DimsSharding(
245+ mesh::Mesh{M},
246+ dims::NTuple{D,Int},
247+ partition_spec;
248+ is_closed::NTuple{D,Bool}=ntuple(Returns(true), D),
249+ priority::NTuple{D,Int}=ntuple(i -> -1, D),
250+ )
251+
252+ Similar to [`NamedSharding`](@ref) but works for a arbitrary dimensional array. Dimensions
253+ not specified in `dims` are replicated. If any dimension in `dims` is greater than the total
254+ number of dimensions in the array, the corresponding `partition_spec`, `is_closed` and
255+ `priority` are ignored. Additionally for any negative dimensions in `dims`, the true
256+ dims are calculated as `ndims(x) - dim + 1`. A dims value of `0` will throw an error.
257+ """
258+ struct DimsSharding{M,D,P} <: AbstractSharding
259+ mesh:: Mesh{M}
260+ dims:: NTuple{D,Int}
261+ partition_spec:: P
262+ is_closed:: NTuple{D,Bool}
263+ priority:: NTuple{D,Int}
264+
265+ function DimsSharding (
266+ mesh:: Mesh{M} ,
267+ dims:: NTuple{D,Int} ,
268+ partition_spec;
269+ is_closed:: NTuple{D,Bool} = ntuple (Returns (true ), length (partition_spec)),
270+ priority:: NTuple{D,Int} = ntuple (i -> - 1 , length (partition_spec)),
271+ ) where {M,D}
272+ @assert length (partition_spec) == length (dims)
273+ # Validity checks on the inputs are deferred to NamedSharding
274+ return new {M,D,typeof(partition_spec)} (
275+ mesh, dims, partition_spec, is_closed, priority
276+ )
277+ end
278+ end
279+
280+ @inline ndevices (sharding:: DimsSharding ) = length (sharding. mesh. device_ids)
281+
282+ @inline function shard_type (:: Type{DimsSharding{M,D,P}} , N) where {M,D,P}
283+ return shard_type (HloSharding{M,N}, N)
284+ end
285+
286+ function standardize_sharding (sharding:: DimsSharding , x:: Union{AbstractArray,Number} )
287+ final_dims = map (sharding. dims) do d
288+ @assert ! iszero (d) " dims cannot contain 0"
289+ return ifelse (d < 0 , ndims (x) + d + 1 , d)
290+ end
291+
292+ dim_indices = ntuple (i -> findfirst (== (i), final_dims), ndims (x))
293+ partition_spec = ntuple (ndims (x)) do i
294+ dim_index = dim_indices[i]
295+ dim_index === nothing && return nothing # replicated dimension
296+ return sharding. partition_spec[dim_index]
297+ end
298+ is_closed = ntuple (ndims (x)) do i
299+ dim_index = dim_indices[i]
300+ dim_index === nothing && return true # replicated dimension
301+ return sharding. is_closed[dim_index]
302+ end
303+ priority = ntuple (ndims (x)) do i
304+ dim_index = dim_indices[i]
305+ dim_index === nothing && return - 1 # replicated dimension
306+ return sharding. priority[dim_index]
307+ end
308+
309+ return NamedSharding (sharding. mesh, partition_spec; is_closed, priority)
310+ end
311+
312+ function (sharding:: DimsSharding )(
313+ client:: XLA.PJRT.Client , device:: Nothing , x:: Union{AbstractArray,Number}
314+ )
315+ return (standardize_sharding (sharding, x))(client, device, x)
316+ end
317+
229318# HloSharding
230319# This stores the sharding information in the form of XLA.HloSharding, and provides a
231320# central type for the final storage. It also potentially saves us the pain of not having
@@ -244,6 +333,12 @@ struct HloSharding{D1,D2} <: AbstractSharding
244333 end
245334end
246335
336+ @inline ndevices (sharding:: HloSharding ) = length (sharding. mesh. device_ids)
337+
338+ @inline function shard_type (:: Type{HloSharding{D1,D2}} , N) where {D1,D2}
339+ return ShardInfo{HloSharding{D1,D2},Vector{NTuple{N,UnitRange{Int64}}}}
340+ end
341+
247342function Base. convert (:: Type{HloSharding} , sharding:: NamedSharding )
248343 if MLIR. IR. _has_context ()
249344 ctx = MLIR. IR. context ()
@@ -321,6 +416,10 @@ struct ShardInfo{S,D} <: AbstractSharding
321416 device_to_array_slices:: D
322417end
323418
419+ @inline ndevices (sharding:: ShardInfo ) = length (sharding. mesh)
420+
421+ @inline shard_type (:: Type{ShardInfo{S,D}} , N) where {S,D} = shard_type (S, N)
422+
324423function Base. getproperty (sharding:: ShardInfo , name:: Symbol )
325424 name ∈ (:sharding , :device_to_array_slices ) && return getfield (sharding, name)
326425 return getproperty (sharding. sharding, name)
@@ -348,6 +447,7 @@ Checks whether the given sharding refers to no sharding.
348447"""
349448is_sharded (:: NoSharding ) = false
350449is_sharded (:: NamedSharding ) = true
450+ is_sharded (:: DimsSharding ) = true
351451is_sharded (:: HloSharding ) = true
352452is_sharded (s:: ShardInfo ) = is_sharded (s. sharding)
353453
0 commit comments