Skip to content

Commit

Permalink
take value for cov from the model config in get_logits
Browse files Browse the repository at this point in the history
  • Loading branch information
svirpioj committed Sep 4, 2024
1 parent c5726b1 commit 3f518fd
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/swag_transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,15 @@ class SampleLogitsMixin:
"""Mixin class for classification models providing get_logits() method using SWAG"""

def get_logits(
self, *args, num_predictions=None, scale=1.0, cov=True, block=False, **kwargs
self, *args, num_predictions=None, scale=1.0, cov=None, block=False, **kwargs
):
"""Sample model parameters num_predictions times and get logits for the input
Results in a tensor of size batch_size x num_predictions x output_size.
"""
if cov is None:
cov = not self.config.no_cov_mat
if num_predictions is None:
sample = False
num_predictions = 1
Expand Down

0 comments on commit 3f518fd

Please sign in to comment.