diff --git a/simple_einet/layers/distributions/binomial.py b/simple_einet/layers/distributions/binomial.py index a6bf24f..96343ae 100644 --- a/simple_einet/layers/distributions/binomial.py +++ b/simple_einet/layers/distributions/binomial.py @@ -49,12 +49,12 @@ def __init__( self.logits = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)) def _get_base_distribution(self, ctx: SamplingContext = None): - # Use sigmoid to ensure, that probs are in valid range - probs = self.logits.sigmoid() + # Cast logits to probabilities if ctx is not None and ctx.is_differentiable: + probs = logits_to_probs(self.logits, is_binary=True) return DifferentiableBinomial(probs=probs, total_count=self.total_count) else: - return dist.Binomial(probs=probs, total_count=self.total_count) + return dist.Binomial(logits=self.logits, total_count=self.total_count) class DifferentiableBinomial: