From d10c129dc0b14e503de5a4d93eab88c69c78f0d7 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Mon, 13 Jan 2025 10:36:43 -0800 Subject: [PATCH] Preliminary parts needed for ragged support, including densification. Added `ragged` option to `KerasTensor`, `InputLayer` and `convert_to_tensor`. The logic is the same as for sparse tensors. Fixes https://github.com/keras-team/keras/issues/20731 --- keras/src/backend/common/keras_tensor.py | 26 +++++++++++++++--- keras/src/backend/common/keras_tensor_test.py | 22 +++++++++++++++ keras/src/backend/jax/__init__.py | 1 + keras/src/backend/jax/core.py | 5 +++- keras/src/backend/numpy/__init__.py | 1 + keras/src/backend/numpy/core.py | 5 +++- keras/src/backend/openvino/__init__.py | 1 + keras/src/backend/openvino/core.py | 5 +++- keras/src/backend/tensorflow/__init__.py | 1 + keras/src/backend/tensorflow/core.py | 5 +++- keras/src/backend/torch/__init__.py | 1 + keras/src/backend/torch/core.py | 5 +++- keras/src/layers/core/input_layer.py | 27 ++++++++++++++++--- keras/src/layers/core/input_layer_test.py | 14 ++++++++-- keras/src/ops/core.py | 14 +++++++--- keras/src/ops/core_test.py | 25 ++++++++++++++++- 16 files changed, 139 insertions(+), 19 deletions(-) diff --git a/keras/src/backend/common/keras_tensor.py b/keras/src/backend/common/keras_tensor.py index 1314a266dfc..75bf51b4884 100644 --- a/keras/src/backend/common/keras_tensor.py +++ b/keras/src/backend/common/keras_tensor.py @@ -32,6 +32,7 @@ def __init__( shape, dtype="float32", sparse=False, + ragged=False, record_history=True, name=None, ): @@ -40,6 +41,12 @@ def __init__( self._shape = backend.standardize_shape(shape) self._dtype = backend.standardize_dtype(dtype) self._sparse = bool(sparse) + self._ragged = bool(ragged) + if self._sparse and self._ragged: + raise ValueError( + "KerasTensor cannot have `sparse=True` and `ragged=True` at " + "the same time." + ) self.name = name or auto_name(self.__class__.__name__) self.record_history = record_history @@ -50,7 +57,7 @@ def shape(self): @shape.setter def shape(self, value): raise AttributeError( - f"The shape of {self.__class__.__name__} is immutable. One should " + "The `shape` attribute of KerasTensor is immutable. One should " "create a new instance of KerasTensor for this." ) @@ -61,7 +68,7 @@ def dtype(self): @dtype.setter def dtype(self, value): raise AttributeError( - f"The dtype of {self.__class__.__name__} is immutable. One should " + "The `dtype` attribute of KerasTensor is immutable. One should " "create a new instance of KerasTensor for this." ) @@ -72,7 +79,18 @@ def sparse(self): @sparse.setter def sparse(self, value): raise AttributeError( - f"The sparse of {self.__class__.__name__} is immutable. One should " + "The `sparse` attribute of KerasTensor is immutable. One should " + "create a new instance of KerasTensor for this." + ) + + @property + def ragged(self): + return self._ragged + + @ragged.setter + def ragged(self, value): + raise AttributeError( + "The `ragged` attribute of KerasTensor is immutable. One should " "create a new instance of KerasTensor for this." ) @@ -160,7 +178,7 @@ def __tf_tensor__(self, dtype=None, name=None): def __repr__(self): return ( f"" + f"sparse={self.sparse}, ragged={self.ragged}, name={self.name}>" ) def __iter__(self): diff --git a/keras/src/backend/common/keras_tensor_test.py b/keras/src/backend/common/keras_tensor_test.py index fee82223353..f4566d65449 100644 --- a/keras/src/backend/common/keras_tensor_test.py +++ b/keras/src/backend/common/keras_tensor_test.py @@ -26,11 +26,33 @@ def test_attributes(self): AttributeError, "The dtype of KerasTensor is immutable." ): x.dtype = "int32" + + def test_attributes_sparse(self): + x = keras_tensor.KerasTensor(shape=(3,), dtype="float32", sparse=True) + self.assertEqual(x.sparse, True) + + # Raise error if trying to set attributes with self.assertRaisesRegex( AttributeError, "The sparse of KerasTensor is immutable." ): x.sparse = False + def test_attributes_ragged(self): + x = keras_tensor.KerasTensor(shape=(3,), dtype="float32", ragged=True) + self.assertEqual(x.ragged, True) + + # Raise error if trying to set attributes + with self.assertRaisesRegex( + AttributeError, "The ragged of KerasTensor is immutable." + ): + x.ragged = False + + def test_init_sparse_ragged_raises(self): + with self.assertRaisesRegex( + ValueError, "cannot have `sparse=True` and `ragged=True`" + ): + keras_tensor.KerasTensor(shape=(3,), sparse=True, ragged=True) + def test_numpy_methods(self): x = keras_tensor.KerasTensor(shape=(3, 2), dtype="float32") diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 547a737b264..12d25effa6f 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -8,6 +8,7 @@ from keras.src.backend.jax import random from keras.src.backend.jax import tensorboard from keras.src.backend.jax.core import IS_THREAD_SAFE +from keras.src.backend.jax.core import SUPPORTS_RAGGED_TENSORS from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS from keras.src.backend.jax.core import Variable from keras.src.backend.jax.core import cast diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index c9a0c7fe083..e5d8757585c 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -15,6 +15,7 @@ from keras.src.backend.jax import distribution_lib SUPPORTS_SPARSE_TENSORS = True +SUPPORTS_RAGGED_TENSORS = False IS_THREAD_SAFE = True @@ -46,7 +47,9 @@ def __jax_array__(self): return self.value -def convert_to_tensor(x, dtype=None, sparse=True): +def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): + if ragged: + raise ValueError("`ragged=True` is not supported with jax backend") if dtype is not None: dtype = standardize_dtype(dtype) if isinstance(x, (jnp.ndarray, jax.Array)) and ( diff --git a/keras/src/backend/numpy/__init__.py b/keras/src/backend/numpy/__init__.py index 00d7d587ed1..1a9d8eeb791 100644 --- a/keras/src/backend/numpy/__init__.py +++ b/keras/src/backend/numpy/__init__.py @@ -7,6 +7,7 @@ from keras.src.backend.numpy import numpy from keras.src.backend.numpy import random from keras.src.backend.numpy.core import IS_THREAD_SAFE +from keras.src.backend.numpy.core import SUPPORTS_RAGGED_TENSORS from keras.src.backend.numpy.core import SUPPORTS_SPARSE_TENSORS from keras.src.backend.numpy.core import Variable from keras.src.backend.numpy.core import cast diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py index b4b34755a0e..2064d81734b 100644 --- a/keras/src/backend/numpy/core.py +++ b/keras/src/backend/numpy/core.py @@ -15,6 +15,7 @@ from keras.src.backend.common.symbolic_scope import SymbolicScope SUPPORTS_SPARSE_TENSORS = False +SUPPORTS_RAGGED_TENSORS = False IS_THREAD_SAFE = True @@ -33,9 +34,11 @@ def __array__(self): return self.value -def convert_to_tensor(x, dtype=None, sparse=None): +def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): if sparse: raise ValueError("`sparse=True` is not supported with numpy backend") + if ragged: + raise ValueError("`ragged=True` is not supported with numpy backend") if dtype is not None: dtype = standardize_dtype(dtype) if isinstance(x, Variable): diff --git a/keras/src/backend/openvino/__init__.py b/keras/src/backend/openvino/__init__.py index d9148e0a049..0612260452e 100644 --- a/keras/src/backend/openvino/__init__.py +++ b/keras/src/backend/openvino/__init__.py @@ -7,6 +7,7 @@ from keras.src.backend.openvino import numpy from keras.src.backend.openvino import random from keras.src.backend.openvino.core import IS_THREAD_SAFE +from keras.src.backend.openvino.core import SUPPORTS_RAGGED_TENSORS from keras.src.backend.openvino.core import SUPPORTS_SPARSE_TENSORS from keras.src.backend.openvino.core import Variable from keras.src.backend.openvino.core import cast diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index 4daf2a61bbb..44866aa188b 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -19,6 +19,7 @@ from keras.src.backend.common.stateless_scope import StatelessScope SUPPORTS_SPARSE_TENSORS = False +SUPPORTS_RAGGED_TENSORS = False IS_THREAD_SAFE = True OPENVINO_DTYPES = { @@ -367,9 +368,11 @@ def _get_first_element(x): return None -def convert_to_tensor(x, dtype=None, sparse=None): +def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): if sparse: raise ValueError("`sparse=True` is not supported with openvino backend") + if ragged: + raise ValueError("`ragged=True` is not supported with openvino backend") if isinstance(x, OpenVINOKerasTensor): return x elif isinstance(x, np.ndarray): diff --git a/keras/src/backend/tensorflow/__init__.py b/keras/src/backend/tensorflow/__init__.py index 56211cebb50..ea4eed39b8d 100644 --- a/keras/src/backend/tensorflow/__init__.py +++ b/keras/src/backend/tensorflow/__init__.py @@ -8,6 +8,7 @@ from keras.src.backend.tensorflow import random from keras.src.backend.tensorflow import tensorboard from keras.src.backend.tensorflow.core import IS_THREAD_SAFE +from keras.src.backend.tensorflow.core import SUPPORTS_RAGGED_TENSORS from keras.src.backend.tensorflow.core import SUPPORTS_SPARSE_TENSORS from keras.src.backend.tensorflow.core import Variable from keras.src.backend.tensorflow.core import cast diff --git a/keras/src/backend/tensorflow/core.py b/keras/src/backend/tensorflow/core.py index 79978250fdc..13ab042d6ff 100644 --- a/keras/src/backend/tensorflow/core.py +++ b/keras/src/backend/tensorflow/core.py @@ -19,6 +19,7 @@ from keras.src.utils.naming import auto_name SUPPORTS_SPARSE_TENSORS = True +SUPPORTS_RAGGED_TENSORS = True # https://github.com/tensorflow/tensorflow/issues/78338 IS_THREAD_SAFE = False @@ -122,9 +123,11 @@ def _map_aggregation(self, aggregation): return mapping[aggregation] -def convert_to_tensor(x, dtype=None, sparse=None): +def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): if isinstance(x, tf.SparseTensor) and sparse is not None and not sparse: x = sparse_to_dense(x) + if isinstance(x, tf.RaggedTensor) and ragged is not None and not ragged: + x = x.to_tensor() if dtype is not None: dtype = standardize_dtype(dtype) if not tf.is_tensor(x): diff --git a/keras/src/backend/torch/__init__.py b/keras/src/backend/torch/__init__.py index 2632d6fd896..371a62cd0f5 100644 --- a/keras/src/backend/torch/__init__.py +++ b/keras/src/backend/torch/__init__.py @@ -23,6 +23,7 @@ from keras.src.backend.torch import numpy from keras.src.backend.torch import random from keras.src.backend.torch.core import IS_THREAD_SAFE +from keras.src.backend.torch.core import SUPPORTS_RAGGED_TENSORS from keras.src.backend.torch.core import SUPPORTS_SPARSE_TENSORS from keras.src.backend.torch.core import Variable from keras.src.backend.torch.core import cast diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 7188f1e45c1..bc655e8e207 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -20,6 +20,7 @@ from keras.src.backend.config import floatx SUPPORTS_SPARSE_TENSORS = False +SUPPORTS_RAGGED_TENSORS = False IS_THREAD_SAFE = True # Some operators such as 'aten::_foreach_mul_.Scalar' @@ -185,9 +186,11 @@ def __eq__(self, other): return False -def convert_to_tensor(x, dtype=None, sparse=None): +def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): if sparse: raise ValueError("`sparse=True` is not supported with torch backend") + if ragged: + raise ValueError("`ragged=True` is not supported with torch backend") if isinstance(x, Variable): # TorchDynamo has bugs supporting nn.Parameter type check. # Return it directly instead of pass it to the rest of the logic in the diff --git a/keras/src/layers/core/input_layer.py b/keras/src/layers/core/input_layer.py index 8a45178456c..abad4617e90 100644 --- a/keras/src/layers/core/input_layer.py +++ b/keras/src/layers/core/input_layer.py @@ -14,13 +14,13 @@ def __init__( batch_size=None, dtype=None, sparse=None, + ragged=None, batch_shape=None, input_tensor=None, optional=False, name=None, **kwargs, ): - # TODO: support for ragged. super().__init__(name=name) if "input_shape" in kwargs: @@ -97,12 +97,23 @@ def __init__( self.sparse = bool(sparse) if self.sparse and not backend.SUPPORTS_SPARSE_TENSORS: raise ValueError( - "`sparse=True` is not supported with backend: " - f"{backend.backend()}" + f"`sparse=True` is not supported with the {backend.backend()} " + "backend" ) + self.ragged = bool(ragged) + if self.ragged and not backend.SUPPORTS_RAGGED_TENSORS: + raise ValueError( + f"`ragged=True` is not supported with the {backend.backend()} " + "backend" + ) + if input_tensor is None: input_tensor = backend.KerasTensor( - shape=batch_shape, dtype=dtype, sparse=sparse, name=name + shape=batch_shape, + dtype=dtype, + sparse=sparse, + ragged=ragged, + name=name, ) self._input_tensor = input_tensor Node(operation=self, call_args=(), call_kwargs={}, outputs=input_tensor) @@ -125,6 +136,7 @@ def get_config(self): "batch_shape": self.batch_shape, "dtype": self.dtype, "sparse": self.sparse, + "ragged": self.ragged, "name": self.name, } @@ -135,6 +147,7 @@ def Input( batch_size=None, dtype=None, sparse=None, + ragged=None, batch_shape=None, name=None, tensor=None, @@ -163,6 +176,11 @@ def Input( sparse: A boolean specifying whether the expected input will be sparse tensors. Note that, if `sparse` is `False`, sparse tensors can still be passed into the input - they will be densified with a default + value of 0. This feature is only supported with the TensorFlow and + the JAX backends. Defaults to `False`. + ragged: A boolean specifying whether the expected input will be ragged + tensors. Note that, if `ragged` is `False`, ragged tensors can still + be passed into the input - they will be densified with a default value of 0. This feature is only supported with the TensorFlow backend. Defaults to `False`. batch_shape: Optional shape tuple (tuple of integers or `None` objects), @@ -193,6 +211,7 @@ def Input( batch_size=batch_size, dtype=dtype, sparse=sparse, + ragged=ragged, batch_shape=batch_shape, name=name, input_tensor=tensor, diff --git a/keras/src/layers/core/input_layer_test.py b/keras/src/layers/core/input_layer_test.py index d3389440d2d..766a07edb63 100644 --- a/keras/src/layers/core/input_layer_test.py +++ b/keras/src/layers/core/input_layer_test.py @@ -11,11 +11,12 @@ class InputLayerTest(testing.TestCase): # Testing happy path for layer without input tensor @parameterized.named_parameters( [ - {"testcase_name": "dense", "sparse": False}, + {"testcase_name": "dense"}, {"testcase_name": "sparse", "sparse": True}, + {"testcase_name": "ragged", "ragged": True}, ] ) - def test_input_basic(self, sparse): + def test_input_basic(self, sparse=False, ragged=False): input_shape = (2, 3) batch_size = 4 dtype = "float32" @@ -26,6 +27,7 @@ def test_input_basic(self, sparse): "batch_size": batch_size, "dtype": dtype, "sparse": sparse, + "ragged": ragged, } if sparse and not backend.SUPPORTS_SPARSE_TENSORS: @@ -34,6 +36,12 @@ def test_input_basic(self, sparse): ): InputLayer(**init_kwargs) return + if ragged and not backend.SUPPORTS_RAGGED_TENSORS: + with self.assertRaisesRegex( + ValueError, "`ragged=True` is not supported" + ): + InputLayer(**init_kwargs) + return values = InputLayer(**init_kwargs) @@ -41,11 +49,13 @@ def test_input_basic(self, sparse): self.assertEqual(values.batch_shape[0], batch_size) self.assertEqual(values.batch_shape[1:], input_shape) self.assertEqual(values.sparse, sparse) + self.assertEqual(values.ragged, ragged) self.assertEqual(values.trainable, True) self.assertIsInstance(values.output, KerasTensor) self.assertEqual(values.output.ndim, ndim) self.assertEqual(values.output.dtype, dtype) self.assertEqual(values.output.sparse, sparse) + self.assertEqual(values.output.ragged, ragged) # Testing shape is not None and batch_shape is not None condition def test_input_error1(self): diff --git a/keras/src/ops/core.py b/keras/src/ops/core.py index 65fab841342..78d38d0a397 100644 --- a/keras/src/ops/core.py +++ b/keras/src/ops/core.py @@ -929,8 +929,11 @@ def compute_output_spec(self, x): @keras_export("keras.ops.convert_to_tensor") -def convert_to_tensor(x, dtype=None, sparse=None): - """Convert a NumPy array to a tensor. +def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): + """Convert a NumPy array or Python array to a tensor. + + Native tensors for the current backend or left unchanged unless the `dtype`, + `sparse` or `ragged` arguments are set. Args: x: A NumPy array, Python array (can be nested) or a backend tensor. @@ -938,6 +941,9 @@ def convert_to_tensor(x, dtype=None, sparse=None): sparse: Whether to keep sparse tensors. `False` will cause sparse tensors to be densified. The default value of `None` means that sparse tensors are kept only if the backend supports them. + ragged: Whether to keep ragged tensors. `False` will cause ragged + tensors to be densified. The default value of `None` means that + ragged tensors are kept only if the backend supports them. Returns: A backend tensor of the specified `dtype` and sparseness. @@ -949,7 +955,9 @@ def convert_to_tensor(x, dtype=None, sparse=None): """ if any_symbolic_tensors((x,)): return ConvertToTensor(dtype=dtype, sparse=sparse)(x) - return backend.core.convert_to_tensor(x, dtype=dtype, sparse=sparse) + return backend.core.convert_to_tensor( + x, dtype=dtype, sparse=sparse, ragged=ragged + ) @keras_export("keras.ops.convert_to_numpy") diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index 58072daf7e2..e4b220bf986 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -644,7 +644,7 @@ def test_shape_sparse(self): self.assertAllEqual(core.shape(x), (2, 3)) @pytest.mark.skipif( - backend.backend() != "tensorflow", + not backend.SUPPORTS_SPARSE_TENSORS, reason="Backend does not support ragged tensors.", ) def test_shape_ragged(self): @@ -704,6 +704,29 @@ def test_convert_to_tensor_sparse(self): self.assertIsInstance(x_numpy, np.ndarray) self.assertAllClose(x_numpy, x_dense) + @pytest.mark.skipif( + not backend.SUPPORTS_RAGGED_TENSORS, + reason="Backend does not support ragged tensors.", + ) + def test_convert_to_tensor_ragged(self): + import tensorflow as tf + + x = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) + + x_default = ops.convert_to_tensor(x) + self.assertIsInstance(x_default, tf.RaggedTensor) + self.assertAllClose(x, x_default) + x_ragged = ops.convert_to_tensor(x, ragged=True) + self.assertIsInstance(x_ragged, tf.RaggedTensor) + self.assertAllClose(x, x_ragged) + x_dense = ops.convert_to_tensor(x, ragged=False) + self.assertNotIsInstance(x_dense, tf.RaggedTensor) + self.assertAllClose(x, x_dense) + + x_numpy = ops.convert_to_numpy(x) + self.assertIsInstance(x_numpy, np.ndarray) + self.assertAllClose(x_numpy, x_dense) + def test_cond(self): t = ops.cond(True, lambda: 0, lambda: 1) self.assertEqual(t, 0)