From 358f6f6ed327ec50ce9dd515edb82a9983aeda4b Mon Sep 17 00:00:00 2001 From: Damien Sileo Date: Fri, 30 Jun 2023 11:08:30 +0200 Subject: [PATCH] added AutoTask --- src/tasknet/tasks.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/tasknet/tasks.py b/src/tasknet/tasks.py index 861abed..2b7e6cc 100755 --- a/src/tasknet/tasks.py +++ b/src/tasknet/tasks.py @@ -409,4 +409,24 @@ def compute_metrics(self, eval_preds): result=self._explode(result) meta = {"name": self.name, "size": len(decoded_preds), "index": self.index} - return {**result,**meta} \ No newline at end of file + return {**result,**meta} + + +def AutoTask(dataset, **kwargs): + if type(dataset)==str: + try: + import tasksource + except: + raise ImportError('To use this feature, use a valid tasksource id and pip install tasksource') + try: + dataset = tasksource.load_task(dataset) + except: + raise ValueError('pick an id from https://github.com/sileod/tasksource/blob/main/tasks.md or write your own preprocessing') + features=dataset['train'].features + if 'sentence1' in features: + return Classification(dataset, **kwargs) + if 'choice' in str(features): + return MultipleChoice(dataset, **kwargs) + if 'tokens' in features: + return TokenClassification(dataset,**kwargs) + \ No newline at end of file