From 051e606e06cf9708366a4ca3526257dc32824367 Mon Sep 17 00:00:00 2001 From: Steven Braun Date: Tue, 5 Dec 2023 09:28:40 +0100 Subject: [PATCH] Fix mixing layer when sampling/mpe with cls>1 --- simple_einet/einet.py | 28 ++++++++++++++++++++++------ simple_einet/layers/mixing.py | 12 ++++++------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/simple_einet/einet.py b/simple_einet/einet.py index a4f502d..53ab358 100644 --- a/simple_einet/einet.py +++ b/simple_einet/einet.py @@ -357,15 +357,28 @@ def sample( torch.Tensor: Samples generated according to the distribution specified by the SPN. """ - assert class_index is None or evidence is None, "Cannot provide both, evidence and class indices." + class_is_given = class_index is not None + evidence_is_given = evidence is not None + is_multiclass = self.config.num_classes > 1 + + assert not (class_is_given and evidence_is_given), "Cannot provide both, evidence and class indices." assert ( - num_samples is None or evidence is None + num_samples is None or not evidence_is_given ), "Cannot provide both, number of samples to generate (num_samples) and evidence." - assert ((class_index is not None) and (self.config.num_classes > 1)) or ( - (class_index is None) and (self.config.num_classes == 1) - ) - # Check if evidence contains nans + if num_samples is not None: + assert num_samples > 0, "Number of samples must be > 0." + + # if not is_mpe: + # assert ((class_index is not None) and (self.config.num_classes > 1)) or ( + # (class_index is None) and (self.config.num_classes == 1) + # ), "Class index must be given if the number of classes is > 1 or must be none if the number of classes is 1." + + if class_is_given: + assert ( + self.config.num_classes > 1 + ), f"Class indices are only supported when the number of classes for this model is > 1." + if evidence is not None: # Set n to the number of samples in the evidence num_samples = evidence.shape[0] @@ -423,6 +436,9 @@ def sample( indices.requires_grad_(True) # Enable gradients ctx.indices_out = indices + else: + # Sample class + ctx = self._class_sampling_root.sample(ctx=ctx) # Save parent indices that were sampled from the sampling root if self.config.num_repetitions > 1: diff --git a/simple_einet/layers/mixing.py b/simple_einet/layers/mixing.py index 8d44e4d..ac244a2 100644 --- a/simple_einet/layers/mixing.py +++ b/simple_einet/layers/mixing.py @@ -81,14 +81,14 @@ def _sample_from_weights(self, ctx: SamplingContext, log_weights: Tensor): def _condition_weights_on_evidence(self, ctx: SamplingContext, log_weights: Tensor): lls = self._input_cache["in"] - # Index repetition + # Index lls at correct repetitions if ctx.is_differentiable: - r_idxs = ctx.indices_repetition.view(ctx.num_samples, 1, 1, 1) - lls = index_one_hot(lls, index=r_idxs, dim=2) + p_idxs = ctx.indices_out.view(ctx.num_samples, 1, self.num_sums_out, 1) + lls = index_one_hot(lls, index=p_idxs, dim=2) else: - r_idxs = ctx.indices_repetition[..., None, None, None] - r_idxs = r_idxs.expand(-1, 1, self.num_sums_out, self.num_sums_in) - lls = lls.gather(dim=2, index=r_idxs).squeeze(2) + p_idxs = ctx.indices_out[..., None, None] + p_idxs = p_idxs.expand(-1, 1, 1, self.num_sums_in) + lls = lls.gather(dim=2, index=p_idxs).squeeze(2) log_prior = log_weights log_posterior = log_prior + lls