diff --git a/simple_einet/layers/distributions/binomial.py b/simple_einet/layers/distributions/binomial.py index 96343ae..97a9af3 100644 --- a/simple_einet/layers/distributions/binomial.py +++ b/simple_einet/layers/distributions/binomial.py @@ -46,7 +46,9 @@ def __init__( self.total_count = check_valid(total_count, int, lower_bound=1) # Create binomial parameters as unnormalized log probabilities - self.logits = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)) + + p = 0.5 + (torch.rand(1, num_channels, num_features, num_leaves, num_repetitions) - 0.5) * 0.2 + self.logits = nn.Parameter(probs_to_logits(p, is_binary=True)) def _get_base_distribution(self, ctx: SamplingContext = None): # Cast logits to probabilities diff --git a/simple_einet/layers/distributions/categorical.py b/simple_einet/layers/distributions/categorical.py index 5f758ba..33e6b9d 100644 --- a/simple_einet/layers/distributions/categorical.py +++ b/simple_einet/layers/distributions/categorical.py @@ -1,4 +1,5 @@ import torch +from torch.distributions.utils import probs_to_logits from torch import distributions as dist from torch import nn from torch.nn import functional as F @@ -27,7 +28,8 @@ def __init__(self, num_features: int, num_channels: int, num_leaves: int, num_re super().__init__(num_features, num_channels, num_leaves, num_repetitions) # Create logits - self.logits = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions, num_bins)) + p = 0.5 + (torch.rand(1, num_channels, num_features, num_leaves, num_repetitions, num_bins) - 0.5) * 0.2 + self.logits = nn.Parameter(probs_to_logits(p)) def _get_base_distribution(self, ctx: SamplingContext = None): # Use sigmoid to ensure, that probs are in valid range