Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add base classes to inherit for wrappers #2

Merged
merged 8 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[flake8]
max-line-length = 127
43 changes: 43 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: Run Python Tests
on:
push:
branches:
- main
tags:
- '*'
pull_request:
branches:
- main
# schedule:
# # Run on Tuesdays at 5:59
# - cron: '59 5 * * 2'
workflow_dispatch:
jobs:
build-n-test:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m ensurepip --upgrade
python -m pip install --upgrade setuptools
python -m pip install --upgrade pip
python -m pip install flake8
python -m pip install .[test]
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
python -m flake8 --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings
python -m flake8 --count --exit-zero --statistics
- name: Run tests with pytest
run: python -m pytest
34 changes: 32 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ convenience tools in the `transformers` library (e.g. `Pipeline` and

See also [examples](./examples).

### Fine-tuning
### Fine-tuning with SWAG

BERT model, sequence classification task:

Expand All @@ -28,9 +28,39 @@ BERT model, sequence classification task:
5. Train the model (`trainer.train()`)
6. Store the complete model using `swag_model.save_pretrained(path)`

Note that `trainer.save_model(path)` will save only the base model without the distribution parameters from SWAG.

For collecting the SWAG parameters, two possible schedules are supported:

* After the end of each training epoch (default, `collect_steps = 0` for `SwagUpdateCallback`)
* After each N training steps (set `collect_steps > 0` for `SwagUpdateCallback`)

### Sampling model parameters

After `swag_model` is trained or fine-tuned as described above,
`swag_model.sample_parameters()` should be called to sample new model
parameters. After that, `swag_model.forward()` can be used to predict
new output from classifiers and `swag_model.generate()` to generate
new output from generative LMs. In order to get a proper distribution
of outputs, `sample_parameters()` needs to be called each time before
`forward()` or `generate()`. For classifiers, the `SampleLogitsMixin`
class provides the convenience method `get_logits()` that samples the
parameters and makes a new prediction `num_predictions` times, and
returns the logit values in a tensor.

### Currently supported models

* BERT
* BERT (bidirectional encoder)
* `BertPreTrainedModel` -> `SwagBertPreTrainedModel`
* `BertModel` -> `SwagBertModel`
* `BertLMHeadModel` -> `SwagBertLMHeadModel`
* `BertForSequenceClassification` -> `SwagBertForSequenceClassification`
* BART (bidirectional encoder + causal decoder)
* `BartPreTrainedModel` -> `SwagBartPreTrainedModel`
* `BartModel` -> `SwagBartModel`
* `BartForConditionalGeneration` -> `SwagBartForConditionalGeneration`
* `BartForSequenceClassification` -> `SwagBartForSequenceClassification`
* MarianMT (bidirectional encoder + causal decoder)
* `MarianPreTrainedModel` -> `SwagMarianPreTrainedModel`
* `MarianModel` -> `SwagMarianModel`
* `MarianMTModel` -> `SwagMarianMTModel`
5 changes: 2 additions & 3 deletions examples/bert_snli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import argparse
import collections
import logging
import os
import sys

import torch
import transformers
Expand Down Expand Up @@ -35,7 +33,8 @@ def main():
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = transformers.AutoTokenizer.from_pretrained(args.base_model, cache_dir=args.model_cache_dir)
model = transformers.AutoModelForSequenceClassification.from_pretrained(args.base_model, num_labels=3, cache_dir=args.model_cache_dir)
model = transformers.AutoModelForSequenceClassification.from_pretrained(
args.base_model, num_labels=3, cache_dir=args.model_cache_dir)
model.to(device)
swag_model = SwagBertForSequenceClassification.from_base(model)
swag_model.to(device)
Expand Down
4 changes: 0 additions & 4 deletions examples/load_pretrained.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import argparse
import collections
import logging
import os
import sys

import transformers

from swag_transformers.swag_bert import SwagBertConfig, SwagBertForSequenceClassification
from swag_transformers.trainer_utils import SwagUpdateCallback


def main():
Expand Down
8 changes: 1 addition & 7 deletions examples/marian_mt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import argparse
import collections
import logging
import os
import sys

import torch
import transformers
Expand Down Expand Up @@ -50,11 +48,7 @@ def tokenize_function(example):
inputs = [pair['de'] for pair in example['translation']]
targets = [pair['nl'] for pair in example['translation']]
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, truncation=True)

labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"swa_gaussian>=0.1.6"
],
extras_require={
"test": ["datasets", "nose", "sentencepiece"]
"test": ["datasets", "pytest", "sentencepiece"]
},
classifiers=[
"Development Status :: 3 - Alpha",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
"""PyTorch SWAG wrapper for BERT"""
"""SWAG wrapper base classes"""

import copy
import functools
import logging
from typing import Union
from typing import Type

import torch

from transformers import PreTrainedModel, PretrainedConfig, BertConfig, BertLMHeadModel, BertModel, \
BertPreTrainedModel, BertForSequenceClassification
from transformers import PreTrainedModel, PretrainedConfig

from swag.posteriors.swag import SWAG


logger = logging.getLogger(__name__)


class SwagBertConfig(PretrainedConfig):
"""Config for BERT model averaging with SWAG"""
class SwagConfig(PretrainedConfig):
"""Base configuration class for SWAG models

model_type = 'swag_bert'
internal_config_class = BertConfig
For using this class, inherit it and define the following class
attributes:

- model_type: string
- internal_config_class: class inherited from PretrainedConfig

"""

internal_config_class: Type[PretrainedConfig] = PretrainedConfig

def __init__(
self,
Expand All @@ -41,26 +47,34 @@ def __init__(
self.var_clamp = var_clamp

@classmethod
def from_config(cls, base_config: BertConfig, **kwargs):
"""Initialize from existing BertConfig"""
def from_config(cls, base_config: PretrainedConfig, **kwargs):
"""Initialize from existing PretrainedConfig"""
config = cls(**kwargs)
config.internal_model_config = base_config.to_dict()
return config

def update_internal_config(self, base_config: BertConfig):
def update_internal_config(self, base_config: PretrainedConfig):
"""Update internal config from base_config"""
self.internal_model_config = base_config.to_dict()
# Copy some things to the top level
if base_config.problem_type is not None:
self.problem_type = base_config.problem_type


class SwagBertPreTrainedModel(PreTrainedModel):
"""Pretrained SWAG BERT model"""
class SwagPreTrainedModel(PreTrainedModel):
"""Base class for SWAG models wrapping PreTrainedModel

For using this class, inherit it and define the following class
attributes:

- base_model_prefix: string
- config_class: class inherited from PretrainedConfig
- internal_model_class: class inherited from PreTrainedModel

"""

config_class = SwagBertConfig
base_model_prefix = 'swag_bert'
internal_model_class = BertModel
config_class: Type[SwagConfig] = SwagConfig
internal_model_class: Type[PreTrainedModel] = PreTrainedModel

def __init__(self, config):
super().__init__(config)
Expand All @@ -87,9 +101,22 @@ def new_base_model(cls, *args, **kwargs):
def _init_weights(self, module):
self.swag.base._init_weights(module)

def sample_parameters(self):
"""Sample new model parameters"""
self.swag.sample()

class SwagBertModel(SwagBertPreTrainedModel):
"""SWAG BERT model"""

class SwagModel(SwagPreTrainedModel):
"""Base class for SWAG models

For using this class, inherit it and define the following class
attributes:

- base_model_prefix: string
- config_class: class inherited from PretrainedConfig
- internal_model_class: class inherited from PreTrainedModel

"""

def __init__(self, config, base_model=None):
super().__init__(config)
Expand All @@ -100,6 +127,8 @@ def __init__(self, config, base_model=None):
max_num_models=config.max_num_models,
var_clamp=config.var_clamp
)
self.prepare_inputs_for_generation = self.swag.base.prepare_inputs_for_generation
self.generate = self.swag.base.generate

@staticmethod
def _base_model_copy(model, *args, **kwargs):
Expand All @@ -111,20 +140,29 @@ def _base_model_copy(model, *args, **kwargs):
return model

@classmethod
def from_base(cls, base_model: BertPreTrainedModel, **kwargs):
"""Initialize from existing BertPreTrainedModel"""
config = SwagBertConfig.from_config(base_model.config, **kwargs)
def from_base(cls, base_model: PreTrainedModel, **kwargs):
"""Initialize from existing PreTrainedModel"""
config = cls.config_class.from_config(base_model.config, **kwargs)
swag_model = cls(config, base_model=base_model)
return swag_model

def forward(self, *args, **kwargs):
"""Call forward pass from the base model"""
return self.swag.forward(*args, **kwargs)

@classmethod
def can_generate(cls) -> bool:
return cls.internal_model_class.can_generate()

class SwagBertForSequenceClassification(SwagBertModel):
"""SWAG BERT model for sequence classification"""
def prepare_inputs_for_generation(self, *args, **kwargs):
return self.swag.base.prepare_inputs_for_generation(*args, **kwargs)

def generate(self, *args, **kwargs):
return self.swag.base.generate(*args, **kwargs)

internal_model_class = BertForSequenceClassification

class SampleLogitsMixin:
"""Mixin class for classification models providing get_logits() method using SWAG"""

def get_logits(
self, *args, num_predictions=None, scale=1.0, cov=True, block=False, **kwargs
Expand All @@ -142,17 +180,8 @@ def get_logits(
logits = []
for _ in range(num_predictions):
if sample:
self.swag.sample(scale=scale, cov=cov, block=block)
self.sample_parameters(scale=scale, cov=cov, block=block)
out = self.forward(*args, **kwargs)
logits.append(out.logits)
logits = torch.permute(torch.stack(logits), (1, 0, 2)) # [batch_size, num_predictions, output_size]
return logits


class SwagBertLMHeadModel(SwagBertModel):
"""SWAG BERT model with LM head"""

internal_model_class = BertLMHeadModel

def prepare_inputs_for_generation(self, *args, **kwargs):
return self.swag.base.prepare_inputs_for_generation(*args, **kwargs)
49 changes: 49 additions & 0 deletions src/swag_transformers/swag_bart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""PyTorch SWAG wrapper for BART"""

import logging

from transformers import BartConfig, BartModel, BartPreTrainedModel, BartForConditionalGeneration, \
BartForSequenceClassification

from .base import SwagConfig, SwagPreTrainedModel, SwagModel, SampleLogitsMixin


logger = logging.getLogger(__name__)


MODEL_TYPE = 'swag_bart'


class SwagBartConfig(SwagConfig):
"""Config for BART model averaging with SWAG"""

model_type = MODEL_TYPE
internal_config_class = BartConfig


class SwagBartPreTrainedModel(SwagPreTrainedModel):
"""Pretrained SWAG BART model"""

config_class = SwagBartConfig
base_model_prefix = MODEL_TYPE
internal_model_class = BartPreTrainedModel


class SwagBartModel(SwagModel):
"""SWAG BART model"""

config_class = SwagBartConfig
base_model_prefix = MODEL_TYPE
internal_model_class = BartModel


class SwagBartForConditionalGeneration(SwagBartModel):
"""SWAG BART model for sequence classification"""

internal_model_class = BartForConditionalGeneration


class SwagBartForSequenceClassification(SampleLogitsMixin, SwagBartModel):
"""SWAG BART model for sequence classification"""

internal_model_class = BartForSequenceClassification
Loading