This repository provides functionality for Stochastic Weight Averaging-Gaussian training for Transformer models. The implementation is tied into two libraries:
- transformers (maintained by Hugging Face)
- swa_gaussian (maintained by the Language Technology Research Group at the University of Helsinki)
The goal is to make an implementation that works directly with the
convenience tools in the transformers
library (e.g. Pipeline
and
Trainer
) as well as evaluator
from the related evaluate
library.
See also examples.
BERT model, sequence classification task:
- Load pretrained Bert model by
base_model = AutoModelForSequenceClassification.from_pretrained(name_or_path)
- Initialize SWAG model by
swag_model = SwagBertForSequenceClassification.from_base(base_model)
- Initialize SWAG callback object
swag_callback = SwagUpdateCallback(swag_model)
- Initialize
transformers.Trainer
with thebase_model
as model andswag_callback
in callbacks. - Train the model (
trainer.train()
) - Store the complete model using
swag_model.save_pretrained(path)
Note that trainer.save_model(path)
will save only the base model without the distribution parameters from SWAG.
For collecting the SWAG parameters, two possible schedules are supported:
- After the end of each training epoch (default,
collect_steps = 0
forSwagUpdateCallback
) - After each N training steps (set
collect_steps > 0
forSwagUpdateCallback
)
After swag_model
is trained or fine-tuned as described above,
swag_model.sample_parameters()
should be called to sample new model
parameters. After that, swag_model.forward()
can be used to predict
new output from classifiers and swag_model.generate()
to generate
new output from generative LMs. In order to get a proper distribution
of outputs, sample_parameters()
needs to be called each time before
forward()
or generate()
. For classifiers, the SampleLogitsMixin
class provides the convenience method get_logits()
that samples the
parameters and makes a new prediction num_predictions
times, and
returns the logit values in a tensor.
- BERT (bidirectional encoder)
BertPreTrainedModel
->SwagBertPreTrainedModel
BertModel
->SwagBertModel
BertLMHeadModel
->SwagBertLMHeadModel
BertForSequenceClassification
->SwagBertForSequenceClassification
- BART (bidirectional encoder + causal decoder)
BartPreTrainedModel
->SwagBartPreTrainedModel
BartModel
->SwagBartModel
BartForConditionalGeneration
->SwagBartForConditionalGeneration
BartForSequenceClassification
->SwagBartForSequenceClassification
- MarianMT (bidirectional encoder + causal decoder)
MarianPreTrainedModel
->SwagMarianPreTrainedModel
MarianModel
->SwagMarianModel
MarianMTModel
->SwagMarianMTModel