From b8488f7d68613004792f2ce6b7be17fc218a9141 Mon Sep 17 00:00:00 2001 From: "Andrew C. Sweet" Date: Mon, 27 Jan 2025 01:13:36 -0800 Subject: [PATCH] formatting --- keras/src/backend/mlx/linalg.py | 2 +- keras/src/backend/mlx/numpy.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/keras/src/backend/mlx/linalg.py b/keras/src/backend/mlx/linalg.py index 7d9c5645b47..484d0d56710 100644 --- a/keras/src/backend/mlx/linalg.py +++ b/keras/src/backend/mlx/linalg.py @@ -7,7 +7,7 @@ def det(a): - # TODO: Swap to mlx.linalg.det when it's implemented + # TODO: Swap to mlx.linalg.det when supported a = jnp.array(a) output = jnp.linalg.det(a) return mx.array(output) diff --git a/keras/src/backend/mlx/numpy.py b/keras/src/backend/mlx/numpy.py index da13e61fbba..6a7beaae19d 100644 --- a/keras/src/backend/mlx/numpy.py +++ b/keras/src/backend/mlx/numpy.py @@ -420,7 +420,7 @@ def dot(x, y): x = x.reshape(x.shape[:-1] + (x.shape[-1],) + (1,) * (ndimy - 2)) r = x @ y return r - + def empty(shape, dtype=None): dtype = to_mlx_dtype(dtype or config.floatx()) @@ -856,6 +856,7 @@ def quantile(x, q, axis=None, method="linear", keepdims=False): result = mx.array(result) return result + def ravel(x): x = convert_to_tensor(x) return x.reshape(-1) @@ -884,7 +885,8 @@ def repeat(x, repeats, axis=None): if repeats.size != x.shape[axis]: raise ValueError( - f"repeats must have same length as axis: got {repeats.size} vs {x.shape[axis]}" + "repeats must have same length as axis: " + f"got {repeats.size} vs {x.shape[axis]}" ) indices = mx.concatenate([mx.full(r, i) for i, r in enumerate(repeats)]) @@ -982,8 +984,7 @@ def tensordot(x1, x2, axes=2): return mx.tensordot(x1, x2, axes) raise ValueError( - "`axes` must be an integer or sequence " - f"Received: axes={axes}" + "`axes` must be an integer or sequence " f"Received: axes={axes}" ) @@ -1196,7 +1197,7 @@ def select(condlist, choicelist, default=0): def slogdet(x): - # TODO: Swap to mlx.linalg.slogdet when supported (or use LU factorization and determinant) + # TODO: Swap to mlx.linalg.slogdet when supported (or with determinant) x = convert_to_tensor(x) x = jnp.array(x) output = jnp.linalg.slogdet(x)