Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse Mode Differentiation #1

Open
adam-hartshorne opened this issue Nov 23, 2023 · 6 comments
Open

Reverse Mode Differentiation #1

adam-hartshorne opened this issue Nov 23, 2023 · 6 comments

Comments

@adam-hartshorne
Copy link

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

I would like to test your library within an existing framework that I have. This involves learning the parameters of a GP which controls a vector flow field and thus I need reverse-mode differentiability in order to optimise a loss function to learn these parameters.

I don't know if you are aware that there are undocumented functions within the Equinox library for JAX which handle this,

https://github.com/patrick-kidger/equinox/tree/main/equinox/internal/_loop

@nathanaelbosch
Copy link
Owner

I have not tried tested the code with autodiff. Would you have a minimal example that I could try? If there is some easy way to get this to work, I would definitely be up for updating the code to support this.

@adam-hartshorne
Copy link
Author

adam-hartshorne commented Nov 24, 2023

Solving a simple NODE like this.
(obviously this isn't using the PODE solver properly by optimising NLL).

There are plenty of examples at provided with Diffrax and ProbDiffeq that require similar optimisation, which I presume your library could be a drop in replacement.

import jax
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jrandom
import equinox as eqx
import matplotlib.pyplot as plt
import optax
from pof.solver import solve, sequential_eks_solve

ts = jnp.linspace(0, 1.0, num=100)
ys = jnp.sin(5 * jnp.pi * ts)

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            key=key,
        )

    def __call__(self, t, y):
        return self.mlp(y)

class NeuralODE(eqx.Module):
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)

    def __call__(self, y0):

        ts_par = jnp.linspace(0, 1.0, 100)
        ys_par, info_par = solve(f=self.func, y0=y0, ts=ts_par, order=3, init="constant")
        mean, cov = ys_par
        return mean


data_size = 1
steps = 500
print_every = 10
width_size = 64
depth = 2
lr = 1e-3

key = jrandom.PRNGKey(42)
data_key, model_key, loader_key = jrandom.split(key, 3)
model = NeuralODE(data_size, width_size, depth, key=model_key)

@eqx.filter_value_and_grad
def grad_loss(model, x0):
    y_pred = model(x0)
    return jnp.mean((ys - y_pred) ** 2)

@eqx.filter_jit
def make_step(x0, model, opt_state):
    loss, grads = grad_loss(model, x0)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

optim = optax.adabelief(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
x0 = jnp.array([0.0])

for step in range(steps):
    loss, model, opt_state = make_step(x0, model, opt_state)
    if (step % print_every) == 0 or step == steps - 1:
        print(f"Step: {step}, Loss: {loss}")

@nathanaelbosch
Copy link
Owner

Thanks for the example! So this seems to be a fundamental issue with lax.while_loop. Right now I don't see an easy way to get rid of it while still doing what the code is supposed to do. I would also prefer not to build on some internal undocumented functionality from equinox, as it is not covered by semantic versioning so this code might change with each new release and thereby break this repository.

Maybe the best option here for you to test the parallel-in-time solvers in your specific usecase would be to fork the repository and replace the while loop with equinox's. Or if you see some other way to add reverse-diff support here, let me know and we can try and figure out how get it implemented.

@adam-hartshorne
Copy link
Author

adam-hartshorne commented Nov 25, 2023

  1. The undocumented functions in equinox is very mature as it forms the basis of the Diffrax ODE Solver library that is widely used, so it is unlikely to massively change and is designed to replace lax.while_loop.

  2. Have you looked at how ProbDiffEq achieve this, as that library enables auto-diff and solves Probabilistic ODEs.

  3. I think that for wider spread adaption you will need reverse diff support, as solving Neural ODEs are widely used in Diffusion Models, Normalising Flows etc.

@adam-hartshorne
Copy link
Author

adam-hartshorne commented Nov 25, 2023

Another related suggestion would be to look into using https://github.com/wilson-labs/cola

It is trivial to drop in replace for base JAX operators, but this enables highly efficient linear ops via multiple dispatch and lazy evaluation. It improves speed and memory efficiency significantly.

@nathanaelbosch
Copy link
Owner

I agree with your points! Supporting autodiff would definitely be good for a probabilistic solver library. But this repository here is first and foremost meant to support our publication on the matter and to make our experiments reproducible - it is not meant as a user-facing library with lots of features. This job is better done by libraries like probdiffeq in jax, or my ProbNumDiffEq.jl package in Julia, both of which are actively maintained, well-tested and documented. I hope to make the parallel-in-time functionality available in ProbNumDiffEq.jl in the future, but I cannot give an ETA on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants