Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example for posting BAAI/bge-small-v1.5 to ONNX with O4 Quantizaton #90

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,5 @@ fooling_around/fast-multilingual-e5-large/sentencepiece.bpe.model
fooling_around/fast-multilingual-e5-large/special_tokens_map.json
fooling_around/fast-multilingual-e5-large/tokenizer_config.json
fooling_around/fast-multilingual-e5-large/tokenizer.json
docs/examples/saved_models/*
docs/examples/local_cache/*
369 changes: 369 additions & 0 deletions docs/examples/To_ONNX.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,369 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7412714c",
"metadata": {},
"source": [
"# Porting to ONNX\n",
"\n",
"This notebook demonstrates how to port models from Transformers/PyTorch package to ONNX. It is based on the [Optimum](https://github.com/huggingface/optimum) library.\n",
"\n",
"## Installation\n",
"\n",
"We use [poetry](https://python-poetry.org/docs/cli) to manage dependencies. To install the dependencies, run:\n",
"\n",
"```bash\n",
"poetry install\n",
"```\n",
"\n",
"Optimum is not backward compatible, so do not upgrade to the latest version. Instead, use the versions pinned in `pyproject.toml`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0e9dbcde",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c37e1fda-c7f1-46e7-a5d4-19fa05c36ac1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import os\n",
"from pathlib import Path\n",
"from typing import List, Tuple, Any\n",
"\n",
"import numpy as np\n",
"import time\n",
"from torch import Tensor\n",
"from transformers import AutoTokenizer, AutoModel\n",
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv() # take environment variables from .env.\n",
"from optimum.onnxruntime import AutoOptimizationConfig, ORTModelForFeatureExtraction, ORTOptimizer, ORTModel\n",
"from optimum.pipelines import pipeline\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b1ecf0b6-db81-4da3-b47f-e31460ccfbf1",
"metadata": {},
"outputs": [],
"source": [
"# Load the tokenizer and PyTorch model from HuggingFace Transformers\n",
"model_id = \"BAAI/bge-small-en-v1.5\"\n",
"\n",
"hf_model = AutoModel.from_pretrained(model_id)\n",
"hf_tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"quantize = True\n",
"if quantize:\n",
" repository_id = f\"Qdrant/{model_id.split('/')[1]}-onnx-Q\"\n",
"else:\n",
" repository_id = f\"Qdrant/{model_id.split('/')[1]}-onnx\"\n",
"\n",
"save_dir = f\"local_cache/{repository_id}\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a38f5aed",
"metadata": {},
"outputs": [],
"source": [
"# The input texts can be in any language, not just English.\n",
"# Each input text should start with \"query: \" or \"passage: \", even for non-English texts.\n",
"# For tasks other than retrieval, you can simply use the \"query: \" prefix.\n",
"multilingual_queries = [\n",
" \"query: how much protein should a female eat\",\n",
" \"query: 南瓜的家常做法\",\n",
" \"query: भारत का राष्ट्रीय खेल कौन-सा है?\", # Hindi text\n",
" \"query: భారత్ దేశంలో రాష్ట్రపతి ఎవరు?\", # Telugu text\n",
" \"query: இந்தியாவின் தேசிய கோப்பை எது?\", # Tamil text\n",
" \"query: ಭಾರತದಲ್ಲಿ ರಾಷ್ಟ್ರಪತಿ ಯಾರು?\", # Kannada text\n",
" \"query: ഇന്ത്യയുടെ രാഷ്ട്രീയ ഗാനം എന്താണ്?\", # Malayalam text\n",
"]\n",
"\n",
"english_texts = [\n",
" \"India: Where the Taj Mahal meets spicy curry.\",\n",
" \"Machine Learning: Turning data into knowledge, one algorithm at a time.\",\n",
" \"Python: The language that makes programming a piece of cake.\",\n",
" \"fastembed: Accelerating embeddings for lightning-fast similarity search.\",\n",
" \"Qdrant: The ultimate tool for high-dimensional indexing and search.\",\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9f8c761c",
"metadata": {},
"outputs": [],
"source": [
"def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:\n",
" last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)\n",
" return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]\n",
"\n",
"\n",
"def hf_embed(model_id: str, inputs: List[str]):\n",
" # Tokenize the input texts\n",
" batch_dict = hf_tokenizer(inputs, max_length=512, padding=True, truncation=True, return_tensors=\"pt\")\n",
"\n",
" outputs = hf_model(**batch_dict)\n",
" embeddings = average_pool(outputs.last_hidden_state, batch_dict[\"attention_mask\"])\n",
"\n",
" # normalize embeddings\n",
" embeddings = F.normalize(embeddings, p=2, dim=1)\n",
" return embeddings.detach().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "69bb4501",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.05485763, 0.08136623, -0.00395789, ..., 0.02512371,\n",
" -0.03349504, -0.0593129 ],\n",
" [ 0.01078518, 0.01582215, 0.04614557, ..., -0.01674951,\n",
" -0.00244641, -0.06179965],\n",
" [-0.06607923, -0.01235531, -0.00689854, ..., 0.10634594,\n",
" 0.12025263, -0.05135345],\n",
" [-0.07568254, 0.00908228, -0.02221818, ..., 0.00177038,\n",
" -0.0325426 , 0.05233581],\n",
" [-0.07008213, 0.02070545, 0.02720274, ..., -0.01158645,\n",
" -0.01457597, 0.01262206]], dtype=float32)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hf_embed(inputs=english_texts, model_id=model_id)"
]
},
{
"cell_type": "markdown",
"id": "bc9d1594",
"metadata": {},
"source": [
"## Load the model using ORTModelForFeatureExtraction"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "451dbd16",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Framework not specified. Using pt to export to ONNX.\n",
"Using the export variant default. Available variants are:\n",
" - default: The default ONNX variant.\n",
"Using framework PyTorch: 2.1.2\n",
"Overriding 1 configuration item(s)\n",
"\t- use_cache -> False\n",
"/opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.11/site-packages/optimum/onnxruntime/configuration.py:770: FutureWarning: disable_embed_layer_norm will be deprecated soon, use disable_embed_layer_norm_fusion instead, disable_embed_layer_norm_fusion is set to True.\n",
" warnings.warn(\n",
"The argument use_external_data_format in the ORTOptimizer.optimize() method is deprecated and will be removed in optimum 2.0.\n",
"Optimizing model...\n",
"There is no gpu for onnxruntime to do optimization.\n",
"Configuration saved in local_cache/Qdrant/bge-small-en-v1.5-onnx-Q/ort_config.json\n",
"Optimized model saved at: local_cache/Qdrant/bge-small-en-v1.5-onnx-Q (external data format: False; saved all tensor to one file: True)\n"
]
},
{
"data": {
"text/plain": [
"('local_cache/Qdrant/bge-small-en-v1.5-onnx-Q/tokenizer_config.json',\n",
" 'local_cache/Qdrant/bge-small-en-v1.5-onnx-Q/special_tokens_map.json',\n",
" 'local_cache/Qdrant/bge-small-en-v1.5-onnx-Q/vocab.txt',\n",
" 'local_cache/Qdrant/bge-small-en-v1.5-onnx-Q/added_tokens.json',\n",
" 'local_cache/Qdrant/bge-small-en-v1.5-onnx-Q/tokenizer.json')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"model = ORTModelForFeatureExtraction.from_pretrained(model_id, export=True)\n",
"\n",
"# Remove all existing files in the save_dir using Path.unlink()\n",
"save_dir = Path(save_dir)\n",
"save_dir.mkdir(parents=True, exist_ok=True)\n",
"for p in save_dir.iterdir():\n",
" p.unlink()\n",
"\n",
"# Load the optimization configuration detailing the optimization we wish to apply\n",
"optimization_config = AutoOptimizationConfig.O4()\n",
"optimizer = ORTOptimizer.from_pretrained(model)\n",
"\n",
"optimizer.optimize(save_dir=save_dir, optimization_config=optimization_config, use_external_data_format=True)\n",
"model = ORTModelForFeatureExtraction.from_pretrained(save_dir)\n",
"tokenizer.save_pretrained(save_dir)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3587d3c4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"model_optimized.onnx: 100%|██████████| 66.5M/66.5M [00:09<00:00, 7.37MB/s]\n"
]
}
],
"source": [
"model.push_to_hub(save_directory=save_dir, repository_id=repository_id, use_auth_token=True)"
]
},
{
"cell_type": "markdown",
"id": "fde9f7a3",
"metadata": {},
"source": [
"## Trying out the model from Huggingface Hub"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "be18d371",
"metadata": {},
"outputs": [],
"source": [
"onnx_model = ORTModelForFeatureExtraction.from_pretrained(repository_id)\n",
"onnx_tokenizer = AutoTokenizer.from_pretrained(repository_id)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "532bd348",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[-0.2167, 0.0514, 0.0928, 0.1594, 0.2467, 0.3481, -0.0795, 0.1916,\n",
" 0.2227, -0.1297, 0.2020, -0.1873, 0.2221, 0.3651, 0.2194, -0.0692,\n",
" 0.1239, 0.2137, 0.0195, -0.2582, 0.2084, -0.1736, -0.0366, -0.2664,\n",
" -0.2339, 0.2233, -0.0657, -0.2686, -0.2866, -0.2278, 0.0309, 0.0677,\n",
" 0.2661, 0.1537, -0.0069, -0.3319, -0.3038, 0.2219, 0.3027, -0.2240,\n",
" -0.0523, 0.1749, -0.2705, 0.1487, -0.3244, -0.2069, -0.2114, -0.1821,\n",
" -0.1516, 0.2255, -0.2053, -0.2625, -0.0964, 0.3533, 0.2315, 0.1583,\n",
" 0.2405, -0.1198, -0.2908, 0.0707, 0.1949, 0.2105, -0.1731, 0.2771,\n",
" 0.2203, -0.1494, 0.0959, -0.1590, -0.1761, -0.0311, 0.3467, 0.2385,\n",
" 0.0964, 0.1245, 0.0470, -0.1691, -0.1228, -0.2064, -0.1982, -0.2398,\n",
" 0.0165, 0.0306, -0.1663, -0.0887, -0.1120, -0.2306, 0.1256, -0.2352,\n",
" 0.1686, -0.4168, -0.1018, -0.1619, -0.1757, -0.3001, -0.2155, -0.2885,\n",
" 0.1868, 0.1945, 0.0881, 0.6445, -0.3311, 0.1964, -0.0292, -0.2312,\n",
" -0.0040, -0.2386, -0.2235, -0.2751, -0.1402, -0.1766, 0.2363, -0.4445,\n",
" 0.2357, -0.0622, -0.0413, -0.0783, 0.1433, 0.2122, 0.2078, 0.1495,\n",
" -0.3446, 0.1067, 0.1613, 0.2959, 0.2626, -0.4740, -0.1082, 0.3239,\n",
" -0.1803, 0.0997, -0.2434, 0.2762, -0.2583, 0.3141, 0.0112, 0.1745,\n",
" 0.2053, -0.1768, -0.1194, 0.0859, -0.3041, -0.1982, 0.2145, -0.0532,\n",
" -0.2455, 0.2052, -0.1348, -0.2428, -0.2488, -0.2168, 0.2666, 0.2777,\n",
" -0.1524, -0.2402, 0.0133, 0.2093, 0.2458, -0.1074, -0.2419, -0.2991,\n",
" -0.1752, 0.0802, -0.2866, 0.2294, -0.2542, -0.2342, 0.3443, -0.3247,\n",
" -0.1445, -0.0290, 0.1546, 0.0302, -0.1828, 0.2438, 0.0751, 0.0462,\n",
" -0.1671, -0.3660, -0.0694, -0.0095, 0.1836, -0.1069, 0.0990, 0.2472,\n",
" 0.0868, -0.0425, 0.2824, -0.2537, 0.1205, 0.1889, -0.2124, 0.0136,\n",
" -0.1349, 0.0912, -0.0262, -0.2412, 0.2757, 0.2797, 0.0145, -0.2441,\n",
" 0.2213, -0.0217, -0.1446, 0.2326, 0.0941, 0.1203, -0.1906, -0.1431,\n",
" -0.0583, 0.2193, 0.1849, -0.0185, 0.0284, -0.1525, 0.2094, -0.1811,\n",
" 0.2608, 0.2625, -0.0797, 0.2194, -0.1662, -0.0763, -0.1854, -0.3313,\n",
" -0.1781, -0.1981, 0.1939, -0.2452, -0.2262, -0.0456, -0.1490, 0.2020,\n",
" 0.1871, -0.1022, 0.2299, -0.2127, 0.1315, 0.0536, -0.0557, 0.1494,\n",
" 0.2217, -0.3015, -0.3018, 0.1369, 0.1541, 0.2631, -0.2188, 0.1107,\n",
" -0.1987, 0.2343, 0.0984, -0.2731, -0.1080, -0.0538, -0.2362, -0.2376,\n",
" -0.1168, 0.2176, 0.0687, 0.0626, 0.1006, -0.1661, -0.1415, -0.1248,\n",
" -0.2594, 0.2187, -0.1395, -0.2153, -0.4498, -0.1936, -0.1451, -0.1676,\n",
" 0.2641, -0.1155, -0.0400, 0.2400, -0.3057, -0.0794, -0.2021, 0.0650,\n",
" 0.1192, -0.0939, 0.1083, 0.1714, -0.2247, -0.3154, -0.1720, 0.1512,\n",
" -0.2241, -0.3141, 0.2288, -0.1172, -0.1860, -0.2495, 0.0397, -0.0322,\n",
" 0.2313, 0.2013, 0.0296, 0.2663, 0.0793, 0.2103, -0.1596, 0.2250,\n",
" 0.1019, 0.2159, -0.2309, 0.1712, 0.2584, 0.1692, -0.0331, 0.1986,\n",
" 0.2638, 0.1509, 0.0965, -0.1657, -0.1721, 0.2107, 0.2588, -0.2725,\n",
" 0.1890, -0.0129, 0.0519, -0.1113, 0.2591, 0.2191, -0.2598, 0.2012,\n",
" 0.1182, -0.0599, 0.1190, 0.1726, -0.2416, -0.1674, -0.0343, 0.1782,\n",
" 0.0531, 0.3363, -0.0606, 0.1282, 0.3418, 0.2447, -0.3291, 0.1275,\n",
" -0.0873, 0.0312, -0.2825, -0.1609, -0.1207, -0.3035, -0.2454, 0.2148,\n",
" -0.2081, 0.0118, 0.2411, 0.2439, 0.1598, -0.3169, -0.2042, 0.2784,\n",
" 0.1087, -0.1713, 0.2854, 0.1956, 0.1949, 0.3777, -0.3244, 0.1883,\n",
" 0.0507, -0.1668, -0.0155, 0.0348, 0.1672, -0.1819, -0.1733, -0.1864,\n",
" -0.2057, 0.0486, -0.1979, 0.0774, -0.0145, -0.0991, -0.2121, 0.0822]]),\n",
" 'Qdrant: The ultimate tool for high-dimensional indexing and search.',\n",
" 5,\n",
" 5)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"onnx_quant_embed = pipeline(\"feature-extraction\", model=onnx_model, accelerator=\"ort\", tokenizer=onnx_tokenizer,return_tensors=True)\n",
"embeddings = onnx_quant_embed(inputs=english_texts)\n",
"F.normalize(embeddings[4])[:,0], english_texts[4], len(embeddings), len(english_texts)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading