Skip to content

Commit

Permalink
Fix mixing layer when sampling/mpe with cls>1
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Dec 5, 2023
1 parent e8628bc commit 051e606
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
28 changes: 22 additions & 6 deletions simple_einet/einet.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,15 +357,28 @@ def sample(
torch.Tensor: Samples generated according to the distribution specified by the SPN.
"""
assert class_index is None or evidence is None, "Cannot provide both, evidence and class indices."
class_is_given = class_index is not None
evidence_is_given = evidence is not None
is_multiclass = self.config.num_classes > 1

assert not (class_is_given and evidence_is_given), "Cannot provide both, evidence and class indices."
assert (
num_samples is None or evidence is None
num_samples is None or not evidence_is_given
), "Cannot provide both, number of samples to generate (num_samples) and evidence."
assert ((class_index is not None) and (self.config.num_classes > 1)) or (
(class_index is None) and (self.config.num_classes == 1)
)

# Check if evidence contains nans
if num_samples is not None:
assert num_samples > 0, "Number of samples must be > 0."

# if not is_mpe:
# assert ((class_index is not None) and (self.config.num_classes > 1)) or (
# (class_index is None) and (self.config.num_classes == 1)
# ), "Class index must be given if the number of classes is > 1 or must be none if the number of classes is 1."

if class_is_given:
assert (
self.config.num_classes > 1
), f"Class indices are only supported when the number of classes for this model is > 1."

if evidence is not None:
# Set n to the number of samples in the evidence
num_samples = evidence.shape[0]
Expand Down Expand Up @@ -423,6 +436,9 @@ def sample(
indices.requires_grad_(True) # Enable gradients

ctx.indices_out = indices
else:
# Sample class
ctx = self._class_sampling_root.sample(ctx=ctx)

# Save parent indices that were sampled from the sampling root
if self.config.num_repetitions > 1:
Expand Down
12 changes: 6 additions & 6 deletions simple_einet/layers/mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def _sample_from_weights(self, ctx: SamplingContext, log_weights: Tensor):
def _condition_weights_on_evidence(self, ctx: SamplingContext, log_weights: Tensor):
lls = self._input_cache["in"]

# Index repetition
# Index lls at correct repetitions
if ctx.is_differentiable:
r_idxs = ctx.indices_repetition.view(ctx.num_samples, 1, 1, 1)
lls = index_one_hot(lls, index=r_idxs, dim=2)
p_idxs = ctx.indices_out.view(ctx.num_samples, 1, self.num_sums_out, 1)
lls = index_one_hot(lls, index=p_idxs, dim=2)
else:
r_idxs = ctx.indices_repetition[..., None, None, None]
r_idxs = r_idxs.expand(-1, 1, self.num_sums_out, self.num_sums_in)
lls = lls.gather(dim=2, index=r_idxs).squeeze(2)
p_idxs = ctx.indices_out[..., None, None]
p_idxs = p_idxs.expand(-1, 1, 1, self.num_sums_in)
lls = lls.gather(dim=2, index=p_idxs).squeeze(2)

log_prior = log_weights
log_posterior = log_prior + lls
Expand Down

0 comments on commit 051e606

Please sign in to comment.