Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Formally addresses implementing the observed hospitalizations module #39

Merged
merged 35 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d8cca42
Copy doc from wastewater model on Hosp Admin to class
gvegayon Mar 21, 2024
80ed623
Rename metaclasses to metaclass (#36)
Mar 22, 2024
b00687c
Create typos.yaml (#42)
natemcintosh Mar 22, 2024
a9554e0
34 port notebooks under modeldocs to quarto (#35)
Mar 25, 2024
db7e6dc
Extra mathematical description of discrete delay distributions (#44)
SamuelBrand1 Mar 26, 2024
dee97dd
Escaping tau
gvegayon Mar 22, 2024
2bd532c
Adding deterministic obs and process to the equation
gvegayon Mar 27, 2024
3109000
Correcting conflicts
gvegayon Mar 27, 2024
a0e2d4d
Cleaning the quarto documents and working on the getting started diagram
gvegayon Mar 27, 2024
fe16c18
Flexible IHR (now RandomVariable)
gvegayon Mar 27, 2024
ec4fc35
Adding weekday and phosp effect to latent hosp
gvegayon Mar 27, 2024
cb6bf98
Adding back figures
gvegayon Mar 27, 2024
a568921
Adding a test for deterministic/stochastic weekday effect
gvegayon Mar 27, 2024
e6a63b4
Typo
gvegayon Mar 28, 2024
2a7151e
Correcting tests (class name) and improving readme a bit
gvegayon Mar 28, 2024
a8865fd
Adding deterministic module (midway, expected to fail) [skip ci]
gvegayon Mar 28, 2024
d8bc5c1
Refactoring I0 and gen_int (expected to fail) [skip ci]
gvegayon Mar 28, 2024
24a515e
gen_int and I0 now are directly passed to the models
gvegayon Mar 29, 2024
517fce1
In latent hosp, change inf_hosp_int to inform_hosp (clearer name)
gvegayon Apr 1, 2024
813eba8
Adding missing figures (pyrenew demo was not compiling)
gvegayon Apr 1, 2024
ddac288
Renaming inform_hosp
gvegayon Apr 1, 2024
d0851f9
Removing defaults for hosp rate
gvegayon Apr 1, 2024
0acbea5
Changing language (initial infections) + adding section to getting-st…
gvegayon Apr 2, 2024
1b86565
Update model/src/pyrenew/latent/hospitaladmissions.py
gvegayon Apr 2, 2024
33d9ed0
Addressing comments on default priors and varnames
gvegayon Apr 2, 2024
39c54c4
Rt is not default now for basic model
gvegayon Apr 2, 2024
25fa66a
Commas and title
gvegayon Apr 3, 2024
4bc34e7
Update model/src/pyrenew/latent/hospitaladmissions.py
gvegayon Apr 3, 2024
134f786
Renaming hosp reporting variable in latent var
gvegayon Apr 3, 2024
b8a6374
Renaming hosp report
gvegayon Apr 3, 2024
dd37e00
Resolving conflicts
gvegayon Apr 3, 2024
af7de70
Update model/src/pyrenew/latent/hospitaladmissions.py
gvegayon Apr 3, 2024
580a09c
Final renaming of vars in tests
gvegayon Apr 3, 2024
1c2ba0a
Merge branch 'main' into 16-observation-model-for-hospital-signals
dylanhmorris Apr 3, 2024
90627e6
Different vector for hosp_report_prob_dist in tests
gvegayon Apr 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions model/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@ test:
docs: docs/pyrenew_demo.md docs/getting-started.md

docs/pyrenew_demo.md: docs/pyrenew_demo.qmd
quarto render docs/pyrenew_demo.qmd
poetry run quarto render docs/pyrenew_demo.qmd

docs/getting-started.md: docs/getting-started.qmd
quarto render docs/getting-started.qmd
poetry run quarto render docs/getting-started.qmd

.PHONY: install test docs
clean:
rm -rf docs/*_files/
rm -f docs/getting-started.ipynb
rm -f docs/pyrenew_demo.ipynb

.PHONY: install test docs clean
9 changes: 6 additions & 3 deletions model/README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# PyRenew
# PyRenew: A Package for Bayesian Renewal Modeling with JAX and Numpyro.

A package for Bayesian renewal modeling with JAX and Numpyro.
`pyrenew` is a flexible tool for simulating and statistical inference of epidemiological models, emphasizing renewal models. Built on top of the [`numpyro`](https://num.pyro.ai/) Python library, `pyrenew` provides core components for model building, including pre-defined models for processing various types of observational processes.

## Installation

Install via pip with

```bash
pip install git+https://github.com/cdcent/cfa-pyrenew.git
```

## Demo
The `docs` folder contains a Jupyter notebook with an interactive demo to get you started. It simulates observed hospitalizations using a simple renewal process model and then fits to it using a No-U-Turn Sampler.

The [`docs`](docs) folder contains quarto documents to get you started. It simulates observed hospitalizations using a simple renewal process model and then fits it using a No-U-Turn Sampler.
1 change: 1 addition & 0 deletions model/docs/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
!*png
*_files/libs
289 changes: 182 additions & 107 deletions model/docs/getting-started.md
Original file line number Diff line number Diff line change
@@ -1,127 +1,142 @@
# Getting started with pyrenew


This document illustrates two features of `pyrenew`: (a) the set of
included `RandomVariable`s, and (b) model composition.
`pyrenew` is a flexible tool for simulating and making statistical
inference of epidemiological models, emphasizing renewal models. Built
on `numpyro`, `pyrenew` provides core components for model building and
pre-defined models for processing various observational processes. This
document illustrates how `pyrenew` can be used to build a basic renewal
model.

## The fundamentals

`pyrenew`’s core components are the metaclasses `RandomVariable` and
`Model`. From the package’s perspective, a `RandomVariable` is a
quantity models can sample and estimate, **including deterministic
quantities**. Mainly, sampling from a `RandomVariable` involves calling
the `sample()` method. The benefit of this design is the definition of
the sample function can be arbitrary, allowing the user to either sample
from a distribution using `numpyro.sample()`, compute fixed quantities
(like a mechanistic equation), or return a fixed value (like a
pre-computed PMF.) For instance, we may be interested in estimating a
PMF, in which case a `RandomVariable` sampling function may roughly be
defined as:

## Hospitalizations model

`pyrenew` has five main components:

- Utility and math functions,
- The `processes` sub-module,
- The `observations` sub-module,
- The `latent` sub-module, and
- The `models` sub-module

All three of `process`, `observation`, and `latent` contain classes that
inherit from the meta class `RandomVariable`. The classes under `model`
inherit from the meta class `Model`. The following diagram illustrates
the composition the model `pyrenew.models.HospitalizationsModel`:

``` mermaid
flowchart TB

subgraph randprocmod["Processes module"]
direction TB
simprw["SimpleRandomWalkProcess"]
rtrw["RtRandomWalkProcess"]
end

subgraph latentmod["Latent module"]
direction TB
hosp_latent["Hospitalizations"]
inf_latent["Infections"]
end

subgraph obsmod["Observations module"]
direction TB
pois["PoissonObservation"]
nb["NegativeBinomialObservation"]
end

subgraph models["Models module"]
direction TB
basic["RtInfectionsRenewalModel"]
hosp["HospitalizationsModel"]
end

rp(("RandomVariable")) --> |Inherited by| randprocmod
rp -->|Inherited by| latentmod
rp -->|Inherited by| obsmod


model(("Model")) -->|Inherited by| models
``` python
class MyRandVar(RandomVariable):
def sample(...) -> ArrayLike:
return numpyro.sample(...)
```

simprw -->|Composes| rtrw
rtrw -->|Composes| basic
inf_latent -->|Composes| basic
basic -->|Composes| hosp
Whereas, in some other cases, we may instead use a fixed quantity for
that variable (like a pre-computed PMF), where the `RandomVariable`’s
sample function could be defined like:

``` python
class MyRandVar(RandomVariable):
def sample(...) -> ArrayLike:
return jax.numpy.array([0.2, 0.7, 0.1])
```

obsmod -->|Composes|models
hosp_latent -->|Composes| hosp
This way, when a `Model` samples from `MyRandVar`, it could be either
adding random variables to be estimated (first case) or just retrieving
some quantity needed for other calculations (second case.)

%% Metaclasses
classDef Metaclass color:black,fill:white
class rp,model Metaclass
The `Model` metaclass provides basic functionality for estimating and
simulation. Like `RandomVariable`, the `Model` metaclass has a
`sample()` method that defines the model structure. Ultimately, models
can be nested (or inherited), providing a straightforward way to add
layers of complexity.

%% Random process
classDef Randproc fill:purple,color:white
class rtrw,simprw Randproc
## ‘Hello world’ model

%% Models
classDef Models fill:teal,color:white
class basic,hosp Models
```

We start by loading the needed components to build a basic renewal
model:
This section will show the steps to build a simple renewal model
featuring a latent infection process, a random walk Rt process, and an
observation process for the reported infections. We start by loading the
needed components to build a basic renewal model:

``` python
import jax.numpy as jnp
import numpy as np
import numpyro as npro
import numpyro.distributions as dist
from pyrenew.process import RtRandomWalkProcess
from pyrenew.latent import Infections
from pyrenew.latent import Infections, Infections0
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF
from pyrenew.model import RtInfectionsRenewalModel
```

/mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

In the basic renewal model we can define three components: Rt, latent
infections, and observed infections.
The basic renewal model defines five components: generation interval,
initial infections, Rt, latent infections, and observed infections. In
this example, the generation interval is not estimated but passed as a
deterministic instance of `RandomVariable`. Here is the code to
initialize the five components:

``` python
latent_infections = Infections(
gen_int=jnp.array([0.25, 0.25, 0.25, 0.25]),
)
# (1) The generation interval (deterministic)
gen_int = DeterministicPMF(
(jnp.array([0.25, 0.25, 0.25, 0.25]),),
)

observed_infections = PoissonObservation(
rate_varname='latent',
counts_varname='observed_infections',
)
# (2) Initial infections (inferred with a prior)
I0 = Infections0(I0_dist=dist.LogNormal(0, 1))

# (3) The random process for Rt
rt_proc = RtRandomWalkProcess()

# (4) Latent infection process (which will use 1 and 2)
latent_infections = Infections()

# (5) The observed infections process (with mean at the latent infections)
observed_infections = PoissonObservation(
rate_varname = 'latent',
counts_varname = 'observed_infections',
)
```

With observation process for the latent infections, we can build the
basic renewal model, and generate a sample calling the `sample()`
method:
With these five pieces, we can build the basic renewal model:

``` python
model1 = RtInfectionsRenewalModel(
Rt_process=rt_proc,
latent_infections=latent_infections,
observed_infections=observed_infections,
gen_int = gen_int,
I0 = I0,
Rt_process = rt_proc,
latent_infections = latent_infections,
observed_infections = observed_infections,
)
```

The following diagram summarizes how the modules interact via
composition; notably, `gen_int`, `I0`, `rt_proc`, `latent_infections`,
and `observed_infections` are instances of `RandomVariable`, which means
these can be easily replaced to generate a different version of
`RtInfectionsRenewalModel`:

``` mermaid
flowchart TB
genint["(1) gen_int\n(DetermnisticPMF)"]
i0["(2) I0\n(Infections0)"]
rt["(3) rt_proc\n(RtRandomWalkProcess)"]
inf["(4) latent_infections\n(Infections)"]
obs["(5) observed_infections\n(PoissonObservation)"]

model1["model1\n(RtInfectionsRenewalModel)"]

i0-->|Composes|model1
genint-->|Composes|model1
rt-->|Composes|model1
obs-->|Composes|model1
inf-->|Composes|model1
```

Using `numpyro`, we can simulate data using the `sample()` member
function of `RtInfectionsRenewalModel`:

``` python
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 60)):
sim_data = model1.sample(constants=dict(n_timepoints=30))
with npro.handlers.seed(rng_seed = np.random.randint(1, 60)):
sim_data = model1.sample(constants = dict(n_timepoints=30))

sim_data
```
Expand All @@ -131,14 +146,14 @@ sim_data
1.271196 , 1.3189521, 1.3054799, 1.3165426, 1.291952 , 1.3026639,
1.2619467, 1.2852622, 1.3121517, 1.2888998, 1.2641873, 1.2580931,
1.2545817, 1.3092988, 1.2488269, 1.2397509, 1.2071848, 1.2334517,
1.21868 ], dtype=float32), latent=Array([ 3.7023427, 4.850682 , 6.4314823, 8.26245 , 6.9874763,
7.940377 , 9.171101 , 10.051114 , 10.633459 , 11.729475 ,
12.559867 , 13.422887 , 15.364211 , 17.50132 , 19.206314 ,
21.556652 , 23.78112 , 26.719398 , 28.792412 , 32.40454 ,
36.641006 , 40.135487 , 43.60607 , 48.055103 , 52.829704 ,
60.43277 , 63.97854 , 69.82776 , 74.564415 , 82.88904 ,
88.73811 ], dtype=float32), observed=Array([ 4, 3, 6, 5, 7, 7, 10, 11, 6, 9, 7, 13, 16, 19, 20, 27, 23,
31, 28, 30, 43, 42, 55, 57, 44, 52, 64, 52, 77, 85, 94], dtype=int32))
1.21868 ], dtype=float32), latent=Array([ 2.3215084, 3.0415602, 4.0327816, 5.180868 , 4.381411 ,
4.978916 , 5.750626 , 6.3024273, 6.66758 , 7.354823 ,
7.8755097, 8.416656 , 9.63394 , 10.973988 , 12.043082 ,
13.516833 , 14.911659 , 16.75407 , 18.053928 , 20.318869 ,
22.975292 , 25.166464 , 27.34265 , 30.13236 , 33.126217 ,
37.89362 , 40.11695 , 43.784634 , 46.754696 , 51.974545 ,
55.642136 ], dtype=float32), observed=Array([ 1, 2, 3, 5, 4, 4, 7, 4, 8, 4, 7, 3, 8, 12, 13, 18, 14,
20, 17, 18, 28, 27, 36, 37, 26, 31, 40, 27, 48, 54, 60], dtype=int32))

The `sample()` method of the `RtInfectionsRenewalModel` returns a list
composed of the `Rt` and `infections` sequences.
Expand All @@ -162,11 +177,12 @@ plt.tight_layout()
plt.show()
```

<img
src="getting-started_files/figure-commonmark/basic-fig-output-1.png"
id="basic-fig" />
![Rt and
Infections](getting-started_files/figure-commonmark/basic-fig-output-1.png)

Let’s see how the estimation would go
To fit the model, we can use the `run()` method of the model
`RtInfectionsRenewalModel`; an inherited method from the metaclass
`Model`:

``` python
import jax
Expand All @@ -183,7 +199,8 @@ model1.run(
)
```

Now, let’s investigate the output
Now, let’s investigate the output, particularly the posterior
distribution of the Rt estimates:

``` python
import polars as pl
Expand All @@ -202,6 +219,64 @@ ax.set_yticks([0.5, 1, 2])
ax.set_yscale("log")
```

<img
src="getting-started_files/figure-commonmark/output-rt-output-1.png"
id="output-rt" />
![Rt posterior
distribution](getting-started_files/figure-commonmark/output-rt-output-1.png)

## Architecture of pyrenew

`pyrenew` leverages `numpyro`’s flexibility to build models via
composition. As a principle, most objects in `pyrenew` can be treated as
random variables we can sample. At the top-level `pyrenew` has two
metaclass from which most objects inherit: `RandomVariable` and `Model`.
From them, the following four sub-modules arise:

- The `process` sub-module,
- The `deterministic` sub-module,
- The `observation` sub-module,
- The `latent` sub-module, and
- The `models` sub-module

The first four are collections of instances of `RandomVariable`, and the
last is a collection of instances of `Model`. The following diagram
shows a detailed view of how meta classes, modules, and classes interact
to create the `RtInfectionsRenewalModel` instantiated in the previous
section:

``` mermaid
flowchart LR
rand((RandomVariable\nmetaclass))
models((Model\nmetaclass))

subgraph observations[Observations module]
obs["observed_infections\n(PoissonObservation)"]
end

subgraph latent[Latent module]
inf["latent_infections\n(Infections)"]
i0["I0\n(Infections0)"]
end

subgraph process[Process module]
rt["rt_proc\n(RtRandomWalkProcess)"]
end

subgraph deterministic[Deterministic module]
detpmf["gen_int\n(DeterministicPMF)"]
end

subgraph model[Model module]
model1["model1\n(RtInfectionsRenewalModel)"]
end

rand-->|Inherited by|observations
rand-->|Inherited by|process
rand-->|Inherited by|latent
rand-->|Inherited by|deterministic
models-->|Inherited by|model

detpmf-->|Composes|model1
i0-->|Composes|model1
rt-->|Composes|model1
obs-->|Composes|model1
inf-->|Composes|model1
```
Loading