Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about optimisation using simulate_terminal_values #452

Closed
adam-hartshorne opened this issue Mar 12, 2023 · 9 comments
Closed

Question about optimisation using simulate_terminal_values #452

adam-hartshorne opened this issue Mar 12, 2023 · 9 comments

Comments

@adam-hartshorne
Copy link

I apologise in advance as I may have misunderstood something obvious, as I haven't used probabilistic ODE solvers before and am coming from using Diffrax.

If one wants to use simulate_terminal_values when using a NODE, due to the use of lax.while_loop in the _advance_ivp_solution_adaptively method this isn't going to be possible e.g. such as in the silly minimal example shown below, because lax.while_loop doesn't support reverse mode optimisation.

import jax
from jax import grad, jit, 
import jax.numpy as jnp
import optax
from diffeqzoo import backend, ivps
from probdiffeq import solvers, solution_routines
from probdiffeq.implementations import filters
from probdiffeq.strategies import filters

backend.select("jax")
f, u0, (t0, t1), f_args = ivps.neural_ode_mlp(layer_sizes=(2, 20, 1))
yt = 5.0

@jax.jit
def vf(y, *, t, p):
    return f(y, t, *p)

strategy = filters.Filter(recipes.IsoTS0.from_params(num_derivatives=4))
solver = solvers.DynamicSolver(strategy)

optim = optax.adam(learning_rate=1e-2)
p = f_args
state = optim.init(p)

def loss_fn(p):

    ekf0sol = solution_routines.simulate_terminal_values(
        vf,
        initial_values=(u0,),
        t0=0.0,
        t1=1.0,
        solver=solver,
        parameters=p,
    )

    return jnp.mean(jnp.square(ekf0sol.u - yt))


@jax.jit
def update(params, opt_state):
        loss, grads = jax.value_and_grad(loss_fn)(params)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state

for i in range(config.num_iterations):
    p, state = update_fn(p, state)
@pnkraemer
Copy link
Owner

Hi! Thanks for dropping by :)

What exactly is your question? If it is about using ProbDiffEq for NODEs:

Reverse-mode differentiation of simulate_terminal_values (akin to e.g. diffrax.BacksolveAdjoint) is a work in progress. For the time being, we must use fixed time-steps instead.

For example, here is an example notebook that does something similar to your example but with two differences:

  • It uses fixed time steps (for reverse-mode differentiability)
  • It replaces jnp.mean(jnp.square(ekf0sol.u - yt)) by a probabilistic equivalent (which takes the statistical nature of the probabilistic solution into account -- something a non-probabilitic solver cannot do).

Does this help? What do you think?

@adam-hartshorne
Copy link
Author

adam-hartshorne commented Mar 13, 2023

Thanks for the quick response.

I did see that example, but as I understand it, it requires you to have data along the path. If you have dataset where you only have a set of input locations and terminal locations, I couldn't see how I could use that example?

I am interested in a probabilistic solution to such a setup.

@pnkraemer
Copy link
Owner

Ah, I see!

Essentially, one would only replace ekf0sol.u with ekf0sol.u[-1] in your example.

To adapt the NODE example notebook (including the loss function I mentioned above), replace the loss_fn with something like the logposterior_fn from the sampling example (i.e. the BlackJAX example, which deals with terminal-value data):

@jax.jit
def logposterior_fn(theta, *, data, ts, solver, obs_stdev=0.1):
    y_T = solve_fixed(theta, ts=ts, solver=solver)
    marginals, _ = y_T.posterior.condition_on_qoi_observation(
        data, observation_std=obs_stdev
    )
    return marginals.logpdf(data)  # removed prior PDF from notebook


# Fixed steps for reverse-mode differentiability:


@jax.jit
def solve_fixed(theta, *, ts, solver):
    sol = solution_routines.solve_fixed_grid(
        vf, initial_values=(theta,), grid=ts, solver=solver
    )
    return sol[-1]

In general, the sampling example might be useful to look at if you deal with terminal value data.
But in general, the differences between the kinds of data are comparably small.

Does this help?

@adam-hartshorne
Copy link
Author

Great, thank you very much for your help. My misunderstanding was that the fixed grid methods were for use exclusively on datasets in which you have trajectory data, not just the terminal values.

FYI, Diffrax makes use of special loops (which are defined in Equinox) including ones to efficiently handle adaptive solves, and allow for reverse mode differentiation. I would have thought you could probably build off those for your use case.

I will give it a try and see how I get on with my actual use case.

@pnkraemer
Copy link
Owner

Awesome, glad to hear that! If you run into more problems/misunderstandings, don't hesitate to ask more questions.

FYI, Diffrax makes use of special loops (which are defined in Equinox) including ones to efficiently handle adaptive solves, and allow for reverse mode differentiation. I would have thought you could probably build off those for your use case.

Yes, I am aware of the bounded while loops and see how such functionality could be helpful.
I made a note about a potential path forward #453; if you're keen on this extension, let's continue discussing there :)

Feel free to close this issue if your original question is resolved; if not, let me know.

@adam-hartshorne
Copy link
Author

Sorry if this is a stupid question, when handling a problem where we have n points in m dimensional space (e.g. 10 x 2d) in which we know their initial location and final position, lets call them X and Y.

After looking at this example https://pnkraemer.github.io/probdiffeq/benchmarks/pleiades/external/ , I am right in thinking that the initial_values is a flattened version of X i.e. tuple, where first element is array of shape (nm,) e.g (20,) ? And then we reshape back to (10,2) in the f that handles the vector field function?

And in terms of

 marginals, _ = y_T.posterior.condition_on_qoi_observation(data, observation_std=obs_stdev)
 return marginals.logpdf(data)  # removed prior PDF from notebook

here data refers to flatten version of Y e.g. shape (20,) ?

@pnkraemer
Copy link
Owner

Are you referring to matrix-valued differential equations? I.e. d/dt M(t) = f(M(t)), where M(t) is a matrix, not a vector?

In this case, I'd say you're right; rewriting this equation as a vector-valued (i.e. flattened) version seems to make sense. Instead of a (10,2)-shaped equation, one would solve a (20,)-shaped equation, and all derived quantities (e.g. data in your example) would be reshaped accordingly.

Does that help?

@adam-hartshorne
Copy link
Author

I am trying to learn a vector flow field, as defined by an NODE, which models the advection of a set of points, given we know their start and end locations in 2d.

@pnkraemer
Copy link
Owner

I see. I think that, for the moment, "flattening the equation" is the best way forward.
I noted a potential extension of ProbDiffEq to matrix-valued equations in #457.

Since we're kind of drifting away from the original question (about simulate_terminal_values), I will close this issue for now. Please reopen if the original question has not been answered yet!

Let's move the discussion about matrix-valued equations to #457 :) And please feel invited to open more issues if you run into more problems!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants