From c6a185d822ce04d5325f68e14ce6f1dca3591cd3 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 13 May 2021 09:40:16 +0200 Subject: [PATCH] Fix bug in inferring which contact models caused which infection. (#125) --- .conda/meta.yaml | 2 +- docs/rtd_environment.yml | 2 +- docs/source/changes.rst | 7 +++++++ environment.yml | 2 +- setup.cfg | 2 +- src/sid/contacts.py | 18 +++++++++--------- tests/test_contacts.py | 38 ++++++++++++++++++++++++++++++++++++++ tests/test_msm.py | 1 + tests/test_plotting.py | 1 + tests/test_seasonality.py | 1 + 10 files changed, 61 insertions(+), 13 deletions(-) diff --git a/.conda/meta.yaml b/.conda/meta.yaml index f9f761bb..d3258aca 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -26,7 +26,7 @@ requirements: - bokeh - dask - - fastparquet + - fastparquet !=0.6.1 - holoviews - numba >=0.48 - pandas >=1 diff --git a/docs/rtd_environment.yml b/docs/rtd_environment.yml index abb34c7b..5f414067 100644 --- a/docs/rtd_environment.yml +++ b/docs/rtd_environment.yml @@ -21,7 +21,7 @@ dependencies: # Package dependencies. - bokeh - dask - - fastparquet + - fastparquet !=0.6.1 - holoviews - numba >=0.48 - pandas >=1 diff --git a/docs/source/changes.rst b/docs/source/changes.rst index 907dba34..4c51d373 100644 --- a/docs/source/changes.rst +++ b/docs/source/changes.rst @@ -7,6 +7,13 @@ all releases are available on `Anaconda.org `_. +0.0.8 - 2021-05-13 +------------------ + +- :gh:`124` fixes a bug in the function which reports the channel of infections by + contacts. + + 0.0.7 - 2021-05-12 ------------------ diff --git a/environment.yml b/environment.yml index 198e9a37..f79d86ef 100644 --- a/environment.yml +++ b/environment.yml @@ -19,7 +19,7 @@ dependencies: # Package dependencies - bokeh - dask - - fastparquet + - fastparquet !=0.6.1 - holoviews - numba >=0.48 - pandas >=1 diff --git a/setup.cfg b/setup.cfg index d1c88089..85391e77 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ packages = find: install_requires = bokeh dask - fastparquet + fastparquet!=0.6.1 holoviews matplotlib numba>=0.48 diff --git a/src/sid/contacts.py b/src/sid/contacts.py index 71444758..03fcb847 100644 --- a/src/sid/contacts.py +++ b/src/sid/contacts.py @@ -714,12 +714,11 @@ def _consolidate_reason_of_infection( ) was_infected_by = np.full(n_individuals, -1) contact_model_to_code = {c: i for i, c in enumerate(contact_models)} + recurrent_models, random_models = separate_contact_model_names(contact_models) if was_infected_by_random is not None: random_pos_to_code = { - i: contact_model_to_code[c] - for i, c in enumerate(contact_models) - if not contact_models[c]["is_recurrent"] + i: contact_model_to_code[c] for i, c in enumerate(random_models) } mask = was_infected_by_random >= 0 @@ -729,9 +728,7 @@ def _consolidate_reason_of_infection( if was_infected_by_recurrent is not None: recurrent_pos_to_code = { - i: contact_model_to_code[c] - for i, c in enumerate(contact_models) - if contact_models[c]["is_recurrent"] + i: contact_model_to_code[c] for i, c in enumerate(recurrent_models) } mask = was_infected_by_recurrent >= 0 @@ -749,7 +746,10 @@ def _consolidate_reason_of_infection( def _numpy_replace(x: np.ndarray, replace_to: Dict[Any, Any]): """Replace values in a NumPy array with a dictionary.""" - sort_idx = np.argsort(list(replace_to)) - idx = np.searchsorted(list(replace_to), x, sorter=sort_idx) - out = np.asarray(list(replace_to.values()))[sort_idx][idx] + uniques = np.unique(x) + full_replace_to = {i: replace_to.get(i, i) for i in uniques} + + sort_idx = np.argsort(list(full_replace_to)) + idx = np.searchsorted(list(full_replace_to), x, sorter=sort_idx) + out = np.asarray(list(full_replace_to.values()))[sort_idx][idx] return out diff --git a/tests/test_contacts.py b/tests/test_contacts.py index ed04bae1..290aa2ce 100644 --- a/tests/test_contacts.py +++ b/tests/test_contacts.py @@ -4,6 +4,8 @@ import numpy as np import pandas as pd import pytest +from sid.contacts import _consolidate_reason_of_infection +from sid.contacts import _numpy_replace from sid.contacts import calculate_infections_by_contacts from sid.contacts import create_group_indexer @@ -234,3 +236,39 @@ def test_calculate_infections_only_non_recurrent(households_w_one_infected): exp_infection_counter ) assert not np.any(calc_missed_contacts) + + +@pytest.mark.unit +def test_consolidate_reason_of_infection(): + was_infected_by_recurrent = np.array([0, 1, 1, -1, -1, -1, 0, -1]) + was_infected_by_random = np.array([-1, -1, -1, 0, 0, 1, 0, -1]) + + contact_models = { + "a": {"is_recurrent": True}, + "b": {"is_recurrent": True}, + "c": {"is_recurrent": False}, + "d": {"is_recurrent": False}, + } + + result = _consolidate_reason_of_infection( + was_infected_by_recurrent, was_infected_by_random, contact_models + ) + + expected = pd.Series( + pd.Categorical( + ["a", "b", "b", "c", "c", "d", "a", "not_infected_by_contact"], + categories=["not_infected_by_contact", "a", "b", "c", "d"], + ) + ) + + assert result.equals(expected) + + +@pytest.mark.unit +def test_numpy_replace(): + x = np.arange(6) + replace_to = {4: 6, 5: 7} + + result = _numpy_replace(x, replace_to) + + assert (result == np.array([0, 1, 2, 3, 6, 7])).all() diff --git a/tests/test_msm.py b/tests/test_msm.py index 6f62390c..15dea6b9 100644 --- a/tests/test_msm.py +++ b/tests/test_msm.py @@ -60,6 +60,7 @@ def test_get_diag_weighting_matrix(empirical_moments, weights, expected): assert np.all(result == expected) +@pytest.mark.integration def test_get_diag_weighting_matrix_with_scalar_weights(): emp_moms = {0: pd.Series([1, 2]), 1: pd.Series([2, 3, 4])} weights = {0: 0.3, 1: 0.7} diff --git a/tests/test_plotting.py b/tests/test_plotting.py index b5cf303f..040e0893 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -46,6 +46,7 @@ } +@pytest.mark.unit def test_plot_policy_gantt_chart(): plot_policy_gantt_chart(POLICIES_FOR_GANTT_CHART, effects=True) diff --git a/tests/test_seasonality.py b/tests/test_seasonality.py index e501fdd8..f2ae8b03 100644 --- a/tests/test_seasonality.py +++ b/tests/test_seasonality.py @@ -94,6 +94,7 @@ def test_prepare_seasonality_factor( assert result.equals(expected) +@pytest.mark.unit def test_prepare_seasonality_factor_with_dataframe_return(): def _model(params, dates, seed): df = pd.DataFrame(index=dates)