Skip to content

Commit

Permalink
Add Rice distribution (#4)
Browse files Browse the repository at this point in the history
This PR adds a Rice distribution as well as tests for many of its methods based on scipy.stats. It also restores CI which was broken on macos-latest with python 3.8/3.9.
  • Loading branch information
kmdalton authored Apr 26, 2024
1 parent 3047a85 commit 807b47b
Show file tree
Hide file tree
Showing 7 changed files with 285 additions and 8 deletions.
29 changes: 25 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,43 @@ jobs:
run:
name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }}
runs-on: ${{ matrix.os }}

# skip build of commit contains 'skip ci'
if: "!contains(github.event.head_commit.message, 'skip ci')"

strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
- "3.12"
os:
- "windows-latest"
- "macos-latest"
- "ubuntu-latest"
# Based on peps #3763 (https://github.com/python/peps/pull/3763) via
# setup-python #850 (https://github.com/actions/setup-python/issues/850)
exclude:
- { python-version: "3.8", os: "macos-latest" }
- { python-version: "3.9", os: "macos-latest" }
include:
- { python-version: "3.8", os: "macos-13" }
- { python-version: "3.9", os: "macos-13" }

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install Hatch
run: |
pip install --upgrade pip
pip install --upgrade hatch
pip install hatch-mkdocs #This is a workaround for https://github.com/pypa/hatch/issues/1379
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ path = "src/rs_distributions/__about__.py"
dependencies = [
"coverage[toml]>=6.5",
"pytest",
"scipy",
]
[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
Expand Down
2 changes: 2 additions & 0 deletions src/rs_distributions/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .folded_normal import FoldedNormal
from .rice import Rice

__all__ = [
"FoldedNormal",
"Rice",
]
3 changes: 1 addition & 2 deletions src/rs_distributions/distributions/folded_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class FoldedNormal(dist.Distribution):
support = torch.distributions.constraints.nonnegative

def __init__(self, loc, scale, validate_args=None):
self.loc = torch.as_tensor(loc)
self.scale = torch.as_tensor(scale)
self.loc, self.scale = torch.distributions.utils.broadcast_all(loc, scale)
batch_shape = self.loc.shape
super().__init__(batch_shape, validate_args=validate_args)
self._irsample = NormalIRSample().apply
Expand Down
179 changes: 179 additions & 0 deletions src/rs_distributions/distributions/rice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import torch
import math
from torch import distributions as dist


class RiceIRSample(torch.autograd.Function):
@staticmethod
def forward(ctx, nu, sigma, samples, dnu, dsigma):
ctx.save_for_backward(dnu, dsigma)
return samples

@staticmethod
def backward(ctx, grad_output):
(
grad_nu,
grad_sigma,
) = ctx.saved_tensors
return grad_output * grad_nu, grad_output * grad_sigma, None, None, None


class Rice(dist.Distribution):
"""
The Rice distribution is useful for modeling acentric structure factor amplitudes in
X-ray crystallography. It is the amplitude distribution corresponding to a bivariate
normal in the complex plane.
```
x ~ MVN([ν, 0], σI)
y = sqrt(x[0] * x[0] + x[1] * x[1])
```
The parameters ν and σ represent the location and standard deviation of an isotropic, bivariate normal.
If x is drawn from the normal with location [ν, 0i] and covariance,
```
| σ 0 |
| 0 σi |
```
the distribution of amplitudes, `y = sqrt(x * conjugate(x))`, follows a Rician distribution.
Args:
nu (float or Tensor): location parameter of the underlying bivariate normal
sigma (float or Tensor): standard deviation of the underlying bivariate normal (must be positive)
validate_args (bool, optional): Whether to validate the arguments of the distribution.
Default is None.
"""

arg_constraints = {
"nu": dist.constraints.nonnegative,
"sigma": dist.constraints.positive,
}
support = torch.distributions.constraints.nonnegative

def __init__(self, nu, sigma, validate_args=None):
self.nu, self.sigma = torch.distributions.utils.broadcast_all(nu, sigma)
batch_shape = self.nu.shape
super().__init__(batch_shape, validate_args=validate_args)
self._irsample = RiceIRSample().apply

def log_prob(self, value):
"""
Compute the log-probability of the given values under the Rice distribution
```
Rice(x | nu, sigma) = \
x * sigma**-2 * exp(-0.5 * (x**2 + nu**2) * sigma ** -2) * I_0(x * nu * sigma **-2)
```
Args:
value (Tensor): The values at which to evaluate the log-probability
Returns:
Tensor: The log-probabilities of the given values
"""
if self._validate_args:
self._validate_sample(value)
nu, sigma = self.nu, self.sigma
x = value
log_prob = (
torch.log(x)
- 2.0 * torch.log(sigma)
- 0.5 * torch.square((x - nu) / sigma)
+ torch.log(torch.special.i0e(nu * x / (sigma * sigma)))
)
return log_prob

def sample(self, sample_shape=torch.Size()):
"""
Generate random samples from the Rice distribution
Args:
sample_shape (torch.Size, optional): The shape of the samples to generate.
Default is an empty shape
Returns:
Tensor: The generated random samples
"""
shape = self._extended_shape(sample_shape)
nu, sigma = self.nu, self.sigma
nu = nu.expand(shape)
sigma = sigma.expand(shape)
with torch.no_grad():
A = torch.normal(nu, sigma)
B = torch.normal(torch.zeros_like(nu), sigma)
z = torch.sqrt(A * A + B * B)
return z

@property
def mean(self):
"""
Compute the mean of the Rice distribution
Returns:
Tensor: The mean of the distribution.
"""
sigma = self.sigma
nu = self.nu

x = -0.5 * torch.square(nu / sigma)
L = (1.0 - x) * torch.special.i0e(-0.5 * x) - x * torch.special.i1e(-0.5 * x)
mean = sigma * math.sqrt(math.pi / 2.0) * L
return mean

@property
def variance(self):
"""
Compute the variance of the Rice distribution
Returns:
Tensor: The variance of the distribution
"""
nu, sigma = self.nu, self.sigma
return 2 * sigma * sigma + nu * nu - torch.square(self.mean)

def cdf(self, value):
"""
Args:
value (Tensor): The values at which to evaluate the CDF
Returns:
Tensor: The CDF values at the given values
"""
raise NotImplementedError("The CDF is not implemented")

def _grad_z(self, samples):
"""
Return the gradient of samples from this distribution
Args:
samples (Tensor): samples from this distribution
Returns:
dnu: gradient with respect to the loc parameter, nu
dsigma: gradient with respect to the underlying normal's scale parameter, sigma
"""
z = samples
nu, sigma = self.nu, self.sigma
ab = z * nu / (sigma * sigma)
dnu = torch.special.i1e(ab) / torch.special.i0e(ab) # == i1(ab)/i0(ab)
dsigma = (z - nu * dnu) / sigma
return dnu, dsigma

def pdf(self, value):
return torch.exp(self.log_prob(value))

def rsample(self, sample_shape=torch.Size()):
"""
Generate differentiable random samples from the Rice distribution.
Gradients are implemented using implicit reparameterization (https://arxiv.org/abs/1805.08498).
Args:
sample_shape (torch.Size, optional): The shape of the samples to generate.
Default is an empty shape
Returns:
Tensor: The generated random samples
"""
samples = self.sample(sample_shape)
dnu, dsigma = self._grad_z(samples)
samples.requires_grad_(True)
return self._irsample(self.nu, self.sigma, samples, dnu, dsigma)
7 changes: 5 additions & 2 deletions tests/distributions/test_folded_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import torch


@pytest.mark.parametrize("batch_shape", [1, 10])
@pytest.mark.parametrize("test_float_broadcasting", [False, True])
@pytest.mark.parametrize("batch_shape", [3, 10])
@pytest.mark.parametrize("sample_shape", [(), (10,)])
def test_folded_normal_execution(batch_shape, sample_shape):
def test_folded_normal_execution(test_float_broadcasting, batch_shape, sample_shape):
params = torch.ones((2, batch_shape), requires_grad=True)
loc, scale = params
if test_float_broadcasting:
loc = 1.0
q = FoldedNormal(loc, scale)
z = q.rsample(sample_shape)
q.mean
Expand Down
72 changes: 72 additions & 0 deletions tests/distributions/test_rice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
from rs_distributions.distributions.rice import Rice
from scipy.stats import rice
import torch
import numpy as np


@pytest.mark.parametrize("test_float_broadcasting", [False, True])
@pytest.mark.parametrize("batch_shape", [3, 10])
@pytest.mark.parametrize("sample_shape", [(), (10,)])
def test_rice_execution(test_float_broadcasting, batch_shape, sample_shape):
params = torch.ones((2, batch_shape), requires_grad=True)
loc, scale = params
if test_float_broadcasting:
loc = 1.0
q = Rice(loc, scale)
z = q.rsample(sample_shape)
q.mean
q.variance
# q.cdf(z) #<-- no cdf implementation
q.pdf(z)
q.log_prob(z)
torch.autograd.grad(z.sum(), params)


def test_rice_against_scipy(
dtype="float32", snr_cutoff=10.0, log_min=-12, log_max=2, rtol=1e-5
):
"""
Test the following attributes of Rice against scipy.stats.rice
- mean
Args:
dtype (string) : float dtype to determine the precision of the test
snr_cutoff (float) : do not test combinations with nu/sigma values above this threshold
log_min (int) : 10**log_min will be the minimum value tested for nu and sigma
log_max (int) : 10**log_max will be the maximum value tested for nu and sigma
rtol (float) : the relative tolerance for equivalency in tests
"""
log_min, log_max = -12, 2
nu = sigma = np.logspace(log_min, log_max, log_max - log_min + 1, dtype=dtype)
nu, sigma = np.meshgrid(nu, sigma)
nu, sigma = nu.flatten(), sigma.flatten()
idx = nu / sigma < snr_cutoff
nu, sigma = nu[idx], sigma[idx]

q = Rice(
torch.as_tensor(nu),
torch.as_tensor(sigma),
)

mean = rice.mean(nu / sigma, scale=sigma).astype(dtype)
result = q.mean.detach().numpy()
assert np.allclose(mean, result, rtol=rtol)

stddev = rice.std(nu / sigma, scale=sigma).astype(dtype)
result = q.stddev.detach().numpy()
assert np.allclose(stddev, result, rtol=rtol)

z = np.linspace(
np.maximum(mean - 3.0 * stddev, 0.0),
mean + 3.0 * stddev,
10,
)

log_prob = rice.logpdf(z, nu / sigma, scale=sigma)
result = q.log_prob(torch.as_tensor(z)).detach().numpy()
assert np.allclose(log_prob, result, rtol=rtol)

pdf = rice.pdf(z, nu / sigma, scale=sigma)
result = q.pdf(torch.as_tensor(z)).detach().numpy()
assert np.allclose(pdf, result, rtol=rtol)

0 comments on commit 807b47b

Please sign in to comment.