diff --git a/data_utils.py b/data_utils.py index cf7f442..263dc7e 100644 --- a/data_utils.py +++ b/data_utils.py @@ -148,7 +148,7 @@ def __call__(self, batch): wav_padded[i, :, :wav.size(1)] = wav wav_lengths[i] = wav.size(1) - emo[:] = row[3] + emo[i, :] = row[3] if self.return_ids: return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, ids_sorted_decreasing