Skip to content

Commit

Permalink
export util
Browse files Browse the repository at this point in the history
  • Loading branch information
sileod committed Jan 9, 2023
1 parent a244fc1 commit 6a61b88
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@
import magicattr
import gc
import random
from tqdm.auto import tqdm

def progress(l):
try:
from tqdm.auto import tqdm
assert len(l)>8
if len(l)>8:
return tqdm(l)
except:
else:
return l


Expand All @@ -43,7 +42,7 @@ def __init__(self, Zi, drop_probability=0.0):
self.drop_probability=drop_probability
def forward(self, x):
if random.random()>self.drop_probability:
x[:, 0, :] = x[:, 0, :] + self.cls
x[:, 0, :] = x[:, 0, :] + self.cls.to(x.device)
return x

class WandbTaskCallback(transformers.integrations.WandbCallback):
Expand All @@ -64,7 +63,7 @@ def __init__(self, tasks, args, warm_start=None):
self.shared_encoder = warm_start
self.models={}
task_models_list = []
for i, task in progress(enumerate(tasks)):
for i, task in progress(list(enumerate(tasks))):
model_type = eval(f"AutoModelFor{task.task_type}")
nl = {a: getattr(task, a) for a in ('num_labels','problem_type')
if hasattr(task, a)
Expand Down Expand Up @@ -98,7 +97,7 @@ def __init__(self, tasks, args, warm_start=None):

self.Z = nn.parameter.Parameter(
torch.zeros(len(tasks),
self.shared_encoder.config.hidden_size, device="cuda"),
self.shared_encoder.config.hidden_size, device=torch.cuda.current_device()),
requires_grad=len(tasks)>1
)

Expand Down Expand Up @@ -135,6 +134,16 @@ def forward(self, task, **kwargs):
y = self.task_models_list[task_index](**kwargs)
return y

def factorize(self, base_index=0, tasks=[]):
m_i = self.task_models_list[base_index]
m_i.Z = self.Z
m_i.classifiers = torch.nn.ModuleList([a.classifier for a in self.task_models_list])
m_i.config = m_i.config.from_dict(
{**m_i.config.to_dict(),
'classifiers_size': [tuple(c.weight.shape) for c in m_i.classifiers],
'tasks':tasks
})
return m_i

class NLPDataCollator:
def __init__(self, tasks):
Expand Down

0 comments on commit 6a61b88

Please sign in to comment.