diff --git a/src/string/ac-automaton.hpp b/src/string/ac-automaton.hpp index d2a010c..4ed4538 100644 --- a/src/string/ac-automaton.hpp +++ b/src/string/ac-automaton.hpp @@ -4,8 +4,10 @@ #include /** Modified from: * https://github.com/kth-competitive-programming/kactl/blob/master/content/strings/AhoCorasick.h - * Try to handdle duplicated patterns beforehand, otherwise change 'end' to - * vector; empty patterns are not allowed. Time: construction takes $O(26N)$, + * If there's no duplicated patterns, just call the constructor, otherwise handle it beforehand + * by yourself, or use the return value of insert + * empty patterns are not allowed. + * Time: construction takes $O(26N)$, * where $N =$ sum of length of patterns. find(x) is $O(N)$, where N = length of * x. findAll is $O(N+M)$ where M is number of occurrence of all pattern (up to N*sqrt(N)) */ @@ -14,8 +16,7 @@ struct AhoCorasick { struct Node { // back: failure link, points to longest suffix that is in the trie. // end: longest pattern that ends here, is -1 if no patten ends here. - // nmatches: number of (patterns that is a suffix of current - // node)/(duplicated patterns), depends on needs. + // nmatches: number of patterns that is a suffix of current node // output: output link, points to the longest pattern that is a suffix // of current node int back, end = -1, nmatches = 0, output = -1; @@ -33,7 +34,9 @@ struct AhoCorasick { build(); } - void insert(const std::string &s, int j) { // j: id of string s + // returns -1 if there's no duplicated pattern already in the trie + // returns the id of the duplicated pattern otherwise + int insert(const std::string &s, int j) { // j: id of string s assert(!s.empty()); int n = 0; for (char c : s) { @@ -43,13 +46,19 @@ struct AhoCorasick { } n = N[n].next[c - first]; } + if (N[n].end != -1) { + return N[n].end; + } N[n].end = j; N[n].nmatches++; + return -1; } void build() { + // adds a dummy node so the root node can be correctly handled N[0].back = (int)N.size(); N.emplace_back(0); + std::queue q; q.push(0); while (!q.empty()) { @@ -64,14 +73,13 @@ struct AhoCorasick { // if prev is an end node, then set output to prev node, // otherwise set to output link of prev node N[v].output = N[fail].end == -1 ? N[fail].output : fail; - // if we don't want to distinguish info of patterns that is - // a suffix of current node, we can add info to the next - // node like this: nxt.nmatches+=N[pnx].nmatches; + N[v].nmatches += N[fail].nmatches; q.push(v); } } } } + // for each position, finds the longest pattern that ends here std::vector find(const std::string &text) { int len = (int)text.size(); @@ -83,14 +91,17 @@ struct AhoCorasick { } return res; } - // for each position, finds the all that ends here + + // for each position, finds all patterns that ends here std::vector> find_all(const std::string &text) { int len = (int)text.size(); std::vector> res(len); int n = 0; for (int i = 0; i < len; i++) { n = N[n].next[text[i] - first]; - res[i].push_back(N[n].end); + if (N[n].end != -1) { + res[i].push_back(N[n].end); + } for (int ind = N[n].output; ind != -1; ind = N[ind].output) { assert(N[ind].end != -1); res[i].push_back(N[ind].end); @@ -99,8 +110,9 @@ struct AhoCorasick { return res; } - std::vector find_cnt(const std::string& text, int n) { - std::vector cnt(n); + // finds the number of occurrence of each pattern + std::vector find_cnt(const std::string& text, int num_of_patterns) { + std::vector cnt(num_of_patterns); int p = 0; for (auto c : text) { p = N[p].next[c - first];