Skip to content

Commit

Permalink
Translator App
Browse files Browse the repository at this point in the history
  • Loading branch information
RodionfromHSE committed Nov 19, 2023
1 parent f6d16df commit 5f846a0
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 1 deletion.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ eval_bleu: 67.197
```
(BLEU is suspeciously high)

## Translation App

**Synonyms Searcher**\
Simple version is based on `word2vec` model, namely `fasttext` ([link](https://fasttext.cc/docs/en/crawl-vectors.html)). We've chosen fasttext because it solves the problem of out-of-vocabulary words.




Expand Down
3 changes: 2 additions & 1 deletion conf/model/opus_distilled_en_ru.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
name: opus-distilled-en-ru
model_and_tokenizer_name: "under-tree/transformer-en-ru"
output_dir: ${root}/models/${.name}/finetuned
output_dir: ${root}/models/${.name}/finetuned
type: seq2seq
3 changes: 3 additions & 0 deletions src/prod/features/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .attention_extractor import BaseAttentionExtractor, choose_attention_extractor
from .synonym_searcher import BaseSynonymSearcher, choose_synonym_searcher
from .translator import BaseTranslator, choose_translator
59 changes: 59 additions & 0 deletions src/prod/features/attention_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from abc import ABC, abstractmethod
import typing as tp
from typing import Any
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from omegaconf import DictConfig
import numpy as np

class BaseAttentionExtractor(ABC):
def __init__(self, model: AutoModelForSeq2SeqLM):
"""
:@param model: model to extract attention from
"""
super().__init__()
self.model = model

@abstractmethod
def __call__(self, tokens: tp.List[str]) -> tp.List[tp.List[float]]:
"""
Extract attention scores for each token in relation to all other tokens
:@param tokens: list of tokens to extract attention for
:@return: list of attention scores for each token
"""
pass

class RandomAttentionExtractor(BaseAttentionExtractor):
"""
Just for testing purposes
"""
def __call__(self, tokens: tp.List[str]) -> tp.List[tp.List[float]]:
"""
Extract random attention scores for each token in relation to all other tokens
:@param tokens: list of tokens to extract attention for
:@return: list of attention scores for each token
"""
n_tokens = len(tokens)
unnormalized_scores = np.random.rand(n_tokens, n_tokens)
normalized_scores = unnormalized_scores / unnormalized_scores.sum(axis=1, keepdims=True)
return normalized_scores.tolist()

class Seq2SeqAttentionExtractor(BaseAttentionExtractor):
pass


def choose_attention_extractor(cfg: DictConfig) -> BaseAttentionExtractor:
"""
Choose attention extractor based on config
:@param cfg: config to choose attention extractor from
:@return: attention extractor
"""
if cfg.type == "random":
model = None
return RandomAttentionExtractor(model)
if cfg.type == "seq2seq":
return Seq2SeqAttentionExtractor(cfg)
else:
raise ValueError(f"Unknown attention extractor type: {cfg.attention_extractor.type}")
55 changes: 55 additions & 0 deletions src/prod/features/synonym_searcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from abc import ABC, abstractmethod
import typing as tp
import fasttext
from omegaconf import DictConfig

class BaseSynonymSearcher(ABC):
@abstractmethod
def __call__(self, tokens: tp.List[str], *args, return_scores: bool = False,
n_synonyms: int = 5, **kwargs) -> tp.List[tp.List[str]] | tp.List[tp.List[tp.Tuple[str, float]]]:
"""
Search for synonyms for each token in the list
:@param tokens: list of tokens to find synonyms for
:@param return_scores: whether to return scores for synonyms
:@param n_synonyms: number of synonyms to return for each token
:@return: list of synonyms for each token or list of synonyms and scores
"""
pass


class FastTextSynonymSearcher(BaseSynonymSearcher):
def __init__(self, model_path: str):
"""
:@param model_path: path to fasttext model
"""
super().__init__()
self.model = fasttext.load_model(model_path)

def __call__(self, tokens: tp.List[str], *args, return_scores: bool = False,
n_synonyms: int = 5, **kwargs) -> tp.List[tp.List[str]] | tp.List[tp.List[tp.Tuple[str, float]]]:
"""
Search for synonyms for each token in the list
:@param tokens: list of tokens to find synonyms for
:@param return_scores: whether to return scores for synonyms
:@param n_synonyms: number of synonyms to return for each token
:@return: list of synonyms for each token or list of synonyms and scores
"""
synonyms = []
for token in tokens:
if return_scores:
synonyms.append(self.model.get_nearest_neighbors(token, k=n_synonyms))
else:
synonyms.append([synonym for _, synonym in self.model.get_nearest_neighbors(token, k=n_synonyms)])
return synonyms

def choose_synonym_searcher(cfg: DictConfig) -> BaseSynonymSearcher:
"""
Choose synonym searcher based on config
:@param cfg: config to choose synonym searcher from
:@return: synonym searcher
"""
if cfg.type == "fasttext":
return FastTextSynonymSearcher(cfg.path)
else:
raise ValueError(f"Unknown synonym searcher type: {cfg.type}")
93 changes: 93 additions & 0 deletions src/prod/features/translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from abc import ABC, abstractmethod
import typing as tp
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from omegaconf import DictConfig

translation_dict = tp.Dict[str, str | tp.List[str]]

class BaseTranslator(ABC):
def __init__(self, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer):
"""
:@param model: model to translate with
:@param tokenizer: tokenizer to use
"""
super().__init__()
self.model = model
self.tokenizer = tokenizer

@abstractmethod
def call_one(self, text: str, *args, **kwargs) -> translation_dict:
"""
Translate one text
:@param text: text to translate
:@return: dict with translation, input and output tokens
{
'translation': str,
'input_tokens': list[str]
'output_tokens': list[str]
}
"""
pass

def __call__(self, texts: str | tp.List[str], *args, **kwargs) -> translation_dict | tp.List[translation_dict]:
"""
Translate texts
:@param texts: text or list of texts to translate
:@return: dict or list of dicts with translation, input and output tokens for each text
{
'translation': str,
'input_tokens': list[str]
'output_tokens': list[str]
}
"""
if isinstance(texts, str):
return self.call_one(texts, *args, **kwargs)
return [self.call_one(text, *args, **kwargs) for text in texts]


class Seq2SeqTranslator(BaseTranslator):
def __init__(self, cfg: DictConfig):
"""
:@param cfg: config with model and tokenizer paths
"""
name = cfg.model_and_tokenizer_name
model = AutoModelForSeq2SeqLM.from_pretrained(name)
tokenizer = AutoTokenizer.from_pretrained(name)
super().__init__(model, tokenizer)
self.gen_params = cfg.get("gen_params", {})

def get_tokens(self, text: str, is_target: bool) -> tp.List[str]:
"""
Tokenize text with tokenizer
:@param text: text to tokenize
:@param is_target: whether text is target or source
:@return: list of tokens
"""
if is_target:
with self.tokenizer.as_target_tokenizer():
return self.tokenizer.tokenize(text)
return self.tokenizer.tokenize(text)

def call_one(self, text: str, *args, **kwargs) -> translation_dict:
input_idx = self.tokenizer.encode(text, return_tensors="pt")
translation = self.model.generate(input_idx, **self.gen_params)
translation = self.tokenizer.batch_decode(translation, skip_special_tokens=True)[0]

input_tokens = self.get_tokens(text, is_target=False)
output_tokens = self.get_tokens(translation, is_target=True)

return {
"translation": translation,
"input_tokens": input_tokens,
"output_tokens": output_tokens
}


def choose_translator(cfg: DictConfig) -> BaseTranslator:
if cfg.type == "seq2seq":
return Seq2SeqTranslator(cfg)
else:
raise ValueError(f"Unknown translator type: {cfg.type}")
70 changes: 70 additions & 0 deletions src/prod/translator_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import sys

cur_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(cur_dir)

from features import BaseAttentionExtractor, BaseSynonymSearcher, BaseTranslator
from features import choose_attention_extractor, choose_synonym_searcher, choose_translator
from omegaconf import DictConfig
import typing as tp
from nltk.tokenize import word_tokenize

translation_type = tp.Dict[str, tp.Any]

class TranslatorApp:
def __init__(self, cfg: DictConfig):
self.cfg = cfg
self.translator = choose_translator(cfg.translator)
self.src_synonym_searcher = choose_synonym_searcher(cfg.src_synonym_searcher)
self.dest_synonym_searcher = choose_synonym_searcher(cfg.dest_synonym_searcher)
self.attention_extractor = choose_attention_extractor(cfg.attention_extractor)

def _get_words(self, text: str) -> tp.List[str]:
"""
Split text into words
:@param text: text to split
:@return: list of words
"""
text = text.lower()
return word_tokenize(text)

def call_one(self, text: str) -> translation_type:
"""
Return translation for one text
:@param text: text to translate
:@return: dict
{
'translation': str,
'input_tokens': list[str]
'input_words': list[str]
'output_tokens': list[str]
'output_words': list[str]
'src_synonyms': list[list[str]]
'dest_synonyms': list[list[str]]
'attention': list[list[float]]
}
"""
translation = self.translator(text)
input_tokens = translation["input_tokens"]
output_tokens = translation["output_tokens"]
input_words = self._get_words(text)
output_words = self._get_words(translation["translation"])
src_synonyms = self.src_synonym_searcher(input_words)
dest_synonyms = self.dest_synonym_searcher(output_words)
attention = self.attention_extractor(input_tokens)
return {
"translation": translation["translation"],
"input_tokens": input_tokens,
"input_words": input_words,
"output_tokens": output_tokens,
"output_words": output_words,
"src_synonyms": src_synonyms,
"dest_synonyms": dest_synonyms,
"attention": attention
}

def __call__(self, texts: str | tp.List[str]) -> translation_type | tp.List[translation_type]:
if isinstance(texts, str):
return self.call_one(texts)
return [self.call_one(text) for text in texts]

0 comments on commit 5f846a0

Please sign in to comment.