diff --git a/src/tasknet/tasks.py b/src/tasknet/tasks.py index cfa84b1..1a70f0a 100755 --- a/src/tasknet/tasks.py +++ b/src/tasknet/tasks.py @@ -102,7 +102,7 @@ def __post_init__(self): else: self.num_labels=max(fc.flatten(self.dataset['train'][self.y]))+1 - if type(self.dataset['train'][self.y][0])==list: + if type(self.dataset['train'][self.y][0])==list and self.task_type=="SequenceClassification": self.problem_type="multi_label_classification" if set(fc.flatten(self.dataset['train'][self.y]))!={0,1}: def one_hot(x): diff --git a/src/tasknet/utils.py b/src/tasknet/utils.py index 7de6337..68c0bdc 100755 --- a/src/tasknet/utils.py +++ b/src/tasknet/utils.py @@ -2,6 +2,14 @@ from easydict import EasyDict as edict import copy import functools +from tqdm.auto import tqdm + +class Shutup_tqdm: + def __enter__(self): + tqdm.__init__ = functools.partialmethod(tqdm.__init__, disable=True) + def __exit__(self, exc_type, exc_value, exc_traceback): + tqdm.__init__ = functools.partialmethod(tqdm.__init__, disable=False) + def train_validation_test_split(dataset, train_ratio=0.8, val_test_ratio=0.5, seed=0):