diff --git a/README.md b/README.md index 8199363..66aa69a 100644 --- a/README.md +++ b/README.md @@ -40,48 +40,60 @@ pip install -e . ```python import torch -from simple_einet.distributions import RatNormal +from simple_einet.layers.distributions.normal import Normal from simple_einet.einet import Einet from simple_einet.einet import EinetConfig -torch.manual_seed(0) -# Input dimensions -in_features = 4 -batchsize = 5 -out_features = 3 +if __name__ == "__main__": + torch.manual_seed(0) -# Create input sample -x = torch.randn(batchsize, in_features) + # Input dimensions + in_features = 4 + batchsize = 5 -# Construct Einet -einet = Einet(EinetConfig(num_features=in_features, depth=2, num_sums=2, num_channels=1, num_leaves=3, num_repetitions=3, num_classes=out_features, dropout=0.0, leaf_type=Normal)) + # Create input sample + x = torch.randn(batchsize, in_features) -# Compute log-likelihoods -lls = einet(x) -print(f"lls.shape: {lls.shape}") -print(f"lls: \n{lls}") + # Construct Einet + cfg = EinetConfig( + num_features=in_features, + depth=2, + num_sums=2, + num_channels=1, + num_leaves=3, + num_repetitions=3, + num_classes=1, + dropout=0.0, + leaf_type=Normal, + ) + einet = Einet(cfg) -# Optimize Einet parameters (weights and leaf params) -optim = torch.optim.Adam(einet.parameters(), lr=0.001) + # Compute log-likelihoods + lls = einet(x) + print(f"lls.shape: {lls.shape}") + print(f"lls: \n{lls}") -for _ in range(1000): - optim.zero_grad() + # Optimize Einet parameters (weights and leaf params) + optim = torch.optim.Adam(einet.parameters(), lr=0.001) - # Forward pass: compute log-likelihoods - lls = einet(x) + for _ in range(1000): + optim.zero_grad() + + # Forward pass: compute log-likelihoods + lls = einet(x) - # Backprop negative log-likelihood loss - nlls = -1 * lls.sum() - nlls.backward() + # Backprop negative log-likelihood loss + nlls = -1 * lls.sum() + nlls.backward() - # Update weights - optim.step() + # Update weights + optim.step() -# Construct samples -samples = einet.sample(2) -print(f"samples.shape: {samples.shape}") -print(f"samples: \n{samples}") + # Construct samples + samples = einet.sample(2) + print(f"samples.shape: {samples.shape}") + print(f"samples: \n{samples}") ``` ## Citing EinsumNetworks