From e43f391954e112a52b5a5c65f952f00484495615 Mon Sep 17 00:00:00 2001 From: Kartheek Date: Sun, 5 May 2024 20:26:50 +0530 Subject: [PATCH] handle i64 for scatter and cumsum --- keras/src/backend/mlx/core.py | 10 ++++++++-- keras/src/backend/mlx/numpy.py | 6 ++++++ keras/src/ops/numpy_test.py | 2 +- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/keras/src/backend/mlx/core.py b/keras/src/backend/mlx/core.py index 3625825aa4a..d8d763d7dbc 100644 --- a/keras/src/backend/mlx/core.py +++ b/keras/src/backend/mlx/core.py @@ -71,8 +71,6 @@ def convert_to_tensor(x, dtype=None, sparse=None): return x.value if isinstance(x, np.ndarray): - if x.dtype == np.int64: - x = x.astype(np.int32) x = x.astype(standardize_dtype(x.dtype)) return mx.array(x, dtype=mlx_dtype) @@ -211,6 +209,10 @@ def vectorized_map(function, elements): def scatter(indices, values, shape): indices = convert_to_tensor(indices) values = convert_to_tensor(values) + if values.dtype == mx.int64: + values = values.astype(mx.int32) + elif values.dtype == mx.uint64: + values = values.astype(mx.uint32) zeros = mx.zeros(shape, dtype=values.dtype) indices = tuple(indices[..., i] for i in range(indices.shape[-1])) zeros = zeros.at[indices].add(values) @@ -222,6 +224,10 @@ def scatter_update(inputs, indices, updates): inputs = convert_to_tensor(inputs) indices = convert_to_tensor(indices) updates = convert_to_tensor(updates) + if inputs.dtype == mx.int64: + inputs = inputs.astype(mx.int32) + elif inputs.dtype == mx.uint64: + inputs = inputs.astype(mx.uint32) indices = tuple(indices[..., i] for i in range(indices.shape[-1])) inputs[indices] = updates diff --git a/keras/src/backend/mlx/numpy.py b/keras/src/backend/mlx/numpy.py index baff911d877..09c02f81e8f 100644 --- a/keras/src/backend/mlx/numpy.py +++ b/keras/src/backend/mlx/numpy.py @@ -278,6 +278,10 @@ def cumprod(x, axis=None, dtype=None): x = convert_to_tensor(x) if dtype is not None: x = cast(x, dtype) + if x.dtype in [mx.int64, mx.uint64]: + return mx.cumprod( + x, axis=axis, stream=mx.Device(type=mx.DeviceType.cpu) + ) return mx.cumprod(x, axis=axis) @@ -285,6 +289,8 @@ def cumsum(x, axis=None, dtype=None): x = convert_to_tensor(x) if dtype is not None: x = cast(x, dtype) + if x.dtype in [mx.int64, mx.uint64]: + return mx.cumsum(x, axis=axis, stream=mx.Device(type=mx.DeviceType.cpu)) return mx.cumsum(x, axis=axis) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index aee3a63e8ee..690be9e83eb 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -4943,7 +4943,7 @@ class NumpyDtypeTest(testing.TestCase, parameterized.TestCase): ] elif backend.backend() == "mlx": ALL_DTYPES = [x for x in ALL_DTYPES if x != "float64"] - # FLOAT_DTYPES = [x for x in FLOAT_DTYPES if x != "float64" ] + FLOAT_DTYPES = tuple([x for x in FLOAT_DTYPES if x != "float64"]) # Remove float8 dtypes for the following tests ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES]