Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add vector type search demo #62

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions vector-type-search/get_started.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9d7ecadb-0138-4d13-b0f3-35e081d6aa89",
"metadata": {},
"source": [
"# Prepare Connection"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "39e406dc-f0fd-451e-8d17-309de4b2284b",
"metadata": {},
"outputs": [],
"source": [
"%pip install pymysql sentence-transformers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a1c9cdb-4112-4cf8-b258-d66a4761213a",
"metadata": {},
"outputs": [],
"source": [
"import pymysql\n",
"def get_connection():\n",
" connection = pymysql.connect(\n",
" host = \"127.0.0.1\",\n",
" port = 4002,\n",
" user = \"root\",\n",
" database = \"public\",\n",
" )\n",
" return connection\n",
"c = get_connection()\n",
"cursor = c.cursor()"
]
},
{
"cell_type": "markdown",
"id": "42acba48-3cd2-44ac-ad20-8756cbd18b45",
"metadata": {},
"source": [
"# Prepare Model"
]
},
{
"cell_type": "markdown",
"id": "596d1521-9771-4cd2-8fe7-1e03ba080202",
"metadata": {},
"source": [
"Note that loading the model may take tens of seconds"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc6a266b-7465-457d-9b9b-3084f0fc574a",
"metadata": {},
"outputs": [],
"source": [
"from sentence_transformers import SentenceTransformer\n",
"\n",
"embed_model = SentenceTransformer(\"sentence-transformers/msmarco-MiniLM-L12-cos-v5\", trust_remote_code=True)\n",
"embed_model_dims = embed_model.get_sentence_embedding_dimension()\n",
"\n",
"def text_to_embedding(text):\n",
" \"\"\"Generates vector embeddings for the given text.\"\"\"\n",
" embedding = embed_model.encode(text)\n",
" return embedding.tolist()"
]
},
{
"cell_type": "markdown",
"id": "653e08ea-2c74-49cf-bf91-61d513a7d687",
"metadata": {},
"source": [
"# Create a Vector Table"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e385b401-e8f6-48e2-841f-3b87e8acbe72",
"metadata": {},
"outputs": [],
"source": [
"cursor.execute(f\"\"\"\n",
"CREATE TABLE IF NOT EXISTS embedded_documents(\n",
" ts TIMESTAMP TIME INDEX DEFAULT CURRENT_TIMESTAMP,\n",
" document TEXT PRIMARY KEY,\n",
" embedding VECTOR({embed_model_dims}));\n",
"\"\"\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "d5550669-be4e-4709-8b42-7d46fe74b28a",
"metadata": {},
"source": [
"# Store the Vector Embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3dd60832-b193-45c2-b0c5-ffbb8023b128",
"metadata": {},
"outputs": [],
"source": [
"documents = [\"dog\", \"fish\", \"tree\"]\n",
"\n",
"def embedding_s(embedding):\n",
" return f\"[{','.join(map(str, embedding))}]\"\n",
"\n",
"def insert_doc(doc):\n",
" embedding = embedding_s(text_to_embedding(doc))\n",
" cursor.execute(f\"\"\"\n",
"INSERT INTO embedded_documents VALUES (DEFAULT, '{doc}', '{embedding}');\n",
" \"\"\");\n",
"\n",
"for doc in documents:\n",
" insert_doc(doc)"
]
},
{
"cell_type": "markdown",
"id": "c6089cb2-18d4-46ad-80f5-e21679f9c606",
"metadata": {},
"source": [
"# Inspect the Vector Table"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "321a8a4b-1b7c-4afe-9ade-7b6bf254b912",
"metadata": {},
"outputs": [],
"source": [
"cursor.execute(\"\"\"\n",
"SELECT * FROM embedded_documents;\n",
"\"\"\");"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "986913a1-32d5-456d-8061-4808a5eaea79",
"metadata": {},
"outputs": [],
"source": [
"for row in cursor:\n",
" print(row)"
]
},
{
"cell_type": "markdown",
"id": "c6230249-c506-4812-8e1e-e3ad6a3f3475",
"metadata": {},
"source": [
"# Search\n",
"\n",
"The search term is \"a swimming animal\" which vector embedding is `[1,2,3]`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "45a2b644-ecf8-4d48-92ba-fef739c9024d",
"metadata": {},
"outputs": [],
"source": [
"def search(query, k):\n",
" query_embedding = embedding_s(text_to_embedding(query))\n",
" cursor.execute(f\"\"\"\n",
"SELECT document, vec_cos_distance(embedding, '{query_embedding}') AS distance\n",
"FROM embedded_documents\n",
"ORDER BY distance\n",
"LIMIT {k};\n",
" \"\"\");\n",
" return cursor.fetchall()\n",
"\n",
"query = \"a swimming animal\"\n",
"res = search(query, 3)\n",
"for doc in res:\n",
" print(doc)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading