Skip to content

Commit

Permalink
docs: Fix readme example
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Oct 30, 2023
1 parent 5b76d1f commit 3b2114c
Showing 1 changed file with 41 additions and 29 deletions.
70 changes: 41 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,48 +40,60 @@ pip install -e .

```python
import torch
from simple_einet.distributions import RatNormal
from simple_einet.layers.distributions.normal import Normal
from simple_einet.einet import Einet
from simple_einet.einet import EinetConfig

torch.manual_seed(0)

# Input dimensions
in_features = 4
batchsize = 5
out_features = 3
if __name__ == "__main__":
torch.manual_seed(0)

# Create input sample
x = torch.randn(batchsize, in_features)
# Input dimensions
in_features = 4
batchsize = 5

# Construct Einet
einet = Einet(EinetConfig(num_features=in_features, depth=2, num_sums=2, num_channels=1, num_leaves=3, num_repetitions=3, num_classes=out_features, dropout=0.0, leaf_type=Normal))
# Create input sample
x = torch.randn(batchsize, in_features)

# Compute log-likelihoods
lls = einet(x)
print(f"lls.shape: {lls.shape}")
print(f"lls: \n{lls}")
# Construct Einet
cfg = EinetConfig(
num_features=in_features,
depth=2,
num_sums=2,
num_channels=1,
num_leaves=3,
num_repetitions=3,
num_classes=1,
dropout=0.0,
leaf_type=Normal,
)
einet = Einet(cfg)

# Optimize Einet parameters (weights and leaf params)
optim = torch.optim.Adam(einet.parameters(), lr=0.001)
# Compute log-likelihoods
lls = einet(x)
print(f"lls.shape: {lls.shape}")
print(f"lls: \n{lls}")

for _ in range(1000):
optim.zero_grad()
# Optimize Einet parameters (weights and leaf params)
optim = torch.optim.Adam(einet.parameters(), lr=0.001)

# Forward pass: compute log-likelihoods
lls = einet(x)
for _ in range(1000):
optim.zero_grad()

# Forward pass: compute log-likelihoods
lls = einet(x)

# Backprop negative log-likelihood loss
nlls = -1 * lls.sum()
nlls.backward()
# Backprop negative log-likelihood loss
nlls = -1 * lls.sum()
nlls.backward()

# Update weights
optim.step()
# Update weights
optim.step()

# Construct samples
samples = einet.sample(2)
print(f"samples.shape: {samples.shape}")
print(f"samples: \n{samples}")
# Construct samples
samples = einet.sample(2)
print(f"samples.shape: {samples.shape}")
print(f"samples: \n{samples}")
```

## Citing EinsumNetworks
Expand Down

0 comments on commit 3b2114c

Please sign in to comment.