From 39f7731383f0fef08b29a8f952a0ae9e5597f300 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Klara=20R=C3=B6hrl?= Date: Sat, 5 Jun 2021 10:16:00 +0200 Subject: [PATCH] Do not convert last states to dask DataFrame. (#134) --- docs/source/changes.rst | 2 ++ src/sid/simulate.py | 30 +----------------------------- tests/test_plotting.py | 2 +- tests/test_rapid_tests.py | 4 ++-- tests/test_seasonality.py | 2 +- tests/test_simulate.py | 10 +++++----- tests/test_time.py | 2 +- 7 files changed, 13 insertions(+), 39 deletions(-) diff --git a/docs/source/changes.rst b/docs/source/changes.rst index 7db6fe76..89ed2e44 100644 --- a/docs/source/changes.rst +++ b/docs/source/changes.rst @@ -13,6 +13,8 @@ all releases are available on `Anaconda.org - :gh:`131` moves the parsing of the virus strain infectiousness factor to the simulation. - :gh:`132` sets initialized countdowns to -9,999. +- :gh:`134` changes that the last states are returned as a ``pandas.DataFrame`` and not + as a ``dask.dataframe``. 0.0.9 - 2021-05-28 diff --git a/src/sid/simulate.py b/src/sid/simulate.py index e42df6be..3e7120d9 100644 --- a/src/sid/simulate.py +++ b/src/sid/simulate.py @@ -644,8 +644,7 @@ def _simulate( time_series = _prepare_time_series(path, columns_to_keep, states) results["time_series"] = time_series if return_last_states: - last_states = _prepare_last_states(path, states) - results["last_states"] = last_states + results["last_states"] = states if period_outputs: results["period_outputs"] = evaluated_period_outputs @@ -1067,33 +1066,6 @@ def _prepare_time_series(output_directory, columns_to_keep, last_states): return time_series -def _prepare_last_states(output_directory, last_states): - """Prepare the last_states for the simulation results. - - Args: - output_directory (pathlib.Path): Path to output directory. - columns_to_keep (list): List of variables which should be kept. - last_states (pandas.DataFrame): The states from the last period. - - Returns: - dask.dataframe: The DataFrame with the last states - - - """ - categoricals = { - column: last_states[column].cat.categories.shape[0] - for column in last_states.select_dtypes("category").columns - } - - last_states.to_parquet(output_directory / "last_states" / "last_states.parquet") - last_states = dd.read_parquet( - output_directory / "last_states" / "last_states.parquet", - categories=categoricals, - engine="fastparquet", - ) - return last_states - - def _process_saved_columns( saved_columns: Union[None, Dict[str, Union[bool, str, List[str]]]], initial_state_columns: List[str], diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 040e0893..79e3299b 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -65,7 +65,7 @@ def test_plot_infection_rates_by_contact_models(params, initial_states, tmp_path result = simulate(params) time_series = result["time_series"].compute() - last_states = result["last_states"].compute() + last_states = result["last_states"] for df in [time_series, last_states]: assert isinstance(df, pd.DataFrame) diff --git a/tests/test_rapid_tests.py b/tests/test_rapid_tests.py index 79a44d62..5ed26830 100644 --- a/tests/test_rapid_tests.py +++ b/tests/test_rapid_tests.py @@ -35,7 +35,7 @@ def test_simulate_rapid_tests(params, initial_states, tmp_path): result = simulate(params) time_series = result["time_series"].compute() - last_states = result["last_states"].compute() + last_states = result["last_states"] for df in [time_series, last_states]: assert isinstance(df, pd.DataFrame) @@ -77,7 +77,7 @@ def test_simulate_rapid_tests_with_reaction_models(params, initial_states, tmp_p result = simulate(params) time_series = result["time_series"].compute() - last_states = result["last_states"].compute() + last_states = result["last_states"] for df in [time_series, last_states]: assert isinstance(df, pd.DataFrame) diff --git a/tests/test_seasonality.py b/tests/test_seasonality.py index cd052250..a8b9e52c 100644 --- a/tests/test_seasonality.py +++ b/tests/test_seasonality.py @@ -26,7 +26,7 @@ def test_simulate_a_simple_model(params, initial_states, tmp_path): result = simulate(params) time_series = result["time_series"].compute() - last_states = result["last_states"].compute() + last_states = result["last_states"] for df in [time_series, last_states]: assert isinstance(df, pd.DataFrame) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index fcf4bbd3..e7bfffa5 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -31,7 +31,7 @@ def test_simulate_a_simple_model(params, initial_states, tmp_path): result = simulate(params) time_series = result["time_series"].compute() - last_states = result["last_states"].compute() + last_states = result["last_states"] for df in [time_series, last_states]: assert isinstance(df, pd.DataFrame) @@ -55,7 +55,7 @@ def test_resume_a_simulation(params, initial_states, tmp_path): result = simulate(params) time_series = result["time_series"].compute() - last_states = result["last_states"].compute() + last_states = result["last_states"] for df in [time_series, last_states]: assert isinstance(df, pd.DataFrame) @@ -77,7 +77,7 @@ def test_resume_a_simulation(params, initial_states, tmp_path): resumed_result = resumed_simulate(params) resumed_time_series = resumed_result["time_series"].compute() - resumed_last_states = resumed_result["last_states"].compute() + resumed_last_states = resumed_result["last_states"] for df in [resumed_time_series, resumed_last_states]: assert isinstance(df, pd.DataFrame) @@ -110,7 +110,7 @@ def test_simulate_a_simple_model_without_assort_by(params, initial_states, tmp_p result = simulate(params) time_series = result["time_series"].compute() - last_states = result["last_states"].compute() + last_states = result["last_states"] for df in [time_series, last_states]: assert isinstance(df, pd.DataFrame) @@ -352,7 +352,7 @@ def test_skipping_factorization_of_assort_by_variable_works( result = simulate(params) time_series = result["time_series"].compute() - last_states = result["last_states"].compute() + last_states = result["last_states"] assert "group_codes_households" not in time_series assert "group_codes_households" not in last_states diff --git a/tests/test_time.py b/tests/test_time.py index 5058f730..9880d0db 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -100,7 +100,7 @@ def test_replace_date_with_period_in_simulation(params, initial_states, tmp_path result = simulate(params) time_series = result["time_series"].compute() - last_states = result["last_states"].compute() + last_states = result["last_states"] for df in [time_series, last_states]: assert isinstance(df, pd.DataFrame)