From 3f518fd20a92cd251029cd34750feadcb3c84f91 Mon Sep 17 00:00:00 2001 From: Sami Virpioja Date: Wed, 4 Sep 2024 13:35:44 +0300 Subject: [PATCH] take value for cov from the model config in get_logits --- src/swag_transformers/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/swag_transformers/base.py b/src/swag_transformers/base.py index d4c3115..375cc84 100644 --- a/src/swag_transformers/base.py +++ b/src/swag_transformers/base.py @@ -165,13 +165,15 @@ class SampleLogitsMixin: """Mixin class for classification models providing get_logits() method using SWAG""" def get_logits( - self, *args, num_predictions=None, scale=1.0, cov=True, block=False, **kwargs + self, *args, num_predictions=None, scale=1.0, cov=None, block=False, **kwargs ): """Sample model parameters num_predictions times and get logits for the input Results in a tensor of size batch_size x num_predictions x output_size. """ + if cov is None: + cov = not self.config.no_cov_mat if num_predictions is None: sample = False num_predictions = 1