diff --git a/src/tasknet/tasks.py b/src/tasknet/tasks.py index 2b7e6cc..d54edea 100755 --- a/src/tasknet/tasks.py +++ b/src/tasknet/tasks.py @@ -13,7 +13,7 @@ from frozendict import frozendict as fdict import funcy as fc import evaluate -from dataclasses import dataclass +from dataclasses import dataclass, field import re from transformers.tokenization_utils_base import PreTrainedTokenizerBase import inspect @@ -199,7 +199,7 @@ def __call__(self, features): class MultipleChoice(Classification): task_type = "MultipleChoice" num_labels:int = 2 - data_collator:...= DataCollatorForMultipleChoice() + data_collator:...= field(default_factory=DataCollatorForMultipleChoice) choices: ... = tuple() s1: str = "inputs" @@ -345,7 +345,7 @@ def check(self): @dataclass class Seq2SeqLM(Task): task_type='Seq2SeqLM' - data_collator:...=DataCollatorForSeq2Seq(None) + data_collator:... = field(default_factory=lambda: DataCollatorForSeq2Seq(None)) s1:str='' s2:str='' metric:...=evaluate.load("bleu")