Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlx - fix error with loading models with h5 and update mlx.core.while_loop #20819

Merged
merged 2 commits into from
Jan 29, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions keras/src/backend/mlx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.stateless_scope import StatelessScope

try:
import h5py
except ImportError:
h5py = None

SUPPORTS_SPARSE_TENSORS = False

MLX_DTYPES = {
Expand Down Expand Up @@ -55,6 +60,13 @@ def __array__(self, dtype=None):
return value


def _is_h5py_dataset(obj):
return (
type(obj).__module__.startswith("h5py.")
and type(obj).__name__ == "Dataset"
)


def convert_to_tensor(x, dtype=None, sparse=None):
if sparse:
raise ValueError("`sparse=True` is not supported with mlx backend")
Expand Down Expand Up @@ -89,6 +101,14 @@ def to_scalar_list(x):

return mx.array(to_scalar_list(x), dtype=mlx_dtype)

if _is_h5py_dataset(x):
if h5py is None:
raise ImportError(
"h5py must be installed in order to load HDF5 datasets."
)
# load h5py._hl.dataset.Dataset object with numpy
x = np.array(x)

return mx.array(x, dtype=mlx_dtype)


Expand Down Expand Up @@ -279,18 +299,32 @@ def while_loop(
loop_vars,
maximum_iterations=None,
):
# TODO: How should we avoid evaluating cond when tracing?
current_iter = 0
iteration_check = (
lambda iter: maximum_iterations is None or iter < maximum_iterations
)
loop_vars = tuple([convert_to_tensor(v) for v in loop_vars])
while cond(*loop_vars) and iteration_check(current_iter):
loop_vars = body(*loop_vars)
if not isinstance(loop_vars, (list, tuple)):
loop_vars = (loop_vars,)
loop_vars = tuple(loop_vars)

is_sequence = isinstance(loop_vars, (tuple, list))

if is_sequence:
loop_vars = tuple(convert_to_tensor(v) for v in loop_vars)
else:
loop_vars = tree.map_structure(convert_to_tensor, loop_vars)

while (
cond(*loop_vars) if is_sequence else cond(loop_vars)
) and iteration_check(current_iter):
new_vars = body(*loop_vars) if is_sequence else body(loop_vars)

if is_sequence:
if not isinstance(new_vars, (tuple, list)):
new_vars = (new_vars,)
loop_vars = tuple(convert_to_tensor(v) for v in new_vars)
else:
loop_vars = tree.map_structure(convert_to_tensor, new_vars)

current_iter += 1

return loop_vars


Expand Down
Loading