Skip to content

Commit

Permalink
more multiplechoice flexibility
Browse files Browse the repository at this point in the history
  • Loading branch information
sileod committed Jan 19, 2023
1 parent 6a5ca4b commit f1ee5b2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/tasknet/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __post_init__(self):
else:
self.num_labels=max(fc.flatten(self.dataset['train'][self.y]))+1

if type(self.dataset['train'][self.y][0])==list:
if type(self.dataset['train'][self.y][0])==list and self.task_type=="SequenceClassification":
self.problem_type="multi_label_classification"
if set(fc.flatten(self.dataset['train'][self.y]))!={0,1}:
def one_hot(x):
Expand Down
8 changes: 8 additions & 0 deletions src/tasknet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
from easydict import EasyDict as edict
import copy
import functools
from tqdm.auto import tqdm

class Shutup_tqdm:
def __enter__(self):
tqdm.__init__ = functools.partialmethod(tqdm.__init__, disable=True)
def __exit__(self, exc_type, exc_value, exc_traceback):
tqdm.__init__ = functools.partialmethod(tqdm.__init__, disable=False)



def train_validation_test_split(dataset, train_ratio=0.8, val_test_ratio=0.5, seed=0):
Expand Down

0 comments on commit f1ee5b2

Please sign in to comment.