Skip to content

Commit

Permalink
Merge pull request #1 from RodionfromHSE/openai-translate
Browse files Browse the repository at this point in the history
  • Loading branch information
waleko authored Nov 5, 2023
2 parents 95de323 + 8227599 commit 3e7200b
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 3 deletions.
6 changes: 3 additions & 3 deletions data.dvc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
outs:
- md5: 3a55ac5aedd31587cd9b43fd9b2280d6.dir
size: 3806114
nfiles: 7
- md5: 255799c6a8913d73679631d546a9dd88.dir
nfiles: 13
hash: md5
path: data
size: 19936558
207 changes: 207 additions & 0 deletions notebooks/data/dataset_translate.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Dataset Preprocessing: Translation"
],
"metadata": {
"collapsed": false
},
"id": "84afd0d97d0e373c"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"import os\n",
"import json\n",
"from getpass import getpass\n",
"import multiprocessing\n",
"\n",
"import pandas as pd\n",
"import openai\n",
"from tqdm import tqdm\n",
"\n",
"from src.data.translate_openai import translate_openai"
],
"metadata": {
"collapsed": false
},
"id": "ba690766e1bcc074"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"openai.api_key = getpass()"
],
"metadata": {
"collapsed": false
},
"id": "e8b67bc54845db96"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"data_folder = '../../data/raw/unarxive_citrec'\n",
"interim_folder = '../../data/interim'\n",
"processed_folder = '../../data/processed'\n",
"splits = ['train', 'valid', 'test']"
],
"metadata": {
"collapsed": false
},
"id": "63b7dd3ab388b7c8"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"lock = multiprocessing.Lock()\n",
"\n",
"\n",
"# Define the worker function\n",
"def worker_function(worker_id, task_queue):\n",
" while True:\n",
" task = task_queue.get() # Get a task from the queue\n",
" if task is None:\n",
" break # Exit the loop when None is received as a task\n",
" s, i, split_name = task\n",
" print(f\"Worker {worker_id} is processing: {s} at {i}\")\n",
" try:\n",
" tmp = translate_openai(s)\n",
" lock.acquire()\n",
" out = open(os.path.join(interim_folder, split_name + '.out'), 'a', encoding='utf-8')\n",
" out.write(f\"\\n###{i}@@ {tmp}\")\n",
" out.flush()\n",
" out.close()\n",
" lock.release()\n",
" # print(tmp)\n",
" except Exception as e:\n",
" print(f\"Fail!! pos={i}\")\n",
" print(e)"
],
"metadata": {
"collapsed": false
},
"id": "514d22eab7752d6c"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"num_workers = 8\n",
"\n",
"# Create and start worker processes\n",
"workers = []\n",
"task_queue = multiprocessing.Queue()\n",
"for worker_id in range(num_workers):\n",
" worker = multiprocessing.Process(target=worker_function, args=(worker_id, task_queue))\n",
" workers.append(worker)\n",
" worker.start()\n",
"\n",
"for split_name, st in zip(splits, [1094, 0, 0]):\n",
" arr = json.load(open(os.path.join(data_folder, split_name + '.json'), 'r', encoding='utf-8'))\n",
" print(f\"Working on {split_name}...\")\n",
"\n",
" for i in range(st, len(arr)):\n",
" s = arr[i]\n",
" task_queue.put((s, i, split_name))\n",
"\n",
"# Add None to the queue for each worker to signal them to exit\n",
"for _ in range(num_workers):\n",
" task_queue.put(None)\n",
"\n",
"# Wait for all workers to finish\n",
"for worker in workers:\n",
" worker.join()\n",
"\n",
"print(\"All workers have finished\")"
],
"metadata": {
"collapsed": false
},
"id": "cfb7d0e6eef90340"
},
{
"cell_type": "markdown",
"source": [
"## Dataset Preprocessing: Build"
],
"metadata": {
"collapsed": false
},
"id": "753a47ea274289b2"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"for split_name in splits:\n",
" src = json.load(open(os.path.join(data_folder, split_name + '.json'), 'r', encoding='utf-8'))\n",
" with open(os.path.join(interim_folder, split_name + '.out'), 'r', encoding='utf-8') as f:\n",
" text = '\\n'.join(f.readlines())\n",
" lines = text.split('\\n###')\n",
" targets = [None for _ in range(len(src))]\n",
" for line in lines[1:]:\n",
" pos = line.find('@@ ')\n",
" idx = int(line[:pos])\n",
" aft = line[pos + 3:]\n",
" targets[idx] = aft.strip()\n",
" df = pd.DataFrame({'source': src, 'target': targets})\n",
"\n",
" # fix missing values by trying to translate again\n",
" print(f\"Fixing missing values in {split_name}...\")\n",
" # save the indices of missing values\n",
" idx = df.loc[df['target'].isna()].index\n",
" print(list(idx))\n",
" df.loc[df['target'].isna(), 'target'] = [translate_openai(s) for s in tqdm(df.loc[df['target'].isna(), 'source'])]\n",
"\n",
" df.to_csv(os.path.join(processed_folder, split_name + '.csv'), index=False)"
],
"metadata": {
"collapsed": false
},
"id": "ad2c459cb00b2897"
},
{
"cell_type": "markdown",
"source": [
"Results have been saved to `data/processed` folder and uploaded to huggingface datasets as [unarxive-en2ru](https://huggingface.co/datasets/waleko/unarxive-en2ru)."
],
"metadata": {
"collapsed": false
},
"id": "3bf0923602f361f9"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ coverage
awscli
flake8
python-dotenv>=0.5.1
openai
tqdm
pandas
24 changes: 24 additions & 0 deletions src/data/translate_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import List

import openai

default_prompt = """Translate the provided scientific texts from English to Russian. Before translation, please make sure to clean up any messy or illegible symbols within the texts. Ensure that
there are no additional explanations or content in the translations. The outputs must contain precisely the same
information as the cleaned-up inputs."""


def translate_openai(string: str, system_prompt=default_prompt):
"""
Translate a string from English to Russian using OpenAI's GPT-3 API
:param string: String to translate
:param system_prompt: The prompt to use for the system
:return: A list of translated strings
"""
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": string},
]
)
return completion.choices[0].message.content

0 comments on commit 3e7200b

Please sign in to comment.