Skip to content

Commit

Permalink
Fix for bug in TextGrid creation (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcauliffe authored Sep 17, 2023
1 parent 01f55d2 commit 5840eea
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 37 deletions.
111 changes: 77 additions & 34 deletions extensions/fstext/fstext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,35 @@ void init_fstext(py::module &_m) {
py::arg("max_states") = -1,
py::arg("use_log") = false);

m.def("fst_determinize_star",
[](
py::object fst,
float delta = kDelta,
int max_states = -1,
bool use_log = false
){
auto pywrapfst_mod = py::module_::import("pywrapfst");
auto ptr = reinterpret_cast<VectorFstStruct*>(fst.ptr());
auto mf = ptr->__pyx_base._mfst->GetMutableFst<StdArc>();
VectorFst<StdArc> vf(*mf);
py::gil_scoped_release gil_release;
bool debug_location = false;

ArcSort(&vf, ILabelCompare<StdArc>()); // improves speed.
if (use_log) {
DeterminizeStarInLog(&vf, delta, &debug_location, max_states);
return vf;
} else {
VectorFst<StdArc> det_fst;
DeterminizeStar(vf, &det_fst, delta, &debug_location, max_states);
return det_fst;
}
},
py::arg("fst"),
py::arg("delta") = kDelta,
py::arg("max_states") = -1,
py::arg("use_log") = false);

m.def("fst_is_stochastic",
[](
Fst<StdArc> *fst,
Expand All @@ -1892,44 +1921,26 @@ void init_fstext(py::module &_m) {
py::arg("delta") = 0.01,
py::arg("test_in_log") = true);

m.def("fst_make_context_fst",
m.def("fst_is_stochastic",
[](
Fst<StdArc> *fst,
int32 subseq_sym,
std::vector<kaldi::int32> phone_syms,
std::vector<int32> disambig_in,
int32 context_width = 3, int32 central_position = 1
py::object fst,
float delta = 0.01,
bool test_in_log = true
){
py::gil_scoped_release gil_release;

StdVectorFst loop_fst;
loop_fst.AddState(); // Add state zero.
loop_fst.SetStart(0);
loop_fst.SetFinal(0, TropicalWeight::One());
for (size_t i = 0; i < phone_syms.size(); i++) {
int32 sym = phone_syms[i];
loop_fst.AddArc(0, StdArc(sym, sym, TropicalWeight::One(), 0));
}

std::vector<std::vector<int32> > ilabels;
VectorFst<StdArc> context_fst;
ComposeContext(disambig_in, context_width, central_position,
&loop_fst, &context_fst, &ilabels, true);

std::vector<int32> disambig_out;
for (size_t i = 0; i < ilabels.size(); i++)
if (ilabels[i].size() == 1 && ilabels[i][0] <= 0)
disambig_out.push_back(static_cast<int32>(i));
auto pywrapfst_mod = py::module_::import("pywrapfst");
auto ptr = reinterpret_cast<VectorFstStruct*>(fst.ptr());
auto mf = ptr->__pyx_base._mfst->GetMutableFst<StdArc>();
VectorFst<StdArc> vf(*mf);
bool ans;
StdArc::Weight min, max;
if (test_in_log) ans = IsStochasticFstInLog(vf, delta, &min, &max);
else ans = IsStochasticFst(vf, delta, &min, &max);
return py::make_tuple(ans, min, max);

py::gil_scoped_acquire acquire;
return py::make_tuple(context_fst, disambig_out);
},
},
py::arg("fst"),
py::arg("subseq_sym"),
py::arg("phone_syms"),
py::arg("disambig_in"),
py::arg("context_width") = 3,
py::arg("central_position") = 1);
py::arg("delta") = 0.01,
py::arg("test_in_log") = true);

m.def("fst_make_context_syms",
[](
Expand Down Expand Up @@ -1965,6 +1976,22 @@ void init_fstext(py::module &_m) {
py::arg("fst"),
py::arg("delta") = kDelta);

m.def("fst_minimize_encoded",
[](
py::object fst,
float delta = kDelta
){
auto pywrapfst_mod = py::module_::import("pywrapfst");
auto ptr = reinterpret_cast<VectorFstStruct*>(fst.ptr());
auto mf = ptr->__pyx_base._mfst->GetMutableFst<StdArc>();
VectorFst<StdArc> vf(*mf);
py::gil_scoped_release gil_release;

MinimizeEncoded(&vf, delta);
},
py::arg("fst"),
py::arg("delta") = kDelta);

m.def("fst_phi_compose",
[](
VectorFst<StdArc> *fst1,
Expand Down Expand Up @@ -1997,6 +2024,22 @@ void init_fstext(py::module &_m) {
py::arg("fst"),
py::arg("delta") = kDelta);

m.def("fst_push_special",
[](
py::object fst,
BaseFloat delta = kDelta
){
auto pywrapfst_mod = py::module_::import("pywrapfst");
auto ptr = reinterpret_cast<VectorFstStruct*>(fst.ptr());
auto mf = ptr->__pyx_base._mfst->GetMutableFst<StdArc>();
VectorFst<StdArc> vf(*mf);
py::gil_scoped_release gil_release;

PushSpecial(&vf, delta);
},
py::arg("fst"),
py::arg("delta") = kDelta);

m.def("fst_arc_sort",
[](
VectorFst<StdArc> *fst,
Expand Down
2 changes: 1 addition & 1 deletion kalpy/decoder/training_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def compile_fst(self, transcript: str) -> typing.Optional[VectorFst]:
lg_fst = pynini.determinize(lg_fst, nstate=state_threshold, weight=weight_threshold)
lg_fst = VectorFst.from_pynini(lg_fst)

fst_determinize_star(lg_fst, use_log=True)
lg_fst = fst_determinize_star(lg_fst, use_log=True)
fst_minimize_encoded(lg_fst)
fst_push_special(lg_fst)
clg_fst, disambig_out, ilabels = fst_compose_context(
Expand Down
6 changes: 4 additions & 2 deletions kalpy/gmm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def to_tg_interval(self, file_duration=None) -> Interval:
end = round(self.end, 6)
begin = round(self.begin, 6)
if file_duration is not None and end > file_duration:
end = round(file_duration, 6)
end = file_duration
assert begin < end
return Interval(round(self.begin, 6), end, self.label)
return Interval(begin, end, self.label)


@dataclassy.dataclass
Expand Down Expand Up @@ -193,6 +193,8 @@ def export_textgrid(
def to_textgrid_tiers(
self, file_duration: float = None
) -> typing.Tuple[tgio.IntervalTier, tgio.IntervalTier]:
if file_duration is not None:
file_duration = round(file_duration, 6)
word_tier = tgio.IntervalTier("words", [], minT=0.0, maxT=file_duration)
phone_tier = tgio.IntervalTier("phones", [], minT=0.0, maxT=file_duration)
for w in self.word_intervals:
Expand Down

0 comments on commit 5840eea

Please sign in to comment.