Skip to content

Commit

Permalink
Enhance the plot for the channels of infections. (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasraabe authored May 16, 2021
1 parent c6a185d commit a52606e
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 23 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0
rev: v4.0.1
hooks:
- id: check-added-large-files
args: ['--maxkb=500']
Expand Down Expand Up @@ -32,7 +32,7 @@ repos:
- id: text-unicode-replacement-char
exclude_types: [jupyter]
- repo: https://github.com/asottile/pyupgrade
rev: v2.15.0
rev: v2.16.0
hooks:
- id: pyupgrade
args: [--py36-plus]
Expand Down Expand Up @@ -76,7 +76,7 @@ repos:
Pygments,
]
- repo: https://github.com/nbQA-dev/nbQA
rev: 0.8.0
rev: 0.8.1
hooks:
- id: nbqa-black
- id: nbqa-pyupgrade
Expand All @@ -86,7 +86,7 @@ repos:
hooks:
- id: doc8
- repo: https://github.com/econchick/interrogate
rev: 1.3.2
rev: 1.4.0
hooks:
- id: interrogate
args: [-v, --fail-under=78, src]
Expand Down
5 changes: 4 additions & 1 deletion docs/source/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ all releases are available on `Anaconda.org
0.0.8 - 2021-05-13
------------------

- :gh:`124` fixes a bug in the function which reports the channel of infections by
- :gh:`125` fixes a bug in the function which reports the channel of infections by
contacts.
- :gh:`126` enhances the plot for the channel of infections. The displayed numbers can
be shares among all infected, among all individuals or seven days incidences per
100,000 people.


0.0.7 - 2021-05-12
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "9a14485a-befd-4461-937f-26e411a06689",
"id": "43203479-d674-4af7-a5a1-2cb9973a7007",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -125,18 +125,17 @@
" seed=144,\n",
")\n",
"\n",
"result = simulate(params)\n",
"\n",
"heatmap = plot_infection_rates_by_contact_models(result[\"time_series\"])"
"result = simulate(params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef9a38f9-9db9-4ffe-bbc0-a25c1455984e",
"id": "dcd2748e-8a06-4c78-b8e0-b95c32352c80",
"metadata": {},
"outputs": [],
"source": [
"heatmap = plot_infection_rates_by_contact_models(result[\"time_series\"], unit=\"share\")\n",
"heatmap"
]
}
Expand Down
178 changes: 165 additions & 13 deletions src/sid/plotting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import itertools
from typing import Any
from typing import Dict
from typing import Optional
from typing import Union

import dask.dataframe as dd
import holoviews as hv
import numpy as np
import pandas as pd
from bokeh.models import HoverTool
from sid.colors import get_colors
Expand Down Expand Up @@ -171,11 +176,54 @@ def _create_y_ticks_and_labels(df):
"ylabel": "Contact Model",
"invert_yaxis": True,
"colorbar": True,
"cmap": "YlOrBr",
}


def plot_infection_rates_by_contact_models(df_or_time_series, fig_kwargs=None):
"""Plot infection rates by contact models."""
def plot_infection_rates_by_contact_models(
df_or_time_series: Union[pd.DataFrame, dd.core.DataFrame],
show_reported_cases: bool = False,
unit: str = "share",
fig_kwargs: Optional[Dict[str, Any]] = None,
) -> hv.HeatMap:
"""Plot infection rates by contact models.
Parameters
----------
df_or_time_series : Union[pandas.DataFrame, dask.dataframe.core.DataFrame]
The input can be one of the following two.
1. It is a :class:`dask.dataframe.core.DataFrame` which holds the time series
from a simulation.
2. It can be a :class:`pandas.DataFrame` which is created with
:func:`prepare_data_for_infection_rates_by_contact_models`. It allows to
compute the data for various simulations with different seeds and use the
average over all seeds.
show_reported_cases : bool, optional
A boolean to select between reported or real cases of infections. Reported cases
are identified via testing mechanisms.
unit : str
The arguments specifies the unit shown in the figure.
- ``"share"`` means that daily units represent the share of infection caused
by a contact model among all infections on the same day.
- ``"population_share"`` means that daily units represent the share of
infection caused by a contact model among all people on the same day.
- ``"incidence"`` means that the daily units represent incidence levels per
100,000 individuals.
fig_kwargs : Optional[Dict[str, Any]], optional
Additional keyword arguments which are passed to ``heatmap.opts`` to style the
plot. The keyword arguments overwrite or extend the default arguments.
Returns
-------
heatmap : hv.HeatMap
The heatmap object.
"""
fig_kwargs = (
DEFAULT_IR_PER_CM_KWARGS
if fig_kwargs is None
Expand All @@ -185,9 +233,11 @@ def plot_infection_rates_by_contact_models(df_or_time_series, fig_kwargs=None):
if _is_data_prepared_for_heatmap(df_or_time_series):
df = df_or_time_series
else:
df = prepare_data_for_infection_rates_by_contact_models(df_or_time_series)
df = prepare_data_for_infection_rates_by_contact_models(
df_or_time_series, show_reported_cases, unit
)

hv.extension("bokeh")
hv.extension("bokeh", logo=False)

heatmap = hv.HeatMap(df)
plot = heatmap.opts(**fig_kwargs)
Expand All @@ -206,8 +256,38 @@ def _is_data_prepared_for_heatmap(df):
)


def prepare_data_for_infection_rates_by_contact_models(time_series):
"""Prepare the data for the heatmap plot."""
def prepare_data_for_infection_rates_by_contact_models(
time_series: dd.core.DataFrame,
show_reported_cases: bool = False, # noqa: U100
unit: str = "share",
) -> pd.DataFrame:
"""Prepare the data for the heatmap plot.
Parameters
----------
time_series : dask.dataframe.core.DataFrame
The time series of a simulation.
show_reported_cases : bool, optional
A boolean to select between reported or real cases of infections. Reported cases
are identified via testing mechanisms.
unit : str
The arguments specifies the unit shown in the figure.
- ``"share"`` means that daily units represent the share of infection caused
by a contact model among all infections on the same day.
- ``"population_share"`` means that daily units represent the share of
infection caused by a contact model among all people on the same day.
- ``"incidence"`` means that the daily units represent incidence levels per
100,000 individuals.
Returns
-------
time_series : pandas.DataFrame
The time series with the prepared data for the plot.
"""
if isinstance(time_series, pd.DataFrame):
time_series = dd.from_pandas(time_series, npartitions=1)
elif not isinstance(time_series, dd.core.DataFrame):
Expand All @@ -216,20 +296,92 @@ def prepare_data_for_infection_rates_by_contact_models(time_series):
if "channel_infected_by_contact" not in time_series:
raise ValueError(ERROR_MISSING_CHANNEL)

time_series = (
if show_reported_cases:
time_series = _adjust_channel_infected_by_contact_to_new_known_cases(
time_series
)

counts = (
time_series[["date", "channel_infected_by_contact"]]
.groupby(["date", "channel_infected_by_contact"])
.size()
.reset_index()
.rename(columns={0: "n"})
.assign(
)

if unit == "share":
out = counts.query(
"channel_infected_by_contact != 'not_infected_by_contact'"
).assign(
share=lambda x: x["n"]
/ x.groupby("date")["n"].transform("sum", meta=("n", "f8")),
)
.drop(columns="n")
.query("channel_infected_by_contact != 'not_infected_by_contact'")

elif unit == "population_share":
out = counts.assign(
share=lambda x: x["n"]
/ x.groupby("date")["n"].transform("sum", meta=("n", "f8")),
).query("channel_infected_by_contact != 'not_infected_by_contact'")

elif unit == "incidence":
out = counts.query(
"channel_infected_by_contact != 'not_infected_by_contact'"
).assign(share=lambda x: x["n"] * 7 / 100_000)

else:
raise ValueError(
"'unit' should be one of 'share', 'population_share' or 'incidence'"
)

out = out.drop(columns="n").compute()

return out


def _adjust_channel_infected_by_contact_to_new_known_cases(df):
"""Adjust channel of infections by contacts to new known cases.
Channel of infections are recorded on the date an individual got infected which is
not the same date an individual is tested positive with a PCR test.
This function adjusts ``"channel_infected_by_contact"`` such that the infection
channel is shifted to the date when an individual is tested positive.
"""
channel_of_infection_by_contact = _find_channel_of_infection_for_individuals(df)
df = _patch_channel_infected_by_contact(df, channel_of_infection_by_contact)
return df


def _find_channel_of_infection_for_individuals(df):
"""Find the channel of infected by contact for each individual."""
df["channel_infected_by_contact"] = df["channel_infected_by_contact"].cat.as_known()

df["channel_infected_by_contact"] = df[
"channel_infected_by_contact"
].cat.remove_categories(["not_infected_by_contact"])

df = df.dropna(subset=["channel_infected_by_contact"])

df = df["channel_infected_by_contact"].compute()

return df


def _patch_channel_infected_by_contact(df, s):
"""Patch channel of infections by contact to only show channels for known cases."""
df = df.drop(columns="channel_infected_by_contact")

df = df.merge(s.to_frame(name="channel_infected_by_contact"), how="left")

df["channel_infected_by_contact"] = df["channel_infected_by_contact"].mask(
~df["new_known_case"], np.nan
)
if isinstance(time_series, dd.core.DataFrame):
time_series = time_series.compute()

return time_series
df["channel_infected_by_contact"] = (
df["channel_infected_by_contact"]
.cat.add_categories("not_infected_by_contact")
.fillna("not_infected_by_contact")
)

return df

0 comments on commit a52606e

Please sign in to comment.