From 315cd25511aa6a0e700fdd1e7a16f275ca6ec994 Mon Sep 17 00:00:00 2001 From: "Andrew C. Sweet" Date: Tue, 28 Jan 2025 02:50:37 -0800 Subject: [PATCH 1/2] fix error with loading models with h5 and core while_loop --- keras/src/backend/mlx/core.py | 48 ++++++++++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/keras/src/backend/mlx/core.py b/keras/src/backend/mlx/core.py index 4c51050a9fc..e4b19b22391 100644 --- a/keras/src/backend/mlx/core.py +++ b/keras/src/backend/mlx/core.py @@ -7,6 +7,11 @@ from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope +try: + import h5py +except ImportError: + h5py = None + SUPPORTS_SPARSE_TENSORS = False MLX_DTYPES = { @@ -55,6 +60,13 @@ def __array__(self, dtype=None): return value +def _is_h5py_dataset(obj): + return ( + type(obj).__module__.startswith("h5py.") + and type(obj).__name__ == "Dataset" + ) + + def convert_to_tensor(x, dtype=None, sparse=None): if sparse: raise ValueError("`sparse=True` is not supported with mlx backend") @@ -89,6 +101,14 @@ def to_scalar_list(x): return mx.array(to_scalar_list(x), dtype=mlx_dtype) + if _is_h5py_dataset(x): + if h5py is None: + raise ImportError( + "h5py must be installed in order to load a model." + ) + # load h5py._hl.dataset.Dataset object with numpy + x = np.array(x) + return mx.array(x, dtype=mlx_dtype) @@ -279,18 +299,32 @@ def while_loop( loop_vars, maximum_iterations=None, ): - # TODO: How should we avoid evaluating cond when tracing? current_iter = 0 iteration_check = ( lambda iter: maximum_iterations is None or iter < maximum_iterations ) - loop_vars = tuple([convert_to_tensor(v) for v in loop_vars]) - while cond(*loop_vars) and iteration_check(current_iter): - loop_vars = body(*loop_vars) - if not isinstance(loop_vars, (list, tuple)): - loop_vars = (loop_vars,) - loop_vars = tuple(loop_vars) + + is_sequence = isinstance(loop_vars, (tuple, list)) + + if is_sequence: + loop_vars = tuple(convert_to_tensor(v) for v in loop_vars) + else: + loop_vars = tree.map_structure(convert_to_tensor, loop_vars) + + while ( + cond(*loop_vars) if is_sequence else cond(loop_vars) + ) and iteration_check(current_iter): + new_vars = body(*loop_vars) if is_sequence else body(loop_vars) + + if is_sequence: + if not isinstance(new_vars, (tuple, list)): + new_vars = (new_vars,) + loop_vars = tuple(convert_to_tensor(v) for v in new_vars) + else: + loop_vars = tree.map_structure(convert_to_tensor, new_vars) + current_iter += 1 + return loop_vars From 45929777dd55d08648467f5c33e69e938517ecba Mon Sep 17 00:00:00 2001 From: "Andrew C. Sweet" Date: Tue, 28 Jan 2025 15:14:44 -0800 Subject: [PATCH 2/2] adjust h5py import error --- keras/src/backend/mlx/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/mlx/core.py b/keras/src/backend/mlx/core.py index e4b19b22391..011e93149c3 100644 --- a/keras/src/backend/mlx/core.py +++ b/keras/src/backend/mlx/core.py @@ -104,7 +104,7 @@ def to_scalar_list(x): if _is_h5py_dataset(x): if h5py is None: raise ImportError( - "h5py must be installed in order to load a model." + "h5py must be installed in order to load HDF5 datasets." ) # load h5py._hl.dataset.Dataset object with numpy x = np.array(x)