-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about optimisation using simulate_terminal_values #452
Comments
Hi! Thanks for dropping by :) What exactly is your question? If it is about using ProbDiffEq for NODEs: Reverse-mode differentiation of For example, here is an example notebook that does something similar to your example but with two differences:
Does this help? What do you think? |
Thanks for the quick response. I did see that example, but as I understand it, it requires you to have data along the path. If you have dataset where you only have a set of input locations and terminal locations, I couldn't see how I could use that example? I am interested in a probabilistic solution to such a setup. |
Ah, I see! Essentially, one would only replace To adapt the NODE example notebook (including the loss function I mentioned above), replace the loss_fn with something like the logposterior_fn from the sampling example (i.e. the BlackJAX example, which deals with terminal-value data):
In general, the sampling example might be useful to look at if you deal with terminal value data. Does this help? |
Great, thank you very much for your help. My misunderstanding was that the fixed grid methods were for use exclusively on datasets in which you have trajectory data, not just the terminal values. FYI, Diffrax makes use of special loops (which are defined in Equinox) including ones to efficiently handle adaptive solves, and allow for reverse mode differentiation. I would have thought you could probably build off those for your use case. I will give it a try and see how I get on with my actual use case. |
Awesome, glad to hear that! If you run into more problems/misunderstandings, don't hesitate to ask more questions.
Yes, I am aware of the bounded while loops and see how such functionality could be helpful. Feel free to close this issue if your original question is resolved; if not, let me know. |
Sorry if this is a stupid question, when handling a problem where we have n points in m dimensional space (e.g. 10 x 2d) in which we know their initial location and final position, lets call them X and Y. After looking at this example https://pnkraemer.github.io/probdiffeq/benchmarks/pleiades/external/ , I am right in thinking that the initial_values is a flattened version of X i.e. tuple, where first element is array of shape (nm,) e.g (20,) ? And then we reshape back to (10,2) in the f that handles the vector field function? And in terms of
here data refers to flatten version of Y e.g. shape (20,) ? |
Are you referring to matrix-valued differential equations? I.e. d/dt M(t) = f(M(t)), where M(t) is a matrix, not a vector? In this case, I'd say you're right; rewriting this equation as a vector-valued (i.e. flattened) version seems to make sense. Instead of a (10,2)-shaped equation, one would solve a (20,)-shaped equation, and all derived quantities (e.g. Does that help? |
I am trying to learn a vector flow field, as defined by an NODE, which models the advection of a set of points, given we know their start and end locations in 2d. |
I see. I think that, for the moment, "flattening the equation" is the best way forward. Since we're kind of drifting away from the original question (about simulate_terminal_values), I will close this issue for now. Please reopen if the original question has not been answered yet! Let's move the discussion about matrix-valued equations to #457 :) And please feel invited to open more issues if you run into more problems! |
I apologise in advance as I may have misunderstood something obvious, as I haven't used probabilistic ODE solvers before and am coming from using Diffrax.
If one wants to use simulate_terminal_values when using a NODE, due to the use of lax.while_loop in the _advance_ivp_solution_adaptively method this isn't going to be possible e.g. such as in the silly minimal example shown below, because lax.while_loop doesn't support reverse mode optimisation.
The text was updated successfully, but these errors were encountered: