From 263a534844ac4700eed3b43e98eb25d587b049e2 Mon Sep 17 00:00:00 2001 From: Marvin Pfoertner Date: Fri, 28 Aug 2020 12:44:58 +0200 Subject: [PATCH] Bugfixes --- src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py | 4 ++-- tests/test_diffeq/test_odesolver.py | 10 +++++----- tests/{test_core => test_random_variables}/__init__.py | 0 .../test_random_variables/test_dirac.py | 1 - .../test_random_variables/test_normal.py | 0 .../test_random_variables/test_random_variable.py | 0 6 files changed, 7 insertions(+), 8 deletions(-) rename tests/{test_core => test_random_variables}/__init__.py (100%) rename tests/{test_core => }/test_random_variables/test_dirac.py (96%) rename tests/{test_core => }/test_random_variables/test_normal.py (100%) rename tests/{test_core => }/test_random_variables/test_random_variable.py (100%) diff --git a/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py b/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py index 26813145d7..da4c1ebfa3 100644 --- a/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py +++ b/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py @@ -49,7 +49,7 @@ def step(self, t, t_new, current_rv, **kwargs): t_new, pred_rv, zero_data, **kwargs ) errorest, self.sigma_squared_current = self._estimate_error( - filt_rv.mean(), crosscov, meas_cov, meas_mean + filt_rv.mean, crosscov, meas_cov, meas_mean ) return filt_rv, errorest @@ -73,7 +73,7 @@ def postprocess(self, times, rvs): def _rescale(self, rvs): """Rescales covariances according to estimate sigma squared value.""" - rvs = [Normal(rv.mean(), self.sigma_squared_global * rv.cov()) for rv in rvs] + rvs = [Normal(rv.mean, self.sigma_squared_global * rv.cov) for rv in rvs] return rvs def _odesmooth(self, ode_solution, **kwargs): diff --git a/tests/test_diffeq/test_odesolver.py b/tests/test_diffeq/test_odesolver.py index 237126975b..a5c005cc6f 100644 --- a/tests/test_diffeq/test_odesolver.py +++ b/tests/test_diffeq/test_odesolver.py @@ -3,7 +3,7 @@ import numpy as np from probnum.diffeq import logistic, ODESolver, ConstantSteps -from probnum.prob import RandomVariable, Dirac +from probnum.random_variables import Dirac class MockODESolver(ODESolver): @@ -14,10 +14,10 @@ def initialise(self): def step(self, start, stop, current): h = stop - start - x = current.mean() + x = current.mean xnew = x + h * self.ivp(start, x) return ( - RandomVariable(Dirac(xnew)), + Dirac(xnew), np.nan, ) # return nan as error estimate to ensure that it is not used @@ -29,7 +29,7 @@ class ODESolverTestCase(unittest.TestCase): """ def setUp(self): - y0 = RandomVariable(distribution=Dirac(0.3)) + y0 = Dirac(0.3) ivp = logistic([0, 4], initrv=y0) self.solver = MockODESolver(ivp) self.step = 0.2 @@ -42,7 +42,7 @@ def test_solve(self): # quick check that the result is sensible self.assertAlmostEqual(odesol.t[-1], self.solver.ivp.tmax) - self.assertAlmostEqual(odesol.y[-1].mean(), 1.0, places=2) + self.assertAlmostEqual(odesol.y[-1].mean, 1.0, places=2) if __name__ == "__main__": diff --git a/tests/test_core/__init__.py b/tests/test_random_variables/__init__.py similarity index 100% rename from tests/test_core/__init__.py rename to tests/test_random_variables/__init__.py diff --git a/tests/test_core/test_random_variables/test_dirac.py b/tests/test_random_variables/test_dirac.py similarity index 96% rename from tests/test_core/test_random_variables/test_dirac.py rename to tests/test_random_variables/test_dirac.py index 51ac994ada..bcc1579797 100644 --- a/tests/test_core/test_random_variables/test_dirac.py +++ b/tests/test_random_variables/test_dirac.py @@ -4,7 +4,6 @@ import numpy as np from probnum import random_variables as rvs -from probnum import utils as _utils class TestDirac(unittest.TestCase): diff --git a/tests/test_core/test_random_variables/test_normal.py b/tests/test_random_variables/test_normal.py similarity index 100% rename from tests/test_core/test_random_variables/test_normal.py rename to tests/test_random_variables/test_normal.py diff --git a/tests/test_core/test_random_variables/test_random_variable.py b/tests/test_random_variables/test_random_variable.py similarity index 100% rename from tests/test_core/test_random_variables/test_random_variable.py rename to tests/test_random_variables/test_random_variable.py