Skip to content

Commit

Permalink
Merge IVP-solution modules (#679)
Browse files Browse the repository at this point in the history
  • Loading branch information
pnkraemer authored Dec 10, 2023
1 parent 41c7d0c commit 422099e
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 96 deletions.
88 changes: 0 additions & 88 deletions probdiffeq/_ivpsolve_impl.py

This file was deleted.

87 changes: 79 additions & 8 deletions probdiffeq/ivpsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import jax
import jax.numpy as jnp

from probdiffeq import _ivpsolve_impl
from probdiffeq.backend import tree_array_util
from probdiffeq.backend import control_flow, tree_array_util
from probdiffeq.impl import impl
from probdiffeq.solvers import markov

Expand Down Expand Up @@ -104,7 +103,7 @@ def simulate_terminal_values(
) -> Solution:
"""Simulate the terminal values of an initial value problem."""
save_at = jnp.asarray([t1])
(_t, solution_save_at), _, num_steps = _ivpsolve_impl.solve_and_save_at(
(_t, solution_save_at), _, num_steps = _solve_and_save_at(
jax.tree_util.Partial(vector_field),
t0,
initial_condition,
Expand Down Expand Up @@ -150,7 +149,7 @@ def solve_and_save_at(
)
warnings.warn(msg, stacklevel=1)

(_t, solution_save_at), _, num_steps = _ivpsolve_impl.solve_and_save_at(
(_t, solution_save_at), _, num_steps = _solve_and_save_at(
jax.tree_util.Partial(vector_field),
save_at[0],
initial_condition,
Expand All @@ -176,6 +175,46 @@ def solve_and_save_at(
)


def _solve_and_save_at(
vector_field, t, initial_condition, *, save_at, adaptive_solver, dt0
):
advance_func = functools.partial(
_advance_and_interpolate,
vector_field=vector_field,
adaptive_solver=adaptive_solver,
)

state = adaptive_solver.init(t, initial_condition, dt0=dt0, num_steps=0.0)
_, solution = jax.lax.scan(f=advance_func, init=state, xs=save_at, reverse=False)
return solution


def _advance_and_interpolate(state, t_next, *, vector_field, adaptive_solver):
# Advance until accepted.t >= t_next.
# Note: This could already be the case and we may not loop (just interpolate)
def cond_fun(s):
# Terminate the loop if
# the difference from s.t to t_next is smaller than a constant factor
# (which is a "small" multiple of the current machine precision)
# or if s.t > t_next holds.
return s.t + 10 * jnp.finfo(float).eps < t_next

def body_fun(s):
return adaptive_solver.rejection_loop(s, vector_field=vector_field, t1=t_next)

state = control_flow.while_loop(cond_fun, body_fun, init=state)

# Either interpolate (t > t_next) or "finalise" (t == t_next)
state, solution = jax.lax.cond(
state.t > t_next + 10 * jnp.finfo(float).eps,
adaptive_solver.interpolate_and_extract,
lambda s, _t: adaptive_solver.right_corner_and_extract(s),
state,
t_next,
)
return state, solution


def solve_and_save_every_step(
vector_field, initial_condition, t0, t1, adaptive_solver, dt0
) -> Solution:
Expand All @@ -193,14 +232,18 @@ def solve_and_save_every_step(
)
warnings.warn(msg, stacklevel=1)

(t, solution_every_step), _dt, num_steps = _ivpsolve_impl.solve_and_save_every_step(
generator = _solution_generator(
jax.tree_util.Partial(vector_field),
t0,
initial_condition,
t1=t1,
adaptive_solver=adaptive_solver,
dt0=dt0,
)
(t, solution_every_step), _dt, num_steps = tree_array_util.tree_stack(
list(generator)
)

# I think the user expects the initial time-point to be part of the grid
# (Even though t0 is not computed by this function)
t = jnp.concatenate((jnp.atleast_1d(t0), t))
Expand All @@ -222,12 +265,40 @@ def solve_and_save_every_step(
)


def _solution_generator(
vector_field, t, initial_condition, *, dt0, t1, adaptive_solver
):
"""Generate a probabilistic IVP solution iteratively."""
state = adaptive_solver.init(t, initial_condition, dt0=dt0, num_steps=0)

while state.t < t1:
state = adaptive_solver.rejection_loop(state, vector_field=vector_field, t1=t1)

if state.t < t1:
solution = adaptive_solver.extract(state)
yield solution

# Either interpolate (t > t_next) or "finalise" (t == t_next)
if state.t > t1:
_, solution = adaptive_solver.interpolate_and_extract(state, t=t1)
else:
_, solution = adaptive_solver.right_corner_and_extract(state)

yield solution


def solve_fixed_grid(vector_field, initial_condition, grid, solver) -> Solution:
"""Solve an initial value problem on a fixed, pre-determined grid."""
# Compute the solution
_t, (posterior, output_scale) = _ivpsolve_impl.solve_fixed_grid(
jax.tree_util.Partial(vector_field), initial_condition, grid=grid, solver=solver
)

def body_fn(s, dt):
_error, s_new = solver.step(state=s, vector_field=vector_field, dt=dt)
return s_new, s_new

t0 = grid[0]
state0 = solver.init(t0, initial_condition)
_, result_state = jax.lax.scan(f=body_fn, init=state0, xs=jnp.diff(grid))
_t, (posterior, output_scale) = solver.extract(result_state)

# I think the user expects marginals, so we compute them here
posterior_t0, *_ = initial_condition
Expand Down

0 comments on commit 422099e

Please sign in to comment.