Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-negative least squares (NNLS) solver. #1155

Merged
merged 1 commit into from
Feb 16, 2025

Conversation

carlosgmartin
Copy link
Contributor

Fixes #1152.

@carlosgmartin
Copy link
Contributor Author

@rdyro How does this look?

@carlosgmartin carlosgmartin force-pushed the nnls branch 2 times, most recently from a503ed5 to fe9fc6b Compare February 4, 2025 19:02
@rdyro
Copy link
Collaborator

rdyro commented Feb 5, 2025

@carlosgmartin This looks great, I left one comment!

@carlosgmartin
Copy link
Contributor Author

@rdyro Where is this comment?

optax/_src/linear_algebra.py Outdated Show resolved Hide resolved
@rdyro
Copy link
Collaborator

rdyro commented Feb 5, 2025

@rdyro Where is this comment?

Oops, should be up now!

@carlosgmartin carlosgmartin force-pushed the nnls branch 2 times, most recently from e654baf to e5b6d02 Compare February 5, 2025 21:18
@rdyro
Copy link
Collaborator

rdyro commented Feb 5, 2025

# We use lstsq with a pre-computed AtA to reduce computation time.
s = jnp.linalg.lstsq(AtA * p[:, None] * p[None, :], Atb * p)[0]

If speed is a consideration, perhaps we should use the jnp.linlag.lsqt(A * p, b) directly, letting XLA optimize as necessary?

I'd be curious if we could introduce a once-factorized cholesky version of this algorithm where we repeatedly apply cho_solve to a masked L or U factor of A^T A?

@carlosgmartin @fabianp ?

@carlosgmartin
Copy link
Contributor Author

@rdyro Edited.

optax/_src/linear_algebra.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this PR!

@rdyro
Copy link
Collaborator

rdyro commented Feb 11, 2025

I ran the tests internally and I can't pass the tolerance checks for this algorithm for larger dimensions (e.g., n=24 d=64). Could you take a look?

@carlosgmartin
Copy link
Contributor Author

Possibly related: jax.numpy.linalg.lstsq seems to be buggy.

@rdyro
Copy link
Collaborator

rdyro commented Feb 13, 2025

NaNs propagate so we'd see NaNs in the output, but the algorithm seems to not be converging instead.

@carlosgmartin carlosgmartin changed the title Add non-negative least squares solver. Add non-negative least squares (NNLS) solver. Feb 13, 2025
@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Feb 13, 2025

True. I also tried the JAX_DEBUG_NANS=True flag.

The issue might be that the arithmetic operations of the algorithm accumulate numerical error, and thus the inequality comparisons (which are discontinuous) yield incorrect results. I'm taking a closer look.

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Feb 13, 2025

Due to the above difficulties, I switched to the fast projected gradient (FPG) algorithm of Polyak 2015. A clear presentation of the algorithm can be found here.

It has some advantages described in the paper's introduction, including not solving a linear system on each iteration (only doing matrix-vector multiplications instead).

@carlosgmartin carlosgmartin force-pushed the nnls branch 3 times, most recently from 9369da3 to 1fe7954 Compare February 13, 2025 23:29
@rdyro
Copy link
Collaborator

rdyro commented Feb 13, 2025

Nice! Maybe we can explore also adding an ADMM version in a future PR https://stanford.edu/class/ee364b/lectures/admm_slides.pdf since it's also promising a $\mathcal{O}(1/k)$ convergence (requires resolves, but with only a single factorization at the beginning).

optax/_src/linear_algebra.py Outdated Show resolved Hide resolved
@carlosgmartin carlosgmartin force-pushed the nnls branch 5 times, most recently from bcd3972 to 257c8ce Compare February 14, 2025 00:11
@carlosgmartin
Copy link
Contributor Author

@rdyro Done.

optax/_src/linear_algebra.py Outdated Show resolved Hide resolved
@carlosgmartin carlosgmartin force-pushed the nnls branch 3 times, most recently from c99ad6a to fa9b5ed Compare February 14, 2025 04:51
@rdyro
Copy link
Collaborator

rdyro commented Feb 14, 2025

@rdyro Done.

Nice, this PR looks really good! Thank you!

@carlosgmartin carlosgmartin force-pushed the nnls branch 3 times, most recently from 35e80c9 to 5493d58 Compare February 15, 2025 21:45
@carlosgmartin
Copy link
Contributor Author

@rdyro Made some minor changes to handle zero matrices without producing nans.

@copybara-service copybara-service bot merged commit db42abd into google-deepmind:main Feb 16, 2025
11 checks passed
@carlosgmartin carlosgmartin deleted the nnls branch February 16, 2025 19:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add nnls (non-negative least squares)
3 participants