From a200f1e7b683c7f21a9c33a3e7474556a36859d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 26 Aug 2023 21:07:50 +0100 Subject: [PATCH] [ENH] add test for subsetting distributions (#43) Adds a test to the distribution suite tests ensuring that they subset appropriately via `loc` and `iloc`, and satisfy common interface expectations on the resulting object and index. Also fixes a bug in subsetting of `Empirical` that was highlighted by these tests. --- skpro/distributions/empirical.py | 34 +++++++++++++++++++ skpro/distributions/tests/test_all_distrs.py | 35 ++++++++++++++++++++ skpro/utils/index.py | 21 ++++++++++++ 3 files changed, 90 insertions(+) create mode 100644 skpro/utils/index.py diff --git a/skpro/distributions/empirical.py b/skpro/distributions/empirical.py index 20ae58a1..45718541 100644 --- a/skpro/distributions/empirical.py +++ b/skpro/distributions/empirical.py @@ -128,6 +128,40 @@ def _apply_per_ix(self, func, params, x=None): res.loc[ix, col] = func(spl=spl_t, weights=weights_t, x=x_t, **params) return res.convert_dtypes() + def _iloc(self, rowidx=None, colidx=None): + + index = self.index + columns = self.columns + weights = self.weights + + spl_subset = self.spl + + if rowidx is not None: + rowidx_loc = index[rowidx] + # subset multiindex to rowidx by last level + spl_subset = self.spl.loc[(slice(None), rowidx_loc), :] + if weights is not None: + weights_subset = weights.loc[(slice(None), rowidx_loc)] + else: + weights_subset = None + subs_rowidx = index[rowidx] + else: + subs_rowidx = index + + if colidx is not None: + spl_subset = spl_subset.iloc[:, colidx] + subs_colidx = columns[colidx] + else: + subs_colidx = columns + + return Empirical( + spl_subset, + weights=weights_subset, + time_indep=self.time_indep, + index=subs_rowidx, + columns=subs_colidx, + ) + def energy(self, x=None): r"""Energy of self, w.r.t. self or a constant frame x. diff --git a/skpro/distributions/tests/test_all_distrs.py b/skpro/distributions/tests/test_all_distrs.py index 1c19258f..7aa427a4 100644 --- a/skpro/distributions/tests/test_all_distrs.py +++ b/skpro/distributions/tests/test_all_distrs.py @@ -13,6 +13,7 @@ from skpro.datatypes import check_is_mtype from skpro.distributions.base import BaseDistribution from skpro.tests.test_all_estimators import PackageConfig +from skpro.utils.index import random_ss_ix class DistributionFixtureGenerator(BaseFixtureGenerator): @@ -132,6 +133,40 @@ def _check_quantile_output(obj, q): res = d.quantile(q) _check_quantile_output(res, q) + @pytest.mark.parametrize("subset_row", [True, False]) + @pytest.mark.parametrize("subset_col", [True, False]) + def test_subsetting(self, object_instance, subset_row, subset_col): + """Test subsetting of distribution.""" + d = object_instance + + if subset_row: + ix_loc = random_ss_ix(d.index, 3) + ix_iloc = d.index.get_indexer(ix_loc) + else: + ix_loc = d.index + ix_iloc = pd.RangeIndex(len(d.index)) + + if subset_col: + iy_loc = random_ss_ix(d.columns, 1) + iy_iloc = d.columns.get_indexer(iy_loc) + else: + iy_loc = d.columns + iy_iloc = pd.RangeIndex(len(d.columns)) + + res_loc = d.loc[ix_loc, iy_loc] + + assert isinstance(res_loc, type(d)) + assert res_loc.shape == (len(ix_loc), len(iy_loc)) + assert (res_loc.index == ix_loc).all() + assert (res_loc.columns == iy_loc).all() + + res_iloc = d.iloc[ix_iloc, iy_iloc] + + assert isinstance(res_iloc, type(d)) + assert res_iloc.shape == (len(ix_iloc), len(iy_iloc)) + assert (res_iloc.index == ix_loc).all() + assert (res_iloc.columns == iy_loc).all() + def _check_output_format(res, dist, method): """Check output format expectations for BaseDistribution tests.""" diff --git a/skpro/utils/index.py b/skpro/utils/index.py new file mode 100644 index 00000000..5496dfd3 --- /dev/null +++ b/skpro/utils/index.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +"""Utility functions for working with indices.""" + +import numpy as np + + +def random_ss_ix(ix, size, replace=True): + """Randomly uniformly sample indices from a list of indices. + + Parameters + ---------- + ix : pd.Index or subsettable iterable via getitem + list of indices to sample from + size : int + number of indices to sample + replace : bool, default=True + whether to sample with replacement + """ + a = range(len(ix)) + ixs = ix[np.random.choice(a, size=size, replace=replace)] + return ixs