Skip to content

Commit

Permalink
Merge branch 'main' into test-sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Aug 26, 2023
2 parents 3ce3df0 + a200f1e commit 8b68a87
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 0 deletions.
34 changes: 34 additions & 0 deletions skpro/distributions/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,40 @@ def _apply_per_ix(self, func, params, x=None):
res.loc[ix, col] = func(spl=spl_t, weights=weights_t, x=x_t, **params)
return res.convert_dtypes()

def _iloc(self, rowidx=None, colidx=None):

index = self.index
columns = self.columns
weights = self.weights

spl_subset = self.spl

if rowidx is not None:
rowidx_loc = index[rowidx]
# subset multiindex to rowidx by last level
spl_subset = self.spl.loc[(slice(None), rowidx_loc), :]
if weights is not None:
weights_subset = weights.loc[(slice(None), rowidx_loc)]
else:
weights_subset = None
subs_rowidx = index[rowidx]
else:
subs_rowidx = index

if colidx is not None:
spl_subset = spl_subset.iloc[:, colidx]
subs_colidx = columns[colidx]
else:
subs_colidx = columns

return Empirical(
spl_subset,
weights=weights_subset,
time_indep=self.time_indep,
index=subs_rowidx,
columns=subs_colidx,
)

def energy(self, x=None):
r"""Energy of self, w.r.t. self or a constant frame x.
Expand Down
35 changes: 35 additions & 0 deletions skpro/distributions/tests/test_all_distrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from skpro.datatypes import check_is_mtype
from skpro.distributions.base import BaseDistribution
from skpro.tests.test_all_estimators import PackageConfig
from skpro.utils.index import random_ss_ix


class DistributionFixtureGenerator(BaseFixtureGenerator):
Expand Down Expand Up @@ -150,6 +151,40 @@ def _check_quantile_output(obj, q):
res = d.quantile(q)
_check_quantile_output(res, q)

@pytest.mark.parametrize("subset_row", [True, False])
@pytest.mark.parametrize("subset_col", [True, False])
def test_subsetting(self, object_instance, subset_row, subset_col):
"""Test subsetting of distribution."""
d = object_instance

if subset_row:
ix_loc = random_ss_ix(d.index, 3)
ix_iloc = d.index.get_indexer(ix_loc)
else:
ix_loc = d.index
ix_iloc = pd.RangeIndex(len(d.index))

if subset_col:
iy_loc = random_ss_ix(d.columns, 1)
iy_iloc = d.columns.get_indexer(iy_loc)
else:
iy_loc = d.columns
iy_iloc = pd.RangeIndex(len(d.columns))

res_loc = d.loc[ix_loc, iy_loc]

assert isinstance(res_loc, type(d))
assert res_loc.shape == (len(ix_loc), len(iy_loc))
assert (res_loc.index == ix_loc).all()
assert (res_loc.columns == iy_loc).all()

res_iloc = d.iloc[ix_iloc, iy_iloc]

assert isinstance(res_iloc, type(d))
assert res_iloc.shape == (len(ix_iloc), len(iy_iloc))
assert (res_iloc.index == ix_loc).all()
assert (res_iloc.columns == iy_loc).all()


def _check_output_format(res, dist, method):
"""Check output format expectations for BaseDistribution tests."""
Expand Down
21 changes: 21 additions & 0 deletions skpro/utils/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
"""Utility functions for working with indices."""

import numpy as np


def random_ss_ix(ix, size, replace=True):
"""Randomly uniformly sample indices from a list of indices.
Parameters
----------
ix : pd.Index or subsettable iterable via getitem
list of indices to sample from
size : int
number of indices to sample
replace : bool, default=True
whether to sample with replacement
"""
a = range(len(ix))
ixs = ix[np.random.choice(a, size=size, replace=replace)]
return ixs

0 comments on commit 8b68a87

Please sign in to comment.