diff --git a/src/senselab/audio/tasks/preprocessing/preprocessing.py b/src/senselab/audio/tasks/preprocessing/preprocessing.py index c9d6f41f..801298e9 100644 --- a/src/senselab/audio/tasks/preprocessing/preprocessing.py +++ b/src/senselab/audio/tasks/preprocessing/preprocessing.py @@ -246,7 +246,7 @@ def concatenate_audios(audios: List[Audio]) -> Audio: if audio.waveform.shape[0] != num_channels: raise ValueError("All audios must have the same number of channels (mono or stereo) to concatenate.") - concatenated_waveform = torch.cat([audio.waveform for audio in audios], dim=1) + concatenated_waveform = torch.cat([audio.waveform.cpu() for audio in audios], dim=1) # TODO: do we want to concatenate metadata? TBD diff --git a/src/senselab/audio/tasks/speech_enhancement/speechbrain.py b/src/senselab/audio/tasks/speech_enhancement/speechbrain.py index cab822c6..6b9b2570 100644 --- a/src/senselab/audio/tasks/speech_enhancement/speechbrain.py +++ b/src/senselab/audio/tasks/speech_enhancement/speechbrain.py @@ -17,6 +17,7 @@ class SpeechBrainEnhancer: """A factory for managing SpeechBrain enhancement pipelines.""" MAX_DURATION_SECONDS = 60 # Maximum duration per segment in seconds + MIN_LENGTH = 16 # kernel size for speechbrain/sepformer-wham16k-enhancement _models: Dict[str, Union[separator, enhance_model]] = {} @classmethod @@ -105,16 +106,21 @@ def enhance_audios_with_speechbrain( enhanced_segments = [] for segment in segments: - if isinstance(enhancer, enhance_model): - enhanced_waveform = enhancer.enhance_batch(segment.waveform, lengths=torch.tensor([1.0])) + if segment.waveform.shape[-1] < cls.MIN_LENGTH: + print(f"Skipping segment with length {segment.waveform.shape[-1]}") + # Append it as it is + enhanced_segments.append(segment) else: - enhanced_waveform = enhancer.separate_batch(segment.waveform) + if isinstance(enhancer, enhance_model): + enhanced_waveform = enhancer.enhance_batch(segment.waveform, lengths=torch.tensor([1.0])) + else: + enhanced_waveform = enhancer.separate_batch(segment.waveform) - enhanced_segments.append( - Audio(waveform=enhanced_waveform.reshape(1, -1), sampling_rate=segment.sampling_rate) - ) - # TODO: decide what to do with metadata + enhanced_segments.append( + Audio(waveform=enhanced_waveform.reshape(1, -1), sampling_rate=segment.sampling_rate) + ) + # TODO: decide what to do with metadata enhanced_audio = concatenate_audios(enhanced_segments) enhanced_audio.metadata = audio.metadata enhanced_audios.append(enhanced_audio)