Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal: Simplify the probabilistic-solver implementation #794

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 75 additions & 91 deletions probdiffeq/ivpsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,25 +787,6 @@ def extract(state, /):
return _Calibration(init=init, update=update, extract=extract)


@containers.dataclass
class _Solver:
"""IVP solver."""

name: str
requires_rescaling: bool
error_contraction_rate: int
is_suitable_for_save_at: int
is_suitable_for_save_every_step: int

initial_condition: Callable
init: Callable
step: Callable
extract: Callable
interpolate: Callable
interpolate_at_t1: Callable
offgrid_marginals: Callable


class _SolverState(containers.NamedTuple):
"""Solver state."""

Expand All @@ -817,13 +798,78 @@ def t(self):
return self.strategy.t


@containers.dataclass
class _ProbabilisticSolver:
name: str
calibration: _Calibration
step_implementation: Callable
strategy: _Strategy
requires_rescaling: bool

@property
def offgrid_marginals(self):
return self.strategy.offgrid_marginals

@property
def error_contraction_rate(self):
return self.strategy.num_derivatives + 1

@property
def is_suitable_for_save_at(self):
return self.strategy.is_suitable_for_save_at

@property
def is_suitable_for_save_every_step(self):
return self.strategy.is_suitable_for_save_every_step

def init(self, t, initial_condition) -> _SolverState:
posterior, output_scale = initial_condition
state_strategy = self.strategy.init(t, posterior)
calib_state = self.calibration.init(output_scale)
return _SolverState(strategy=state_strategy, output_scale=calib_state)

def step(self, state: _SolverState, *, vector_field, dt) -> _SolverState:
return self.step_implementation(
state, vector_field=vector_field, dt=dt, calibration=self.calibration
)

def extract(self, state: _SolverState, /):
t, posterior = self.strategy.extract(state.strategy)
_output_scale_prior, output_scale = self.calibration.extract(state.output_scale)
return t, (posterior, output_scale)

def interpolate(
self, t, *, interp_from: _SolverState, interp_to: _SolverState
) -> _InterpRes:
output_scale, _ = self.calibration.extract(interp_to.output_scale)
interp = self.strategy.case_interpolate(
t, s0=interp_from.strategy, s1=interp_to.strategy, output_scale=output_scale
)
prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale)
sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale)
acc = _SolverState(interp.step_from, output_scale=interp_to.output_scale)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)

def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes:
x = self.strategy.case_interpolate_at_t1(interp_to.strategy)

prev = _SolverState(x.interp_from, output_scale=interp_from.output_scale)
sol = _SolverState(x.interpolated, output_scale=interp_to.output_scale)
acc = _SolverState(x.step_from, output_scale=interp_to.output_scale)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)

def initial_condition(self):
"""Construct an initial condition."""
posterior = self.strategy.initial_condition()
return posterior, self.strategy.prior.output_scale


def solver_mle(strategy, *, ssm):
"""Create a solver that calibrates the output scale via maximum-likelihood.

Warning: needs to be combined with a call to stats.calibrate()
after solving if the MLE-calibration shall be *used*.
"""
name = f"<MLE-solver with {strategy}>"

def step_mle(state, /, *, dt, vector_field, calibration):
output_scale_prior, _calibrated = calibration.extract(state.output_scale)
Expand All @@ -843,18 +889,17 @@ def step_mle(state, /, *, dt, vector_field, calibration):
state = _SolverState(strategy=state_strategy, output_scale=output_scale)
return dt * error, state

return _solver_calibrated(
return _ProbabilisticSolver(
name="Probabilistic solver with MLE calibration",
calibration=_calibration_running_mean(ssm=ssm),
impl_step=step_mle,
step_implementation=step_mle,
strategy=strategy,
name=name,
requires_rescaling=True,
)


def solver_dynamic(strategy, *, ssm):
"""Create a solver that calibrates the output scale dynamically."""
name = f"<Dynamic solver with {strategy}>"

def step_dynamic(state, /, *, dt, vector_field, calibration):
error, observed, state_strategy = strategy.begin(
Expand All @@ -870,11 +915,11 @@ def step_dynamic(state, /, *, dt, vector_field, calibration):
state = _SolverState(strategy=state_strategy, output_scale=output_scale)
return dt * error, state

return _solver_calibrated(
return _ProbabilisticSolver(
strategy=strategy,
calibration=_calibration_most_recent(ssm=ssm),
name=name,
impl_step=step_dynamic,
name="Dynamic probabilistic solver",
step_implementation=step_dynamic,
requires_rescaling=False,
)

Expand All @@ -895,71 +940,10 @@ def step(state: _SolverState, *, vector_field, dt, calibration):
state = _SolverState(strategy=state_strategy, output_scale=state.output_scale)
return dt * error, state

name = f"<Uncalibrated solver with {strategy}>"
return _solver_calibrated(
return _ProbabilisticSolver(
strategy=strategy,
calibration=_calibration_none(),
impl_step=step,
name=name,
step_implementation=step,
name="Probabilistic solver",
requires_rescaling=False,
)


def _solver_calibrated(
*, calibration, impl_step, strategy, name, requires_rescaling
) -> _Solver:
def init(t, initial_condition) -> _SolverState:
posterior, output_scale = initial_condition
state_strategy = strategy.init(t, posterior)
calib_state = calibration.init(output_scale)
return _SolverState(strategy=state_strategy, output_scale=calib_state)

def step(state: _SolverState, *, vector_field, dt) -> _SolverState:
return impl_step(
state, vector_field=vector_field, dt=dt, calibration=calibration
)

def extract(state: _SolverState, /):
t, posterior = strategy.extract(state.strategy)
_output_scale_prior, output_scale = calibration.extract(state.output_scale)
return t, (posterior, output_scale)

def interpolate(
t, *, interp_from: _SolverState, interp_to: _SolverState
) -> _InterpRes:
output_scale, _ = calibration.extract(interp_to.output_scale)
interp = strategy.case_interpolate(
t, s0=interp_from.strategy, s1=interp_to.strategy, output_scale=output_scale
)
prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale)
sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale)
acc = _SolverState(interp.step_from, output_scale=interp_to.output_scale)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)

def interpolate_at_t1(*, interp_from, interp_to) -> _InterpRes:
x = strategy.case_interpolate_at_t1(interp_to.strategy)

prev = _SolverState(x.interp_from, output_scale=interp_from.output_scale)
sol = _SolverState(x.interpolated, output_scale=interp_to.output_scale)
acc = _SolverState(x.step_from, output_scale=interp_to.output_scale)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)

def initial_condition():
"""Construct an initial condition."""
posterior = strategy.initial_condition()
return posterior, strategy.prior.output_scale

return _Solver(
error_contraction_rate=strategy.num_derivatives + 1,
is_suitable_for_save_at=strategy.is_suitable_for_save_at,
is_suitable_for_save_every_step=strategy.is_suitable_for_save_every_step,
name=name,
requires_rescaling=requires_rescaling,
initial_condition=initial_condition,
init=init,
step=step,
extract=extract,
interpolate=interpolate,
interpolate_at_t1=interpolate_at_t1,
offgrid_marginals=strategy.offgrid_marginals,
)
Loading