From 1518d996d47a66ee24587dee99c7dbaeacca90f3 Mon Sep 17 00:00:00 2001 From: Steven Braun Date: Mon, 17 Jun 2024 09:37:03 +0200 Subject: [PATCH] Use torch native logits_to_probs/probs_to_logits --- simple_einet/layers/distributions/binomial.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/simple_einet/layers/distributions/binomial.py b/simple_einet/layers/distributions/binomial.py index a6bf24f..96343ae 100644 --- a/simple_einet/layers/distributions/binomial.py +++ b/simple_einet/layers/distributions/binomial.py @@ -49,12 +49,12 @@ def __init__( self.logits = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions)) def _get_base_distribution(self, ctx: SamplingContext = None): - # Use sigmoid to ensure, that probs are in valid range - probs = self.logits.sigmoid() + # Cast logits to probabilities if ctx is not None and ctx.is_differentiable: + probs = logits_to_probs(self.logits, is_binary=True) return DifferentiableBinomial(probs=probs, total_count=self.total_count) else: - return dist.Binomial(probs=probs, total_count=self.total_count) + return dist.Binomial(logits=self.logits, total_count=self.total_count) class DifferentiableBinomial: