diff --git a/evaluate_oov.sh b/evaluate_oov.sh new file mode 100755 index 000000000..f0466f305 --- /dev/null +++ b/evaluate_oov.sh @@ -0,0 +1,28 @@ +#!/bin/bash +set -e + + +if [ $# -eq 0 ]; then + echo "\n This script compares the performance of a given AM on both OOV and non-OOV testing sets with the use of an external scorer." + echo "\n It works on the data prepared by oov_lm_prep.sh" + echo -e "\n Usage: \n $0 \n" + exit 1 +fi + +am=$1 +scorer=tmp/lm/kenlm.scorer +nj=$(nproc) + +mkdir -p tmp/results + +echo "Evaluating Using Scorer" + +echo "Case (1): Evaluating on OOV testing set." +python -m coqui_stt_training.evaluate --test_files tmp/oov_corpus.csv \ + --test_output_file tmp/results/oov_results.json --scorer_path $scorer \ + --checkpoint_dir $am --test_batch_size $nj + +echo "Case (2): Evaluating on original testing set." +python -m coqui_stt_training.evaluate --test_files tmp/scorer_corpus.csv \ + --test_output_file tmp/results/samples.json --scorer_path $scorer \ + --checkpoint_dir $am --test_batch_size $nj diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 179ec467f..5c7d0df7a 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -19,13 +19,12 @@ namespace flt = fl::lib::text; -int -DecoderState::init(const Alphabet& alphabet, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - std::shared_ptr ext_scorer, - std::unordered_map hot_words) +int DecoderState::init(const Alphabet &alphabet, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + std::shared_ptr ext_scorer, + std::unordered_map hot_words) { // assign special ids abs_time_step_ = 0; @@ -57,10 +56,9 @@ DecoderState::init(const Alphabet& alphabet, return 0; } -void -DecoderState::next(const double *probs, - int time_dim, - int class_dim) +void DecoderState::next(const double *probs, + int time_dim, + int class_dim) { // prefix search over time for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++abs_time_step_) { @@ -73,7 +71,6 @@ DecoderState::next(const double *probs, if (prob[blank_id_] < 0.999) { start_expanding_ = true; } - // If not expanding yet, just continue to next timestep. if (!start_expanding_) { continue; @@ -92,14 +89,12 @@ DecoderState::next(const double *probs, std::log(prob[blank_id_]) - std::max(0.0, ext_scorer_->beta); full_beam = (num_prefixes == beam_size_); } - std::vector> log_prob_idx = get_pruned_log_probs(prob, class_dim, cutoff_prob_, cutoff_top_n_); // loop over class dim for (size_t index = 0; index < log_prob_idx.size(); index++) { auto c = log_prob_idx[index].first; auto log_prob_c = log_prob_idx[index].second; - for (size_t i = 0; i < prefixes_.size() && i < beam_size_; ++i) { auto prefix = prefixes_[i]; if (full_beam && log_prob_c + prefix->score < min_cutoff) { @@ -141,7 +136,10 @@ DecoderState::next(const double *probs, } // get new prefix - auto prefix_new = prefix->get_path_trie(c, log_prob_c); + const bool is_scoring_boundary = + ext_scorer_ && ext_scorer_->is_scoring_boundary(prefix, c); + auto prefix_new = prefix->get_path_trie(c, log_prob_c, is_scoring_boundary); + bool is_oov = prefix_new->first_time_oov; if (prefix_new != nullptr) { // compute probability of current path @@ -155,6 +153,9 @@ DecoderState::next(const double *probs, } if (ext_scorer_) { + if (is_oov) { + log_p += -10; // ext_scorer_->get_unk_log_cond_prob(); + } // skip scoring the space in word based LMs PathTrie* prefix_to_score; if (ext_scorer_->is_utf8_mode()) { @@ -164,7 +165,7 @@ DecoderState::next(const double *probs, } // language model scoring - if (ext_scorer_->is_scoring_boundary(prefix_to_score, c)) { + if (is_scoring_boundary) { float score = 0.0; std::vector ngram; ngram = ext_scorer_->make_ngram(prefix_to_score); @@ -176,7 +177,7 @@ DecoderState::next(const double *probs, // that matches a word in the hot-words list for (std::string word : ngram) { iter = hot_words_.find(word); - if ( iter != hot_words_.end() ) { + if (iter != hot_words_.end()) { // increase the log_cond_prob(prefix|LM) hot_boost += iter->second; } @@ -184,7 +185,20 @@ DecoderState::next(const double *probs, } bool bos = ngram.size() < ext_scorer_->get_max_order(); - score = ( ext_scorer_->get_log_cond_prob(ngram, bos) + hot_boost ) * ext_scorer_->alpha; + float raw_score = ext_scorer_->get_log_cond_prob(ngram, bos); + score = (raw_score + hot_boost) * ext_scorer_->alpha; + #ifdef DEBUG + if (abs_time_step_ > 40 && abs_time_step_ <= 40) + { + printf("[%03d], p_i = %02d, ngram = ", abs_time_step_, i); + for (const std::string &word : ngram) + { + printf("%s ", word.c_str()); + } + printf("= %.2f (scaled = %.2f). prefix change: %.2f --> %.2f\n", raw_score, score, log_p, log_p + score + ext_scorer_->beta); + } + #endif + log_p += score; log_p += ext_scorer_->beta; } @@ -199,9 +213,23 @@ DecoderState::next(const double *probs, } prefix_new->log_prob_nb_cur = log_sum_exp(prefix_new->log_prob_nb_cur, log_p); - } - } // end of loop over prefix - } // end of loop over alphabet + } + #ifdef DEBUG + if (abs_time_step_ > 40 && abs_time_step_ <= 40 && c != 0) + { + + std::vector ngram; + ngram = ext_scorer_->make_ngram(prefix_new); + printf("[%03d], c= %02d, p_i = %02d, scoring ngram = ", abs_time_step_, c, i); + for (const std::string &word : ngram) + { + printf("%s ", word.c_str()); + } + printf(", new_score= %.3f\n", log_sum_exp(prefix_new->log_prob_b_cur, prefix_new->log_prob_nb_cur)); + } + #endif + } // end of loop over prefix + } // end of loop over alphabet // update log probs prefixes_.clear(); @@ -226,18 +254,18 @@ DecoderState::next(const double *probs, std::vector DecoderState::decode(size_t num_results) const { - std::vector prefixes_copy = prefixes_; - std::unordered_map scores; - for (PathTrie* prefix : prefixes_copy) { + std::vector prefixes_copy = prefixes_; + std::unordered_map scores; + for (PathTrie *prefix : prefixes_copy) + { scores[prefix] = prefix->score; } // score the last word of each prefix that doesn't end with space if (ext_scorer_) { for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) { - PathTrie* prefix = prefixes_copy[i]; - PathTrie* prefix_boundary = ext_scorer_->is_utf8_mode() ? prefix : prefix->parent; - if (prefix_boundary && !ext_scorer_->is_scoring_boundary(prefix_boundary, prefix->character)) { + PathTrie *prefix = prefixes_copy[i]; + if (prefix->parent && !ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) { float score = 0.0; std::vector ngram = ext_scorer_->make_ngram(prefix); bool bos = ngram.size() < ext_scorer_->get_max_order(); @@ -261,7 +289,7 @@ DecoderState::decode(size_t num_results) const for (size_t i = 0; i < num_returned; ++i) { Output output; prefixes_copy[i]->get_path_vec(output.tokens); - output.timesteps = get_history(prefixes_copy[i]->timesteps, ×tep_tree_root_); + output.timesteps = get_history(prefixes_copy[i]->timesteps, ×tep_tree_root_); assert(output.tokens.size() == output.timesteps.size()); output.confidence = scores[prefixes_copy[i]]; outputs.push_back(output); @@ -272,18 +300,18 @@ DecoderState::decode(size_t num_results) const int FlashlightDecoderState::init( - const Alphabet& alphabet, - size_t beam_size, - double beam_threshold, - size_t cutoff_top_n, - std::shared_ptr ext_scorer, - FlashlightDecoderState::LMTokenType token_type, - flt::Dictionary lm_tokens, - FlashlightDecoderState::DecoderType decoder_type, - double silence_score, - bool merge_with_log_add, - FlashlightDecoderState::CriterionType criterion_type, - std::vector transitions) + const Alphabet &alphabet, + size_t beam_size, + double beam_threshold, + size_t cutoff_top_n, + std::shared_ptr ext_scorer, + FlashlightDecoderState::LMTokenType token_type, + flt::Dictionary lm_tokens, + FlashlightDecoderState::DecoderType decoder_type, + double silence_score, + bool merge_with_log_add, + FlashlightDecoderState::CriterionType criterion_type, + std::vector transitions) { // Lexicon-free decoder must use single-token based LM if (decoder_type == LexiconFree) { @@ -299,17 +327,20 @@ FlashlightDecoderState::init( // Convert our criterion type to Flashlight type flt::CriterionType flt_criterion; - switch (criterion_type) { - case ASG: flt_criterion = flt::CriterionType::ASG; break; - case CTC: flt_criterion = flt::CriterionType::CTC; break; - case S2S: flt_criterion = flt::CriterionType::S2S; break; - default: assert(false); + switch (criterion_type){ + case ASG: flt_criterion = flt::CriterionType::ASG; + break; + case CTC: flt_criterion = flt::CriterionType::CTC; + break; + case S2S: flt_criterion = flt::CriterionType::S2S; + break; + default: assert(false); } // Build Trie std::shared_ptr trie = nullptr; auto startState = ext_scorer->start(false); - if (token_type == Aggregate || decoder_type == LexiconBased) { + if (token_type == Aggregate || decoder_type == LexiconBased) { trie = std::make_shared(lm_tokens.indexSize(), alphabet.GetSpaceLabel()); for (int i = 0; i < lm_tokens.entrySize(); ++i) { const std::string entry = lm_tokens.getEntry(i); @@ -333,16 +364,18 @@ FlashlightDecoderState::init( // Query unknown token score int unknown_word_index = lm_tokens.getIndex(""); float unknown_score = -std::numeric_limits::infinity(); - if (token_type == Aggregate) { + if (token_type == Aggregate) + { std::tie(std::ignore, unknown_score) = - ext_scorer->score(startState, unknown_word_index); + ext_scorer->score(startState, unknown_word_index); } // Make sure conversions from uint to int below don't trip us assert(beam_size < INT_MAX); assert(cutoff_top_n < INT_MAX); - if (decoder_type == LexiconBased) { + if (decoder_type == LexiconBased) + { flt::LexiconDecoderOptions opts; opts.beamSize = static_cast(beam_size); opts.beamSizeToken = static_cast(cutoff_top_n); @@ -354,16 +387,17 @@ FlashlightDecoderState::init( opts.logAdd = merge_with_log_add; opts.criterionType = flt_criterion; decoder_impl_.reset(new flt::LexiconDecoder( - opts, - trie, - ext_scorer, - alphabet.GetSpaceLabel(), // silence index - alphabet.GetSize(), // blank index - unknown_word_index, - transitions, - token_type == Single) - ); - } else { + opts, + trie, + ext_scorer, + alphabet.GetSpaceLabel(), // silence index + alphabet.GetSize(), // blank index + unknown_word_index, + transitions, + token_type == Single)); + } + else + { flt::LexiconFreeDecoderOptions opts; opts.beamSize = static_cast(beam_size); opts.beamSizeToken = static_cast(cutoff_top_n); @@ -373,12 +407,11 @@ FlashlightDecoderState::init( opts.logAdd = merge_with_log_add; opts.criterionType = flt_criterion; decoder_impl_.reset(new flt::LexiconFreeDecoder( - opts, - ext_scorer, - alphabet.GetSpaceLabel(), // silence index - alphabet.GetSize(), // blank index - transitions) - ); + opts, + ext_scorer, + alphabet.GetSpaceLabel(), // silence index + alphabet.GetSize(), // blank index + transitions)); } // Init decoder for stream @@ -387,11 +420,10 @@ FlashlightDecoderState::init( return 0; } -void -FlashlightDecoderState::next( - const double *probs, - int time_dim, - int class_dim) +void FlashlightDecoderState::next( + const double *probs, + int time_dim, + int class_dim) { std::vector probs_f(probs, probs + (time_dim * class_dim) + 1); decoder_impl_->decodeStep(probs_f.data(), time_dim, class_dim); @@ -402,8 +434,10 @@ FlashlightDecoderState::intermediate(bool prune) { flt::DecodeResult result = decoder_impl_->getBestHypothesis(); std::vector valid_words; - for (int w : result.words) { - if (w != -1) { + for (int w : result.words) + { + if (w != -1) + { valid_words.push_back(w); } } @@ -413,7 +447,8 @@ FlashlightDecoderState::intermediate(bool prune) ret.language_model_score = result.lmScore; ret.words = lm_tokens_.mapIndicesToEntries(valid_words); // how does this interact with token-based decoding ret.tokens = result.tokens; - if (prune) { + if (prune) + { decoder_impl_->prune(); } return ret; @@ -425,10 +460,13 @@ FlashlightDecoderState::decode(size_t num_results) decoder_impl_->decodeEnd(); std::vector flt_results = decoder_impl_->getAllFinalHypothesis(); std::vector ret; - for (auto result : flt_results) { + for (auto result : flt_results) + { std::vector valid_words; - for (int w : result.words) { - if (w != -1) { + for (int w : result.words) + { + if (w != -1) + { valid_words.push_back(w); } } @@ -456,7 +494,7 @@ std::vector ctc_beam_search_decoder( std::unordered_map hot_words, size_t num_results) { - VALID_CHECK_EQ(alphabet.GetSize()+1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model."); + VALID_CHECK_EQ(alphabet.GetSize() + 1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model."); DecoderState state; state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer, hot_words); state.next(probs, time_dim, class_dim); @@ -469,7 +507,7 @@ ctc_beam_search_decoder_batch( int batch_size, int time_dim, int class_dim, - const int* seq_lengths, + const int *seq_lengths, int seq_lengths_size, const Alphabet &alphabet, size_t beam_size, @@ -487,9 +525,10 @@ ctc_beam_search_decoder_batch( // enqueue the tasks of decoding std::vector>> res; - for (size_t i = 0; i < batch_size; ++i) { + for (size_t i = 0; i < batch_size; ++i) + { res.emplace_back(pool.enqueue(ctc_beam_search_decoder, - &probs[i*time_dim*class_dim], + &probs[i * time_dim * class_dim], seq_lengths[i], class_dim, alphabet, @@ -503,7 +542,8 @@ ctc_beam_search_decoder_batch( // get decoding results std::vector> batch_results; - for (size_t i = 0; i < batch_size; ++i) { + for (size_t i = 0; i < batch_size; ++i) + { batch_results.emplace_back(res[i].get()); } return batch_results; @@ -511,16 +551,16 @@ ctc_beam_search_decoder_batch( std::vector flashlight_beam_search_decoder( - const double* probs, + const double *probs, int time_dim, int class_dim, - const Alphabet& alphabet, + const Alphabet &alphabet, size_t beam_size, double beam_threshold, size_t cutoff_top_n, std::shared_ptr ext_scorer, FlashlightDecoderState::LMTokenType token_type, - const std::vector& lm_tokens, + const std::vector &lm_tokens, FlashlightDecoderState::DecoderType decoder_type, double silence_score, bool merge_with_log_add, @@ -528,25 +568,26 @@ flashlight_beam_search_decoder( std::vector transitions, size_t num_results) { - VALID_CHECK_EQ(alphabet.GetSize()+1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model."); + VALID_CHECK_EQ(alphabet.GetSize() + 1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model."); flt::Dictionary tokens_dict; - for (auto str : lm_tokens) { + for (auto str : lm_tokens) + { tokens_dict.addEntry(str); } FlashlightDecoderState state; state.init( - alphabet, - beam_size, - beam_threshold, - cutoff_top_n, - ext_scorer, - token_type, - tokens_dict, - decoder_type, - silence_score, - merge_with_log_add, - criterion_type, - transitions); + alphabet, + beam_size, + beam_threshold, + cutoff_top_n, + ext_scorer, + token_type, + tokens_dict, + decoder_type, + silence_score, + merge_with_log_add, + criterion_type, + transitions); state.next(probs, time_dim, class_dim); return state.decode(num_results); } @@ -557,15 +598,15 @@ flashlight_beam_search_decoder_batch( int batch_size, int time_dim, int class_dim, - const int* seq_lengths, + const int *seq_lengths, int seq_lengths_size, - const Alphabet& alphabet, + const Alphabet &alphabet, size_t beam_size, double beam_threshold, size_t cutoff_top_n, std::shared_ptr ext_scorer, FlashlightDecoderState::LMTokenType token_type, - const std::vector& lm_tokens, + const std::vector &lm_tokens, FlashlightDecoderState::DecoderType decoder_type, double silence_score, bool merge_with_log_add, @@ -581,9 +622,10 @@ flashlight_beam_search_decoder_batch( // enqueue the tasks of decoding std::vector>> res; - for (size_t i = 0; i < batch_size; ++i) { + for (size_t i = 0; i < batch_size; ++i) + { res.emplace_back(pool.enqueue(flashlight_beam_search_decoder, - &probs[i*time_dim*class_dim], + &probs[i * time_dim * class_dim], seq_lengths[i], class_dim, alphabet, @@ -603,7 +645,8 @@ flashlight_beam_search_decoder_batch( // get decoding results std::vector> batch_results; - for (size_t i = 0; i < batch_size; ++i) { + for (size_t i = 0; i < batch_size; ++i) + { batch_results.emplace_back(res[i].get()); } diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.h b/native_client/ctcdecode/ctc_beam_search_decoder.h index 2176565ca..2c6d62882 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.h +++ b/native_client/ctcdecode/ctc_beam_search_decoder.h @@ -27,6 +27,12 @@ class DecoderState TimestepTreeNode timestep_tree_root_{nullptr, 0}; std::unordered_map hot_words_; +#ifdef DEBUG + void drawdot(const char* format, ...) const; +#else + void drawdot(...) const {}; +#endif + public: DecoderState() = default; ~DecoderState() = default; diff --git a/native_client/ctcdecode/path_trie.cpp b/native_client/ctcdecode/path_trie.cpp index e68b1ca79..206af5be0 100644 --- a/native_client/ctcdecode/path_trie.cpp +++ b/native_client/ctcdecode/path_trie.cpp @@ -8,14 +8,16 @@ #include "decoder_utils.h" -PathTrie::PathTrie() { +PathTrie::PathTrie() +{ log_prob_b_prev = -NUM_FLT_INF; log_prob_nb_prev = -NUM_FLT_INF; log_prob_b_cur = -NUM_FLT_INF; log_prob_nb_cur = -NUM_FLT_INF; log_prob_c = -NUM_FLT_INF; score = -NUM_FLT_INF; - + oov_word = false; + first_time_oov = false; ROOT_ = -1; character = ROOT_; exists_ = true; @@ -28,21 +30,28 @@ PathTrie::PathTrie() { matcher_ = nullptr; } -PathTrie::~PathTrie() { - for (auto child : children_) { +PathTrie::~PathTrie() +{ + for (auto child : children_) + { delete child.second; } } -PathTrie* PathTrie::get_path_trie(unsigned int new_char, float cur_log_prob_c, bool reset) { +PathTrie *PathTrie::get_path_trie(unsigned int new_char, float cur_log_prob_c, bool reset, bool is_scoring_boundary) +{ auto child = children_.begin(); - for (; child != children_.end(); ++child) { - if (child->first == new_char) { + for (; child != children_.end(); ++child) + { + if (child->first == new_char) + { break; } } - if (child != children_.end()) { - if (!child->second->exists_) { + if (child != children_.end()) + { + if (!child->second->exists_) + { child->second->exists_ = true; child->second->log_prob_b_prev = -NUM_FLT_INF; child->second->log_prob_nb_prev = -NUM_FLT_INF; @@ -50,46 +59,82 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, float cur_log_prob_c, b child->second->log_prob_nb_cur = -NUM_FLT_INF; } return child->second; - } else { - if (has_dictionary_) { - matcher_->SetState(dictionary_state_); - bool found = matcher_->Find(new_char + 1); - if (!found) { - // Adding this character causes word outside dictionary - auto FSTZERO = fst::TropicalWeight::Zero(); - auto final_weight = dictionary_->Final(dictionary_state_); - bool is_final = (final_weight != FSTZERO); - if (is_final && reset) { - dictionary_state_ = dictionary_->Start(); + } + else + { + if (has_dictionary_) + { + if (!first_time_oov) + { + matcher_->SetState(dictionary_state_); + bool found = matcher_->Find(new_char + 1); + if (!found) + { + PathTrie *new_path = new PathTrie; + new_path->character = new_char; + new_path->parent = this; + new_path->log_prob_c = cur_log_prob_c; + new_path->oov_word = true; + new_path->first_time_oov = true; + children_.push_back(std::make_pair(new_char, new_path)); + return new_path; } - return nullptr; - } else { - PathTrie* new_path = new PathTrie; + else + { + PathTrie *new_path = new PathTrie; + new_path->character = new_char; + new_path->parent = this; + new_path->dictionary_ = dictionary_; + new_path->has_dictionary_ = true; + new_path->matcher_ = matcher_; + new_path->log_prob_c = cur_log_prob_c; + new_path->oov_word = false; + + // set spell checker state + // check to see if next state is final + auto FSTZERO = fst::TropicalWeight::Zero(); + auto final_weight = dictionary_->Final(dictionary_state_); + if (found) + { + final_weight = dictionary_->Final(matcher_->Value().nextstate); + } + bool is_final = (final_weight != FSTZERO); + if (is_final && reset) + { + // restart spell checker at the start state + new_path->dictionary_state_ = dictionary_->Start(); + } + else + { + // go to next state + new_path->dictionary_state_ = matcher_->Value().nextstate; + } + + children_.push_back(std::make_pair(new_char, new_path)); + return new_path; + } + } + else + { + PathTrie *new_path = new PathTrie; new_path->character = new_char; new_path->parent = this; - new_path->dictionary_ = dictionary_; - new_path->has_dictionary_ = true; - new_path->matcher_ = matcher_; new_path->log_prob_c = cur_log_prob_c; - - // set spell checker state - // check to see if next state is final - auto FSTZERO = fst::TropicalWeight::Zero(); - auto final_weight = dictionary_->Final(matcher_->Value().nextstate); - bool is_final = (final_weight != FSTZERO); - if (is_final && reset) { + new_path->oov_word = true; + new_path->first_time_oov = false; + if (is_scoring_boundary) + { // restart spell checker at the start state new_path->dictionary_state_ = dictionary_->Start(); - } else { - // go to next state - new_path->dictionary_state_ = matcher_->Value().nextstate; + new_path->oov_word = false; } - children_.push_back(std::make_pair(new_char, new_path)); return new_path; } - } else { - PathTrie* new_path = new PathTrie; + } + else + { + PathTrie *new_path = new PathTrie; new_path->character = new_char; new_path->parent = this; new_path->log_prob_c = cur_log_prob_c; @@ -99,27 +144,32 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, float cur_log_prob_c, b } } -void PathTrie::get_path_vec(std::vector& output) { +void PathTrie::get_path_vec(std::vector &output) +{ // Recursive call: recurse back until stop condition, then append data in // correct order as we walk back down the stack in the lines below. - if (parent != nullptr) { + if (parent != nullptr) + { parent->get_path_vec(output); } - if (character != ROOT_) { + if (character != ROOT_) + { output.push_back(character); } } -PathTrie* PathTrie::get_prev_grapheme(std::vector& output, - const Alphabet& alphabet) +PathTrie *PathTrie::get_prev_grapheme(std::vector &output, + const Alphabet &alphabet) { - PathTrie* stop = this; - if (character == ROOT_) { + PathTrie *stop = this; + if (character == ROOT_) + { return stop; } // Recursive call: recurse back until stop condition, then append data in // correct order as we walk back down the stack in the lines below. - if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) { + if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) + { stop = parent->get_prev_grapheme(output, alphabet); } output.push_back(character); @@ -127,42 +177,49 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector& output, } int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte, - const Alphabet& alphabet) + const Alphabet &alphabet) { - if (byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) { + if (byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) + { *first_byte = (unsigned char)character + 1; return 1; } - if (parent != nullptr && parent->character != ROOT_) { + if (parent != nullptr && parent->character != ROOT_) + { return 1 + parent->distance_to_codepoint_boundary(first_byte, alphabet); } assert(false); // unreachable return 0; } -PathTrie* PathTrie::get_prev_word(std::vector& output, - const Alphabet& alphabet) +PathTrie *PathTrie::get_prev_word(std::vector &output, + const Alphabet &alphabet) { - PathTrie* stop = this; - if (character == alphabet.GetSpaceLabel() || character == ROOT_) { + PathTrie *stop = this; + if (character == alphabet.GetSpaceLabel() || character == ROOT_) + { return stop; } // Recursive call: recurse back until stop condition, then append data in // correct order as we walk back down the stack in the lines below. - if (parent != nullptr) { + if (parent != nullptr) + { stop = parent->get_prev_word(output, alphabet); } output.push_back(character); return stop; } -void PathTrie::iterate_to_vec(std::vector& output) { +void PathTrie::iterate_to_vec(std::vector &output) +{ // previous_timesteps might point to ancestors' timesteps // therefore, children must be uptaded first - for (auto child : children_) { + for (auto child : children_) + { child.second->iterate_to_vec(output); } - if (exists_) { + if (exists_) + { log_prob_b_prev = log_prob_b_cur; log_prob_nb_prev = log_prob_nb_cur; @@ -171,16 +228,20 @@ void PathTrie::iterate_to_vec(std::vector& output) { score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); - if (previous_timesteps != nullptr) { + if (previous_timesteps != nullptr) + { timesteps = nullptr; - for (auto const& child : previous_timesteps->children) { - if (child->data == new_timestep) { - timesteps = child.get(); - break; + for (auto const &child : previous_timesteps->children) + { + if (child->data == new_timestep) + { + timesteps = child.get(); + break; } } - if (timesteps == nullptr) { - timesteps = add_child(previous_timesteps, new_timestep); + if (timesteps == nullptr) + { + timesteps = add_child(previous_timesteps, new_timestep); } } previous_timesteps = nullptr; @@ -189,18 +250,23 @@ void PathTrie::iterate_to_vec(std::vector& output) { } } -void PathTrie::remove() { +void PathTrie::remove() +{ exists_ = false; - if (children_.size() == 0) { - for (auto child = parent->children_.begin(); child != parent->children_.end(); ++child) { - if (child->first == character) { + if (children_.size() == 0) + { + for (auto child = parent->children_.begin(); child != parent->children_.end(); ++child) + { + if (child->first == character) + { parent->children_.erase(child); break; } } - if (parent->children_.size() == 0 && !parent->exists_) { + if (parent->children_.size() == 0 && !parent->exists_) + { parent->remove(); } @@ -208,37 +274,45 @@ void PathTrie::remove() { } } -void PathTrie::set_dictionary(std::shared_ptr dictionary) { +void PathTrie::set_dictionary(std::shared_ptr dictionary) +{ dictionary_ = dictionary; dictionary_state_ = dictionary_->Start(); has_dictionary_ = true; } -void PathTrie::set_matcher(std::shared_ptr> matcher) { +void PathTrie::set_matcher(std::shared_ptr> matcher) +{ matcher_ = matcher; } #ifdef DEBUG -void PathTrie::vec(std::vector& out) { - if (parent != nullptr) { +void PathTrie::vec(std::vector &out) +{ + if (parent != nullptr) + { parent->vec(out); } out.push_back(this); } -void PathTrie::print(const Alphabet& a) { - std::vector chain; +void PathTrie::print(const Alphabet &a) +{ + std::vector chain; vec(chain); std::string tr; printf("characters:\t "); - for (PathTrie* el : chain) { + for (PathTrie *el : chain) + { printf("%X ", (unsigned char)(el->character)); - if (el->character != ROOT_) { + if (el->character != ROOT_) + { tr.append(a.DecodeSingle(el->character)); } } printf("\ntimesteps:\t "); - for (unsigned int timestep : get_history(timesteps)) { + for (unsigned int timestep : get_history(timesteps)) + { printf("%d ", timestep); } printf("\n"); diff --git a/native_client/ctcdecode/path_trie.h b/native_client/ctcdecode/path_trie.h index 255c18973..512150ba9 100644 --- a/native_client/ctcdecode/path_trie.h +++ b/native_client/ctcdecode/path_trie.h @@ -49,7 +49,7 @@ class PathTrie { ~PathTrie(); // get new prefix after appending new char - PathTrie* get_path_trie(unsigned int new_char, float log_prob_c, bool reset = true); + PathTrie* get_path_trie(unsigned int new_char, float log_prob_c, bool reset = true, bool is_scoring_boundary = false); // get the prefix data in correct time order from root to current node void get_path_vec(std::vector& output); @@ -92,7 +92,8 @@ class PathTrie { float approx_ctc; unsigned int character; TimestepTreeNode* timesteps = nullptr; - + bool oov_word; + bool first_time_oov; // timestep temporary storage for each decoding step. TimestepTreeNode* previous_timesteps = nullptr; unsigned int new_timestep; diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index 34ad90fb2..e9c425940 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -209,7 +209,9 @@ bool Scorer::is_scoring_boundary(PathTrie* prefix, size_t new_label) return false; } unsigned char first_byte; - int distance_to_boundary = prefix->distance_to_codepoint_boundary(&first_byte, alphabet_); + // The distance from new prefix (ie after adding new_label) to the first + // byte of the grapheme is the distance from current prefix plus one (new_label). + int distance_to_first_byte = prefix->distance_to_codepoint_boundary(&first_byte, alphabet_) + 1; int needed_bytes; if ((first_byte >> 3) == 0x1E) { needed_bytes = 4; @@ -223,7 +225,7 @@ bool Scorer::is_scoring_boundary(PathTrie* prefix, size_t new_label) assert(false); // invalid byte sequence. should be unreachable, disallowed by vocabulary/trie return false; } - return distance_to_boundary == needed_bytes; + return distance_to_first_byte == needed_bytes; } else { return new_label == SPACE_ID_; } @@ -255,12 +257,6 @@ double Scorer::get_log_cond_prob(const std::vector::const_iterator& double cond_prob = 0.0; for (auto it = begin; it != end; ++it) { lm::WordIndex word_index = vocab.Index(*it); - - // encounter OOV - if (word_index == lm::kUNK) { - return OOV_SCORE; - } - cond_prob = language_model_->BaseScore(in_state, word_index, out_state); std::swap(in_state, out_state); } @@ -273,6 +269,26 @@ double Scorer::get_log_cond_prob(const std::vector::const_iterator& return cond_prob/NUM_FLT_LOGE; } +double Scorer::get_unk_log_cond_prob() +{ + const auto& vocab = language_model_->BaseVocabulary(); + lm::ngram::State state_vec[2]; + lm::ngram::State *in_state = &state_vec[0]; + lm::ngram::State *out_state = &state_vec[1]; + + + language_model_->BeginSentenceWrite(in_state); + + double cond_prob = 0.0; + lm::WordIndex word_index = lm::kUNK; + cond_prob = language_model_->BaseScore(in_state, word_index, out_state); + std::swap(in_state, out_state); + cond_prob = language_model_->BaseScore(in_state, vocab.EndSentence(), out_state); + + // return loge prob + return cond_prob/NUM_FLT_LOGE; +} + void Scorer::reset_params(float alpha, float beta) { this->alpha = alpha; diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index eaf789dbf..36f3dab71 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -51,6 +51,8 @@ class Scorer : public fl::lib::text::LM { bool bos = false, bool eos = false); + double get_unk_log_cond_prob(); + // return the max order size_t get_max_order() const { return max_order_; } diff --git a/oov_lm_prep.sh b/oov_lm_prep.sh new file mode 100755 index 000000000..bcec1caff --- /dev/null +++ b/oov_lm_prep.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +set -e + +if [ $# -eq 0 ]; then + echo -e "\n This script prepares a controlled testing environment for OOV handling." + echo -e "\n Usage: \n $0 \n" + echo -e "\n Ex: $0 data.csv checkpoint-dir/ 0.1 \n" + exit 1 +fi + +data=$1 +am=$2 +percent=${3:-0.1} + +if awk -v x=$percent -v y=1 'BEGIN { exit (x >= y) ? 0 : 1 }'; then + echo " Error: OOV percentage must be less than one." + exit 0 +fi + +mkdir -p tmp/ +mkdir -p tmp/lm + +# Data preparation: split the vocab into 10% (that'd later represent OOVs) +# and the remaining 90% to compose a corpus for LM generation +echo "Preparing Data for Language Model Generation" + +# Extract corpus vocabulary (unique words) +xsv select transcript $data | awk -F, '$3!="" && NR>1;{print $0}' > tmp/data.txt +sed 's/ /\n/g' tmp/data.txt | sort | uniq -c | sort -nr > tmp/vocab.txt + +# Pick the least frequent 10% words to build OOV set +oov_count=$(wc -l tmp/vocab.txt | awk -v p="$percent" '{print int($0*p)}') +tail -$oov_count tmp/vocab.txt | awk '{print $2}'> tmp/oov_words +grep -wFf tmp/oov_words tmp/data.txt > tmp/oov_sents + +# Exclude OOVs from the text corpus +grep -vf tmp/oov_sents tmp/data.txt > tmp/scorer_corpus.txt +gzip -c tmp/scorer_corpus.txt > tmp/scorer_corpus.txt.gz +grep -vf tmp/oov_sents $data > tmp/scorer_corpus.csv + +# Prepare OOV CSV or testing purposes (to assess improvements on it) +grep -wFf tmp/oov_sents tmp/data.txt > tmp/oov_corpus.txt +echo "wav_filename,wav_filesize,transcript" > tmp/oov_corpus.csv +grep -wFf tmp/oov_sents $data >> tmp/oov_corpus.csv + +# Generate LM +python3 data/lm/generate_lm.py --input_txt tmp/scorer_corpus.txt.gz \ + --output_dir tmp/lm --top_k 500000 --kenlm_bins kenlm/build/bin \ + --arpa_order 3 --max_arpa_memory "85%" --arpa_prune "0|0|1" \ + --binary_a_bits 255 --binary_q_bits 8 --binary_type trie --discount_fallback + +./native_client/generate_scorer_package --alphabet $am/alphabet.txt \ + --lm tmp/lm/lm.binary --vocab tmp/lm/vocab-500000.txt \ + --package tmp/lm/kenlm.scorer --default_alpha 0.931289039105002 \ + --default_beta 1.1834137581510284 + +echo "Done!"