diff --git a/src/tasknet/models.py b/src/tasknet/models.py index fc02371..21f0031 100755 --- a/src/tasknet/models.py +++ b/src/tasknet/models.py @@ -219,7 +219,7 @@ def forward(self, task, **kwargs): return y def factorize(self, base_index=0, tasks=[],labels=[]): - m_i = self.task_models_list[base_index] + m_i = copy.deepcopy(self.task_models_list[base_index]) classifiers = torch.nn.ModuleList([a.classifier for a in self.task_models_list]) if hasattr(m_i,'auto'): del m_i.auto