From b2eda52d9f1d9e8fc6c27d27068646e26136a7ac Mon Sep 17 00:00:00 2001 From: sileod Date: Sat, 24 Dec 2022 16:04:03 +0100 Subject: [PATCH] label sharing --- src/tasknet/models.py | 23 +++++++++++++++++------ src/tasknet/tasks.py | 8 +++++--- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/tasknet/models.py b/src/tasknet/models.py index 1f48cb5..d98f4fd 100755 --- a/src/tasknet/models.py +++ b/src/tasknet/models.py @@ -35,12 +35,12 @@ def forward(self, x): x[:, 0, :] = x[:, 0, :] + self.cls return x - class Model(transformers.PreTrainedModel): def __init__(self, tasks, args, warm_start=None): super().__init__(transformers.PretrainedConfig()) self.shared_encoder = warm_start mc_model = None + self.models={} task_models_list = [] for i, task in enumerate(tasks): model_type = eval(f"AutoModelFor{task.task_type}") @@ -50,11 +50,16 @@ def __init__(self, tasks, args, warm_start=None): model = model_type.from_pretrained(args.model_name, **nl) - if task.task_type=='MultipleChoice': - if not mc_model: - mc_model=model - else: - self.shallow_copy(mc_model.classifier, model.classifier) + if task.task_type=='MultipleChoice': + key="mc" + else: + labels = getattr(task.dataset['train'].features[task.y],"names",None) + key=(tuple(labels) if labels else None) + + if key and key not in self.models: + self.models[key] = model + if key and key in self.models: + self.shallow_copy(self.models[key].classifier, model.classifier) model.auto = getattr(model, self.get_encoder_attr_name(model)) @@ -378,3 +383,9 @@ def preprocess_tasks(self, tasks, tokenizer): ) task.processed_features=features_dict[task] #added return features_dict + + +def Model_Trainer(tasks, args): + model = Model(tasks, args) + trainer = Trainer(model, tasks, args) + return model, trainer \ No newline at end of file diff --git a/src/tasknet/tasks.py b/src/tasknet/tasks.py index 57da8ad..9f6a03f 100755 --- a/src/tasknet/tasks.py +++ b/src/tasknet/tasks.py @@ -88,15 +88,17 @@ def __post_init__(self): elif hasattr(target,'num_classes'): self.num_labels=target.num_classes else: - self.num_labels=len(set(fc.flatten(self.dataset['train'][self.y]))) + self.num_labels=max(fc.flatten(self.dataset['train'][self.y]))+1 if type(self.dataset['train'][self.y][0])==list: self.problem_type="multi_label_classification" if set(fc.flatten(self.dataset['train'][self.y]))!={0,1}: def one_hot(x): - x['labels'] = [float(i in x[self.y]) for i in range(self.num_labels)] + x[self.y] = [float(i in x[self.y]) for i in range(self.num_labels)] return x self.dataset=self.dataset.map(one_hot) + + self.num_labels=len(self.dataset['train'][self.y][0]) self.dataset=self.dataset.cast_column(self.y, ds.Sequence(feature=ds.Value(dtype='float64'))) def check(self): @@ -120,7 +122,7 @@ def compute_metrics(self, eval_pred): metric = load_metric("super_glue", "cb") predictions = np.argmax(predictions, axis=1) - elif self.problem_type=='multi_label_classification': + elif getattr(self,"problem_type", None)=='multi_label_classification': metric=evaluate.load('f1','multilabel', average='macro') labels=labels.astype(int) predictions = (expit(predictions)>0.5).astype(int)