diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index a8851e4d..f6845e8e 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -787,25 +787,6 @@ def extract(state, /): return _Calibration(init=init, update=update, extract=extract) -@containers.dataclass -class _Solver: - """IVP solver.""" - - name: str - requires_rescaling: bool - error_contraction_rate: int - is_suitable_for_save_at: int - is_suitable_for_save_every_step: int - - initial_condition: Callable - init: Callable - step: Callable - extract: Callable - interpolate: Callable - interpolate_at_t1: Callable - offgrid_marginals: Callable - - class _SolverState(containers.NamedTuple): """Solver state.""" @@ -817,13 +798,78 @@ def t(self): return self.strategy.t +@containers.dataclass +class _ProbabilisticSolver: + name: str + calibration: _Calibration + step_implementation: Callable + strategy: _Strategy + requires_rescaling: bool + + @property + def offgrid_marginals(self): + return self.strategy.offgrid_marginals + + @property + def error_contraction_rate(self): + return self.strategy.num_derivatives + 1 + + @property + def is_suitable_for_save_at(self): + return self.strategy.is_suitable_for_save_at + + @property + def is_suitable_for_save_every_step(self): + return self.strategy.is_suitable_for_save_every_step + + def init(self, t, initial_condition) -> _SolverState: + posterior, output_scale = initial_condition + state_strategy = self.strategy.init(t, posterior) + calib_state = self.calibration.init(output_scale) + return _SolverState(strategy=state_strategy, output_scale=calib_state) + + def step(self, state: _SolverState, *, vector_field, dt) -> _SolverState: + return self.step_implementation( + state, vector_field=vector_field, dt=dt, calibration=self.calibration + ) + + def extract(self, state: _SolverState, /): + t, posterior = self.strategy.extract(state.strategy) + _output_scale_prior, output_scale = self.calibration.extract(state.output_scale) + return t, (posterior, output_scale) + + def interpolate( + self, t, *, interp_from: _SolverState, interp_to: _SolverState + ) -> _InterpRes: + output_scale, _ = self.calibration.extract(interp_to.output_scale) + interp = self.strategy.case_interpolate( + t, s0=interp_from.strategy, s1=interp_to.strategy, output_scale=output_scale + ) + prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale) + sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale) + acc = _SolverState(interp.step_from, output_scale=interp_to.output_scale) + return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) + + def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: + x = self.strategy.case_interpolate_at_t1(interp_to.strategy) + + prev = _SolverState(x.interp_from, output_scale=interp_from.output_scale) + sol = _SolverState(x.interpolated, output_scale=interp_to.output_scale) + acc = _SolverState(x.step_from, output_scale=interp_to.output_scale) + return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) + + def initial_condition(self): + """Construct an initial condition.""" + posterior = self.strategy.initial_condition() + return posterior, self.strategy.prior.output_scale + + def solver_mle(strategy, *, ssm): """Create a solver that calibrates the output scale via maximum-likelihood. Warning: needs to be combined with a call to stats.calibrate() after solving if the MLE-calibration shall be *used*. """ - name = f"" def step_mle(state, /, *, dt, vector_field, calibration): output_scale_prior, _calibrated = calibration.extract(state.output_scale) @@ -843,18 +889,17 @@ def step_mle(state, /, *, dt, vector_field, calibration): state = _SolverState(strategy=state_strategy, output_scale=output_scale) return dt * error, state - return _solver_calibrated( + return _ProbabilisticSolver( + name="Probabilistic solver with MLE calibration", calibration=_calibration_running_mean(ssm=ssm), - impl_step=step_mle, + step_implementation=step_mle, strategy=strategy, - name=name, requires_rescaling=True, ) def solver_dynamic(strategy, *, ssm): """Create a solver that calibrates the output scale dynamically.""" - name = f"" def step_dynamic(state, /, *, dt, vector_field, calibration): error, observed, state_strategy = strategy.begin( @@ -870,11 +915,11 @@ def step_dynamic(state, /, *, dt, vector_field, calibration): state = _SolverState(strategy=state_strategy, output_scale=output_scale) return dt * error, state - return _solver_calibrated( + return _ProbabilisticSolver( strategy=strategy, calibration=_calibration_most_recent(ssm=ssm), - name=name, - impl_step=step_dynamic, + name="Dynamic probabilistic solver", + step_implementation=step_dynamic, requires_rescaling=False, ) @@ -895,71 +940,10 @@ def step(state: _SolverState, *, vector_field, dt, calibration): state = _SolverState(strategy=state_strategy, output_scale=state.output_scale) return dt * error, state - name = f"" - return _solver_calibrated( + return _ProbabilisticSolver( strategy=strategy, calibration=_calibration_none(), - impl_step=step, - name=name, + step_implementation=step, + name="Probabilistic solver", requires_rescaling=False, ) - - -def _solver_calibrated( - *, calibration, impl_step, strategy, name, requires_rescaling -) -> _Solver: - def init(t, initial_condition) -> _SolverState: - posterior, output_scale = initial_condition - state_strategy = strategy.init(t, posterior) - calib_state = calibration.init(output_scale) - return _SolverState(strategy=state_strategy, output_scale=calib_state) - - def step(state: _SolverState, *, vector_field, dt) -> _SolverState: - return impl_step( - state, vector_field=vector_field, dt=dt, calibration=calibration - ) - - def extract(state: _SolverState, /): - t, posterior = strategy.extract(state.strategy) - _output_scale_prior, output_scale = calibration.extract(state.output_scale) - return t, (posterior, output_scale) - - def interpolate( - t, *, interp_from: _SolverState, interp_to: _SolverState - ) -> _InterpRes: - output_scale, _ = calibration.extract(interp_to.output_scale) - interp = strategy.case_interpolate( - t, s0=interp_from.strategy, s1=interp_to.strategy, output_scale=output_scale - ) - prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale) - sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale) - acc = _SolverState(interp.step_from, output_scale=interp_to.output_scale) - return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) - - def interpolate_at_t1(*, interp_from, interp_to) -> _InterpRes: - x = strategy.case_interpolate_at_t1(interp_to.strategy) - - prev = _SolverState(x.interp_from, output_scale=interp_from.output_scale) - sol = _SolverState(x.interpolated, output_scale=interp_to.output_scale) - acc = _SolverState(x.step_from, output_scale=interp_to.output_scale) - return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) - - def initial_condition(): - """Construct an initial condition.""" - posterior = strategy.initial_condition() - return posterior, strategy.prior.output_scale - - return _Solver( - error_contraction_rate=strategy.num_derivatives + 1, - is_suitable_for_save_at=strategy.is_suitable_for_save_at, - is_suitable_for_save_every_step=strategy.is_suitable_for_save_every_step, - name=name, - requires_rescaling=requires_rescaling, - initial_condition=initial_condition, - init=init, - step=step, - extract=extract, - interpolate=interpolate, - interpolate_at_t1=interpolate_at_t1, - offgrid_marginals=strategy.offgrid_marginals, - )