Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rice distribution #4

Merged
merged 16 commits into from
Apr 26, 2024
29 changes: 25 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,43 @@ jobs:
run:
name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }}
runs-on: ${{ matrix.os }}

# skip build of commit contains 'skip ci'
if: "!contains(github.event.head_commit.message, 'skip ci')"

strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
- "3.12"
os:
- "windows-latest"
- "macos-latest"
- "ubuntu-latest"
# Based on peps #3763 (https://github.com/python/peps/pull/3763) via
# setup-python #850 (https://github.com/actions/setup-python/issues/850)
exclude:
kmdalton marked this conversation as resolved.
Show resolved Hide resolved
- { python-version: "3.8", os: "macos-latest" }
- { python-version: "3.9", os: "macos-latest" }
include:
- { python-version: "3.8", os: "macos-13" }
- { python-version: "3.9", os: "macos-13" }

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install Hatch
run: |
pip install --upgrade pip
pip install --upgrade hatch
pip install hatch-mkdocs #This is a workaround for https://github.com/pypa/hatch/issues/1379

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ path = "src/rs_distributions/__about__.py"
dependencies = [
"coverage[toml]>=6.5",
"pytest",
"scipy",
]
[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
Expand Down
2 changes: 2 additions & 0 deletions src/rs_distributions/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .folded_normal import FoldedNormal
from .rice import Rice

__all__ = [
"FoldedNormal",
"Rice",
]
3 changes: 1 addition & 2 deletions src/rs_distributions/distributions/folded_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class FoldedNormal(dist.Distribution):
support = torch.distributions.constraints.nonnegative

def __init__(self, loc, scale, validate_args=None):
self.loc = torch.as_tensor(loc)
self.scale = torch.as_tensor(scale)
self.loc, self.scale = torch.distributions.utils.broadcast_all(loc, scale)
batch_shape = self.loc.shape
super().__init__(batch_shape, validate_args=validate_args)
self._irsample = NormalIRSample().apply
Expand Down
179 changes: 179 additions & 0 deletions src/rs_distributions/distributions/rice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import torch
import math
from torch import distributions as dist


class RiceIRSample(torch.autograd.Function):
kmdalton marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def forward(ctx, nu, sigma, samples, dnu, dsigma):
ctx.save_for_backward(dnu, dsigma)
return samples

@staticmethod
def backward(ctx, grad_output):
(
grad_nu,
grad_sigma,
) = ctx.saved_tensors
return grad_output * grad_nu, grad_output * grad_sigma, None, None, None


class Rice(dist.Distribution):
"""
The Rice distribution is useful for modeling acentric structure factor amplitudes in
X-ray crystallography. It is the amplitude distribution corresponding to a bivariate
normal in the complex plane.

```
x ~ MVN([ν, 0], σI)
y = sqrt(x[0] * x[0] + x[1] * x[1])
```
The parameters ν and σ represent the location and standard deviation of an isotropic, bivariate normal.
If x is drawn from the normal with location [ν, 0i] and covariance,
```
| σ 0 |
| 0 σi |
```
the distribution of amplitudes, `y = sqrt(x * conjugate(x))`, follows a Rician distribution.

Args:
nu (float or Tensor): location parameter of the underlying bivariate normal
sigma (float or Tensor): standard deviation of the underlying bivariate normal (must be positive)
validate_args (bool, optional): Whether to validate the arguments of the distribution.
Default is None.
"""

arg_constraints = {
"nu": dist.constraints.nonnegative,
"sigma": dist.constraints.positive,
kmdalton marked this conversation as resolved.
Show resolved Hide resolved
}
support = torch.distributions.constraints.nonnegative

def __init__(self, nu, sigma, validate_args=None):
self.nu, self.sigma = torch.distributions.utils.broadcast_all(nu, sigma)
batch_shape = self.nu.shape
super().__init__(batch_shape, validate_args=validate_args)
self._irsample = RiceIRSample().apply

def log_prob(self, value):
"""
Compute the log-probability of the given values under the Rice distribution

```
Rice(x | nu, sigma) = \
x * sigma**-2 * exp(-0.5 * (x**2 + nu**2) * sigma ** -2) * I_0(x * nu * sigma **-2)
```

Args:
value (Tensor): The values at which to evaluate the log-probability

Returns:
Tensor: The log-probabilities of the given values
"""
if self._validate_args:
self._validate_sample(value)
nu, sigma = self.nu, self.sigma
x = value
log_prob = (
torch.log(x)
- 2.0 * torch.log(sigma)
- 0.5 * torch.square((x - nu) / sigma)
+ torch.log(torch.special.i0e(nu * x / (sigma * sigma)))
)
return log_prob

def sample(self, sample_shape=torch.Size()):
"""
Generate random samples from the Rice distribution

Args:
sample_shape (torch.Size, optional): The shape of the samples to generate.
Default is an empty shape

Returns:
Tensor: The generated random samples
"""
shape = self._extended_shape(sample_shape)
nu, sigma = self.nu, self.sigma
nu = nu.expand(shape)
sigma = sigma.expand(shape)
with torch.no_grad():
A = torch.normal(nu, sigma)
B = torch.normal(torch.zeros_like(nu), sigma)
z = torch.sqrt(A * A + B * B)
return z

@property
def mean(self):
"""
Compute the mean of the Rice distribution

Returns:
Tensor: The mean of the distribution.
"""
sigma = self.sigma
nu = self.nu

x = -0.5 * torch.square(nu / sigma)
L = (1.0 - x) * torch.special.i0e(-0.5 * x) - x * torch.special.i1e(-0.5 * x)
mean = sigma * math.sqrt(math.pi / 2.0) * L
return mean

@property
def variance(self):
"""
Compute the variance of the Rice distribution

Returns:
Tensor: The variance of the distribution
"""
nu, sigma = self.nu, self.sigma
return 2 * sigma * sigma + nu * nu - torch.square(self.mean)

def cdf(self, value):
"""
Args:
value (Tensor): The values at which to evaluate the CDF

Returns:
Tensor: The CDF values at the given values
"""
raise NotImplementedError("The CDF is not implemented")
LuisA92 marked this conversation as resolved.
Show resolved Hide resolved

def _grad_z(self, samples):
"""
Return the gradient of samples from this distribution

Args:
samples (Tensor): samples from this distribution

Returns:
dnu: gradient with respect to the loc parameter, nu
dsigma: gradient with respect to the underlying normal's scale parameter, sigma
"""
z = samples
nu, sigma = self.nu, self.sigma
ab = z * nu / (sigma * sigma)
dnu = torch.special.i1e(ab) / torch.special.i0e(ab) # == i1(ab)/i0(ab)
dsigma = (z - nu * dnu) / sigma
return dnu, dsigma

def pdf(self, value):
return torch.exp(self.log_prob(value))

def rsample(self, sample_shape=torch.Size()):
"""
Generate differentiable random samples from the Rice distribution.
Gradients are implemented using implicit reparameterization (https://arxiv.org/abs/1805.08498).

Args:
sample_shape (torch.Size, optional): The shape of the samples to generate.
Default is an empty shape

Returns:
Tensor: The generated random samples
"""
samples = self.sample(sample_shape)
dnu, dsigma = self._grad_z(samples)
samples.requires_grad_(True)
return self._irsample(self.nu, self.sigma, samples, dnu, dsigma)
7 changes: 5 additions & 2 deletions tests/distributions/test_folded_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import torch


@pytest.mark.parametrize("batch_shape", [1, 10])
@pytest.mark.parametrize("test_float_broadcasting", [False, True])
@pytest.mark.parametrize("batch_shape", [3, 10])
@pytest.mark.parametrize("sample_shape", [(), (10,)])
def test_folded_normal_execution(batch_shape, sample_shape):
def test_folded_normal_execution(test_float_broadcasting, batch_shape, sample_shape):
params = torch.ones((2, batch_shape), requires_grad=True)
loc, scale = params
if test_float_broadcasting:
loc = 1.0
q = FoldedNormal(loc, scale)
z = q.rsample(sample_shape)
q.mean
Expand Down
72 changes: 72 additions & 0 deletions tests/distributions/test_rice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
from rs_distributions.distributions.rice import Rice
from scipy.stats import rice
import torch
import numpy as np


@pytest.mark.parametrize("test_float_broadcasting", [False, True])
@pytest.mark.parametrize("batch_shape", [3, 10])
@pytest.mark.parametrize("sample_shape", [(), (10,)])
def test_rice_execution(test_float_broadcasting, batch_shape, sample_shape):
minhuanli marked this conversation as resolved.
Show resolved Hide resolved
params = torch.ones((2, batch_shape), requires_grad=True)
loc, scale = params
if test_float_broadcasting:
loc = 1.0
q = Rice(loc, scale)
z = q.rsample(sample_shape)
q.mean
q.variance
# q.cdf(z) #<-- no cdf implementation
q.pdf(z)
q.log_prob(z)
torch.autograd.grad(z.sum(), params)


def test_rice_against_scipy(
dtype="float32", snr_cutoff=10.0, log_min=-12, log_max=2, rtol=1e-5
):
"""
Test the following attributes of Rice against scipy.stats.rice
- mean

Args:
dtype (string) : float dtype to determine the precision of the test
snr_cutoff (float) : do not test combinations with nu/sigma values above this threshold
log_min (int) : 10**log_min will be the minimum value tested for nu and sigma
log_max (int) : 10**log_max will be the maximum value tested for nu and sigma
rtol (float) : the relative tolerance for equivalency in tests
"""
log_min, log_max = -12, 2
nu = sigma = np.logspace(log_min, log_max, log_max - log_min + 1, dtype=dtype)
nu, sigma = np.meshgrid(nu, sigma)
nu, sigma = nu.flatten(), sigma.flatten()
idx = nu / sigma < snr_cutoff
nu, sigma = nu[idx], sigma[idx]

q = Rice(
torch.as_tensor(nu),
torch.as_tensor(sigma),
)

mean = rice.mean(nu / sigma, scale=sigma).astype(dtype)
result = q.mean.detach().numpy()
assert np.allclose(mean, result, rtol=rtol)

stddev = rice.std(nu / sigma, scale=sigma).astype(dtype)
result = q.stddev.detach().numpy()
assert np.allclose(stddev, result, rtol=rtol)

z = np.linspace(
np.maximum(mean - 3.0 * stddev, 0.0),
mean + 3.0 * stddev,
10,
)

log_prob = rice.logpdf(z, nu / sigma, scale=sigma)
result = q.log_prob(torch.as_tensor(z)).detach().numpy()
assert np.allclose(log_prob, result, rtol=rtol)

pdf = rice.pdf(z, nu / sigma, scale=sigma)
result = q.pdf(torch.as_tensor(z)).detach().numpy()
assert np.allclose(pdf, result, rtol=rtol)
Loading