Skip to content

Commit

Permalink
fix push to hub
Browse files Browse the repository at this point in the history
  • Loading branch information
sileod committed Mar 14, 2023
1 parent 2c67fdd commit ba37176
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Adapter(transformers.PreTrainedModel):
config_class = transformers.PretrainedConfig
def __init__(self, config, classifiers=None, Z=None):
super().__init__(config)
self.Z= torch.nn.Embedding(config.hidden_size, len(config.classifiers_size)) if Z==None else Z
self.Z= torch.nn.Embedding(len(config.classifiers_size),config.hidden_size).weight if Z==None else Z
self.classifiers=torch.nn.ModuleList(
[torch.nn.Linear(config.hidden_size,size) for size in config.classifiers_size]
) if classifiers==None else classifiers
Expand All @@ -48,6 +48,8 @@ def adapt_model_to_task(self, model, task_name):
task_index=self.config['tasks'].index(task_name)
last_linear(model).weight = last_linear(model.classifiers[task_index]).weight
return model
def _init_weights(*args):
pass


class ConditionalLayerNorm(torch.nn.Module):
Expand Down Expand Up @@ -218,21 +220,20 @@ def forward(self, task, **kwargs):

def factorize(self, base_index=0, tasks=[],labels=[]):
m_i = self.task_models_list[base_index]
m_i.Z = self.Z
m_i.classifiers = torch.nn.ModuleList([a.classifier for a in self.task_models_list])
classifiers = torch.nn.ModuleList([a.classifier for a in self.task_models_list])
if hasattr(m_i,'auto'):
del m_i.auto

id2label=dict(enumerate(labels))
label2id = {str(v):str(k) for k,v in id2label.items()}
label2id = {str(v):k for k,v in id2label.items()}

m_i.config = m_i.config.from_dict(
{**m_i.config.to_dict(),
'classifiers_size': [c.out_features for c in m_i.classifiers],
'classifiers_size': [c.out_features for c in classifiers],
'tasks': (tasks if tasks else self.task_names),
'label2id':label2id,'id2label':id2label
})
adapter=Adapter(m_i.config)
adapter=Adapter(m_i.config, classifiers, self.Z)
return remove_cls(m_i), adapter


Expand Down

0 comments on commit ba37176

Please sign in to comment.