Skip to content

Commit

Permalink
Preliminary parts needed for ragged support, including densification.
Browse files Browse the repository at this point in the history
Added `ragged` option to `KerasTensor`, `InputLayer` and `convert_to_tensor`. The logic is the same as for sparse tensors.

Fixes keras-team#20731
  • Loading branch information
hertschuh committed Jan 14, 2025
1 parent 509c92b commit d10c129
Show file tree
Hide file tree
Showing 16 changed files with 139 additions and 19 deletions.
26 changes: 22 additions & 4 deletions keras/src/backend/common/keras_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
shape,
dtype="float32",
sparse=False,
ragged=False,
record_history=True,
name=None,
):
Expand All @@ -40,6 +41,12 @@ def __init__(
self._shape = backend.standardize_shape(shape)
self._dtype = backend.standardize_dtype(dtype)
self._sparse = bool(sparse)
self._ragged = bool(ragged)
if self._sparse and self._ragged:
raise ValueError(
"KerasTensor cannot have `sparse=True` and `ragged=True` at "
"the same time."
)
self.name = name or auto_name(self.__class__.__name__)
self.record_history = record_history

Expand All @@ -50,7 +57,7 @@ def shape(self):
@shape.setter
def shape(self, value):
raise AttributeError(
f"The shape of {self.__class__.__name__} is immutable. One should "
"The `shape` attribute of KerasTensor is immutable. One should "
"create a new instance of KerasTensor for this."
)

Expand All @@ -61,7 +68,7 @@ def dtype(self):
@dtype.setter
def dtype(self, value):
raise AttributeError(
f"The dtype of {self.__class__.__name__} is immutable. One should "
"The `dtype` attribute of KerasTensor is immutable. One should "
"create a new instance of KerasTensor for this."
)

Expand All @@ -72,7 +79,18 @@ def sparse(self):
@sparse.setter
def sparse(self, value):
raise AttributeError(
f"The sparse of {self.__class__.__name__} is immutable. One should "
"The `sparse` attribute of KerasTensor is immutable. One should "
"create a new instance of KerasTensor for this."
)

@property
def ragged(self):
return self._ragged

@ragged.setter
def ragged(self, value):
raise AttributeError(
"The `ragged` attribute of KerasTensor is immutable. One should "
"create a new instance of KerasTensor for this."
)

Expand Down Expand Up @@ -160,7 +178,7 @@ def __tf_tensor__(self, dtype=None, name=None):
def __repr__(self):
return (
f"<KerasTensor shape={self.shape}, dtype={self.dtype}, "
f"sparse={self.sparse}, name={self.name}>"
f"sparse={self.sparse}, ragged={self.ragged}, name={self.name}>"
)

def __iter__(self):
Expand Down
22 changes: 22 additions & 0 deletions keras/src/backend/common/keras_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,33 @@ def test_attributes(self):
AttributeError, "The dtype of KerasTensor is immutable."
):
x.dtype = "int32"

def test_attributes_sparse(self):
x = keras_tensor.KerasTensor(shape=(3,), dtype="float32", sparse=True)
self.assertEqual(x.sparse, True)

# Raise error if trying to set attributes
with self.assertRaisesRegex(
AttributeError, "The sparse of KerasTensor is immutable."
):
x.sparse = False

def test_attributes_ragged(self):
x = keras_tensor.KerasTensor(shape=(3,), dtype="float32", ragged=True)
self.assertEqual(x.ragged, True)

# Raise error if trying to set attributes
with self.assertRaisesRegex(
AttributeError, "The ragged of KerasTensor is immutable."
):
x.ragged = False

def test_init_sparse_ragged_raises(self):
with self.assertRaisesRegex(
ValueError, "cannot have `sparse=True` and `ragged=True`"
):
keras_tensor.KerasTensor(shape=(3,), sparse=True, ragged=True)

def test_numpy_methods(self):
x = keras_tensor.KerasTensor(shape=(3, 2), dtype="float32")

Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.backend.jax import random
from keras.src.backend.jax import tensorboard
from keras.src.backend.jax.core import IS_THREAD_SAFE
from keras.src.backend.jax.core import SUPPORTS_RAGGED_TENSORS
from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.jax.core import Variable
from keras.src.backend.jax.core import cast
Expand Down
5 changes: 4 additions & 1 deletion keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from keras.src.backend.jax import distribution_lib

SUPPORTS_SPARSE_TENSORS = True
SUPPORTS_RAGGED_TENSORS = False
IS_THREAD_SAFE = True


Expand Down Expand Up @@ -46,7 +47,9 @@ def __jax_array__(self):
return self.value


def convert_to_tensor(x, dtype=None, sparse=True):
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
if ragged:
raise ValueError("`ragged=True` is not supported with jax backend")
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, (jnp.ndarray, jax.Array)) and (
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.src.backend.numpy import numpy
from keras.src.backend.numpy import random
from keras.src.backend.numpy.core import IS_THREAD_SAFE
from keras.src.backend.numpy.core import SUPPORTS_RAGGED_TENSORS
from keras.src.backend.numpy.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.numpy.core import Variable
from keras.src.backend.numpy.core import cast
Expand Down
5 changes: 4 additions & 1 deletion keras/src/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from keras.src.backend.common.symbolic_scope import SymbolicScope

SUPPORTS_SPARSE_TENSORS = False
SUPPORTS_RAGGED_TENSORS = False
IS_THREAD_SAFE = True


Expand All @@ -33,9 +34,11 @@ def __array__(self):
return self.value


def convert_to_tensor(x, dtype=None, sparse=None):
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
if sparse:
raise ValueError("`sparse=True` is not supported with numpy backend")
if ragged:
raise ValueError("`ragged=True` is not supported with numpy backend")
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, Variable):
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.src.backend.openvino import numpy
from keras.src.backend.openvino import random
from keras.src.backend.openvino.core import IS_THREAD_SAFE
from keras.src.backend.openvino.core import SUPPORTS_RAGGED_TENSORS
from keras.src.backend.openvino.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.openvino.core import Variable
from keras.src.backend.openvino.core import cast
Expand Down
5 changes: 4 additions & 1 deletion keras/src/backend/openvino/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras.src.backend.common.stateless_scope import StatelessScope

SUPPORTS_SPARSE_TENSORS = False
SUPPORTS_RAGGED_TENSORS = False
IS_THREAD_SAFE = True

OPENVINO_DTYPES = {
Expand Down Expand Up @@ -367,9 +368,11 @@ def _get_first_element(x):
return None


def convert_to_tensor(x, dtype=None, sparse=None):
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
if sparse:
raise ValueError("`sparse=True` is not supported with openvino backend")
if ragged:
raise ValueError("`ragged=True` is not supported with openvino backend")
if isinstance(x, OpenVINOKerasTensor):
return x
elif isinstance(x, np.ndarray):
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.backend.tensorflow import random
from keras.src.backend.tensorflow import tensorboard
from keras.src.backend.tensorflow.core import IS_THREAD_SAFE
from keras.src.backend.tensorflow.core import SUPPORTS_RAGGED_TENSORS
from keras.src.backend.tensorflow.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.tensorflow.core import Variable
from keras.src.backend.tensorflow.core import cast
Expand Down
5 changes: 4 additions & 1 deletion keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras.src.utils.naming import auto_name

SUPPORTS_SPARSE_TENSORS = True
SUPPORTS_RAGGED_TENSORS = True
# https://github.com/tensorflow/tensorflow/issues/78338
IS_THREAD_SAFE = False

Expand Down Expand Up @@ -122,9 +123,11 @@ def _map_aggregation(self, aggregation):
return mapping[aggregation]


def convert_to_tensor(x, dtype=None, sparse=None):
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
if isinstance(x, tf.SparseTensor) and sparse is not None and not sparse:
x = sparse_to_dense(x)
if isinstance(x, tf.RaggedTensor) and ragged is not None and not ragged:
x = x.to_tensor()
if dtype is not None:
dtype = standardize_dtype(dtype)
if not tf.is_tensor(x):
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from keras.src.backend.torch import numpy
from keras.src.backend.torch import random
from keras.src.backend.torch.core import IS_THREAD_SAFE
from keras.src.backend.torch.core import SUPPORTS_RAGGED_TENSORS
from keras.src.backend.torch.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.torch.core import Variable
from keras.src.backend.torch.core import cast
Expand Down
5 changes: 4 additions & 1 deletion keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from keras.src.backend.config import floatx

SUPPORTS_SPARSE_TENSORS = False
SUPPORTS_RAGGED_TENSORS = False
IS_THREAD_SAFE = True

# Some operators such as 'aten::_foreach_mul_.Scalar'
Expand Down Expand Up @@ -185,9 +186,11 @@ def __eq__(self, other):
return False


def convert_to_tensor(x, dtype=None, sparse=None):
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
if sparse:
raise ValueError("`sparse=True` is not supported with torch backend")
if ragged:
raise ValueError("`ragged=True` is not supported with torch backend")
if isinstance(x, Variable):
# TorchDynamo has bugs supporting nn.Parameter type check.
# Return it directly instead of pass it to the rest of the logic in the
Expand Down
27 changes: 23 additions & 4 deletions keras/src/layers/core/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ def __init__(
batch_size=None,
dtype=None,
sparse=None,
ragged=None,
batch_shape=None,
input_tensor=None,
optional=False,
name=None,
**kwargs,
):
# TODO: support for ragged.
super().__init__(name=name)

if "input_shape" in kwargs:
Expand Down Expand Up @@ -97,12 +97,23 @@ def __init__(
self.sparse = bool(sparse)
if self.sparse and not backend.SUPPORTS_SPARSE_TENSORS:
raise ValueError(
"`sparse=True` is not supported with backend: "
f"{backend.backend()}"
f"`sparse=True` is not supported with the {backend.backend()} "
"backend"
)
self.ragged = bool(ragged)
if self.ragged and not backend.SUPPORTS_RAGGED_TENSORS:
raise ValueError(
f"`ragged=True` is not supported with the {backend.backend()} "
"backend"
)

if input_tensor is None:
input_tensor = backend.KerasTensor(
shape=batch_shape, dtype=dtype, sparse=sparse, name=name
shape=batch_shape,
dtype=dtype,
sparse=sparse,
ragged=ragged,
name=name,
)
self._input_tensor = input_tensor
Node(operation=self, call_args=(), call_kwargs={}, outputs=input_tensor)
Expand All @@ -125,6 +136,7 @@ def get_config(self):
"batch_shape": self.batch_shape,
"dtype": self.dtype,
"sparse": self.sparse,
"ragged": self.ragged,
"name": self.name,
}

Expand All @@ -135,6 +147,7 @@ def Input(
batch_size=None,
dtype=None,
sparse=None,
ragged=None,
batch_shape=None,
name=None,
tensor=None,
Expand Down Expand Up @@ -163,6 +176,11 @@ def Input(
sparse: A boolean specifying whether the expected input will be sparse
tensors. Note that, if `sparse` is `False`, sparse tensors can still
be passed into the input - they will be densified with a default
value of 0. This feature is only supported with the TensorFlow and
the JAX backends. Defaults to `False`.
ragged: A boolean specifying whether the expected input will be ragged
tensors. Note that, if `ragged` is `False`, ragged tensors can still
be passed into the input - they will be densified with a default
value of 0. This feature is only supported with the TensorFlow
backend. Defaults to `False`.
batch_shape: Optional shape tuple (tuple of integers or `None` objects),
Expand Down Expand Up @@ -193,6 +211,7 @@ def Input(
batch_size=batch_size,
dtype=dtype,
sparse=sparse,
ragged=ragged,
batch_shape=batch_shape,
name=name,
input_tensor=tensor,
Expand Down
14 changes: 12 additions & 2 deletions keras/src/layers/core/input_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ class InputLayerTest(testing.TestCase):
# Testing happy path for layer without input tensor
@parameterized.named_parameters(
[
{"testcase_name": "dense", "sparse": False},
{"testcase_name": "dense"},
{"testcase_name": "sparse", "sparse": True},
{"testcase_name": "ragged", "ragged": True},
]
)
def test_input_basic(self, sparse):
def test_input_basic(self, sparse=False, ragged=False):
input_shape = (2, 3)
batch_size = 4
dtype = "float32"
Expand All @@ -26,6 +27,7 @@ def test_input_basic(self, sparse):
"batch_size": batch_size,
"dtype": dtype,
"sparse": sparse,
"ragged": ragged,
}

if sparse and not backend.SUPPORTS_SPARSE_TENSORS:
Expand All @@ -34,18 +36,26 @@ def test_input_basic(self, sparse):
):
InputLayer(**init_kwargs)
return
if ragged and not backend.SUPPORTS_RAGGED_TENSORS:
with self.assertRaisesRegex(
ValueError, "`ragged=True` is not supported"
):
InputLayer(**init_kwargs)
return

values = InputLayer(**init_kwargs)

self.assertEqual(values.dtype, dtype)
self.assertEqual(values.batch_shape[0], batch_size)
self.assertEqual(values.batch_shape[1:], input_shape)
self.assertEqual(values.sparse, sparse)
self.assertEqual(values.ragged, ragged)
self.assertEqual(values.trainable, True)
self.assertIsInstance(values.output, KerasTensor)
self.assertEqual(values.output.ndim, ndim)
self.assertEqual(values.output.dtype, dtype)
self.assertEqual(values.output.sparse, sparse)
self.assertEqual(values.output.ragged, ragged)

# Testing shape is not None and batch_shape is not None condition
def test_input_error1(self):
Expand Down
Loading

0 comments on commit d10c129

Please sign in to comment.