From 1c7522164b27742dbbc0bd087562f49550c881e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Mon, 28 Oct 2024 09:04:53 +0100 Subject: [PATCH] Remove iterability from the Solution object because it is mathematically ill-defined (#800) --- .../physics_enhanced_regression_2.py | 36 ++++---- probdiffeq/ivpsolve.py | 62 +++---------- tests/test_ivpsolve/test_solution_object.py | 88 +------------------ 3 files changed, 31 insertions(+), 155 deletions(-) diff --git a/docs/examples_parameter_estimation/physics_enhanced_regression_2.py b/docs/examples_parameter_estimation/physics_enhanced_regression_2.py index 213981d1..d781eec0 100644 --- a/docs/examples_parameter_estimation/physics_enhanced_regression_2.py +++ b/docs/examples_parameter_estimation/physics_enhanced_regression_2.py @@ -172,12 +172,12 @@ def vf(y, *, t): # noqa: ARG001 # + -def plot_solution(sol, *, ax, marker=".", **plotting_kwargs): +def plot_solution(t, u, *, ax, marker=".", **plotting_kwargs): """Plot the IVP solution.""" for d in [0, 1]: - ax.plot(sol.t, sol.u[0][:, d], marker="None", **plotting_kwargs) - ax.plot(sol.t[0], sol.u[0][0, d], marker=marker, **plotting_kwargs) - ax.plot(sol.t[-1], sol.u[0][-1, d], marker=marker, **plotting_kwargs) + ax.plot(t, u[:, d], marker="None", **plotting_kwargs) + ax.plot(t[0], u[0, d], marker=marker, **plotting_kwargs) + ax.plot(t[-1], u[-1, d], marker=marker, **plotting_kwargs) return ax @@ -194,8 +194,7 @@ def solve_fixed(theta, *, ts): strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() - sol = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm) - return sol[-1] + return ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm) @jax.jit @@ -229,12 +228,12 @@ def solve_adaptive(theta, *, save_at): data_kwargs = {"alpha": 0.5, "color": "gray"} ax.annotate("Data", (13.0, 30.0), **data_kwargs) sol = solve_save_at(theta_true) -ax = plot_solution(sol, ax=ax, **data_kwargs) +ax = plot_solution(sol.t, sol.u[0], ax=ax, **data_kwargs) guess_kwargs = {"color": "C3"} ax.annotate("Initial guess", (7.5, 20.0), **guess_kwargs) sol = solve_save_at(theta_guess) -ax = plot_solution(sol, ax=ax, **guess_kwargs) +ax = plot_solution(sol.t, sol.u[0], ax=ax, **guess_kwargs) plt.show() # - @@ -251,9 +250,10 @@ def solve_adaptive(theta, *, save_at): @jax.jit def logposterior_fn(theta, *, data, ts, obs_stdev=0.1): """Evaluate the logposterior-function of the data.""" - y_T = solve_fixed(theta, ts=ts) + solution = solve_fixed(theta, ts=ts) + y_T = jax.tree.map(lambda s: s[-1], solution.posterior) logpdf_data = stats.log_marginal_likelihood_terminal_values( - data, standard_deviation=obs_stdev, posterior=y_T.posterior, ssm=y_T.ssm + data, standard_deviation=obs_stdev, posterior=y_T, ssm=solution.ssm ) logpdf_prior = jax.scipy.stats.multivariate_normal.logpdf(theta, mean=mean, cov=cov) return logpdf_data + logpdf_prior @@ -263,7 +263,7 @@ def logposterior_fn(theta, *, data, ts, obs_stdev=0.1): ts = jnp.linspace(t0, t1, endpoint=True, num=100) -data = solve_fixed(theta_true, ts=ts).u[0] +data = solve_fixed(theta_true, ts=ts).u[0][-1] log_M = functools.partial(logposterior_fn, data=data, ts=ts) # - @@ -330,18 +330,20 @@ def one_step(state, rng_key): sample_kwargs = {"color": "C0"} ax.annotate("Samples", (2.75, 31.0), **sample_kwargs) -for sol in solution_samples: - ax = plot_solution(sol, ax=ax, linewidth=0.1, alpha=0.75, **sample_kwargs) +for ts, us in zip(solution_samples.t, solution_samples.u[0]): + ax = plot_solution(ts, us, ax=ax, linewidth=0.1, alpha=0.75, **sample_kwargs) data_kwargs = {"color": "gray"} ax.annotate("Data", (18.25, 40.0), **data_kwargs) sol = solve_save_at(theta_true) -ax = plot_solution(sol, ax=ax, linewidth=4, alpha=0.5, **data_kwargs) +ax = plot_solution(sol.t, sol.u[0], ax=ax, linewidth=4, alpha=0.5, **data_kwargs) guess_kwargs = {"color": "gray"} ax.annotate("Initial guess", (6.0, 12.0), **guess_kwargs) sol = solve_save_at(theta_guess) -ax = plot_solution(sol, ax=ax, linestyle="dashed", alpha=0.75, **guess_kwargs) +ax = plot_solution( + sol.t, sol.u[0], ax=ax, linestyle="dashed", alpha=0.75, **guess_kwargs +) plt.show() # - @@ -361,8 +363,8 @@ def one_step(state, rng_key): # the sampler covers the entire region of interest. # + -xlim = 17, jnp.amax(states.position[:, 0]) + 0.5 -ylim = 17, jnp.amax(states.position[:, 1]) + 0.5 +xlim = 14, jnp.amax(states.position[:, 0]) + 0.5 +ylim = 14, jnp.amax(states.position[:, 1]) + 0.5 xs = jnp.linspace(*xlim, endpoint=True, num=300) ys = jnp.linspace(*ylim, endpoint=True, num=300) diff --git a/probdiffeq/ivpsolve.py b/probdiffeq/ivpsolve.py index d860f0f0..b080ef1c 100644 --- a/probdiffeq/ivpsolve.py +++ b/probdiffeq/ivpsolve.py @@ -2,6 +2,7 @@ from probdiffeq import stats from probdiffeq.backend import ( + containers, control_flow, functools, linalg, @@ -10,62 +11,21 @@ warnings, ) from probdiffeq.backend import numpy as np +from probdiffeq.backend.typing import Any, Array +@containers.dataclass class _Solution: """Estimated initial value problem solution.""" - def __init__(self, t, u, u_std, output_scale, marginals, posterior, num_steps, ssm): - """Construct a solution object.""" - self.t = t - self.u = u - self.u_std = u_std - self.output_scale = output_scale - self.marginals = marginals # todo: marginals are replaced by "u" and "u_std" - self.posterior = posterior - self.num_steps = num_steps - self.ssm = ssm - - def __repr__(self): - """Evaluate a string-representation of the solution object.""" - return ( - f"{self.__class__.__name__}(" - f"t={self.t}," - f"u={self.u}," - f"output_scale={self.output_scale}," - f"marginals={self.marginals}," - f"posterior={self.posterior}," - f"num_steps={self.num_steps}," - ")" - ) - - def __len__(self): - """Evaluate the length of a solution.""" - if np.ndim(self.t) < 1: - msg = "Solution object not batched :(" - raise ValueError(msg) - return self.t.shape[0] - - def __getitem__(self, item): - """Access a single item of the solution.""" - if np.ndim(self.t) < 1: - msg = "Solution object not batched :(" - raise ValueError(msg) - - if np.ndim(self.t) == 1 and item != -1: - msg = "Access to non-terminal states is not available." - raise ValueError(msg) - - return tree_util.tree_map(lambda s: s[item, ...], self) - - def __iter__(self): - """Iterate through the solution.""" - if np.ndim(self.t) <= 1: - msg = "Solution object not batched :(" - raise ValueError(msg) - - for i in range(self.t.shape[0]): - yield self[i] + t: Array + u: Array + u_std: Array + output_scale: Array + marginals: Any + posterior: Any + num_steps: Array + ssm: Any @staticmethod def register_pytree_node(): diff --git a/tests/test_ivpsolve/test_solution_object.py b/tests/test_ivpsolve/test_solution_object.py index b1b1aa1b..2ff3c1a5 100644 --- a/tests/test_ivpsolve/test_solution_object.py +++ b/tests/test_ivpsolve/test_solution_object.py @@ -1,8 +1,7 @@ """Tests for interaction with the solution object.""" from probdiffeq import ivpsolve, ivpsolvers, taylor -from probdiffeq.backend import containers, functools, ode, testing -from probdiffeq.backend import numpy as np +from probdiffeq.backend import containers, ode, testing from probdiffeq.backend.typing import Array @@ -37,91 +36,6 @@ def fixture_approximate_solution(fact): def test_u_inherits_data_structure(approximate_solution): assert isinstance(approximate_solution.u, Taylor) - solution_t1 = approximate_solution[-1] - assert isinstance(solution_t1.u, Taylor) - def test_u_std_inherits_data_structure(approximate_solution): assert isinstance(approximate_solution.u_std, Taylor) - - solution_t1 = approximate_solution[-1] - assert isinstance(solution_t1.u_std, Taylor) - - -def test_getitem_possible_for_terminal_values(approximate_solution): - solution_t1 = approximate_solution[-1] - assert isinstance(solution_t1, type(approximate_solution)) - - -@testing.parametrize("item", [-2, 0, slice(1, -1, 1)]) -def test_getitem_impossible_for_nonterminal_values(approximate_solution, item): - with testing.raises(ValueError, match="non-terminal"): - _ = approximate_solution[item] - - -@testing.parametrize("item", [-1, -2, 0, slice(1, -1, 1)]) -def test_getitem_impossible_at_single_time_for_any_item(approximate_solution, item): - # Allowed slicing: - # solution_t1 is not batched now, so further slicing should be impossible - solution_t1 = approximate_solution[-1] - - with testing.raises(ValueError, match="not batched"): - _ = solution_t1[item] - - -def test_iter_impossible(approximate_solution): - with testing.raises(ValueError, match="not batched"): - for _ in approximate_solution: - pass - - -@testing.fixture(name="approximate_solution_batched") -@testing.parametrize("fact", ["dense", "blockdiag", "isotropic"]) -def fixture_approximate_solution_batched(fact): - vf, (u0,), (t0, t1) = ode.ivp_lotka_volterra() - - # Generate a solver - save_at = np.linspace(t0, t1, endpoint=True, num=4) - - def solve(init): - tcoeffs = (init, vf(init, t=None)) - ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) - - ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) - adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) - - initcond = solver.initial_condition() - return ivpsolve.solve_adaptive_save_at( - vf, - initcond, - save_at=save_at, - adaptive_solver=adaptive_solver, - dt0=0.1, - ssm=ssm, - ) - - u0_batched = u0[None, ...] - solve = functools.vmap(solve) - return solve(u0_batched) - - -def test_batched_getitem_possible(approximate_solution_batched): - solution_type = type(approximate_solution_batched) - for idx in (0,): - approximate_solution = approximate_solution_batched[idx] - assert isinstance(approximate_solution, solution_type) - assert np.allclose(approximate_solution.t, approximate_solution_batched.t[idx]) - - for u1, u2 in zip(approximate_solution.u, approximate_solution_batched.u[idx]): - assert np.allclose(u1, u2) - - -def test_batched_iter_possible(approximate_solution_batched): - solution_type = type(approximate_solution_batched) - for idx, approximate_solution in enumerate(approximate_solution_batched): - assert isinstance(approximate_solution, solution_type) - assert np.allclose(approximate_solution.t, approximate_solution_batched.t[idx]) - for u1, u2 in zip(approximate_solution.u, approximate_solution_batched.u[idx]): - assert np.allclose(u1, u2)