Skip to content

Commit

Permalink
collator
Browse files Browse the repository at this point in the history
  • Loading branch information
sileod committed Feb 27, 2023
1 parent 4b26ff5 commit 5431ecb
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,12 @@ def __init__(self, tasks):
def __call__(
self, features: List[Union[InputDataClass, Dict]]
) -> Dict[str, torch.Tensor]:
task_index = features[0]["task"].flatten()[0].item()
try:
task_index = features[0]["task"].flatten()[0].item()
except:
print("features:",features)
task_index = features[-1]["task"].flatten()[0].item()

features = [{k:v for k,v in x.items() if k!='task'} for x in features]
collated = self.tasks[task_index].data_collator.__call__(features)
collated['task']=torch.tensor([task_index])
Expand Down

0 comments on commit 5431ecb

Please sign in to comment.