Skip to content

Commit

Permalink
Fix bug in inferring which contact models caused which infection. (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasraabe authored May 13, 2021
1 parent 0681911 commit c6a185d
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ requirements:

- bokeh
- dask
- fastparquet
- fastparquet !=0.6.1
- holoviews
- numba >=0.48
- pandas >=1
Expand Down
2 changes: 1 addition & 1 deletion docs/rtd_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies:
# Package dependencies.
- bokeh
- dask
- fastparquet
- fastparquet !=0.6.1
- holoviews
- numba >=0.48
- pandas >=1
Expand Down
7 changes: 7 additions & 0 deletions docs/source/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ all releases are available on `Anaconda.org
<https://anaconda.org/covid-19-impact-lab/sid>`_.


0.0.8 - 2021-05-13
------------------

- :gh:`124` fixes a bug in the function which reports the channel of infections by
contacts.


0.0.7 - 2021-05-12
------------------

Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies:
# Package dependencies
- bokeh
- dask
- fastparquet
- fastparquet !=0.6.1
- holoviews
- numba >=0.48
- pandas >=1
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ packages = find:
install_requires =
bokeh
dask
fastparquet
fastparquet!=0.6.1
holoviews
matplotlib
numba>=0.48
Expand Down
18 changes: 9 additions & 9 deletions src/sid/contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,12 +714,11 @@ def _consolidate_reason_of_infection(
)
was_infected_by = np.full(n_individuals, -1)
contact_model_to_code = {c: i for i, c in enumerate(contact_models)}
recurrent_models, random_models = separate_contact_model_names(contact_models)

if was_infected_by_random is not None:
random_pos_to_code = {
i: contact_model_to_code[c]
for i, c in enumerate(contact_models)
if not contact_models[c]["is_recurrent"]
i: contact_model_to_code[c] for i, c in enumerate(random_models)
}

mask = was_infected_by_random >= 0
Expand All @@ -729,9 +728,7 @@ def _consolidate_reason_of_infection(

if was_infected_by_recurrent is not None:
recurrent_pos_to_code = {
i: contact_model_to_code[c]
for i, c in enumerate(contact_models)
if contact_models[c]["is_recurrent"]
i: contact_model_to_code[c] for i, c in enumerate(recurrent_models)
}

mask = was_infected_by_recurrent >= 0
Expand All @@ -749,7 +746,10 @@ def _consolidate_reason_of_infection(

def _numpy_replace(x: np.ndarray, replace_to: Dict[Any, Any]):
"""Replace values in a NumPy array with a dictionary."""
sort_idx = np.argsort(list(replace_to))
idx = np.searchsorted(list(replace_to), x, sorter=sort_idx)
out = np.asarray(list(replace_to.values()))[sort_idx][idx]
uniques = np.unique(x)
full_replace_to = {i: replace_to.get(i, i) for i in uniques}

sort_idx = np.argsort(list(full_replace_to))
idx = np.searchsorted(list(full_replace_to), x, sorter=sort_idx)
out = np.asarray(list(full_replace_to.values()))[sort_idx][idx]
return out
38 changes: 38 additions & 0 deletions tests/test_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import pandas as pd
import pytest
from sid.contacts import _consolidate_reason_of_infection
from sid.contacts import _numpy_replace
from sid.contacts import calculate_infections_by_contacts
from sid.contacts import create_group_indexer

Expand Down Expand Up @@ -234,3 +236,39 @@ def test_calculate_infections_only_non_recurrent(households_w_one_infected):
exp_infection_counter
)
assert not np.any(calc_missed_contacts)


@pytest.mark.unit
def test_consolidate_reason_of_infection():
was_infected_by_recurrent = np.array([0, 1, 1, -1, -1, -1, 0, -1])
was_infected_by_random = np.array([-1, -1, -1, 0, 0, 1, 0, -1])

contact_models = {
"a": {"is_recurrent": True},
"b": {"is_recurrent": True},
"c": {"is_recurrent": False},
"d": {"is_recurrent": False},
}

result = _consolidate_reason_of_infection(
was_infected_by_recurrent, was_infected_by_random, contact_models
)

expected = pd.Series(
pd.Categorical(
["a", "b", "b", "c", "c", "d", "a", "not_infected_by_contact"],
categories=["not_infected_by_contact", "a", "b", "c", "d"],
)
)

assert result.equals(expected)


@pytest.mark.unit
def test_numpy_replace():
x = np.arange(6)
replace_to = {4: 6, 5: 7}

result = _numpy_replace(x, replace_to)

assert (result == np.array([0, 1, 2, 3, 6, 7])).all()
1 change: 1 addition & 0 deletions tests/test_msm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_get_diag_weighting_matrix(empirical_moments, weights, expected):
assert np.all(result == expected)


@pytest.mark.integration
def test_get_diag_weighting_matrix_with_scalar_weights():
emp_moms = {0: pd.Series([1, 2]), 1: pd.Series([2, 3, 4])}
weights = {0: 0.3, 1: 0.7}
Expand Down
1 change: 1 addition & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
}


@pytest.mark.unit
def test_plot_policy_gantt_chart():
plot_policy_gantt_chart(POLICIES_FOR_GANTT_CHART, effects=True)

Expand Down
1 change: 1 addition & 0 deletions tests/test_seasonality.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_prepare_seasonality_factor(
assert result.equals(expected)


@pytest.mark.unit
def test_prepare_seasonality_factor_with_dataframe_return():
def _model(params, dates, seed):
df = pd.DataFrame(index=dates)
Expand Down

0 comments on commit c6a185d

Please sign in to comment.