diff --git a/simple_einet/layers/factorized_leaf.py b/simple_einet/layers/factorized_leaf.py index da49d54..8db2a59 100644 --- a/simple_einet/layers/factorized_leaf.py +++ b/simple_einet/layers/factorized_leaf.py @@ -41,15 +41,25 @@ def __init__( self.num_features_out = num_features_out # Size of the factorized groups of RVs - cardinality = int(np.round(self.num_features / self.num_features_out)) + cardinality = int(np.floor(self.num_features / self.num_features_out)) + + # Construct equal group sizes, such that (sum(group_sizes) == num_features) and the are num_features_out groups + group_sizes = np.ones(self.num_features_out, dtype=int) * cardinality + rest = self.num_features - cardinality * self.num_features_out + for i in range(rest): + group_sizes[i] += 1 + np.random.shuffle(group_sizes) # Construct mapping of scopes from in_features -> out_features scopes = torch.zeros(num_features, self.num_features_out, num_repetitions) for r in range(num_repetitions): idxs = torch.randperm(n=self.num_features) + offset = 0 for o in range(num_features_out): - low = o * cardinality - high = (o + 1) * cardinality + group_size = group_sizes[o] + low = offset + high = offset + group_size + offset = high if o == num_features_out - 1: high = self.num_features scopes[idxs[low:high], o, r] = 1