Skip to content

Commit

Permalink
Use torch native logits_to_probs/probs_to_logits
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Jun 17, 2024
1 parent 88bb239 commit 1518d99
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions simple_einet/layers/distributions/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def __init__(
self.logits = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions))

def _get_base_distribution(self, ctx: SamplingContext = None):
# Use sigmoid to ensure, that probs are in valid range
probs = self.logits.sigmoid()
# Cast logits to probabilities
if ctx is not None and ctx.is_differentiable:
probs = logits_to_probs(self.logits, is_binary=True)
return DifferentiableBinomial(probs=probs, total_count=self.total_count)
else:
return dist.Binomial(probs=probs, total_count=self.total_count)
return dist.Binomial(logits=self.logits, total_count=self.total_count)


class DifferentiableBinomial:
Expand Down

0 comments on commit 1518d99

Please sign in to comment.