Skip to content

Commit

Permalink
Stay in logit space for conditional binomial
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Jun 17, 2024
1 parent 29fb588 commit 88bb239
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions simple_einet/layers/distributions/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch
from torch import distributions as dist
from torch.distributions.utils import probs_to_logits, logits_to_probs
from torch import nn

from simple_einet.layers.distributions.abstract_leaf import (
Expand Down Expand Up @@ -167,12 +168,11 @@ def __init__(
self.cond_fn = cond_fn
self.cond_idxs = cond_idxs

self.probs_conditioned_base = nn.Parameter(
0.5 + torch.rand(1, num_channels, num_features // 2, num_leaves, num_repetitions) * 0.1
)
self.probs_unconditioned = nn.Parameter(
0.5 + torch.rand(1, num_channels, num_features // 2, num_leaves, num_repetitions) * 0.1
)
p = 0.5 + (torch.rand(1, num_channels, num_features // 2, num_leaves, num_repetitions) - 0.5) * 0.2
self.logits_conditioned_base = nn.Parameter(probs_to_logits(p, is_binary=True))

p = 0.5 + (torch.rand(1, num_channels, num_features // 2, num_leaves, num_repetitions) - 0.5) * 0.2
self.logits_unconditioned = nn.Parameter(probs_to_logits(p, is_binary=True))

def get_conditioned_distribution(self, x_cond: torch.Tensor):
"""
Expand All @@ -190,22 +190,22 @@ def get_conditioned_distribution(self, x_cond: torch.Tensor):
x_cond_shape = x_cond.shape

# Get conditioned parameters
probs_cond = self.cond_fn(x_cond.view(-1, x_cond.shape[1], hw, hw))
probs_cond = probs_cond.view(
logits_cond = self.cond_fn(x_cond.view(-1, x_cond.shape[1], hw, hw))
logits_cond = logits_cond.view(
x_cond_shape[0],
x_cond_shape[1],
self.num_leaves,
self.num_repetitions,
hw * hw,
)
probs_cond = probs_cond.permute(0, 1, 4, 2, 3)
logits_cond = logits_cond.permute(0, 1, 4, 2, 3)

# Add conditioned parameters to default parameters
probs_cond = self.probs_conditioned_base + probs_cond
# Add conditioned parameters as "correction" to default parameters
logits_cond = self.logits_conditioned_base + logits_cond

probs_unc = self.probs_unconditioned.expand(x_cond.shape[0], -1, -1, -1, -1)
probs = torch.cat((probs_cond, probs_unc), dim=2)
d = dist.Binomial(self.total_count, logits=probs)
logits_unc = self.logits_unconditioned.expand(x_cond.shape[0], -1, -1, -1, -1)
logits = torch.cat((logits_cond, logits_unc), dim=2)
d = dist.Binomial(self.total_count, logits=logits)
return d

def forward(self, x, marginalized_scopes: List[int]):
Expand Down

0 comments on commit 88bb239

Please sign in to comment.