Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
acsweet committed Jan 27, 2025
1 parent fbf7553 commit b8488f7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion keras/src/backend/mlx/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def det(a):
# TODO: Swap to mlx.linalg.det when it's implemented
# TODO: Swap to mlx.linalg.det when supported
a = jnp.array(a)
output = jnp.linalg.det(a)
return mx.array(output)
Expand Down
11 changes: 6 additions & 5 deletions keras/src/backend/mlx/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def dot(x, y):
x = x.reshape(x.shape[:-1] + (x.shape[-1],) + (1,) * (ndimy - 2))
r = x @ y
return r


def empty(shape, dtype=None):
dtype = to_mlx_dtype(dtype or config.floatx())
Expand Down Expand Up @@ -856,6 +856,7 @@ def quantile(x, q, axis=None, method="linear", keepdims=False):
result = mx.array(result)
return result


def ravel(x):
x = convert_to_tensor(x)
return x.reshape(-1)
Expand Down Expand Up @@ -884,7 +885,8 @@ def repeat(x, repeats, axis=None):

if repeats.size != x.shape[axis]:
raise ValueError(
f"repeats must have same length as axis: got {repeats.size} vs {x.shape[axis]}"
"repeats must have same length as axis: "
f"got {repeats.size} vs {x.shape[axis]}"
)

indices = mx.concatenate([mx.full(r, i) for i, r in enumerate(repeats)])
Expand Down Expand Up @@ -982,8 +984,7 @@ def tensordot(x1, x2, axes=2):
return mx.tensordot(x1, x2, axes)

raise ValueError(
"`axes` must be an integer or sequence "
f"Received: axes={axes}"
"`axes` must be an integer or sequence " f"Received: axes={axes}"
)


Expand Down Expand Up @@ -1196,7 +1197,7 @@ def select(condlist, choicelist, default=0):


def slogdet(x):
# TODO: Swap to mlx.linalg.slogdet when supported (or use LU factorization and determinant)
# TODO: Swap to mlx.linalg.slogdet when supported (or with determinant)
x = convert_to_tensor(x)
x = jnp.array(x)
output = jnp.linalg.slogdet(x)
Expand Down

0 comments on commit b8488f7

Please sign in to comment.