diff --git a/README.md b/README.md index a507e41..6bc970d 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,11 @@ eval_bleu: 67.197 ``` (BLEU is suspeciously high) +## Translation App + +**Synonyms Searcher**\ +Simple version is based on `word2vec` model, namely `fasttext` ([link](https://fasttext.cc/docs/en/crawl-vectors.html)). We've chosen fasttext because it solves the problem of out-of-vocabulary words. + diff --git a/conf/model/opus_distilled_en_ru.yaml b/conf/model/opus_distilled_en_ru.yaml index dfddef8..8a6e2a8 100644 --- a/conf/model/opus_distilled_en_ru.yaml +++ b/conf/model/opus_distilled_en_ru.yaml @@ -1,3 +1,4 @@ name: opus-distilled-en-ru model_and_tokenizer_name: "under-tree/transformer-en-ru" -output_dir: ${root}/models/${.name}/finetuned \ No newline at end of file +output_dir: ${root}/models/${.name}/finetuned +type: seq2seq \ No newline at end of file diff --git a/src/prod/features/__init__.py b/src/prod/features/__init__.py new file mode 100644 index 0000000..f14d04c --- /dev/null +++ b/src/prod/features/__init__.py @@ -0,0 +1,3 @@ +from .attention_extractor import BaseAttentionExtractor, choose_attention_extractor +from .synonym_searcher import BaseSynonymSearcher, choose_synonym_searcher +from .translator import BaseTranslator, choose_translator diff --git a/src/prod/features/attention_extractor.py b/src/prod/features/attention_extractor.py new file mode 100644 index 0000000..1a29eaf --- /dev/null +++ b/src/prod/features/attention_extractor.py @@ -0,0 +1,59 @@ +from abc import ABC, abstractmethod +import typing as tp +from typing import Any +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from omegaconf import DictConfig +import numpy as np + +class BaseAttentionExtractor(ABC): + def __init__(self, model: AutoModelForSeq2SeqLM): + """ + :@param model: model to extract attention from + """ + super().__init__() + self.model = model + + @abstractmethod + def __call__(self, tokens: tp.List[str]) -> tp.List[tp.List[float]]: + """ + Extract attention scores for each token in relation to all other tokens + + :@param tokens: list of tokens to extract attention for + :@return: list of attention scores for each token + """ + pass + +class RandomAttentionExtractor(BaseAttentionExtractor): + """ + Just for testing purposes + """ + def __call__(self, tokens: tp.List[str]) -> tp.List[tp.List[float]]: + """ + Extract random attention scores for each token in relation to all other tokens + + :@param tokens: list of tokens to extract attention for + :@return: list of attention scores for each token + """ + n_tokens = len(tokens) + unnormalized_scores = np.random.rand(n_tokens, n_tokens) + normalized_scores = unnormalized_scores / unnormalized_scores.sum(axis=1, keepdims=True) + return normalized_scores.tolist() + +class Seq2SeqAttentionExtractor(BaseAttentionExtractor): + pass + + +def choose_attention_extractor(cfg: DictConfig) -> BaseAttentionExtractor: + """ + Choose attention extractor based on config + + :@param cfg: config to choose attention extractor from + :@return: attention extractor + """ + if cfg.type == "random": + model = None + return RandomAttentionExtractor(model) + if cfg.type == "seq2seq": + return Seq2SeqAttentionExtractor(cfg) + else: + raise ValueError(f"Unknown attention extractor type: {cfg.attention_extractor.type}") \ No newline at end of file diff --git a/src/prod/features/synonym_searcher.py b/src/prod/features/synonym_searcher.py new file mode 100644 index 0000000..742c8da --- /dev/null +++ b/src/prod/features/synonym_searcher.py @@ -0,0 +1,55 @@ +from abc import ABC, abstractmethod +import typing as tp +import fasttext +from omegaconf import DictConfig + +class BaseSynonymSearcher(ABC): + @abstractmethod + def __call__(self, tokens: tp.List[str], *args, return_scores: bool = False, + n_synonyms: int = 5, **kwargs) -> tp.List[tp.List[str]] | tp.List[tp.List[tp.Tuple[str, float]]]: + """ + Search for synonyms for each token in the list + :@param tokens: list of tokens to find synonyms for + :@param return_scores: whether to return scores for synonyms + :@param n_synonyms: number of synonyms to return for each token + :@return: list of synonyms for each token or list of synonyms and scores + """ + pass + + +class FastTextSynonymSearcher(BaseSynonymSearcher): + def __init__(self, model_path: str): + """ + :@param model_path: path to fasttext model + """ + super().__init__() + self.model = fasttext.load_model(model_path) + + def __call__(self, tokens: tp.List[str], *args, return_scores: bool = False, + n_synonyms: int = 5, **kwargs) -> tp.List[tp.List[str]] | tp.List[tp.List[tp.Tuple[str, float]]]: + """ + Search for synonyms for each token in the list + :@param tokens: list of tokens to find synonyms for + :@param return_scores: whether to return scores for synonyms + :@param n_synonyms: number of synonyms to return for each token + :@return: list of synonyms for each token or list of synonyms and scores + """ + synonyms = [] + for token in tokens: + if return_scores: + synonyms.append(self.model.get_nearest_neighbors(token, k=n_synonyms)) + else: + synonyms.append([synonym for _, synonym in self.model.get_nearest_neighbors(token, k=n_synonyms)]) + return synonyms + +def choose_synonym_searcher(cfg: DictConfig) -> BaseSynonymSearcher: + """ + Choose synonym searcher based on config + + :@param cfg: config to choose synonym searcher from + :@return: synonym searcher + """ + if cfg.type == "fasttext": + return FastTextSynonymSearcher(cfg.path) + else: + raise ValueError(f"Unknown synonym searcher type: {cfg.type}") \ No newline at end of file diff --git a/src/prod/features/translator.py b/src/prod/features/translator.py new file mode 100644 index 0000000..f8aabf0 --- /dev/null +++ b/src/prod/features/translator.py @@ -0,0 +1,93 @@ +from abc import ABC, abstractmethod +import typing as tp +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from omegaconf import DictConfig + +translation_dict = tp.Dict[str, str | tp.List[str]] + +class BaseTranslator(ABC): + def __init__(self, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer): + """ + :@param model: model to translate with + :@param tokenizer: tokenizer to use + """ + super().__init__() + self.model = model + self.tokenizer = tokenizer + + @abstractmethod + def call_one(self, text: str, *args, **kwargs) -> translation_dict: + """ + Translate one text + + :@param text: text to translate + :@return: dict with translation, input and output tokens + { + 'translation': str, + 'input_tokens': list[str] + 'output_tokens': list[str] + } + """ + pass + + def __call__(self, texts: str | tp.List[str], *args, **kwargs) -> translation_dict | tp.List[translation_dict]: + """ + Translate texts + + :@param texts: text or list of texts to translate + :@return: dict or list of dicts with translation, input and output tokens for each text + { + 'translation': str, + 'input_tokens': list[str] + 'output_tokens': list[str] + } + """ + if isinstance(texts, str): + return self.call_one(texts, *args, **kwargs) + return [self.call_one(text, *args, **kwargs) for text in texts] + + +class Seq2SeqTranslator(BaseTranslator): + def __init__(self, cfg: DictConfig): + """ + :@param cfg: config with model and tokenizer paths + """ + name = cfg.model_and_tokenizer_name + model = AutoModelForSeq2SeqLM.from_pretrained(name) + tokenizer = AutoTokenizer.from_pretrained(name) + super().__init__(model, tokenizer) + self.gen_params = cfg.get("gen_params", {}) + + def get_tokens(self, text: str, is_target: bool) -> tp.List[str]: + """ + Tokenize text with tokenizer + + :@param text: text to tokenize + :@param is_target: whether text is target or source + :@return: list of tokens + """ + if is_target: + with self.tokenizer.as_target_tokenizer(): + return self.tokenizer.tokenize(text) + return self.tokenizer.tokenize(text) + + def call_one(self, text: str, *args, **kwargs) -> translation_dict: + input_idx = self.tokenizer.encode(text, return_tensors="pt") + translation = self.model.generate(input_idx, **self.gen_params) + translation = self.tokenizer.batch_decode(translation, skip_special_tokens=True)[0] + + input_tokens = self.get_tokens(text, is_target=False) + output_tokens = self.get_tokens(translation, is_target=True) + + return { + "translation": translation, + "input_tokens": input_tokens, + "output_tokens": output_tokens + } + + +def choose_translator(cfg: DictConfig) -> BaseTranslator: + if cfg.type == "seq2seq": + return Seq2SeqTranslator(cfg) + else: + raise ValueError(f"Unknown translator type: {cfg.type}") \ No newline at end of file diff --git a/src/prod/translator_app.py b/src/prod/translator_app.py new file mode 100644 index 0000000..7971cde --- /dev/null +++ b/src/prod/translator_app.py @@ -0,0 +1,70 @@ +import os +import sys + +cur_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(cur_dir) + +from features import BaseAttentionExtractor, BaseSynonymSearcher, BaseTranslator +from features import choose_attention_extractor, choose_synonym_searcher, choose_translator +from omegaconf import DictConfig +import typing as tp +from nltk.tokenize import word_tokenize + +translation_type = tp.Dict[str, tp.Any] + +class TranslatorApp: + def __init__(self, cfg: DictConfig): + self.cfg = cfg + self.translator = choose_translator(cfg.translator) + self.src_synonym_searcher = choose_synonym_searcher(cfg.src_synonym_searcher) + self.dest_synonym_searcher = choose_synonym_searcher(cfg.dest_synonym_searcher) + self.attention_extractor = choose_attention_extractor(cfg.attention_extractor) + + def _get_words(self, text: str) -> tp.List[str]: + """ + Split text into words + :@param text: text to split + :@return: list of words + """ + text = text.lower() + return word_tokenize(text) + + def call_one(self, text: str) -> translation_type: + """ + Return translation for one text + :@param text: text to translate + :@return: dict + { + 'translation': str, + 'input_tokens': list[str] + 'input_words': list[str] + 'output_tokens': list[str] + 'output_words': list[str] + 'src_synonyms': list[list[str]] + 'dest_synonyms': list[list[str]] + 'attention': list[list[float]] + } + """ + translation = self.translator(text) + input_tokens = translation["input_tokens"] + output_tokens = translation["output_tokens"] + input_words = self._get_words(text) + output_words = self._get_words(translation["translation"]) + src_synonyms = self.src_synonym_searcher(input_words) + dest_synonyms = self.dest_synonym_searcher(output_words) + attention = self.attention_extractor(input_tokens) + return { + "translation": translation["translation"], + "input_tokens": input_tokens, + "input_words": input_words, + "output_tokens": output_tokens, + "output_words": output_words, + "src_synonyms": src_synonyms, + "dest_synonyms": dest_synonyms, + "attention": attention + } + + def __call__(self, texts: str | tp.List[str]) -> translation_type | tp.List[translation_type]: + if isinstance(texts, str): + return self.call_one(texts) + return [self.call_one(text) for text in texts]