Skip to content

Commit

Permalink
better save and test-eval
Browse files Browse the repository at this point in the history
  • Loading branch information
sileod committed Feb 2, 2023
1 parent 544eb26 commit 6369c21
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ class default:
task: dataset["test"]
for task, dataset in self.processed_tasks.items()
}
# We revents trainer from automatically evaluating on each dataset:
# transformerS.Trainer recognizes eval_dataset instances of "dict"
# We preventstrainer from automatically evaluating on each dataset:
# transformers.Trainer recognizes eval_dataset instances of "dict"
# But we use a custom "evaluate" function so that we can use different metrics for each task
self.eval_dataset = MappingProxyType(self.eval_dataset)
self.fix_callback()
Expand Down Expand Up @@ -308,7 +308,7 @@ def write_line(other, values):
other.inner_table[0] = columns
other.inner_table.append([values.get(c, np.nan) for c in columns])

def evaluate(self, **kwargs):
def evaluate(self,metric_key_prefix="eval", **kwargs):
try:
i=[i for (i,c) in enumerate(self.callback_handler.callbacks) if 'NotebookProgress' in str(c)][0]
self.callback_handler.callbacks[i].training_tracker.write_line = fc.partial(
Expand All @@ -321,12 +321,13 @@ def evaluate(self, **kwargs):
self.compute_metrics = task.compute_metrics
output = transformers.Trainer.evaluate(
self,
eval_dataset=dict([fc.nth(i, self.eval_dataset.items())]),
eval_dataset=dict([fc.nth(i, (self.eval_dataset if metric_key_prefix=="eval" else self.test_dataset).items())]),
metric_key_prefix=metric_key_prefix
)
if "Accuracy" not in output:
output["Accuracy"] = np.nan
outputs += [output]
return fc.join(outputs)
return fc.join(outputs) if metric_key_prefix!="test" else outputs

def task_batch_size(self,task_name):
if hasattr(task_name, 'num_choices'):
Expand Down Expand Up @@ -391,6 +392,9 @@ def get_test_dataloader(self, test_dataset=None):
}
)

def save_model(self,output_dir,**kwargs):
self.model.factorize().save_pretrained(output_dir)

def preprocess_tasks(self, tasks, tokenizer):

features_dict = {}
Expand Down

0 comments on commit 6369c21

Please sign in to comment.