-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f6d16df
commit 5f846a0
Showing
7 changed files
with
287 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
name: opus-distilled-en-ru | ||
model_and_tokenizer_name: "under-tree/transformer-en-ru" | ||
output_dir: ${root}/models/${.name}/finetuned | ||
output_dir: ${root}/models/${.name}/finetuned | ||
type: seq2seq |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .attention_extractor import BaseAttentionExtractor, choose_attention_extractor | ||
from .synonym_searcher import BaseSynonymSearcher, choose_synonym_searcher | ||
from .translator import BaseTranslator, choose_translator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from abc import ABC, abstractmethod | ||
import typing as tp | ||
from typing import Any | ||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | ||
from omegaconf import DictConfig | ||
import numpy as np | ||
|
||
class BaseAttentionExtractor(ABC): | ||
def __init__(self, model: AutoModelForSeq2SeqLM): | ||
""" | ||
:@param model: model to extract attention from | ||
""" | ||
super().__init__() | ||
self.model = model | ||
|
||
@abstractmethod | ||
def __call__(self, tokens: tp.List[str]) -> tp.List[tp.List[float]]: | ||
""" | ||
Extract attention scores for each token in relation to all other tokens | ||
:@param tokens: list of tokens to extract attention for | ||
:@return: list of attention scores for each token | ||
""" | ||
pass | ||
|
||
class RandomAttentionExtractor(BaseAttentionExtractor): | ||
""" | ||
Just for testing purposes | ||
""" | ||
def __call__(self, tokens: tp.List[str]) -> tp.List[tp.List[float]]: | ||
""" | ||
Extract random attention scores for each token in relation to all other tokens | ||
:@param tokens: list of tokens to extract attention for | ||
:@return: list of attention scores for each token | ||
""" | ||
n_tokens = len(tokens) | ||
unnormalized_scores = np.random.rand(n_tokens, n_tokens) | ||
normalized_scores = unnormalized_scores / unnormalized_scores.sum(axis=1, keepdims=True) | ||
return normalized_scores.tolist() | ||
|
||
class Seq2SeqAttentionExtractor(BaseAttentionExtractor): | ||
pass | ||
|
||
|
||
def choose_attention_extractor(cfg: DictConfig) -> BaseAttentionExtractor: | ||
""" | ||
Choose attention extractor based on config | ||
:@param cfg: config to choose attention extractor from | ||
:@return: attention extractor | ||
""" | ||
if cfg.type == "random": | ||
model = None | ||
return RandomAttentionExtractor(model) | ||
if cfg.type == "seq2seq": | ||
return Seq2SeqAttentionExtractor(cfg) | ||
else: | ||
raise ValueError(f"Unknown attention extractor type: {cfg.attention_extractor.type}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from abc import ABC, abstractmethod | ||
import typing as tp | ||
import fasttext | ||
from omegaconf import DictConfig | ||
|
||
class BaseSynonymSearcher(ABC): | ||
@abstractmethod | ||
def __call__(self, tokens: tp.List[str], *args, return_scores: bool = False, | ||
n_synonyms: int = 5, **kwargs) -> tp.List[tp.List[str]] | tp.List[tp.List[tp.Tuple[str, float]]]: | ||
""" | ||
Search for synonyms for each token in the list | ||
:@param tokens: list of tokens to find synonyms for | ||
:@param return_scores: whether to return scores for synonyms | ||
:@param n_synonyms: number of synonyms to return for each token | ||
:@return: list of synonyms for each token or list of synonyms and scores | ||
""" | ||
pass | ||
|
||
|
||
class FastTextSynonymSearcher(BaseSynonymSearcher): | ||
def __init__(self, model_path: str): | ||
""" | ||
:@param model_path: path to fasttext model | ||
""" | ||
super().__init__() | ||
self.model = fasttext.load_model(model_path) | ||
|
||
def __call__(self, tokens: tp.List[str], *args, return_scores: bool = False, | ||
n_synonyms: int = 5, **kwargs) -> tp.List[tp.List[str]] | tp.List[tp.List[tp.Tuple[str, float]]]: | ||
""" | ||
Search for synonyms for each token in the list | ||
:@param tokens: list of tokens to find synonyms for | ||
:@param return_scores: whether to return scores for synonyms | ||
:@param n_synonyms: number of synonyms to return for each token | ||
:@return: list of synonyms for each token or list of synonyms and scores | ||
""" | ||
synonyms = [] | ||
for token in tokens: | ||
if return_scores: | ||
synonyms.append(self.model.get_nearest_neighbors(token, k=n_synonyms)) | ||
else: | ||
synonyms.append([synonym for _, synonym in self.model.get_nearest_neighbors(token, k=n_synonyms)]) | ||
return synonyms | ||
|
||
def choose_synonym_searcher(cfg: DictConfig) -> BaseSynonymSearcher: | ||
""" | ||
Choose synonym searcher based on config | ||
:@param cfg: config to choose synonym searcher from | ||
:@return: synonym searcher | ||
""" | ||
if cfg.type == "fasttext": | ||
return FastTextSynonymSearcher(cfg.path) | ||
else: | ||
raise ValueError(f"Unknown synonym searcher type: {cfg.type}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from abc import ABC, abstractmethod | ||
import typing as tp | ||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | ||
from omegaconf import DictConfig | ||
|
||
translation_dict = tp.Dict[str, str | tp.List[str]] | ||
|
||
class BaseTranslator(ABC): | ||
def __init__(self, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer): | ||
""" | ||
:@param model: model to translate with | ||
:@param tokenizer: tokenizer to use | ||
""" | ||
super().__init__() | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
|
||
@abstractmethod | ||
def call_one(self, text: str, *args, **kwargs) -> translation_dict: | ||
""" | ||
Translate one text | ||
:@param text: text to translate | ||
:@return: dict with translation, input and output tokens | ||
{ | ||
'translation': str, | ||
'input_tokens': list[str] | ||
'output_tokens': list[str] | ||
} | ||
""" | ||
pass | ||
|
||
def __call__(self, texts: str | tp.List[str], *args, **kwargs) -> translation_dict | tp.List[translation_dict]: | ||
""" | ||
Translate texts | ||
:@param texts: text or list of texts to translate | ||
:@return: dict or list of dicts with translation, input and output tokens for each text | ||
{ | ||
'translation': str, | ||
'input_tokens': list[str] | ||
'output_tokens': list[str] | ||
} | ||
""" | ||
if isinstance(texts, str): | ||
return self.call_one(texts, *args, **kwargs) | ||
return [self.call_one(text, *args, **kwargs) for text in texts] | ||
|
||
|
||
class Seq2SeqTranslator(BaseTranslator): | ||
def __init__(self, cfg: DictConfig): | ||
""" | ||
:@param cfg: config with model and tokenizer paths | ||
""" | ||
name = cfg.model_and_tokenizer_name | ||
model = AutoModelForSeq2SeqLM.from_pretrained(name) | ||
tokenizer = AutoTokenizer.from_pretrained(name) | ||
super().__init__(model, tokenizer) | ||
self.gen_params = cfg.get("gen_params", {}) | ||
|
||
def get_tokens(self, text: str, is_target: bool) -> tp.List[str]: | ||
""" | ||
Tokenize text with tokenizer | ||
:@param text: text to tokenize | ||
:@param is_target: whether text is target or source | ||
:@return: list of tokens | ||
""" | ||
if is_target: | ||
with self.tokenizer.as_target_tokenizer(): | ||
return self.tokenizer.tokenize(text) | ||
return self.tokenizer.tokenize(text) | ||
|
||
def call_one(self, text: str, *args, **kwargs) -> translation_dict: | ||
input_idx = self.tokenizer.encode(text, return_tensors="pt") | ||
translation = self.model.generate(input_idx, **self.gen_params) | ||
translation = self.tokenizer.batch_decode(translation, skip_special_tokens=True)[0] | ||
|
||
input_tokens = self.get_tokens(text, is_target=False) | ||
output_tokens = self.get_tokens(translation, is_target=True) | ||
|
||
return { | ||
"translation": translation, | ||
"input_tokens": input_tokens, | ||
"output_tokens": output_tokens | ||
} | ||
|
||
|
||
def choose_translator(cfg: DictConfig) -> BaseTranslator: | ||
if cfg.type == "seq2seq": | ||
return Seq2SeqTranslator(cfg) | ||
else: | ||
raise ValueError(f"Unknown translator type: {cfg.type}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import os | ||
import sys | ||
|
||
cur_dir = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.append(cur_dir) | ||
|
||
from features import BaseAttentionExtractor, BaseSynonymSearcher, BaseTranslator | ||
from features import choose_attention_extractor, choose_synonym_searcher, choose_translator | ||
from omegaconf import DictConfig | ||
import typing as tp | ||
from nltk.tokenize import word_tokenize | ||
|
||
translation_type = tp.Dict[str, tp.Any] | ||
|
||
class TranslatorApp: | ||
def __init__(self, cfg: DictConfig): | ||
self.cfg = cfg | ||
self.translator = choose_translator(cfg.translator) | ||
self.src_synonym_searcher = choose_synonym_searcher(cfg.src_synonym_searcher) | ||
self.dest_synonym_searcher = choose_synonym_searcher(cfg.dest_synonym_searcher) | ||
self.attention_extractor = choose_attention_extractor(cfg.attention_extractor) | ||
|
||
def _get_words(self, text: str) -> tp.List[str]: | ||
""" | ||
Split text into words | ||
:@param text: text to split | ||
:@return: list of words | ||
""" | ||
text = text.lower() | ||
return word_tokenize(text) | ||
|
||
def call_one(self, text: str) -> translation_type: | ||
""" | ||
Return translation for one text | ||
:@param text: text to translate | ||
:@return: dict | ||
{ | ||
'translation': str, | ||
'input_tokens': list[str] | ||
'input_words': list[str] | ||
'output_tokens': list[str] | ||
'output_words': list[str] | ||
'src_synonyms': list[list[str]] | ||
'dest_synonyms': list[list[str]] | ||
'attention': list[list[float]] | ||
} | ||
""" | ||
translation = self.translator(text) | ||
input_tokens = translation["input_tokens"] | ||
output_tokens = translation["output_tokens"] | ||
input_words = self._get_words(text) | ||
output_words = self._get_words(translation["translation"]) | ||
src_synonyms = self.src_synonym_searcher(input_words) | ||
dest_synonyms = self.dest_synonym_searcher(output_words) | ||
attention = self.attention_extractor(input_tokens) | ||
return { | ||
"translation": translation["translation"], | ||
"input_tokens": input_tokens, | ||
"input_words": input_words, | ||
"output_tokens": output_tokens, | ||
"output_words": output_words, | ||
"src_synonyms": src_synonyms, | ||
"dest_synonyms": dest_synonyms, | ||
"attention": attention | ||
} | ||
|
||
def __call__(self, texts: str | tp.List[str]) -> translation_type | tp.List[translation_type]: | ||
if isinstance(texts, str): | ||
return self.call_one(texts) | ||
return [self.call_one(text) for text in texts] |