Skip to content

Commit

Permalink
Small fixes for estimation of christmas models (#62)
Browse files Browse the repository at this point in the history
Co-authored-by: Klara Roehrl <[email protected]>
Co-authored-by: Tobias Raabe <[email protected]>
  • Loading branch information
3 people authored Jan 4, 2021
1 parent 6e5a9d3 commit 440669e
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 125 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ black = 1
pyupgrade = 1

[tool.nbqa.addopts]
isort = ["--treat-comment-as-code", "# %%"]
isort = ["--treat-comment-as-code", "# %%", "--profile=black"]
pyupgrade = ["--py36-plus"]
2 changes: 1 addition & 1 deletion src/sid/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
DTYPE_INDEX = np.uint32
DTYPE_INFECTED = np.bool_
DTYPE_INFECTION_COUNTER = np.uint16
DTYPE_N_CONTACTS = np.uint32
DTYPE_N_CONTACTS = np.uint16
DTYPE_SID_PERIOD = np.int16

INDEX_NAMES = ["category", "subcategory", "name"]
Expand Down
168 changes: 60 additions & 108 deletions src/sid/contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def calculate_infections_by_contacts(
[params.loc[("infection_prob", cm, cm), "value"] for cm in indexers]
)

reduced_contacts = _reduce_contacts_with_infection_probs(
contacts = _reduce_contacts_with_infection_probs(
contacts, is_recurrent, infect_probs, next(seed)
)

Expand All @@ -123,18 +123,14 @@ def calculate_infections_by_contacts(
for ind in indexers.values():
indexers_list.append(ind)

np.random.seed(next(seed))

loop_order = _get_shuffled_loop_entries(len(states), len(indexers), next(seed))

(
infected,
infection_counter,
immune,
missed,
was_infected_by,
) = _calculate_infections_by_contacts_numba(
reduced_contacts,
contacts,
infectious,
immune,
group_codes,
Expand All @@ -143,7 +139,6 @@ def calculate_infections_by_contacts(
infect_probs,
next(seed),
is_recurrent,
loop_order,
)

infected = pd.Series(infected, index=states.index)
Expand All @@ -165,47 +160,6 @@ def calculate_infections_by_contacts(
return infected, n_has_additionally_infected, missed_contacts, was_infected_by


@nb.njit
def _get_shuffled_loop_entries(n_states, n_contact_models, seed, n_model_orders=1000):
"""Create an array of loop entries.
We save the loop entries of the following loop in shuffled order:
.. code-block:: python
for i in range(n_states):
for j in range(n_contact_models):
pass
The loop entries are stored in an array with ``n_states * n_contact_model`` rows
and two columns. The first column contains the i, the second the j elements.
Achieving complete randomness would require us to first store all loop entries
in an array and then shuffle it. However, this would be very slow. Instead
we loop over states in random order and cycle through previously.
"""
np.random.seed(seed)
res = np.empty((n_states * n_contact_models, 2), dtype=np.int64)
shuffled_state_indices = np.random.choice(n_states, replace=False, size=n_states)

# create random permutations of the model orders
model_orders = np.zeros((n_model_orders, n_contact_models))
for m in range(n_model_orders):
model_orders[m] = np.random.choice(
n_contact_models, replace=False, size=n_contact_models
)

counter = 0

for i in shuffled_state_indices:
for j in model_orders[i % n_model_orders]:
res[counter, 0] = i
res[counter, 1] = j
counter += 1
return res


@nb.njit
def _reduce_contacts_with_infection_probs(contacts, is_recurrent, probs, seed):
"""Reduce the number of contacts stochastically.
Expand All @@ -227,7 +181,7 @@ def _reduce_contacts_with_infection_probs(contacts, is_recurrent, probs, seed):
"""

reduced_contacts = contacts.copy()
contacts = contacts.copy()
np.random.seed(seed)
n_obs, n_contacts = contacts.shape
for i in range(n_obs):
Expand All @@ -237,8 +191,8 @@ def _reduce_contacts_with_infection_probs(contacts, is_recurrent, probs, seed):
for _ in range(contacts[i, j]):
if boolean_choice(probs[j]):
success += 1
reduced_contacts[i, j] = success
return reduced_contacts
contacts[i, j] = success
return contacts


@nb.njit
Expand All @@ -252,7 +206,6 @@ def _calculate_infections_by_contacts_numba(
infection_probs,
seed,
is_recurrent,
loop_order,
):
"""Match people, draw if they get infected and record who infected whom.
Expand All @@ -275,8 +228,6 @@ def _calculate_infections_by_contacts_numba(
probability of infection for each contact model.
seed (int): Seed value to control randomness.
is_recurrent (numpy.ndarray): Boolean array of length n_contact_models.
loop_orrder (numpy.ndarray): 2d numpy array with two columns. The first column
indicates an individual. The second indicates a contact model.
Returns:
(tuple) Tuple containing
Expand All @@ -297,66 +248,67 @@ def _calculate_infections_by_contacts_numba(
groups_list = [np.arange(len(gp)) for gp in group_cdfs]
was_infected_by = np.full(len(contacts), -1, dtype=np.int16)

n_obs, n_contact_models = contacts.shape
# Loop over all individual-contact_model combinations
for k in range(len(loop_order)):
i, cm = loop_order[k]
if is_recurrent[cm]:
# We only check if i infects someone else from his group. Whether
# he is infected by some j is only checked, when the main loop arrives at j.
# This allows us to skip completely if i is not infectious or has no
# contacts under contact model cm.
group_i = group_codes[i, cm]
if group_i >= 0 and infectious[i] and contacts[i, cm] > 0:
others = indexers_list[cm][group_i]
# extract infection probability into a variable for faster access
prob = infection_probs[cm]
for j in others:
# the case i == j is skipped by the next if condition because it
# never happens that i is infectious but not immune
if not immune[j] and contacts[j, cm] > 0:
is_infection = boolean_choice(prob)
if is_infection:
for i in range(n_obs):
for cm in range(n_contact_models):
if is_recurrent[cm]:
# We only check if i infects someone else from his group. Whether
# he is infected by some j is only checked, when the main loop arrives
# at j. This allows us to skip completely if i is not infectious or has
# no contacts under contact model cm.
group_i = group_codes[i, cm]
if group_i >= 0 and infectious[i] and contacts[i, cm] > 0:
others = indexers_list[cm][group_i]
# extract infection probability into a variable for faster access
prob = infection_probs[cm]
for j in others:
# the case i == j is skipped by the next if condition because it
# never happens that i is infectious but not immune
if not immune[j] and contacts[j, cm] > 0:
is_infection = boolean_choice(prob)
if is_infection:
infection_counter[i] += 1
infected[j] = 1
immune[j] = True
was_infected_by[j] = cm

else:
# get the probabilities for meeting another group which depend on the
# individual's group.
group_i = group_codes[i, cm]
group_i_cdf = group_cdfs[cm][group_i]

# Loop over each contact the individual has, sample the contact's group
# and compute the sum of possible contacts in this group.
n_contacts = contacts[i, cm]
for _ in range(n_contacts):
contact_takes_place = True
group_j = choose_other_group(groups_list[cm], cdf=group_i_cdf)
choice_indices = indexers_list[cm][group_j]
contacts_j = contacts[choice_indices, cm]

j = choose_other_individual(choice_indices, weights=contacts_j)

if j < 0 or j == i:
contact_takes_place = False

# If a contact takes place, find out if one individual got infected.
if contact_takes_place:
contacts[i, cm] -= 1
contacts[j, cm] -= 1

if infectious[i] and not immune[j]:
infection_counter[i] += 1
infected[j] = 1
immune[j] = True
was_infected_by[j] = cm

else:
# get the probabilities for meeting another group which depend on the
# individual's group.
group_i = group_codes[i, cm]
group_i_cdf = group_cdfs[cm][group_i]

# Loop over each contact the individual has, sample the contact's group and
# compute the sum of possible contacts in this group.
n_contacts = contacts[i, cm]
for _ in range(n_contacts):
contact_takes_place = True
group_j = choose_other_group(groups_list[cm], cdf=group_i_cdf)
choice_indices = indexers_list[cm][group_j]
contacts_j = contacts[choice_indices, cm]

j = choose_other_individual(choice_indices, weights=contacts_j)

if j < 0 or j == i:
contact_takes_place = False

# If a contact takes place, find out if one individual got infected.
if contact_takes_place:
contacts[i, cm] -= 1
contacts[j, cm] -= 1

if infectious[i] and not immune[j]:
infection_counter[i] += 1
infected[j] = 1
immune[j] = True
was_infected_by[j] = cm

elif infectious[j] and not immune[i]:
infection_counter[j] += 1
infected[i] = 1
immune[i] = True
was_infected_by[i] = cm
elif infectious[j] and not immune[i]:
infection_counter[j] += 1
infected[i] = 1
immune[i] = True
was_infected_by[i] = cm

missed = contacts

Expand Down
8 changes: 7 additions & 1 deletion src/sid/covid_epi_params.csv
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@ cd_needs_icu_false,all,1,0.22
cd_needs_icu_false,all,15,0.30
cd_needs_icu_false,all,25,0.30
cd_needs_icu_false,all,45,0.18
cd_received_test_result_true,all,2,1
cd_received_test_result_true,all,2,0.1
cd_received_test_result_true,all,3,0.1
cd_received_test_result_true,all,4,0.1
cd_received_test_result_true,all,5,0.1
cd_received_test_result_true,all,6,0.1
cd_received_test_result_true,all,7,0.2
cd_received_test_result_true,all,8,0.3
cd_knows_immune_false,all,-1,1
cd_knows_infectious_false,all,-1,1
cd_ever_infected,all,-1,1
4 changes: 2 additions & 2 deletions src/sid/msm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def _msm(
empirical_moments = copy.deepcopy(empirical_moments)

df = simulate(params)
df = df.compute()

simulated_moments = {name: func(df) for name, func in calc_moments.items()}

Expand All @@ -141,7 +140,8 @@ def _msm(
flat_empirical_moments = _flatten_index(empirical_moments)
flat_simulated_moments = _flatten_index(simulated_moments)

moment_errors = flat_empirical_moments - flat_simulated_moments
# Order is important to manfred.
moment_errors = flat_simulated_moments - flat_empirical_moments

# Return moment errors as indexed DataFrame or calculate weighted square product of
# moment errors depending on return_scalar.
Expand Down
5 changes: 4 additions & 1 deletion src/sid/parse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def parse_share_known_cases(share_known_cases, duration, burn_in_periods):
)
# Extend series with burn-in periods and if shares for burn-in periods do not
# exist, backfill NaNs.
share_known_cases = share_known_cases.reindex(extended_index).backfill()
# .backfill() is only available in pandas which is not supported by estimagic.
share_known_cases = share_known_cases.reindex(extended_index).fillna(
method="bfill"
)

elif share_known_cases is None:
share_known_cases = pd.Series(index=extended_index, data=0)
Expand Down
1 change: 1 addition & 0 deletions src/sid/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def factorize_assortative_variables(states, assort_by, is_recurrent):
if is_recurrent:
assort_by_series = states[assort_by[0]].astype(int).replace({-1: pd.NA})
group_codes, group_codes_values = pd.factorize(assort_by_series)
group_codes = group_codes.astype(DTYPE_GROUP_CODE)
elif assort_by:
assort_by_series = [states[col].to_numpy() for col in assort_by]
group_codes, group_codes_values = pd.factorize(
Expand Down
11 changes: 0 additions & 11 deletions tests/test_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def test_calculate_infections_numba_with_single_group(num_regression, seed):
indexer,
infection_prob,
is_recurrent,
loop_order,
) = _sample_data_for_calculate_infections_numba(n_individuals=100, seed=seed)

(
Expand All @@ -70,7 +69,6 @@ def test_calculate_infections_numba_with_single_group(num_regression, seed):
infection_prob,
seed,
is_recurrent,
loop_order,
)

num_regression.check(
Expand Down Expand Up @@ -143,8 +141,6 @@ def _sample_data_for_calculate_infections_numba(

is_recurrent = np.array([False])

loop_order = np.array(list(itertools.product(range(n_individuals), range(1))))

return (
contacts.reshape(-1, 1),
infectious,
Expand All @@ -154,7 +150,6 @@ def _sample_data_for_calculate_infections_numba(
indexers_list,
infection_prob,
is_recurrent,
loop_order,
)


Expand Down Expand Up @@ -332,12 +327,6 @@ def fake_choose_j(a, weights):

m.setattr("sid.contacts.choose_other_individual", fake_choose_j)

@njit
def fix_loop_order(x, replace, size):
return NumbaList(range(x))

m.setattr("sid.contacts.np.random.choice", fix_loop_order)


def test_calculate_infections_only_non_recurrent(
setup_households_w_one_infection, monkeypatch
Expand Down

0 comments on commit 440669e

Please sign in to comment.