Skip to content

Commit 3f518fd

Browse files
committed
take value for cov from the model config in get_logits
1 parent c5726b1 commit 3f518fd

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/swag_transformers/base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,15 @@ class SampleLogitsMixin:
165165
"""Mixin class for classification models providing get_logits() method using SWAG"""
166166

167167
def get_logits(
168-
self, *args, num_predictions=None, scale=1.0, cov=True, block=False, **kwargs
168+
self, *args, num_predictions=None, scale=1.0, cov=None, block=False, **kwargs
169169
):
170170
"""Sample model parameters num_predictions times and get logits for the input
171171
172172
Results in a tensor of size batch_size x num_predictions x output_size.
173173
174174
"""
175+
if cov is None:
176+
cov = not self.config.no_cov_mat
175177
if num_predictions is None:
176178
sample = False
177179
num_predictions = 1

0 commit comments

Comments
 (0)