We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
linalg.lstsq
jax.numpy.linalg.lstsq sometimes incorrectly returns nans. Example:
jax.numpy.linalg.lstsq
from jax import numpy as jnp A = jnp.zeros((2, 2)) b = jnp.zeros(2) sol = jnp.linalg.lstsq(A, b) print(sol) # (Array([nan, nan], dtype=float32), Array([nan], dtype=float32), Array(2, dtype=int32), Array([0., 0.], dtype=float32))
For comparison, using pinv yields the correct answer:
pinv
x = jnp.linalg.pinv(A) @ b print(x) # [0. 0.]
jax: 0.5.1.dev20250204+687131e98 jaxlib: 0.5.0 numpy: 1.26.4 python: 3.12.7 (main, Oct 1 2024, 02:05:46) [Clang 15.0.0 (clang-1500.3.9.4)] device info: cpu-1, 1 local devices" process_count: 1 platform: uname_result(system='Darwin', node='Carloss-MacBook-Pro-2.local', release='24.3.0', version='Darwin Kernel Version 24.3.0: Thu Jan 2 20:24:23 PST 2025; root:xnu-11215.81.4~3/RELEASE_ARM64_T6031', machine='arm64')
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Description
jax.numpy.linalg.lstsq
sometimes incorrectly returns nans. Example:For comparison, using
pinv
yields the correct answer:System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: