Skip to content

Commit

Permalink
Add sort and remove duplication to statistical_inefficiency (#119)
Browse files Browse the repository at this point in the history
* fix #118
* set default to false
* Update subsampling.py
* bump coverage
* update test
  • Loading branch information
xiki-tempula authored Apr 14, 2021
1 parent 0b31fb2 commit 950b591
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The rules for this file:
* 0.?.?

Enhancements
- Allow automatic sorting and duplication removal during subsampling
(issue #118, PR #119).
- Allow statistical_inefficiency to work on multiindex series. (issue #116,
PR #117)
- Allow the overlap matrix of the MBAR estimator to be plotted. (issue #73,
Expand Down
53 changes: 47 additions & 6 deletions src/alchemlyb/preprocessing/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def slicing(df, lower=None, upper=None, step=None, force=False):


def statistical_inefficiency(df, series=None, lower=None, upper=None, step=None,
conservative=True):
conservative=True, drop_duplicates=False, sort=False):
"""Subsample a DataFrame based on the calculated statistical inefficiency
of a timeseries.
Expand All @@ -83,6 +83,10 @@ def statistical_inefficiency(df, series=None, lower=None, upper=None, step=None,
intervals (the default). ``False`` will sample at non-uniform intervals to
closely match the (fractional) statistical_inefficieny, as implemented
in :func:`pymbar.timeseries.subsampleCorrelatedData`.
drop_duplicates : bool
Drop the duplicated lines based on time.
sort : bool
Sort the Dataframe based on the time column.
Returns
-------
Expand Down Expand Up @@ -120,13 +124,50 @@ def statistical_inefficiency(df, series=None, lower=None, upper=None, step=None,
"""
if _check_multiple_times(df):
raise KeyError("Duplicate time values found; statistical inefficiency "
"only works on a single, contiguous, "
"and sorted timeseries.")
if drop_duplicates:
if isinstance(df, pd.Series):
# remove the duplicate based on time
drop_duplicates_series = df.reset_index('time', name='').\
drop_duplicates('time')
# Rest the time index
lambda_names = ['time',]
lambda_names.extend(drop_duplicates_series.index.names)
df = drop_duplicates_series.set_index('time', append=True).\
reorder_levels(lambda_names)
else:
# remove the duplicate based on time
drop_duplicates_df = df.reset_index('time').drop_duplicates('time')
# Rest the time index
lambda_names = ['time',]
lambda_names.extend(drop_duplicates_df.index.names)
df = drop_duplicates_df.set_index('time', append=True).\
reorder_levels(lambda_names)

# Do the same withing with the series
if series is not None:
# remove the duplicate based on time
drop_duplicates_series = series.reset_index('time', name='').\
drop_duplicates('time')
# Rest the time index
lambda_names = ['time',]
lambda_names.extend(drop_duplicates_series.index.names)
series = drop_duplicates_series.set_index('time', append=True).\
reorder_levels(lambda_names)

else:
raise KeyError("Duplicate time values found; statistical inefficiency "
"only works on a single, contiguous, "
"and sorted timeseries.")

if not _check_sorted(df):
raise KeyError("Statistical inefficiency only works as expected if "
"values are sorted by time, increasing.")
if sort:
df = df.sort_index(level='time')

if series is not None:
series = series.sort_index(level='time')
else:
raise KeyError("Statistical inefficiency only works as expected if "
"values are sorted by time, increasing.")

if series is not None:
series = slicing(series, lower=lower, upper=upper, step=step)
Expand Down
56 changes: 56 additions & 0 deletions src/alchemlyb/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,61 @@ def test_multiindex_duplicated(self, gmx_ABFE):
gmx_ABFE.sum(axis=1))
assert len(subsample) == 501

def test_sort_off(self, gmx_ABFE):
unsorted = pd.concat([gmx_ABFE[-500:], gmx_ABFE[:500]])
with pytest.raises(KeyError):
statistical_inefficiency(unsorted,
unsorted.sum(axis=1),
sort=False)

def test_sort_on(self, gmx_ABFE):
unsorted = pd.concat([gmx_ABFE[-500:], gmx_ABFE[:500]])
subsample = statistical_inefficiency(unsorted,
unsorted.sum(axis=1),
sort=True)
assert subsample.reset_index(0)['time'].is_monotonic_increasing

def test_sort_on_noseries(self, gmx_ABFE):
unsorted = pd.concat([gmx_ABFE[-500:], gmx_ABFE[:500]])
subsample = statistical_inefficiency(unsorted,
None,
sort=True)
assert subsample.reset_index(0)['time'].is_monotonic_increasing

def test_duplication_off(self, gmx_ABFE):
duplicated = pd.concat([gmx_ABFE, gmx_ABFE])
with pytest.raises(KeyError):
statistical_inefficiency(duplicated,
duplicated.sum(axis=1),
drop_duplicates=False)

def test_duplication_on_dataframe(self, gmx_ABFE):
duplicated = pd.concat([gmx_ABFE, gmx_ABFE])
subsample = statistical_inefficiency(duplicated,
duplicated.sum(axis=1),
drop_duplicates=True)
assert len(subsample) < 1000

def test_duplication_on_dataframe_noseries(self, gmx_ABFE):
duplicated = pd.concat([gmx_ABFE, gmx_ABFE])
subsample = statistical_inefficiency(duplicated,
None,
drop_duplicates=True)
assert len(subsample) == 1001

def test_duplication_on_series(self, gmx_ABFE):
duplicated = pd.concat([gmx_ABFE, gmx_ABFE])
subsample = statistical_inefficiency(duplicated.sum(axis=1),
duplicated.sum(axis=1),
drop_duplicates=True)
assert len(subsample) < 1000

def test_duplication_on_series_noseries(self, gmx_ABFE):
duplicated = pd.concat([gmx_ABFE, gmx_ABFE])
subsample = statistical_inefficiency(duplicated.sum(axis=1),
None,
drop_duplicates=True)
assert len(subsample) == 1001

class CorrelatedPreprocessors:

Expand Down Expand Up @@ -135,3 +190,4 @@ class TestEquilibriumDetection(TestSlicing, CorrelatedPreprocessors):

def slicer(self, *args, **kwargs):
return equilibrium_detection(*args, **kwargs)

0 comments on commit 950b591

Please sign in to comment.