From 5840eeaaec74ae7fe80aabd2105da10b44f5cd21 Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Sun, 17 Sep 2023 16:23:08 -0700 Subject: [PATCH] Fix for bug in TextGrid creation (#6) --- extensions/fstext/fstext.cpp | 111 +++++++++++++++++++++---------- kalpy/decoder/training_graphs.py | 2 +- kalpy/gmm/data.py | 6 +- 3 files changed, 82 insertions(+), 37 deletions(-) diff --git a/extensions/fstext/fstext.cpp b/extensions/fstext/fstext.cpp index c112aa0..3a0d3de 100644 --- a/extensions/fstext/fstext.cpp +++ b/extensions/fstext/fstext.cpp @@ -1875,6 +1875,35 @@ void init_fstext(py::module &_m) { py::arg("max_states") = -1, py::arg("use_log") = false); + m.def("fst_determinize_star", + []( + py::object fst, + float delta = kDelta, + int max_states = -1, + bool use_log = false + ){ + auto pywrapfst_mod = py::module_::import("pywrapfst"); + auto ptr = reinterpret_cast(fst.ptr()); + auto mf = ptr->__pyx_base._mfst->GetMutableFst(); + VectorFst vf(*mf); + py::gil_scoped_release gil_release; + bool debug_location = false; + + ArcSort(&vf, ILabelCompare()); // improves speed. + if (use_log) { + DeterminizeStarInLog(&vf, delta, &debug_location, max_states); + return vf; + } else { + VectorFst det_fst; + DeterminizeStar(vf, &det_fst, delta, &debug_location, max_states); + return det_fst; + } + }, + py::arg("fst"), + py::arg("delta") = kDelta, + py::arg("max_states") = -1, + py::arg("use_log") = false); + m.def("fst_is_stochastic", []( Fst *fst, @@ -1892,44 +1921,26 @@ void init_fstext(py::module &_m) { py::arg("delta") = 0.01, py::arg("test_in_log") = true); - m.def("fst_make_context_fst", + m.def("fst_is_stochastic", []( - Fst *fst, - int32 subseq_sym, - std::vector phone_syms, - std::vector disambig_in, - int32 context_width = 3, int32 central_position = 1 + py::object fst, + float delta = 0.01, + bool test_in_log = true ){ - py::gil_scoped_release gil_release; - - StdVectorFst loop_fst; - loop_fst.AddState(); // Add state zero. - loop_fst.SetStart(0); - loop_fst.SetFinal(0, TropicalWeight::One()); - for (size_t i = 0; i < phone_syms.size(); i++) { - int32 sym = phone_syms[i]; - loop_fst.AddArc(0, StdArc(sym, sym, TropicalWeight::One(), 0)); - } - - std::vector > ilabels; - VectorFst context_fst; - ComposeContext(disambig_in, context_width, central_position, - &loop_fst, &context_fst, &ilabels, true); - - std::vector disambig_out; - for (size_t i = 0; i < ilabels.size(); i++) - if (ilabels[i].size() == 1 && ilabels[i][0] <= 0) - disambig_out.push_back(static_cast(i)); + auto pywrapfst_mod = py::module_::import("pywrapfst"); + auto ptr = reinterpret_cast(fst.ptr()); + auto mf = ptr->__pyx_base._mfst->GetMutableFst(); + VectorFst vf(*mf); + bool ans; + StdArc::Weight min, max; + if (test_in_log) ans = IsStochasticFstInLog(vf, delta, &min, &max); + else ans = IsStochasticFst(vf, delta, &min, &max); + return py::make_tuple(ans, min, max); - py::gil_scoped_acquire acquire; - return py::make_tuple(context_fst, disambig_out); - }, + }, py::arg("fst"), - py::arg("subseq_sym"), - py::arg("phone_syms"), - py::arg("disambig_in"), - py::arg("context_width") = 3, - py::arg("central_position") = 1); + py::arg("delta") = 0.01, + py::arg("test_in_log") = true); m.def("fst_make_context_syms", []( @@ -1965,6 +1976,22 @@ void init_fstext(py::module &_m) { py::arg("fst"), py::arg("delta") = kDelta); + m.def("fst_minimize_encoded", + []( + py::object fst, + float delta = kDelta + ){ + auto pywrapfst_mod = py::module_::import("pywrapfst"); + auto ptr = reinterpret_cast(fst.ptr()); + auto mf = ptr->__pyx_base._mfst->GetMutableFst(); + VectorFst vf(*mf); + py::gil_scoped_release gil_release; + + MinimizeEncoded(&vf, delta); + }, + py::arg("fst"), + py::arg("delta") = kDelta); + m.def("fst_phi_compose", []( VectorFst *fst1, @@ -1997,6 +2024,22 @@ void init_fstext(py::module &_m) { py::arg("fst"), py::arg("delta") = kDelta); + m.def("fst_push_special", + []( + py::object fst, + BaseFloat delta = kDelta + ){ + auto pywrapfst_mod = py::module_::import("pywrapfst"); + auto ptr = reinterpret_cast(fst.ptr()); + auto mf = ptr->__pyx_base._mfst->GetMutableFst(); + VectorFst vf(*mf); + py::gil_scoped_release gil_release; + + PushSpecial(&vf, delta); + }, + py::arg("fst"), + py::arg("delta") = kDelta); + m.def("fst_arc_sort", []( VectorFst *fst, diff --git a/kalpy/decoder/training_graphs.py b/kalpy/decoder/training_graphs.py index 54388d3..30e3dfc 100644 --- a/kalpy/decoder/training_graphs.py +++ b/kalpy/decoder/training_graphs.py @@ -252,7 +252,7 @@ def compile_fst(self, transcript: str) -> typing.Optional[VectorFst]: lg_fst = pynini.determinize(lg_fst, nstate=state_threshold, weight=weight_threshold) lg_fst = VectorFst.from_pynini(lg_fst) - fst_determinize_star(lg_fst, use_log=True) + lg_fst = fst_determinize_star(lg_fst, use_log=True) fst_minimize_encoded(lg_fst) fst_push_special(lg_fst) clg_fst, disambig_out, ilabels = fst_compose_context( diff --git a/kalpy/gmm/data.py b/kalpy/gmm/data.py index 8f7afe6..30ad833 100644 --- a/kalpy/gmm/data.py +++ b/kalpy/gmm/data.py @@ -144,9 +144,9 @@ def to_tg_interval(self, file_duration=None) -> Interval: end = round(self.end, 6) begin = round(self.begin, 6) if file_duration is not None and end > file_duration: - end = round(file_duration, 6) + end = file_duration assert begin < end - return Interval(round(self.begin, 6), end, self.label) + return Interval(begin, end, self.label) @dataclassy.dataclass @@ -193,6 +193,8 @@ def export_textgrid( def to_textgrid_tiers( self, file_duration: float = None ) -> typing.Tuple[tgio.IntervalTier, tgio.IntervalTier]: + if file_duration is not None: + file_duration = round(file_duration, 6) word_tier = tgio.IntervalTier("words", [], minT=0.0, maxT=file_duration) phone_tier = tgio.IntervalTier("phones", [], minT=0.0, maxT=file_duration) for w in self.word_intervals: