Skip to content

Commit

Permalink
allow cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
sileod committed Jan 20, 2023
1 parent f1ee5b2 commit 4f8c594
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def __init__(self, tasks, args, warm_start=None):
self.shared_encoder = warm_start
self.models={}
self.task_names = [t.name for t in tasks]

task_models_list = []
for i, task in progress(list(enumerate(tasks))):
model_type = eval(f"AutoModelFor{task.task_type}")
Expand Down Expand Up @@ -98,9 +97,10 @@ def __init__(self, tasks, args, warm_start=None):

self.task_models_list = nn.ModuleList(task_models_list)

device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
self.Z = nn.parameter.Parameter(
torch.zeros(len(tasks),
self.shared_encoder.config.hidden_size, device=torch.cuda.current_device()),
self.shared_encoder.config.hidden_size, device=device),
requires_grad=len(tasks)>1
)

Expand Down

0 comments on commit 4f8c594

Please sign in to comment.