Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Dalton committed Apr 25, 2024
1 parent a7f5c96 commit c98eba8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
22 changes: 10 additions & 12 deletions src/rs_distributions/distributions/rice.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import torch
import math
from torch import distributions as dist
Expand Down Expand Up @@ -75,10 +74,11 @@ def log_prob(self, value):
self._validate_sample(value)
nu, sigma = self.nu, self.sigma
x = value
log_prob = \
torch.log(x) - 2.*torch.log(sigma) - \
0.5 * torch.square((x-nu)/sigma) + torch.log(
torch.special.i0e(nu * x / (sigma*sigma))
log_prob = (
torch.log(x)
- 2.0 * torch.log(sigma)
- 0.5 * torch.square((x - nu) / sigma)
+ torch.log(torch.special.i0e(nu * x / (sigma * sigma)))
)
return log_prob

Expand Down Expand Up @@ -115,7 +115,7 @@ def mean(self):
nu = self.nu

x = -0.5 * torch.square(nu / sigma)
L = (1. - x) * torch.special.i0e(-0.5*x) - x * torch.special.i1e(-0.5 * x)
L = (1.0 - x) * torch.special.i0e(-0.5 * x) - x * torch.special.i1e(-0.5 * x)
mean = sigma * math.sqrt(math.pi / 2.0) * L
return mean

Expand All @@ -127,10 +127,8 @@ def variance(self):
Returns:
Tensor: The variance of the distribution
"""
nu,sigma = self.nu,self.sigma
n2 = nu * nu
sigma = self.sigma
return 2*sigma*sigma + nu*nu - torch.square(self.mean)
nu, sigma = self.nu, self.sigma
return 2 * sigma * sigma + nu * nu - torch.square(self.mean)

def cdf(self, value):
"""
Expand All @@ -156,8 +154,8 @@ def _grad_z(self, samples):
z = samples
nu, sigma = self.nu, self.sigma
ab = z * nu / (sigma * sigma)
dnu = torch.special.i1e(ab) / torch.special.i0e(ab) #== i1(ab)/i0(ab)
dsigma = (z - nu * dnu)/sigma
dnu = torch.special.i1e(ab) / torch.special.i0e(ab) # == i1(ab)/i0(ab)
dsigma = (z - nu * dnu) / sigma
return dnu, dsigma

def pdf(self, value):
Expand Down
24 changes: 13 additions & 11 deletions tests/distributions/test_rice.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def test_rice_execution(test_float_broadcasting, batch_shape, sample_shape):
q.log_prob(z)
torch.autograd.grad(z.sum(), params)

def test_rice_against_scipy(dtype='float32', snr_cutoff=10., log_min=-12, log_max=2, rtol=1e-5):

def test_rice_against_scipy(
dtype="float32", snr_cutoff=10.0, log_min=-12, log_max=2, rtol=1e-5
):
"""
Test the following attributes of Rice against scipy.stats.rice
- mean
Expand All @@ -36,13 +39,13 @@ def test_rice_against_scipy(dtype='float32', snr_cutoff=10., log_min=-12, log_ma
"""
log_min, log_max = -12, 2
nu = sigma = np.logspace(log_min, log_max, log_max - log_min + 1, dtype=dtype)
nu,sigma = np.meshgrid(nu, sigma)
nu,sigma = nu.flatten(),sigma.flatten()
nu, sigma = np.meshgrid(nu, sigma)
nu, sigma = nu.flatten(), sigma.flatten()
idx = nu / sigma < snr_cutoff
nu,sigma = nu[idx],sigma[idx]
nu, sigma = nu[idx], sigma[idx]

q = Rice(
torch.as_tensor(nu),
torch.as_tensor(nu),
torch.as_tensor(sigma),
)

Expand All @@ -55,16 +58,15 @@ def test_rice_against_scipy(dtype='float32', snr_cutoff=10., log_min=-12, log_ma
assert np.allclose(stddev, result, rtol=rtol)

z = np.linspace(
np.maximum(mean - 3.*stddev, 0.),
mean + 3.*stddev,
np.maximum(mean - 3.0 * stddev, 0.0),
mean + 3.0 * stddev,
10,
)

log_prob = rice.logpdf(z, nu/sigma, scale=sigma)
log_prob = rice.logpdf(z, nu / sigma, scale=sigma)
result = q.log_prob(torch.as_tensor(z)).detach().numpy()
assert np.allclose(log_prob, result, rtol=rtol)
pdf = rice.pdf(z, nu/sigma, scale=sigma)

pdf = rice.pdf(z, nu / sigma, scale=sigma)
result = q.pdf(torch.as_tensor(z)).detach().numpy()
assert np.allclose(pdf, result, rtol=rtol)

0 comments on commit c98eba8

Please sign in to comment.