Skip to content

Commit

Permalink
remove extra imports
Browse files Browse the repository at this point in the history
  • Loading branch information
svirpioj committed Jun 26, 2024
1 parent 071cd1a commit ed3c831
Show file tree
Hide file tree
Showing 8 changed files with 8 additions and 21 deletions.
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[flake8]
max-line-length = 127
5 changes: 2 additions & 3 deletions examples/bert_snli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import argparse
import collections
import logging
import os
import sys

import torch
import transformers
Expand Down Expand Up @@ -35,7 +33,8 @@ def main():
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = transformers.AutoTokenizer.from_pretrained(args.base_model, cache_dir=args.model_cache_dir)
model = transformers.AutoModelForSequenceClassification.from_pretrained(args.base_model, num_labels=3, cache_dir=args.model_cache_dir)
model = transformers.AutoModelForSequenceClassification.from_pretrained(
args.base_model, num_labels=3, cache_dir=args.model_cache_dir)
model.to(device)
swag_model = SwagBertForSequenceClassification.from_base(model)
swag_model.to(device)
Expand Down
4 changes: 0 additions & 4 deletions examples/load_pretrained.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import argparse
import collections
import logging
import os
import sys

import transformers

from swag_transformers.swag_bert import SwagBertConfig, SwagBertForSequenceClassification
from swag_transformers.trainer_utils import SwagUpdateCallback


def main():
Expand Down
2 changes: 0 additions & 2 deletions examples/marian_mt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import argparse
import collections
import logging
import os
import sys

import torch
import transformers
Expand Down
3 changes: 1 addition & 2 deletions src/swag_transformers/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""SWAG wrapper base classes"""

from abc import ABCMeta, abstractmethod
import copy
import functools
import logging
from typing import Union, Type
from typing import Type

import torch

Expand Down
1 change: 0 additions & 1 deletion src/swag_transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging

from swag.posteriors.swag import SWAG
from transformers import TrainerCallback


Expand Down
9 changes: 2 additions & 7 deletions tests/test_swag_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@
import unittest
import tempfile

import numpy as np
import torch

from datasets import Dataset, DatasetDict
from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification, \
AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments, BartForConditionalGeneration
from transformers import AutoTokenizer, BartForConditionalGeneration

from swag_transformers.swag_bart import SwagBartConfig, SwagBartModel, SwagBartPreTrainedModel, \
SwagBartForSequenceClassification, SwagBartForConditionalGeneration
from swag_transformers.trainer_utils import SwagUpdateCallback
SwagBartForConditionalGeneration


class TestSwagBart(unittest.TestCase):
Expand Down Expand Up @@ -77,7 +73,6 @@ def test_pretrained_bart_generative(self):
self.assertEqual(base_out, out)



if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
unittest.main()
3 changes: 1 addition & 2 deletions tests/test_swag_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import unittest
import tempfile

import numpy as np
import torch

from datasets import Dataset, DatasetDict
from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification, AutoModelWithLMHead, \
from transformers import AutoModel, AutoModelForSequenceClassification, \
AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments

from swag_transformers.swag_bert import SwagBertConfig, SwagBertLMHeadModel, SwagBertModel, SwagBertPreTrainedModel, \
Expand Down

0 comments on commit ed3c831

Please sign in to comment.