diff --git a/src/tasknet/models.py b/src/tasknet/models.py index c72edd9..9a120a9 100755 --- a/src/tasknet/models.py +++ b/src/tasknet/models.py @@ -63,7 +63,6 @@ def __init__(self, tasks, args, warm_start=None): self.shared_encoder = warm_start self.models={} self.task_names = [t.name for t in tasks] - task_models_list = [] for i, task in progress(list(enumerate(tasks))): model_type = eval(f"AutoModelFor{task.task_type}") @@ -98,9 +97,10 @@ def __init__(self, tasks, args, warm_start=None): self.task_models_list = nn.ModuleList(task_models_list) + device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" self.Z = nn.parameter.Parameter( torch.zeros(len(tasks), - self.shared_encoder.config.hidden_size, device=torch.cuda.current_device()), + self.shared_encoder.config.hidden_size, device=device), requires_grad=len(tasks)>1 )