Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Optax-based LBFGS #749

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

Conversation

aphc14
Copy link

@aphc14 aphc14 commented Oct 26, 2024

This (draft) PR adds the Optax-based LBFGS optimiser. Relates to Blackjax-devs/blackjax #704.

There are a lot of repeated codes between _minimize_lbfgs and optax_lbfgs. I am also seeking feedback or tips on how to reduce the code repetitions.

Minor changes:

  • Changed pathfinder.py to truncate LBFGS history based on convergence point to reduce the number of inverse hessian computations.

Testing:

  • Added a class called TestOptaxLBFGS to tests/optimizers/test_optimizers.py so that the Optax-based LBFGS can be checked for convergence and the last index is correctly checked.

Updated lbfgs to truncate history to reduce inverse hessian computation.
- Introduced a new test class  to verify the functionality and convergence of the  function.
- Implemented a test method  that checks the consistency of the history and convergence of the optimizer.
@junpenglao
Copy link
Member

Feel free to replace the jaxopt lbfgs directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants