Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add real-time processing for FRCRN_SE_16K #26

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clearvoice/config/inference/FRCRN_SE_16K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ decode_window: 1 #one-pass decoding length
#
# FFT parameters
win_type: 'hanning'
win_len: 640
win_inc: 320
win_len: 320
win_inc: 160
fft_len: 640
14 changes: 14 additions & 0 deletions clearvoice/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,17 @@

#2nd calling method: process video files listed in .scp file, and write outputs to 'path_to_output_videos_tse_scp/'
myClearVoice(input_path='samples/scp/video_samples.scp', online_write=True, output_path='samples/path_to_output_videos_tse_scp')

##-----Demo Six: use FRCRN_SE_16K model for real-time processing -----------------
if False:
myClearVoice = ClearVoice(task='speech_enhancement', model_names=['FRCRN_SE_16K'])

##1st calling method: process an input waveform in real-time and return output waveform, then write to output_FRCRN_SE_16K_realtime.wav
output_wav = myClearVoice(input_path='samples/input_realtime.wav', online_write=False)
myClearVoice.write(output_wav, output_path='samples/output_FRCRN_SE_16K_realtime.wav')

##2nd calling method: process all wav files in 'path_to_input_wavs_realtime/' in real-time and write outputs to 'path_to_output_wavs_realtime'
myClearVoice(input_path='samples/path_to_input_wavs_realtime', online_write=True, output_path='samples/path_to_output_wavs_realtime')

##3rd calling method: process wav files listed in .scp file in real-time, and write outputs to 'path_to_output_wavs_realtime_scp/'
myClearVoice(input_path='samples/scp/audio_samples_realtime.scp', online_write=True, output_path='samples/path_to_output_wavs_realtime_scp')
28 changes: 28 additions & 0 deletions clearvoice/demo_with_more_comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,31 @@
# - online_write (bool): Set to True to enable saving the enhanced output during processing
# - output_path (str): Path to the directory to save the enhanced output files
myClearVoice(input_path='samples/scp/audio_samples.scp', online_write=True, output_path='samples/path_to_output_wavs_scp')

## ---------------- Demo Three: Real-Time Processing -----------------------
if False: # This block demonstrates how to use the FRCRN_SE_16K model for real-time speech enhancement
# Initialize ClearVoice for the task of speech enhancement using the FRCRN_SE_16K model
myClearVoice = ClearVoice(task='speech_enhancement', model_names=['FRCRN_SE_16K'])

# 1st calling method:
# Process an input waveform in real-time and return the enhanced output waveform
# - input_path (str): Path to the input noisy audio file (input_realtime.wav)
# - output_wav (dict or ndarray) : The enhanced output waveform
output_wav = myClearVoice(input_path='samples/input_realtime.wav', online_write=False)
# Write the processed waveform to an output file
# - output_path (str): Path to save the enhanced audio file (output_FRCRN_SE_16K_realtime.wav)
myClearVoice.write(output_wav, output_path='samples/output_FRCRN_SE_16K_realtime.wav')

# 2nd calling method:
# Process and write audio files directly in real-time
# - input_path (str): Path to the directory of input noisy audio files
# - online_write (bool): Set to True to enable saving the enhanced audio directly to files during processing
# - output_path (str): Path to the directory to save the enhanced output files
myClearVoice(input_path='samples/path_to_input_wavs_realtime', online_write=True, output_path='samples/path_to_output_wavs_realtime')

# 3rd calling method:
# Use an .scp file to specify input audio paths for real-time processing
# - input_path (str): Path to a .scp file listing multiple audio file paths
# - online_write (bool): Set to True to enable saving the enhanced audio directly to files during processing
# - output_path (str): Path to the directory to save the enhanced output files
myClearVoice(input_path='samples/scp/audio_samples_realtime.scp', online_write=True, output_path='samples/path_to_output_wavs_realtime_scp')
61 changes: 58 additions & 3 deletions clearvoice/models/frcrn_se/frcrn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,35 @@ def __init__(self, args):
win_type=args.win_type
)

def forward(self, x):
def forward(self, x, real_time=False):
"""
Forward pass of the model.

Args:
x (torch.Tensor): Input tensor representing audio signals.
real_time (bool): Flag to indicate real-time processing.

Returns:
torch.Tensor: Processed output tensor after applying the model.
"""
output = self.model(x)
return output[1][0] # Return estimated waveform
if real_time:
return self.real_time_process(x)
else:
output = self.model(x)
return output[1][0] # Return estimated waveform

def real_time_process(self, x):
"""
Real-time processing method for the FRCRN model.

Args:
x (torch.Tensor): Input tensor representing audio signals.

Returns:
torch.Tensor: Processed output tensor after applying the model in real-time.
"""
output = self.model.real_time_process(x)
return output


class DCCRN(nn.Module):
Expand Down Expand Up @@ -249,3 +266,41 @@ def get_params(self, weight_decay=0.0):
}]
return params

def real_time_process(self, inputs):
"""
Real-time processing method for the DCCRN model.

Args:
inputs (torch.Tensor): Input tensor representing audio signals.

Returns:
torch.Tensor: Processed output tensor after applying the model in real-time.
"""
out_list = []
# Compute the complex spectrogram using STFT
cmp_spec = self.stft(inputs) # [B, D*2, T]
cmp_spec = torch.unsqueeze(cmp_spec, 1) # [B, 1, D*2, T]

# Split into real and imaginary parts
cmp_spec = torch.cat([
cmp_spec[:, :, :self.feat_dim, :], # Real part
cmp_spec[:, :, self.feat_dim:, :], # Imaginary part
], 1) # [B, 2, D, T]

cmp_spec = torch.unsqueeze(cmp_spec, 4) # [B, 2, D, T, 1]
cmp_spec = torch.transpose(cmp_spec, 1, 4) # [B, 1, D, T, 2]

# Pass through the UNet to estimate masks
unet1_out = self.unet(cmp_spec) # First UNet output
cmp_mask1 = torch.tanh(unet1_out) # First mask

unet2_out = self.unet2(unet1_out) # Second UNet output
cmp_mask2 = torch.tanh(unet2_out) # Second mask
cmp_mask2 = cmp_mask2 + cmp_mask1 # Combine masks

# Apply the estimated mask to the complex spectrogram
est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2)
out_list.append(est_spec)
out_list.append(est_wav)
out_list.append(est_mask)
return out_list