From 0bcfa702234ab23abcd86f90df0e5b0aedcd1ba3 Mon Sep 17 00:00:00 2001 From: Sami Virpioja Date: Wed, 20 Mar 2024 12:49:57 +0200 Subject: [PATCH] add initializing from pretrained BERT --- src/swag_transformers/swag_bert/__init__.py | 19 +++++++++++- tests/test_swag_bert.py | 32 +++++++++++++++++++-- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/src/swag_transformers/swag_bert/__init__.py b/src/swag_transformers/swag_bert/__init__.py index 3dd177f..0fb2a1e 100644 --- a/src/swag_transformers/swag_bert/__init__.py +++ b/src/swag_transformers/swag_bert/__init__.py @@ -1,10 +1,12 @@ """PyTorch SWAG wrapper for BERT""" import logging +from typing import Union import torch -from transformers import PreTrainedModel, PretrainedConfig, BertConfig, BertModel, BertForSequenceClassification +from transformers import PreTrainedModel, PretrainedConfig, BertConfig, BertModel, \ + BertPreTrainedModel, BertForSequenceClassification from swag.posteriors.swag import SWAG @@ -31,6 +33,13 @@ def __init__( self.var_clamp = var_clamp self.internal_model_config = internal_config.to_dict() + @classmethod + def from_config(cls, base_config: BertConfig, **kwargs): + """Initialize from existing BertConfig""" + config = cls(**kwargs) + config.internal_model_config = base_config.to_dict() + return config + class SwagBertPreTrainedModel(PreTrainedModel): @@ -49,6 +58,14 @@ def __init__(self, config): device='cpu' # FIXME: how to deal with device ) + @classmethod + def from_base(cls, base_model: BertPreTrainedModel, **kwargs): + """Initialize from existing BertPreTrainedModel""" + config = SwagBertConfig.from_config(base_model.config, **kwargs) + swag_model = cls(config) + swag_model.model.base = base_model + return swag_model + def _init_weights(self, module): # FIXME: What should be here? raise NotImplementedError("TODO") diff --git a/tests/test_swag_bert.py b/tests/test_swag_bert.py index 1e548a1..f35bc25 100644 --- a/tests/test_swag_bert.py +++ b/tests/test_swag_bert.py @@ -3,13 +3,16 @@ import torch -from transformers import AutoConfig +from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification, AutoTokenizer -from swag_transformers.swag_bert import * +from swag_transformers.swag_bert import SwagBertConfig, SwagBertModel, SwagBertPreTrainedModel, \ + SwagBertForSequenceClassification class TestSwagBert(unittest.TestCase): + pretrained_model_name = 'prajjwal1/bert-tiny' + def test_untrained(self): hidden_size = 240 config = SwagBertConfig(no_cov_mat=False, hidden_size=hidden_size) @@ -18,7 +21,10 @@ def test_untrained(self): logging.debug(swag_model) swag_model = SwagBertModel(config) logging.debug(swag_model) - out = swag_model.forward(input_ids=torch.tensor([[3, 14]])) + with self.assertLogs(level='WARNING') as cm: + out = swag_model.forward(input_ids=torch.tensor([[3, 14]])) + # Warning from using forward before sampling parameters + self.assertTrue(any(msg.startswith('WARNING') for msg in cm.output)) logging.debug(out) swag_model.model.sample() out = swag_model.forward(input_ids=torch.tensor([[3, 14]])) @@ -38,7 +44,27 @@ def test_untrained_classifier(self): logging.debug(out) self.assertEqual(out.logits.shape, (1, num_labels)) + def test_pretrained_bert_tiny(self): + model = AutoModel.from_pretrained(self.pretrained_model_name) + hidden_size = model.config.hidden_size + config = SwagBertConfig.from_config(model.config, no_cov_mat=False) + logging.debug(config) + swag_model = SwagBertModel.from_base(model) + logging.debug(swag_model) + swag_model.model.sample() + out = swag_model.forward(input_ids=torch.tensor([[3, 14]])) + logging.debug(out) + self.assertEqual(out.last_hidden_state.shape, (1, 2, hidden_size)) + def test_pretrained_bert_tiny_classifier(self): + num_labels = 4 + model = AutoModelForSequenceClassification.from_pretrained(self.pretrained_model_name, num_labels=num_labels) + swag_model = SwagBertForSequenceClassification.from_base(model) + logging.debug(swag_model) + swag_model.model.sample() + out = swag_model.forward(input_ids=torch.tensor([[3, 14]])) + logging.debug(out) + self.assertEqual(out.logits.shape, (1, num_labels)) if __name__ == "__main__":