Skip to content

Commit

Permalink
Merge pull request #250 from sensein/fabiocat93-patch-5
Browse files Browse the repository at this point in the history
Fixing enhancement with segment duration < kernel size
  • Loading branch information
fabiocat93 authored Feb 3, 2025
2 parents fb1b538 + 6248eb4 commit 06a511d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/senselab/audio/tasks/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def concatenate_audios(audios: List[Audio]) -> Audio:
if audio.waveform.shape[0] != num_channels:
raise ValueError("All audios must have the same number of channels (mono or stereo) to concatenate.")

concatenated_waveform = torch.cat([audio.waveform for audio in audios], dim=1)
concatenated_waveform = torch.cat([audio.waveform.cpu() for audio in audios], dim=1)

# TODO: do we want to concatenate metadata? TBD

Expand Down
20 changes: 13 additions & 7 deletions src/senselab/audio/tasks/speech_enhancement/speechbrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class SpeechBrainEnhancer:
"""A factory for managing SpeechBrain enhancement pipelines."""

MAX_DURATION_SECONDS = 60 # Maximum duration per segment in seconds
MIN_LENGTH = 16 # kernel size for speechbrain/sepformer-wham16k-enhancement
_models: Dict[str, Union[separator, enhance_model]] = {}

@classmethod
Expand Down Expand Up @@ -105,16 +106,21 @@ def enhance_audios_with_speechbrain(
enhanced_segments = []

for segment in segments:
if isinstance(enhancer, enhance_model):
enhanced_waveform = enhancer.enhance_batch(segment.waveform, lengths=torch.tensor([1.0]))
if segment.waveform.shape[-1] < cls.MIN_LENGTH:
print(f"Skipping segment with length {segment.waveform.shape[-1]}")
# Append it as it is
enhanced_segments.append(segment)
else:
enhanced_waveform = enhancer.separate_batch(segment.waveform)
if isinstance(enhancer, enhance_model):
enhanced_waveform = enhancer.enhance_batch(segment.waveform, lengths=torch.tensor([1.0]))
else:
enhanced_waveform = enhancer.separate_batch(segment.waveform)

enhanced_segments.append(
Audio(waveform=enhanced_waveform.reshape(1, -1), sampling_rate=segment.sampling_rate)
)
# TODO: decide what to do with metadata
enhanced_segments.append(
Audio(waveform=enhanced_waveform.reshape(1, -1), sampling_rate=segment.sampling_rate)
)

# TODO: decide what to do with metadata
enhanced_audio = concatenate_audios(enhanced_segments)
enhanced_audio.metadata = audio.metadata
enhanced_audios.append(enhanced_audio)
Expand Down

0 comments on commit 06a511d

Please sign in to comment.