-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for ODE simulation with diffrax #2
Comments
Hi Johanna, thanks for your interest! Yes I think that adding support for ODE simulation with diffrax would be great and happy to integrate a PR on that! However I'm not sure whether we can and should discard the step function: for some models they are parameter updates made by assignment rules after each integration step so we need to check whether this can be integrated within diffrax (i'm not familiar with it) and I think it's nice for users to be able to intervene on the system dynamics through time. Ideally we should support both solving one step at a time and solving for an array of timesteps. Regarding your second point I'm not sure I understand, could you specify what you have in mind? |
Hi! Nice, I'm glad to hear it!
These can usually be integrated into the ODE, no? I am working with a model that does contain algebraic relations for some parameters, too, it looks something like this class ODESystem(eqx.Module):
def __call__(self, time, states, args):
x1, x2 = states
p1, p2 = args
pa = p1 * (x1 - p2) / x2 # Some algebraic relation that depends on the system state
d_x1 = pa * x1 # Regular ODE system
d_x2 = - p2 * x2 + p1 * x1
d_state = d_x1, d_x2
return jnp.array(d_state) That looks like something that can be integrated into a Serializer to me.
I have time-varying inputs, which I represent with a ts = jnp.array([0, 1, 1+0.001, 2]) # A step function
ys = jnp.array([1, 1, 2, 2])
u = dfx.LinearInterpolation(ts=ts, ys=ys)
def i1ffl(time, states, args):
x, y = states
params, u = args
d_x = u.evaluate(time) - x
d_y = u.evaluate(time)/x - y
d_state = d_x, d_y
return jnp.array(d_state) I've been working with diffrax for a while and it truly is a joy to use. It also supports interactively stepping through a solve, similar to your step function, but you could use any one of its 19 ODE solvers to do it. A new library for nonlinear optimization in that same ecosystem (and by the same author), optimistix, already interfaces with optax (and also supports interactively stepping through a solve). Both diffrax and optimistix are built on top of equinox, which you already use to register models as PyTrees. It seems to me that there are some synergies here! |
Great yes looks like we should then be able to include it easily, akin your example! I agree that having the 19 ODE solvers of diffrax would be a great addition, and would indeed enable users to couple it with optimistix or the steady state event. Do you wanna start the PR? I think that simply rewriting the Then to check whether this is functioning as expected you can try to reproduce examples from https://developmentalsystems.org/sbmltoodejax/tutorials/biomodels_curation.html. Regarding the potential use of What do you think? |
Hello, I have been using diffrax for a bit now and also just came across your library. I am interested in fitting SBML models to data with JAX and am interested in connecting diffrax with SBMLtoodejax. What is the status of this? I would be interested in helping. |
Hi @mayalenE and @DylanEsguerra, this has completely escaped my notice - I had not realised that you had replied to me, how rude of me! Now, looking at the code: I agree that adapting the model step should be quite straightforward, but I would prefer to be able to export to callable that I can import directly into diffrax as an Have a nice evening :) |
Hi @johannahaffner and @DylanEsguerra, thanks for your messages! Yes definitely, let's work on that I've been able to reproduce some results with diffrax the only thing that we might loose is that the ODETerm from diffrax cannot return other args than dy. If we generate something like below by integrating the ratefunc and assignment func from SBML models in the ODE term, it would be nice to be able to save the evolution of w as well but I dont see it as an option in diffrax at the moment: class ODESystem(eqx.Module):
def __call__(self, t, y, args):
w, c = args
# Assignment func in SBML
w = w.at[0].set(1.0 * ((c[1]/1.0) - (y[2]/1.0)))
w = w.at[1].set(1.0 * ((y[3]/1.0) + (y[4]/1.0)))
w = w.at[2].set(1.0 * ((c[2]/1.0) - (y[0]/1.0)))
w = w.at[3].set(1.0 * ((c[0]/1.0) - (y[1]/1.0)))
dy = ratefunc(y, t, w, c)
return dy |
And also do one of you knows if it can happen in SBML than parameters updated in the assignment rule (w here) depend of previous w values? |
Hello @mayalenE, I am pretty new to SBML so not aware if w ever depends on previous values, but I have forked your repository and integrated 8 diffrax solvers into the ModelStep function. It seems to be working well and can handle return ws shown by a test on BioModel 240. The solvers I added are the ERK methods since I have had trouble stepping through solves with the IRK methods. |
Im going to try and add the IRK methods before making a PR, but my current version is in a feature branch and I would greatly appreciate both of you checking it out! Here is the part I added
|
Hi Dylan, nice thanks! My question was to see if we could also generate an ODETerm that calls both the ratefunc and assignmentfunc so users can directly perform a diffeqsolve on it with any combination of solver/stepsize_controller/adjoint as proposed by Johanna. Regarding your PR I've made a branch for this issue (https://github.com/flowersteam/sbmltoodejax/tree/support-for-ode-simulation-with-diffrax) so would be great if you can do the PR on that. After a quick look at your |
Hi Dylan, yes, looks good to me! If that veers to far away from what you are building and using, I completely understand! |
I'm not sure what you have in mind here? As explained here I separated rate function from assignment function because one handles the evolution of variables y by solving an ODE whereas the other directly sets the value of parameters w by an equation. I dont see a way to include w in the y vector, but maybe i'm not understanding what you are suggesting here. |
My bad, and maybe I do not understand well enough what you mean by
In which models does this |
I believe ws = jax.vmap(lambda t, y: assignment_func(y, w0, c, t))(ts, ys) Shown in the context of a BioModels example: import jax
jax.config.update("jax_platform_name", "cpu")
import diffrax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from sbmltoodejax.utils import load_biomodel
# Load and simulate model
load_biomodel(167)
from jax_model import RateofSpeciesChange, AssignmentRule, y0, w0, t0, c
def ode_func(t, y, args):
w, c = args
rate_of_change = RateofSpeciesChange()
assignment_rule = AssignmentRule()
# Update w using the assignment rule
w = assignment_rule(y, w, c, t)
# Calculate the rate of change
dy_dt = rate_of_change(y, t, w, c)
return dy_dt
t1 = 5000
deltaT = 0.1
n_steps = int(t1 / deltaT)
# Create diffrax solver
term = diffrax.ODETerm(ode_func)
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, n_steps))
# Solve the ODE system
sol, ws = diffrax.diffeqsolve(
term,
solver,
t0=t0,
t1=t1,
dt0=deltaT,
y0=y0,
args=(w0, c),
saveat=saveat,
max_steps=500000
)
# Extract results
ts = sol.ts
ys = sol.ys
ws = jax.vmap(lambda t, y: AssignmentRule()(y, w0, c, t))(ts, ys)
# Plot time course simulation as in the original paper
fig, ax = plt.subplots(3, 1, figsize=(6, 10))
ax[0].plot(ts, ws[:,0]/14.625, color="red", label="statKinase_sol")
ax[0].legend()
ax[1].plot(ts, ys[:,1]/14.625, color="red", label="PstatDimer_sol")
ax[1].legend()
ax[2].plot(ts, ys[:,4]/14.625, color="red", label="stat_sol")
ax[2].legend()
for i in range(3):
ax[i].set_xlim(left=0)
if i < 2:
ax[i].set_ylim(bottom=0)
else:
ax[i].set_ylim(top=1)
plt.show() |
Alright, I see. So it is option 4., and the evolving state of Given that we can also import the functions computing the rate of species change and the assignment rules from the imported jax model, I will probably just use that for now. If you're doing something similar, I suggest defining the functions outside of odefunc instead of instantiating the classes for each call, which would create a new object for every step. I expect that his would decrease compilation time significantly in larger projects. rate_of_change = RateOfSpeciesChange()
assignment_rule = AssignmentRule()
def ode_func(t, y, args):
w0, c = args
w = assignment_rule(...)
dy = rate_of_change(...)
return dy Thank you for taking the time to explain this! |
On another note: I see that some of the parameters are hard-coded into the |
And one more: make this an SBMLTOODEJAX example! It was super helpful :) |
Hi, just found your library! Very exciting.
Were you planning to add support for ODE simulation with diffrax?
I use diffrax to simulate my models and optimistix to fit them, and more generally work in a hierarchical modeling context where I compose a larger model out of many smaller equinox modules.
I see a lot of similarities here - and I think it would make sense to serialize to an ODE system that can be used as a
diffrax.ODETerm
, which requires the signature to bet, y, args
. The step function used in your example would then not be needed.I'd also like to take advantage of
diffrax.LinearInterpolation
, I use it to define my experimental inputs.I believe that input/condition handling is currently out-of-scope for vanilla SBML and only supported in PETab, right?
Let me know if this is on your To-Do list already - or if you'd be happy to take a PR on this.
The text was updated successfully, but these errors were encountered: