diff --git a/simple_einet/layers/distributions/normal.py b/simple_einet/layers/distributions/normal.py index 28946e5..ff34b3b 100644 --- a/simple_einet/layers/distributions/normal.py +++ b/simple_einet/layers/distributions/normal.py @@ -6,6 +6,7 @@ from simple_einet.layers.distributions.abstract_leaf import AbstractLeaf from simple_einet.sampling_utils import SamplingContext +from simple_einet.type_checks import check_valid class Normal(AbstractLeaf): @@ -70,6 +71,21 @@ def __init__( min_mean (float, optional): The minimum value for the mean. Defaults to None. max_mean (float, optional): The maximum value for the mean. Defaults to None. """ + super().__init__(num_features, num_channels, num_leaves, num_repetitions) + # Create gaussian means and stds + self.means = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)) + + if min_sigma is not None and max_sigma is not None: + # Init from normal + self.stds = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)) + else: + # Init uniform between 0 and 1 + self.stds = nn.Parameter(torch.rand(1, num_channels, num_features, num_leaves, num_repetitions)) + + self.min_sigma = check_valid(min_sigma, float, 0.0, max_sigma) + self.max_sigma = check_valid(max_sigma, float, min_sigma) + self.min_mean = check_valid(min_mean, float, upper_bound=max_mean, allow_none=True) + self.max_mean = check_valid(max_mean, float, min_mean, allow_none=True) def _get_base_distribution(self, ctx: SamplingContext = None) -> "CustomNormal": if self.min_sigma < self.max_sigma: