From 9e9192d736537cc6f9434dc90e9c477d85700570 Mon Sep 17 00:00:00 2001 From: peach-water Date: Tue, 14 May 2024 22:26:14 +0800 Subject: [PATCH] Implement the DCT function using numpy --- python/rapid_paraformer/kaldifeat/feature.py | 25 ++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/python/rapid_paraformer/kaldifeat/feature.py b/python/rapid_paraformer/kaldifeat/feature.py index a6c6a6c..20460e1 100644 --- a/python/rapid_paraformer/kaldifeat/feature.py +++ b/python/rapid_paraformer/kaldifeat/feature.py @@ -1,5 +1,4 @@ import numpy as np -from scipy.fftpack import dct # ---------- feature-window ---------- @@ -137,6 +136,28 @@ def extract_window(waveform, blackman_coeff, dither, window_size, window_shift, # ---------- feature-functions ---------- +def dct_np(x: np.array, norm=None): + x_shape = x.shape + N = x_shape[-1] + + v = np.hstack([x[:, ::2], x[:, 1::2][:, ::-1]]) + Vc = np.fft.fft(v) + Vc = np.dstack([np.real(Vc), np.imag(Vc)]) + + k = - np.arange(N, dtype=x.dtype)[None, :] * np.pi / (2*N) + W_r = np.cos(k) + W_i = np.sin(k) + + V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i + + if norm == "ortho": + V[:, 0] /= np.sqrt(N) * 2 + V[:, 1:] /= np.sqrt(N/2) * 2 + + V = 2 * V + return V + + def compute_spectrum(frames, n): complex_spec = np.fft.rfft(frames, n) return np.absolute(complex_spec) @@ -424,7 +445,7 @@ def compute_mfcc_feats( window_type=window_type, dtype=dtype ) - feat = dct(feat, type=2, axis=1, norm='ortho')[:, :num_ceps] + feat = dct_np(feat, norm="ortho")[:, :num_ceps] lifter_coeffs = compute_lifter_coeffs(cepstral_lifter, num_ceps).astype(dtype) feat = feat * lifter_coeffs if use_energy: