Skip to content

Commit

Permalink
label sharing
Browse files Browse the repository at this point in the history
  • Loading branch information
sileod committed Dec 24, 2022
1 parent ae8012d commit b2eda52
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
23 changes: 17 additions & 6 deletions src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def forward(self, x):
x[:, 0, :] = x[:, 0, :] + self.cls
return x


class Model(transformers.PreTrainedModel):
def __init__(self, tasks, args, warm_start=None):
super().__init__(transformers.PretrainedConfig())
self.shared_encoder = warm_start
mc_model = None
self.models={}
task_models_list = []
for i, task in enumerate(tasks):
model_type = eval(f"AutoModelFor{task.task_type}")
Expand All @@ -50,11 +50,16 @@ def __init__(self, tasks, args, warm_start=None):

model = model_type.from_pretrained(args.model_name, **nl)

if task.task_type=='MultipleChoice':
if not mc_model:
mc_model=model
else:
self.shallow_copy(mc_model.classifier, model.classifier)
if task.task_type=='MultipleChoice':
key="mc"
else:
labels = getattr(task.dataset['train'].features[task.y],"names",None)
key=(tuple(labels) if labels else None)

if key and key not in self.models:
self.models[key] = model
if key and key in self.models:
self.shallow_copy(self.models[key].classifier, model.classifier)

model.auto = getattr(model, self.get_encoder_attr_name(model))

Expand Down Expand Up @@ -378,3 +383,9 @@ def preprocess_tasks(self, tasks, tokenizer):
)
task.processed_features=features_dict[task] #added
return features_dict


def Model_Trainer(tasks, args):
model = Model(tasks, args)
trainer = Trainer(model, tasks, args)
return model, trainer
8 changes: 5 additions & 3 deletions src/tasknet/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,17 @@ def __post_init__(self):
elif hasattr(target,'num_classes'):
self.num_labels=target.num_classes
else:
self.num_labels=len(set(fc.flatten(self.dataset['train'][self.y])))
self.num_labels=max(fc.flatten(self.dataset['train'][self.y]))+1

if type(self.dataset['train'][self.y][0])==list:
self.problem_type="multi_label_classification"
if set(fc.flatten(self.dataset['train'][self.y]))!={0,1}:
def one_hot(x):
x['labels'] = [float(i in x[self.y]) for i in range(self.num_labels)]
x[self.y] = [float(i in x[self.y]) for i in range(self.num_labels)]
return x
self.dataset=self.dataset.map(one_hot)

self.num_labels=len(self.dataset['train'][self.y][0])
self.dataset=self.dataset.cast_column(self.y, ds.Sequence(feature=ds.Value(dtype='float64')))

def check(self):
Expand All @@ -120,7 +122,7 @@ def compute_metrics(self, eval_pred):
metric = load_metric("super_glue", "cb")
predictions = np.argmax(predictions, axis=1)

elif self.problem_type=='multi_label_classification':
elif getattr(self,"problem_type", None)=='multi_label_classification':
metric=evaluate.load('f1','multilabel', average='macro')
labels=labels.astype(int)
predictions = (expit(predictions)>0.5).astype(int)
Expand Down

0 comments on commit b2eda52

Please sign in to comment.