Skip to content

Commit

Permalink
Remove iterability from the Solution object because it is mathematica…
Browse files Browse the repository at this point in the history
…lly ill-defined (#800)
  • Loading branch information
pnkraemer authored Oct 28, 2024
1 parent c5260d3 commit 1c75221
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 155 deletions.
36 changes: 19 additions & 17 deletions docs/examples_parameter_estimation/physics_enhanced_regression_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,12 @@ def vf(y, *, t): # noqa: ARG001


# +
def plot_solution(sol, *, ax, marker=".", **plotting_kwargs):
def plot_solution(t, u, *, ax, marker=".", **plotting_kwargs):
"""Plot the IVP solution."""
for d in [0, 1]:
ax.plot(sol.t, sol.u[0][:, d], marker="None", **plotting_kwargs)
ax.plot(sol.t[0], sol.u[0][0, d], marker=marker, **plotting_kwargs)
ax.plot(sol.t[-1], sol.u[0][-1, d], marker=marker, **plotting_kwargs)
ax.plot(t, u[:, d], marker="None", **plotting_kwargs)
ax.plot(t[0], u[0, d], marker=marker, **plotting_kwargs)
ax.plot(t[-1], u[-1, d], marker=marker, **plotting_kwargs)
return ax


Expand All @@ -194,8 +194,7 @@ def solve_fixed(theta, *, ts):
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver.initial_condition()
sol = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm)
return sol[-1]
return ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm)


@jax.jit
Expand Down Expand Up @@ -229,12 +228,12 @@ def solve_adaptive(theta, *, save_at):
data_kwargs = {"alpha": 0.5, "color": "gray"}
ax.annotate("Data", (13.0, 30.0), **data_kwargs)
sol = solve_save_at(theta_true)
ax = plot_solution(sol, ax=ax, **data_kwargs)
ax = plot_solution(sol.t, sol.u[0], ax=ax, **data_kwargs)

guess_kwargs = {"color": "C3"}
ax.annotate("Initial guess", (7.5, 20.0), **guess_kwargs)
sol = solve_save_at(theta_guess)
ax = plot_solution(sol, ax=ax, **guess_kwargs)
ax = plot_solution(sol.t, sol.u[0], ax=ax, **guess_kwargs)
plt.show()
# -

Expand All @@ -251,9 +250,10 @@ def solve_adaptive(theta, *, save_at):
@jax.jit
def logposterior_fn(theta, *, data, ts, obs_stdev=0.1):
"""Evaluate the logposterior-function of the data."""
y_T = solve_fixed(theta, ts=ts)
solution = solve_fixed(theta, ts=ts)
y_T = jax.tree.map(lambda s: s[-1], solution.posterior)
logpdf_data = stats.log_marginal_likelihood_terminal_values(
data, standard_deviation=obs_stdev, posterior=y_T.posterior, ssm=y_T.ssm
data, standard_deviation=obs_stdev, posterior=y_T, ssm=solution.ssm
)
logpdf_prior = jax.scipy.stats.multivariate_normal.logpdf(theta, mean=mean, cov=cov)
return logpdf_data + logpdf_prior
Expand All @@ -263,7 +263,7 @@ def logposterior_fn(theta, *, data, ts, obs_stdev=0.1):


ts = jnp.linspace(t0, t1, endpoint=True, num=100)
data = solve_fixed(theta_true, ts=ts).u[0]
data = solve_fixed(theta_true, ts=ts).u[0][-1]

log_M = functools.partial(logposterior_fn, data=data, ts=ts)
# -
Expand Down Expand Up @@ -330,18 +330,20 @@ def one_step(state, rng_key):

sample_kwargs = {"color": "C0"}
ax.annotate("Samples", (2.75, 31.0), **sample_kwargs)
for sol in solution_samples:
ax = plot_solution(sol, ax=ax, linewidth=0.1, alpha=0.75, **sample_kwargs)
for ts, us in zip(solution_samples.t, solution_samples.u[0]):
ax = plot_solution(ts, us, ax=ax, linewidth=0.1, alpha=0.75, **sample_kwargs)

data_kwargs = {"color": "gray"}
ax.annotate("Data", (18.25, 40.0), **data_kwargs)
sol = solve_save_at(theta_true)
ax = plot_solution(sol, ax=ax, linewidth=4, alpha=0.5, **data_kwargs)
ax = plot_solution(sol.t, sol.u[0], ax=ax, linewidth=4, alpha=0.5, **data_kwargs)

guess_kwargs = {"color": "gray"}
ax.annotate("Initial guess", (6.0, 12.0), **guess_kwargs)
sol = solve_save_at(theta_guess)
ax = plot_solution(sol, ax=ax, linestyle="dashed", alpha=0.75, **guess_kwargs)
ax = plot_solution(
sol.t, sol.u[0], ax=ax, linestyle="dashed", alpha=0.75, **guess_kwargs
)
plt.show()
# -

Expand All @@ -361,8 +363,8 @@ def one_step(state, rng_key):
# the sampler covers the entire region of interest.

# +
xlim = 17, jnp.amax(states.position[:, 0]) + 0.5
ylim = 17, jnp.amax(states.position[:, 1]) + 0.5
xlim = 14, jnp.amax(states.position[:, 0]) + 0.5
ylim = 14, jnp.amax(states.position[:, 1]) + 0.5

xs = jnp.linspace(*xlim, endpoint=True, num=300)
ys = jnp.linspace(*ylim, endpoint=True, num=300)
Expand Down
62 changes: 11 additions & 51 deletions probdiffeq/ivpsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from probdiffeq import stats
from probdiffeq.backend import (
containers,
control_flow,
functools,
linalg,
Expand All @@ -10,62 +11,21 @@
warnings,
)
from probdiffeq.backend import numpy as np
from probdiffeq.backend.typing import Any, Array


@containers.dataclass
class _Solution:
"""Estimated initial value problem solution."""

def __init__(self, t, u, u_std, output_scale, marginals, posterior, num_steps, ssm):
"""Construct a solution object."""
self.t = t
self.u = u
self.u_std = u_std
self.output_scale = output_scale
self.marginals = marginals # todo: marginals are replaced by "u" and "u_std"
self.posterior = posterior
self.num_steps = num_steps
self.ssm = ssm

def __repr__(self):
"""Evaluate a string-representation of the solution object."""
return (
f"{self.__class__.__name__}("
f"t={self.t},"
f"u={self.u},"
f"output_scale={self.output_scale},"
f"marginals={self.marginals},"
f"posterior={self.posterior},"
f"num_steps={self.num_steps},"
")"
)

def __len__(self):
"""Evaluate the length of a solution."""
if np.ndim(self.t) < 1:
msg = "Solution object not batched :("
raise ValueError(msg)
return self.t.shape[0]

def __getitem__(self, item):
"""Access a single item of the solution."""
if np.ndim(self.t) < 1:
msg = "Solution object not batched :("
raise ValueError(msg)

if np.ndim(self.t) == 1 and item != -1:
msg = "Access to non-terminal states is not available."
raise ValueError(msg)

return tree_util.tree_map(lambda s: s[item, ...], self)

def __iter__(self):
"""Iterate through the solution."""
if np.ndim(self.t) <= 1:
msg = "Solution object not batched :("
raise ValueError(msg)

for i in range(self.t.shape[0]):
yield self[i]
t: Array
u: Array
u_std: Array
output_scale: Array
marginals: Any
posterior: Any
num_steps: Array
ssm: Any

@staticmethod
def register_pytree_node():
Expand Down
88 changes: 1 addition & 87 deletions tests/test_ivpsolve/test_solution_object.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Tests for interaction with the solution object."""

from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.backend import containers, functools, ode, testing
from probdiffeq.backend import numpy as np
from probdiffeq.backend import containers, ode, testing
from probdiffeq.backend.typing import Array


Expand Down Expand Up @@ -37,91 +36,6 @@ def fixture_approximate_solution(fact):
def test_u_inherits_data_structure(approximate_solution):
assert isinstance(approximate_solution.u, Taylor)

solution_t1 = approximate_solution[-1]
assert isinstance(solution_t1.u, Taylor)


def test_u_std_inherits_data_structure(approximate_solution):
assert isinstance(approximate_solution.u_std, Taylor)

solution_t1 = approximate_solution[-1]
assert isinstance(solution_t1.u_std, Taylor)


def test_getitem_possible_for_terminal_values(approximate_solution):
solution_t1 = approximate_solution[-1]
assert isinstance(solution_t1, type(approximate_solution))


@testing.parametrize("item", [-2, 0, slice(1, -1, 1)])
def test_getitem_impossible_for_nonterminal_values(approximate_solution, item):
with testing.raises(ValueError, match="non-terminal"):
_ = approximate_solution[item]


@testing.parametrize("item", [-1, -2, 0, slice(1, -1, 1)])
def test_getitem_impossible_at_single_time_for_any_item(approximate_solution, item):
# Allowed slicing:
# solution_t1 is not batched now, so further slicing should be impossible
solution_t1 = approximate_solution[-1]

with testing.raises(ValueError, match="not batched"):
_ = solution_t1[item]


def test_iter_impossible(approximate_solution):
with testing.raises(ValueError, match="not batched"):
for _ in approximate_solution:
pass


@testing.fixture(name="approximate_solution_batched")
@testing.parametrize("fact", ["dense", "blockdiag", "isotropic"])
def fixture_approximate_solution_batched(fact):
vf, (u0,), (t0, t1) = ode.ivp_lotka_volterra()

# Generate a solver
save_at = np.linspace(t0, t1, endpoint=True, num=4)

def solve(init):
tcoeffs = (init, vf(init, t=None))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact)

ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)

initcond = solver.initial_condition()
return ivpsolve.solve_adaptive_save_at(
vf,
initcond,
save_at=save_at,
adaptive_solver=adaptive_solver,
dt0=0.1,
ssm=ssm,
)

u0_batched = u0[None, ...]
solve = functools.vmap(solve)
return solve(u0_batched)


def test_batched_getitem_possible(approximate_solution_batched):
solution_type = type(approximate_solution_batched)
for idx in (0,):
approximate_solution = approximate_solution_batched[idx]
assert isinstance(approximate_solution, solution_type)
assert np.allclose(approximate_solution.t, approximate_solution_batched.t[idx])

for u1, u2 in zip(approximate_solution.u, approximate_solution_batched.u[idx]):
assert np.allclose(u1, u2)


def test_batched_iter_possible(approximate_solution_batched):
solution_type = type(approximate_solution_batched)
for idx, approximate_solution in enumerate(approximate_solution_batched):
assert isinstance(approximate_solution, solution_type)
assert np.allclose(approximate_solution.t, approximate_solution_batched.t[idx])
for u1, u2 in zip(approximate_solution.u, approximate_solution_batched.u[idx]):
assert np.allclose(u1, u2)

0 comments on commit 1c75221

Please sign in to comment.