Skip to content

Commit

Permalink
leaves: Fix RatNormal missing initializer
Browse files Browse the repository at this point in the history
See also: #4
  • Loading branch information
braun-steven committed Oct 30, 2023
1 parent 3b2114c commit adb3b14
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions simple_einet/layers/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from simple_einet.layers.distributions.abstract_leaf import AbstractLeaf
from simple_einet.sampling_utils import SamplingContext
from simple_einet.type_checks import check_valid


class Normal(AbstractLeaf):
Expand Down Expand Up @@ -70,6 +71,21 @@ def __init__(
min_mean (float, optional): The minimum value for the mean. Defaults to None.
max_mean (float, optional): The maximum value for the mean. Defaults to None.
"""
super().__init__(num_features, num_channels, num_leaves, num_repetitions)
# Create gaussian means and stds
self.means = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions))

if min_sigma is not None and max_sigma is not None:
# Init from normal
self.stds = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions))
else:
# Init uniform between 0 and 1
self.stds = nn.Parameter(torch.rand(1, num_channels, num_features, num_leaves, num_repetitions))

self.min_sigma = check_valid(min_sigma, float, 0.0, max_sigma)
self.max_sigma = check_valid(max_sigma, float, min_sigma)
self.min_mean = check_valid(min_mean, float, upper_bound=max_mean, allow_none=True)
self.max_mean = check_valid(max_mean, float, min_mean, allow_none=True)

def _get_base_distribution(self, ctx: SamplingContext = None) -> "CustomNormal":
if self.min_sigma < self.max_sigma:
Expand Down

0 comments on commit adb3b14

Please sign in to comment.