Skip to content

Commit

Permalink
add DistributionModule to README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Dalton committed Apr 16, 2024
1 parent f0cc884 commit ad15933
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

- [Installation](#installation)
- [Distributions](#distributions)
- [Modules](#modules)
- [License](#license)

rs-distributions provides statistical tools which are helpful for structural biologists who wish to model their data using variational inference.
Expand Down Expand Up @@ -71,6 +72,35 @@ for i in range(steps):
```
This example uses the folded normal distribution which is important in X-ray crystallography.

## Modules
Working with PyTorch distributions can be a little verbose.
So in addition to the `torch.distributions` style implementation, we provide `DistributionModule` classes which enable learnable distributions with automatic bijections in less code.
These `DistributionModule` classes are subclasses of `torch.nn.Module`.
They automatically instantiate problem parameters as `TransformedParameter` modules following the constraints in the distribution definition.
In the following example, a `FoldedNormal` `DistributionModule` is instantiated with an initial location and scale and trained to match a target distribution.

```python
from rs_distributions import modules as rsm
import torch

loc_init = 10.
scale_init = 5.

q = rsm.FoldedNormal(loc_init, scale_init)
p = torch.distributions.HalfNormal(1.)

opt = torch.optim.Adam(q.parameters())

steps = 10_000
num_samples = 256
for i in range(steps):
opt.zero_grad()
z = q.rsample((num_samples,))
kl = (q.log_prob(z) - p.log_prob(z)).mean()
kl.backward()
opt.step()
```

## License

`rs-distributions` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.

0 comments on commit ad15933

Please sign in to comment.