From 6369c21b485be476f016a636a8be9af7675a9b1d Mon Sep 17 00:00:00 2001 From: sileod Date: Thu, 2 Feb 2023 16:43:57 +0100 Subject: [PATCH] better save and test-eval --- src/tasknet/models.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/tasknet/models.py b/src/tasknet/models.py index 39e6fa2..88156f1 100755 --- a/src/tasknet/models.py +++ b/src/tasknet/models.py @@ -271,8 +271,8 @@ class default: task: dataset["test"] for task, dataset in self.processed_tasks.items() } - # We revents trainer from automatically evaluating on each dataset: - # transformerS.Trainer recognizes eval_dataset instances of "dict" + # We preventstrainer from automatically evaluating on each dataset: + # transformers.Trainer recognizes eval_dataset instances of "dict" # But we use a custom "evaluate" function so that we can use different metrics for each task self.eval_dataset = MappingProxyType(self.eval_dataset) self.fix_callback() @@ -308,7 +308,7 @@ def write_line(other, values): other.inner_table[0] = columns other.inner_table.append([values.get(c, np.nan) for c in columns]) - def evaluate(self, **kwargs): + def evaluate(self,metric_key_prefix="eval", **kwargs): try: i=[i for (i,c) in enumerate(self.callback_handler.callbacks) if 'NotebookProgress' in str(c)][0] self.callback_handler.callbacks[i].training_tracker.write_line = fc.partial( @@ -321,12 +321,13 @@ def evaluate(self, **kwargs): self.compute_metrics = task.compute_metrics output = transformers.Trainer.evaluate( self, - eval_dataset=dict([fc.nth(i, self.eval_dataset.items())]), + eval_dataset=dict([fc.nth(i, (self.eval_dataset if metric_key_prefix=="eval" else self.test_dataset).items())]), + metric_key_prefix=metric_key_prefix ) if "Accuracy" not in output: output["Accuracy"] = np.nan outputs += [output] - return fc.join(outputs) + return fc.join(outputs) if metric_key_prefix!="test" else outputs def task_batch_size(self,task_name): if hasattr(task_name, 'num_choices'): @@ -391,6 +392,9 @@ def get_test_dataloader(self, test_dataset=None): } ) + def save_model(self,output_dir,**kwargs): + self.model.factorize().save_pretrained(output_dir) + def preprocess_tasks(self, tasks, tokenizer): features_dict = {}