Skip to content

Commit

Permalink
Add stable age distribution initialization (#482)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris authored Feb 25, 2025
1 parent bbaf1b9 commit 5c29559
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 87 deletions.
140 changes: 140 additions & 0 deletions pyrenew/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,146 @@
from pyrenew.distutil import validate_discrete_dist_vector


def _positive_ints_like(vec: ArrayLike) -> jnp.ndarray:
"""
Given an array of size n, return the 1D Jax array
``[1, ... n]``.
Parameters
----------
vec: ArrayLike
The template array
Returns
-------
jnp.ndarray
The resulting array ``[1, ..., n]``.
"""
return jnp.arange(1, jnp.size(vec) + 1)


def neg_MGF(r: float, w: ArrayLike) -> float:
"""
Compute the negative moment generating function (MGF)
for a given rate ``r`` and weights ``w``.
Parameters
----------
r: float
The rate parameter.
w: ArrayLike
An array of weights.
Returns
-------
float
The value of the negative MGF evaluated at ``r``
and ``w``.
Notes
-----
For a finite discrete random variable :math:`X` supported on
the first :math:`n` positive integers (:math:`\\{1, 2, ..., n \\}`),
the moment generating function (MGF) :math:`M_+(r)` is defined
as the expected value of :math:`\\exp(rX)`. Similarly, the negative
moment generating function :math:`M_-(r)` is the expected value of
:math:`\\exp(-rX)`. So if we represent the PMF of :math:`X` as a
"weights" vector :math:`w` of length :math:`n`, the negative MGF
:math:`M_-(r)` is given by:
.. math::
M_-(r) = \\sum_{t = 1}^{n} w_i \\exp(-rt)
"""
return jnp.sum(w * jnp.exp(-r * _positive_ints_like(w)))


def neg_MGF_del_r(r: float, w: ArrayLike) -> float:
"""
Compute the value of the partial deriative of
:func:`neg_MGF` with respect to ``r``
evaluated at a particular ``r`` and ``w`` pair.
Parameters
----------
r: float
The rate parameter.
w: ArrayLike
An array of weights.
Returns
-------
float
The value of the partial derivative evaluated at ``r``
and ``w``.
"""
t_vec = _positive_ints_like(w)
return -jnp.sum(w * t_vec * jnp.exp(-r * t_vec))


def r_approx_from_R(R: float, g: ArrayLike, n_newton_steps: int) -> ArrayLike:
"""
Get the approximate asymptotic geometric growth rate ``r``
for a renewal process with a fixed reproduction number ``R``
and discrete generation interval PMF ``g``.
Uses Newton's method with a fixed number of steps.
Parameters
----------
R: float
The reproduction number
g: ArrayLike
The probability mass function of the generation
interval.
n_newton_steps: int
Number of steps to take when performing Newton's method.
Returns
-------
float
The approximate value of ``r``.
Notes
-----
For a fixed value of :math:`\\mathcal{R}`, a renewal process
has an asymptotic geometric growth rate :math:`r` that satisfies
.. math::
M_{-}(r) - \\frac{1}{\\mathcal{R}} = 0
where :math:`M_-(r)` is the negative moment generating function
for a random variable :math:`\\tau` representing the (discrete)
generation interval. See :func:`neg_MGF` for details.
We obtain a value for :math:`r` via approximate numerical solution
of this implicit equation.
We first make an initial guess based on the mean generation interval
:math:`\\bar{\\tau} = \\mathbb{E}(\\tau)`:
.. math::
r \\approx \\frac{\\mathcal{R} - 1}{\\mathcal{R} \\bar{\\tau}}
We then refine this approximation by applying Newton's method for
a fixed number of steps.
"""
mean_gi = jnp.dot(g, _positive_ints_like(g))
init_r = (R - 1) / (R * mean_gi)

def _r_next(r, _) -> tuple[ArrayLike, None]: # numpydoc ignore=GL08
return (
r - ((R * neg_MGF(r, g) - 1) / (R * neg_MGF_del_r(r, g))),
None,
)

result, _ = scan(f=_r_next, init=init_r, xs=None, length=n_newton_steps)
return result


def get_leslie_matrix(
R: float, generation_interval_pmf: ArrayLike
) -> ArrayLike:
Expand Down
38 changes: 0 additions & 38 deletions test/test_leslie_matrix.py

This file was deleted.

142 changes: 142 additions & 0 deletions test/test_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
Unit tests for the pyrenew.math module.
"""

import jax.numpy as jnp
import numpy as np
import pytest
from numpy.random import RandomState
from numpy.testing import (
assert_almost_equal,
assert_array_almost_equal,
assert_array_equal,
)

import pyrenew.math as pmath

rng = RandomState(5)


@pytest.mark.parametrize(
"arr, arr_len",
[
([3, 1, 2], 3),
(np.ones(50), 50),
((jnp.nan * jnp.ones(250)).reshape((50, -1)), 250),
],
)
def test_positive_ints_like(arr, arr_len):
"""
Test the _positive_ints_like helper function.
"""
result = pmath._positive_ints_like(arr)
expected = jnp.arange(1, arr_len + 1)
assert_array_equal(result, expected)


@pytest.mark.parametrize(
"R, G",
[
(5, rng.dirichlet(np.ones(2))),
(0.2, rng.dirichlet(np.ones(50))),
(1, rng.dirichlet(np.ones(10))),
(1.01, rng.dirichlet(np.ones(4))),
(0.99, rng.dirichlet(np.ones(6))),
],
)
def test_r_approx(R, G):
"""
Test that r_approx_from_R gives answers
consistent with those gained from a Leslie
matrix approach.
"""
r_val = pmath.r_approx_from_R(R, G, n_newton_steps=5)
e_val, stable_dist = pmath.get_asymptotic_growth_rate_and_age_dist(R, G)

unnormed = r_val * stable_dist
if r_val != 0:
assert_array_almost_equal(unnormed / np.sum(unnormed), stable_dist)
else:
assert_almost_equal(e_val, 1, decimal=5)


def test_asymptotic_properties():
"""
Check that the calculated
asymptotic growth rate and
age distribution given by
get_asymptotic_growth_rate()
and get_stable_age_distribution()
agree with simulated ones from
just running a process for a
while.
"""
R = 1.2
gi = np.array([0.2, 0.1, 0.2, 0.15, 0.05, 0.025, 0.025, 0.25])
A = pmath.get_leslie_matrix(R, gi)

# check via Leslie matrix multiplication
x = np.array([1, 0, 0, 0, 0, 0, 0, 0])
for i in range(1000):
x_new = A @ x
rat_x = np.sum(x_new) / np.sum(x)
x = x_new

assert_almost_equal(
rat_x, pmath.get_asymptotic_growth_rate(R, gi), decimal=5
)
assert_array_almost_equal(
x / np.sum(x), pmath.get_stable_age_distribution(R, gi)
)

# check via backward-looking convolution
y = np.array([1, 0, 0, 0, 0, 0, 0, 0])
for j in range(1000):
new_pop = np.dot(y, R * gi)
rat_y = new_pop / y[0]
y = np.hstack([new_pop, y[:-1]])
assert_almost_equal(
rat_y, pmath.get_asymptotic_growth_rate(R, gi), decimal=5
)
assert_array_almost_equal(
y / np.sum(x), pmath.get_stable_age_distribution(R, gi)
)


@pytest.mark.parametrize(
"R, gi, expected",
[
(
0.4,
np.array([0.4, 0.2, 0.2, 0.1, 0.1]),
np.array(
[
[0.16, 0.08, 0.08, 0.04, 0.04],
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
]
),
),
(
3,
np.array([0.4, 0.2, 0.2, 0.1, 0.1]),
np.array(
[
[1.2, 0.6, 0.6, 0.3, 0.3],
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
]
),
),
],
)
def test_get_leslie(R, gi, expected):
"""
Test that get_leslie matrix
returns expected Leslie matrices
"""
assert_array_almost_equal(pmath.get_leslie_matrix(R, gi), expected)
49 changes: 0 additions & 49 deletions test/test_process_asymptotics.py

This file was deleted.

0 comments on commit 5c29559

Please sign in to comment.