From e04923f2c772d35001e2e8af53390f5076a5e81a Mon Sep 17 00:00:00 2001 From: Fabio Catania Date: Mon, 3 Feb 2025 22:03:04 +0100 Subject: [PATCH 1/3] Update speechbrain.py Fixing enhancement with segment duration < kernel size --- .../tasks/speech_enhancement/speechbrain.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/senselab/audio/tasks/speech_enhancement/speechbrain.py b/src/senselab/audio/tasks/speech_enhancement/speechbrain.py index cab822c6..2402b92a 100644 --- a/src/senselab/audio/tasks/speech_enhancement/speechbrain.py +++ b/src/senselab/audio/tasks/speech_enhancement/speechbrain.py @@ -17,6 +17,7 @@ class SpeechBrainEnhancer: """A factory for managing SpeechBrain enhancement pipelines.""" MAX_DURATION_SECONDS = 60 # Maximum duration per segment in seconds + MIN_LENGTH = 16 # kernel size for speechbrain/sepformer-wham16k-enhancement _models: Dict[str, Union[separator, enhance_model]] = {} @classmethod @@ -105,16 +106,21 @@ def enhance_audios_with_speechbrain( enhanced_segments = [] for segment in segments: - if isinstance(enhancer, enhance_model): - enhanced_waveform = enhancer.enhance_batch(segment.waveform, lengths=torch.tensor([1.0])) + if segment.waveform.shape[-1] < cls.MIN_LENGTH: + print(f"Skipping segment with length {segment.waveform.shape[-1]}") + # Append it as it is + enhanced_segments.append(segment) else: - enhanced_waveform = enhancer.separate_batch(segment.waveform) - - enhanced_segments.append( - Audio(waveform=enhanced_waveform.reshape(1, -1), sampling_rate=segment.sampling_rate) - ) - # TODO: decide what to do with metadata - + if isinstance(enhancer, enhance_model): + enhanced_waveform = enhancer.enhance_batch(segment.waveform, lengths=torch.tensor([1.0])) + else: + enhanced_waveform = enhancer.separate_batch(segment.waveform) + + enhanced_segments.append( + Audio(waveform=enhanced_waveform.reshape(1, -1), sampling_rate=segment.sampling_rate) + ) + + # TODO: decide what to do with metadata enhanced_audio = concatenate_audios(enhanced_segments) enhanced_audio.metadata = audio.metadata enhanced_audios.append(enhanced_audio) From c5b2c1fff106ebdc25bace6106f613c7bdf42b45 Mon Sep 17 00:00:00 2001 From: Fabio Catania Date: Mon, 3 Feb 2025 22:14:25 +0100 Subject: [PATCH 2/3] Update speechbrain.py Fixing style --- src/senselab/audio/tasks/speech_enhancement/speechbrain.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/senselab/audio/tasks/speech_enhancement/speechbrain.py b/src/senselab/audio/tasks/speech_enhancement/speechbrain.py index 2402b92a..6b9b2570 100644 --- a/src/senselab/audio/tasks/speech_enhancement/speechbrain.py +++ b/src/senselab/audio/tasks/speech_enhancement/speechbrain.py @@ -17,7 +17,7 @@ class SpeechBrainEnhancer: """A factory for managing SpeechBrain enhancement pipelines.""" MAX_DURATION_SECONDS = 60 # Maximum duration per segment in seconds - MIN_LENGTH = 16 # kernel size for speechbrain/sepformer-wham16k-enhancement + MIN_LENGTH = 16 # kernel size for speechbrain/sepformer-wham16k-enhancement _models: Dict[str, Union[separator, enhance_model]] = {} @classmethod @@ -115,11 +115,11 @@ def enhance_audios_with_speechbrain( enhanced_waveform = enhancer.enhance_batch(segment.waveform, lengths=torch.tensor([1.0])) else: enhanced_waveform = enhancer.separate_batch(segment.waveform) - + enhanced_segments.append( Audio(waveform=enhanced_waveform.reshape(1, -1), sampling_rate=segment.sampling_rate) ) - + # TODO: decide what to do with metadata enhanced_audio = concatenate_audios(enhanced_segments) enhanced_audio.metadata = audio.metadata From 6248eb4cf58f71059162eb10afed3df32d1d90b6 Mon Sep 17 00:00:00 2001 From: Fabio Catania Date: Mon, 3 Feb 2025 22:42:40 +0100 Subject: [PATCH 3/3] Update preprocessing.py --- src/senselab/audio/tasks/preprocessing/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/senselab/audio/tasks/preprocessing/preprocessing.py b/src/senselab/audio/tasks/preprocessing/preprocessing.py index c9d6f41f..801298e9 100644 --- a/src/senselab/audio/tasks/preprocessing/preprocessing.py +++ b/src/senselab/audio/tasks/preprocessing/preprocessing.py @@ -246,7 +246,7 @@ def concatenate_audios(audios: List[Audio]) -> Audio: if audio.waveform.shape[0] != num_channels: raise ValueError("All audios must have the same number of channels (mono or stereo) to concatenate.") - concatenated_waveform = torch.cat([audio.waveform for audio in audios], dim=1) + concatenated_waveform = torch.cat([audio.waveform.cpu() for audio in audios], dim=1) # TODO: do we want to concatenate metadata? TBD