diff --git a/pyha_analyzer/dataset.py b/pyha_analyzer/dataset.py index 54f1d26..d9397e0 100644 --- a/pyha_analyzer/dataset.py +++ b/pyha_analyzer/dataset.py @@ -256,7 +256,7 @@ def to_image(self, audio): # Sigmoid to get 0 to 1 scaling (0.5 becomes mean) mel = torch.sigmoid(mel) - return mel.unsqueeze(0) #torch.stack([mel, mel, mel]) + return torch.stack([mel, mel, mel]) def __getitem__(self, index): #-> Any: """ Takes an index and returns tuple of spectrogram image with corresponding label