From 593aad8b9234ec48a50c7057666151a35387d2b6 Mon Sep 17 00:00:00 2001 From: Steven Braun Date: Tue, 12 Mar 2024 07:23:47 +0100 Subject: [PATCH] Fix bernoulli dist --- simple_einet/data.py | 5 +++++ simple_einet/layers/distributions/bernoulli.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/simple_einet/data.py b/simple_einet/data.py index 2a5d1f1..8cfc60c 100644 --- a/simple_einet/data.py +++ b/simple_einet/data.py @@ -26,6 +26,7 @@ ) from simple_einet.layers.distributions.binomial import Binomial +from simple_einet.layers.distributions.bernoulli import Bernoulli from simple_einet.layers.distributions.categorical import Categorical from simple_einet.layers.distributions.multivariate_normal import MultivariateNormal from simple_einet.layers.distributions.normal import Normal, RatNormal @@ -525,6 +526,7 @@ class Dist(str, Enum): NORMAL_RAT = "normal_rat" BINOMIAL = "binomial" CATEGORICAL = "categorical" + BERNOULLI = "bernoulli" def get_distribution(dist: Dist, cfg): @@ -554,6 +556,9 @@ def get_distribution(dist: Dist, cfg): elif dist == Dist.MULTIVARIATE_NORMAL: leaf_type = MultivariateNormal leaf_kwargs = {"cardinality": cfg.multivariate_cardinality} + elif dist == Dist.BERNOULLI: + leaf_type = Bernoulli + leaf_kwargs = {} else: raise ValueError(f"Unknown distribution ({dist}).") return leaf_kwargs, leaf_type diff --git a/simple_einet/layers/distributions/bernoulli.py b/simple_einet/layers/distributions/bernoulli.py index 7b52e05..27d4104 100644 --- a/simple_einet/layers/distributions/bernoulli.py +++ b/simple_einet/layers/distributions/bernoulli.py @@ -1,4 +1,5 @@ import torch +from simple_einet.sampling_utils import SamplingContext from torch import distributions as dist from torch import nn @@ -26,6 +27,6 @@ def __init__(self, num_features: int, num_channels: int, num_leaves: int, num_re # Create bernoulli parameters self.probs = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)) - def _get_base_distribution(self): + def _get_base_distribution(self, ctx: SamplingContext = None): # Use sigmoid to ensure, that probs are in valid range return dist.Bernoulli(probs=torch.sigmoid(self.probs))