diff --git a/.dvc/config b/.dvc/config index e4688e6..4434bab 100644 --- a/.dvc/config +++ b/.dvc/config @@ -2,4 +2,4 @@ autostage = true remote = storage ['remote "storage"'] - url = gdrive://1XzdLLDSWCRT57Kj9ZqYlfTk_0phVu6Fz + url = gdrive://19-PaarPhbUW27F4XXpLXS1SBu0Dvzch1 diff --git a/.gitignore b/.gitignore index 0c96245..50ec321 100644 --- a/.gitignore +++ b/.gitignore @@ -88,3 +88,4 @@ target/ # Mypy cache .mypy_cache/ /data +/models diff --git a/README.md b/README.md index c735243..6bc970d 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ Project Organization ├── LICENSE ├── Makefile <- Makefile with commands like `make data` or `make train` ├── README.md <- The top-level README for developers using this project. + ├── conf <- Configuration files ├── data │   ├── external <- Data from third party sources. │   ├── interim <- Intermediate data that has been transformed. @@ -54,4 +55,74 @@ Project Organization -------- +## Installation + +```bash +pip3 install -r requirements.txt +``` + +## Dataset + +#### Training dataset +The training dataset is based on `saier/unarxive_citrec` [hf](https://huggingface.co/datasets/saier/unarxive_citrec). + +*Details*: +```yaml +Train size: 9082 +Valid size: 702 +Test size: 568 +``` + +All the samples have length from `128` to `512` characters (TO-DO: characters -> tokens)\ +More in `notebooks/data/dataset_download.ipynb` + +After collecting the dataset, we carefully translated the samples from English to Russian using the OpenAI API.\ +Details in `notebooks/data/dataset_translate.ipynb` + +#### Dataset for model comparison (EvalDataset) +This dataset is based on `turkic_xwmt`, `subset=ru-en`, `split=test` [hf](https://huggingface.co/datasets/turkic_xwmt). + +Dataset size: 1000 + +## Models comparison + +Models comparison is based on bleu score of the translated samples and reference translation by OpenAI. + +**Models**:\ +transformer-en-ru: `Helsinki-NLP/opus-mt-en-ru` [hf](https://huggingface.co/Helsinki-NLP/opus-mt-en-ru)\ +nnlb-1.3B-distilled: `facebook/nllb-200-distilled-1.3B` [hf](https://huggingface.co/facebook/nllb-200-distilled-1.3B) + + +**Results**: +```yaml +transformer-en-ru BLEU: 2.58 +nnlb-1.3B-distilled BLEU: 2.55 +``` + +Even though results aren't statistically important, transformer-en-ru model was chosen since it's faster and has smaller size.\ +Details in `src/finetune/eval_bleu.py` + +## Model finetuning + +Simple seq2seq model finetuning transformer-en-ru.\ +Details in `notebooks/finetune/finetune.ipynb`.\ +Model on [hf](https://huggingface.co/under-tree/transformer-en-ru) + +**Fine-tuned model results**: +```yaml +eval_loss: 0.656 +eval_bleu: 67.197 +``` +(BLEU is suspeciously high) + +## Translation App + +**Synonyms Searcher**\ +Simple version is based on `word2vec` model, namely `fasttext` ([link](https://fasttext.cc/docs/en/crawl-vectors.html)). We've chosen fasttext because it solves the problem of out-of-vocabulary words. + + + + + +

Project based on the cookiecutter data science project template. #cookiecutterdatascience

diff --git a/conf/.gitignore b/conf/.gitignore new file mode 100644 index 0000000..5b6b072 --- /dev/null +++ b/conf/.gitignore @@ -0,0 +1 @@ +config.yaml diff --git a/conf/config.yaml b/conf/config.yaml new file mode 100644 index 0000000..acfa789 --- /dev/null +++ b/conf/config.yaml @@ -0,0 +1,9 @@ +defaults: + - _self_ + - dataset: null + - model: null + - params: null + - setup: null + + +root: /Users/user010/Desktop/Programming/ML/En2RuTranslator diff --git a/conf/dataset/model_eval.yaml b/conf/dataset/model_eval.yaml new file mode 100644 index 0000000..567dd1c --- /dev/null +++ b/conf/dataset/model_eval.yaml @@ -0,0 +1,7 @@ +path: ${root}/data/processed/model_eval_results.csv + +cols: # cols to be used when calculating BLEU + reference: target + candidates: + - transformer-en-ru + - nnlb-1.3B-distilled \ No newline at end of file diff --git a/conf/dataset/model_eval_raw.yaml b/conf/dataset/model_eval_raw.yaml new file mode 100644 index 0000000..3caec78 --- /dev/null +++ b/conf/dataset/model_eval_raw.yaml @@ -0,0 +1 @@ +path: ${root}/data/processed/model_eval.csv \ No newline at end of file diff --git a/conf/dataset/unarxive.yaml b/conf/dataset/unarxive.yaml new file mode 100644 index 0000000..ddd2bb3 --- /dev/null +++ b/conf/dataset/unarxive.yaml @@ -0,0 +1 @@ +path: "waleko/unarXive-en2ru" \ No newline at end of file diff --git a/conf/model/fasttext_en.yaml b/conf/model/fasttext_en.yaml new file mode 100644 index 0000000..adc6df4 --- /dev/null +++ b/conf/model/fasttext_en.yaml @@ -0,0 +1,2 @@ +path: ${root}/models/embs/cc.en.100.bin +type: fasttext \ No newline at end of file diff --git a/conf/model/fasttext_ru.yaml b/conf/model/fasttext_ru.yaml new file mode 100644 index 0000000..2c816f9 --- /dev/null +++ b/conf/model/fasttext_ru.yaml @@ -0,0 +1,2 @@ +path: ${root}/models/embs/cc.ru.100.bin +type: fasttext \ No newline at end of file diff --git a/conf/model/nnlb_1.3B.yaml b/conf/model/nnlb_1.3B.yaml new file mode 100644 index 0000000..91c8c6d --- /dev/null +++ b/conf/model/nnlb_1.3B.yaml @@ -0,0 +1,2 @@ +name: nnlb-1.3B-distilled +model_and_tokenizer_name: facebook/nllb-200-distilled-1.3B \ No newline at end of file diff --git a/conf/model/opus_distilled_en_ru.yaml b/conf/model/opus_distilled_en_ru.yaml new file mode 100644 index 0000000..8a6e2a8 --- /dev/null +++ b/conf/model/opus_distilled_en_ru.yaml @@ -0,0 +1,4 @@ +name: opus-distilled-en-ru +model_and_tokenizer_name: "under-tree/transformer-en-ru" +output_dir: ${root}/models/${.name}/finetuned +type: seq2seq \ No newline at end of file diff --git a/conf/model/opus_en_ru.yaml b/conf/model/opus_en_ru.yaml new file mode 100644 index 0000000..5f6cdea --- /dev/null +++ b/conf/model/opus_en_ru.yaml @@ -0,0 +1,2 @@ +name: opus-en-ru +model_and_tokenizer_name: Helsinki-NLP/opus-mt-en-ru diff --git a/conf/model/random_attention_extractor.yaml b/conf/model/random_attention_extractor.yaml new file mode 100644 index 0000000..010e927 --- /dev/null +++ b/conf/model/random_attention_extractor.yaml @@ -0,0 +1 @@ +type: "random" \ No newline at end of file diff --git a/conf/notebooks/finetune/candidates_inference.yaml b/conf/notebooks/finetune/candidates_inference.yaml new file mode 100644 index 0000000..1ba38b7 --- /dev/null +++ b/conf/notebooks/finetune/candidates_inference.yaml @@ -0,0 +1,13 @@ +root: ??? + +nnlb_model: + name: nnlb-1.3B-distilled + model_and_tokenizer_name: facebook/nllb-200-distilled-1.3B + +mt_model: + name: transformer-en-ru + model_and_tokenizer_name: Helsinki-NLP/opus-mt-en-ru + +inference_dataset_path: ${root}/data/processed/model_eval.csv +results_path: ${root}/data/processed/model_eval_results.csv + \ No newline at end of file diff --git a/models/.gitkeep b/conf/notebooks/finetune/finetune.yaml similarity index 100% rename from models/.gitkeep rename to conf/notebooks/finetune/finetune.yaml diff --git a/conf/notebooks/finetune/model_eval.yaml b/conf/notebooks/finetune/model_eval.yaml new file mode 100644 index 0000000..4157507 --- /dev/null +++ b/conf/notebooks/finetune/model_eval.yaml @@ -0,0 +1,7 @@ +root: ??? +load_dataset_params: + path: 'turkic_xwmt' + name: 'ru-en' + split: 'test' +save_path: '${root}/data/processed/model_eval.csv' + diff --git a/conf/params/finetune.yaml b/conf/params/finetune.yaml new file mode 100644 index 0000000..ac3e392 --- /dev/null +++ b/conf/params/finetune.yaml @@ -0,0 +1,15 @@ +batch_size: 16 +max_length: 512 +train_args: + evaluation_strategy: epoch + learning_rate: 2e-5 + per_device_train_batch_size: ${..batch_size} + per_device_eval_batch_size: ${..batch_size} + weight_decay: 0.01 + save_total_limit: 3 + num_train_epochs: 4 + predict_with_generate: true + +wandb_args: + report_to: wandb + run_name: finetune \ No newline at end of file diff --git a/conf/setup/all_models_example.yaml b/conf/setup/all_models_example.yaml new file mode 100644 index 0000000..4b72b91 --- /dev/null +++ b/conf/setup/all_models_example.yaml @@ -0,0 +1,6 @@ +# @package _global_ + +defaults: + - /model@model1: opus_en_ru + - /model@model2: opus_distilled_en_ru + - override /dataset: unarxive \ No newline at end of file diff --git a/conf/setup/finetune.yaml b/conf/setup/finetune.yaml new file mode 100644 index 0000000..9f69d00 --- /dev/null +++ b/conf/setup/finetune.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +defaults: + - /model@pretrained: opus_en_ru + - /model@finetuned: opus_distilled_en_ru + - override /dataset: unarxive + - override /params: finetune \ No newline at end of file diff --git a/conf/setup/inference.yaml b/conf/setup/inference.yaml new file mode 100644 index 0000000..627fd2f --- /dev/null +++ b/conf/setup/inference.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +defaults: + - /model@opus_model: opus_en_ru + - /model@nnlb_model: nnlb_1.3B + - /dataset@inference_dataset: model_eval_raw + - /dataset@result_dataset: model_eval \ No newline at end of file diff --git a/conf/setup/prod.yaml b/conf/setup/prod.yaml new file mode 100644 index 0000000..988c68b --- /dev/null +++ b/conf/setup/prod.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /model@dest_synonym_searcher: fasttext_ru + - /model@src_synonym_searcher: fasttext_en + - /model@translator: opus_distilled_en_ru + - /model@attention_extractor: random_attention_extractor + - override /params: finetune \ No newline at end of file diff --git a/custom_utils/__init__.py b/custom_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom_utils/config_handler.py b/custom_utils/config_handler.py new file mode 100644 index 0000000..6e01d21 --- /dev/null +++ b/custom_utils/config_handler.py @@ -0,0 +1,25 @@ +from omegaconf import OmegaConf +import json +import typing as tp +import os +from hydra import initialize_config_dir, compose + +__ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +__CONFIG_DIR = os.path.join(__ROOT_DIR, "conf") + +def read_config(config_dir: str = __CONFIG_DIR, overrides: tp.Optional[tp.List[str]] = None) -> OmegaConf: + """ + :@param config_dir: path to config directory + :@param overrides: list of overrides (e.g. ["dataset=model_eval"]) + :@param set_to_none_empty_with_warn: if True, set empty values to None and print warning + :@return: OmegaConf object + """ + config_dir = os.path.abspath(config_dir) + with initialize_config_dir(config_dir=config_dir, version_base=None): + cfg = compose(config_name="config", overrides=overrides) + cfg = OmegaConf.create(OmegaConf.to_yaml(cfg, resolve=True)) + return cfg + +def pprint_config(cfg: OmegaConf) -> None: + "Pretty print config" + print(json.dumps(OmegaConf.to_container(cfg), indent=2)) diff --git a/data.dvc b/data.dvc index 1c30a79..e52ba8a 100644 --- a/data.dvc +++ b/data.dvc @@ -1,6 +1,6 @@ outs: -- md5: 255799c6a8913d73679631d546a9dd88.dir - nfiles: 13 +- md5: a04b7051c1e5067e29c68137c321dae1.dir + nfiles: 15 hash: md5 path: data - size: 19936558 + size: 21297660 diff --git a/models.dvc b/models.dvc new file mode 100644 index 0000000..cee5780 --- /dev/null +++ b/models.dvc @@ -0,0 +1,6 @@ +outs: +- md5: c32a6fc3220dba5ae7628692d397c852.dir + size: 4892930090 + nfiles: 2 + hash: md5 + path: models diff --git a/notebooks/data/dataset_generation.ipynb b/notebooks/data/dataset_generation.ipynb deleted file mode 100644 index 5b1124b..0000000 --- a/notebooks/data/dataset_generation.ipynb +++ /dev/null @@ -1,136 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(False, False)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "PROMPT = \"\"\"\\\n", - "Ты профессиональный тестировщик больших языковых моделей.\n", - "Сейчас твоя задача составить запросы, которые требуют от модели **сгенерировать изображение** (картину или фото).\n", - "Эти запросы должны использовать **как явные инструкции, так и намёки**. Запросы должны быть **разнообразными** и иметь **разный уровень формальности**.\n", - "\n", - "Сгенирируй мне 10 таких запросов.\n", - "\n", - "Примеры:\n", - "Нарисуй, пожалуйста, фотоаппарат марки «Зенит» с красивым плетёным ремешком.\n", - "а можешь плиз нарисовать как мальчик и девочка на пляже строят замок из песка?\n", - "Изобрази мне кота Матроскина, который играет на гитаре.\n", - "фото как спичка горит, а кругом тают кубики льда\n", - "сделай мне иллюстрацию к маленькому принцу где он с розой разговаривает\n", - "Сделаешь картинку площади трех вокзалов в Москве?\n", - "хочу картинку с аниме девочкой\n", - "покажи мне портрет Иосифа Сталина\n", - "\n", - "Твои запросы:\n", - "\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip3 install openai python-dotenv" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from dotenv import load_dotenv\n", - "import openai\n", - "import time\n", - "import numpy as np\n", - "import os\n", - "path_to_env = os.path.join('..', '.env')\n", - "load_dotenv()\n", - "\n", - "\n", - "openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n", - "\n", - "class QuestionGenerator:\n", - " def __init__(self, query: str, max_queries: int = 3):\n", - " self.query = query\n", - " self.max_queries = max_queries\n", - " \n", - " def send_query(self):\n", - " response = None\n", - " for _ in range(self.max_queries):\n", - " try:\n", - " response = openai.Completion.create(\n", - " model=\"text-babbage-001\",\n", - " prompt=self.query,\n", - " temperature=0.7,\n", - " max_tokens=100,\n", - " top_p=0.6,\n", - " frequency_penalty=0.5,\n", - " presence_penalty=0.0\n", - " )\n", - " # random sleep seconds \n", - " time.sleep(np.random.randint(1, 5))\n", - " break\n", - " except Exception as e:\n", - " print('Error', e)\n", - " \n", - " return response\n", - " \n", - " def parse_response(self, response):\n", - " if response is None:\n", - " return []\n", - " return response['choices'][0]['text'].strip().lower().split(', ')\n", - " \n", - " def __call__(self):\n", - " response = self.send_query()\n", - " samples = self.get_topics(response)\n", - " return samples" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "qg = QuestionGenerator(PROMPT)\n", - "qg()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/finetune/candidates_inference.ipynb b/notebooks/finetune/candidates_inference.ipynb new file mode 100644 index 0000000..31e6112 --- /dev/null +++ b/notebooks/finetune/candidates_inference.ipynb @@ -0,0 +1,1584 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/Users/user010/Desktop/Programming/ML/En2RuTranslator\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "\n", + "root_dir = os.path.abspath(os.path.join(os.getcwd(), '../..'))\n", + "print(root_dir)\n", + "assert os.path.exists(root_dir), f'Could not find root directory at {root_dir}'\n", + "sys.path.insert(0, root_dir)\n", + "\n", + "from custom_utils.config_handler import read_config, pprint_config" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"root\": \"/Users/user010/Desktop/Programming/ML/En2RuTranslator\",\n", + " \"opus_model\": {\n", + " \"name\": \"opus-en-ru\",\n", + " \"model_and_tokenizer_name\": \"Helsinki-NLP/opus-mt-en-ru\"\n", + " },\n", + " \"nnlb_model\": {\n", + " \"name\": \"nnlb-1.3B-distilled\",\n", + " \"model_and_tokenizer_name\": \"facebook/nllb-200-distilled-1.3B\"\n", + " },\n", + " \"inference_dataset\": {\n", + " \"path\": \"/Users/user010/Desktop/Programming/ML/En2RuTranslator/data/processed/model_eval.csv\"\n", + " },\n", + " \"result_dataset\": {\n", + " \"path\": \"/Users/user010/Desktop/Programming/ML/En2RuTranslator/data/processed/model_eval_results.csv\",\n", + " \"cols\": {\n", + " \"reference\": \"target\",\n", + " \"candidates\": [\n", + " \"transformer-en-ru\",\n", + " \"nnlb-1.3B-distilled\"\n", + " ]\n", + " }\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "overrides = [\"setup=inference\"]\n", + "cfg = read_config(overrides=overrides)\n", + "pprint_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n", + "sample_texts = [\"Hey, how are you?\", \"My name is John Smith, I live in the United States of America.\",\n", + " \"I love NLP and Transformers!\"]\n", + "\n", + "def get_translations(model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer, sample_texts: list,\n", + " special_gen_params: dict = None) -> list:\n", + " special_gen_params = special_gen_params or {}\n", + " print(\"Tokenizing...\")\n", + " inputs = tokenizer(sample_texts, return_tensors=\"pt\", padding=True, truncation=True, max_length=600)\n", + " print(\"Generating...\")\n", + " translated_tokens = model.generate(\n", + " **inputs,\n", + " **special_gen_params,\n", + " max_length=600,\n", + " early_stopping=True\n", + " )\n", + " print(\"Decoding...\")\n", + " translated_texts = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)\n", + " return translated_texts\n", + "\n", + "def print_translations(source: list[str], target: list[str]):\n", + " assert len(source) == len(target), \"Source and target lists must be of same length\"\n", + " for src, tgt in zip(source, target):\n", + " print(f\"Source: {src}\")\n", + " print(f\"Target: {tgt}\")\n", + " print()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "model_and_tokenizer_name = cfg.nnlb_model.model_and_tokenizer_name\n", + "nnlb_tokenizer = AutoTokenizer.from_pretrained(model_and_tokenizer_name)\n", + "nnlb_model = AutoModelForSeq2SeqLM.from_pretrained(model_and_tokenizer_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/user010/Desktop/Programming/ML/En2RuTranslator/venv/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:399: UserWarning: `num_beams` is set to 1. However, `early_stopping` is set to `True` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `early_stopping`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Source: Hey, how are you?\n", + "Target: Привет, как дела?\n", + "\n", + "Source: My name is John Smith, I live in the United States of America.\n", + "Target: Меня зовут Джон Смит, я живу в Соединенных Штатах Америки.\n", + "\n", + "Source: I love NLP and Transformers!\n", + "Target: Я люблю НЛП и Трансформеров!\n", + "\n" + ] + } + ], + "source": [ + "translations = get_translations(nnlb_model, nnlb_tokenizer, sample_texts, \n", + " special_gen_params={\"forced_bos_token_id\": nnlb_tokenizer.lang_code_to_id[\"rus_Cyrl\"]})\n", + "print_translations(sample_texts, translations)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/user010/Desktop/Programming/ML/En2RuTranslator/venv/lib/python3.11/site-packages/transformers/models/marian/tokenization_marian.py:197: UserWarning: Recommended: pip install sacremoses.\n", + " warnings.warn(\"Recommended: pip install sacremoses.\")\n" + ] + } + ], + "source": [ + "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", + "\n", + "model_and_tokenizer_name = cfg.opus_model.model_and_tokenizer_name\n", + "opus_tokenizer = AutoTokenizer.from_pretrained(model_and_tokenizer_name)\n", + "opus_model = AutoModelForSeq2SeqLM.from_pretrained(model_and_tokenizer_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenizing...\n", + "Generating...\n", + "Decoding...\n", + "Source: Hey, how are you?\n", + "Target: Привет, как дела?\n", + "\n", + "Source: My name is John Smith, I live in the United States of America.\n", + "Target: Меня зовут Джон Смит, я живу в Соединенных Штатах Америки.\n", + "\n", + "Source: I love NLP and Transformers!\n", + "Target: Я люблю NLP и Transformers!\n", + "\n" + ] + } + ], + "source": [ + "translations = get_translations(opus_model, opus_tokenizer, sample_texts)\n", + "print_translations(sample_texts, translations)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sourcetarget
0The intention would also be to infiltrate terr...Террористов также обучают методам проникновени...
1Officials say that as the latest information a...По последним данным представителей власти , в ...
2While the Balakot camp was reactivated by the ...Джаиш-е-Мухаммад возобновили работу террористи...
3The incident in which Pakistan used drones to ...В качестве яркого примера новой стратегии паки...
4Officials tell OneIndia that terror groups wou...Представители власти рассказали порталу OneInd...
\n", + "
" + ], + "text/plain": [ + " source \\\n", + "0 The intention would also be to infiltrate terr... \n", + "1 Officials say that as the latest information a... \n", + "2 While the Balakot camp was reactivated by the ... \n", + "3 The incident in which Pakistan used drones to ... \n", + "4 Officials tell OneIndia that terror groups wou... \n", + "\n", + " target \n", + "0 Террористов также обучают методам проникновени... \n", + "1 По последним данным представителей власти , в ... \n", + "2 Джаиш-е-Мухаммад возобновили работу террористи... \n", + "3 В качестве яркого примера новой стратегии паки... \n", + "4 Представители власти рассказали порталу OneInd... " + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "inference_ds = pd.read_csv(cfg.inference_dataset.path)\n", + "inference_ds.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "\n", + "batch_size = 32\n", + "num_batches = len(inference_ds) // batch_size + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/32 [00:001` or unset `early_stopping`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 3%|▎ | 1/32 [00:33<17:07, 33.14s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 6%|▋ | 2/32 [01:01<15:06, 30.21s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 9%|▉ | 3/32 [01:35<15:33, 32.17s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 12%|█▎ | 4/32 [02:11<15:43, 33.68s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 16%|█▌ | 5/32 [02:45<15:10, 33.71s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 19%|█▉ | 6/32 [03:19<14:40, 33.87s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 22%|██▏ | 7/32 [03:55<14:20, 34.41s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 25%|██▌ | 8/32 [04:21<12:40, 31.68s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 28%|██▊ | 9/32 [04:45<11:16, 29.42s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 31%|███▏ | 10/32 [05:16<10:57, 29.91s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 34%|███▍ | 11/32 [05:38<09:39, 27.60s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 38%|███▊ | 12/32 [06:05<09:08, 27.42s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 41%|████ | 13/32 [06:37<09:05, 28.72s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 44%|████▍ | 14/32 [07:33<11:07, 37.06s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 47%|████▋ | 15/32 [07:59<09:30, 33.58s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 16/32 [08:40<09:32, 35.78s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 53%|█████▎ | 17/32 [09:18<09:08, 36.58s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 56%|█████▋ | 18/32 [09:54<08:28, 36.35s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 59%|█████▉ | 19/32 [10:20<07:09, 33.07s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 62%|██████▎ | 20/32 [10:41<05:55, 29.63s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 66%|██████▌ | 21/32 [11:07<05:12, 28.43s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 69%|██████▉ | 22/32 [11:41<05:02, 30.23s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 72%|███████▏ | 23/32 [12:26<05:12, 34.67s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 75%|███████▌ | 24/32 [13:55<06:47, 50.96s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 78%|███████▊ | 25/32 [14:24<05:10, 44.41s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 81%|████████▏ | 26/32 [14:59<04:09, 41.57s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 84%|████████▍ | 27/32 [15:38<03:23, 40.65s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 88%|████████▊ | 28/32 [16:10<02:32, 38.10s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 91%|█████████ | 29/32 [16:48<01:54, 38.08s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 94%|█████████▍| 30/32 [17:16<01:09, 34.99s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 97%|█████████▋| 31/32 [17:52<00:35, 35.49s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n", + "Tokenizing...\n", + "Generating...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 32/32 [18:12<00:00, 34.14s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Decoding...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# nnlb\n", + "translations = []\n", + "for i in tqdm(range(num_batches)):\n", + " batch = inference_ds.iloc[i*batch_size:(i+1)*batch_size][\"source\"].tolist()\n", + " batch_translations = get_translations(nnlb_model, nnlb_tokenizer, batch,\n", + " special_gen_params={\"forced_bos_token_id\": nnlb_tokenizer.lang_code_to_id[\"rus_Cyrl\"]})\n", + " translations.extend(batch_translations)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "inference_ds[cfg.nnlb_model.name] = translations" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/32 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sourcetargettransformer-en-runnlb-1.3B-distilled
0The intention would also be to infiltrate terr...Террористов также обучают методам проникновени...Кроме того, намерение состоит в том, чтобы про...Имеется в виду также проникновение террористов...
1Officials say that as the latest information a...По последним данным представителей власти , в ...Официальные лица говорят, что в качестве самой...Официальные лица говорят, что по последней инф...
2While the Balakot camp was reactivated by the ...Джаиш-е-Мухаммад возобновили работу террористи...В то время как лагерь в Балакоте был восстанов...В то время как лагерь Балакота был активирован...
3The incident in which Pakistan used drones to ...В качестве яркого примера новой стратегии паки...Инцидент, в ходе которого Пакистан использовал...Инцидент, когда Пакистан использовал беспилотн...
4Officials tell OneIndia that terror groups wou...Представители власти рассказали порталу OneInd...Официальные лица сообщают одной Индии, что тер...Официальные лица говорят OneIndia, что террори...
\n", + "" + ], + "text/plain": [ + " source \\\n", + "0 The intention would also be to infiltrate terr... \n", + "1 Officials say that as the latest information a... \n", + "2 While the Balakot camp was reactivated by the ... \n", + "3 The incident in which Pakistan used drones to ... \n", + "4 Officials tell OneIndia that terror groups wou... \n", + "\n", + " target \\\n", + "0 Террористов также обучают методам проникновени... \n", + "1 По последним данным представителей власти , в ... \n", + "2 Джаиш-е-Мухаммад возобновили работу террористи... \n", + "3 В качестве яркого примера новой стратегии паки... \n", + "4 Представители власти рассказали порталу OneInd... \n", + "\n", + " transformer-en-ru \\\n", + "0 Кроме того, намерение состоит в том, чтобы про... \n", + "1 Официальные лица говорят, что в качестве самой... \n", + "2 В то время как лагерь в Балакоте был восстанов... \n", + "3 Инцидент, в ходе которого Пакистан использовал... \n", + "4 Официальные лица сообщают одной Индии, что тер... \n", + "\n", + " nnlb-1.3B-distilled \n", + "0 Имеется в виду также проникновение террористов... \n", + "1 Официальные лица говорят, что по последней инф... \n", + "2 В то время как лагерь Балакота был активирован... \n", + "3 Инцидент, когда Пакистан использовал беспилотн... \n", + "4 Официальные лица говорят OneIndia, что террори... " + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inference_ds.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "inference_ds.to_csv(cfg.result_dataset.path, index=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/finetune/finetune.ipynb b/notebooks/finetune/finetune.ipynb new file mode 100644 index 0000000..97bbf2c --- /dev/null +++ b/notebooks/finetune/finetune.ipynb @@ -0,0 +1,1465 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", + "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.14.5)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.0.dev0)\n", + "Requirement already satisfied: sacrebleu in /usr/local/lib/python3.10/dist-packages (1.5.0)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.0a0+32f93b1)\n", + "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n", + "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (0.17.3)\n", + "Requirement already satisfied: wandb in /usr/local/lib/python3.10/dist-packages (0.15.12)\n", + "Collecting hydra-core\n", + " Downloading hydra_core-1.3.2-py3-none-any.whl (154 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m154.5/154.5 kB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.1)\n", + "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (11.0.0)\n", + "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.4)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.1)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n", + "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.64.1)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.12.2)\n", + "Requirement already satisfied: fsspec<2023.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<2023.9.0,>=2023.1.0->datasets) (2023.6.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.5)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.4)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.8.8)\n", + "Requirement already satisfied: tokenizers<0.15,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.14.1)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.0)\n", + "Requirement already satisfied: portalocker in /usr/local/lib/python3.10/dist-packages (from sacrebleu) (2.8.2)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.7.1)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (2.6.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n", + "Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (4.24.3)\n", + "Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.6)\n", + "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.1.18)\n", + "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (5.9.4)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.32.0)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (0.4.0)\n", + "Requirement already satisfied: pathtools in /usr/local/lib/python3.10/dist-packages (from wandb) (0.1.2)\n", + "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb) (1.3.3)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (68.2.2)\n", + "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)\n", + "Requirement already satisfied: omegaconf<2.4,>=2.2 in /usr/local/lib/python3.10/dist-packages (from hydra-core) (2.3.0)\n", + "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.10/dist-packages (from hydra-core) (4.9.3)\n", + "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n", + "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.2.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.11)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (1.26.16)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.1)\n", + "Installing collected packages: hydra-core\n", + "Successfully installed hydra-core-1.3.2\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install datasets transformers sacrebleu torch sentencepiece \"transformers[sentencepiece]\" huggingface_hub wandb hydra-core" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/jovyan/rodion/other/trans\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "\n", + "root_dir = os.path.abspath(os.path.join(os.getcwd(), '../..'))\n", + "print(root_dir)\n", + "assert os.path.exists(root_dir), f'Could not find root directory at {root_dir}'\n", + "sys.path.insert(0, root_dir)\n", + "\n", + "from custom_utils.config_handler import read_config, pprint_config" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"root\": \"/home/jovyan/rodion/other/trans\",\n", + " \"dataset\": {\n", + " \"path\": \"waleko/unarXive-en2ru\"\n", + " },\n", + " \"params\": {\n", + " \"batch_size\": 16,\n", + " \"max_length\": 512,\n", + " \"train_args\": {\n", + " \"evaluation_strategy\": \"epoch\",\n", + " \"learning_rate\": 2e-05,\n", + " \"per_device_train_batch_size\": 16,\n", + " \"per_device_eval_batch_size\": 16,\n", + " \"weight_decay\": 0.01,\n", + " \"save_total_limit\": 3,\n", + " \"num_train_epochs\": 4,\n", + " \"predict_with_generate\": true\n", + " },\n", + " \"wandb_args\": {\n", + " \"report_to\": \"wandb\",\n", + " \"run_name\": \"finetune\"\n", + " }\n", + " },\n", + " \"pretrained\": {\n", + " \"name\": \"opus-en-ru\",\n", + " \"model_and_tokenizer_name\": \"Helsinki-NLP/opus-mt-en-ru\"\n", + " },\n", + " \"finetuned\": {\n", + " \"name\": \"opus-distilled-en-ru\",\n", + " \"model_and_tokenizer_name\": \"under-tree/transformer-en-ru\",\n", + " \"output_dir\": \"/home/jovyan/rodion/other/trans/models/opus-distilled-en-ru/finetuned\"\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "overrides = [\"setup=finetune\"]\n", + "cfg = read_config(overrides=overrides)\n", + "pprint_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "688ebf26fce74939bd55e43ab219eafb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='
the W&B docs." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "wandb version 0.16.0 is available! To upgrade, please run:\n", + " $ pip install wandb --upgrade" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.15.12" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/jovyan/rodion/other/trans/notebooks/finetune/wandb/run-20231117_130518-fgt5hc0j" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run finetune to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/wide-learning/huggingface" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/wide-learning/huggingface/runs/fgt5hc0j" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [2272/2272 20:54, Epoch 4/4]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation LossBleuGen Len
10.9412000.65388467.599700127.287700
20.6977000.60877869.201200127.783500
30.6312000.58938369.907900127.373200
40.5876000.58422770.363300126.859000

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=2272, training_loss=0.6974154055958063, metrics={'train_runtime': 1218.4183, 'train_samples_per_second': 29.816, 'train_steps_per_second': 1.865, 'total_flos': 2196400063905792.0, 'train_loss': 0.6974154055958063, 'epoch': 4.0})" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [44/44 06:58]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Perplexity: 1.7936033980595856\n" + ] + } + ], + "source": [ + "print(\"Perplexity:\", np.exp(trainer.evaluate()[\"eval_loss\"]))" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"eval_loss\": 0.6561718583106995,\n", + " \"eval_bleu\": 67.197,\n", + " \"eval_gen_len\": 126.7394,\n", + " \"eval_runtime\": 194.0045,\n", + " \"eval_samples_per_second\": 2.928,\n", + " \"eval_steps_per_second\": 0.186,\n", + " \"epoch\": 4.0\n", + "}\n" + ] + } + ], + "source": [ + "import json\n", + "test_score = trainer.evaluate(tokenized_datasets[\"test\"])\n", + "\n", + "print(json.dumps(test_score, indent=4))" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Waiting for W&B process to finish... (success)." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


eval/bleu▂▅▇██▁
eval/gen_len▅█▅▂▂▁
eval/loss█▃▂▁▁█
eval/runtime██▇▇▇▁
eval/samples_per_second▁▁▆███
eval/steps_per_second▁▁▅▆▆█
train/epoch▁▁▃▄▅▆▇████
train/global_step▁▁▃▄▅▆▇████
train/learning_rate█▆▃▁
train/loss█▃▂▁
train/total_flos
train/train_loss
train/train_runtime
train/train_samples_per_second
train/train_steps_per_second

Run summary:


eval/bleu67.197
eval/gen_len126.7394
eval/loss0.65617
eval/runtime194.0045
eval/samples_per_second2.928
eval/steps_per_second0.186
train/epoch4.0
train/global_step2272
train/learning_rate0.0
train/loss0.5876
train/total_flos2196400063905792.0
train/train_loss0.69742
train/train_runtime1218.4183
train/train_samples_per_second29.816
train/train_steps_per_second1.865

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run finetune at: https://wandb.ai/wide-learning/huggingface/runs/fgt5hc0j
View job at https://wandb.ai/wide-learning/huggingface/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjExNjYyNzQ4OA==/version_details/v0
Synced 5 W&B file(s), 0 media file(s), 11 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20231117_130518-fgt5hc0j/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "wandb.finish()" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "for dirname, _, filenames in os.walk(model_name):\n", + " for filename in filenames:\n", + " print(os.path.join(dirname, filename))" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "44e967b87fa3423a80769bccdb3972e9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading (…)okenizer_config.json: 0%| | 0.00/816 [00:00=0.5.1 openai tqdm pandas + +# RodionfromHSE +accelerate==0.24.1 +aiohttp==3.8.6 +aiosignal==1.3.1 +antlr4-python3-runtime==4.9.3 +appnope==0.1.3 +asttokens==2.4.0 +async-timeout==4.0.3 +attrs==23.1.0 +backcall==0.2.0 +certifi==2023.7.22 +charset-normalizer==3.3.1 +colorama==0.4.6 +comm==0.1.4 +datasets==2.14.6 +debugpy==1.8.0 +decorator==5.1.1 +dill==0.3.7 +executing==2.0.0 +filelock==3.12.4 +frozenlist==1.4.0 +fsspec==2023.10.0 +huggingface-hub==0.17.3 +hydra-core==1.3.2 +idna==3.4 +ipykernel==6.26.0 +ipython==8.16.1 +ipywidgets==8.1.1 +jedi==0.19.1 +Jinja2==3.1.2 +jupyter_client==8.5.0 +jupyter_core==5.4.0 +jupyterlab-widgets==3.0.9 +lxml==4.9.3 +MarkupSafe==2.1.3 +matplotlib-inline==0.1.6 +mpmath==1.3.0 +multidict==6.0.4 +multiprocess==0.70.15 +nest-asyncio==1.5.8 +networkx==3.2.1 +numpy==1.26.1 +omegaconf==2.3.0 +packaging==23.2 +pandas==2.1.1 +parso==0.8.3 +pexpect==4.8.0 +pickleshare==0.7.5 +platformdirs==3.11.0 +portalocker==2.8.2 +prompt-toolkit==3.0.39 +protobuf==4.25.0 +psutil==5.9.6 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pyarrow==13.0.0 +Pygments==2.16.1 +python-dateutil==2.8.2 +pytz==2023.3.post1 +PyYAML==6.0.1 +pyzmq==25.1.1 +regex==2023.10.3 +requests==2.31.0 +sacrebleu==2.3.2 +safetensors==0.4.0 +sentencepiece==0.1.99 +six==1.16.0 +stack-data==0.6.3 +sympy==1.12 +tabulate==0.9.0 +tokenizers==0.14.1 +torch==2.1.0 +tornado==6.3.3 +tqdm==4.66.1 +traitlets==5.12.0 +transformers==4.34.1 +typing_extensions==4.8.0 +tzdata==2023.3 +urllib3==2.0.7 +wcwidth==0.2.8 +widgetsnbextension==4.0.9 +xxhash==3.4.1 +yarl==1.9.2 +zstandard==0.21.0 diff --git a/src/finetune/eval_bleu.py b/src/finetune/eval_bleu.py new file mode 100644 index 0000000..b358625 --- /dev/null +++ b/src/finetune/eval_bleu.py @@ -0,0 +1,33 @@ +import os +import sys + +root_dir = os.path.abspath(os.path.join(__file__, '../../..')) +sys.path.append(root_dir) + +import pandas as pd +import sacrebleu +from omegaconf import OmegaConf +from custom_utils.config_handler import read_config, pprint_config + + +def main() -> None: + cfg: OmegaConf = read_config(overrides=["dataset=model_eval"]) + pprint_config(cfg) + + # Load data + data = pd.read_csv(cfg.dataset.path) + + # compute blue for each candidate + cols = cfg.dataset.cols + reference_col = cols.reference + data[reference_col] = data[reference_col].apply(lambda x: [x]) + for candidate_col in cols.candidates: + bleu = data[[reference_col, candidate_col]].apply( + lambda x: sacrebleu.corpus_bleu(x[reference_col], x[candidate_col]).score, + axis=1 + ) + data[f"{candidate_col}_bleu"] = bleu + print(f"{candidate_col} BLEU: {bleu.mean():.2f}") + +if __name__ == "__main__": + main() diff --git a/src/prod/features/__init__.py b/src/prod/features/__init__.py new file mode 100644 index 0000000..f14d04c --- /dev/null +++ b/src/prod/features/__init__.py @@ -0,0 +1,3 @@ +from .attention_extractor import BaseAttentionExtractor, choose_attention_extractor +from .synonym_searcher import BaseSynonymSearcher, choose_synonym_searcher +from .translator import BaseTranslator, choose_translator diff --git a/src/prod/features/attention_extractor.py b/src/prod/features/attention_extractor.py new file mode 100644 index 0000000..1a29eaf --- /dev/null +++ b/src/prod/features/attention_extractor.py @@ -0,0 +1,59 @@ +from abc import ABC, abstractmethod +import typing as tp +from typing import Any +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from omegaconf import DictConfig +import numpy as np + +class BaseAttentionExtractor(ABC): + def __init__(self, model: AutoModelForSeq2SeqLM): + """ + :@param model: model to extract attention from + """ + super().__init__() + self.model = model + + @abstractmethod + def __call__(self, tokens: tp.List[str]) -> tp.List[tp.List[float]]: + """ + Extract attention scores for each token in relation to all other tokens + + :@param tokens: list of tokens to extract attention for + :@return: list of attention scores for each token + """ + pass + +class RandomAttentionExtractor(BaseAttentionExtractor): + """ + Just for testing purposes + """ + def __call__(self, tokens: tp.List[str]) -> tp.List[tp.List[float]]: + """ + Extract random attention scores for each token in relation to all other tokens + + :@param tokens: list of tokens to extract attention for + :@return: list of attention scores for each token + """ + n_tokens = len(tokens) + unnormalized_scores = np.random.rand(n_tokens, n_tokens) + normalized_scores = unnormalized_scores / unnormalized_scores.sum(axis=1, keepdims=True) + return normalized_scores.tolist() + +class Seq2SeqAttentionExtractor(BaseAttentionExtractor): + pass + + +def choose_attention_extractor(cfg: DictConfig) -> BaseAttentionExtractor: + """ + Choose attention extractor based on config + + :@param cfg: config to choose attention extractor from + :@return: attention extractor + """ + if cfg.type == "random": + model = None + return RandomAttentionExtractor(model) + if cfg.type == "seq2seq": + return Seq2SeqAttentionExtractor(cfg) + else: + raise ValueError(f"Unknown attention extractor type: {cfg.attention_extractor.type}") \ No newline at end of file diff --git a/src/prod/features/synonym_searcher.py b/src/prod/features/synonym_searcher.py new file mode 100644 index 0000000..742c8da --- /dev/null +++ b/src/prod/features/synonym_searcher.py @@ -0,0 +1,55 @@ +from abc import ABC, abstractmethod +import typing as tp +import fasttext +from omegaconf import DictConfig + +class BaseSynonymSearcher(ABC): + @abstractmethod + def __call__(self, tokens: tp.List[str], *args, return_scores: bool = False, + n_synonyms: int = 5, **kwargs) -> tp.List[tp.List[str]] | tp.List[tp.List[tp.Tuple[str, float]]]: + """ + Search for synonyms for each token in the list + :@param tokens: list of tokens to find synonyms for + :@param return_scores: whether to return scores for synonyms + :@param n_synonyms: number of synonyms to return for each token + :@return: list of synonyms for each token or list of synonyms and scores + """ + pass + + +class FastTextSynonymSearcher(BaseSynonymSearcher): + def __init__(self, model_path: str): + """ + :@param model_path: path to fasttext model + """ + super().__init__() + self.model = fasttext.load_model(model_path) + + def __call__(self, tokens: tp.List[str], *args, return_scores: bool = False, + n_synonyms: int = 5, **kwargs) -> tp.List[tp.List[str]] | tp.List[tp.List[tp.Tuple[str, float]]]: + """ + Search for synonyms for each token in the list + :@param tokens: list of tokens to find synonyms for + :@param return_scores: whether to return scores for synonyms + :@param n_synonyms: number of synonyms to return for each token + :@return: list of synonyms for each token or list of synonyms and scores + """ + synonyms = [] + for token in tokens: + if return_scores: + synonyms.append(self.model.get_nearest_neighbors(token, k=n_synonyms)) + else: + synonyms.append([synonym for _, synonym in self.model.get_nearest_neighbors(token, k=n_synonyms)]) + return synonyms + +def choose_synonym_searcher(cfg: DictConfig) -> BaseSynonymSearcher: + """ + Choose synonym searcher based on config + + :@param cfg: config to choose synonym searcher from + :@return: synonym searcher + """ + if cfg.type == "fasttext": + return FastTextSynonymSearcher(cfg.path) + else: + raise ValueError(f"Unknown synonym searcher type: {cfg.type}") \ No newline at end of file diff --git a/src/prod/features/translator.py b/src/prod/features/translator.py new file mode 100644 index 0000000..f8aabf0 --- /dev/null +++ b/src/prod/features/translator.py @@ -0,0 +1,93 @@ +from abc import ABC, abstractmethod +import typing as tp +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from omegaconf import DictConfig + +translation_dict = tp.Dict[str, str | tp.List[str]] + +class BaseTranslator(ABC): + def __init__(self, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer): + """ + :@param model: model to translate with + :@param tokenizer: tokenizer to use + """ + super().__init__() + self.model = model + self.tokenizer = tokenizer + + @abstractmethod + def call_one(self, text: str, *args, **kwargs) -> translation_dict: + """ + Translate one text + + :@param text: text to translate + :@return: dict with translation, input and output tokens + { + 'translation': str, + 'input_tokens': list[str] + 'output_tokens': list[str] + } + """ + pass + + def __call__(self, texts: str | tp.List[str], *args, **kwargs) -> translation_dict | tp.List[translation_dict]: + """ + Translate texts + + :@param texts: text or list of texts to translate + :@return: dict or list of dicts with translation, input and output tokens for each text + { + 'translation': str, + 'input_tokens': list[str] + 'output_tokens': list[str] + } + """ + if isinstance(texts, str): + return self.call_one(texts, *args, **kwargs) + return [self.call_one(text, *args, **kwargs) for text in texts] + + +class Seq2SeqTranslator(BaseTranslator): + def __init__(self, cfg: DictConfig): + """ + :@param cfg: config with model and tokenizer paths + """ + name = cfg.model_and_tokenizer_name + model = AutoModelForSeq2SeqLM.from_pretrained(name) + tokenizer = AutoTokenizer.from_pretrained(name) + super().__init__(model, tokenizer) + self.gen_params = cfg.get("gen_params", {}) + + def get_tokens(self, text: str, is_target: bool) -> tp.List[str]: + """ + Tokenize text with tokenizer + + :@param text: text to tokenize + :@param is_target: whether text is target or source + :@return: list of tokens + """ + if is_target: + with self.tokenizer.as_target_tokenizer(): + return self.tokenizer.tokenize(text) + return self.tokenizer.tokenize(text) + + def call_one(self, text: str, *args, **kwargs) -> translation_dict: + input_idx = self.tokenizer.encode(text, return_tensors="pt") + translation = self.model.generate(input_idx, **self.gen_params) + translation = self.tokenizer.batch_decode(translation, skip_special_tokens=True)[0] + + input_tokens = self.get_tokens(text, is_target=False) + output_tokens = self.get_tokens(translation, is_target=True) + + return { + "translation": translation, + "input_tokens": input_tokens, + "output_tokens": output_tokens + } + + +def choose_translator(cfg: DictConfig) -> BaseTranslator: + if cfg.type == "seq2seq": + return Seq2SeqTranslator(cfg) + else: + raise ValueError(f"Unknown translator type: {cfg.type}") \ No newline at end of file diff --git a/src/prod/translator_app.py b/src/prod/translator_app.py new file mode 100644 index 0000000..7971cde --- /dev/null +++ b/src/prod/translator_app.py @@ -0,0 +1,70 @@ +import os +import sys + +cur_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(cur_dir) + +from features import BaseAttentionExtractor, BaseSynonymSearcher, BaseTranslator +from features import choose_attention_extractor, choose_synonym_searcher, choose_translator +from omegaconf import DictConfig +import typing as tp +from nltk.tokenize import word_tokenize + +translation_type = tp.Dict[str, tp.Any] + +class TranslatorApp: + def __init__(self, cfg: DictConfig): + self.cfg = cfg + self.translator = choose_translator(cfg.translator) + self.src_synonym_searcher = choose_synonym_searcher(cfg.src_synonym_searcher) + self.dest_synonym_searcher = choose_synonym_searcher(cfg.dest_synonym_searcher) + self.attention_extractor = choose_attention_extractor(cfg.attention_extractor) + + def _get_words(self, text: str) -> tp.List[str]: + """ + Split text into words + :@param text: text to split + :@return: list of words + """ + text = text.lower() + return word_tokenize(text) + + def call_one(self, text: str) -> translation_type: + """ + Return translation for one text + :@param text: text to translate + :@return: dict + { + 'translation': str, + 'input_tokens': list[str] + 'input_words': list[str] + 'output_tokens': list[str] + 'output_words': list[str] + 'src_synonyms': list[list[str]] + 'dest_synonyms': list[list[str]] + 'attention': list[list[float]] + } + """ + translation = self.translator(text) + input_tokens = translation["input_tokens"] + output_tokens = translation["output_tokens"] + input_words = self._get_words(text) + output_words = self._get_words(translation["translation"]) + src_synonyms = self.src_synonym_searcher(input_words) + dest_synonyms = self.dest_synonym_searcher(output_words) + attention = self.attention_extractor(input_tokens) + return { + "translation": translation["translation"], + "input_tokens": input_tokens, + "input_words": input_words, + "output_tokens": output_tokens, + "output_words": output_words, + "src_synonyms": src_synonyms, + "dest_synonyms": dest_synonyms, + "attention": attention + } + + def __call__(self, texts: str | tp.List[str]) -> translation_type | tp.List[translation_type]: + if isinstance(texts, str): + return self.call_one(texts) + return [self.call_one(text) for text in texts]