From cdeb881abeb599e3a912aa3897cabb78522efca8 Mon Sep 17 00:00:00 2001 From: zhaotw <504924839@qq.com> Date: Thu, 9 May 2024 10:42:19 +0800 Subject: [PATCH] fix: pad in collate func when training use full song(or no segment), need to pad labels to the max length in the batch --- src/allin1/training/data/datasets/collate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/allin1/training/data/datasets/collate.py b/src/allin1/training/data/datasets/collate.py index 64569cb..01e8d83 100644 --- a/src/allin1/training/data/datasets/collate.py +++ b/src/allin1/training/data/datasets/collate.py @@ -21,7 +21,7 @@ def collate_fn(raw_batch): 'true_beat', 'true_downbeat', 'true_section', 'true_function', 'widen_true_beat', 'widen_true_downbeat', 'widen_true_section', ]: - data[key] = value[:max_T] + data[key] = np.pad(value, (0, max_T - value.shape[0]), 'constant') elif key in ['spec']: T = raw_data[key].shape[1] spec = raw_data[key]