From 627ee921a8402d5ea8b37cdb5fb87c0422786d23 Mon Sep 17 00:00:00 2001 From: Sami Virpioja Date: Wed, 22 Jan 2025 11:22:10 +0200 Subject: [PATCH] update README --- README.md | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/README.md b/README.md index 399938f..475113c 100644 --- a/README.md +++ b/README.md @@ -124,3 +124,44 @@ matrix) and `scale=0.5` (recommended). For SWA, you should use * `MarianPreTrainedModel` -> `SwagMarianPreTrainedModel` * `MarianModel` -> `SwagMarianModel` * `MarianMTModel` -> `SwagMarianMTModel` + +### Registering classes for transformers Auto methods + +As the SWAG variants are not part of the Hugging Face libraries, you +need to register them in order for `AutoModelForXXX` methods to +work. For example, for `SwagBertForSequenceClassification` you need to +have: + +```python +transformers.AutoConfig.register("swag_bert", SwagBertConfig) +transformers.AutoModelForSequenceClassification.register(SwagBertConfig, SwagBertForSequenceClassification) +``` + +Now you can load a saved model with: + +```python +transformers.AutoModelForSequenceClassification.from_pretrained(path) +``` + +### Wrapping custom models + +It is easy to wrap any model class based on the `transformers` +library. For example, having `MyModel` with `MyModelConfig`, you can +define the SWAG variants as follows: + +```python +from swag_transformers.base import SwagConfig, SwagModel + +MODEL_TYPE = 'swag_mymodel' + +class SwagMyModelConfig(SwagConfig): + + model_type = MODEL_TYPE + internal_config_class = MyModelConfig + +class SwagMyModel(SwagModel): + + base_model_prefix = MODEL_TYPE + config_class = SwagMyModelConfig + internal_model_class = MyModel +```