diff --git a/data.dvc b/data.dvc index b3eb5b6..1c30a79 100644 --- a/data.dvc +++ b/data.dvc @@ -1,6 +1,6 @@ outs: -- md5: 3a55ac5aedd31587cd9b43fd9b2280d6.dir - size: 3806114 - nfiles: 7 +- md5: 255799c6a8913d73679631d546a9dd88.dir + nfiles: 13 hash: md5 path: data + size: 19936558 diff --git a/notebooks/data/dataset_translate.ipynb b/notebooks/data/dataset_translate.ipynb new file mode 100644 index 0000000..82052fd --- /dev/null +++ b/notebooks/data/dataset_translate.ipynb @@ -0,0 +1,207 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Dataset Preprocessing: Translation" + ], + "metadata": { + "collapsed": false + }, + "id": "84afd0d97d0e373c" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "from getpass import getpass\n", + "import multiprocessing\n", + "\n", + "import pandas as pd\n", + "import openai\n", + "from tqdm import tqdm\n", + "\n", + "from src.data.translate_openai import translate_openai" + ], + "metadata": { + "collapsed": false + }, + "id": "ba690766e1bcc074" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "openai.api_key = getpass()" + ], + "metadata": { + "collapsed": false + }, + "id": "e8b67bc54845db96" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "data_folder = '../../data/raw/unarxive_citrec'\n", + "interim_folder = '../../data/interim'\n", + "processed_folder = '../../data/processed'\n", + "splits = ['train', 'valid', 'test']" + ], + "metadata": { + "collapsed": false + }, + "id": "63b7dd3ab388b7c8" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "lock = multiprocessing.Lock()\n", + "\n", + "\n", + "# Define the worker function\n", + "def worker_function(worker_id, task_queue):\n", + " while True:\n", + " task = task_queue.get() # Get a task from the queue\n", + " if task is None:\n", + " break # Exit the loop when None is received as a task\n", + " s, i, split_name = task\n", + " print(f\"Worker {worker_id} is processing: {s} at {i}\")\n", + " try:\n", + " tmp = translate_openai(s)\n", + " lock.acquire()\n", + " out = open(os.path.join(interim_folder, split_name + '.out'), 'a', encoding='utf-8')\n", + " out.write(f\"\\n###{i}@@ {tmp}\")\n", + " out.flush()\n", + " out.close()\n", + " lock.release()\n", + " # print(tmp)\n", + " except Exception as e:\n", + " print(f\"Fail!! pos={i}\")\n", + " print(e)" + ], + "metadata": { + "collapsed": false + }, + "id": "514d22eab7752d6c" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "num_workers = 8\n", + "\n", + "# Create and start worker processes\n", + "workers = []\n", + "task_queue = multiprocessing.Queue()\n", + "for worker_id in range(num_workers):\n", + " worker = multiprocessing.Process(target=worker_function, args=(worker_id, task_queue))\n", + " workers.append(worker)\n", + " worker.start()\n", + "\n", + "for split_name, st in zip(splits, [1094, 0, 0]):\n", + " arr = json.load(open(os.path.join(data_folder, split_name + '.json'), 'r', encoding='utf-8'))\n", + " print(f\"Working on {split_name}...\")\n", + "\n", + " for i in range(st, len(arr)):\n", + " s = arr[i]\n", + " task_queue.put((s, i, split_name))\n", + "\n", + "# Add None to the queue for each worker to signal them to exit\n", + "for _ in range(num_workers):\n", + " task_queue.put(None)\n", + "\n", + "# Wait for all workers to finish\n", + "for worker in workers:\n", + " worker.join()\n", + "\n", + "print(\"All workers have finished\")" + ], + "metadata": { + "collapsed": false + }, + "id": "cfb7d0e6eef90340" + }, + { + "cell_type": "markdown", + "source": [ + "## Dataset Preprocessing: Build" + ], + "metadata": { + "collapsed": false + }, + "id": "753a47ea274289b2" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "for split_name in splits:\n", + " src = json.load(open(os.path.join(data_folder, split_name + '.json'), 'r', encoding='utf-8'))\n", + " with open(os.path.join(interim_folder, split_name + '.out'), 'r', encoding='utf-8') as f:\n", + " text = '\\n'.join(f.readlines())\n", + " lines = text.split('\\n###')\n", + " targets = [None for _ in range(len(src))]\n", + " for line in lines[1:]:\n", + " pos = line.find('@@ ')\n", + " idx = int(line[:pos])\n", + " aft = line[pos + 3:]\n", + " targets[idx] = aft.strip()\n", + " df = pd.DataFrame({'source': src, 'target': targets})\n", + "\n", + " # fix missing values by trying to translate again\n", + " print(f\"Fixing missing values in {split_name}...\")\n", + " # save the indices of missing values\n", + " idx = df.loc[df['target'].isna()].index\n", + " print(list(idx))\n", + " df.loc[df['target'].isna(), 'target'] = [translate_openai(s) for s in tqdm(df.loc[df['target'].isna(), 'source'])]\n", + "\n", + " df.to_csv(os.path.join(processed_folder, split_name + '.csv'), index=False)" + ], + "metadata": { + "collapsed": false + }, + "id": "ad2c459cb00b2897" + }, + { + "cell_type": "markdown", + "source": [ + "Results have been saved to `data/processed` folder and uploaded to huggingface datasets as [unarxive-en2ru](https://huggingface.co/datasets/waleko/unarxive-en2ru)." + ], + "metadata": { + "collapsed": false + }, + "id": "3bf0923602f361f9" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/requirements.txt b/requirements.txt index d4f7d11..adcabd5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,6 @@ coverage awscli flake8 python-dotenv>=0.5.1 +openai +tqdm +pandas diff --git a/src/data/translate_openai.py b/src/data/translate_openai.py new file mode 100644 index 0000000..cd6cdd3 --- /dev/null +++ b/src/data/translate_openai.py @@ -0,0 +1,24 @@ +from typing import List + +import openai + +default_prompt = """Translate the provided scientific texts from English to Russian. Before translation, please make sure to clean up any messy or illegible symbols within the texts. Ensure that +there are no additional explanations or content in the translations. The outputs must contain precisely the same +information as the cleaned-up inputs.""" + + +def translate_openai(string: str, system_prompt=default_prompt): + """ + Translate a string from English to Russian using OpenAI's GPT-3 API + :param string: String to translate + :param system_prompt: The prompt to use for the system + :return: A list of translated strings + """ + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": string}, + ] + ) + return completion.choices[0].message.content