Skip to content

Commit

Permalink
Do not convert last states to dask DataFrame. (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
roecla authored Jun 5, 2021
1 parent 035cb0e commit 39f7731
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 39 deletions.
2 changes: 2 additions & 0 deletions docs/source/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ all releases are available on `Anaconda.org
- :gh:`131` moves the parsing of the virus strain infectiousness factor to the
simulation.
- :gh:`132` sets initialized countdowns to -9,999.
- :gh:`134` changes that the last states are returned as a ``pandas.DataFrame`` and not
as a ``dask.dataframe``.


0.0.9 - 2021-05-28
Expand Down
30 changes: 1 addition & 29 deletions src/sid/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,7 @@ def _simulate(
time_series = _prepare_time_series(path, columns_to_keep, states)
results["time_series"] = time_series
if return_last_states:
last_states = _prepare_last_states(path, states)
results["last_states"] = last_states
results["last_states"] = states
if period_outputs:
results["period_outputs"] = evaluated_period_outputs

Expand Down Expand Up @@ -1067,33 +1066,6 @@ def _prepare_time_series(output_directory, columns_to_keep, last_states):
return time_series


def _prepare_last_states(output_directory, last_states):
"""Prepare the last_states for the simulation results.
Args:
output_directory (pathlib.Path): Path to output directory.
columns_to_keep (list): List of variables which should be kept.
last_states (pandas.DataFrame): The states from the last period.
Returns:
dask.dataframe: The DataFrame with the last states
"""
categoricals = {
column: last_states[column].cat.categories.shape[0]
for column in last_states.select_dtypes("category").columns
}

last_states.to_parquet(output_directory / "last_states" / "last_states.parquet")
last_states = dd.read_parquet(
output_directory / "last_states" / "last_states.parquet",
categories=categoricals,
engine="fastparquet",
)
return last_states


def _process_saved_columns(
saved_columns: Union[None, Dict[str, Union[bool, str, List[str]]]],
initial_state_columns: List[str],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_plot_infection_rates_by_contact_models(params, initial_states, tmp_path
result = simulate(params)

time_series = result["time_series"].compute()
last_states = result["last_states"].compute()
last_states = result["last_states"]

for df in [time_series, last_states]:
assert isinstance(df, pd.DataFrame)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_rapid_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_simulate_rapid_tests(params, initial_states, tmp_path):
result = simulate(params)

time_series = result["time_series"].compute()
last_states = result["last_states"].compute()
last_states = result["last_states"]

for df in [time_series, last_states]:
assert isinstance(df, pd.DataFrame)
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_simulate_rapid_tests_with_reaction_models(params, initial_states, tmp_p
result = simulate(params)

time_series = result["time_series"].compute()
last_states = result["last_states"].compute()
last_states = result["last_states"]

for df in [time_series, last_states]:
assert isinstance(df, pd.DataFrame)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_seasonality.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_simulate_a_simple_model(params, initial_states, tmp_path):
result = simulate(params)

time_series = result["time_series"].compute()
last_states = result["last_states"].compute()
last_states = result["last_states"]

for df in [time_series, last_states]:
assert isinstance(df, pd.DataFrame)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_simulate_a_simple_model(params, initial_states, tmp_path):
result = simulate(params)

time_series = result["time_series"].compute()
last_states = result["last_states"].compute()
last_states = result["last_states"]

for df in [time_series, last_states]:
assert isinstance(df, pd.DataFrame)
Expand All @@ -55,7 +55,7 @@ def test_resume_a_simulation(params, initial_states, tmp_path):
result = simulate(params)

time_series = result["time_series"].compute()
last_states = result["last_states"].compute()
last_states = result["last_states"]

for df in [time_series, last_states]:
assert isinstance(df, pd.DataFrame)
Expand All @@ -77,7 +77,7 @@ def test_resume_a_simulation(params, initial_states, tmp_path):
resumed_result = resumed_simulate(params)

resumed_time_series = resumed_result["time_series"].compute()
resumed_last_states = resumed_result["last_states"].compute()
resumed_last_states = resumed_result["last_states"]

for df in [resumed_time_series, resumed_last_states]:
assert isinstance(df, pd.DataFrame)
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_simulate_a_simple_model_without_assort_by(params, initial_states, tmp_p
result = simulate(params)

time_series = result["time_series"].compute()
last_states = result["last_states"].compute()
last_states = result["last_states"]

for df in [time_series, last_states]:
assert isinstance(df, pd.DataFrame)
Expand Down Expand Up @@ -352,7 +352,7 @@ def test_skipping_factorization_of_assort_by_variable_works(
result = simulate(params)

time_series = result["time_series"].compute()
last_states = result["last_states"].compute()
last_states = result["last_states"]

assert "group_codes_households" not in time_series
assert "group_codes_households" not in last_states
Expand Down
2 changes: 1 addition & 1 deletion tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_replace_date_with_period_in_simulation(params, initial_states, tmp_path
result = simulate(params)

time_series = result["time_series"].compute()
last_states = result["last_states"].compute()
last_states = result["last_states"]

for df in [time_series, last_states]:
assert isinstance(df, pd.DataFrame)
Expand Down

0 comments on commit 39f7731

Please sign in to comment.