From d74feb74e2ccf199fb4f7f97ff6bf0fdaa55a313 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Fri, 28 Jun 2024 05:40:41 -0700 Subject: [PATCH] [JAX] add support for gather/scatter batching dims following the new attributes in stablehlo. This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota. See https://github.com/openxla/stablehlo/pull/2259 PiperOrigin-RevId: 647647825 --- tf2jax/_src/xla_utils.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tf2jax/_src/xla_utils.py b/tf2jax/_src/xla_utils.py index e97f53e..9793239 100644 --- a/tf2jax/_src/xla_utils.py +++ b/tf2jax/_src/xla_utils.py @@ -68,16 +68,24 @@ def gather_dimension_numbers_from_proto( message) -> jax.lax.GatherDimensionNumbers: proto = xla_data_pb2.GatherDimensionNumbers().FromString(message) return jax.lax.GatherDimensionNumbers( - tuple(proto.offset_dims), tuple(proto.collapsed_slice_dims), - tuple(proto.start_index_map)) + tuple(proto.offset_dims), + tuple(proto.collapsed_slice_dims), + tuple(proto.start_index_map), + tuple(proto.operand_batching_dims), + tuple(proto.start_indices_batching_dims), + ) def scatter_dimension_numbers_from_proto( message) -> jax.lax.ScatterDimensionNumbers: proto = xla_data_pb2.ScatterDimensionNumbers().FromString(message) return jax.lax.ScatterDimensionNumbers( - tuple(proto.update_window_dims), tuple(proto.inserted_window_dims), - tuple(proto.scatter_dims_to_operand_dims)) + tuple(proto.update_window_dims), + tuple(proto.inserted_window_dims), + tuple(proto.scatter_dims_to_operand_dims), + tuple(proto.input_batching_dims), + tuple(proto.scatter_indices_batching_dims), + ) def precision_config_from_proto(