From 44e055764ca277d99a69f1fa207533d0f81c3bf7 Mon Sep 17 00:00:00 2001 From: Kevin Dalton Date: Tue, 23 Apr 2024 15:09:29 -0400 Subject: [PATCH] simplify gradients --- src/rs_distributions/distributions/rice.py | 44 ++++++++-------------- 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/src/rs_distributions/distributions/rice.py b/src/rs_distributions/distributions/rice.py index 61f1f2e..3387d2e 100644 --- a/src/rs_distributions/distributions/rice.py +++ b/src/rs_distributions/distributions/rice.py @@ -6,10 +6,8 @@ class RiceIRSample(torch.autograd.Function): @staticmethod - def forward(ctx, nu, sigma, samples, dnu, dsigma, dz): - grad_nu = -dnu / dz - grad_sigma = -dsigma / dz - ctx.save_for_backward(grad_nu, grad_sigma) + def forward(ctx, nu, sigma, samples, dnu, dsigma): + ctx.save_for_backward(dnu, dsigma) return samples @staticmethod @@ -18,7 +16,7 @@ def backward(ctx, grad_output): grad_nu, grad_sigma, ) = ctx.saved_tensors - return grad_output * grad_nu, grad_output * grad_sigma, None, None, None, None + return grad_output * grad_nu, grad_output * grad_sigma, None, None, None class Rice(dist.Distribution): @@ -168,9 +166,9 @@ def cdf(self, value): """ raise NotImplementedError("The CDF is not implemented") - def grad_cdf(self, samples): + def _grad_z(self, samples): """ - Return the gradient of the CDF + Return the gradient of samples from this distribution Args: samples (Tensor): samples from this distribution @@ -181,32 +179,20 @@ def grad_cdf(self, samples): """ z = samples nu, sigma = self.nu, self.sigma + log_z, log_nu, log_sigma = torch.log(z), torch.log(nu), torch.log(sigma) log_a = log_nu - log_sigma log_b = log_z - log_sigma ab = torch.exp(log_a + log_b) # <-- argument of bessel functions + log_i0 = self._log_bessel_i0(ab) + log_i1 = self._log_bessel_i1(ab) - # dQ = b*exp(-0.5*(a*a + b*b)) (shared term) - # log_dQ = log(b) -0.5*(a*a + b*b) - # da = -dQ * I_1(a*b) - # -log_da = log(dQ) + log_I1(a*b) - # da = dQ * I_1(a*b) - # log_db = log_dQ + log_I0(a*b) - log_dQ = log_b - 0.5 * (torch.exp(2.0 * log_a) + torch.exp(2.0 * log_b)) - log_da = log_dQ + self._log_bessel_i1(ab) - log_db = log_dQ + self._log_bessel_i0(ab) - - dz = torch.exp(log_db - log_sigma) - dnu = -torch.exp(log_da - log_sigma) - - # Remember the sign of log_a is -1: - # dsigma = -da * nu * sigma**-2 - db * nu * sigma**-2 - # = -log_a_sign * exp(da + log_a - 2*log_sigma) - db * nu * sigma**-2 - # = exp(da + log_a - 2*log_sigma) - db * nu * sigma**-2 - dsigma = torch.exp(log_da + log_nu - 2 * log_sigma) - torch.exp( - log_db + log_z - 2 * log_sigma + dnu = torch.exp(log_i1 - log_i0) + dsigma = -( + torch.exp(log_nu - log_sigma + log_i1 - log_i0) + - torch.exp(log_z - log_sigma) ) - return dnu, dsigma, dz + return dnu, dsigma def pdf(self, value): return torch.exp(self.log_prob(value)) @@ -224,6 +210,6 @@ def rsample(self, sample_shape=torch.Size()): Tensor: The generated random samples """ samples = self.sample(sample_shape) - dnu, dsigma, dz = self.grad_cdf(samples) + dnu, dsigma = self._grad_z(samples) samples.requires_grad_(True) - return self._irsample(self.nu, self.sigma, samples, dnu, dsigma, dz) + return self._irsample(self.nu, self.sigma, samples, dnu, dsigma)