diff --git a/src/tasknet/models.py b/src/tasknet/models.py index a21f6da..53b7a2d 100755 --- a/src/tasknet/models.py +++ b/src/tasknet/models.py @@ -26,13 +26,12 @@ import magicattr import gc import random +from tqdm.auto import tqdm def progress(l): - try: - from tqdm.auto import tqdm - assert len(l)>8 + if len(l)>8: return tqdm(l) - except: + else: return l @@ -43,7 +42,7 @@ def __init__(self, Zi, drop_probability=0.0): self.drop_probability=drop_probability def forward(self, x): if random.random()>self.drop_probability: - x[:, 0, :] = x[:, 0, :] + self.cls + x[:, 0, :] = x[:, 0, :] + self.cls.to(x.device) return x class WandbTaskCallback(transformers.integrations.WandbCallback): @@ -64,7 +63,7 @@ def __init__(self, tasks, args, warm_start=None): self.shared_encoder = warm_start self.models={} task_models_list = [] - for i, task in progress(enumerate(tasks)): + for i, task in progress(list(enumerate(tasks))): model_type = eval(f"AutoModelFor{task.task_type}") nl = {a: getattr(task, a) for a in ('num_labels','problem_type') if hasattr(task, a) @@ -98,7 +97,7 @@ def __init__(self, tasks, args, warm_start=None): self.Z = nn.parameter.Parameter( torch.zeros(len(tasks), - self.shared_encoder.config.hidden_size, device="cuda"), + self.shared_encoder.config.hidden_size, device=torch.cuda.current_device()), requires_grad=len(tasks)>1 ) @@ -135,6 +134,16 @@ def forward(self, task, **kwargs): y = self.task_models_list[task_index](**kwargs) return y + def factorize(self, base_index=0, tasks=[]): + m_i = self.task_models_list[base_index] + m_i.Z = self.Z + m_i.classifiers = torch.nn.ModuleList([a.classifier for a in self.task_models_list]) + m_i.config = m_i.config.from_dict( + {**m_i.config.to_dict(), + 'classifiers_size': [tuple(c.weight.shape) for c in m_i.classifiers], + 'tasks':tasks + }) + return m_i class NLPDataCollator: def __init__(self, tasks):