-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reverse Mode Differentiation #1
Comments
I have not tried tested the code with autodiff. Would you have a minimal example that I could try? If there is some easy way to get this to work, I would definitely be up for updating the code to support this. |
Solving a simple NODE like this. There are plenty of examples at provided with Diffrax and ProbDiffeq that require similar optimisation, which I presume your library could be a drop in replacement.
|
Thanks for the example! So this seems to be a fundamental issue with Maybe the best option here for you to test the parallel-in-time solvers in your specific usecase would be to fork the repository and replace the while loop with equinox's. Or if you see some other way to add reverse-diff support here, let me know and we can try and figure out how get it implemented. |
|
Another related suggestion would be to look into using https://github.com/wilson-labs/cola It is trivial to drop in replace for base JAX operators, but this enables highly efficient linear ops via multiple dispatch and lazy evaluation. It improves speed and memory efficiency significantly. |
I agree with your points! Supporting autodiff would definitely be good for a probabilistic solver library. But this repository here is first and foremost meant to support our publication on the matter and to make our experiments reproducible - it is not meant as a user-facing library with lots of features. This job is better done by libraries like |
I would like to test your library within an existing framework that I have. This involves learning the parameters of a GP which controls a vector flow field and thus I need reverse-mode differentiability in order to optimise a loss function to learn these parameters.
I don't know if you are aware that there are undocumented functions within the Equinox library for JAX which handle this,
https://github.com/patrick-kidger/equinox/tree/main/equinox/internal/_loop
The text was updated successfully, but these errors were encountered: