diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 126d608..be6b7b2 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -10,3 +10,10 @@ sphinx: conda: environment: rtd_environment.yml + +# This part is necessary otherwise the project is not built +python: + version: 3.9 + install: + - method: pip + path: . diff --git a/extensions/ivector/ivector.cpp b/extensions/ivector/ivector.cpp index 6fc97c6..e0d5785 100644 --- a/extensions/ivector/ivector.cpp +++ b/extensions/ivector/ivector.cpp @@ -499,6 +499,29 @@ void pybind_ivector_extractor(py::module &m) { py::arg("opts"), py::arg("extractor"), py::call_guard()) + .def(py::pickle( + [](const PyClass &p) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + std::ostringstream os; + bool binary = true; + p.Write(os, binary); + return py::make_tuple( + py::bytes(os.str())); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 1) + throw std::runtime_error("Invalid state!"); + + /* Create a new C++ instance */ + PyClass *p = new PyClass(); + + /* Assign any additional state */ + std::istringstream str(t[0].cast()); + p->Read(str, true); + + return p; + } + )) .def("update", []( PyClass &stats, IvectorExtractor &extractor, @@ -930,6 +953,29 @@ void pybind_plda(py::module &m) { }, py::arg("utterance_ivector"), py::arg("transformed_enrolled_ivectors")) + .def(py::pickle( + [](const PyClass &p) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + std::ostringstream os; + bool binary = true; + p.Write(os, binary); + return py::make_tuple( + py::bytes(os.str())); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 1) + throw std::runtime_error("Invalid state!"); + + /* Create a new C++ instance */ + PyClass *p = new PyClass(); + + /* Assign any additional state */ + std::istringstream str(t[0].cast()); + p->Read(str, true); + + return p; + } + )) .def("TransformIvector", py::overload_cast &, @@ -1363,6 +1409,22 @@ void init_ivector(py::module &_m) { py::arg("normalize") = true, py::arg("scaleup") = true); + m.def("ivector_normalize_length", + []( + Vector* ivector, + bool normalize = true, + bool scaleup = true + ) { + py::gil_scoped_release gil_release; + double norm = ivector->Norm(2.0); + double ratio = norm / sqrt(ivector->Dim()); + if (!scaleup) ratio = norm; + if (normalize) ivector->Scale(1.0 / ratio); + }, + py::arg("ivector"), + py::arg("normalize") = true, + py::arg("scaleup") = true); + m.def("ivector_subtract_mean", []( std::vector*> &ivectors diff --git a/extensions/transform/transform.cpp b/extensions/transform/transform.cpp index f8b7771..0cdac16 100644 --- a/extensions/transform/transform.cpp +++ b/extensions/transform/transform.cpp @@ -159,6 +159,25 @@ void pybind_cmvn(py::module &m) { py::arg("uttlist"), py::arg("feat_reader")); + m.def("apply_cmvn", + []( + const Matrix &feats, + const Matrix &cmvn_stats, + bool reverse = false, + bool norm_vars = false + ){ + py::gil_scoped_release release; + Matrix feat_out(feats); + if (reverse) { + ApplyCmvnReverse(cmvn_stats, norm_vars, &feat_out); + } else { + ApplyCmvn(cmvn_stats, norm_vars, &feat_out); + } + + return feat_out; + }, + py::arg("feats"), py::arg("cmvn_stats"), py::arg("reverse") = false, py::arg("norm_vars") = false); + m.def("ApplyCmvn", &ApplyCmvn, "Apply cepstral mean and variance normalization to a matrix of features. " @@ -380,6 +399,8 @@ void pybind_fmllr_diag_gmm(py::module &m) { py::arg("feats")) .def("accumulate_from_alignment", [](PyClass& spk_stats, + const TransitionModel &alignment_trans_model, + const AmDiagGmm &alignment_am_gmm, const TransitionModel &trans_model, const AmDiagGmm &am_gmm, const Matrix &feats, @@ -391,49 +412,51 @@ void pybind_fmllr_diag_gmm(py::module &m) { bool two_models = false ){ py::gil_scoped_release gil_release; - Posterior pdf_post; - Posterior post; + Posterior posterior; + + AlignmentToPosterior(ali, &posterior); - AlignmentToPosterior(ali, &post); if (distributed) - WeightSilencePostDistributed(trans_model, silence_set, - silence_scale, &post); + WeightSilencePostDistributed(alignment_trans_model, silence_set, + silence_scale, &posterior); else - WeightSilencePost(trans_model, silence_set, - silence_scale, &post); - ConvertPosteriorToPdfs(trans_model, post, &pdf_post); + WeightSilencePost(alignment_trans_model, silence_set, + silence_scale, &posterior); + + Posterior pdf_posterior; + ConvertPosteriorToPdfs(alignment_trans_model, posterior, &pdf_posterior); if (!two_models){ - for (size_t i = 0; i < pdf_post.size(); i++) { - for (size_t j = 0; j < pdf_post[i].size(); j++) { - int32 pdf_id = pdf_post[i][j].first; - spk_stats.AccumulateForGmm(am_gmm.GetPdf(pdf_id), + for (size_t i = 0; i < pdf_posterior.size(); i++) { + for (size_t j = 0; j < pdf_posterior[i].size(); j++) { + int32 pdf_id = pdf_posterior[i][j].first; + spk_stats.AccumulateForGmm(alignment_am_gmm.GetPdf(pdf_id), feats.Row(i), - pdf_post[i][j].second); + pdf_posterior[i][j].second); } } } else{ - GaussPost gpost(pdf_post.size()); + GaussPost gpost(posterior.size()); BaseFloat tot_like_this_file = 0.0, tot_weight = 0.0; - for (size_t i = 0; i < pdf_post.size(); i++) { - gpost[i].reserve(pdf_post[i].size()); - for (size_t j = 0; j < pdf_post[i].size(); j++) { - int32 pdf_id = pdf_post[i][j].first; - BaseFloat weight = pdf_post[i][j].second; - const DiagGmm &gmm = am_gmm.GetPdf(pdf_id); + for (size_t i = 0; i < posterior.size(); i++) { + gpost[i].reserve(pdf_posterior[i].size()); + for (size_t j = 0; j < pdf_posterior[i].size(); j++) { + int32 pdf_id = pdf_posterior[i][j].first; + BaseFloat weight = pdf_posterior[i][j].second; + const DiagGmm &gmm = alignment_am_gmm.GetPdf(pdf_id); Vector this_post_vec; BaseFloat like = gmm.ComponentPosteriors(feats.Row(i), &this_post_vec); this_post_vec.Scale(weight); if (rand_prune > 0.0) - for (int32 k = 0; k < this_post_vec.Dim(); k++) - this_post_vec(k) = RandPrune(this_post_vec(k), - rand_prune); + for (int32 k = 0; k < this_post_vec.Dim(); k++) + this_post_vec(k) = RandPrune(this_post_vec(k), + rand_prune); if (!this_post_vec.IsZero()) - gpost[i].push_back(std::make_pair(pdf_id, this_post_vec)); + gpost[i].push_back(std::make_pair(pdf_id, this_post_vec)); tot_like_this_file += like * weight; tot_weight += weight; } @@ -450,6 +473,8 @@ void pybind_fmllr_diag_gmm(py::module &m) { } } }, + py::arg("alignment_trans_model"), + py::arg("alignment_am_gmm"), py::arg("trans_model"), py::arg("am_gmm"), py::arg("feats"), @@ -461,6 +486,8 @@ void pybind_fmllr_diag_gmm(py::module &m) { py::arg("two_models") = false) .def("accumulate_from_lattice", [](PyClass* spk_stats, + const TransitionModel &alignment_trans_model, + const AmDiagGmm &alignment_am_gmm, const TransitionModel &trans_model, const AmDiagGmm &am_gmm, const Matrix &feats, @@ -490,13 +517,13 @@ void pybind_fmllr_diag_gmm(py::module &m) { Posterior post; double lat_like = LatticeForwardBackward(lat, &post); if (distributed) - WeightSilencePostDistributed(trans_model, silence_set, + WeightSilencePostDistributed(alignment_trans_model, silence_set, silence_scale, &post); else - WeightSilencePost(trans_model, silence_set, + WeightSilencePost(alignment_trans_model, silence_set, silence_scale, &post); Posterior pdf_post; - ConvertPosteriorToPdfs(trans_model, post, &pdf_post); + ConvertPosteriorToPdfs(alignment_trans_model, post, &pdf_post); if (!two_models){ for (size_t i = 0; i < post.size(); i++) { for (size_t j = 0; j < pdf_post[i].size(); j++) { @@ -517,7 +544,7 @@ void pybind_fmllr_diag_gmm(py::module &m) { for (size_t j = 0; j < pdf_post[i].size(); j++) { int32 pdf_id = pdf_post[i][j].first; BaseFloat weight = pdf_post[i][j].second; - const DiagGmm &gmm = am_gmm.GetPdf(pdf_id); + const DiagGmm &gmm = alignment_am_gmm.GetPdf(pdf_id); Vector this_post_vec; BaseFloat like = gmm.ComponentPosteriors(feats.Row(i), &this_post_vec); @@ -542,6 +569,8 @@ void pybind_fmllr_diag_gmm(py::module &m) { } } }, + py::arg("alignment_trans_model"), + py::arg("alignment_am_gmm"), py::arg("trans_model"), py::arg("am_gmm"), py::arg("feats"), @@ -567,13 +596,12 @@ void pybind_fmllr_diag_gmm(py::module &m) { .def("compute_transform", [](PyClass& f, const AmDiagGmm &am_gmm, const FmllrOptions &fmllr_opts){ - py::gil_scoped_release gil_release; BaseFloat impr, tot_t; Matrix transform(am_gmm.Dim(), am_gmm.Dim()+1); { transform.SetUnit(); f.Update(fmllr_opts, &transform, &impr, &tot_t); - return transform; + return py::make_tuple(transform, impr, tot_t); } }, py::arg("am_gmm"), diff --git a/kalpy/data.py b/kalpy/data.py index eb41f41..8acdda2 100644 --- a/kalpy/data.py +++ b/kalpy/data.py @@ -33,7 +33,9 @@ class Segment: channel: typing.Optional[int] = 0 def load_audio(self): - duration = self.end - self.begin + duration = None + if self.end is not None and self.begin is not None: + duration = self.end - self.begin y, _ = librosa.load( self.file_path, sr=16000, diff --git a/kalpy/feat/cmvn.py b/kalpy/feat/cmvn.py index aa9ac10..f6a2dc5 100644 --- a/kalpy/feat/cmvn.py +++ b/kalpy/feat/cmvn.py @@ -90,13 +90,26 @@ def compute_cmvn_for_export( Returns ------- - :class:`_kalpy.matrix.FloatMatrixBase` + :class:`_kalpy.matrix.DoubleMatrix` Feature matrix for the segment """ - cmvn, num_done, num_error = transform.calculate_cmvn(utterance_list, feature_reader) + if False: + cmvn_stats = DoubleMatrix() + is_init = False + num_done = 0 + num_error = 0 + for utt in utterance_list: + print(utt) + feats = feature_reader.Value(utt) + if not is_init: + transform.InitCmvnStats(feats.NumCols(), cmvn_stats) + is_init = True + transform.AccCmvnStats(feats, None, cmvn_stats) + num_done += 1 + cmvn_stats, num_done, num_error = transform.calculate_cmvn(utterance_list, feature_reader) self.num_done += num_done self.num_error += num_error - return cmvn + return cmvn_stats def export_cmvn( self, diff --git a/kalpy/feat/data.py b/kalpy/feat/data.py index e910cf9..899706f 100644 --- a/kalpy/feat/data.py +++ b/kalpy/feat/data.py @@ -6,7 +6,7 @@ import typing from _kalpy import feat, transform -from _kalpy.matrix import FloatMatrix +from _kalpy.matrix import DoubleMatrix, FloatMatrix from _kalpy.util import ( RandomAccessBaseDoubleMatrixReader, RandomAccessBaseFloatMatrixReader, @@ -35,15 +35,19 @@ def __init__( sliding_cmvn_window: int = 300, sliding_cmvn_center_window: bool = True, double: bool = False, + callback: typing.Callable = None, ): self.cmvn_reader = None self.transform_reader = None self.vad_reader = None if not os.path.exists(file_name): raise OSError(f"Specified file does not exist: {file_name}") + self.file_name = str(file_name) self.archive = MatrixArchive(file_name, double=double) self.utt2spk = utt2spk + self.double = double self.subsample_n = subsample_n + self.callback = callback self.use_sliding_cmvn = use_sliding_cmvn self.cmvn_norm_vars = cmvn_norm_vars @@ -58,12 +62,12 @@ def __init__( self.splice_frames = splice_frames self.use_deltas = deltas self.use_splices = splices - self.cmvn_file_name = cmvn_file_name + self.cmvn_read_specifier = None if cmvn_file_name: - cmvn_read_specifier = generate_read_specifier(cmvn_file_name) - self.cmvn_reader = RandomAccessBaseDoubleMatrixReader(cmvn_read_specifier) + self.cmvn_read_specifier = generate_read_specifier(cmvn_file_name) + self.cmvn_reader = RandomAccessBaseDoubleMatrixReader(self.cmvn_read_specifier) - self.lda_mat_file_name = lda_mat_file_name + self.lda_mat_file_name = None self.lda_mat = None if lda_mat_file_name: self.use_splices = True @@ -71,14 +75,17 @@ def __init__( self.lda_mat_file_name = str(lda_mat_file_name) self.lda_mat = read_kaldi_object(FloatMatrix, self.lda_mat_file_name) self.transform_file_name = transform_file_name + self.transform_read_specifier = None if transform_file_name: - transform_read_specifier = generate_read_specifier(transform_file_name) - self.transform_reader = RandomAccessBaseFloatMatrixReader(transform_read_specifier) + self.transform_read_specifier = generate_read_specifier(transform_file_name) + self.transform_reader = RandomAccessBaseFloatMatrixReader( + self.transform_read_specifier + ) - self.vad_file_name = vad_file_name + self.vad_read_specifier = None if vad_file_name: - vad_read_specifier = generate_read_specifier(vad_file_name) - self.vad_reader = RandomAccessBaseFloatVectorReader(vad_read_specifier) + self.vad_read_specifier = generate_read_specifier(vad_file_name) + self.vad_reader = RandomAccessBaseFloatVectorReader(self.vad_read_specifier) self.current_speaker = None self.trans = None self.cmvn_stats = None @@ -110,23 +117,27 @@ def __iter__(self) -> typing.Generator[typing.Tuple[str, FloatMatrix]]: speaker = self.utt2spk[utt] else: speaker = None - # Apply CMVN - if self.cmvn_file_name and speaker is not None: - if self.current_speaker != speaker: + if self.current_speaker != speaker: + if self.cmvn_reader and speaker is not None: if not self.cmvn_reader.HasKey(speaker): raise Exception( - f"Could not find key {speaker} in {self.cmvn_file_name}" + f"Could not find key {speaker} in {self.cmvn_read_specifier}" ) self.cmvn_stats = self.cmvn_reader.Value(speaker) - if self.transform_reader is not None and self.transform_reader.HasKey( - speaker - ): + if self.transform_reader is not None: + if self.transform_reader.HasKey(speaker): self.trans = self.transform_reader.Value(speaker) - self.current_speaker = speaker - if self.cmvn_reverse: - transform.ApplyCmvnReverse(self.cmvn_stats, self.cmvn_norm_vars, feats) - else: - transform.ApplyCmvn(self.cmvn_stats, self.cmvn_norm_vars, feats) + else: + self.trans = None + self.current_speaker = speaker + # Apply CMVN + if self.cmvn_stats is not None: + feats = transform.apply_cmvn( + feats, + self.cmvn_stats, + reverse=self.cmvn_reverse, + norm_vars=self.cmvn_norm_vars, + ) elif self.use_sliding_cmvn: feats = feat.sliding_window_cmn(self.sliding_cmvn_options, feats) @@ -163,19 +174,22 @@ def __getitem__(self, item: str) -> FloatMatrix: speaker = self.utt2spk[item] else: speaker = None - # Apply CMVN - if self.cmvn_reader is not None and speaker is not None: - if self.current_speaker != speaker: + if self.current_speaker != speaker: + if self.cmvn_reader and speaker is not None: if not self.cmvn_reader.HasKey(speaker): - raise Exception(f"Could not find key {speaker} in {self.cmvn_file_name}") + raise Exception(f"Could not find key {speaker} in {self.cmvn_read_specifier}") self.cmvn_stats = self.cmvn_reader.Value(speaker) - if self.transform_reader is not None and self.transform_reader.HasKey(speaker): + if self.transform_reader is not None: + if self.transform_reader.HasKey(speaker): self.trans = self.transform_reader.Value(speaker) - self.current_speaker = speaker - if self.cmvn_reverse: - transform.ApplyCmvnReverse(self.cmvn_stats, self.cmvn_norm_vars, feats) - else: - transform.ApplyCmvn(self.cmvn_stats, self.cmvn_norm_vars, feats) + else: + self.trans = None + self.current_speaker = speaker + # Apply CMVN + if self.cmvn_stats is not None: + feats = transform.apply_cmvn( + feats, self.cmvn_stats, reverse=self.cmvn_reverse, norm_vars=self.cmvn_norm_vars + ) elif self.use_sliding_cmvn: feats = feat.sliding_window_cmn(self.sliding_cmvn_options, feats) diff --git a/kalpy/feat/fmllr.py b/kalpy/feat/fmllr.py index 7c6be40..7f1a621 100644 --- a/kalpy/feat/fmllr.py +++ b/kalpy/feat/fmllr.py @@ -6,7 +6,7 @@ import threading import typing -from _kalpy import transform +from _kalpy import gmm, hmm, transform from _kalpy.util import BaseFloatMatrixWriter, ConstIntegerSet from kalpy.data import KaldiMapping, MatrixArchive from kalpy.feat.data import FeatureArchive @@ -23,10 +23,10 @@ class FmllrComputer: def __init__( self, + alignment_acoustic_model_path: typing.Union[pathlib.Path, str], acoustic_model_path: typing.Union[pathlib.Path, str], silence_phones: typing.List[int], spk2utt: KaldiMapping = None, - two_models: bool = True, weight_distribute: bool = False, fmllr_update_type: str = "full", silence_weight: float = 0.0, @@ -36,22 +36,34 @@ def __init__( thread_lock: typing.Optional[threading.Lock] = None, ): self.acoustic_model_path = acoustic_model_path + self.alignment_acoustic_model_path = alignment_acoustic_model_path self.transition_model, self.acoustic_model = read_gmm_model(self.acoustic_model_path) + self.two_models = self.alignment_acoustic_model_path != self.acoustic_model_path + if self.two_models: + self.alignment_transition_model, self.alignment_acoustic_model = read_gmm_model( + self.alignment_acoustic_model_path + ) + else: + self.alignment_transition_model, self.alignment_acoustic_model = ( + self.transition_model, + self.acoustic_model, + ) self.spk2utt = spk2utt self.silence_weight = silence_weight self.acoustic_scale = acoustic_scale - self.two_models = two_models self.silence_phones = silence_phones self.weight_distribute = weight_distribute self.fmllr_update_type = fmllr_update_type self.fmllr_min_count = fmllr_min_count self.fmllr_num_iters = fmllr_num_iters self.thread_lock = thread_lock + self.callback_frequency = 100 def compute_fmllr( self, feature_archive: FeatureArchive, alignment_archive: typing.Union[AlignmentArchive, LatticeArchive], + callback: typing.Callable = None, ): fmllr_options = transform.FmllrOptions() fmllr_options.update_type = self.fmllr_update_type @@ -67,19 +79,32 @@ def compute_fmllr( silence_set = ConstIntegerSet(self.silence_phones) if self.spk2utt is not None: for spk, utt_list in self.spk2utt.items(): - spk_stats = transform.FmllrDiagGmmAccs(am_dim, fmllr_options) + spk_stats = transform.FmllrDiagGmmAccs(am_dim) logger.info(f"Processing speaker {spk}...") for utterance_id in utt_list: try: alignment = alignment_archive[utterance_id] except KeyError: - logger.info(f"Skipping {utterance_id} due to missing lattice.") + logger.info(f"Skipping {utterance_id} due to missing alignment.") + num_skipped += 1 continue if use_alignment: alignment = alignment_archive[utterance_id].alignment - feats = feature_archive[utterance_id] + try: + feats = feature_archive[utterance_id] + except KeyError: + logger.info(f"Skipping {utterance_id} due to missing features.") + num_skipped += 1 + continue + if feats.NumRows() == 0: + logger.warning(f"Skipping {utterance_id} due to zero-length features") + num_skipped += 1 + continue + num_done += 1 if use_alignment: spk_stats.accumulate_from_alignment( + self.alignment_transition_model, + self.alignment_acoustic_model, self.transition_model, self.acoustic_model, feats, @@ -92,6 +117,8 @@ def compute_fmllr( ) else: spk_stats.accumulate_from_lattice( + self.alignment_transition_model, + self.alignment_acoustic_model, self.transition_model, self.acoustic_model, feats, @@ -103,16 +130,24 @@ def compute_fmllr( distributed=self.weight_distribute, two_models=self.two_models, ) + if callback is not None and num_done % self.callback_frequency == 0: + callback(self.callback_frequency) if self.thread_lock is not None: self.thread_lock.acquire() - trans = transform.compute_fmllr_transform( - spk_stats, self.acoustic_model.Dim(), fmllr_options + trans, impr, spk_tot_t = spk_stats.compute_transform( + self.acoustic_model, fmllr_options ) if self.thread_lock is not None: self.thread_lock.release() - num_done += 1 - - yield spk, trans + if spk_tot_t: + logger.debug( + f"For speaker {spk}, auxf-impr from fMLLR is {impr/spk_tot_t}, over {spk_tot_t} frames." + ) + tot_impr += impr + tot_t += spk_tot_t + yield spk, trans + else: + logger.debug(f"Skipping speaker {spk} due to no data") else: for utterance_id, feats in feature_archive: @@ -159,8 +194,12 @@ def compute_fmllr( num_done += 1 yield utterance_id, trans - logger.info(f"Done {num_done} speakers.") - logger.info(f"Skipped {num_skipped} speakers.") + if callback is not None and num_done % self.callback_frequency == 0: + callback(self.callback_frequency) + if callback is not None and num_done % self.callback_frequency: + callback(num_done % self.callback_frequency) + logger.info(f"Done {num_done} utterances.") + logger.info(f"Skipped {num_skipped} utterances.") if tot_t: logger.info( f"Overall fMLLR auxf impr per frame is {tot_impr / tot_t} over {tot_t} frames." @@ -181,13 +220,13 @@ def export_transforms( prev_reader = None if previous_transform_archive is not None: prev_reader = previous_transform_archive.random_reader - for speaker, trans in self.compute_fmllr(feature_archive, alignment_archive): - if callback: - callback(speaker) + for speaker, trans in self.compute_fmllr( + feature_archive, alignment_archive, callback=callback + ): if previous_transform_archive is not None: if prev_reader.HasKey(speaker): prev_trans = prev_reader.Value(speaker) - new_trans = transform.compose_transforms(prev_trans, trans, True) + new_trans = transform.compose_transforms(trans, prev_trans, True) trans = new_trans writer.Write(str(speaker), trans) finally: diff --git a/kalpy/feat/lda.py b/kalpy/feat/lda.py index 2ca9c86..7d0cd69 100644 --- a/kalpy/feat/lda.py +++ b/kalpy/feat/lda.py @@ -41,10 +41,14 @@ def accumulate_stats( silence_weight = 0.0 silence_set = ConstIntegerSet(self.silence_phones) num_done = 0 - for alignment in alignment_archive: - feats = feature_archive[alignment.utterance_id] + for utterance_id, feats in feature_archive: if feats.NumRows() == 0: - logger.warning(f"Skipping {alignment.utterance_id} due to zero-length features") + logger.warning(f"Skipping {utterance_id} due to zero-length features") + continue + try: + alignment = alignment_archive[utterance_id] + except KeyError: + logger.warning(f"Skipping {utterance_id} due to missing alignment") continue if self.lda.Dim() == 0: self.lda.Init(self.transition_model.NumPdfs(), feats.NumCols()) @@ -99,10 +103,14 @@ def accumulate_stats( num_done = 0 tot_like = 0.0 tot_t = 0.0 - for alignment in alignment_archive: - feats = feature_archive[alignment.utterance_id] + for utterance_id, feats in feature_archive: if feats.NumRows() == 0: - logger.warning(f"Skipping {alignment.utterance_id} due to zero-length features") + logger.warning(f"Skipping {utterance_id} due to zero-length features") + continue + try: + alignment = alignment_archive[utterance_id] + except KeyError: + logger.warning(f"Skipping {utterance_id} due to missing alignment") continue if callback: callback(alignment.utterance_id) diff --git a/kalpy/feat/mfcc.py b/kalpy/feat/mfcc.py index b75b291..cea854e 100644 --- a/kalpy/feat/mfcc.py +++ b/kalpy/feat/mfcc.py @@ -88,9 +88,9 @@ def __init__( window_type: str = "povey", round_to_power_of_two: bool = True, blackman_coeff: float = 0.42, - snip_edges: bool = True, + snip_edges: bool = False, max_feature_vectors: int = -1, - num_mel_bins: int = 25, + num_mel_bins: int = 23, low_frequency: float = 20, high_frequency: float = 7800, vtln_low: float = 100, @@ -129,6 +129,35 @@ def __init__( self.raw_energy = raw_energy self.htk_compatibility = htk_compatibility + @property + def parameters(self): + return { + "sample_frequency": self.sample_frequency, + "frame_length": self.frame_length, + "frame_shift": self._frame_shift, + "dither": self.dither, + "preemphasis_coefficient": self.preemphasis_coefficient, + "remove_dc_offset": self.remove_dc_offset, + "window_type": self.window_type, + "round_to_power_of_two": self.round_to_power_of_two, + "blackman_coeff": self.blackman_coeff, + "snip_edges": self.snip_edges, + "max_feature_vectors": self.max_feature_vectors, + "num_mel_bins": self.num_mel_bins, + "low_frequency": self.low_frequency, + "high_frequency": self.high_frequency, + "vtln_low": self.vtln_low, + "vtln_high": self.vtln_high, + "num_coefficients": self.num_coefficients, + "use_energy": self.use_energy, + "energy_floor": self.energy_floor, + "raw_energy": self.raw_energy, + "cepstral_lifter": self.cepstral_lifter, + "htk_compatibility": self.htk_compatibility, + "allow_downsample": self.allow_downsample, + "allow_upsample": self.allow_upsample, + } + @property def frame_shift(self): return round(self._frame_shift / 1000, 3) @@ -207,24 +236,11 @@ def compute_mfccs_for_export( Feature matrix for the segment """ if isinstance(segment, Segment): - duration = None - if segment.end is not None and segment.begin is not None: - duration = segment.end - segment.begin - wave, sr = librosa.load( - segment.file_path, - sr=16000, - offset=segment.begin, - duration=duration, - mono=False, - ) - wave = np.round(wave * 32768) - if len(wave.shape) == 2: - channel = 0 if segment.channel is None else segment.channel - wave = wave[channel, :] + wave = segment.kaldi_wave else: wave = segment if isinstance(wave, np.ndarray) and np.max(wave) < 1.0: - wave = np.round(wave * 32768) + wave = wave * 32768 mfccs = self.mfcc_obj.compute(wave) if compress: diff --git a/kalpy/feat/pitch.py b/kalpy/feat/pitch.py index a38acc1..3d1046c 100644 --- a/kalpy/feat/pitch.py +++ b/kalpy/feat/pitch.py @@ -133,7 +133,7 @@ def __init__( frames_per_chunk: int = 0, simulate_first_pass_online: bool = False, recompute_frame: int = 500, - snip_edges: bool = True, + snip_edges: bool = False, pitch_scale: float = 2.0, pov_scale: float = 2.0, pov_offset: float = 0.0, @@ -183,6 +183,42 @@ def __init__( self.process_opts.add_delta_pitch = add_delta_pitch self.process_opts.add_raw_log_pitch = add_raw_log_pitch + @property + def parameters(self): + return { + "sample_frequency": self.extraction_opts.samp_freq, + "frame_length": self.extraction_opts.frame_length_ms, + "frame_shift": self.extraction_opts.frame_shift_ms, + "min_f0": self.extraction_opts.min_f0, + "max_f0": self.extraction_opts.max_f0, + "soft_min_f0": self.extraction_opts.soft_min_f0, + "penalty_factor": self.extraction_opts.penalty_factor, + "lowpass_cutoff": self.extraction_opts.lowpass_cutoff, + "resample_frequency": self.extraction_opts.resample_freq, + "delta_pitch": self.extraction_opts.delta_pitch, + "nccf_ballast": self.extraction_opts.nccf_ballast, + "lowpass_filter_width": self.extraction_opts.lowpass_filter_width, + "upsample_filter_width": self.extraction_opts.upsample_filter_width, + "max_frames_latency": self.extraction_opts.max_frames_latency, + "frames_per_chunk": self.extraction_opts.frames_per_chunk, + "simulate_first_pass_online": self.extraction_opts.simulate_first_pass_online, + "recompute_frame": self.extraction_opts.recompute_frame, + "snip_edges": self.extraction_opts.snip_edges, + "pitch_scale": self.process_opts.pitch_scale, + "pov_scale": self.process_opts.pov_scale, + "pov_offset": self.process_opts.pov_offset, + "delta_pitch_scale": self.process_opts.delta_pitch_scale, + "delta_pitch_noise_stddev": self.process_opts.delta_pitch_noise_stddev, + "normalization_left_context": self.process_opts.normalization_left_context, + "normalization_right_context": self.process_opts.normalization_right_context, + "delta_window": self.process_opts.delta_window, + "delay": self.process_opts.delay, + "add_pov_feature": self.process_opts.add_pov_feature, + "add_normalized_log_pitch": self.process_opts.add_normalized_log_pitch, + "add_delta_pitch": self.process_opts.add_delta_pitch, + "add_raw_log_pitch": self.process_opts.add_raw_log_pitch, + } + def compute_pitch( self, segment: Segment, diff --git a/kalpy/fstext/lexicon.py b/kalpy/fstext/lexicon.py index 03ab7f0..bb806cf 100644 --- a/kalpy/fstext/lexicon.py +++ b/kalpy/fstext/lexicon.py @@ -192,7 +192,6 @@ def __init__( self._align_lexicon = None self.word_begin_label = word_begin_label self.word_end_label = word_end_label - self.lock = threading.Lock() self.start_state = None self.loop_state = None self.silence_state = None @@ -576,202 +575,199 @@ def add_pronunciation( ): if (pronunciation.orthography, pronunciation.pronunciation) in self._cached_pronunciations: return - with self.lock: - phones = pronunciation.pronunciation.split() - if self.position_dependent_phones: - if len(phones) == 1: - phones[0] += "_S" - else: - phones[0] += "_B" - phones[-1] += "_E" - for i in range(1, len(phones) - 1): - phones[i] += "_I" - new_phones = ", ".join(sorted({x for x in phones if not self.phone_table.member(x)})) - if new_phones: - raise Exception( - f"The pronunciation '{pronunciation}' had the following phones not in the symbol table: {new_phones}" - ) - pron = " ".join(phones) - fst = pynini.accep(pron, token_type=self.phone_table) - if phonological_rule_fst: - fst = pynini.compose(phonological_rule_fst, fst) - fst.rmepsilon() - self._cached_pronunciations.add( - (pronunciation.orthography, pronunciation.pronunciation) - ) - if not self.word_table.member(pronunciation.orthography): - self.word_table.add_symbol(pronunciation.orthography) - word_symbol = self.word_table.find(pronunciation.orthography) - word_eps_symbol = self.word_table.find("") - phone_eps_symbol = self.phone_table.find("") - silence_before_cost = ( - -math.log(pronunciation.silence_before_correction) - if pronunciation.silence_before_correction - else 0.0 - ) - non_silence_before_cost = ( - -math.log(pronunciation.non_silence_before_correction) - if pronunciation.non_silence_before_correction - else 0.0 - ) - silence_following_cost = ( - -math.log(pronunciation.silence_after_probability) - if pronunciation.silence_after_probability - else self.base_silence_following_cost - ) - non_silence_following_cost = ( - -math.log(1 - pronunciation.silence_after_probability) - if pronunciation.silence_after_probability - else self.base_non_silence_following_cost + phones = pronunciation.pronunciation.split() + if self.position_dependent_phones: + if len(phones) == 1: + phones[0] += "_S" + else: + phones[0] += "_B" + phones[-1] += "_E" + for i in range(1, len(phones) - 1): + phones[i] += "_I" + new_phones = ", ".join(sorted({x for x in phones if not self.phone_table.member(x)})) + if new_phones: + raise Exception( + f"The pronunciation '{pronunciation}' had the following phones not in the symbol table: {new_phones}" ) - probability = pronunciation.probability - if probability is None: - probability = 1 - elif probability < 0.01: - probability = 0.01 # Dithering to ensure low probability entries - pron_cost = abs(math.log(probability)) - start_index = self._fst.num_states() - 1 - align_start_index = self._align_fst.num_states() - num_new_states = fst.num_states() - 1 - self._fst.add_states(num_new_states) - self._align_fst.add_states(num_new_states + 2) - - # FST arcs - for state in fst.states(): - for arc in fst.arcs(state): - if state == fst.start(): - # No silence before the pronunciation - self._fst.add_arc( - self.non_silence_state, - pywrapfst.Arc( - arc.ilabel, - word_symbol, - pywrapfst.Weight( - self._fst.weight_type(), pron_cost + non_silence_before_cost - ), - arc.nextstate + start_index, - ), - ) - # Silence before the pronunciation - self._fst.add_arc( - self.silence_state, - pywrapfst.Arc( - arc.ilabel, - word_symbol, - pywrapfst.Weight( - self._fst.weight_type(), pron_cost + silence_before_cost - ), - arc.nextstate + start_index, - ), - ) - - # No silence before the pronunciation - self._align_fst.add_arc( - self.non_silence_state, - pywrapfst.Arc( - self.phone_table.find(self.word_begin_label), - word_symbol, - pywrapfst.Weight( - self._fst.weight_type(), pron_cost + non_silence_before_cost - ), - arc.nextstate + align_start_index - 1, + pron = " ".join(phones) + fst = pynini.accep(pron, token_type=self.phone_table) + if phonological_rule_fst: + fst = pynini.compose(phonological_rule_fst, fst) + fst.rmepsilon() + self._cached_pronunciations.add((pronunciation.orthography, pronunciation.pronunciation)) + if not self.word_table.member(pronunciation.orthography): + self.word_table.add_symbol(pronunciation.orthography) + word_symbol = self.word_table.find(pronunciation.orthography) + word_eps_symbol = self.word_table.find("") + phone_eps_symbol = self.phone_table.find("") + silence_before_cost = ( + -math.log(pronunciation.silence_before_correction) + if pronunciation.silence_before_correction + else 0.0 + ) + non_silence_before_cost = ( + -math.log(pronunciation.non_silence_before_correction) + if pronunciation.non_silence_before_correction + else 0.0 + ) + silence_following_cost = ( + -math.log(pronunciation.silence_after_probability) + if pronunciation.silence_after_probability + else self.base_silence_following_cost + ) + non_silence_following_cost = ( + -math.log(1 - pronunciation.silence_after_probability) + if pronunciation.silence_after_probability + else self.base_non_silence_following_cost + ) + probability = pronunciation.probability + if probability is None: + probability = 1 + elif probability < 0.01: + probability = 0.01 # Dithering to ensure low probability entries + pron_cost = abs(math.log(probability)) + start_index = self._fst.num_states() - 1 + align_start_index = self._align_fst.num_states() + num_new_states = fst.num_states() - 1 + self._fst.add_states(num_new_states) + self._align_fst.add_states(num_new_states + 2) + + # FST arcs + for state in fst.states(): + for arc in fst.arcs(state): + if state == fst.start(): + # No silence before the pronunciation + self._fst.add_arc( + self.non_silence_state, + pywrapfst.Arc( + arc.ilabel, + word_symbol, + pywrapfst.Weight( + self._fst.weight_type(), pron_cost + non_silence_before_cost ), - ) - # Silence before the pronunciation - self._align_fst.add_arc( - self.silence_state, - pywrapfst.Arc( - self.phone_table.find(self.word_begin_label), - word_symbol, - pywrapfst.Weight( - self._fst.weight_type(), pron_cost + silence_before_cost - ), - arc.nextstate + align_start_index - 1, + arc.nextstate + start_index, + ), + ) + # Silence before the pronunciation + self._fst.add_arc( + self.silence_state, + pywrapfst.Arc( + arc.ilabel, + word_symbol, + pywrapfst.Weight( + self._fst.weight_type(), pron_cost + silence_before_cost ), - ) - else: - self._fst.add_arc( - state + start_index, - pywrapfst.Arc( - arc.ilabel, - word_eps_symbol, - arc.weight, - arc.nextstate + start_index, + arc.nextstate + start_index, + ), + ) + + # No silence before the pronunciation + self._align_fst.add_arc( + self.non_silence_state, + pywrapfst.Arc( + self.phone_table.find(self.word_begin_label), + word_symbol, + pywrapfst.Weight( + self._fst.weight_type(), pron_cost + non_silence_before_cost ), - ) + arc.nextstate + align_start_index - 1, + ), + ) + # Silence before the pronunciation self._align_fst.add_arc( - state + align_start_index, + self.silence_state, + pywrapfst.Arc( + self.phone_table.find(self.word_begin_label), + word_symbol, + pywrapfst.Weight( + self._fst.weight_type(), pron_cost + silence_before_cost + ), + arc.nextstate + align_start_index - 1, + ), + ) + else: + self._fst.add_arc( + state + start_index, pywrapfst.Arc( arc.ilabel, word_eps_symbol, arc.weight, - arc.nextstate + align_start_index, + arc.nextstate + start_index, ), ) - - if self.disambiguation and pronunciation.disambiguation is not None: - self._fst.add_state() - self._fst.add_arc( - num_new_states + start_index, + self._align_fst.add_arc( + state + align_start_index, pywrapfst.Arc( - self.phone_table.find(f"#{pronunciation.disambiguation}"), + arc.ilabel, word_eps_symbol, - pywrapfst.Weight(self._fst.weight_type(), non_silence_following_cost), - num_new_states + start_index + 1, + arc.weight, + arc.nextstate + align_start_index, ), ) - start_index += 1 - # No silence following the pronunciation + if self.disambiguation and pronunciation.disambiguation is not None: + self._fst.add_state() self._fst.add_arc( num_new_states + start_index, pywrapfst.Arc( - self.phone_table.find(self.silence_disambiguation_symbol), + self.phone_table.find(f"#{pronunciation.disambiguation}"), word_eps_symbol, pywrapfst.Weight(self._fst.weight_type(), non_silence_following_cost), - self.non_silence_state, - ), - ) - # Silence following the pronunciation - self._fst.add_arc( - num_new_states + start_index, - pywrapfst.Arc( - self.phone_table.find(self.silence_phone), - word_eps_symbol, - pywrapfst.Weight(self._fst.weight_type(), silence_following_cost), - self.silence_state, - ), - ) - self._align_fst.add_arc( - num_new_states + align_start_index, - pywrapfst.Arc( - self.phone_table.find(self.word_end_label), - word_eps_symbol, - pywrapfst.Weight.one(self._align_fst.weight_type()), - num_new_states + align_start_index + 1, + num_new_states + start_index + 1, ), ) + start_index += 1 - # No silence following the pronunciation - self._align_fst.add_arc( - num_new_states + align_start_index + 1, - pywrapfst.Arc( - phone_eps_symbol, - word_eps_symbol, - pywrapfst.Weight(self._fst.weight_type(), non_silence_following_cost), - self.non_silence_state, - ), - ) - # Silence following the pronunciation - self._align_fst.add_arc( + # No silence following the pronunciation + self._fst.add_arc( + num_new_states + start_index, + pywrapfst.Arc( + self.phone_table.find(self.silence_disambiguation_symbol), + word_eps_symbol, + pywrapfst.Weight(self._fst.weight_type(), non_silence_following_cost), + self.non_silence_state, + ), + ) + # Silence following the pronunciation + self._fst.add_arc( + num_new_states + start_index, + pywrapfst.Arc( + self.phone_table.find(self.silence_phone), + word_eps_symbol, + pywrapfst.Weight(self._fst.weight_type(), silence_following_cost), + self.silence_state, + ), + ) + self._align_fst.add_arc( + num_new_states + align_start_index, + pywrapfst.Arc( + self.phone_table.find(self.word_end_label), + word_eps_symbol, + pywrapfst.Weight.one(self._align_fst.weight_type()), num_new_states + align_start_index + 1, - pywrapfst.Arc( - self.phone_table.find(self.silence_phone), - word_eps_symbol, - pywrapfst.Weight(self._fst.weight_type(), silence_following_cost), - self.silence_state, - ), - ) + ), + ) + + # No silence following the pronunciation + self._align_fst.add_arc( + num_new_states + align_start_index + 1, + pywrapfst.Arc( + phone_eps_symbol, + word_eps_symbol, + pywrapfst.Weight(self._fst.weight_type(), non_silence_following_cost), + self.non_silence_state, + ), + ) + # Silence following the pronunciation + self._align_fst.add_arc( + num_new_states + align_start_index + 1, + pywrapfst.Arc( + self.phone_table.find(self.silence_phone), + word_eps_symbol, + pywrapfst.Weight(self._fst.weight_type(), silence_following_cost), + self.silence_state, + ), + ) @property def kaldi_fst(self) -> VectorFst: diff --git a/kalpy/gmm/align.py b/kalpy/gmm/align.py index f44818c..6e5a076 100644 --- a/kalpy/gmm/align.py +++ b/kalpy/gmm/align.py @@ -54,10 +54,7 @@ def __init__( self.retry_beam = 4 * self.beam def boost_silence(self, silence_weight: float, silence_phones: typing.List[int]): - if silence_weight != 1.0: - self.acoustic_model.boost_silence( - self.transition_model, silence_phones, silence_weight - ) + self.acoustic_model.boost_silence(self.transition_model, silence_phones, silence_weight) def align_utterance( self, training_graph: VectorFst, features: FloatMatrix, utterance_id: str = None @@ -83,6 +80,8 @@ def align_utterance( ) if not successful: return None + if retried and utterance_id: + logger.debug(f"Retried {utterance_id}") return Alignment(utterance_id, alignment, words, likelihood, per_frame_log_likelihoods) def align_utterances( @@ -93,20 +92,18 @@ def align_utterances( num_error = 0 total_frames = 0 total_likelihood = 0 - for utterance_id, training_graph in training_graph_archive: - try: - feats = feature_archive[utterance_id] - except KeyError: - logger.warning(f"Skipping {utterance_id} not in feature archive.") - num_error += 1 - continue - + for utterance_id, feats in feature_archive: if feats.NumRows() == 0: logger.warning(f"Skipping {utterance_id} due to zero-length features") - num_error += 1 + continue + try: + training_graph = training_graph_archive[utterance_id] + except KeyError: + logger.warning(f"Skipping {utterance_id} due to missing training graph") continue alignment = self.align_utterance(training_graph, feats, utterance_id) if alignment is None: + yield utterance_id, None num_error += 1 continue yield alignment @@ -143,6 +140,10 @@ def export_alignments( for alignment in self.align_utterances(training_graph_archive, feature_archive): if alignment is None: continue + if isinstance(alignment, tuple): + if callback: + callback(alignment) + continue if callback: callback((alignment.utterance_id, alignment.likelihood)) writer.Write(str(alignment.utterance_id), alignment.alignment) diff --git a/kalpy/gmm/train.py b/kalpy/gmm/train.py index 397865d..82fe82d 100644 --- a/kalpy/gmm/train.py +++ b/kalpy/gmm/train.py @@ -26,7 +26,7 @@ def __init__(self, acoustic_model_path: typing.Union[pathlib.Path, str]): self.gmm_accs = gmm.AccumAmDiagGmm() self.transition_model.InitStats(self.transition_accs) self.gmm_accs.Init(self.acoustic_model, gmm.kGmmAll) - self.num_done = 0 + self.callback_frequency = 100 def accumulate_stats( self, @@ -36,28 +36,35 @@ def accumulate_stats( ): tot_like = 0.0 tot_t = 0.0 - for alignment in alignment_archive: - feats = feature_archive[alignment.utterance_id] + num_done = 0 + for utterance_id, feats in feature_archive: if feats.NumRows() == 0: - logger.warning(f"Skipping {alignment.utterance_id} due to zero-length features") + logger.warning(f"Skipping {utterance_id} due to zero-length features") + continue + try: + alignment = alignment_archive[utterance_id] + except KeyError: + logger.warning(f"Skipping {utterance_id} due to missing alignment") continue - if callback: - callback(alignment.utterance_id) tot_like_this_file = self.gmm_accs.acc_stats( self.acoustic_model, self.transition_model, alignment.alignment, feats ) self.transition_model.acc_stats(alignment.alignment, self.transition_accs) - self.num_done += 1 + num_done += 1 tot_like += tot_like_this_file tot_t += len(alignment.alignment) - if self.num_done % 50 == 0: + if num_done % self.callback_frequency == 0: + if callback: + callback(self.callback_frequency) logger.info( - f"Processed {self.num_done} utterances; for utterance " + f"Processed {num_done} utterances; for utterance " f"{alignment.utterance_id} avg. like is " f"{tot_like_this_file/len(alignment.alignment)} " f"over {len(alignment.alignment)} frames." ) - logger.info(f"Done {self.num_done} files.") + if callback is not None and num_done % self.callback_frequency: + callback(num_done % self.callback_frequency) + logger.info(f"Done {num_done} files.") if tot_t: logger.info( f"Overall avg like per frame (Gaussian only) = {tot_like/tot_t} over {tot_t} frames." @@ -99,6 +106,7 @@ def __init__( if phone_map: self.tree_stats_info.ci_phones = phone_map self.tree_stats = {} + self.callback_frequency = 100 def accumulate_stats( self, @@ -107,13 +115,17 @@ def accumulate_stats( callback: typing.Callable = None, ): num_done = 0 - for alignment in alignment_archive: - feats = feature_archive[alignment.utterance_id] + for utterance_id, feats in feature_archive: if feats.NumRows() == 0: - logger.warning(f"Skipping {alignment.utterance_id} due to zero-length features") + logger.warning(f"Skipping {utterance_id} due to zero-length features") + continue + try: + alignment = alignment_archive[utterance_id] + except KeyError: + logger.warning(f"Skipping {utterance_id} due to missing alignment") continue - if callback: - callback(alignment.utterance_id) + if callback is not None and num_done % self.callback_frequency == 0: + callback(self.callback_frequency) stats = hmm.accumulate_tree_stats( self.transition_model, self.tree_stats_info, alignment.alignment, feats ) @@ -124,6 +136,8 @@ def accumulate_stats( else: self.tree_stats[e].Add(c) num_done += 1 + if callback is not None and num_done % self.callback_frequency: + callback(num_done % self.callback_frequency) logger.info(f"Done {num_done} files.") def export_stats( @@ -148,6 +162,7 @@ def __init__(self, acoustic_model_path: typing.Union[pathlib.Path, str]): self.gmm_accs = gmm.AccumAmDiagGmm() self.transition_model.InitStats(self.transition_accs) self.gmm_accs.Init(self.acoustic_model, gmm.kGmmAll) + self.callback_frequency = 100 def accumulate_stats( self, @@ -158,21 +173,27 @@ def accumulate_stats( ): num_done = 0 tot_like = 0.0 - for alignment in alignment_archive: - first_feats = first_feature_archive[alignment.utterance_id] - second_feats = second_feature_archive[alignment.utterance_id] + for (utterance_id, first_feats), (second_utterance_id, second_feats) in zip( + first_feature_archive, second_feature_archive + ): + assert utterance_id == second_utterance_id if first_feats.NumRows() == 0: logger.warning( - f"Skipping {alignment.utterance_id} due to zero-length features in first archive" + f"Skipping {utterance_id} due to zero-length features in first archive" ) continue if second_feats.NumRows() == 0: logger.warning( - f"Skipping {alignment.utterance_id} due to zero-length features in second archive" + f"Skipping {utterance_id} due to zero-length features in second archive" ) continue - if callback: - callback(alignment.utterance_id) + try: + alignment = alignment_archive[utterance_id] + except KeyError: + logger.warning(f"Skipping {utterance_id} due to missing alignment") + continue + if callback is not None and num_done % self.callback_frequency == 0: + callback(self.callback_frequency) post = hmm.AlignmentToPosterior(alignment.alignment) pdf_post = hmm.convert_posterior_to_pdfs(self.transition_model, post) tot_like_this_file = self.gmm_accs.acc_twofeats( @@ -183,6 +204,8 @@ def accumulate_stats( ) num_done += 1 tot_like += tot_like_this_file + if callback is not None and num_done % self.callback_frequency: + callback(num_done % self.callback_frequency) logger.info(f"Done {num_done} files.") def export_stats( diff --git a/tests/conftest.py b/tests/conftest.py index 00ae05a..6717e60 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,11 @@ +import os import pathlib +import subprocess +from io import BytesIO +import librosa import pytest +import soundfile @pytest.fixture(scope="session") @@ -36,6 +41,324 @@ def temp_dir(test_dir): return p +@pytest.fixture(scope="session") +def reference_dir(test_dir): + p = test_dir.joinpath("kaldi") + return p + + +@pytest.fixture(scope="session") +def reference_mfcc_path(wav_path, reference_dir): + ark_path = reference_dir.joinpath("mfccs.ark") + scp_path = reference_dir.joinpath("mfccs.scp") + + mfcc_proc = subprocess.Popen( + [ + "compute-mfcc-feats", + "--use-energy=false", + "--dither=0", + "--energy-floor=0", + "--num-ceps=13", + "--num-mel-bins=23", + "--cepstral-lifter=22", + "--preemphasis-coefficient=0.97", + "--frame-shift=10", + "--frame-length=25", + "--low-freq=20", + "--high-freq=7800", + "--sample-frequency=16000", + "--allow-downsample=true", + "--allow-upsample=true", + "--snip-edges=false", + "ark,s,cs:-", + f"ark,scp:{ark_path},{scp_path}", + ], + stdin=subprocess.PIPE, + env=os.environ, + ) + + wave, _ = librosa.load( + wav_path, + sr=16000, + offset=0.0, + duration=26.72325, + mono=False, + ) + bio = BytesIO() + soundfile.write(bio, wave, samplerate=16000, format="WAV") + mfcc_proc.stdin.write(f"1-1\t".encode("utf8")) + mfcc_proc.stdin.write(bio.getvalue()) + mfcc_proc.stdin.flush() + mfcc_proc.stdin.close() + mfcc_proc.communicate() + return scp_path + + +@pytest.fixture(scope="session") +def reference_cmvn_path(wav_path, reference_dir, reference_mfcc_path): + ark_path = reference_dir.joinpath("cmvn.ark") + scp_path = reference_dir.joinpath("cmvn.scp") + spk2utt = reference_dir.joinpath("spk2utt.scp") + with open(spk2utt, "w", encoding="utf8") as f: + f.write("1 1-1\n") + subprocess.call( + [ + "compute-cmvn-stats", + f"--spk2utt=ark:{spk2utt}", + f"scp:{reference_mfcc_path}", + f"ark,scp:{ark_path},{scp_path}", + ], + env=os.environ, + ) + return scp_path + + +@pytest.fixture(scope="session") +def reference_final_features_path( + wav_path, reference_dir, reference_mfcc_path, reference_cmvn_path +): + ark_path = reference_dir.joinpath("final_features.ark") + scp_path = reference_dir.joinpath("final_features.scp") + utt2spk = reference_dir.joinpath("utt2spk.scp") + with open(utt2spk, "w", encoding="utf8") as f: + f.write("1-1 1\n") + subprocess.call( + [ + "apply-cmvn", + f"--utt2spk=ark:{utt2spk}", + f"scp:{reference_cmvn_path}", + f"scp:{reference_mfcc_path}", + f"ark,scp:{ark_path},{scp_path}", + ], + env=os.environ, + ) + return scp_path + + +@pytest.fixture(scope="session") +def reference_si_feature_string(reference_final_features_path, sat_lda_mat_path): + return ( + f'ark,s,cs:splice-feats --left-context=3 --right-context=3 scp,s,cs:"{reference_final_features_path}" ark:- ' + f'| transform-feats "{sat_lda_mat_path}" ark:- ark:- |' + ) + + +@pytest.fixture(scope="session") +def reference_sat_feature_string( + reference_dir, reference_final_features_path, sat_lda_mat_path, reference_trans_path +): + utt2spk = reference_dir.joinpath("utt2spk.scp") + return ( + f'ark,s,cs:splice-feats --left-context=3 --right-context=3 scp,s,cs:"{reference_final_features_path}" ark:- ' + f'| transform-feats "{sat_lda_mat_path}" ark:- ark:- | transform-feats --utt2spk=ark:"{utt2spk}" scp:"{reference_trans_path}" ark:- ark:- |' + ) + + +@pytest.fixture(scope="session") +def reference_first_pass_ali_path( + wav_path, + reference_dir, + sat_align_model_path, + sat_temp_dir, + align_options, + reference_si_feature_string, +): + ali_path = reference_dir.joinpath("ali_first_pass.ark") + fst_path = sat_temp_dir.joinpath("fsts.ark") + subprocess.call( + [ + "gmm-align-compiled", + f"--transition-scale={align_options['transition_scale']}", + f"--acoustic-scale={align_options['acoustic_scale']}", + f"--self-loop-scale={align_options['self_loop_scale']}", + f"--beam={align_options['beam']}", + f"--retry-beam={align_options['retry_beam']}", + "--careful=false", + sat_align_model_path, + f"ark,s,cs:{fst_path}", + reference_si_feature_string, + f"ark:{ali_path}", + ], + env=os.environ, + ) + return ali_path + + +@pytest.fixture(scope="session") +def reference_second_pass_ali_path( + wav_path, + reference_dir, + sat_model_path, + sat_temp_dir, + align_options, + reference_sat_feature_string, +): + ali_path = reference_dir.joinpath("ali_second_pass.ark") + fst_path = sat_temp_dir.joinpath("fsts.ark") + subprocess.call( + [ + "gmm-align-compiled", + f"--transition-scale={align_options['transition_scale']}", + f"--acoustic-scale={align_options['acoustic_scale']}", + f"--self-loop-scale={align_options['self_loop_scale']}", + f"--beam={align_options['beam']}", + f"--retry-beam={align_options['retry_beam']}", + "--careful=false", + sat_model_path, + f"ark,s,cs:{fst_path}", + reference_sat_feature_string, + f"ark:{ali_path}", + ], + env=os.environ, + ) + return ali_path + + +@pytest.fixture(scope="session") +def align_options(): + return { + "transition_scale": 1.0, + "acoustic_scale": 0.1, + "self_loop_scale": 0.1, + "beam": 10, + "retry_beam": 40, + } + + +@pytest.fixture(scope="session") +def fmllr_options(): + return {"fmllr_update_type": "full", "silence_weight": 0.0, "silence_csl": "1:2"} + + +@pytest.fixture(scope="session") +def reference_trans_path( + wav_path, + reference_dir, + reference_final_features_path, + reference_si_feature_string, + fmllr_options, + sat_align_model_path, + sat_model_path, +): + ark_path = reference_dir.joinpath("trans.ark") + scp_path = reference_dir.joinpath("trans.scp") + ali_path = reference_dir.joinpath("ali_first_pass.ark") + spk2utt = reference_dir.joinpath("spk2utt.scp") + + post_proc = subprocess.Popen( + ["ali-to-post", f"ark,s,cs:{ali_path}", "ark:-"], + stdout=subprocess.PIPE, + env=os.environ, + ) + + weight_proc = subprocess.Popen( + [ + "weight-silence-post", + "0.0", + fmllr_options["silence_csl"], + sat_align_model_path, + "ark,s,cs:-", + "ark:-", + ], + stdin=post_proc.stdout, + stdout=subprocess.PIPE, + env=os.environ, + ) + post_gpost_proc = subprocess.Popen( + [ + "gmm-post-to-gpost", + sat_align_model_path, + reference_si_feature_string, + "ark,s,cs:-", + "ark:-", + ], + stdin=weight_proc.stdout, + stdout=subprocess.PIPE, + env=os.environ, + ) + est_proc = subprocess.Popen( + [ + "gmm-est-fmllr-gpost", + f"--fmllr-update-type={fmllr_options['fmllr_update_type']}", + f"--spk2utt=ark:{spk2utt}", + sat_model_path, + reference_si_feature_string, + "ark,s,cs:-", + f"ark,scp:{ark_path},{scp_path}", + ], + encoding="utf8", + stdin=post_gpost_proc.stdout, + env=os.environ, + ) + est_proc.communicate() + return scp_path + + +@pytest.fixture(scope="session") +def reference_trans_compose_path( + wav_path, + reference_dir, + reference_final_features_path, + reference_sat_feature_string, + fmllr_options, + sat_model_path, + reference_trans_path, + reference_second_pass_ali_path, +): + temp_ark_path = reference_dir.joinpath("trans_second.ark") + temp_scp_path = reference_dir.joinpath("trans_second.scp") + ark_path = reference_dir.joinpath("trans_composed.ark") + scp_path = reference_dir.joinpath("trans_composed.scp") + spk2utt = reference_dir.joinpath("spk2utt.scp") + + post_proc = subprocess.Popen( + ["ali-to-post", f"ark,s,cs:{reference_second_pass_ali_path}", "ark:-"], + stdout=subprocess.PIPE, + env=os.environ, + ) + + weight_proc = subprocess.Popen( + [ + "weight-silence-post", + "0.0", + fmllr_options["silence_csl"], + sat_model_path, + "ark,s,cs:-", + "ark:-", + ], + stdin=post_proc.stdout, + stdout=subprocess.PIPE, + env=os.environ, + ) + est_proc = subprocess.Popen( + [ + "gmm-est-fmllr", + f"--fmllr-update-type={fmllr_options['fmllr_update_type']}", + f"--spk2utt=ark:{spk2utt}", + sat_model_path, + reference_sat_feature_string, + "ark,s,cs:-", + f"ark,scp:{temp_ark_path},{temp_scp_path}", + ], + encoding="utf8", + stdin=weight_proc.stdout, + env=os.environ, + ) + est_proc.communicate() + compose_proc = subprocess.Popen( + [ + "compose-transforms", + "--b-is-affine=true", + f"scp:{temp_scp_path}", + f"scp:{reference_trans_path}", + f"ark,scp:{ark_path},{scp_path}", + ], + env=os.environ, + ) + compose_proc.communicate() + return scp_path + + @pytest.fixture(scope="session") def mono_temp_dir(temp_dir): p = temp_dir.joinpath("mono") diff --git a/tests/data/kaldi/ali_first_pass.ark b/tests/data/kaldi/ali_first_pass.ark new file mode 100644 index 0000000..aa58a03 Binary files /dev/null and b/tests/data/kaldi/ali_first_pass.ark differ diff --git a/tests/data/kaldi/ali_second_pass.ark b/tests/data/kaldi/ali_second_pass.ark new file mode 100644 index 0000000..7a4accf Binary files /dev/null and b/tests/data/kaldi/ali_second_pass.ark differ diff --git a/tests/data/kaldi/cmvn.ark b/tests/data/kaldi/cmvn.ark new file mode 100644 index 0000000..e57572e Binary files /dev/null and b/tests/data/kaldi/cmvn.ark differ diff --git a/tests/data/kaldi/cmvn.scp b/tests/data/kaldi/cmvn.scp new file mode 100644 index 0000000..429f74e --- /dev/null +++ b/tests/data/kaldi/cmvn.scp @@ -0,0 +1 @@ +1 C:\Users\michael\Documents\Dev\kalpy\tests\data\kaldi\cmvn.ark:2 diff --git a/tests/data/kaldi/feats.1.ark b/tests/data/kaldi/feats.1.ark new file mode 100644 index 0000000..8152f14 Binary files /dev/null and b/tests/data/kaldi/feats.1.ark differ diff --git a/tests/data/kaldi/feats.1.scp b/tests/data/kaldi/feats.1.scp new file mode 100644 index 0000000..a3cb55d --- /dev/null +++ b/tests/data/kaldi/feats.1.scp @@ -0,0 +1 @@ +1-1 C:\Users\michael\Documents\Dev\kalpy\tests\data\temp\sat\mfccs.ark:4 diff --git a/tests/data/kaldi/final_features.1.ark b/tests/data/kaldi/final_features.1.ark new file mode 100644 index 0000000..77c8edc Binary files /dev/null and b/tests/data/kaldi/final_features.1.ark differ diff --git a/tests/data/kaldi/final_features.1.scp b/tests/data/kaldi/final_features.1.scp new file mode 100644 index 0000000..bf5535e --- /dev/null +++ b/tests/data/kaldi/final_features.1.scp @@ -0,0 +1 @@ +1-1 tests/data/kaldi/final_features.1.ark:4 diff --git a/tests/data/kaldi/final_features.ark b/tests/data/kaldi/final_features.ark new file mode 100644 index 0000000..1bde0db Binary files /dev/null and b/tests/data/kaldi/final_features.ark differ diff --git a/tests/data/kaldi/final_features.scp b/tests/data/kaldi/final_features.scp new file mode 100644 index 0000000..92daa7c --- /dev/null +++ b/tests/data/kaldi/final_features.scp @@ -0,0 +1 @@ +1-1 C:\Users\michael\Documents\Dev\kalpy\tests\data\kaldi\final_features.ark:4 diff --git a/tests/data/kaldi/fsts.1.2.ark b/tests/data/kaldi/fsts.1.2.ark new file mode 100644 index 0000000..fba34ee Binary files /dev/null and b/tests/data/kaldi/fsts.1.2.ark differ diff --git a/tests/data/kaldi/mfccs.ark b/tests/data/kaldi/mfccs.ark new file mode 100644 index 0000000..47dffad Binary files /dev/null and b/tests/data/kaldi/mfccs.ark differ diff --git a/tests/data/kaldi/mfccs.scp b/tests/data/kaldi/mfccs.scp new file mode 100644 index 0000000..488c6e7 --- /dev/null +++ b/tests/data/kaldi/mfccs.scp @@ -0,0 +1 @@ +1-1 C:\Users\michael\Documents\Dev\kalpy\tests\data\kaldi\mfccs.ark:4 diff --git a/tests/data/kaldi/spk2utt.scp b/tests/data/kaldi/spk2utt.scp new file mode 100644 index 0000000..d5b290c --- /dev/null +++ b/tests/data/kaldi/spk2utt.scp @@ -0,0 +1 @@ +1 1-1 diff --git a/tests/data/kaldi/trans.ark b/tests/data/kaldi/trans.ark new file mode 100644 index 0000000..482184d Binary files /dev/null and b/tests/data/kaldi/trans.ark differ diff --git a/tests/data/kaldi/trans.scp b/tests/data/kaldi/trans.scp new file mode 100644 index 0000000..96f86ce --- /dev/null +++ b/tests/data/kaldi/trans.scp @@ -0,0 +1 @@ +1 C:\Users\michael\Documents\Dev\kalpy\tests\data\kaldi\trans.ark:2 diff --git a/tests/data/kaldi/trans_composed.ark b/tests/data/kaldi/trans_composed.ark new file mode 100644 index 0000000..d7db4af Binary files /dev/null and b/tests/data/kaldi/trans_composed.ark differ diff --git a/tests/data/kaldi/trans_composed.scp b/tests/data/kaldi/trans_composed.scp new file mode 100644 index 0000000..ef99cf4 --- /dev/null +++ b/tests/data/kaldi/trans_composed.scp @@ -0,0 +1 @@ +1 C:\Users\michael\Documents\Dev\kalpy\tests\data\kaldi\trans_composed.ark:2 diff --git a/tests/data/kaldi/trans_second.ark b/tests/data/kaldi/trans_second.ark new file mode 100644 index 0000000..665b7b8 Binary files /dev/null and b/tests/data/kaldi/trans_second.ark differ diff --git a/tests/data/kaldi/trans_second.scp b/tests/data/kaldi/trans_second.scp new file mode 100644 index 0000000..412d324 --- /dev/null +++ b/tests/data/kaldi/trans_second.scp @@ -0,0 +1 @@ +1 C:\Users\michael\Documents\Dev\kalpy\tests\data\kaldi\trans_second.ark:2 diff --git a/tests/data/kaldi/utt2spk.scp b/tests/data/kaldi/utt2spk.scp new file mode 100644 index 0000000..005f2a6 --- /dev/null +++ b/tests/data/kaldi/utt2spk.scp @@ -0,0 +1 @@ +1-1 1 diff --git a/tests/data/kaldi/wav.scp b/tests/data/kaldi/wav.scp new file mode 100644 index 0000000..8d1591e --- /dev/null +++ b/tests/data/kaldi/wav.scp @@ -0,0 +1 @@ +1-1 tests/data/wav/acoustic_corpus.wav diff --git a/tests/test_align.py b/tests/test_align.py index 03a0ec6..2aa6254 100644 --- a/tests/test_align.py +++ b/tests/test_align.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from kalpy.data import KaldiMapping @@ -16,7 +17,7 @@ def test_align(mono_tree_path, mono_model_path, dictionary_path, mono_temp_dir): cmvn_file_name = mono_temp_dir.joinpath("cmvn.ark") training_graph_archive = FstArchive(mono_temp_dir.joinpath("fsts.ark")) utt2spk = KaldiMapping() - utt2spk["1"] = "1" + utt2spk["1-1"] = "1" feature_archive = FeatureArchive( mono_temp_dir.joinpath("mfccs.ark"), utt2spk=utt2spk, @@ -25,11 +26,11 @@ def test_align(mono_tree_path, mono_model_path, dictionary_path, mono_temp_dir): ) aligner = GmmAligner(mono_model_path, beam=1000, retry_beam=4000) for alignment in aligner.align_utterances(training_graph_archive, feature_archive): - assert alignment.utterance_id == "1" + assert alignment.utterance_id == "1-1" assert len(alignment.alignment) == 2672 assert alignment.per_frame_likelihoods.numpy().shape[0] == 2672 ctm = alignment.generate_ctm(aligner.transition_model, lc.phone_table) - assert len(ctm) == 242 + assert len(ctm) == 243 @pytest.mark.order(3) @@ -40,6 +41,9 @@ def test_align_sat_first_pass( sat_dictionary_path, sat_temp_dir, sat_phones, + reference_dir, + align_options, + reference_first_pass_ali_path, ): lc = LexiconCompiler(position_dependent_phones=False, phones=sat_phones) lc.load_pronunciations(sat_dictionary_path) @@ -50,7 +54,7 @@ def test_align_sat_first_pass( training_graph_archive = FstArchive(sat_temp_dir.joinpath("fsts.ark")) utt2spk = KaldiMapping() textgrid_name = sat_temp_dir.joinpath("first_pass.TextGrid") - utt2spk["1"] = "1" + utt2spk["1-1"] = "1" feature_archive = FeatureArchive( sat_temp_dir.joinpath("mfccs.ark"), utt2spk=utt2spk, @@ -58,8 +62,9 @@ def test_align_sat_first_pass( lda_mat_file_name=sat_lda_mat_path, splices=True, ) - aligner = GmmAligner(sat_align_model_path, beam=10, retry_beam=40) - aligner.boost_silence(20.0, lc.silence_symbols) + + aligner = GmmAligner(sat_align_model_path, **align_options) + aligner.boost_silence(1.0, lc.silence_symbols) aligner.export_alignments( alignments_file_name, training_graph_archive, @@ -68,17 +73,30 @@ def test_align_sat_first_pass( ) assert alignments_file_name.exists() alignment_archive = AlignmentArchive(alignments_file_name, words_file_name=word_file_name) - alignment = alignment_archive["1"] - assert len(alignment.alignment) == 2670 + alignment = alignment_archive["1-1"] + assert len(alignment.alignment) == 2672 intervals = alignment.generate_ctm(aligner.transition_model, lc.phone_table) text = " ".join(lc.word_table.find(x) for x in alignment.words) ctm = lc.phones_to_pronunciations(text, alignment.words, intervals) ctm.export_textgrid(textgrid_name, file_duration=26.72) + reference_alignment_archive = AlignmentArchive(reference_first_pass_ali_path) + reference_alignment = reference_alignment_archive["1-1"] + phone_ctm = alignment.generate_ctm(aligner.transition_model, lc.phone_table) + ref_phone_ctm = reference_alignment.generate_ctm(aligner.transition_model, lc.phone_table) + assert alignment.alignment == reference_alignment.alignment + assert phone_ctm == ref_phone_ctm @pytest.mark.order(5) def test_align_sat_second_pass( - sat_tree_path, sat_model_path, sat_lda_mat_path, sat_dictionary_path, sat_temp_dir, sat_phones + sat_tree_path, + sat_model_path, + sat_lda_mat_path, + sat_dictionary_path, + sat_temp_dir, + sat_phones, + reference_second_pass_ali_path, + align_options, ): lc = LexiconCompiler(position_dependent_phones=False, phones=sat_phones) lc.load_pronunciations(sat_dictionary_path) @@ -89,7 +107,7 @@ def test_align_sat_second_pass( alignments_file_name = sat_temp_dir.joinpath("ali_second_pass.ark") training_graph_archive = FstArchive(sat_temp_dir.joinpath("fsts.ark")) utt2spk = KaldiMapping() - utt2spk["1"] = "1" + utt2spk["1-1"] = "1" feature_archive = FeatureArchive( sat_temp_dir.joinpath("mfccs.ark"), utt2spk=utt2spk, @@ -98,8 +116,8 @@ def test_align_sat_second_pass( transform_file_name=trans_file_name, splices=True, ) - aligner = GmmAligner(sat_model_path, beam=10, retry_beam=40) - aligner.boost_silence(20.0, lc.silence_symbols) + aligner = GmmAligner(sat_model_path, **align_options) + aligner.boost_silence(1.0, lc.silence_symbols) aligner.export_alignments( alignments_file_name, training_graph_archive, @@ -108,9 +126,16 @@ def test_align_sat_second_pass( ) assert alignments_file_name.exists() alignment_archive = AlignmentArchive(alignments_file_name, words_file_name=word_file_name) - alignment = alignment_archive["1"] - assert len(alignment.alignment) == 2670 + alignment = alignment_archive["1-1"] + assert len(alignment.alignment) == 2672 intervals = alignment.generate_ctm(aligner.transition_model, lc.phone_table) text = " ".join(lc.word_table.find(x) for x in alignment.words) ctm = lc.phones_to_pronunciations(text, alignment.words, intervals) ctm.export_textgrid(textgrid_name, file_duration=26.72) + + reference_alignment_archive = AlignmentArchive(reference_second_pass_ali_path) + reference_alignment = reference_alignment_archive["1-1"] + phone_ctm = alignment.generate_ctm(aligner.transition_model, lc.phone_table) + ref_phone_ctm = reference_alignment.generate_ctm(aligner.transition_model, lc.phone_table) + assert alignment.alignment == reference_alignment.alignment + assert phone_ctm == ref_phone_ctm diff --git a/tests/test_decode.py b/tests/test_decode.py index 7d3f559..0c11313 100644 --- a/tests/test_decode.py +++ b/tests/test_decode.py @@ -17,7 +17,7 @@ def test_decode(mono_tree_path, mono_model_path, dictionary_path, mono_temp_dir, gc = DecodeGraphCompiler(mono_model_path, mono_tree_path, lc, arpa_path=lm_path) cmvn_file_name = mono_temp_dir.joinpath("cmvn.ark") utt2spk = KaldiMapping() - utt2spk["1"] = "1" + utt2spk["1-1"] = "1" feature_archive = FeatureArchive( mono_temp_dir.joinpath("mfccs.ark"), utt2spk=utt2spk, @@ -26,7 +26,7 @@ def test_decode(mono_tree_path, mono_model_path, dictionary_path, mono_temp_dir, ) aligner = GmmDecoder(mono_model_path, gc.hclg_fst, beam=1000) for alignment in aligner.decode_utterances(feature_archive): - assert alignment.utterance_id == "1" + assert alignment.utterance_id == "1-1" assert len(alignment.alignment) == 2672 assert alignment.per_frame_likelihoods.numpy().shape[0] == 2672 ctm = alignment.generate_ctm(aligner.transition_model, lc.phone_table) @@ -55,7 +55,7 @@ def test_decode_sat_first_pass( word_file_name = sat_temp_dir.joinpath("words.ark") utt2spk = KaldiMapping() textgrid_name = sat_temp_dir.joinpath("first_pass_decode.TextGrid") - utt2spk["1"] = "1" + utt2spk["1-1"] = "1" feature_archive = FeatureArchive( sat_temp_dir.joinpath("mfccs.ark"), utt2spk=utt2spk, @@ -74,8 +74,8 @@ def test_decode_sat_first_pass( assert lattice_file_name.exists() assert alignment_file_name.exists() alignment_archive = AlignmentArchive(alignment_file_name, words_file_name=word_file_name) - alignment = alignment_archive["1"] - assert len(alignment.alignment) == 2670 + alignment = alignment_archive["1-1"] + assert len(alignment.alignment) == 2672 intervals = alignment.generate_ctm(decoder.transition_model, lc.phone_table) text = " ".join(lc.word_table.find(x) for x in alignment.words) ctm = lc.phones_to_pronunciations(text, alignment.words, intervals) @@ -105,7 +105,7 @@ def test_decode_sat_second_pass( lattice_file_name = sat_temp_dir.joinpath("lat_second_pass.ark") alignment_file_name = sat_temp_dir.joinpath("ali_decode_second_pass.ark") utt2spk = KaldiMapping() - utt2spk["1"] = "1" + utt2spk["1-1"] = "1" feature_archive = FeatureArchive( sat_temp_dir.joinpath("mfccs.ark"), utt2spk=utt2spk, @@ -125,8 +125,8 @@ def test_decode_sat_second_pass( assert lattice_file_name.exists() assert alignment_file_name.exists() alignment_archive = AlignmentArchive(alignment_file_name, words_file_name=word_file_name) - alignment = alignment_archive["1"] - assert len(alignment.alignment) == 2670 + alignment = alignment_archive["1-1"] + assert len(alignment.alignment) == 2672 intervals = alignment.generate_ctm(decoder.transition_model, lc.phone_table) text = " ".join(lc.word_table.find(x) for x in alignment.words) ctm = lc.phones_to_pronunciations(text, alignment.words, intervals) @@ -155,7 +155,7 @@ def test_decode_sat_lm_rescore( lattice_file_name = sat_temp_dir.joinpath("lat_second_pass.ark") lattice_output_file_name = sat_temp_dir.joinpath("lat_second_pass_rescore.ark") utt2spk = KaldiMapping() - utt2spk["1"] = "1" + utt2spk["1-1"] = "1" decoder = LmRescorer(gc.g_fst) lattice_archive = LatticeArchive(lattice_file_name) decoder.export_lattices(lattice_output_file_name, lattice_archive, gc.g_carpa) diff --git a/tests/test_decoder.py b/tests/test_decoder.py index 1d373a9..f2bdd67 100644 --- a/tests/test_decoder.py +++ b/tests/test_decoder.py @@ -29,10 +29,10 @@ def test_training_graphs( assert graph.num_states() > 0 assert graph.start() != pywrapfst.NO_STATE_ID output_file_name = mono_temp_dir.joinpath("fsts.ark") - gc.export_graphs(output_file_name, [("1", acoustic_corpus_text)]) + gc.export_graphs(output_file_name, [("1-1", acoustic_corpus_text)]) assert output_file_name.exists() os.remove(output_file_name) - gc.export_graphs(output_file_name, [("1", acoustic_corpus_text)], write_scp=True) + gc.export_graphs(output_file_name, [("1-1", acoustic_corpus_text)], write_scp=True) assert output_file_name.exists() assert output_file_name.with_suffix(".scp").exists() @@ -60,10 +60,10 @@ def test_training_graphs_sat( graph.write(str(sat_temp_dir.joinpath("LG_debug.fst"))) output_file_name = sat_temp_dir.joinpath("fsts.ark") - gc.export_graphs(output_file_name, [("1", acoustic_corpus_text)]) + gc.export_graphs(output_file_name, [("1-1", acoustic_corpus_text)]) assert output_file_name.exists() os.remove(output_file_name) - gc.export_graphs(output_file_name, [("1", acoustic_corpus_text)], write_scp=True) + gc.export_graphs(output_file_name, [("1-1", acoustic_corpus_text)], write_scp=True) assert output_file_name.exists() assert output_file_name.with_suffix(".scp").exists() diff --git a/tests/test_mfcc.py b/tests/test_mfcc.py index 9477cda..df9ec72 100644 --- a/tests/test_mfcc.py +++ b/tests/test_mfcc.py @@ -1,8 +1,18 @@ import os +import numpy as np import pytest -from kalpy.data import KaldiMapping, Segment +from _kalpy import transform +from _kalpy.feat import RandomAccessWaveReader +from _kalpy.matrix import CompressedMatrix +from _kalpy.util import ( + BaseFloatMatrixWriter, + CompressedMatrixWriter, + RandomAccessBaseDoubleMatrixReader, + RandomAccessBaseFloatMatrixReader, +) +from kalpy.data import KaldiMapping, MatrixArchive, Segment from kalpy.decoder.training_graphs import TrainingGraphCompiler from kalpy.feat.cmvn import CmvnComputer from kalpy.feat.data import FeatureArchive @@ -11,6 +21,19 @@ from kalpy.fstext.lexicon import LexiconCompiler from kalpy.gmm.align import GmmAligner from kalpy.gmm.data import AlignmentArchive, LatticeArchive +from kalpy.utils import generate_read_specifier, generate_write_specifier + + +@pytest.mark.order(1) +def test_wave(wav_path, reference_dir): + ref_wav_scp = reference_dir.joinpath("wav.scp") + wav_rspecifier = generate_read_specifier(ref_wav_scp) + wave_reader = RandomAccessWaveReader(wav_rspecifier) + kaldi_wave = wave_reader.Value("1-1").Data().numpy()[0, :] + segment = Segment(wav_path) + kalpy_wave = segment.kaldi_wave + + np.testing.assert_allclose(kaldi_wave, kalpy_wave) @pytest.mark.order(1) @@ -28,17 +51,17 @@ def test_generate_mfcc(wav_path): def test_export_mfcc(wav_path, mono_temp_dir): output_file_name = mono_temp_dir.joinpath("mfccs.ark") feature_generator = MfccComputer(snip_edges=False) - segments = {"1": Segment(wav_path)} + segments = {"1-1": Segment(wav_path)} feature_generator.export_feats(output_file_name, segments.items()) assert output_file_name.exists() archive = FeatureArchive(output_file_name) for utt, mfccs in archive: mfccs = mfccs.numpy() - assert utt == "1" + assert utt == "1-1" assert mfccs.shape[0] == 2672 assert mfccs.shape[1] == 13 - mfccs = archive["1"].numpy() + mfccs = archive["1-1"].numpy() print(mfccs) print(mfccs.shape) assert mfccs.shape[0] == 2672 @@ -46,17 +69,17 @@ def test_export_mfcc(wav_path, mono_temp_dir): archive.close() os.remove(output_file_name) feature_generator.export_feats( - output_file_name, segments.items(), write_scp=True, compress=True + output_file_name, segments.items(), write_scp=True, compress=False ) assert output_file_name.with_suffix(".scp").exists() archive = FeatureArchive(output_file_name.with_suffix(".scp")) for utt, mfccs in archive: mfccs = mfccs.numpy() - assert utt == "1" + assert utt == "1-1" assert mfccs.shape[0] == 2672 assert mfccs.shape[1] == 13 - mfccs = archive["1"].numpy() + mfccs = archive["1-1"].numpy() assert mfccs.shape[0] == 2672 assert mfccs.shape[1] == 13 @@ -67,70 +90,124 @@ def test_cmvn(mono_temp_dir): output_file_name = mono_temp_dir.joinpath("cmvn.ark") feature_generator = CmvnComputer(online=False) spk2utt = KaldiMapping() - spk2utt["1"] = ["1"] + spk2utt["1"] = ["1-1"] feature_archive = FeatureArchive(feature_file_name) feature_generator.export_cmvn(output_file_name, feature_archive, spk2utt) assert output_file_name.exists() @pytest.mark.order(1) -def test_export_mfcc_sat(wav_path, sat_temp_dir): +def test_export_mfcc_sat(wav_path, sat_temp_dir, reference_dir, reference_mfcc_path): output_file_name = sat_temp_dir.joinpath("mfccs.ark") - feature_generator = MfccComputer(snip_edges=True) - segments = {"1": Segment(wav_path)} + feature_generator = MfccComputer(snip_edges=False, dither=0, use_energy=False) + segments = {"1-1": Segment(wav_path, begin=0.0, end=26.72325)} feature_generator.export_feats(output_file_name, segments.items()) assert output_file_name.exists() archive = FeatureArchive(output_file_name) for utt, mfccs in archive: mfccs = mfccs.numpy() - assert utt == "1" - assert mfccs.shape[0] == 2670 + assert utt == "1-1" + assert mfccs.shape[0] == 2672 assert mfccs.shape[1] == 13 - mfccs = archive["1"].numpy() + mfccs = archive["1-1"].numpy() print(mfccs) print(mfccs.shape) - assert mfccs.shape[0] == 2670 + assert mfccs.shape[0] == 2672 assert mfccs.shape[1] == 13 archive.close() os.remove(output_file_name) feature_generator.export_feats( - output_file_name, segments.items(), write_scp=True, compress=True + output_file_name, segments.items(), write_scp=True, compress=False ) assert output_file_name.with_suffix(".scp").exists() archive = FeatureArchive(output_file_name.with_suffix(".scp")) + ref_archive = FeatureArchive(reference_mfcc_path) for utt, mfccs in archive: mfccs = mfccs.numpy() - assert utt == "1" - assert mfccs.shape[0] == 2670 + assert utt == "1-1" + assert mfccs.shape[0] == 2672 assert mfccs.shape[1] == 13 - mfccs = archive["1"].numpy() - assert mfccs.shape[0] == 2670 + ref_mfccs = ref_archive["1-1"] + np.testing.assert_allclose(mfccs, ref_mfccs.numpy()) + mfccs = archive["1-1"].numpy() + assert mfccs.shape[0] == 2672 assert mfccs.shape[1] == 13 @pytest.mark.order(2) -def test_cmvn_sat(sat_temp_dir): +def test_cmvn_sat( + sat_temp_dir, + reference_dir, + reference_mfcc_path, + reference_cmvn_path, + reference_final_features_path, +): feature_file_name = sat_temp_dir.joinpath("mfccs.ark") output_file_name = sat_temp_dir.joinpath("cmvn.ark") feature_generator = CmvnComputer(online=False) spk2utt = KaldiMapping() - spk2utt["1"] = ["1"] + spk2utt["1"] = ["1-1"] + utt2spk = KaldiMapping() + utt2spk["1-1"] = "1" feature_archive = FeatureArchive(feature_file_name) - feature_generator.export_cmvn(output_file_name, feature_archive, spk2utt) + feature_generator.export_cmvn(output_file_name, feature_archive, spk2utt, write_scp=True) assert output_file_name.exists() + cmvn_read_specifier = generate_read_specifier(output_file_name) + cmvn_reader = RandomAccessBaseDoubleMatrixReader(cmvn_read_specifier) + ref_cmvn_read_specifier = generate_read_specifier(reference_cmvn_path) + ref_cmvn_reader = RandomAccessBaseDoubleMatrixReader(ref_cmvn_read_specifier) + cmvn = cmvn_reader.Value("1") + ref_cmvn = ref_cmvn_reader.Value("1") + np.testing.assert_allclose(cmvn.numpy(), ref_cmvn.numpy()) + final_archive = FeatureArchive( + feature_file_name, + utt2spk=utt2spk, + cmvn_file_name=output_file_name, + ) + feat_archive = FeatureArchive( + feature_file_name, + ) + ref_final_features_archive = FeatureArchive(reference_final_features_path) + feats = feat_archive["1-1"] + cmvned_feats = transform.apply_cmvn(feats, cmvn) + ref_cmvned_feats = ref_final_features_archive["1-1"] + np.testing.assert_allclose(cmvned_feats.numpy(), ref_cmvned_feats.numpy()) + temp_ark_path = sat_temp_dir.joinpath("final_features.ark") + temp_scp_path = sat_temp_dir.joinpath("final_features.scp") + write_specifier = generate_write_specifier(temp_ark_path, write_scp=True) + feature_writer = BaseFloatMatrixWriter(write_specifier) + np.testing.assert_allclose( + final_archive["1-1"].numpy(), ref_final_features_archive["1-1"].numpy() + ) + for utt_id, mfccs in final_archive: + feature_writer.Write(utt_id, mfccs) + feature_writer.Close() + feature_archive = FeatureArchive(temp_scp_path) + np.testing.assert_allclose( + feature_archive["1-1"].numpy(), ref_final_features_archive["1-1"].numpy() + ) + cmvn_reader.Close() + ref_cmvn_reader.Close() @pytest.mark.order(4) def test_fmllr_sat( - sat_tree_path, sat_model_path, sat_lda_mat_path, sat_dictionary_path, sat_temp_dir, sat_phones + sat_tree_path, + sat_model_path, + sat_align_model_path, + sat_lda_mat_path, + sat_dictionary_path, + sat_temp_dir, + sat_phones, + reference_trans_path, ): lc = LexiconCompiler(position_dependent_phones=False, phones=sat_phones) lc.load_pronunciations(sat_dictionary_path) utt2spk = KaldiMapping() - utt2spk["1"] = "1" + utt2spk["1-1"] = "1" cmvn_file_name = sat_temp_dir.joinpath("cmvn.ark") feature_archive = FeatureArchive( sat_temp_dir.joinpath("mfccs.ark"), @@ -142,22 +219,75 @@ def test_fmllr_sat( alignment_archive = AlignmentArchive(sat_temp_dir.joinpath("ali.ark")) output_file_name = sat_temp_dir.joinpath("trans.ark") spk2utt = KaldiMapping() - spk2utt["1"] = ["1"] + spk2utt["1"] = ["1-1"] + fmllr_computer = FmllrComputer( + sat_align_model_path, sat_model_path, silence_phones=lc.silence_symbols, spk2utt=spk2utt + ) + fmllr_computer.export_transforms(output_file_name, feature_archive, alignment_archive) + assert output_file_name.exists() + feature_archive.close() + alignment_archive.close() + trans_read_specifier = generate_read_specifier(output_file_name) + trans_reader = RandomAccessBaseFloatMatrixReader(trans_read_specifier) + ref_trans_read_specifier = generate_read_specifier(reference_trans_path) + ref_trans_reader = RandomAccessBaseFloatMatrixReader(ref_trans_read_specifier) + trans = trans_reader.Value("1") + ref_trans = ref_trans_reader.Value("1") + np.testing.assert_allclose(trans.numpy(), ref_trans.numpy()) + + +@pytest.mark.order(4) +def test_fmllr_sat_no_two_model( + sat_tree_path, + sat_align_model_path, + sat_lda_mat_path, + sat_dictionary_path, + sat_temp_dir, + sat_phones, + reference_trans_path, +): + lc = LexiconCompiler(position_dependent_phones=False, phones=sat_phones) + lc.load_pronunciations(sat_dictionary_path) + utt2spk = KaldiMapping() + utt2spk["1-1"] = "1" + cmvn_file_name = sat_temp_dir.joinpath("cmvn.ark") + feature_archive = FeatureArchive( + sat_temp_dir.joinpath("mfccs.ark"), + utt2spk=utt2spk, + cmvn_file_name=cmvn_file_name, + lda_mat_file_name=sat_lda_mat_path, + splices=True, + ) + alignment_archive = AlignmentArchive(sat_temp_dir.joinpath("ali.ark")) + output_file_name = sat_temp_dir.joinpath("trans_no_two_model.ark") + spk2utt = KaldiMapping() + spk2utt["1"] = ["1-1"] fmllr_computer = FmllrComputer( - sat_model_path, silence_phones=lc.silence_symbols, spk2utt=spk2utt + sat_align_model_path, + sat_align_model_path, + silence_phones=lc.silence_symbols, + spk2utt=spk2utt, ) fmllr_computer.export_transforms(output_file_name, feature_archive, alignment_archive) assert output_file_name.exists() + feature_archive.close() + alignment_archive.close() @pytest.mark.order(4) def test_fmllr_decode_sat( - sat_tree_path, sat_model_path, sat_lda_mat_path, sat_dictionary_path, sat_temp_dir, sat_phones + sat_tree_path, + sat_model_path, + sat_align_model_path, + sat_lda_mat_path, + sat_dictionary_path, + sat_temp_dir, + sat_phones, ): lc = LexiconCompiler(position_dependent_phones=False, phones=sat_phones) lc.load_pronunciations(sat_dictionary_path) utt2spk = KaldiMapping() - utt2spk["1"] = "1" + utt2spk["1-1"] = "1" cmvn_file_name = sat_temp_dir.joinpath("cmvn.ark") feature_archive = FeatureArchive( sat_temp_dir.joinpath("mfccs.ark"), @@ -169,9 +299,68 @@ def test_fmllr_decode_sat( alignment_archive = LatticeArchive(sat_temp_dir.joinpath("lat.ark"), determinized=False) output_file_name = sat_temp_dir.joinpath("trans_decode.ark") spk2utt = KaldiMapping() - spk2utt["1"] = ["1"] + spk2utt["1"] = ["1-1"] fmllr_computer = FmllrComputer( - sat_model_path, silence_phones=lc.silence_symbols, spk2utt=spk2utt + sat_align_model_path, sat_model_path, silence_phones=lc.silence_symbols, spk2utt=spk2utt ) fmllr_computer.export_transforms(output_file_name, feature_archive, alignment_archive) assert output_file_name.exists() + + +@pytest.mark.order(6) +def test_fmllr_compose( + sat_tree_path, + sat_model_path, + sat_lda_mat_path, + sat_dictionary_path, + sat_temp_dir, + sat_phones, + reference_trans_path, + reference_trans_compose_path, +): + lc = LexiconCompiler(position_dependent_phones=False, phones=sat_phones) + lc.load_pronunciations(sat_dictionary_path) + utt2spk = KaldiMapping() + utt2spk["1-1"] = "1" + cmvn_file_name = sat_temp_dir.joinpath("cmvn.ark") + feature_archive = FeatureArchive( + sat_temp_dir.joinpath("mfccs.ark"), + utt2spk=utt2spk, + cmvn_file_name=cmvn_file_name, + lda_mat_file_name=sat_lda_mat_path, + transform_file_name=reference_trans_path, + splices=True, + ) + previous_transform_archive = MatrixArchive(reference_trans_path) + alignment_archive = AlignmentArchive(sat_temp_dir.joinpath("ali_second_pass.ark")) + output_file_name = sat_temp_dir.joinpath("trans_second_pass.ark") + spk2utt = KaldiMapping() + spk2utt["1"] = ["1-1"] + fmllr_computer = FmllrComputer( + sat_model_path, sat_model_path, silence_phones=lc.silence_symbols, spk2utt=spk2utt + ) + fmllr_computer.export_transforms( + output_file_name, + feature_archive, + alignment_archive, + previous_transform_archive=previous_transform_archive, + write_scp=True, + ) + assert output_file_name.exists() + feature_archive.close() + alignment_archive.close() + trans_read_specifier = generate_read_specifier(output_file_name) + trans_reader = RandomAccessBaseFloatMatrixReader(trans_read_specifier) + ref_trans_composed_read_specifier = generate_read_specifier(reference_trans_compose_path) + ref_trans_composed_reader = RandomAccessBaseFloatMatrixReader( + ref_trans_composed_read_specifier + ) + ref_trans_read_specifier = generate_read_specifier(reference_trans_path) + ref_trans_reader = RandomAccessBaseFloatMatrixReader(ref_trans_read_specifier) + trans = trans_reader.Value("1") + ref_trans = ref_trans_reader.Value("1") + ref_trans_composed = ref_trans_composed_reader.Value("1") + np.testing.assert_allclose(trans.numpy(), ref_trans_composed.numpy()) + np.testing.assert_raises( + AssertionError, np.testing.assert_allclose, trans.numpy(), ref_trans.numpy() + ) diff --git a/tests/test_pitch.py b/tests/test_pitch.py index 3efed45..fb4af1d 100644 --- a/tests/test_pitch.py +++ b/tests/test_pitch.py @@ -18,33 +18,33 @@ def test_generate_pitch(wav_path): def test_export_pitch(wav_path, temp_dir): output_file_name = temp_dir.joinpath("mfccs.ark") feature_generator = PitchComputer(snip_edges=False) - segments = {"1": Segment(wav_path, 1, 2)} + segments = {"1-1": Segment(wav_path)} feature_generator.export_feats(output_file_name, segments.items()) assert output_file_name.exists() archive = FeatureArchive(output_file_name) for utt, pitch in archive: pitch = pitch.numpy() - assert utt == "1" - assert pitch.shape[0] == 100 + assert utt == "1-1" + assert pitch.shape[0] == 2672 assert pitch.shape[1] == 3 - pitch = archive["1"].numpy() - assert pitch.shape[0] == 100 + pitch = archive["1-1"].numpy() + assert pitch.shape[0] == 2672 assert pitch.shape[1] == 3 archive.close() os.remove(output_file_name) feature_generator.export_feats( - output_file_name, segments.items(), write_scp=True, compress=True + output_file_name, segments.items(), write_scp=True, compress=False ) assert output_file_name.with_suffix(".scp").exists() archive = FeatureArchive(output_file_name.with_suffix(".scp")) for utt, pitch in archive: pitch = pitch.numpy() - assert utt == "1" - assert pitch.shape[0] == 100 + assert utt == "1-1" + assert pitch.shape[0] == 2672 assert pitch.shape[1] == 3 - pitch = archive["1"].numpy() - assert pitch.shape[0] == 100 + pitch = archive["1-1"].numpy() + assert pitch.shape[0] == 2672 assert pitch.shape[1] == 3