Skip to content

Commit

Permalink
simplify rice and add tests against scipy
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Dalton committed Apr 25, 2024
1 parent 44e0557 commit a7f5c96
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 50 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ path = "src/rs_distributions/__about__.py"
dependencies = [
"coverage[toml]>=6.5",
"pytest",
"scipy",
]
[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
Expand Down
66 changes: 16 additions & 50 deletions src/rs_distributions/distributions/rice.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,6 @@ def __init__(self, nu, sigma, validate_args=None):
super().__init__(batch_shape, validate_args=validate_args)
self._irsample = RiceIRSample().apply

def _log_bessel_i0(self, x):
return torch.log(torch.special.i0e(x)) + torch.abs(x)

def _log_bessel_i1(self, x):
return torch.log(torch.special.i1e(x)) + torch.abs(x)

def _laguerre_half(self, x):
return (1.0 - x) * torch.exp(
x / 2.0 + self._log_bessel_i0(-0.5 * x)
) - x * torch.exp(x / 2.0 + self._log_bessel_i1(-0.5 * x))

def log_prob(self, value):
"""
Compute the log-probability of the given values under the Rice distribution
Expand All @@ -82,20 +71,15 @@ def log_prob(self, value):
Returns:
Tensor: The log-probabilities of the given values
"""
nu, sigma = self.nu, self.sigma

if self._validate_args:
self._validate_sample(value)

log_sigma = torch.log(sigma)
nu, sigma = self.nu, self.sigma
x = value
log_x = torch.log(value)
log_nu = torch.log(nu)
i0_arg = torch.exp(log_x + log_nu - 2.0 * log_sigma)

log_prob = log_x - 2.0 * log_sigma - 0.5 * (x * x + nu * nu) / (sigma * sigma)
log_prob += self._log_bessel_i0(i0_arg)

log_prob = \
torch.log(x) - 2.*torch.log(sigma) - \
0.5 * torch.square((x-nu)/sigma) + torch.log(
torch.special.i0e(nu * x / (sigma*sigma))
)
return log_prob

def sample(self, sample_shape=torch.Size()):
Expand Down Expand Up @@ -129,11 +113,10 @@ def mean(self):
"""
sigma = self.sigma
nu = self.nu
mean = (
sigma
* math.sqrt(math.pi / 2.0)
* self._laguerre_half(-0.5 * (nu / sigma) ** 2)
)

x = -0.5 * torch.square(nu / sigma)
L = (1. - x) * torch.special.i0e(-0.5*x) - x * torch.special.i1e(-0.5 * x)
mean = sigma * math.sqrt(math.pi / 2.0) * L
return mean

@property
Expand All @@ -144,17 +127,10 @@ def variance(self):
Returns:
Tensor: The variance of the distribution
"""
nu,sigma = self.nu,self.sigma
n2 = nu * nu
sigma = self.sigma
nu = self.nu
variance = (
2 * sigma**2.0
+ nu**2.0
- 0.5
* np.pi
* sigma**2.0
* self._laguerre_half(-0.5 * (nu / sigma) ** 2) ** 2.0
)
return variance
return 2*sigma*sigma + nu*nu - torch.square(self.mean)

def cdf(self, value):
"""
Expand All @@ -179,19 +155,9 @@ def _grad_z(self, samples):
"""
z = samples
nu, sigma = self.nu, self.sigma

log_z, log_nu, log_sigma = torch.log(z), torch.log(nu), torch.log(sigma)
log_a = log_nu - log_sigma
log_b = log_z - log_sigma
ab = torch.exp(log_a + log_b) # <-- argument of bessel functions
log_i0 = self._log_bessel_i0(ab)
log_i1 = self._log_bessel_i1(ab)

dnu = torch.exp(log_i1 - log_i0)
dsigma = -(
torch.exp(log_nu - log_sigma + log_i1 - log_i0)
- torch.exp(log_z - log_sigma)
)
ab = z * nu / (sigma * sigma)
dnu = torch.special.i1e(ab) / torch.special.i0e(ab) #== i1(ab)/i0(ab)
dsigma = (z - nu * dnu)/sigma
return dnu, dsigma

def pdf(self, value):
Expand Down
49 changes: 49 additions & 0 deletions tests/distributions/test_rice.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
from rs_distributions.distributions.rice import Rice
from scipy.stats import rice
import torch
import numpy as np


@pytest.mark.parametrize("test_float_broadcasting", [False, True])
Expand All @@ -19,3 +21,50 @@ def test_rice_execution(test_float_broadcasting, batch_shape, sample_shape):
q.pdf(z)
q.log_prob(z)
torch.autograd.grad(z.sum(), params)

def test_rice_against_scipy(dtype='float32', snr_cutoff=10., log_min=-12, log_max=2, rtol=1e-5):
"""
Test the following attributes of Rice against scipy.stats.rice
- mean
Args:
dtype (string) : float dtype to determine the precision of the test
snr_cutoff (float) : do not test combinations with nu/sigma values above this threshold
log_min (int) : 10**log_min will be the minimum value tested for nu and sigma
log_max (int) : 10**log_max will be the maximum value tested for nu and sigma
rtol (float) : the relative tolerance for equivalency in tests
"""
log_min, log_max = -12, 2
nu = sigma = np.logspace(log_min, log_max, log_max - log_min + 1, dtype=dtype)
nu,sigma = np.meshgrid(nu, sigma)
nu,sigma = nu.flatten(),sigma.flatten()
idx = nu / sigma < snr_cutoff
nu,sigma = nu[idx],sigma[idx]

q = Rice(
torch.as_tensor(nu),
torch.as_tensor(sigma),
)

mean = rice.mean(nu / sigma, scale=sigma).astype(dtype)
result = q.mean.detach().numpy()
assert np.allclose(mean, result, rtol=rtol)

stddev = rice.std(nu / sigma, scale=sigma).astype(dtype)
result = q.stddev.detach().numpy()
assert np.allclose(stddev, result, rtol=rtol)

z = np.linspace(
np.maximum(mean - 3.*stddev, 0.),
mean + 3.*stddev,
10,
)

log_prob = rice.logpdf(z, nu/sigma, scale=sigma)
result = q.log_prob(torch.as_tensor(z)).detach().numpy()
assert np.allclose(log_prob, result, rtol=rtol)

pdf = rice.pdf(z, nu/sigma, scale=sigma)
result = q.pdf(torch.as_tensor(z)).detach().numpy()
assert np.allclose(pdf, result, rtol=rtol)

0 comments on commit a7f5c96

Please sign in to comment.