diff --git a/simple_einet/layers/distributions/binomial.py b/simple_einet/layers/distributions/binomial.py index 1d50da3..a6bf24f 100644 --- a/simple_einet/layers/distributions/binomial.py +++ b/simple_einet/layers/distributions/binomial.py @@ -3,6 +3,7 @@ import numpy as np import torch from torch import distributions as dist +from torch.distributions.utils import probs_to_logits, logits_to_probs from torch import nn from simple_einet.layers.distributions.abstract_leaf import ( @@ -167,12 +168,11 @@ def __init__( self.cond_fn = cond_fn self.cond_idxs = cond_idxs - self.probs_conditioned_base = nn.Parameter( - 0.5 + torch.rand(1, num_channels, num_features // 2, num_leaves, num_repetitions) * 0.1 - ) - self.probs_unconditioned = nn.Parameter( - 0.5 + torch.rand(1, num_channels, num_features // 2, num_leaves, num_repetitions) * 0.1 - ) + p = 0.5 + (torch.rand(1, num_channels, num_features // 2, num_leaves, num_repetitions) - 0.5) * 0.2 + self.logits_conditioned_base = nn.Parameter(probs_to_logits(p, is_binary=True)) + + p = 0.5 + (torch.rand(1, num_channels, num_features // 2, num_leaves, num_repetitions) - 0.5) * 0.2 + self.logits_unconditioned = nn.Parameter(probs_to_logits(p, is_binary=True)) def get_conditioned_distribution(self, x_cond: torch.Tensor): """ @@ -190,22 +190,22 @@ def get_conditioned_distribution(self, x_cond: torch.Tensor): x_cond_shape = x_cond.shape # Get conditioned parameters - probs_cond = self.cond_fn(x_cond.view(-1, x_cond.shape[1], hw, hw)) - probs_cond = probs_cond.view( + logits_cond = self.cond_fn(x_cond.view(-1, x_cond.shape[1], hw, hw)) + logits_cond = logits_cond.view( x_cond_shape[0], x_cond_shape[1], self.num_leaves, self.num_repetitions, hw * hw, ) - probs_cond = probs_cond.permute(0, 1, 4, 2, 3) + logits_cond = logits_cond.permute(0, 1, 4, 2, 3) - # Add conditioned parameters to default parameters - probs_cond = self.probs_conditioned_base + probs_cond + # Add conditioned parameters as "correction" to default parameters + logits_cond = self.logits_conditioned_base + logits_cond - probs_unc = self.probs_unconditioned.expand(x_cond.shape[0], -1, -1, -1, -1) - probs = torch.cat((probs_cond, probs_unc), dim=2) - d = dist.Binomial(self.total_count, logits=probs) + logits_unc = self.logits_unconditioned.expand(x_cond.shape[0], -1, -1, -1, -1) + logits = torch.cat((logits_cond, logits_unc), dim=2) + d = dist.Binomial(self.total_count, logits=logits) return d def forward(self, x, marginalized_scopes: List[int]):