-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
50aa538
commit 58813d1
Showing
2 changed files
with
309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 20, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from omegaconf import OmegaConf\n", | ||
"\n", | ||
"config = {\n", | ||
" 'dataset': 'saier/unarxive_citrec',\n", | ||
" 'n_train': 10_000,\n", | ||
" 'n_valid': 1_000,\n", | ||
" 'n_test': 1_000,\n", | ||
" 'max_chars_len': 512,\n", | ||
" 'min_chars_len': 128,\n", | ||
" 'save_dir': '../../data/raw/unarxive_citrec/'\n", | ||
"}\n", | ||
"config = OmegaConf.create(config)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 21, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "TypeError", | ||
"evalue": "take_n_samples() got an unexpected keyword argument 'split'", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | ||
"\u001b[1;32m/Users/user010/Desktop/Programming/ML/En2RuTranslator/notebooks/data/dataset_download.ipynb Cell 2\u001b[0m line \u001b[0;36m2\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/user010/Desktop/Programming/ML/En2RuTranslator/notebooks/data/dataset_download.ipynb#W0sZmlsZQ%3D%3D?line=15'>16</a>\u001b[0m bar\u001b[39m.\u001b[39mupdate(\u001b[39mlen\u001b[39m(new_samples))\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/user010/Desktop/Programming/ML/En2RuTranslator/notebooks/data/dataset_download.ipynb#W0sZmlsZQ%3D%3D?line=17'>18</a>\u001b[0m \u001b[39mreturn\u001b[39;00m samples\n\u001b[0;32m---> <a href='vscode-notebook-cell:/Users/user010/Desktop/Programming/ML/En2RuTranslator/notebooks/data/dataset_download.ipynb#W0sZmlsZQ%3D%3D?line=19'>20</a>\u001b[0m train_samples \u001b[39m=\u001b[39m take_n_samples(config\u001b[39m.\u001b[39;49mn_train, split\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mtrain\u001b[39;49m\u001b[39m'\u001b[39;49m)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/user010/Desktop/Programming/ML/En2RuTranslator/notebooks/data/dataset_download.ipynb#W0sZmlsZQ%3D%3D?line=20'>21</a>\u001b[0m valid_samples \u001b[39m=\u001b[39m take_n_samples(config\u001b[39m.\u001b[39mn_valid, split\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mvalidation\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/user010/Desktop/Programming/ML/En2RuTranslator/notebooks/data/dataset_download.ipynb#W0sZmlsZQ%3D%3D?line=21'>22</a>\u001b[0m test_samples \u001b[39m=\u001b[39m take_n_samples(config\u001b[39m.\u001b[39mn_test, split\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mtest\u001b[39m\u001b[39m'\u001b[39m)\n", | ||
"\u001b[0;31mTypeError\u001b[0m: take_n_samples() got an unexpected keyword argument 'split'" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from datasets import load_dataset\n", | ||
"from tqdm import tqdm\n", | ||
"\n", | ||
"# Load the dataset in streaming mode\n", | ||
"dataset = load_dataset(config.dataset, split='train', streaming=True)\n", | ||
"\n", | ||
"def take_n_samples(n: int, split: str, batch_size: int = 250) -> list:\n", | ||
" dataset = load_dataset(config.dataset, split=split, streaming=True)\n", | ||
" samples = []\n", | ||
" bar = tqdm(total=n)\n", | ||
" while len(samples) < n:\n", | ||
" new_samples = dataset.take(batch_size)\n", | ||
" new_samples = list(filter(lambda x: config.min_chars_len <= len(x['text']) <= config.max_chars_len, new_samples))\n", | ||
" samples.extend(new_samples)\n", | ||
" bar.update(len(new_samples))\n", | ||
"\n", | ||
" return samples\n", | ||
"\n", | ||
"train_samples = take_n_samples(config.n_train, split='train')\n", | ||
"valid_samples = take_n_samples(config.n_valid, split='validation')\n", | ||
"test_samples = take_n_samples(config.n_test, split='test')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def extract_texts(samples):\n", | ||
" return [sample['text'] for sample in samples]\n", | ||
"\n", | ||
"train_texts = extract_texts(train_samples)\n", | ||
"valid_texts = extract_texts(valid_samples)\n", | ||
"test_texts = extract_texts(test_samples)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"There is a strand of literature on continuous-action games on networks in which each player takes an action represented by a real value \\(x\\ge 0\\) [1]}, [2]}. Typically, player \\(i\\) maximizes the following quadratic utility function\n", | ||
"\\(u_i(x_i;{\\bf {x}}_{-i}) = \\alpha x_i - \\frac{1}{2}x_i^2 +\\gamma \\sum _{j\\ne i} \\mathcal {A}_{ij}x_ix_j,\\) \n", | ||
"\n", | ||
"There is a strand of literature on continuous-action games on networks in which each player takes an action represented by a real value \\(x\\ge 0\\) [1]}, [2]}. Typically, player \\(i\\) maximizes the following quadratic utility function\n", | ||
"\\(u_i(x_i;{\\bf {x}}_{-i}) = \\alpha x_i - \\frac{1}{2}x_i^2 +\\gamma \\sum _{j\\ne i} \\mathcal {A}_{ij}x_ix_j,\\) \n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(valid_texts[0])\n", | ||
"print(test_texts[0])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"0.25, 0.5, 0.75 quantile: [289. 382. 458.25]\n", | ||
"Max len: 508\n", | ||
"Min len: 145\n", | ||
"Example: Theorem B (Equivalent version of Beurling's Theorem, [1]}). \n", | ||
"A closed subspace of \\(H^{2}\\) is shift-invariant iff it is invariant under multiplication by every bounded analytic function in \\(H^{\\infty }\\) .\n", | ||
"\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import numpy as np\n", | ||
"\n", | ||
"train_lens = np.array([len(text) for text in train_texts])\n", | ||
"\n", | ||
"# 0.25 quantile, 0.5 quantile, 0.75 quantile\n", | ||
"print(\"0.25, 0.5, 0.75 quantile:\", np.quantile(train_lens, [0.25, 0.5, 0.75]))\n", | ||
"print(\"Max len:\", np.max(train_lens))\n", | ||
"print(\"Min len:\", np.min(train_lens))\n", | ||
"print(\"Example:\", train_texts[np.random.randint(0, len(train_texts))])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import json\n", | ||
"\n", | ||
"for texts, split_name in [\n", | ||
" (train_texts, 'train'),\n", | ||
" (valid_texts, 'valid'),\n", | ||
" (test_texts, 'test')\n", | ||
"]:\n", | ||
" path = os.path.join(config.save_dir, split_name + '.json')\n", | ||
" with open(path, 'w') as f:\n", | ||
" json.dump(texts, f, indent=4, ensure_ascii=False)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(False, False)" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"PROMPT = \"\"\"\\\n", | ||
"Ты профессиональный тестировщик больших языковых моделей.\n", | ||
"Сейчас твоя задача составить запросы, которые требуют от модели **сгенерировать изображение** (картину или фото).\n", | ||
"Эти запросы должны использовать **как явные инструкции, так и намёки**. Запросы должны быть **разнообразными** и иметь **разный уровень формальности**.\n", | ||
"\n", | ||
"Сгенирируй мне 10 таких запросов.\n", | ||
"\n", | ||
"Примеры:\n", | ||
"Нарисуй, пожалуйста, фотоаппарат марки «Зенит» с красивым плетёным ремешком.\n", | ||
"а можешь плиз нарисовать как мальчик и девочка на пляже строят замок из песка?\n", | ||
"Изобрази мне кота Матроскина, который играет на гитаре.\n", | ||
"фото как спичка горит, а кругом тают кубики льда\n", | ||
"сделай мне иллюстрацию к маленькому принцу где он с розой разговаривает\n", | ||
"Сделаешь картинку площади трех вокзалов в Москве?\n", | ||
"хочу картинку с аниме девочкой\n", | ||
"покажи мне портрет Иосифа Сталина\n", | ||
"\n", | ||
"Твои запросы:\n", | ||
"\"\"\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip3 install openai python-dotenv" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from dotenv import load_dotenv\n", | ||
"import openai\n", | ||
"import time\n", | ||
"import numpy as np\n", | ||
"import os\n", | ||
"path_to_env = os.path.join('..', '.env')\n", | ||
"load_dotenv()\n", | ||
"\n", | ||
"\n", | ||
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n", | ||
"\n", | ||
"class QuestionGenerator:\n", | ||
" def __init__(self, query: str, max_queries: int = 3):\n", | ||
" self.query = query\n", | ||
" self.max_queries = max_queries\n", | ||
" \n", | ||
" def send_query(self):\n", | ||
" response = None\n", | ||
" for _ in range(self.max_queries):\n", | ||
" try:\n", | ||
" response = openai.Completion.create(\n", | ||
" model=\"text-babbage-001\",\n", | ||
" prompt=self.query,\n", | ||
" temperature=0.7,\n", | ||
" max_tokens=100,\n", | ||
" top_p=0.6,\n", | ||
" frequency_penalty=0.5,\n", | ||
" presence_penalty=0.0\n", | ||
" )\n", | ||
" # random sleep seconds \n", | ||
" time.sleep(np.random.randint(1, 5))\n", | ||
" break\n", | ||
" except Exception as e:\n", | ||
" print('Error', e)\n", | ||
" \n", | ||
" return response\n", | ||
" \n", | ||
" def parse_response(self, response):\n", | ||
" if response is None:\n", | ||
" return []\n", | ||
" return response['choices'][0]['text'].strip().lower().split(', ')\n", | ||
" \n", | ||
" def __call__(self):\n", | ||
" response = self.send_query()\n", | ||
" samples = self.get_topics(response)\n", | ||
" return samples" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"qg = QuestionGenerator(PROMPT)\n", | ||
"qg()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |