Skip to content

Commit

Permalink
Internal: Simplify the probabilistic-solver implementation (#794)
Browse files Browse the repository at this point in the history
* Simplify the probabilistic-solver implementation

* Fix a test?
  • Loading branch information
pnkraemer authored Oct 25, 2024
1 parent 3b7e235 commit d17c76d
Showing 1 changed file with 75 additions and 91 deletions.
166 changes: 75 additions & 91 deletions probdiffeq/ivpsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,25 +787,6 @@ def extract(state, /):
return _Calibration(init=init, update=update, extract=extract)


@containers.dataclass
class _Solver:
"""IVP solver."""

name: str
requires_rescaling: bool
error_contraction_rate: int
is_suitable_for_save_at: int
is_suitable_for_save_every_step: int

initial_condition: Callable
init: Callable
step: Callable
extract: Callable
interpolate: Callable
interpolate_at_t1: Callable
offgrid_marginals: Callable


class _SolverState(containers.NamedTuple):
"""Solver state."""

Expand All @@ -817,13 +798,78 @@ def t(self):
return self.strategy.t


@containers.dataclass
class _ProbabilisticSolver:
name: str
calibration: _Calibration
step_implementation: Callable
strategy: _Strategy
requires_rescaling: bool

@property
def offgrid_marginals(self):
return self.strategy.offgrid_marginals

@property
def error_contraction_rate(self):
return self.strategy.num_derivatives + 1

@property
def is_suitable_for_save_at(self):
return self.strategy.is_suitable_for_save_at

@property
def is_suitable_for_save_every_step(self):
return self.strategy.is_suitable_for_save_every_step

def init(self, t, initial_condition) -> _SolverState:
posterior, output_scale = initial_condition
state_strategy = self.strategy.init(t, posterior)
calib_state = self.calibration.init(output_scale)
return _SolverState(strategy=state_strategy, output_scale=calib_state)

def step(self, state: _SolverState, *, vector_field, dt) -> _SolverState:
return self.step_implementation(
state, vector_field=vector_field, dt=dt, calibration=self.calibration
)

def extract(self, state: _SolverState, /):
t, posterior = self.strategy.extract(state.strategy)
_output_scale_prior, output_scale = self.calibration.extract(state.output_scale)
return t, (posterior, output_scale)

def interpolate(
self, t, *, interp_from: _SolverState, interp_to: _SolverState
) -> _InterpRes:
output_scale, _ = self.calibration.extract(interp_to.output_scale)
interp = self.strategy.case_interpolate(
t, s0=interp_from.strategy, s1=interp_to.strategy, output_scale=output_scale
)
prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale)
sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale)
acc = _SolverState(interp.step_from, output_scale=interp_to.output_scale)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)

def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes:
x = self.strategy.case_interpolate_at_t1(interp_to.strategy)

prev = _SolverState(x.interp_from, output_scale=interp_from.output_scale)
sol = _SolverState(x.interpolated, output_scale=interp_to.output_scale)
acc = _SolverState(x.step_from, output_scale=interp_to.output_scale)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)

def initial_condition(self):
"""Construct an initial condition."""
posterior = self.strategy.initial_condition()
return posterior, self.strategy.prior.output_scale


def solver_mle(strategy, *, ssm):
"""Create a solver that calibrates the output scale via maximum-likelihood.
Warning: needs to be combined with a call to stats.calibrate()
after solving if the MLE-calibration shall be *used*.
"""
name = f"<MLE-solver with {strategy}>"

def step_mle(state, /, *, dt, vector_field, calibration):
output_scale_prior, _calibrated = calibration.extract(state.output_scale)
Expand All @@ -843,18 +889,17 @@ def step_mle(state, /, *, dt, vector_field, calibration):
state = _SolverState(strategy=state_strategy, output_scale=output_scale)
return dt * error, state

return _solver_calibrated(
return _ProbabilisticSolver(
name="Probabilistic solver with MLE calibration",
calibration=_calibration_running_mean(ssm=ssm),
impl_step=step_mle,
step_implementation=step_mle,
strategy=strategy,
name=name,
requires_rescaling=True,
)


def solver_dynamic(strategy, *, ssm):
"""Create a solver that calibrates the output scale dynamically."""
name = f"<Dynamic solver with {strategy}>"

def step_dynamic(state, /, *, dt, vector_field, calibration):
error, observed, state_strategy = strategy.begin(
Expand All @@ -870,11 +915,11 @@ def step_dynamic(state, /, *, dt, vector_field, calibration):
state = _SolverState(strategy=state_strategy, output_scale=output_scale)
return dt * error, state

return _solver_calibrated(
return _ProbabilisticSolver(
strategy=strategy,
calibration=_calibration_most_recent(ssm=ssm),
name=name,
impl_step=step_dynamic,
name="Dynamic probabilistic solver",
step_implementation=step_dynamic,
requires_rescaling=False,
)

Expand All @@ -895,71 +940,10 @@ def step(state: _SolverState, *, vector_field, dt, calibration):
state = _SolverState(strategy=state_strategy, output_scale=state.output_scale)
return dt * error, state

name = f"<Uncalibrated solver with {strategy}>"
return _solver_calibrated(
return _ProbabilisticSolver(
strategy=strategy,
calibration=_calibration_none(),
impl_step=step,
name=name,
step_implementation=step,
name="Probabilistic solver",
requires_rescaling=False,
)


def _solver_calibrated(
*, calibration, impl_step, strategy, name, requires_rescaling
) -> _Solver:
def init(t, initial_condition) -> _SolverState:
posterior, output_scale = initial_condition
state_strategy = strategy.init(t, posterior)
calib_state = calibration.init(output_scale)
return _SolverState(strategy=state_strategy, output_scale=calib_state)

def step(state: _SolverState, *, vector_field, dt) -> _SolverState:
return impl_step(
state, vector_field=vector_field, dt=dt, calibration=calibration
)

def extract(state: _SolverState, /):
t, posterior = strategy.extract(state.strategy)
_output_scale_prior, output_scale = calibration.extract(state.output_scale)
return t, (posterior, output_scale)

def interpolate(
t, *, interp_from: _SolverState, interp_to: _SolverState
) -> _InterpRes:
output_scale, _ = calibration.extract(interp_to.output_scale)
interp = strategy.case_interpolate(
t, s0=interp_from.strategy, s1=interp_to.strategy, output_scale=output_scale
)
prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale)
sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale)
acc = _SolverState(interp.step_from, output_scale=interp_to.output_scale)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)

def interpolate_at_t1(*, interp_from, interp_to) -> _InterpRes:
x = strategy.case_interpolate_at_t1(interp_to.strategy)

prev = _SolverState(x.interp_from, output_scale=interp_from.output_scale)
sol = _SolverState(x.interpolated, output_scale=interp_to.output_scale)
acc = _SolverState(x.step_from, output_scale=interp_to.output_scale)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)

def initial_condition():
"""Construct an initial condition."""
posterior = strategy.initial_condition()
return posterior, strategy.prior.output_scale

return _Solver(
error_contraction_rate=strategy.num_derivatives + 1,
is_suitable_for_save_at=strategy.is_suitable_for_save_at,
is_suitable_for_save_every_step=strategy.is_suitable_for_save_every_step,
name=name,
requires_rescaling=requires_rescaling,
initial_condition=initial_condition,
init=init,
step=step,
extract=extract,
interpolate=interpolate,
interpolate_at_t1=interpolate_at_t1,
offgrid_marginals=strategy.offgrid_marginals,
)

0 comments on commit d17c76d

Please sign in to comment.