From c98eba829ad8ff9ff36496d67849828762357290 Mon Sep 17 00:00:00 2001 From: Kevin Dalton Date: Thu, 25 Apr 2024 14:31:09 -0400 Subject: [PATCH] fmt --- src/rs_distributions/distributions/rice.py | 22 +++++++++----------- tests/distributions/test_rice.py | 24 ++++++++++++---------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/rs_distributions/distributions/rice.py b/src/rs_distributions/distributions/rice.py index f03db07..73adbd0 100644 --- a/src/rs_distributions/distributions/rice.py +++ b/src/rs_distributions/distributions/rice.py @@ -1,4 +1,3 @@ -import numpy as np import torch import math from torch import distributions as dist @@ -75,10 +74,11 @@ def log_prob(self, value): self._validate_sample(value) nu, sigma = self.nu, self.sigma x = value - log_prob = \ - torch.log(x) - 2.*torch.log(sigma) - \ - 0.5 * torch.square((x-nu)/sigma) + torch.log( - torch.special.i0e(nu * x / (sigma*sigma)) + log_prob = ( + torch.log(x) + - 2.0 * torch.log(sigma) + - 0.5 * torch.square((x - nu) / sigma) + + torch.log(torch.special.i0e(nu * x / (sigma * sigma))) ) return log_prob @@ -115,7 +115,7 @@ def mean(self): nu = self.nu x = -0.5 * torch.square(nu / sigma) - L = (1. - x) * torch.special.i0e(-0.5*x) - x * torch.special.i1e(-0.5 * x) + L = (1.0 - x) * torch.special.i0e(-0.5 * x) - x * torch.special.i1e(-0.5 * x) mean = sigma * math.sqrt(math.pi / 2.0) * L return mean @@ -127,10 +127,8 @@ def variance(self): Returns: Tensor: The variance of the distribution """ - nu,sigma = self.nu,self.sigma - n2 = nu * nu - sigma = self.sigma - return 2*sigma*sigma + nu*nu - torch.square(self.mean) + nu, sigma = self.nu, self.sigma + return 2 * sigma * sigma + nu * nu - torch.square(self.mean) def cdf(self, value): """ @@ -156,8 +154,8 @@ def _grad_z(self, samples): z = samples nu, sigma = self.nu, self.sigma ab = z * nu / (sigma * sigma) - dnu = torch.special.i1e(ab) / torch.special.i0e(ab) #== i1(ab)/i0(ab) - dsigma = (z - nu * dnu)/sigma + dnu = torch.special.i1e(ab) / torch.special.i0e(ab) # == i1(ab)/i0(ab) + dsigma = (z - nu * dnu) / sigma return dnu, dsigma def pdf(self, value): diff --git a/tests/distributions/test_rice.py b/tests/distributions/test_rice.py index c6e75ad..60192c8 100644 --- a/tests/distributions/test_rice.py +++ b/tests/distributions/test_rice.py @@ -22,7 +22,10 @@ def test_rice_execution(test_float_broadcasting, batch_shape, sample_shape): q.log_prob(z) torch.autograd.grad(z.sum(), params) -def test_rice_against_scipy(dtype='float32', snr_cutoff=10., log_min=-12, log_max=2, rtol=1e-5): + +def test_rice_against_scipy( + dtype="float32", snr_cutoff=10.0, log_min=-12, log_max=2, rtol=1e-5 +): """ Test the following attributes of Rice against scipy.stats.rice - mean @@ -36,13 +39,13 @@ def test_rice_against_scipy(dtype='float32', snr_cutoff=10., log_min=-12, log_ma """ log_min, log_max = -12, 2 nu = sigma = np.logspace(log_min, log_max, log_max - log_min + 1, dtype=dtype) - nu,sigma = np.meshgrid(nu, sigma) - nu,sigma = nu.flatten(),sigma.flatten() + nu, sigma = np.meshgrid(nu, sigma) + nu, sigma = nu.flatten(), sigma.flatten() idx = nu / sigma < snr_cutoff - nu,sigma = nu[idx],sigma[idx] + nu, sigma = nu[idx], sigma[idx] q = Rice( - torch.as_tensor(nu), + torch.as_tensor(nu), torch.as_tensor(sigma), ) @@ -55,16 +58,15 @@ def test_rice_against_scipy(dtype='float32', snr_cutoff=10., log_min=-12, log_ma assert np.allclose(stddev, result, rtol=rtol) z = np.linspace( - np.maximum(mean - 3.*stddev, 0.), - mean + 3.*stddev, + np.maximum(mean - 3.0 * stddev, 0.0), + mean + 3.0 * stddev, 10, ) - log_prob = rice.logpdf(z, nu/sigma, scale=sigma) + log_prob = rice.logpdf(z, nu / sigma, scale=sigma) result = q.log_prob(torch.as_tensor(z)).detach().numpy() assert np.allclose(log_prob, result, rtol=rtol) - - pdf = rice.pdf(z, nu/sigma, scale=sigma) + + pdf = rice.pdf(z, nu / sigma, scale=sigma) result = q.pdf(torch.as_tensor(z)).detach().numpy() assert np.allclose(pdf, result, rtol=rtol) -