Skip to content

Commit

Permalink
add initializing from pretrained BERT
Browse files Browse the repository at this point in the history
  • Loading branch information
svirpioj committed Mar 20, 2024
1 parent 72b50df commit 0bcfa70
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
19 changes: 18 additions & 1 deletion src/swag_transformers/swag_bert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""PyTorch SWAG wrapper for BERT"""

import logging
from typing import Union

import torch

from transformers import PreTrainedModel, PretrainedConfig, BertConfig, BertModel, BertForSequenceClassification
from transformers import PreTrainedModel, PretrainedConfig, BertConfig, BertModel, \
BertPreTrainedModel, BertForSequenceClassification

from swag.posteriors.swag import SWAG

Expand All @@ -31,6 +33,13 @@ def __init__(
self.var_clamp = var_clamp
self.internal_model_config = internal_config.to_dict()

@classmethod
def from_config(cls, base_config: BertConfig, **kwargs):
"""Initialize from existing BertConfig"""
config = cls(**kwargs)
config.internal_model_config = base_config.to_dict()
return config


class SwagBertPreTrainedModel(PreTrainedModel):

Expand All @@ -49,6 +58,14 @@ def __init__(self, config):
device='cpu' # FIXME: how to deal with device
)

@classmethod
def from_base(cls, base_model: BertPreTrainedModel, **kwargs):
"""Initialize from existing BertPreTrainedModel"""
config = SwagBertConfig.from_config(base_model.config, **kwargs)
swag_model = cls(config)
swag_model.model.base = base_model
return swag_model

def _init_weights(self, module):
# FIXME: What should be here?
raise NotImplementedError("TODO")
Expand Down
32 changes: 29 additions & 3 deletions tests/test_swag_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

import torch

from transformers import AutoConfig
from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification, AutoTokenizer

from swag_transformers.swag_bert import *
from swag_transformers.swag_bert import SwagBertConfig, SwagBertModel, SwagBertPreTrainedModel, \
SwagBertForSequenceClassification


class TestSwagBert(unittest.TestCase):

pretrained_model_name = 'prajjwal1/bert-tiny'

def test_untrained(self):
hidden_size = 240
config = SwagBertConfig(no_cov_mat=False, hidden_size=hidden_size)
Expand All @@ -18,7 +21,10 @@ def test_untrained(self):
logging.debug(swag_model)
swag_model = SwagBertModel(config)
logging.debug(swag_model)
out = swag_model.forward(input_ids=torch.tensor([[3, 14]]))
with self.assertLogs(level='WARNING') as cm:
out = swag_model.forward(input_ids=torch.tensor([[3, 14]]))
# Warning from using forward before sampling parameters
self.assertTrue(any(msg.startswith('WARNING') for msg in cm.output))
logging.debug(out)
swag_model.model.sample()
out = swag_model.forward(input_ids=torch.tensor([[3, 14]]))
Expand All @@ -38,7 +44,27 @@ def test_untrained_classifier(self):
logging.debug(out)
self.assertEqual(out.logits.shape, (1, num_labels))

def test_pretrained_bert_tiny(self):
model = AutoModel.from_pretrained(self.pretrained_model_name)
hidden_size = model.config.hidden_size
config = SwagBertConfig.from_config(model.config, no_cov_mat=False)
logging.debug(config)
swag_model = SwagBertModel.from_base(model)
logging.debug(swag_model)
swag_model.model.sample()
out = swag_model.forward(input_ids=torch.tensor([[3, 14]]))
logging.debug(out)
self.assertEqual(out.last_hidden_state.shape, (1, 2, hidden_size))

def test_pretrained_bert_tiny_classifier(self):
num_labels = 4
model = AutoModelForSequenceClassification.from_pretrained(self.pretrained_model_name, num_labels=num_labels)
swag_model = SwagBertForSequenceClassification.from_base(model)
logging.debug(swag_model)
swag_model.model.sample()
out = swag_model.forward(input_ids=torch.tensor([[3, 14]]))
logging.debug(out)
self.assertEqual(out.logits.shape, (1, num_labels))


if __name__ == "__main__":
Expand Down

0 comments on commit 0bcfa70

Please sign in to comment.