Skip to content

Commit

Permalink
handle i64 for scatter and cumsum
Browse files Browse the repository at this point in the history
  • Loading branch information
lkarthee committed May 6, 2024
1 parent 5a3542b commit e43f391
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
10 changes: 8 additions & 2 deletions keras/src/backend/mlx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def convert_to_tensor(x, dtype=None, sparse=None):
return x.value

if isinstance(x, np.ndarray):
if x.dtype == np.int64:
x = x.astype(np.int32)
x = x.astype(standardize_dtype(x.dtype))
return mx.array(x, dtype=mlx_dtype)

Expand Down Expand Up @@ -211,6 +209,10 @@ def vectorized_map(function, elements):
def scatter(indices, values, shape):
indices = convert_to_tensor(indices)
values = convert_to_tensor(values)
if values.dtype == mx.int64:
values = values.astype(mx.int32)
elif values.dtype == mx.uint64:
values = values.astype(mx.uint32)
zeros = mx.zeros(shape, dtype=values.dtype)
indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
zeros = zeros.at[indices].add(values)
Expand All @@ -222,6 +224,10 @@ def scatter_update(inputs, indices, updates):
inputs = convert_to_tensor(inputs)
indices = convert_to_tensor(indices)
updates = convert_to_tensor(updates)
if inputs.dtype == mx.int64:
inputs = inputs.astype(mx.int32)
elif inputs.dtype == mx.uint64:
inputs = inputs.astype(mx.uint32)
indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
inputs[indices] = updates

Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/mlx/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,19 @@ def cumprod(x, axis=None, dtype=None):
x = convert_to_tensor(x)
if dtype is not None:
x = cast(x, dtype)
if x.dtype in [mx.int64, mx.uint64]:
return mx.cumprod(
x, axis=axis, stream=mx.Device(type=mx.DeviceType.cpu)
)
return mx.cumprod(x, axis=axis)


def cumsum(x, axis=None, dtype=None):
x = convert_to_tensor(x)
if dtype is not None:
x = cast(x, dtype)
if x.dtype in [mx.int64, mx.uint64]:
return mx.cumsum(x, axis=axis, stream=mx.Device(type=mx.DeviceType.cpu))
return mx.cumsum(x, axis=axis)


Expand Down
2 changes: 1 addition & 1 deletion keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4943,7 +4943,7 @@ class NumpyDtypeTest(testing.TestCase, parameterized.TestCase):
]
elif backend.backend() == "mlx":
ALL_DTYPES = [x for x in ALL_DTYPES if x != "float64"]
# FLOAT_DTYPES = [x for x in FLOAT_DTYPES if x != "float64" ]
FLOAT_DTYPES = tuple([x for x in FLOAT_DTYPES if x != "float64"])
# Remove float8 dtypes for the following tests
ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES]

Expand Down

0 comments on commit e43f391

Please sign in to comment.