From f66447e608c9b3597fa436bfef4f1c21bb0062f2 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Thu, 6 Feb 2025 17:13:30 -0500 Subject: [PATCH] Simplify create_diagonal_array --- pyproject.toml | 3 ++- tjax/_src/math_tools.py | 11 +++-------- uv.lock | 14 ++++++++++++++ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 05ae2d09..603aef56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,11 +27,12 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ + "array-api-extra>=0.6.0", "array_api_compat>=1.10", - "rich>=13.7", "jax>=0.4.27", "numpy>=1.25", "optax>=0.2.3", + "rich>=13.7", "typing_extensions>=4.11", ] diff --git a/tjax/_src/math_tools.py b/tjax/_src/math_tools.py index 1e39f5be..e69ad58e 100644 --- a/tjax/_src/math_tools.py +++ b/tjax/_src/math_tools.py @@ -2,7 +2,7 @@ from typing import Literal, TypeVar, overload -import jax +import array_api_extra as xpx import numpy as np from array_api_compat import get_namespace @@ -168,13 +168,8 @@ def create_diagonal_array(m: T) -> T: xp = get_namespace(m) pre = m.shape[:-1] n = m.shape[-1] - s = (*m.shape, n) retval = xp.zeros((*pre, n ** 2), dtype=m.dtype) for index in np.ndindex(*pre): target_index = (*index, slice(None, None, n + 1)) - source_values = m[*index, :] # type: ignore[arg-type] - if isinstance(retval, jax.Array): - retval = retval.at[target_index].set(source_values) - else: - retval[target_index] = source_values - return xp.reshape(retval, s) + xpx.at(retval)[target_index].set(m[*index, :]) + return xp.reshape(retval, (*m.shape, n)) diff --git a/uv.lock b/uv.lock index 3c43c932..00050c76 100644 --- a/uv.lock +++ b/uv.lock @@ -80,6 +80,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/72/76/633dffbd77631525921ab8d8867e33abd8bdb4ac64bfabd41e88ea910940/array_api_compat-1.10.0-py3-none-any.whl", hash = "sha256:d9066981fbc730174861b4394f38e27928827cbf7ed5becd8b1263b507c58864", size = 50427 }, ] +[[package]] +name = "array-api-extra" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "array-api-compat" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/da/38/bb6ec91d5a1dc23f17d733b769054c62f2c2c735029ba9e06ed7dad47c70/array_api_extra-0.6.0.tar.gz", hash = "sha256:392e6ad645a08d774e3148d04612c0e4725f79c6c6dd3f5c3dae4dd060f81b32", size = 110251 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/20/6fc693dcdebdea8b8c75afe83263bba12d2af0847ff20a04e9424890d2e9/array_api_extra-0.6.0-py3-none-any.whl", hash = "sha256:a4a3954358ee382f84d18ff3912809f35ee9b3823dec6e8bb34116b54fc5ccb9", size = 21649 }, +] + [[package]] name = "arrow" version = "1.3.0" @@ -2068,6 +2080,7 @@ version = "1.1.3" source = { editable = "." } dependencies = [ { name = "array-api-compat" }, + { name = "array-api-extra" }, { name = "jax" }, { name = "numpy" }, { name = "optax" }, @@ -2093,6 +2106,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "array-api-compat", specifier = ">=1.10" }, + { name = "array-api-extra", specifier = ">=0.6.0" }, { name = "jax", specifier = ">=0.4.27" }, { name = "numpy", specifier = ">=1.25" }, { name = "optax", specifier = ">=0.2.3" },