Skip to content

Commit

Permalink
include pytorch codes for Implicit Reparameterization
Browse files Browse the repository at this point in the history
  • Loading branch information
minhuanli committed Nov 23, 2023
1 parent 1b89a86 commit 8e451b8
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions _posts/2023-09-12-ImplicitReparameterizationTrick.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ comments: true

Deriving gradients from stochastics operations is a persistent headache in various tasks related to Bayesian inference or training generative models. The reparameterization trick has come to our rescue in numerous cases involving continuous random variables, such as the Gaussian distribution. However, many distributions lacking a location-scale parameterization or a tractable inverse cumulative function—like truncated, mixture, Von Mises or Dirichlet distributions—can't be used with reparameterization gradients. The authors proposed an alternative approach called implicit reparameterization trick, in contrast to the classic reparameterization trick, which **provided unbiased estimators for continuous distributions with numerically tractable CDFs**.


<i class='contrast'>Update @ Nov 23, 2023</i> : Attach pytorch codes to demo implicit reparameterization with customized gradient


<i class='contrast'>Reference</i>

Figurnov, Mikhail, Shakir Mohamed, and Andriy Mnih. "Implicit reparameterization gradients." Advances in neural information processing systems 31 (2018)
Expand Down Expand Up @@ -104,6 +108,59 @@ $$
$$
</p>

<i class='contrast'>Pytorch Implementation Example</i>

The following demonstrates the implementation of implicit reparameterization using a Gaussian distribution. It serves as a framework example, considering that Gaussian distribution already possesses a clearly established explicit reparameterization technique. To apply implicit reparameterization sampling to other distributions, you require three key components: a differentiable cumulative distribution function (CDF), various methods for sampling from the distribution, and a probability density function (PDF).

```python
class NormalIRSample(torch.autograd.Function):
@staticmethod
def forward(ctx, loc, scale, samples, dFdmu, dFdsig, q):
dzdmu = -dFdmu/q
dzdsig = -dFdsig/q
ctx.save_for_backward(dzdmu, dzdsig)
return samples

@staticmethod
def backward(ctx, grad_output):
dzdmu, dzdsig, = ctx.saved_tensors
return grad_output * dzdmu, grad_output * dzdsig, None, None, None, None

class IRNormal(torch.distributions.Normal):
def __init__(self, *args, **kwargs):
super(IRNormal, self).__init__(*args, **kwargs)
self._irsample = NormalIRSample().apply

def pdf(self, value):
return torch.exp(self.log_prob(value))

def irsample(self, sample_shape=torch.Size()):
samples = self.sample(sample_shape) # sample without grad
F = self.cdf(samples)
q = self.pdf(samples)
dFdmu = torch.autograd.grad(F, self.loc, retain_graph=True)[0]
dFdsig = torch.autograd.grad(F, self.scale, retain_graph=True)[0]
samples.requires_grad_(True)
return self._irsample(self.loc, self.scale, samples, dFdmu, dFdsig, q)
```

And it works as:

```shell
>>> mu = torch.tensor(1.0, requires_grad=True)
>>> sig = torch.tensor(2.0, requires_grad=True)
>>> dista = IRNormal(mu, sig)
>>> z = dista.irsample()
>>> z
tensor(1.9856, grad_fn=<NormalIRSampleBackward>)
>>> z.backward()
>>> mu.grad
tensor(1.0000)
>>> sig.grad
tensor(0.4928)
```


### <i class='contrast'>Accuracy and speed of reparameterization gradient estimators</i>

In this paper, the author compared the implicit reparameterization estimator with two alternatives, Automatic differentiation with implicit reparameterization achieves the lowest error and the highest speed.
Expand Down

0 comments on commit 8e451b8

Please sign in to comment.