Skip to content

Commit

Permalink
simplify gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Dalton committed Apr 23, 2024
1 parent 5364733 commit 44e0557
Showing 1 changed file with 15 additions and 29 deletions.
44 changes: 15 additions & 29 deletions src/rs_distributions/distributions/rice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@

class RiceIRSample(torch.autograd.Function):
@staticmethod
def forward(ctx, nu, sigma, samples, dnu, dsigma, dz):
grad_nu = -dnu / dz
grad_sigma = -dsigma / dz
ctx.save_for_backward(grad_nu, grad_sigma)
def forward(ctx, nu, sigma, samples, dnu, dsigma):
ctx.save_for_backward(dnu, dsigma)
return samples

@staticmethod
Expand All @@ -18,7 +16,7 @@ def backward(ctx, grad_output):
grad_nu,
grad_sigma,
) = ctx.saved_tensors
return grad_output * grad_nu, grad_output * grad_sigma, None, None, None, None
return grad_output * grad_nu, grad_output * grad_sigma, None, None, None


class Rice(dist.Distribution):
Expand Down Expand Up @@ -168,9 +166,9 @@ def cdf(self, value):
"""
raise NotImplementedError("The CDF is not implemented")

def grad_cdf(self, samples):
def _grad_z(self, samples):
"""
Return the gradient of the CDF
Return the gradient of samples from this distribution
Args:
samples (Tensor): samples from this distribution
Expand All @@ -181,32 +179,20 @@ def grad_cdf(self, samples):
"""
z = samples
nu, sigma = self.nu, self.sigma

log_z, log_nu, log_sigma = torch.log(z), torch.log(nu), torch.log(sigma)
log_a = log_nu - log_sigma
log_b = log_z - log_sigma
ab = torch.exp(log_a + log_b) # <-- argument of bessel functions
log_i0 = self._log_bessel_i0(ab)
log_i1 = self._log_bessel_i1(ab)

# dQ = b*exp(-0.5*(a*a + b*b)) (shared term)
# log_dQ = log(b) -0.5*(a*a + b*b)
# da = -dQ * I_1(a*b)
# -log_da = log(dQ) + log_I1(a*b)
# da = dQ * I_1(a*b)
# log_db = log_dQ + log_I0(a*b)
log_dQ = log_b - 0.5 * (torch.exp(2.0 * log_a) + torch.exp(2.0 * log_b))
log_da = log_dQ + self._log_bessel_i1(ab)
log_db = log_dQ + self._log_bessel_i0(ab)

dz = torch.exp(log_db - log_sigma)
dnu = -torch.exp(log_da - log_sigma)

# Remember the sign of log_a is -1:
# dsigma = -da * nu * sigma**-2 - db * nu * sigma**-2
# = -log_a_sign * exp(da + log_a - 2*log_sigma) - db * nu * sigma**-2
# = exp(da + log_a - 2*log_sigma) - db * nu * sigma**-2
dsigma = torch.exp(log_da + log_nu - 2 * log_sigma) - torch.exp(
log_db + log_z - 2 * log_sigma
dnu = torch.exp(log_i1 - log_i0)
dsigma = -(
torch.exp(log_nu - log_sigma + log_i1 - log_i0)
- torch.exp(log_z - log_sigma)
)
return dnu, dsigma, dz
return dnu, dsigma

def pdf(self, value):
return torch.exp(self.log_prob(value))
Expand All @@ -224,6 +210,6 @@ def rsample(self, sample_shape=torch.Size()):
Tensor: The generated random samples
"""
samples = self.sample(sample_shape)
dnu, dsigma, dz = self.grad_cdf(samples)
dnu, dsigma = self._grad_z(samples)
samples.requires_grad_(True)
return self._irsample(self.nu, self.sigma, samples, dnu, dsigma, dz)
return self._irsample(self.nu, self.sigma, samples, dnu, dsigma)

0 comments on commit 44e0557

Please sign in to comment.