Skip to content

Commit c7aba73

Browse files
tomnatan30TF2JAXDev
authored andcommitted
[JAX] add support for gather/scatter batching dims following the new attributes in stablehlo.
This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota. See openxla/stablehlo#2259 PiperOrigin-RevId: 647647825
1 parent db3f7d1 commit c7aba73

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

tf2jax/_src/xla_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,24 @@ def gather_dimension_numbers_from_proto(
6868
message) -> jax.lax.GatherDimensionNumbers:
6969
proto = xla_data_pb2.GatherDimensionNumbers().FromString(message)
7070
return jax.lax.GatherDimensionNumbers(
71-
tuple(proto.offset_dims), tuple(proto.collapsed_slice_dims),
72-
tuple(proto.start_index_map))
71+
tuple(proto.offset_dims),
72+
tuple(proto.collapsed_slice_dims),
73+
tuple(proto.start_index_map),
74+
tuple(proto.operand_batching_dims),
75+
tuple(proto.start_indices_batching_dims),
76+
)
7377

7478

7579
def scatter_dimension_numbers_from_proto(
7680
message) -> jax.lax.ScatterDimensionNumbers:
7781
proto = xla_data_pb2.ScatterDimensionNumbers().FromString(message)
7882
return jax.lax.ScatterDimensionNumbers(
79-
tuple(proto.update_window_dims), tuple(proto.inserted_window_dims),
80-
tuple(proto.scatter_dims_to_operand_dims))
83+
tuple(proto.update_window_dims),
84+
tuple(proto.inserted_window_dims),
85+
tuple(proto.scatter_dims_to_operand_dims),
86+
tuple(proto.input_batching_dims),
87+
tuple(proto.scatter_indices_batching_dims),
88+
)
8189

8290

8391
def precision_config_from_proto(

0 commit comments

Comments
 (0)