From 950b591f79e8a4299934e1f8f6937be2f9b1d198 Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Wed, 14 Apr 2021 23:11:57 +0100 Subject: [PATCH] Add sort and remove duplication to statistical_inefficiency (#119) * fix #118 * set default to false * Update subsampling.py * bump coverage * update test --- CHANGES | 2 + src/alchemlyb/preprocessing/subsampling.py | 53 +++++++++++++++++--- src/alchemlyb/tests/test_preprocessing.py | 56 ++++++++++++++++++++++ 3 files changed, 105 insertions(+), 6 deletions(-) diff --git a/CHANGES b/CHANGES index 3c8e1de6..8265ed48 100644 --- a/CHANGES +++ b/CHANGES @@ -20,6 +20,8 @@ The rules for this file: * 0.?.? Enhancements + - Allow automatic sorting and duplication removal during subsampling + (issue #118, PR #119). - Allow statistical_inefficiency to work on multiindex series. (issue #116, PR #117) - Allow the overlap matrix of the MBAR estimator to be plotted. (issue #73, diff --git a/src/alchemlyb/preprocessing/subsampling.py b/src/alchemlyb/preprocessing/subsampling.py index 199ef880..53129f21 100644 --- a/src/alchemlyb/preprocessing/subsampling.py +++ b/src/alchemlyb/preprocessing/subsampling.py @@ -58,7 +58,7 @@ def slicing(df, lower=None, upper=None, step=None, force=False): def statistical_inefficiency(df, series=None, lower=None, upper=None, step=None, - conservative=True): + conservative=True, drop_duplicates=False, sort=False): """Subsample a DataFrame based on the calculated statistical inefficiency of a timeseries. @@ -83,6 +83,10 @@ def statistical_inefficiency(df, series=None, lower=None, upper=None, step=None, intervals (the default). ``False`` will sample at non-uniform intervals to closely match the (fractional) statistical_inefficieny, as implemented in :func:`pymbar.timeseries.subsampleCorrelatedData`. + drop_duplicates : bool + Drop the duplicated lines based on time. + sort : bool + Sort the Dataframe based on the time column. Returns ------- @@ -120,13 +124,50 @@ def statistical_inefficiency(df, series=None, lower=None, upper=None, step=None, """ if _check_multiple_times(df): - raise KeyError("Duplicate time values found; statistical inefficiency " - "only works on a single, contiguous, " - "and sorted timeseries.") + if drop_duplicates: + if isinstance(df, pd.Series): + # remove the duplicate based on time + drop_duplicates_series = df.reset_index('time', name='').\ + drop_duplicates('time') + # Rest the time index + lambda_names = ['time',] + lambda_names.extend(drop_duplicates_series.index.names) + df = drop_duplicates_series.set_index('time', append=True).\ + reorder_levels(lambda_names) + else: + # remove the duplicate based on time + drop_duplicates_df = df.reset_index('time').drop_duplicates('time') + # Rest the time index + lambda_names = ['time',] + lambda_names.extend(drop_duplicates_df.index.names) + df = drop_duplicates_df.set_index('time', append=True).\ + reorder_levels(lambda_names) + + # Do the same withing with the series + if series is not None: + # remove the duplicate based on time + drop_duplicates_series = series.reset_index('time', name='').\ + drop_duplicates('time') + # Rest the time index + lambda_names = ['time',] + lambda_names.extend(drop_duplicates_series.index.names) + series = drop_duplicates_series.set_index('time', append=True).\ + reorder_levels(lambda_names) + + else: + raise KeyError("Duplicate time values found; statistical inefficiency " + "only works on a single, contiguous, " + "and sorted timeseries.") if not _check_sorted(df): - raise KeyError("Statistical inefficiency only works as expected if " - "values are sorted by time, increasing.") + if sort: + df = df.sort_index(level='time') + + if series is not None: + series = series.sort_index(level='time') + else: + raise KeyError("Statistical inefficiency only works as expected if " + "values are sorted by time, increasing.") if series is not None: series = slicing(series, lower=lower, upper=upper, step=step) diff --git a/src/alchemlyb/tests/test_preprocessing.py b/src/alchemlyb/tests/test_preprocessing.py index 7c183d28..4725c5d0 100644 --- a/src/alchemlyb/tests/test_preprocessing.py +++ b/src/alchemlyb/tests/test_preprocessing.py @@ -78,6 +78,61 @@ def test_multiindex_duplicated(self, gmx_ABFE): gmx_ABFE.sum(axis=1)) assert len(subsample) == 501 + def test_sort_off(self, gmx_ABFE): + unsorted = pd.concat([gmx_ABFE[-500:], gmx_ABFE[:500]]) + with pytest.raises(KeyError): + statistical_inefficiency(unsorted, + unsorted.sum(axis=1), + sort=False) + + def test_sort_on(self, gmx_ABFE): + unsorted = pd.concat([gmx_ABFE[-500:], gmx_ABFE[:500]]) + subsample = statistical_inefficiency(unsorted, + unsorted.sum(axis=1), + sort=True) + assert subsample.reset_index(0)['time'].is_monotonic_increasing + + def test_sort_on_noseries(self, gmx_ABFE): + unsorted = pd.concat([gmx_ABFE[-500:], gmx_ABFE[:500]]) + subsample = statistical_inefficiency(unsorted, + None, + sort=True) + assert subsample.reset_index(0)['time'].is_monotonic_increasing + + def test_duplication_off(self, gmx_ABFE): + duplicated = pd.concat([gmx_ABFE, gmx_ABFE]) + with pytest.raises(KeyError): + statistical_inefficiency(duplicated, + duplicated.sum(axis=1), + drop_duplicates=False) + + def test_duplication_on_dataframe(self, gmx_ABFE): + duplicated = pd.concat([gmx_ABFE, gmx_ABFE]) + subsample = statistical_inefficiency(duplicated, + duplicated.sum(axis=1), + drop_duplicates=True) + assert len(subsample) < 1000 + + def test_duplication_on_dataframe_noseries(self, gmx_ABFE): + duplicated = pd.concat([gmx_ABFE, gmx_ABFE]) + subsample = statistical_inefficiency(duplicated, + None, + drop_duplicates=True) + assert len(subsample) == 1001 + + def test_duplication_on_series(self, gmx_ABFE): + duplicated = pd.concat([gmx_ABFE, gmx_ABFE]) + subsample = statistical_inefficiency(duplicated.sum(axis=1), + duplicated.sum(axis=1), + drop_duplicates=True) + assert len(subsample) < 1000 + + def test_duplication_on_series_noseries(self, gmx_ABFE): + duplicated = pd.concat([gmx_ABFE, gmx_ABFE]) + subsample = statistical_inefficiency(duplicated.sum(axis=1), + None, + drop_duplicates=True) + assert len(subsample) == 1001 class CorrelatedPreprocessors: @@ -135,3 +190,4 @@ class TestEquilibriumDetection(TestSlicing, CorrelatedPreprocessors): def slicer(self, *args, **kwargs): return equilibrium_detection(*args, **kwargs) +