Skip to content

Commit

Permalink
add SwagBertForSequenceClassification
Browse files Browse the repository at this point in the history
  • Loading branch information
svirpioj committed Mar 20, 2024
1 parent 05795a8 commit 72b50df
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 23 deletions.
21 changes: 11 additions & 10 deletions src/swag_transformers/swag_bert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""PyTorch SWAG wrapper for BERT"""

import inspect
import logging

import torch

from transformers import PreTrainedModel, PretrainedConfig, BertConfig, BertModel
from transformers import PreTrainedModel, PretrainedConfig, BertConfig, BertModel, BertForSequenceClassification

from swag.posteriors.swag import SWAG

Expand All @@ -25,14 +24,8 @@ def __init__(
var_clamp: float = 1e-30,
**kwargs
):
internal_params = list(inspect.signature(self.internal_config_class).parameters)
internal_kwargs = {
name: kwargs.pop(name)
for name in internal_params
if name in kwargs
}
super().__init__(**kwargs)
internal_config = self.internal_config_class(**internal_kwargs)
super().__init__()
internal_config = self.internal_config_class(**kwargs)
self.no_cov_mat = no_cov_mat
self.max_num_models = max_num_models
self.var_clamp = var_clamp
Expand Down Expand Up @@ -65,3 +58,11 @@ class SwagBertModel(SwagBertPreTrainedModel):

def forward(self, *args, **kwargs):
return self.model.forward(*args, **kwargs)


class SwagBertForSequenceClassification(SwagBertPreTrainedModel):

internal_model_class = BertForSequenceClassification

def forward(self, *args, **kwargs):
return self.model.forward(*args, **kwargs)
45 changes: 32 additions & 13 deletions tests/test_swag_bert.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,44 @@

import logging
import unittest

import torch

from swag_transformers.swag_bert import SwagBertConfig, SwagBertPreTrainedModel, SwagBertModel
from transformers import AutoConfig

from swag_transformers.swag_bert import *


class TestSwagBert(unittest.TestCase):

def test_untrained(self):
config = SwagBertConfig(no_cov_mat=False, hidden_size=240)
print(config)
model = SwagBertPreTrainedModel(config)
print(model)
model = SwagBertModel(config)
print(model)
out = model.forward(input_ids=torch.tensor([[3, 14]]))
print(out)
model.model.sample()
out = model.forward(input_ids=torch.tensor([[3, 14]]))
print(out)
hidden_size = 240
config = SwagBertConfig(no_cov_mat=False, hidden_size=hidden_size)
logging.debug(config)
swag_model = SwagBertPreTrainedModel(config)
logging.debug(swag_model)
swag_model = SwagBertModel(config)
logging.debug(swag_model)
out = swag_model.forward(input_ids=torch.tensor([[3, 14]]))
logging.debug(out)
swag_model.model.sample()
out = swag_model.forward(input_ids=torch.tensor([[3, 14]]))
logging.debug(out)
self.assertEqual(out.last_hidden_state.shape, (1, 2, hidden_size))

def test_untrained_classifier(self):
hidden_size = 240
num_labels = 3
config = SwagBertConfig(no_cov_mat=False, hidden_size=hidden_size, num_labels=num_labels)
logging.debug(config)
swag_model = SwagBertForSequenceClassification(config)
swag_model.model.sample()
logging.debug(swag_model)
logging.debug(swag_model.model.base.config)
out = swag_model.forward(input_ids=torch.tensor([[3, 14]]))
logging.debug(out)
self.assertEqual(out.logits.shape, (1, num_labels))




if __name__ == "__main__":
Expand Down

0 comments on commit 72b50df

Please sign in to comment.