From 5431ecb39ce7f3bef305bb20d1bfbc716009a922 Mon Sep 17 00:00:00 2001 From: sileod Date: Mon, 27 Feb 2023 09:19:37 +0100 Subject: [PATCH] collator --- src/tasknet/models.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/tasknet/models.py b/src/tasknet/models.py index e07a786..0cdc8f9 100755 --- a/src/tasknet/models.py +++ b/src/tasknet/models.py @@ -213,7 +213,12 @@ def __init__(self, tasks): def __call__( self, features: List[Union[InputDataClass, Dict]] ) -> Dict[str, torch.Tensor]: - task_index = features[0]["task"].flatten()[0].item() + try: + task_index = features[0]["task"].flatten()[0].item() + except: + print("features:",features) + task_index = features[-1]["task"].flatten()[0].item() + features = [{k:v for k,v in x.items() if k!='task'} for x in features] collated = self.tasks[task_index].data_collator.__call__(features) collated['task']=torch.tensor([task_index])