Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[Test] improve tf distill ut effiecnty (#512)
Browse files Browse the repository at this point in the history
* improve tf distill ut effiecnty

Signed-off-by: Guo, Heng <[email protected]>
ntel.com>
  • Loading branch information
n1ck-guo authored Oct 23, 2023
1 parent 6599bdc commit 70ff034
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions tests/test_tf_autodistillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def setUpClass(self):
self.strategy = tf.distribute.MultiWorkerMirroredStrategy()
set_seed(42)
self.model = TFAutoModelForSequenceClassification.from_pretrained(
'distilbert-base-uncased')
'hf-internal-testing/tiny-random-DistilBertMode')
self.teacher_model = TFAutoModelForSequenceClassification.from_pretrained(
'distilbert-base-uncased-finetuned-sst-2-english')
'hf-internal-testing/tiny-random-DistilBertForSequenceClassification')

raw_datasets = load_dataset("glue", "sst2")["validation"]
self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-DistilBertForSequenceClassification")
non_label_column_names = [
name for name in raw_datasets.column_names if name != "label"
]
Expand Down
10 changes: 5 additions & 5 deletions tests/test_tf_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ class TestDistillation(unittest.TestCase):
def setUpClass(self):
set_seed(42)
self.model = TFAutoModelForSequenceClassification.from_pretrained(
'distilbert-base-uncased')
'hf-internal-testing/tiny-random-DistilBertModel')
self.teacher_model = TFAutoModelForSequenceClassification.from_pretrained(
'distilbert-base-uncased-finetuned-sst-2-english')
'hf-internal-testing/tiny-random-DistilBertForSequenceClassification')

raw_datasets = load_dataset("glue", "sst2")["validation"]
self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-DistilBertForSequenceClassification")
non_label_column_names = [
name for name in raw_datasets.column_names if name != "label"
]
Expand Down Expand Up @@ -107,13 +107,13 @@ def eval_func(model):
eval_func=eval_func,
train_func=self.optimizer.build_train_func
)
distilled_model = self.optimizer.distill(
distilled_model2 = self.optimizer.distill(
distillation_config=distillation_conf,
teacher_model=self.teacher_model,
eval_func=None,
train_func=None
)
# distilled_weight = copy.deepcopy(distilled_model.model.classifier.get_weights())
self.assertEqual(distilled_model.signatures['serving_default'].output_shapes['Identity'], distilled_model2.signatures['serving_default'].output_shapes['Identity'])


if __name__ == "__main__":
Expand Down

0 comments on commit 70ff034

Please sign in to comment.