diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..0fdb3d0 --- /dev/null +++ b/.flake8 @@ -0,0 +1,2 @@ +[flake8] +max-line-length = 127 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..4899c06 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,43 @@ +name: Run Python Tests +on: + push: + branches: + - main + tags: + - '*' + pull_request: + branches: + - main + # schedule: + # # Run on Tuesdays at 5:59 + # - cron: '59 5 * * 2' + workflow_dispatch: +jobs: + build-n-test: + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m ensurepip --upgrade + python -m pip install --upgrade setuptools + python -m pip install --upgrade pip + python -m pip install flake8 + python -m pip install .[test] + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + python -m flake8 --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings + python -m flake8 --count --exit-zero --statistics + - name: Run tests with pytest + run: python -m pytest diff --git a/README.md b/README.md index 472b3b8..216ec61 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ convenience tools in the `transformers` library (e.g. `Pipeline` and See also [examples](./examples). -### Fine-tuning +### Fine-tuning with SWAG BERT model, sequence classification task: @@ -28,9 +28,39 @@ BERT model, sequence classification task: 5. Train the model (`trainer.train()`) 6. Store the complete model using `swag_model.save_pretrained(path)` +Note that `trainer.save_model(path)` will save only the base model without the distribution parameters from SWAG. + +For collecting the SWAG parameters, two possible schedules are supported: + +* After the end of each training epoch (default, `collect_steps = 0` for `SwagUpdateCallback`) +* After each N training steps (set `collect_steps > 0` for `SwagUpdateCallback`) + +### Sampling model parameters + +After `swag_model` is trained or fine-tuned as described above, +`swag_model.sample_parameters()` should be called to sample new model +parameters. After that, `swag_model.forward()` can be used to predict +new output from classifiers and `swag_model.generate()` to generate +new output from generative LMs. In order to get a proper distribution +of outputs, `sample_parameters()` needs to be called each time before +`forward()` or `generate()`. For classifiers, the `SampleLogitsMixin` +class provides the convenience method `get_logits()` that samples the +parameters and makes a new prediction `num_predictions` times, and +returns the logit values in a tensor. + ### Currently supported models -* BERT +* BERT (bidirectional encoder) * `BertPreTrainedModel` -> `SwagBertPreTrainedModel` * `BertModel` -> `SwagBertModel` + * `BertLMHeadModel` -> `SwagBertLMHeadModel` * `BertForSequenceClassification` -> `SwagBertForSequenceClassification` +* BART (bidirectional encoder + causal decoder) + * `BartPreTrainedModel` -> `SwagBartPreTrainedModel` + * `BartModel` -> `SwagBartModel` + * `BartForConditionalGeneration` -> `SwagBartForConditionalGeneration` + * `BartForSequenceClassification` -> `SwagBartForSequenceClassification` +* MarianMT (bidirectional encoder + causal decoder) + * `MarianPreTrainedModel` -> `SwagMarianPreTrainedModel` + * `MarianModel` -> `SwagMarianModel` + * `MarianMTModel` -> `SwagMarianMTModel` diff --git a/examples/bert_snli.py b/examples/bert_snli.py index 3719919..bb84b86 100644 --- a/examples/bert_snli.py +++ b/examples/bert_snli.py @@ -1,8 +1,6 @@ import argparse -import collections import logging import os -import sys import torch import transformers @@ -35,7 +33,8 @@ def main(): device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = transformers.AutoTokenizer.from_pretrained(args.base_model, cache_dir=args.model_cache_dir) - model = transformers.AutoModelForSequenceClassification.from_pretrained(args.base_model, num_labels=3, cache_dir=args.model_cache_dir) + model = transformers.AutoModelForSequenceClassification.from_pretrained( + args.base_model, num_labels=3, cache_dir=args.model_cache_dir) model.to(device) swag_model = SwagBertForSequenceClassification.from_base(model) swag_model.to(device) diff --git a/examples/load_pretrained.py b/examples/load_pretrained.py index 8f2626e..aed1dbc 100644 --- a/examples/load_pretrained.py +++ b/examples/load_pretrained.py @@ -1,13 +1,9 @@ import argparse -import collections import logging -import os -import sys import transformers from swag_transformers.swag_bert import SwagBertConfig, SwagBertForSequenceClassification -from swag_transformers.trainer_utils import SwagUpdateCallback def main(): diff --git a/examples/marian_mt.py b/examples/marian_mt.py index 59830fd..687da54 100644 --- a/examples/marian_mt.py +++ b/examples/marian_mt.py @@ -1,8 +1,6 @@ import argparse -import collections import logging import os -import sys import torch import transformers @@ -50,11 +48,7 @@ def tokenize_function(example): inputs = [pair['de'] for pair in example['translation']] targets = [pair['nl'] for pair in example['translation']] model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True) - - # Setup the tokenizer for targets - with tokenizer.as_target_tokenizer(): - labels = tokenizer(targets, max_length=max_target_length, truncation=True) - + labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True) model_inputs["labels"] = labels["input_ids"] return model_inputs diff --git a/setup.py b/setup.py index b963d0b..016f617 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ "swa_gaussian>=0.1.6" ], extras_require={ - "test": ["datasets", "nose", "sentencepiece"] + "test": ["datasets", "pytest", "sentencepiece"] }, classifiers=[ "Development Status :: 3 - Alpha", diff --git a/src/swag_transformers/swag_bert/__init__.py b/src/swag_transformers/base.py similarity index 61% rename from src/swag_transformers/swag_bert/__init__.py rename to src/swag_transformers/base.py index e8b963d..ad2a8d3 100644 --- a/src/swag_transformers/swag_bert/__init__.py +++ b/src/swag_transformers/base.py @@ -1,14 +1,13 @@ -"""PyTorch SWAG wrapper for BERT""" +"""SWAG wrapper base classes""" import copy import functools import logging -from typing import Union +from typing import Type import torch -from transformers import PreTrainedModel, PretrainedConfig, BertConfig, BertLMHeadModel, BertModel, \ - BertPreTrainedModel, BertForSequenceClassification +from transformers import PreTrainedModel, PretrainedConfig from swag.posteriors.swag import SWAG @@ -16,11 +15,18 @@ logger = logging.getLogger(__name__) -class SwagBertConfig(PretrainedConfig): - """Config for BERT model averaging with SWAG""" +class SwagConfig(PretrainedConfig): + """Base configuration class for SWAG models - model_type = 'swag_bert' - internal_config_class = BertConfig + For using this class, inherit it and define the following class + attributes: + + - model_type: string + - internal_config_class: class inherited from PretrainedConfig + + """ + + internal_config_class: Type[PretrainedConfig] = PretrainedConfig def __init__( self, @@ -41,13 +47,13 @@ def __init__( self.var_clamp = var_clamp @classmethod - def from_config(cls, base_config: BertConfig, **kwargs): - """Initialize from existing BertConfig""" + def from_config(cls, base_config: PretrainedConfig, **kwargs): + """Initialize from existing PretrainedConfig""" config = cls(**kwargs) config.internal_model_config = base_config.to_dict() return config - def update_internal_config(self, base_config: BertConfig): + def update_internal_config(self, base_config: PretrainedConfig): """Update internal config from base_config""" self.internal_model_config = base_config.to_dict() # Copy some things to the top level @@ -55,12 +61,20 @@ def update_internal_config(self, base_config: BertConfig): self.problem_type = base_config.problem_type -class SwagBertPreTrainedModel(PreTrainedModel): - """Pretrained SWAG BERT model""" +class SwagPreTrainedModel(PreTrainedModel): + """Base class for SWAG models wrapping PreTrainedModel + + For using this class, inherit it and define the following class + attributes: + + - base_model_prefix: string + - config_class: class inherited from PretrainedConfig + - internal_model_class: class inherited from PreTrainedModel + + """ - config_class = SwagBertConfig - base_model_prefix = 'swag_bert' - internal_model_class = BertModel + config_class: Type[SwagConfig] = SwagConfig + internal_model_class: Type[PreTrainedModel] = PreTrainedModel def __init__(self, config): super().__init__(config) @@ -87,9 +101,22 @@ def new_base_model(cls, *args, **kwargs): def _init_weights(self, module): self.swag.base._init_weights(module) + def sample_parameters(self): + """Sample new model parameters""" + self.swag.sample() -class SwagBertModel(SwagBertPreTrainedModel): - """SWAG BERT model""" + +class SwagModel(SwagPreTrainedModel): + """Base class for SWAG models + + For using this class, inherit it and define the following class + attributes: + + - base_model_prefix: string + - config_class: class inherited from PretrainedConfig + - internal_model_class: class inherited from PreTrainedModel + + """ def __init__(self, config, base_model=None): super().__init__(config) @@ -100,6 +127,8 @@ def __init__(self, config, base_model=None): max_num_models=config.max_num_models, var_clamp=config.var_clamp ) + self.prepare_inputs_for_generation = self.swag.base.prepare_inputs_for_generation + self.generate = self.swag.base.generate @staticmethod def _base_model_copy(model, *args, **kwargs): @@ -111,20 +140,29 @@ def _base_model_copy(model, *args, **kwargs): return model @classmethod - def from_base(cls, base_model: BertPreTrainedModel, **kwargs): - """Initialize from existing BertPreTrainedModel""" - config = SwagBertConfig.from_config(base_model.config, **kwargs) + def from_base(cls, base_model: PreTrainedModel, **kwargs): + """Initialize from existing PreTrainedModel""" + config = cls.config_class.from_config(base_model.config, **kwargs) swag_model = cls(config, base_model=base_model) return swag_model def forward(self, *args, **kwargs): + """Call forward pass from the base model""" return self.swag.forward(*args, **kwargs) + @classmethod + def can_generate(cls) -> bool: + return cls.internal_model_class.can_generate() -class SwagBertForSequenceClassification(SwagBertModel): - """SWAG BERT model for sequence classification""" + def prepare_inputs_for_generation(self, *args, **kwargs): + return self.swag.base.prepare_inputs_for_generation(*args, **kwargs) + + def generate(self, *args, **kwargs): + return self.swag.base.generate(*args, **kwargs) - internal_model_class = BertForSequenceClassification + +class SampleLogitsMixin: + """Mixin class for classification models providing get_logits() method using SWAG""" def get_logits( self, *args, num_predictions=None, scale=1.0, cov=True, block=False, **kwargs @@ -142,17 +180,8 @@ def get_logits( logits = [] for _ in range(num_predictions): if sample: - self.swag.sample(scale=scale, cov=cov, block=block) + self.sample_parameters(scale=scale, cov=cov, block=block) out = self.forward(*args, **kwargs) logits.append(out.logits) logits = torch.permute(torch.stack(logits), (1, 0, 2)) # [batch_size, num_predictions, output_size] return logits - - -class SwagBertLMHeadModel(SwagBertModel): - """SWAG BERT model with LM head""" - - internal_model_class = BertLMHeadModel - - def prepare_inputs_for_generation(self, *args, **kwargs): - return self.swag.base.prepare_inputs_for_generation(*args, **kwargs) diff --git a/src/swag_transformers/swag_bart.py b/src/swag_transformers/swag_bart.py new file mode 100644 index 0000000..3decc55 --- /dev/null +++ b/src/swag_transformers/swag_bart.py @@ -0,0 +1,49 @@ +"""PyTorch SWAG wrapper for BART""" + +import logging + +from transformers import BartConfig, BartModel, BartPreTrainedModel, BartForConditionalGeneration, \ + BartForSequenceClassification + +from .base import SwagConfig, SwagPreTrainedModel, SwagModel, SampleLogitsMixin + + +logger = logging.getLogger(__name__) + + +MODEL_TYPE = 'swag_bart' + + +class SwagBartConfig(SwagConfig): + """Config for BART model averaging with SWAG""" + + model_type = MODEL_TYPE + internal_config_class = BartConfig + + +class SwagBartPreTrainedModel(SwagPreTrainedModel): + """Pretrained SWAG BART model""" + + config_class = SwagBartConfig + base_model_prefix = MODEL_TYPE + internal_model_class = BartPreTrainedModel + + +class SwagBartModel(SwagModel): + """SWAG BART model""" + + config_class = SwagBartConfig + base_model_prefix = MODEL_TYPE + internal_model_class = BartModel + + +class SwagBartForConditionalGeneration(SwagBartModel): + """SWAG BART model for sequence classification""" + + internal_model_class = BartForConditionalGeneration + + +class SwagBartForSequenceClassification(SampleLogitsMixin, SwagBartModel): + """SWAG BART model for sequence classification""" + + internal_model_class = BartForSequenceClassification diff --git a/src/swag_transformers/swag_bert.py b/src/swag_transformers/swag_bert.py new file mode 100644 index 0000000..cc13db3 --- /dev/null +++ b/src/swag_transformers/swag_bert.py @@ -0,0 +1,49 @@ +"""PyTorch SWAG wrapper for BERT""" + +import logging + +from transformers import BertConfig, BertLMHeadModel, BertModel, BertPreTrainedModel, \ + BertForSequenceClassification + +from .base import SwagConfig, SwagPreTrainedModel, SwagModel, SampleLogitsMixin + + +logger = logging.getLogger(__name__) + + +MODEL_TYPE = 'swag_bert' + + +class SwagBertConfig(SwagConfig): + """Config for BERT model averaging with SWAG""" + + model_type = MODEL_TYPE + internal_config_class = BertConfig + + +class SwagBertPreTrainedModel(SwagPreTrainedModel): + """Pretrained SWAG BERT model""" + + config_class = SwagBertConfig + base_model_prefix = MODEL_TYPE + internal_model_class = BertPreTrainedModel + + +class SwagBertModel(SwagModel): + """SWAG BERT model""" + + config_class = SwagBertConfig + base_model_prefix = MODEL_TYPE + internal_model_class = BertModel + + +class SwagBertForSequenceClassification(SampleLogitsMixin, SwagBertModel): + """SWAG BERT model for sequence classification""" + + internal_model_class = BertForSequenceClassification + + +class SwagBertLMHeadModel(SwagBertModel): + """SWAG BERT model with LM head""" + + internal_model_class = BertLMHeadModel diff --git a/src/swag_transformers/swag_marian.py b/src/swag_transformers/swag_marian.py new file mode 100644 index 0000000..2a698e3 --- /dev/null +++ b/src/swag_transformers/swag_marian.py @@ -0,0 +1,43 @@ +"""PyTorch SWAG wrapper for Marian models""" + +import logging + +from transformers import MarianConfig, MarianModel, MarianMTModel +from transformers.models.marian import MarianPreTrainedModel + +from .base import SwagConfig, SwagPreTrainedModel, SwagModel + + +logger = logging.getLogger(__name__) + + +MODEL_TYPE = 'swag_marian' + + +class SwagMarianConfig(SwagConfig): + """Config for Marian model averaging with SWAG""" + + model_type = MODEL_TYPE + internal_config_class = MarianConfig + + +class SwagMarianPreTrainedModel(SwagPreTrainedModel): + """Pretrained SWAG Marian model""" + + config_class = SwagMarianConfig + base_model_prefix = MODEL_TYPE + internal_model_class = MarianPreTrainedModel + + +class SwagMarianModel(SwagModel): + """SWAG Marian model""" + + config_class = SwagMarianConfig + base_model_prefix = MODEL_TYPE + internal_model_class = MarianModel + + +class SwagMarianMTModel(SwagMarianModel): + """SWAG MarianMT model""" + + internal_model_class = MarianMTModel diff --git a/src/swag_transformers/swag_marian/__init__.py b/src/swag_transformers/swag_marian/__init__.py deleted file mode 100644 index 6aea7a2..0000000 --- a/src/swag_transformers/swag_marian/__init__.py +++ /dev/null @@ -1,133 +0,0 @@ -"""PyTorch SWAG wrapper for Marian models""" - -import copy -import functools -import logging -from typing import Union - -import torch - -from transformers import PreTrainedModel, PretrainedConfig, MarianConfig, \ - MarianModel, MarianMTModel -from transformers.models.marian import MarianPreTrainedModel - -from swag.posteriors.swag import SWAG - - -logger = logging.getLogger(__name__) - - -class SwagMarianConfig(PretrainedConfig): - """Config for Marian model averaging with SWAG""" - - model_type = 'swag_marian' - internal_config_class = MarianConfig - - def __init__( - self, - internal_model_config: dict = None, - no_cov_mat: bool = True, - max_num_models: int = 20, - var_clamp: float = 1e-30, - **kwargs - ): - super().__init__() - if internal_model_config: - self.internal_model_config = internal_model_config - else: - internal_config = self.internal_config_class(**kwargs) - self.internal_model_config = internal_config.to_dict() - self.no_cov_mat = no_cov_mat - self.max_num_models = max_num_models - self.var_clamp = var_clamp - - @classmethod - def from_config(cls, base_config: MarianConfig, **kwargs): - """Initialize from existing MarianConfig""" - config = cls(**kwargs) - config.internal_model_config = base_config.to_dict() - return config - - def update_internal_config(self, base_config: MarianConfig): - """Update internal config from base_config""" - self.internal_model_config = base_config.to_dict() - # Copy some things to the top level - if base_config.problem_type is not None: - self.problem_type = base_config.problem_type - - -class SwagMarianPreTrainedModel(PreTrainedModel): - """Pretrained SWAG Marian model""" - - config_class = SwagMarianConfig - base_model_prefix = 'swag_marian' - internal_model_class = MarianPreTrainedModel - - def __init__(self, config): - super().__init__(config) - self.swag = SWAG( - base=self.new_base_model, - no_cov_mat=config.no_cov_mat, - max_num_models=config.max_num_models, - var_clamp=config.var_clamp, - config=config.internal_config_class(**config.internal_model_config), - ) - self.post_init() - - @classmethod - def new_base_model(cls, *args, **kwargs): - """Return new model of the base class - - Any arguments are passed to the base class constructor. - - """ - model = cls.internal_model_class(*args, **kwargs) - model.tie_weights() - return model - - def _init_weights(self, module): - self.swag.base._init_weights(module) - - -class SwagMarianModel(SwagMarianPreTrainedModel): - """SWAG Marian model""" - - internal_model_class = MarianModel - - def __init__(self, config, base_model=None): - super().__init__(config) - if base_model: - self.swag = SWAG( - base=functools.partial(self._base_model_copy, base_model), - no_cov_mat=config.no_cov_mat, - max_num_models=config.max_num_models, - var_clamp=config.var_clamp - ) - - @staticmethod - def _base_model_copy(model, *args, **kwargs): - """Return deep copy of the model ignoring other arguments""" - # Has to be copied, otherwise SWAG would initialize parameters - # of the original model to zero - model = copy.deepcopy(model) - model.tie_weights() - return model - - @classmethod - def from_base(cls, base_model: MarianPreTrainedModel, **kwargs): - """Initialize from existing MarianPreTrainedModel""" - config = SwagMarianConfig.from_config(base_model.config, **kwargs) - swag_model = cls(config, base_model=base_model) - return swag_model - - def forward(self, *args, **kwargs): - return self.swag.forward(*args, **kwargs) - - -class SwagMarianMTModel(SwagMarianModel): - """SWAG MarianMT model""" - - internal_model_class = MarianMTModel - - def generate(self, *args, **kwargs): - return self.swag.base.generate(*args, **kwargs) diff --git a/src/swag_transformers/trainer_utils.py b/src/swag_transformers/trainer_utils.py index bae41c8..1a57f6c 100644 --- a/src/swag_transformers/trainer_utils.py +++ b/src/swag_transformers/trainer_utils.py @@ -2,7 +2,6 @@ import logging -from swag.posteriors.swag import SWAG from transformers import TrainerCallback diff --git a/tests/test_swag_bart.py b/tests/test_swag_bart.py new file mode 100644 index 0000000..0348656 --- /dev/null +++ b/tests/test_swag_bart.py @@ -0,0 +1,78 @@ +import logging +import unittest +import tempfile + +import torch + +from transformers import AutoTokenizer, BartForConditionalGeneration + +from swag_transformers.swag_bart import SwagBartConfig, SwagBartModel, SwagBartPreTrainedModel, \ + SwagBartForConditionalGeneration + + +class TestSwagBart(unittest.TestCase): + + pretrained_model_name = 'Finnish-NLP/bart-small-finnish' + # pretrained_model_name = 'sshleifer/bart-tiny-random' + + def test_untrained(self): + hidden_size = 240 + config = SwagBartConfig(no_cov_mat=False, hidden_size=hidden_size) + logging.debug(config) + swag_model = SwagBartPreTrainedModel(config) + swag_model = SwagBartModel(config) + logging.debug(swag_model) + self.assertEqual(swag_model.device.type, 'cpu') + with self.assertLogs(level='WARNING') as cm: + out = swag_model.forward(input_ids=torch.tensor([[3, 14]])) + # Warning from using forward before sampling parameters + self.assertTrue(any(msg.startswith('WARNING') for msg in cm.output)) + logging.debug(out) + swag_model.sample_parameters() + out = swag_model.forward(input_ids=torch.tensor([[3, 14]])) + logging.debug(out) + self.assertEqual(out.last_hidden_state.shape, (1, 2, hidden_size)) + + def test_pretrained_bart_generative(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + model = BartForConditionalGeneration.from_pretrained(self.pretrained_model_name) + model.to(device) + self.assertEqual(model.device.type, device) + swag_model = SwagBartForConditionalGeneration.from_base(model) + swag_model.to(device) + self.assertEqual(swag_model.device.type, device) + tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name) + + swag_model.swag.collect_model(model) + swag_model.sample_parameters() + + # Test forward + base_out = model.forward(input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]])) + out = swag_model.forward(input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]])) + self.assertTrue(torch.allclose(base_out.logits, out.logits)) + + # Test generate + example = "I have no BART and I must generate" + batch = tokenizer(example, return_tensors="pt") + base_generated_ids = model.generate(batch["input_ids"]) + base_out = tokenizer.batch_decode(base_generated_ids, skip_special_tokens=True) + generated_ids = swag_model.generate(batch["input_ids"]) + out = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + logging.info(base_out) + logging.info(out) + self.assertEqual(base_out, out) + + # Test saving & loading + with tempfile.TemporaryDirectory() as tempdir: + swag_model.save_pretrained(tempdir) + stored_model = SwagBartForConditionalGeneration.from_pretrained(tempdir).to(device) + + generated_ids = stored_model.generate(batch["input_ids"]) + out = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + logging.info(out) + self.assertEqual(base_out, out) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + unittest.main() diff --git a/tests/test_swag_bert.py b/tests/test_swag_bert.py index 12ade61..8e73233 100644 --- a/tests/test_swag_bert.py +++ b/tests/test_swag_bert.py @@ -2,11 +2,10 @@ import unittest import tempfile -import numpy as np import torch from datasets import Dataset, DatasetDict -from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification, AutoModelWithLMHead, \ +from transformers import AutoModel, AutoModelForSequenceClassification, \ AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments from swag_transformers.swag_bert import SwagBertConfig, SwagBertLMHeadModel, SwagBertModel, SwagBertPreTrainedModel, \ @@ -31,7 +30,7 @@ def test_untrained(self): # Warning from using forward before sampling parameters self.assertTrue(any(msg.startswith('WARNING') for msg in cm.output)) logging.debug(out) - swag_model.swag.sample() + swag_model.sample_parameters() out = swag_model.forward(input_ids=torch.tensor([[3, 14]])) logging.debug(out) self.assertEqual(out.last_hidden_state.shape, (1, 2, hidden_size)) @@ -43,7 +42,7 @@ def test_untrained_classifier(self): config = SwagBertConfig(no_cov_mat=False, hidden_size=hidden_size, num_labels=num_labels) logging.debug(config) swag_model = SwagBertForSequenceClassification(config) - swag_model.swag.sample() + swag_model.sample_parameters() logging.debug(swag_model) logging.debug(swag_model.swag.base.config) self.assertEqual(swag_model.device.type, 'cpu') @@ -60,7 +59,7 @@ def test_untrained_lmhead(self): no_cov_mat=False, num_attention_heads=num_attention_heads, hidden_size=hidden_size, vocab_size=vocab_size, is_decoder=True) swag_model = SwagBertLMHeadModel(config) - swag_model.swag.sample() + swag_model.sample_parameters() logging.debug(swag_model.config) logging.debug(swag_model) prep_inputs = swag_model.prepare_inputs_for_generation(input_ids=torch.tensor([[3, 14, 45]])) @@ -78,7 +77,7 @@ def test_pretrained_bert_tiny_base(self): swag_model = SwagBertModel.from_base(model) logging.debug(swag_model) self.assertEqual(swag_model.device.type, 'cpu') - swag_model.swag.sample() + swag_model.sample_parameters() out = swag_model.forward(input_ids=torch.tensor([[3, 14]])) logging.debug(out) self.assertEqual(out.last_hidden_state.shape, (1, 2, hidden_size)) @@ -94,12 +93,13 @@ def test_pretrained_bert_tiny_classifier_test(self): logging.debug(swag_model.swag.device) logging.debug(swag_model.swag.base.device) swag_model.swag.collect_model(model) - swag_model.swag.sample() + swag_model.sample_parameters() out = swag_model.forward(input_ids=torch.tensor([[3, 14]])) logging.debug(out) self.assertEqual(out.logits.shape, (1, num_labels)) - def _data_gen(self): + @staticmethod + def _data_gen(): yield {"text": "Hello world", "label": 0} yield {"text": "Just some swaggering", "label": 1} yield {"text": "Have a good day", "label": 0} @@ -125,7 +125,7 @@ def test_pretrained_bert_tiny_classifier_finetune(self): out_swag = swag_model(**tokens) self.assertEqual(out_swag.logits.shape, (2, num_labels)) self.assertTrue(torch.allclose(out_swag.logits.to('cpu'), torch.zeros(*out_swag.logits.shape))) - swag_model.swag.sample() + swag_model.sample_parameters() out_swag = swag_model(**tokens) self.assertEqual(out_swag.logits.shape, (2, num_labels)) self.assertTrue(torch.allclose(out_swag.logits.to('cpu'), torch.zeros(*out_swag.logits.shape))) @@ -153,7 +153,7 @@ def tokenize_function(example): ) trainer.train() self.assertEqual(swag_model.swag.n_models, train_epochs) - swag_model.swag.sample() + swag_model.sample_parameters() out_swag = swag_model(**tokens) self.assertEqual(out_swag.logits.shape, (2, num_labels)) diff --git a/tests/test_swag_marian.py b/tests/test_swag_marian.py index 0997fd7..9bc5b00 100644 --- a/tests/test_swag_marian.py +++ b/tests/test_swag_marian.py @@ -37,14 +37,14 @@ def test_untrained(self): logging.debug(swag_model) swag_model = SwagMarianModel(config) logging.debug(swag_model) - swag_model.swag.sample() + swag_model.sample_parameters() logging.debug(swag_model.swag.base.decoder.embed_tokens.weight) out = swag_model.forward(**input_dict) self.assertEqual(out.last_hidden_state.shape, (1, 1, hidden_size)) swag_model = SwagMarianMTModel(config) logging.debug(swag_model) - swag_model.swag.sample() + swag_model.sample_parameters() logging.debug(swag_model.swag.base.model.decoder.embed_tokens.weight) out = swag_model.forward(**input_dict) logging.debug(out.logits.shape) @@ -67,7 +67,7 @@ def test_pretrained_marian_tiny_test(self): bufs_and_params_before = set(buf_and_param_names(swag_model)) swag_model.swag.collect_model(model) - swag_model.swag.sample() + swag_model.sample_parameters() bufs_and_params_after = set(buf_and_param_names(swag_model)) self.assertEqual(bufs_and_params_before, bufs_and_params_after) @@ -106,7 +106,8 @@ def test_pretrained_marian_tiny_test(self): self.assertGreater(len(output), 0) self.assertEqual(base_output, output) - def _data_gen(self): + @staticmethod + def _data_gen(): yield {"source": "India and Japan prime ministers meet in Tokyo", "target": "Die Premierminister Indiens und Japans trafen sich in Tokio."} yield {"source": "High on the agenda are plans for greater nuclear co-operation.", @@ -136,11 +137,7 @@ def tokenize_function(example): inputs = example['source'] targets = example['target'] model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True) - - # Setup the tokenizer for targets - with tokenizer.as_target_tokenizer(): - labels = tokenizer(targets, max_length=max_target_length, truncation=True) - + labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True) model_inputs["labels"] = labels["input_ids"] return model_inputs @@ -171,7 +168,7 @@ def tokenize_function(example): trainer.train() logging.info("N models: %s", swag_model.swag.n_models.item()) # self.assertEqual(swag_model.swag.n_models, train_epochs) - swag_model.swag.sample() + swag_model.sample_parameters() sample_text = "India and Japan prime ministers meet in Tokyo" batch = tokenizer([sample_text], return_tensors="pt") generated_ids = model.generate(**batch, max_new_tokens=10) @@ -205,4 +202,4 @@ def tokenize_function(example): self.assertTrue(torch.allclose(orig_embed, loaded_embed)) self.assertTrue(torch.allclose(loaded_embed, loaded_enc)) self.assertTrue(torch.allclose(loaded_embed, loaded_head)) - stored_model.swag.sample() + stored_model.sample_parameters()