3636 @nospecialize (obj:: AbstractArray{T} ), field, val
3737) where {T}
3838 ancestor_obj = ancestor (obj)
39- if isbitstype (T) || ancestor_obj isa RArray
40- if val isa XLA. AsyncBuffer
41- if Reactant. Sharding. is_sharded (ancestor_obj)
42- error (" `val` can't be a buffer if `obj` is sharded" )
43- else
44- return Base. setfield! (obj, field, (val,))
45- end
46- end
47- return Base. setfield! (obj, field, val)
48- end
39+ (isbitstype (T) || ancestor_obj isa RArray) && return Base. setfield! (obj, field, val)
4940 return Base. setindex! (obj, val, field)
5041end
5142
@@ -75,40 +66,48 @@ function create_result(
7566 return Expr (:new , T, elems... )
7667end
7768
69+ function __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
70+ device_to_array_slices, partition_spec = path_to_shard_info[path]
71+ delete! (path_to_shard_info, path)
72+ sharding = Reactant. Sharding. NamedSharding (sharding_mesh, partition_spec)
73+ return Reactant. Sharding. ShardInfo (sharding, device_to_array_slices)
74+ end
75+
7876function create_result (
79- tocopy:: ConcreteRNumber{T} , path, result_stores, path_to_shard_info, sharding_mesh
80- ) where {T}
77+ tocopy:: ConcreteRNumber{T,D,S } , path, result_stores, path_to_shard_info, sharding_mesh
78+ ) where {T,D,S }
8179 if haskey (result_stores, path)
8280 restore = result_stores[path]
8381 delete! (result_stores, path)
84- return :(ConcreteRNumber {$T} ($ restore))
82+ if path_to_shard_info != = nothing # restore sharding
83+ sharding = __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
84+ return :(ConcreteRNumber {$T,length($(restore)),$(typeof(sharding))} (
85+ ($ (restore). .. ,), $ sharding
86+ ))
87+ else
88+ return :(ConcreteRNumber {$T} ($ restore))
89+ end
90+ end
91+
92+ if path_to_shard_info != = nothing # restore sharding
93+ sharding = __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
94+ return :(ConcreteRNumber {$T,length($(tocopy.data)),$(typeof(sharding))} (
95+ ($ (tocopy. data... ,)), $ sharding
96+ ))
8597 end
8698 # We will set the data for this later
8799 return :(ConcreteRNumber {$T} ($ (tocopy. data)))
88100end
89101
90- function __construct_sharding_for_carray (
91- :: ConcreteRArray{T,N,D,S} , path, _, path_to_shard_info, sharding_mesh
92- ) where {T,N,D,S}
93- device_to_array_slices, partition_spec = path_to_shard_info[path]
94- delete! (path_to_shard_info, path)
95- sharding = Reactant. Sharding. NamedSharding (sharding_mesh, partition_spec)
96- return Reactant. Sharding. FinalizedNamedSharding {typeof(sharding),ndims(sharding_mesh)} (
97- sharding, device_to_array_slices
98- )
99- end
100-
101102function create_result (
102103 tocopy:: ConcreteRArray{T,N,D,S} , path, result_stores, path_to_shard_info, sharding_mesh
103104) where {T,N,D,S}
104105 if haskey (result_stores, path)
105106 restore = result_stores[path]
106107 delete! (result_stores, path)
107108 if path_to_shard_info != = nothing # restore sharding
108- sharding = __construct_sharding_for_carray (
109- tocopy, path, result_stores, path_to_shard_info, sharding_mesh
110- )
111- return :(ConcreteRArray {$T,$N,$(ndims(sharding_mesh)),$(typeof(sharding))} (
109+ sharding = __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
110+ return :(ConcreteRArray {$T,$N,length($(restore)),$(typeof(sharding))} (
112111 ($ (restore). .. ,), $ (tocopy. shape), $ sharding
113112 ))
114113 else
@@ -117,10 +116,8 @@ function create_result(
117116 end
118117
119118 if path_to_shard_info != = nothing # restore sharding
120- sharding = __construct_sharding_for_carray (
121- tocopy, path, result_stores, path_to_shard_info, sharding_mesh
122- )
123- return :(ConcreteRArray {$T,$N,$(ndims(sharding_mesh)),$(typeof(sharding))} (
119+ sharding = __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
120+ return :(ConcreteRArray {$T,$N,length($(tocopy.data)),$(typeof(sharding))} (
124121 ($ (tocopy. data). .. ,), $ (tocopy. shape), $ sharding
125122 ))
126123 end
@@ -365,6 +362,7 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo
365362 " binary_op_transpose_simplify_or" ,
366363 " binary_op_transpose_simplify_and" ,
367364 " binary_op_transpose_simplify_xor" ,
365+ " associative_binary_op_reordering<1>" ,
368366 " transpose_unary_transpose_abs" ,
369367 " transpose_unary_transpose_neg" ,
370368 " transpose_unary_transpose_sqrt" ,
@@ -380,12 +378,15 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo
380378 " transpose_unary_transpose_sine" ,
381379 " transpose_unary_transpose_tanh" ,
382380 " transpose_broadcast_in_dim_to_broadcast_in_dim<16>" ,
381+ " scatter_indices_are_unique" ,
382+ " transpose_reduce_simplify" ,
383383 " replace_neg_add_with_subtract" ,
384384 " log_const_prop<1>" ,
385385 " log_plus_one_const_prop<1>" ,
386386 " binop_const_simplify" ,
387387 " transpose_broadcast_in_dim_to_broadcast_in_dim" ,
388388 " not_select_simplify" ,
389+ " scatter_update_computation_const_prop" ,
389390 " common_compare_expression_rewrite" ,
390391 " compare_select_simplify" ,
391392 " while_simplify<1>" ,
@@ -794,10 +795,12 @@ function compile_mlir!(
794795 results = [MLIR. IR. operand (ret, i) for i in 1 : MLIR. IR. noperands (ret)]
795796 nresults = MLIR. IR. Value[]
796797 linear_results2 = TracedType[]
798+ results_mask = falses (length (results))
797799 for (i, op) in enumerate (results)
798800 if ! MLIR. IR. is_block_arg (op)
799801 push! (nresults, op)
800802 push! (linear_results2, linear_results[i])
803+ results_mask[i] = true
801804 continue
802805 end
803806 push! (preserved_args, (linear_results[i], MLIR. IR. block_arg_num (op)))
@@ -812,11 +815,18 @@ function compile_mlir!(
812815
813816 out_tys2 = [MLIR. IR. type (a) for a in nresults]
814817
818+ res_attrs = MLIR. IR. attr (compiled_f, " res_attrs" )
819+ if res_attrs isa MLIR. IR. Attribute
820+ res_attrs = [
821+ res_attrs[i - 1 ] for (i, present) in enumerate (results_mask) if present
822+ ]
823+ end
824+
815825 func3 = MLIR. Dialects. func. func_ (;
816826 sym_name= " main" ,
817827 function_type= MLIR. IR. FunctionType (in_tys, out_tys2),
818828 arg_attrs= MLIR. IR. attr (compiled_f, " arg_attrs" ),
819- res_attrs= MLIR . IR . attr (compiled_f, " res_attrs " ) ,
829+ res_attrs,
820830 no_inline= MLIR. IR. attr (compiled_f, " no_inline" ),
821831 body= MLIR. IR. Region (),
822832 )
@@ -837,7 +847,6 @@ function compile_mlir!(
837847 linear_args,
838848 in_tys,
839849 linear_results2,
840- mlir_fn_res. linear_result_shard_info,
841850 mlir_fn_res. num_partitions,
842851 mlir_fn_res. num_replicas,
843852 mlir_fn_res. is_sharded,
@@ -862,6 +871,22 @@ macro code_hlo(args...)
862871 $ (first)($ (compiled))))
863872end
864873
874+ """
875+ @code_mhlo [optimize = ...] [no_nan = <true/false>] f(args...)
876+
877+ Similar to `@code_hlo`, but prints the module after running the XLA compiler.
878+ """
879+ macro code_mhlo (args... )
880+ default_options = Dict {Symbol,Any} (
881+ :optimize => true , :no_nan => false , :client => nothing
882+ )
883+ compile_expr, (; compiled) = compile_call_expr (
884+ __module__, compile_xla, default_options, args...
885+ )
886+ return esc (:($ (compile_expr);
887+ $ (first)($ (compiled))))
888+ end
889+
865890"""
866891 @compile [optimize = ...] [no_nan = <true/false>] [sync = <true/false>] f(args...)
867892"""
@@ -998,7 +1023,7 @@ function codegen_flatten!(
9981023
9991024 if is_sharded
10001025 carg = inv_seen_args[arg]
1001- if carg isa ConcreteRArray && Reactant. Sharding. is_sharded (carg)
1026+ if Reactant. Sharding. is_sharded (carg)
10021027 for j in 1 : length (mesh)
10031028 sbuf = Symbol (:sbuf_ , i, " _" , j)
10041029 push! (flatten_names, sbuf)
@@ -1007,17 +1032,11 @@ function codegen_flatten!(
10071032 else
10081033 # Warn here first and then replicate the input across all devices on the
10091034 # mesh
1010- if carg isa ConcreteRArray
1011- @warn " Input $carg is not sharded, replicating across all devices. It \
1012- is recommended to replicate the input across all devices on the \
1013- mesh manually using `Reactant.Sharding.NamedSharding`" maxlog = 1
1014- end
1035+ @warn " Input $carg is not sharded, replicating across all devices. It \
1036+ is recommended to replicate the input across all devices on the \
1037+ mesh manually using `Reactant.Sharding.NamedSharding`" maxlog = 1
10151038 buf = Symbol (:buf_ , i)
1016- if carg isa ConcreteRArray
1017- push! (flatten_code, :($ buf = XLA. synced_buffer (only ($ usbuf))))
1018- else
1019- push! (flatten_code, :($ buf = XLA. synced_buffer ($ usbuf)))
1020- end
1039+ push! (flatten_code, :($ buf = XLA. synced_buffer (only ($ usbuf))))
10211040 for j in 1 : length (mesh)
10221041 device_id = mesh. device_ids[j]
10231042 device_ordinal = XLA. device_ordinal (client, device_id)
@@ -1030,9 +1049,7 @@ function codegen_flatten!(
10301049 else
10311050 sbuf = Symbol (:sbuf_ , i)
10321051 push! (flatten_names, sbuf)
1033- if arg isa TracedRNumber
1034- push! (flatten_code, :($ sbuf = XLA. synced_buffer ($ usbuf)))
1035- elseif arg isa TracedRArray
1052+ if arg isa TracedRArray || arg isa TracedRNumber
10361053 push! (flatten_code, :($ sbuf = only (XLA. synced_buffer ($ usbuf))))
10371054 else
10381055 error (" Unsupported type $(typeof (arg)) " )
@@ -1061,7 +1078,6 @@ function codegen_unflatten!(
10611078 concrete_result,
10621079 result_stores,
10631080 path_to_shard_info,
1064- is_sharded:: Bool ,
10651081 linear_result_shard_info,
10661082 sharding_mesh,
10671083)
@@ -1369,26 +1385,28 @@ function compile_xla(f, args; client=nothing, kwargs...)
13691385 mlir_fn_res. is_sharded,
13701386 )
13711387
1372- mlir_fn_res. num_partitions > 1 && (device = nothing )
1373-
13741388 # Attach a name, and partitioning attributes to the module
13751389 __add_mhlo_attributes_and_name! (
13761390 mod, f; mlir_fn_res. num_partitions, mlir_fn_res. num_replicas
13771391 )
13781392
13791393 # compile MLIR module to XLA executable
1380- is_sharded = mlir_fn_res. num_partitions > 1
1381- if is_sharded
1382- # mesh_shape = collect(Int64, size(mlir_fn_res.sharding_mesh))
1383- mesh_ids = collect (Int64, vec (mlir_fn_res. sharding_mesh. device_ids))
1394+ mlir_fn_res. is_sharded && (device = nothing )
1395+ mesh_ids = if mlir_fn_res. is_sharded
1396+ collect (Int64, mlir_fn_res. sharding_mesh. device_ids)
13841397 else
1385- # mesh_shape = Int64[]
1386- mesh_ids = Int64[]
1398+ Int64[]
13871399 end
1388- # exec = XLA.Compile(client, device, mod; is_sharded, mesh_ids, mesh_shape)
1389- exec = XLA. Compile (client, device, mod; is_sharded, mesh_ids)
1400+ exec = XLA. Compile (
1401+ client,
1402+ device,
1403+ mod;
1404+ num_results= length (mlir_fn_res. linear_results),
1405+ mlir_fn_res. is_sharded,
1406+ mesh_ids,
1407+ )
13901408
1391- return exec, mlir_fn_res, device, client
1409+ return mod, exec, mlir_fn_res, device, client
13921410 finally
13931411 MLIR. IR. deactivate! (ctx)
13941412 end
@@ -1398,7 +1416,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
13981416end
13991417
14001418function compile (f, args; sync= false , kwargs... )
1401- exec, mlir_fn_res, device, client = compile_xla (f, args; kwargs... )
1419+ _, exec, mlir_fn_res, device, client = compile_xla (f, args; kwargs... )
14021420 (; linear_args, seen_args, linear_results, preserved_args, concrete_result) =
14031421 mlir_fn_res
14041422
@@ -1408,11 +1426,7 @@ function compile(f, args; sync=false, kwargs...)
14081426 end
14091427
14101428 result_stores = Dict {Tuple,Symbol} ()
1411- path_to_shard_info = if mlir_fn_res. is_sharded
1412- Dict{Tuple,Tuple{Array{Vector{UnitRange{Int}}},Tuple}}()
1413- else
1414- nothing
1415- end
1429+ path_to_shard_info = mlir_fn_res. is_sharded ? Dict {Tuple,Tuple} () : nothing
14161430
14171431 # generate Julia `Thunk` code
14181432 flatten_arg_names, flatten_code = codegen_flatten! (
@@ -1431,9 +1445,25 @@ function compile(f, args; sync=false, kwargs...)
14311445 donated_args_mask,
14321446 length (linear_results),
14331447 mlir_fn_res. is_sharded,
1434- mlir_fn_res. is_sharded ? vec (mlir_fn_res. sharding_mesh. device_ids) : Int64[],
1448+ if mlir_fn_res. is_sharded
1449+ collect (Int64, mlir_fn_res. sharding_mesh. device_ids)
1450+ else
1451+ Int64[]
1452+ end ,
14351453 )
14361454
1455+ linear_result_shard_info = if mlir_fn_res. is_sharded
1456+ # Generate a tuple of DeviceToArraySlices and PartitionSpecs
1457+ output_shardings = XLA. get_output_shardings (exec)
1458+ XLA. compute_array_indices_and_partition_spec .(
1459+ output_shardings,
1460+ size .(mlir_fn_res. linear_results),
1461+ (mlir_fn_res. sharding_mesh,),
1462+ )
1463+ else
1464+ ntuple (Returns (nothing ), length (linear_results))
1465+ end
1466+
14371467 unflatten_code = codegen_unflatten! (
14381468 linear_args,
14391469 preserved_args,
@@ -1442,8 +1472,7 @@ function compile(f, args; sync=false, kwargs...)
14421472 concrete_result,
14431473 result_stores,
14441474 path_to_shard_info,
1445- mlir_fn_res. is_sharded,
1446- mlir_fn_res. linear_result_shard_info,
1475+ linear_result_shard_info,
14471476 mlir_fn_res. sharding_mesh,
14481477 )
14491478
0 commit comments