Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for ODE simulation with diffrax #2

Open
johannahaffner opened this issue Jun 21, 2024 · 17 comments
Open

Support for ODE simulation with diffrax #2

johannahaffner opened this issue Jun 21, 2024 · 17 comments

Comments

@johannahaffner
Copy link

Hi, just found your library! Very exciting.

Were you planning to add support for ODE simulation with diffrax?
I use diffrax to simulate my models and optimistix to fit them, and more generally work in a hierarchical modeling context where I compose a larger model out of many smaller equinox modules.

I see a lot of similarities here - and I think it would make sense to serialize to an ODE system that can be used as a diffrax.ODETerm, which requires the signature to be t, y, args. The step function used in your example would then not be needed.

I'd also like to take advantage of diffrax.LinearInterpolation, I use it to define my experimental inputs.
I believe that input/condition handling is currently out-of-scope for vanilla SBML and only supported in PETab, right?

Let me know if this is on your To-Do list already - or if you'd be happy to take a PR on this.

@mayalenE
Copy link
Contributor

Hi Johanna, thanks for your interest!

Yes I think that adding support for ODE simulation with diffrax would be great and happy to integrate a PR on that! However I'm not sure whether we can and should discard the step function: for some models they are parameter updates made by assignment rules after each integration step so we need to check whether this can be integrated within diffrax (i'm not familiar with it) and I think it's nice for users to be able to intervene on the system dynamics through time. Ideally we should support both solving one step at a time and solving for an array of timesteps.

Regarding your second point I'm not sure I understand, could you specify what you have in mind?

@johannahaffner
Copy link
Author

johannahaffner commented Jun 22, 2024

Hi! Nice, I'm glad to hear it!

for some models they are parameter updates made by assignment rules after each integration step so we need to check whether this can be integrated within diffrax

These can usually be integrated into the ODE, no? I am working with a model that does contain algebraic relations for some parameters, too, it looks something like this

class ODESystem(eqx.Module):
    
     def __call__(self, time, states, args):
           x1, x2 = states
           p1, p2 = args
 
           pa = p1 * (x1 - p2) / x2    # Some algebraic relation that depends on the system state
           
           d_x1 = pa * x1    # Regular ODE system
           d_x2 = - p2 * x2 + p1 * x1

           d_state = d_x1, d_x2
           return jnp.array(d_state)

That looks like something that can be integrated into a Serializer to me.

Regarding your second point I'm not sure I understand, could you specify what you have in mind?

I have time-varying inputs, which I represent with a dfx.LinearInterpolation and then evaluate inside of the ODE system, like this

ts = jnp.array([0, 1, 1+0.001, 2])  # A step function
ys = jnp.array([1, 1, 2, 2])

u = dfx.LinearInterpolation(ts=ts, ys=ys)

def i1ffl(time, states, args):
       x, y = states
       params, u = args
       
       d_x = u.evaluate(time) - x
       d_y = u.evaluate(time)/x - y
       
       d_state = d_x, d_y
       return jnp.array(d_state)

I've been working with diffrax for a while and it truly is a joy to use. It also supports interactively stepping through a solve, similar to your step function, but you could use any one of its 19 ODE solvers to do it. A new library for nonlinear optimization in that same ecosystem (and by the same author), optimistix, already interfaces with optax (and also supports interactively stepping through a solve).

Both diffrax and optimistix are built on top of equinox, which you already use to register models as PyTrees.
Diffrax also has support for steady-state and custom events and is written to be easily extendable.

It seems to me that there are some synergies here!

@mayalenE
Copy link
Contributor

Great yes looks like we should then be able to include it easily, akin your example! I agree that having the 19 ODE solvers of diffrax would be a great addition, and would indeed enable users to couple it with optimistix or the steady state event.

Do you wanna start the PR? I think that simply rewriting the ModelStep function should make the job (at the end of https://github.com/flowersteam/sbmltoodejax/blob/main/sbmltoodejax/modulegeneration.py) using interactive stepping through diffrax solve.

Then to check whether this is functioning as expected you can try to reproduce examples from https://developmentalsystems.org/sbmltoodejax/tutorials/biomodels_curation.html.

Regarding the potential use of LinearInterpolation, optimistix, SteadyStateEvent I think we should keep them outside of the main library (which is intended to stay minimal) but we could definitely add more tutorials on how to combine them with sbmltoodejax.

What do you think?

@DylanEsguerra
Copy link

Hello,

I have been using diffrax for a bit now and also just came across your library. I am interested in fitting SBML models to data with JAX and am interested in connecting diffrax with SBMLtoodejax. What is the status of this? I would be interested in helping.

@johannahaffner
Copy link
Author

Hi @mayalenE and @DylanEsguerra,

this has completely escaped my notice - I had not realised that you had replied to me, how rude of me!

Now, looking at the code: I agree that adapting the model step should be quite straightforward, but I would prefer to be able to export to callable that I can import directly into diffrax as an ODETerm. For this, the model needs to be in the form Callable[[t, y, args], y]. So it would probably be the most straightforward thing to adapt the serialiser.
@mayalenE, do you think we could modularise GenerateModel a bit, to re-use some of its functionality while writing to a different output file?

Have a nice evening :)

@mayalenE
Copy link
Contributor

mayalenE commented Oct 26, 2024

Hi @johannahaffner and @DylanEsguerra, thanks for your messages!

Yes definitely, let's work on that I've been able to reproduce some results with diffrax the only thing that we might loose is that the ODETerm from diffrax cannot return other args than dy.

If we generate something like below by integrating the ratefunc and assignment func from SBML models in the ODE term, it would be nice to be able to save the evolution of w as well but I dont see it as an option in diffrax at the moment:

class ODESystem(eqx.Module):
    
     def __call__(self, t, y, args):
        w, c = args

        # Assignment func in SBML
        w = w.at[0].set(1.0 * ((c[1]/1.0) - (y[2]/1.0)))
        w = w.at[1].set(1.0 * ((y[3]/1.0) + (y[4]/1.0)))
        w = w.at[2].set(1.0 * ((c[2]/1.0) - (y[0]/1.0)))
        w = w.at[3].set(1.0 * ((c[0]/1.0) - (y[1]/1.0)))
        
        dy = ratefunc(y, t, w, c)
        
            
        return dy

@mayalenE mayalenE reopened this Oct 26, 2024
@mayalenE
Copy link
Contributor

And also do one of you knows if it can happen in SBML than parameters updated in the assignment rule (w here) depend of previous w values?

@DylanEsguerra
Copy link

Hello @mayalenE, I am pretty new to SBML so not aware if w ever depends on previous values, but I have forked your repository and integrated 8 diffrax solvers into the ModelStep function. It seems to be working well and can handle return ws shown by a test on BioModel 240. The solvers I added are the ERK methods since I have had trouble stepping through solves with the IRK methods.

@DylanEsguerra
Copy link

Im going to try and add the IRK methods before making a PR, but my current version is in a feature branch and I would greatly appreciate both of you checking it out!

Here is the part I added

outputFile.write("\t\telse: # diffrax\n") outputFile.write("\t\t\tterm = ODETerm(lambda t, y, args: self.ratefunc(y, t, *args))\n") outputFile.write("\t\t\ttprev, tnext = t, t + deltaT\n") outputFile.write("\t\t\tstate = self.solver.init(term, tprev, tnext, y, (w, c))\n") outputFile.write("\t\t\ty_new, _, _, _, _ = self.solver.step(term, tprev, tnext, y, (w, c), state, made_jump=False)\n") outputFile.write("\t\tt_new = t + deltaT\n") outputFile.write("\t\tw_new = self.assignmentfunc(y_new, w, c, t_new)\n") outputFile.write("\t\treturn y_new, w_new, c, t_new\n\n")`

@mayalenE
Copy link
Contributor

Hi Dylan, nice thanks!
Iteratively stepping through diffrax solver step would indeed enable to handle the evolution of w and y separately (as in current sbmltoodejax version) and work in cases when w depend on previous values (which I'm not sure if they exist tbh).

My question was to see if we could also generate an ODETerm that calls both the ratefunc and assignmentfunc so users can directly perform a diffeqsolve on it with any combination of solver/stepsize_controller/adjoint as proposed by Johanna.

Regarding your PR I've made a branch for this issue (https://github.com/flowersteam/sbmltoodejax/tree/support-for-ode-simulation-with-diffrax) so would be great if you can do the PR on that. After a quick look at your modulegeneration_3.py code I think you could use f-string outputFile.write(f"\t\t\t\tself.solver = {diffrax_solver}()\n") to avoid handling all the cases by hand, but otherwise looks good to me!

@johannahaffner
Copy link
Author

Hi Dylan,

yes, looks good to me!
@mayalenE is it necessary to keep the rate function and assignment function like that? I know that other SBML tools unpack these completely and just return a vector of parameters and variables, all written into a (temp) file.
Something like that would be my preferred solution, since the ODE models are just a small part of larger models I am building, and I handle attributes such as tolerances etc. elsewhere.

If that veers to far away from what you are building and using, I completely understand!

@mayalenE
Copy link
Contributor

I'm not sure what you have in mind here? As explained here I separated rate function from assignment function because one handles the evolution of variables y by solving an ODE whereas the other directly sets the value of parameters w by an equation.

I dont see a way to include w in the y vector, but maybe i'm not understanding what you are suggesting here.
One solution as proposed above would be to performs the equations of the assignment function directly in the ODE term but this means that (i) it will not work if w depend of previous values and (ii) we will not be able to return w.

@johannahaffner
Copy link
Author

My bad, and maybe I do not understand well enough what you mean by $w$ and how it enters into the model.

  • Is this a differential-algebraic equation? If yes, then support for these is planned in diffrax. These can evolve your constraint function alongside the ODE.
  • If $w$ is simply a parameter of, or the observable in, the observation model, then it does not matter for the ODE and a map can be written to perform $y_{states}(t) \rightarrow y_{observed}(t)$ on some general time series $y$, after the ODE has already been solved.
  • If $w$ is something else, for example an experimental input (such as a step function), then I think it would be more natively handled as a diffrax.LinearInterpolation.
  • If $w$ is a factor that is necessary to compute the evolving state, then it can simply be integrated into the formulation of the ODE system without returning it, such that we compute $w = f(y)$, then $d_y = g(y, w)$, and return $d_y$, where $w$ is essentially just computed for convenience.

In which models does this $w$ play a role?

@DylanEsguerra
Copy link

DylanEsguerra commented Oct 31, 2024

I believe w is a time-dependent parameter defined by the assignment rule, often using a piecewise function in SBML. To extract its value from diffrax, you can map the assignment rule over your solution trajectory:

ws = jax.vmap(lambda t, y: assignment_func(y, w0, c, t))(ts, ys)

Shown in the context of a BioModels example:

import jax
jax.config.update("jax_platform_name", "cpu")
import diffrax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from sbmltoodejax.utils import load_biomodel

# Load and simulate model
load_biomodel(167)

from jax_model import RateofSpeciesChange, AssignmentRule, y0, w0, t0, c

def ode_func(t, y, args):
    w, c = args
    rate_of_change = RateofSpeciesChange()
    assignment_rule = AssignmentRule()
    
    # Update w using the assignment rule
    w = assignment_rule(y, w, c, t)
    
    # Calculate the rate of change
    dy_dt = rate_of_change(y, t, w, c)
    
    return dy_dt

t1 = 5000
deltaT = 0.1
n_steps = int(t1 / deltaT)

# Create diffrax solver
term = diffrax.ODETerm(ode_func)
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, n_steps))

# Solve the ODE system
sol, ws = diffrax.diffeqsolve(
    term,
    solver,
    t0=t0,
    t1=t1,
    dt0=deltaT,
    y0=y0,
    args=(w0, c),
    saveat=saveat,
    max_steps=500000
)

# Extract results
ts = sol.ts
ys = sol.ys

ws = jax.vmap(lambda t, y: AssignmentRule()(y, w0, c, t))(ts, ys)

# Plot time course simulation as in the original paper
fig, ax = plt.subplots(3, 1, figsize=(6, 10))

ax[0].plot(ts, ws[:,0]/14.625, color="red", label="statKinase_sol")
ax[0].legend()

ax[1].plot(ts, ys[:,1]/14.625, color="red", label="PstatDimer_sol")
ax[1].legend()

ax[2].plot(ts, ys[:,4]/14.625, color="red", label="stat_sol")
ax[2].legend()

for i in range(3):
    ax[i].set_xlim(left=0)
    if i < 2:
        ax[i].set_ylim(bottom=0)
    else:
        ax[i].set_ylim(top=1)

plt.show()

@johannahaffner
Copy link
Author

johannahaffner commented Nov 1, 2024

Alright, I see. So it is option 4., and the evolving state of $w$ does not actually need to be returned, because it is entirely dependent on $y$.

Given that we can also import the functions computing the rate of species change and the assignment rules from the imported jax model, I will probably just use that for now. If you're doing something similar, I suggest defining the functions outside of odefunc instead of instantiating the classes for each call, which would create a new object for every step. I expect that his would decrease compilation time significantly in larger projects.
Like so:

rate_of_change = RateOfSpeciesChange()
assignment_rule = AssignmentRule()

def ode_func(t, y, args):
    w0, c = args
    w = assignment_rule(...)
    dy = rate_of_change(...)
    return dy

Thank you for taking the time to explain this!

@johannahaffner
Copy link
Author

On another note: I see that some of the parameters are hard-coded into the RateofSpeciesChange and AssignmentRule classes. Is that a common choice in treating SBML models? This is most likely an area where I would have to customise this to my context.

@johannahaffner
Copy link
Author

And one more: make this an SBMLTOODEJAX example! It was super helpful :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

3 participants