diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 88f4079..5f60a87 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,22 +18,43 @@ jobs: run: name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} runs-on: ${{ matrix.os }} + + # skip build of commit contains 'skip ci' + if: "!contains(github.event.head_commit.message, 'skip ci')" + strategy: fail-fast: false matrix: - os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + python-version: + - "3.8" + - "3.9" + - "3.10" + - "3.11" + - "3.12" + os: + - "windows-latest" + - "macos-latest" + - "ubuntu-latest" + # Based on peps #3763 (https://github.com/python/peps/pull/3763) via + # setup-python #850 (https://github.com/actions/setup-python/issues/850) + exclude: + - { python-version: "3.8", os: "macos-latest" } + - { python-version: "3.9", os: "macos-latest" } + include: + - { python-version: "3.8", os: "macos-13" } + - { python-version: "3.9", os: "macos-13" } steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install Hatch run: | + pip install --upgrade pip pip install --upgrade hatch pip install hatch-mkdocs #This is a workaround for https://github.com/pypa/hatch/issues/1379 diff --git a/pyproject.toml b/pyproject.toml index 87b076c..896de5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ path = "src/rs_distributions/__about__.py" dependencies = [ "coverage[toml]>=6.5", "pytest", + "scipy", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" diff --git a/src/rs_distributions/distributions/__init__.py b/src/rs_distributions/distributions/__init__.py index 59222aa..88d1f5d 100644 --- a/src/rs_distributions/distributions/__init__.py +++ b/src/rs_distributions/distributions/__init__.py @@ -1,5 +1,7 @@ from .folded_normal import FoldedNormal +from .rice import Rice __all__ = [ "FoldedNormal", + "Rice", ] diff --git a/src/rs_distributions/distributions/folded_normal.py b/src/rs_distributions/distributions/folded_normal.py index ea3f361..d18fa98 100644 --- a/src/rs_distributions/distributions/folded_normal.py +++ b/src/rs_distributions/distributions/folded_normal.py @@ -35,8 +35,7 @@ class FoldedNormal(dist.Distribution): support = torch.distributions.constraints.nonnegative def __init__(self, loc, scale, validate_args=None): - self.loc = torch.as_tensor(loc) - self.scale = torch.as_tensor(scale) + self.loc, self.scale = torch.distributions.utils.broadcast_all(loc, scale) batch_shape = self.loc.shape super().__init__(batch_shape, validate_args=validate_args) self._irsample = NormalIRSample().apply diff --git a/src/rs_distributions/distributions/rice.py b/src/rs_distributions/distributions/rice.py new file mode 100644 index 0000000..73adbd0 --- /dev/null +++ b/src/rs_distributions/distributions/rice.py @@ -0,0 +1,179 @@ +import torch +import math +from torch import distributions as dist + + +class RiceIRSample(torch.autograd.Function): + @staticmethod + def forward(ctx, nu, sigma, samples, dnu, dsigma): + ctx.save_for_backward(dnu, dsigma) + return samples + + @staticmethod + def backward(ctx, grad_output): + ( + grad_nu, + grad_sigma, + ) = ctx.saved_tensors + return grad_output * grad_nu, grad_output * grad_sigma, None, None, None + + +class Rice(dist.Distribution): + """ + The Rice distribution is useful for modeling acentric structure factor amplitudes in + X-ray crystallography. It is the amplitude distribution corresponding to a bivariate + normal in the complex plane. + + ``` + x ~ MVN([ν, 0], σI) + y = sqrt(x[0] * x[0] + x[1] * x[1]) + ``` + The parameters ν and σ represent the location and standard deviation of an isotropic, bivariate normal. + If x is drawn from the normal with location [ν, 0i] and covariance, + ``` + | σ 0 | + | 0 σi | + ``` + the distribution of amplitudes, `y = sqrt(x * conjugate(x))`, follows a Rician distribution. + + Args: + nu (float or Tensor): location parameter of the underlying bivariate normal + sigma (float or Tensor): standard deviation of the underlying bivariate normal (must be positive) + validate_args (bool, optional): Whether to validate the arguments of the distribution. + Default is None. + """ + + arg_constraints = { + "nu": dist.constraints.nonnegative, + "sigma": dist.constraints.positive, + } + support = torch.distributions.constraints.nonnegative + + def __init__(self, nu, sigma, validate_args=None): + self.nu, self.sigma = torch.distributions.utils.broadcast_all(nu, sigma) + batch_shape = self.nu.shape + super().__init__(batch_shape, validate_args=validate_args) + self._irsample = RiceIRSample().apply + + def log_prob(self, value): + """ + Compute the log-probability of the given values under the Rice distribution + + ``` + Rice(x | nu, sigma) = \ + x * sigma**-2 * exp(-0.5 * (x**2 + nu**2) * sigma ** -2) * I_0(x * nu * sigma **-2) + ``` + + Args: + value (Tensor): The values at which to evaluate the log-probability + + Returns: + Tensor: The log-probabilities of the given values + """ + if self._validate_args: + self._validate_sample(value) + nu, sigma = self.nu, self.sigma + x = value + log_prob = ( + torch.log(x) + - 2.0 * torch.log(sigma) + - 0.5 * torch.square((x - nu) / sigma) + + torch.log(torch.special.i0e(nu * x / (sigma * sigma))) + ) + return log_prob + + def sample(self, sample_shape=torch.Size()): + """ + Generate random samples from the Rice distribution + + Args: + sample_shape (torch.Size, optional): The shape of the samples to generate. + Default is an empty shape + + Returns: + Tensor: The generated random samples + """ + shape = self._extended_shape(sample_shape) + nu, sigma = self.nu, self.sigma + nu = nu.expand(shape) + sigma = sigma.expand(shape) + with torch.no_grad(): + A = torch.normal(nu, sigma) + B = torch.normal(torch.zeros_like(nu), sigma) + z = torch.sqrt(A * A + B * B) + return z + + @property + def mean(self): + """ + Compute the mean of the Rice distribution + + Returns: + Tensor: The mean of the distribution. + """ + sigma = self.sigma + nu = self.nu + + x = -0.5 * torch.square(nu / sigma) + L = (1.0 - x) * torch.special.i0e(-0.5 * x) - x * torch.special.i1e(-0.5 * x) + mean = sigma * math.sqrt(math.pi / 2.0) * L + return mean + + @property + def variance(self): + """ + Compute the variance of the Rice distribution + + Returns: + Tensor: The variance of the distribution + """ + nu, sigma = self.nu, self.sigma + return 2 * sigma * sigma + nu * nu - torch.square(self.mean) + + def cdf(self, value): + """ + Args: + value (Tensor): The values at which to evaluate the CDF + + Returns: + Tensor: The CDF values at the given values + """ + raise NotImplementedError("The CDF is not implemented") + + def _grad_z(self, samples): + """ + Return the gradient of samples from this distribution + + Args: + samples (Tensor): samples from this distribution + + Returns: + dnu: gradient with respect to the loc parameter, nu + dsigma: gradient with respect to the underlying normal's scale parameter, sigma + """ + z = samples + nu, sigma = self.nu, self.sigma + ab = z * nu / (sigma * sigma) + dnu = torch.special.i1e(ab) / torch.special.i0e(ab) # == i1(ab)/i0(ab) + dsigma = (z - nu * dnu) / sigma + return dnu, dsigma + + def pdf(self, value): + return torch.exp(self.log_prob(value)) + + def rsample(self, sample_shape=torch.Size()): + """ + Generate differentiable random samples from the Rice distribution. + Gradients are implemented using implicit reparameterization (https://arxiv.org/abs/1805.08498). + + Args: + sample_shape (torch.Size, optional): The shape of the samples to generate. + Default is an empty shape + + Returns: + Tensor: The generated random samples + """ + samples = self.sample(sample_shape) + dnu, dsigma = self._grad_z(samples) + samples.requires_grad_(True) + return self._irsample(self.nu, self.sigma, samples, dnu, dsigma) diff --git a/tests/distributions/test_folded_normal.py b/tests/distributions/test_folded_normal.py index 5fea021..2dad315 100644 --- a/tests/distributions/test_folded_normal.py +++ b/tests/distributions/test_folded_normal.py @@ -3,11 +3,14 @@ import torch -@pytest.mark.parametrize("batch_shape", [1, 10]) +@pytest.mark.parametrize("test_float_broadcasting", [False, True]) +@pytest.mark.parametrize("batch_shape", [3, 10]) @pytest.mark.parametrize("sample_shape", [(), (10,)]) -def test_folded_normal_execution(batch_shape, sample_shape): +def test_folded_normal_execution(test_float_broadcasting, batch_shape, sample_shape): params = torch.ones((2, batch_shape), requires_grad=True) loc, scale = params + if test_float_broadcasting: + loc = 1.0 q = FoldedNormal(loc, scale) z = q.rsample(sample_shape) q.mean diff --git a/tests/distributions/test_rice.py b/tests/distributions/test_rice.py new file mode 100644 index 0000000..60192c8 --- /dev/null +++ b/tests/distributions/test_rice.py @@ -0,0 +1,72 @@ +import pytest +from rs_distributions.distributions.rice import Rice +from scipy.stats import rice +import torch +import numpy as np + + +@pytest.mark.parametrize("test_float_broadcasting", [False, True]) +@pytest.mark.parametrize("batch_shape", [3, 10]) +@pytest.mark.parametrize("sample_shape", [(), (10,)]) +def test_rice_execution(test_float_broadcasting, batch_shape, sample_shape): + params = torch.ones((2, batch_shape), requires_grad=True) + loc, scale = params + if test_float_broadcasting: + loc = 1.0 + q = Rice(loc, scale) + z = q.rsample(sample_shape) + q.mean + q.variance + # q.cdf(z) #<-- no cdf implementation + q.pdf(z) + q.log_prob(z) + torch.autograd.grad(z.sum(), params) + + +def test_rice_against_scipy( + dtype="float32", snr_cutoff=10.0, log_min=-12, log_max=2, rtol=1e-5 +): + """ + Test the following attributes of Rice against scipy.stats.rice + - mean + + Args: + dtype (string) : float dtype to determine the precision of the test + snr_cutoff (float) : do not test combinations with nu/sigma values above this threshold + log_min (int) : 10**log_min will be the minimum value tested for nu and sigma + log_max (int) : 10**log_max will be the maximum value tested for nu and sigma + rtol (float) : the relative tolerance for equivalency in tests + """ + log_min, log_max = -12, 2 + nu = sigma = np.logspace(log_min, log_max, log_max - log_min + 1, dtype=dtype) + nu, sigma = np.meshgrid(nu, sigma) + nu, sigma = nu.flatten(), sigma.flatten() + idx = nu / sigma < snr_cutoff + nu, sigma = nu[idx], sigma[idx] + + q = Rice( + torch.as_tensor(nu), + torch.as_tensor(sigma), + ) + + mean = rice.mean(nu / sigma, scale=sigma).astype(dtype) + result = q.mean.detach().numpy() + assert np.allclose(mean, result, rtol=rtol) + + stddev = rice.std(nu / sigma, scale=sigma).astype(dtype) + result = q.stddev.detach().numpy() + assert np.allclose(stddev, result, rtol=rtol) + + z = np.linspace( + np.maximum(mean - 3.0 * stddev, 0.0), + mean + 3.0 * stddev, + 10, + ) + + log_prob = rice.logpdf(z, nu / sigma, scale=sigma) + result = q.log_prob(torch.as_tensor(z)).detach().numpy() + assert np.allclose(log_prob, result, rtol=rtol) + + pdf = rice.pdf(z, nu / sigma, scale=sigma) + result = q.pdf(torch.as_tensor(z)).detach().numpy() + assert np.allclose(pdf, result, rtol=rtol)