Skip to content

Commit

Permalink
update README
Browse files Browse the repository at this point in the history
  • Loading branch information
svirpioj committed Jan 22, 2025
1 parent f3453e9 commit 627ee92
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,44 @@ matrix) and `scale=0.5` (recommended). For SWA, you should use
* `MarianPreTrainedModel` -> `SwagMarianPreTrainedModel`
* `MarianModel` -> `SwagMarianModel`
* `MarianMTModel` -> `SwagMarianMTModel`

### Registering classes for transformers Auto methods

As the SWAG variants are not part of the Hugging Face libraries, you
need to register them in order for `AutoModelForXXX` methods to
work. For example, for `SwagBertForSequenceClassification` you need to
have:

```python
transformers.AutoConfig.register("swag_bert", SwagBertConfig)
transformers.AutoModelForSequenceClassification.register(SwagBertConfig, SwagBertForSequenceClassification)
```

Now you can load a saved model with:

```python
transformers.AutoModelForSequenceClassification.from_pretrained(path)
```

### Wrapping custom models

It is easy to wrap any model class based on the `transformers`
library. For example, having `MyModel` with `MyModelConfig`, you can
define the SWAG variants as follows:

```python
from swag_transformers.base import SwagConfig, SwagModel

MODEL_TYPE = 'swag_mymodel'

class SwagMyModelConfig(SwagConfig):

model_type = MODEL_TYPE
internal_config_class = MyModelConfig

class SwagMyModel(SwagModel):

base_model_prefix = MODEL_TYPE
config_class = SwagMyModelConfig
internal_model_class = MyModel
```

0 comments on commit 627ee92

Please sign in to comment.