Skip to content

Commit

Permalink
Bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Aug 28, 2020
1 parent 4ce2dde commit 263a534
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def step(self, t, t_new, current_rv, **kwargs):
t_new, pred_rv, zero_data, **kwargs
)
errorest, self.sigma_squared_current = self._estimate_error(
filt_rv.mean(), crosscov, meas_cov, meas_mean
filt_rv.mean, crosscov, meas_cov, meas_mean
)
return filt_rv, errorest

Expand All @@ -73,7 +73,7 @@ def postprocess(self, times, rvs):

def _rescale(self, rvs):
"""Rescales covariances according to estimate sigma squared value."""
rvs = [Normal(rv.mean(), self.sigma_squared_global * rv.cov()) for rv in rvs]
rvs = [Normal(rv.mean, self.sigma_squared_global * rv.cov) for rv in rvs]
return rvs

def _odesmooth(self, ode_solution, **kwargs):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_diffeq/test_odesolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from probnum.diffeq import logistic, ODESolver, ConstantSteps
from probnum.prob import RandomVariable, Dirac
from probnum.random_variables import Dirac


class MockODESolver(ODESolver):
Expand All @@ -14,10 +14,10 @@ def initialise(self):

def step(self, start, stop, current):
h = stop - start
x = current.mean()
x = current.mean
xnew = x + h * self.ivp(start, x)
return (
RandomVariable(Dirac(xnew)),
Dirac(xnew),
np.nan,
) # return nan as error estimate to ensure that it is not used

Expand All @@ -29,7 +29,7 @@ class ODESolverTestCase(unittest.TestCase):
"""

def setUp(self):
y0 = RandomVariable(distribution=Dirac(0.3))
y0 = Dirac(0.3)
ivp = logistic([0, 4], initrv=y0)
self.solver = MockODESolver(ivp)
self.step = 0.2
Expand All @@ -42,7 +42,7 @@ def test_solve(self):

# quick check that the result is sensible
self.assertAlmostEqual(odesol.t[-1], self.solver.ivp.tmax)
self.assertAlmostEqual(odesol.y[-1].mean(), 1.0, places=2)
self.assertAlmostEqual(odesol.y[-1].mean, 1.0, places=2)


if __name__ == "__main__":
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np

from probnum import random_variables as rvs
from probnum import utils as _utils


class TestDirac(unittest.TestCase):
Expand Down
File renamed without changes.

0 comments on commit 263a534

Please sign in to comment.