From 1523ce8a222409ea9f132db609a16958b7a69e91 Mon Sep 17 00:00:00 2001 From: Steven Braun Date: Thu, 2 Nov 2023 16:41:24 +0100 Subject: [PATCH] leaves: Fix missing self object --- simple_einet/layers/distributions/multivariate_normal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/simple_einet/layers/distributions/multivariate_normal.py b/simple_einet/layers/distributions/multivariate_normal.py index 124bfac..a7e7bd3 100644 --- a/simple_einet/layers/distributions/multivariate_normal.py +++ b/simple_einet/layers/distributions/multivariate_normal.py @@ -66,8 +66,8 @@ def scale_tril(self): def _get_base_distribution(self, ctx: SamplingContext = None, marginalized_scopes = None): # View means and scale_tril - means = self.means.view(self._num_dists, cardinality) - scale_tril = self.scale_tril.view(self._num_dists, cardinality, cardinality) + means = self.means.view(self._num_dists, self.cardinality) + scale_tril = self.scale_tril.view(self._num_dists, self.cardinality, self.cardinality) mv = CustomMultivariateNormalDist(