Skip to content

Commit

Permalink
Fix bernoulli dist
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Mar 12, 2024
1 parent 93728b3 commit 593aad8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 5 additions & 0 deletions simple_einet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

from simple_einet.layers.distributions.binomial import Binomial
from simple_einet.layers.distributions.bernoulli import Bernoulli
from simple_einet.layers.distributions.categorical import Categorical
from simple_einet.layers.distributions.multivariate_normal import MultivariateNormal
from simple_einet.layers.distributions.normal import Normal, RatNormal
Expand Down Expand Up @@ -525,6 +526,7 @@ class Dist(str, Enum):
NORMAL_RAT = "normal_rat"
BINOMIAL = "binomial"
CATEGORICAL = "categorical"
BERNOULLI = "bernoulli"


def get_distribution(dist: Dist, cfg):
Expand Down Expand Up @@ -554,6 +556,9 @@ def get_distribution(dist: Dist, cfg):
elif dist == Dist.MULTIVARIATE_NORMAL:
leaf_type = MultivariateNormal
leaf_kwargs = {"cardinality": cfg.multivariate_cardinality}
elif dist == Dist.BERNOULLI:
leaf_type = Bernoulli
leaf_kwargs = {}
else:
raise ValueError(f"Unknown distribution ({dist}).")
return leaf_kwargs, leaf_type
3 changes: 2 additions & 1 deletion simple_einet/layers/distributions/bernoulli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from simple_einet.sampling_utils import SamplingContext
from torch import distributions as dist
from torch import nn

Expand Down Expand Up @@ -26,6 +27,6 @@ def __init__(self, num_features: int, num_channels: int, num_leaves: int, num_re
# Create bernoulli parameters
self.probs = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions))

def _get_base_distribution(self):
def _get_base_distribution(self, ctx: SamplingContext = None):
# Use sigmoid to ensure, that probs are in valid range
return dist.Bernoulli(probs=torch.sigmoid(self.probs))

0 comments on commit 593aad8

Please sign in to comment.