Skip to content

Commit

Permalink
Update test_all_distrs.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Aug 26, 2023
1 parent e49d3fc commit 3d95ed1
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions skpro/distributions/tests/test_all_distrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,14 @@ def _has_capability(distr, method):
class TestAllDistributions(PackageConfig, DistributionFixtureGenerator, QuickTester):
"""Module level tests for all sktime parameter fitters."""

def test_sample(self, object_instance):
@pytest.mark.parametrize("shuffled", [False, True])
def test_sample(self, object_instance, shuffled):
"""Test sample expected return."""
d = object_instance

if shuffled:
d = _shuffle_distr(d)

res = d.sample()

assert d.shape == res.shape
Expand All @@ -76,36 +80,50 @@ def test_sample(self, object_instance):
assert (res_panel.index == dummy_panel.index).all()
assert (res_panel.columns == dummy_panel.columns).all()

@pytest.mark.parametrize("shuffled", [False, True])
@pytest.mark.parametrize("method", METHODS_SCALAR, ids=METHODS_SCALAR)
def test_methods_scalar(self, object_instance, method):
def test_methods_scalar(self, object_instance, method, shuffled):
"""Test expected return of scalar methods."""
if not _has_capability(object_instance, method):
return None

d = object_instance
if shuffled:
d = _shuffle_distr(d)

res = getattr(object_instance, method)()

_check_output_format(res, d, method)

@pytest.mark.parametrize("shuffled", [False, True])
@pytest.mark.parametrize("method", METHODS_X, ids=METHODS_X)
def test_methods_x(self, object_instance, method):
def test_methods_x(self, object_instance, method, shuffled):
"""Test expected return of methods that take sample-like argument."""
if not _has_capability(object_instance, method):
return None

d = object_instance

if shuffled:
d = _shuffle_distr(d)

x = d.sample()
res = getattr(object_instance, method)(x)

_check_output_format(res, d, method)

@pytest.mark.parametrize("shuffled", [False, True])
@pytest.mark.parametrize("method", METHODS_P, ids=METHODS_P)
def test_methods_p(self, object_instance, method):
def test_methods_p(self, object_instance, method, shuffled):
"""Test expected return of methods that take percentage-like argument."""
if not _has_capability(object_instance, method):
return None

d = object_instance

if shuffled:
d = _shuffle_distr(d)

np_unif = np.random.uniform(size=d.shape)
p = pd.DataFrame(np_unif, index=d.index, columns=d.columns)
res = getattr(object_instance, method)(p)
Expand Down Expand Up @@ -146,3 +164,9 @@ def _check_output_format(res, dist, method):

if method in METHODS_SCALAR_POS or method in METHODS_X_POS:
assert (res >= 0).all().all()


def _shuffle_distr(d):
"""Shuffle distribution row index."""
shuffled_index = d.index.sample(frac=1)
return d.loc[shuffled_index]

0 comments on commit 3d95ed1

Please sign in to comment.