-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from RodionfromHSE/openai-translate
- Loading branch information
Showing
4 changed files
with
237 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
outs: | ||
- md5: 3a55ac5aedd31587cd9b43fd9b2280d6.dir | ||
size: 3806114 | ||
nfiles: 7 | ||
- md5: 255799c6a8913d73679631d546a9dd88.dir | ||
nfiles: 13 | ||
hash: md5 | ||
path: data | ||
size: 19936558 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Dataset Preprocessing: Translation" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "84afd0d97d0e373c" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import json\n", | ||
"from getpass import getpass\n", | ||
"import multiprocessing\n", | ||
"\n", | ||
"import pandas as pd\n", | ||
"import openai\n", | ||
"from tqdm import tqdm\n", | ||
"\n", | ||
"from src.data.translate_openai import translate_openai" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "ba690766e1bcc074" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"openai.api_key = getpass()" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "e8b67bc54845db96" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"data_folder = '../../data/raw/unarxive_citrec'\n", | ||
"interim_folder = '../../data/interim'\n", | ||
"processed_folder = '../../data/processed'\n", | ||
"splits = ['train', 'valid', 'test']" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "63b7dd3ab388b7c8" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"lock = multiprocessing.Lock()\n", | ||
"\n", | ||
"\n", | ||
"# Define the worker function\n", | ||
"def worker_function(worker_id, task_queue):\n", | ||
" while True:\n", | ||
" task = task_queue.get() # Get a task from the queue\n", | ||
" if task is None:\n", | ||
" break # Exit the loop when None is received as a task\n", | ||
" s, i, split_name = task\n", | ||
" print(f\"Worker {worker_id} is processing: {s} at {i}\")\n", | ||
" try:\n", | ||
" tmp = translate_openai(s)\n", | ||
" lock.acquire()\n", | ||
" out = open(os.path.join(interim_folder, split_name + '.out'), 'a', encoding='utf-8')\n", | ||
" out.write(f\"\\n###{i}@@ {tmp}\")\n", | ||
" out.flush()\n", | ||
" out.close()\n", | ||
" lock.release()\n", | ||
" # print(tmp)\n", | ||
" except Exception as e:\n", | ||
" print(f\"Fail!! pos={i}\")\n", | ||
" print(e)" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "514d22eab7752d6c" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"num_workers = 8\n", | ||
"\n", | ||
"# Create and start worker processes\n", | ||
"workers = []\n", | ||
"task_queue = multiprocessing.Queue()\n", | ||
"for worker_id in range(num_workers):\n", | ||
" worker = multiprocessing.Process(target=worker_function, args=(worker_id, task_queue))\n", | ||
" workers.append(worker)\n", | ||
" worker.start()\n", | ||
"\n", | ||
"for split_name, st in zip(splits, [1094, 0, 0]):\n", | ||
" arr = json.load(open(os.path.join(data_folder, split_name + '.json'), 'r', encoding='utf-8'))\n", | ||
" print(f\"Working on {split_name}...\")\n", | ||
"\n", | ||
" for i in range(st, len(arr)):\n", | ||
" s = arr[i]\n", | ||
" task_queue.put((s, i, split_name))\n", | ||
"\n", | ||
"# Add None to the queue for each worker to signal them to exit\n", | ||
"for _ in range(num_workers):\n", | ||
" task_queue.put(None)\n", | ||
"\n", | ||
"# Wait for all workers to finish\n", | ||
"for worker in workers:\n", | ||
" worker.join()\n", | ||
"\n", | ||
"print(\"All workers have finished\")" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "cfb7d0e6eef90340" | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## Dataset Preprocessing: Build" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "753a47ea274289b2" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"for split_name in splits:\n", | ||
" src = json.load(open(os.path.join(data_folder, split_name + '.json'), 'r', encoding='utf-8'))\n", | ||
" with open(os.path.join(interim_folder, split_name + '.out'), 'r', encoding='utf-8') as f:\n", | ||
" text = '\\n'.join(f.readlines())\n", | ||
" lines = text.split('\\n###')\n", | ||
" targets = [None for _ in range(len(src))]\n", | ||
" for line in lines[1:]:\n", | ||
" pos = line.find('@@ ')\n", | ||
" idx = int(line[:pos])\n", | ||
" aft = line[pos + 3:]\n", | ||
" targets[idx] = aft.strip()\n", | ||
" df = pd.DataFrame({'source': src, 'target': targets})\n", | ||
"\n", | ||
" # fix missing values by trying to translate again\n", | ||
" print(f\"Fixing missing values in {split_name}...\")\n", | ||
" # save the indices of missing values\n", | ||
" idx = df.loc[df['target'].isna()].index\n", | ||
" print(list(idx))\n", | ||
" df.loc[df['target'].isna(), 'target'] = [translate_openai(s) for s in tqdm(df.loc[df['target'].isna(), 'source'])]\n", | ||
"\n", | ||
" df.to_csv(os.path.join(processed_folder, split_name + '.csv'), index=False)" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "ad2c459cb00b2897" | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Results have been saved to `data/processed` folder and uploaded to huggingface datasets as [unarxive-en2ru](https://huggingface.co/datasets/waleko/unarxive-en2ru)." | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "3bf0923602f361f9" | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,6 @@ coverage | |
awscli | ||
flake8 | ||
python-dotenv>=0.5.1 | ||
openai | ||
tqdm | ||
pandas |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from typing import List | ||
|
||
import openai | ||
|
||
default_prompt = """Translate the provided scientific texts from English to Russian. Before translation, please make sure to clean up any messy or illegible symbols within the texts. Ensure that | ||
there are no additional explanations or content in the translations. The outputs must contain precisely the same | ||
information as the cleaned-up inputs.""" | ||
|
||
|
||
def translate_openai(string: str, system_prompt=default_prompt): | ||
""" | ||
Translate a string from English to Russian using OpenAI's GPT-3 API | ||
:param string: String to translate | ||
:param system_prompt: The prompt to use for the system | ||
:return: A list of translated strings | ||
""" | ||
completion = openai.ChatCompletion.create( | ||
model="gpt-3.5-turbo", | ||
messages=[ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": string}, | ||
] | ||
) | ||
return completion.choices[0].message.content |