Skip to content

Commit

Permalink
Make the InterpRes into a dataclass to not rely on the ordering of at…
Browse files Browse the repository at this point in the history
…tributes, and rename the fields to improve clarity (#774)
  • Loading branch information
pnkraemer authored Aug 12, 2024
1 parent 730c3cd commit fad9ff6
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 60 deletions.
15 changes: 8 additions & 7 deletions probdiffeq/ivpsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,22 +326,23 @@ def extract_at_t1(self, state):
interp = self.solver.interpolate_at_t1(
interp_from=state.interp_from, interp_to=state.step_from
)
accepted, solution, previous = interp
state = _AdaptiveState(accepted, previous, state.control, state.stats)
state = _AdaptiveState(
interp.step_from, interp.interp_from, state.control, state.stats
)

solution_solver = self.solver.extract(solution)
solution_solver = self.solver.extract(interp.interpolated)
solution_control = self.control.extract(state.control)
return state, (solution_solver, solution_control, state.stats)

def extract_after_t1_via_interpolation(self, state, t):
interp = self.solver.interpolate(
t, interp_from=state.interp_from, interp_to=state.step_from
)
state = _AdaptiveState(
interp.step_from, interp.interp_from, state.control, state.stats
)

accepted, solution, previous = interp
state = _AdaptiveState(accepted, previous, state.control, state.stats)

solution_solver = self.solver.extract(solution)
solution_solver = self.solver.extract(interp.interpolated)
solution_control = self.control.extract(state.control)
return state, (solution_solver, solution_control, state.stats)

Expand Down
112 changes: 59 additions & 53 deletions probdiffeq/ivpsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,41 @@
from probdiffeq.backend.typing import Any, Array, Generic, TypeVar
from probdiffeq.impl import impl

T = TypeVar("T")
R = TypeVar("R")
S = TypeVar("S")


class _InterpRes(containers.NamedTuple):
# todo: rename to: solution, step_from, interpolate_from?
# in general, this object should not be necessary...
# instead, make all interpolation return a
# (solution, {"step_from": ..., "interp_from": ...}) tuple
accepted: Any
"""The new 'accepted' field.
@containers.dataclass
class _InterpRes(Generic[R]):
step_from: R
"""The new 'step_from' field.
At time `max(t, s1.t)`. Use this as the right-most reference state
At time `max(t, s1.t)`.
Use this as the right-most reference state
in future interpolations, or continue time-stepping from here.
"""

solution: Any
interpolated: R
"""The new 'solution' field.
At time `t`. This is the interpolation result.
"""

previous: Any
"""The new `previous_solution` field.
interp_from: R
"""The new `interp_from` field.
At time `t`. Use this as the right-most reference state
in future interpolations, or continue time-stepping from here.
The difference between `solution` and `previous` emerges in save_at* modes.
One belongs to the just-concluded time interval, and the other belongs to
the to-be-started time interval.
Concretely, this means that one has a unit backward model and the other
remembers how to step back to the previous state.
The difference between `interpolated` and `interp_from` emerges in save_at* modes.
`interpolated` belongs to the just-concluded time interval,
and `interp_from` belongs to the to-be-started time interval.
Concretely, this means that `interp_from` has a unit backward model
and `interpolated` remembers how to step back to the previous target location.
"""


T = TypeVar("T")
R = TypeVar("R")
S = TypeVar("S")


class _ExtrapolationImpl(abc.ABC, Generic[T, R, S]):
"""Extrapolation model interface."""

Expand Down Expand Up @@ -153,7 +150,11 @@ def extract(self, state: _StrategyState, /):
def case_interpolate_at_t1(self, state_t1: _StrategyState) -> _InterpRes:
"""Process the solution in case t=t_n."""
_tmp = self.extrapolation.interpolate_at_t1(state_t1.hidden, state_t1.aux_extra)
step_from, solution, interp_from = _tmp
step_from, solution, interp_from = (
_tmp.step_from,
_tmp.interpolated,
_tmp.interp_from,
)

def _state(x):
t = state_t1.t
Expand All @@ -170,7 +171,7 @@ def case_interpolate(
) -> _InterpRes:
"""Process the solution in case t>t_n."""
# Interpolate
step_from, solution, interp_from = self.extrapolation.interpolate(
interp = self.extrapolation.interpolate(
state_t0=(s0.hidden, s0.aux_extra),
marginal_t1=s1.hidden,
dt0=t - s0.t,
Expand All @@ -184,10 +185,12 @@ def _state(t_, x):
corr_like = tree_util.tree_map(np.empty_like, s0.aux_corr)
return _StrategyState(t=t_, hidden=x[0], aux_extra=x[1], aux_corr=corr_like)

step_from = _state(s1.t, step_from)
solution = _state(t, solution)
interp_from = _state(t, interp_from)
return _InterpRes(step_from, solution, interp_from)
step_from = _state(s1.t, interp.step_from)
interpolated = _state(t, interp.interpolated)
interp_from = _state(t, interp.interp_from)
return _InterpRes(
step_from=step_from, interpolated=interpolated, interp_from=interp_from
)

def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale):
"""Compute offgrid_marginals."""
Expand All @@ -198,14 +201,15 @@ def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_sca
dt1 = t1 - t
state_t0 = self.init(t0, posterior_t0)

_acc, (marginals, _aux), _prev = self.extrapolation.interpolate(
interp = self.extrapolation.interpolate(
state_t0=(state_t0.hidden, state_t0.aux_extra),
marginal_t1=marginals_t1,
dt0=dt0,
dt1=dt1,
output_scale=output_scale,
)

(marginals, _aux) = interp.interpolated
u = impl.stats.qoi(marginals)
return u, marginals

Expand Down Expand Up @@ -322,7 +326,9 @@ def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale):
solution_at_t1 = (marginal_t1, conditional_t1_to_t)

return _InterpRes(
accepted=solution_at_t1, solution=solution_at_t, previous=solution_at_t
step_from=solution_at_t1,
interpolated=solution_at_t,
interp_from=solution_at_t,
)

def _extrapolate(self, state, extra, /, dt, output_scale):
Expand Down Expand Up @@ -444,9 +450,9 @@ def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale):

# Return the right combination of marginals and conditionals.
return _InterpRes(
accepted=(marginal_t1, conditional_t1_to_t),
solution=(rv_at_t, extrapolated_t[1]),
previous=previous_new,
step_from=(marginal_t1, conditional_t1_to_t),
interpolated=(rv_at_t, extrapolated_t[1]),
interp_from=previous_new,
)

def _extrapolate(self, state, extra, /, dt, output_scale):
Expand Down Expand Up @@ -524,7 +530,7 @@ def interpolate(self, state_t0, marginal_t1, dt0, dt1, output_scale):
# Consistent state-types in interpolation result.
interp = (hidden, extra)
step_from = (marginal_t1, None)
return _InterpRes(accepted=step_from, solution=interp, previous=interp)
return _InterpRes(step_from=step_from, interpolated=interp, interp_from=interp)

def interpolate_at_t1(self, rv, extra, /):
return _InterpRes((rv, extra), (rv, extra), (rv, extra))
Expand Down Expand Up @@ -999,21 +1005,21 @@ def interpolate(
self, t, *, interp_from: _SolverState, interp_to: _SolverState
) -> _InterpRes:
output_scale, _ = self.calibration.extract(interp_to.output_scale)
acc_p, sol_p, prev_p = self.strategy.case_interpolate(
interp = self.strategy.case_interpolate(
t, s0=interp_from.strategy, s1=interp_to.strategy, output_scale=output_scale
)
prev = _SolverState(prev_p, output_scale=interp_from.output_scale)
sol = _SolverState(sol_p, output_scale=interp_to.output_scale)
acc = _SolverState(acc_p, output_scale=interp_to.output_scale)
return _InterpRes(accepted=acc, solution=sol, previous=prev)
prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale)
sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale)
acc = _SolverState(interp.step_from, output_scale=interp_to.output_scale)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)

def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes:
acc_p, sol_p, prev_p = self.strategy.case_interpolate_at_t1(interp_to.strategy)
x = self.strategy.case_interpolate_at_t1(interp_to.strategy)

prev = _SolverState(prev_p, output_scale=interp_from.output_scale)
sol = _SolverState(sol_p, output_scale=interp_to.output_scale)
acc = _SolverState(acc_p, output_scale=interp_to.output_scale)
return _InterpRes(accepted=acc, solution=sol, previous=prev)
prev = _SolverState(x.interp_from, output_scale=interp_from.output_scale)
sol = _SolverState(x.interpolated, output_scale=interp_to.output_scale)
acc = _SolverState(x.step_from, output_scale=interp_to.output_scale)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)


def _slvr_flatten(solver):
Expand Down Expand Up @@ -1069,24 +1075,24 @@ def extract(self, state: _SolverState, /):
def interpolate(
self, t, *, interp_from: _SolverState, interp_to: _SolverState
) -> _InterpRes:
acc_p, sol_p, prev_p = self.strategy.case_interpolate(
interp = self.strategy.case_interpolate(
t,
s0=interp_from.strategy,
s1=interp_to.strategy,
output_scale=interp_to.output_scale,
)
prev = _SolverState(prev_p, output_scale=interp_from.output_scale)
sol = _SolverState(sol_p, output_scale=interp_to.output_scale)
acc = _SolverState(acc_p, output_scale=interp_to.output_scale)
return _InterpRes(accepted=acc, solution=sol, previous=prev)
prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale)
sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale)
acc = _SolverState(interp.step_from, output_scale=interp_to.output_scale)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)

def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes:
acc_p, sol_p, prev_p = self.strategy.case_interpolate_at_t1(interp_to.strategy)
interp = self.strategy.case_interpolate_at_t1(interp_to.strategy)

prev = _SolverState(prev_p, output_scale=interp_from.output_scale)
sol = _SolverState(sol_p, output_scale=interp_to.output_scale)
acc = _SolverState(acc_p, output_scale=interp_to.output_scale)
return _InterpRes(accepted=acc, solution=sol, previous=prev)
prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale)
sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale)
acc = _SolverState(interp.step_from, output_scale=interp_to.output_scale)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)


def _solver_flatten(slvr):
Expand Down

0 comments on commit fad9ff6

Please sign in to comment.