Skip to content

Commit

Permalink
Simplify create_diagonal_array
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Feb 6, 2025
1 parent c5e13c3 commit f66447e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ classifiers = [
"Typing :: Typed",
]
dependencies = [
"array-api-extra>=0.6.0",
"array_api_compat>=1.10",
"rich>=13.7",
"jax>=0.4.27",
"numpy>=1.25",
"optax>=0.2.3",
"rich>=13.7",
"typing_extensions>=4.11",
]

Expand Down
11 changes: 3 additions & 8 deletions tjax/_src/math_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Literal, TypeVar, overload

import jax
import array_api_extra as xpx
import numpy as np
from array_api_compat import get_namespace

Expand Down Expand Up @@ -168,13 +168,8 @@ def create_diagonal_array(m: T) -> T:
xp = get_namespace(m)
pre = m.shape[:-1]
n = m.shape[-1]
s = (*m.shape, n)
retval = xp.zeros((*pre, n ** 2), dtype=m.dtype)
for index in np.ndindex(*pre):
target_index = (*index, slice(None, None, n + 1))
source_values = m[*index, :] # type: ignore[arg-type]
if isinstance(retval, jax.Array):
retval = retval.at[target_index].set(source_values)
else:
retval[target_index] = source_values
return xp.reshape(retval, s)
xpx.at(retval)[target_index].set(m[*index, :])
return xp.reshape(retval, (*m.shape, n))
14 changes: 14 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit f66447e

Please sign in to comment.