diff --git a/.gitignore b/.gitignore index b039fb9..3addd7b 100644 --- a/.gitignore +++ b/.gitignore @@ -135,4 +135,4 @@ dmypy.json data/*.json data/*.jsonl dbs/meilisearch/meili_data -dbs/qdrant/onnx_model/onnx \ No newline at end of file +*/*/onnx_model/onnx \ No newline at end of file diff --git a/dbs/weaviate/.env.example b/dbs/weaviate/.env.example new file mode 100644 index 0000000..b64ef4e --- /dev/null +++ b/dbs/weaviate/.env.example @@ -0,0 +1,13 @@ +WEAVIATE_VERSION = "1.18.4" +WEAVIATE_PORT = 8080 +WEAVIATE_HOST = "localhost" +WEAVIATE_SERVICE = "weaviate" +API_PORT = 8004 +EMBEDDING_MODEL_CHECKPOINT = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" +ONNX_MODEL_FILENAME = "model_optimized_quantized.onnx" + +# Container image tag +TAG = "0.1.0" + +# Docker project namespace (defaults to the current folder name if not set) +COMPOSE_PROJECT_NAME = weaviate_wine \ No newline at end of file diff --git a/dbs/weaviate/Dockerfile b/dbs/weaviate/Dockerfile new file mode 100644 index 0000000..2a47f56 --- /dev/null +++ b/dbs/weaviate/Dockerfile @@ -0,0 +1,13 @@ +FROM python:3.10-slim-bullseye + +WORKDIR /wine + +COPY ./requirements.txt /wine/requirements.txt + +RUN pip install --no-cache-dir -U pip wheel setuptools +RUN pip install --no-cache-dir -r /wine/requirements.txt + +COPY ./api /wine/api +COPY ./schemas /wine/schemas + +EXPOSE 8000 \ No newline at end of file diff --git a/dbs/weaviate/Dockerfile.onnxruntime b/dbs/weaviate/Dockerfile.onnxruntime new file mode 100644 index 0000000..f62d37c --- /dev/null +++ b/dbs/weaviate/Dockerfile.onnxruntime @@ -0,0 +1,14 @@ +FROM python:3.10-slim-bullseye + +WORKDIR /wine + +COPY ./requirements-onnx.txt /wine/requirements-onnx.txt + +RUN pip install --no-cache-dir -U pip wheel setuptools +RUN pip install --no-cache-dir -r /wine/requirements-onnx.txt + +COPY ./api /wine/api +COPY ./schemas /wine/schemas +COPY ./onnx_model /wine/onnx_model + +EXPOSE 8000 \ No newline at end of file diff --git a/dbs/weaviate/README.md b/dbs/weaviate/README.md new file mode 100644 index 0000000..dfebb47 --- /dev/null +++ b/dbs/weaviate/README.md @@ -0,0 +1,244 @@ +# Weaviate + +[Weaviate](https://weaviate.io/) is an ML-first vector search database written in Go. It allows users to store data objects and vector embeddings and scale to billions of objects, allowing for sub-millisecond searches. The primary use case for a vector database is to retrieve results that are most semantically similar to the input natural language query. The semantic similarity is obtained by comparing the sentence embeddings (which are n-dimensional vectors) between the input query and the data stored in the database. Most vector DBs, including Weaviate, store both the metadata (as JSON) and the sentence embeddings of text on which we want to search (as vectors), allowing us to perform much more flexible searches than keyword-only search databases. In the case of Weaviate, it even allows hybrid searches, giving developers the flexibility to decide what search methods work best on the data at hand. + +Code is provided for ingesting the wine reviews dataset into Weaviate. In addition, a query API written in FastAPI is also provided that allows a user to query available endpoints. As always in FastAPI, documentation is available via OpenAPI (http://localhost:8005/docs). + +* Unlike "normal" databases, in a vector DB, the vectorization process is the biggest bottleneck, and because a lot of vector DBs are relatively new, they do not yet support async indexing (although they might, soon). + * It doesn't make sense to focus on async requests for vector DBs at present -- rather, it makes more sense to focus on speeding up the vectorization process +* [Pydantic](https://docs.pydantic.dev) is used for schema validation, both prior to data ingestion and during API request handling +* For ease of reproducibility during development, the whole setup is orchestrated and deployed via docker + +## Setup + +Note that this code base has been tested in Python 3.10, and requires a minimum of Python 3.10 to work. Install dependencies via `requirements.txt`. + +```sh +# Setup the environment for the first time +python -m venv weaviate_venv # python -> python 3.10 + +# Activate the environment (for subsequent runs) +source weaviate_venv/bin/activate + +python -m pip install -r requirements.txt +``` + +--- + +## Step 1: Set up containers + +Docker compose files are provided, which start a persistent-volume Weaviate database with credentials specified in `.env`. The `weaviate` variable in the environment file under the `fastapi` service indicates that we are opening up the database service to FastAPI (running as a separate service, in a separate container) downstream. Both containers can communicate with one another with the common network that they share, on the exact port numbers specified. + +The database and API services can be restarted at any time for maintenance and updates by simply running the `docker restart ` command. + +**💡 Note:** The setup shown here would not be ideal in production, as there are other details related to security and scalability that are not addressed via simple docker, but, this is a good starting point to begin experimenting! + +### Option 1: Use `sbert` model + +If using the `sbert` model [from the sentence-transformers repo](https://www.sbert.net/) directly, use the provided `docker-compose.yml` to initiate separate containers, one that runs Weaviate, and another one that serves as an API on top of the database. + +**⚠️ Note**: This approach will attempt to run `sbert` on a GPU if available, and if not, on CPU (while utilizing all CPU cores). This approach may not yield the fastest vectorization if using CPU-only -- a more optimized version is provided [below](#option-2-use-onnxruntime-model-highly-optimized-for-cpu). + +``` +docker compose -f docker-compose.yml up -d +``` +Tear down the services using the following command. + +``` +docker compose -f docker-compose.yml down +``` + +### Option 2: Use `onnxruntime` model + +An approach to make the sentence embedding vector generation process more efficient is to optimize and quantize the original `sbert` model via [ONNX (Open Neural Network Exchange)](https://huggingface.co/docs/transformers/serialization). This framework provides a standard interface for optimizing deep learning models and their computational graphs to be executed much faster and with lower resources on specialized runtimes and hardware. + +To deploy the services with the optimized `sbert` model, use the provided `docker-compose.yml` to initiate separate containers, one that runs Weaviate, and another one that serves as an API on top of the database. + +**⚠️ Note**: This approach requires some more additional packages from Hugging Face, on top of the `sbert` modules. **Currently (as of early 2023), they only work on Python 3.10**. For this section, make sure to only use Python 3.10 if ONNX complains about module installations via `pip`. + +``` +docker compose -f docker-compose-onnx.yml up -d +``` +Tear down the services using the following command. + +``` +docker compose -f docker-compose-onnx.yml down +``` + + +## Step 2: Ingest the data + +We ingest both the JSON data for full-text search and filtering, as well as the sentence embedding vectors for similarity search into Weaviate. For this dataset, it's reasonable to expect that a simple concatenation of fields like `title`, `country`, `province`, `variety` and `description` would result in a useful vector that can be compared against a search query, also vectorized in the same embedding space. + +As an example, consider the following data snippet form the `data/` directory in this repo: + +```json +"variety": "Red Blend", +"country": "Italy", +"province": "Tuscany", +"title": "Castello San Donato in Perano 2009 Riserva (Chianti Classico)", +"description": "Made from a blend of 85% Sangiovese and 15% Merlot, this ripe wine delivers soft plum, black currants, clove and cracked pepper sensations accented with coffee and espresso notes. A backbone of firm tannins give structure. Drink now through 2019." +``` + +The above fields are concatenated for vectorization, and then indexed along with the data to Weaviate. + + +### Choice of embedding model + +[SentenceTransformers](https://www.sbert.net/) is a Python framework for a range of sentence and text embeddings. It results from extensive work on fine-tuning BERT to work well on semantic similarity tasks using Siamese BERT networks, where the model is trained to predict the similarity between sentence pairs. The original work is [described here](https://arxiv.org/abs/1908.10084). + +#### Why use sentence transformers? + +Although larger and more powerful text embedding models exist (such as [OpenAI embeddings](https://platform.openai.com/docs/guides/embeddings)), they can become really expensive as they are not free, and charge per token of text. SentenceTransformers are free and open-source, and have been optimized for years for performance, both to utilize all CPU cores and for reduced size while maintaining performance. A full list of sentence transformer models [is in the project page](https://www.sbert.net/docs/pretrained_models.html). + +For this work, it makes sense to use among the fastest models in this list, which is the `multi-qa-MiniLM-L6-cos-v1` **uncased** model. As per the docs, it was tuned for semantic search and question answering, and generates sentence embeddings for single sentences or paragraphs up to a maximum sequence length of 512. It was trained on 215M question answer pairs from various sources. Compared to the more general-purpose `all-MiniLM-L6-v2` model, it shows slightly improved performance on semantic search tasks while offering a similar level of performance. [See the sbert docs](https://www.sbert.net/docs/pretrained_models.html) for more details on performance comparisons between the various pretrained models. + +### Build ONNX optimized model files + +A key step, if using ONNX runtime to speed up vectorization, is to build optimized and quantized models from the base `sbert` model. This is done by running the script `onnx_optimizer.py` in the `onnx_model/` directory. + +The optimization/quantization are done using a modified version of [the methods in this blog post](https://www.philschmid.de/optimize-sentence-transformers). We ony perform dynamic quantization for now as static quantization requires a very hardware and OS-specific set of instructions that don't generalize -- it only makes sense to do this in a production environment that is expected to serve thousands of requests in short time. As further reading, a detailed explanation of the difference between static and dynamic quantization [is available in the Hugging Face docs](https://huggingface.co/docs/optimum/concept_guides/quantization). + +```sh +cd onnx_model +python onnx_optimizer.py # python -> python 3.10 +``` + +Running this script generates a new directory `onnx_models/onnx` with the optimized and quantized models, along with their associated model config files. + +* `model_optimized.onnx` +* `model_optimized_quantized.onnx` + +The `model_optimized_quantized.onnx` is a dynamically-quantized model file that is ~26% smaller in size than the original model in this case, and generates sentence embeddings roughly 1.8x faster than the original sentence transformers model, due to the optimized ONNX runtime. A more detailed blog post benchmarking these numbers will be published shortly! + +### Run data loader + +Data is ingested into the Weaviate database through the scripts in the `scripts` directly. The scripts validate the input JSON data via [Pydantic](https://docs.pydantic.dev), and then index both the JSON data and the vectors to Weaviate using the [Weaviate Python client](https://github.com/weaviate/weaviate-python-client). + +As mentioned before, the fields `variety`, `country`, `province`, `title` and `description` are concatenated, vectorized, and then indexed to Weaviate. + +#### Option 1: Use `sbert` + +If running on a Macbook or a machine without a GPU, it's possible to generate sentence embeddings using the original `sbert` model as per the `EMBEDDING_MODEL_CHECKPOINT` variable in the `.env` file. + +```sh +cd scripts +python bulk_index_sbert.py +``` + +#### Option 2: Use `onnx` quantized model + +If running on a remote Linux CPU instance, it is highly recommended to use the ONNX quantized model for the `EMBEDDING_MODEL_CHECKPOINT` model specified in `.env`. Using the appropriate hardware on modern Intel chips can vastly outperform the original `sbert` model on a conventional CPU, allowing for lower-cost and higher-throughput indexing for much larger datasets, all with very low memory consumption (under 2 GB). + +```sh +cd scripts +python bulk_index_onnx.py +``` + +### Time to index dataset + +Because vectorizing a large dataset can be an expensive step, part of the goal of this exercise is to see whether we can do so on CPU, with the fewest resources possible. + +In short, We are able to index all 129,971 wine reviews from the dataset in **28 min 30 sec**. The conditions under which this indexing time was achieved are listed below. + +* Ubuntu 22.04 EC2 `T2.xlarge` instance on AWS (1 CPU with 4 cores, 16 GB of RAM) +* Python 3.10.10 (Did not use Python 3.11 because ONNX doesn't support it yet) +* Quantized ONNX version of the `sentence-transformers/multi-qa-MiniLM-L6-cos-v1` sentence transformer +* Weaviate version `1.18.4` + +## Step 3: Test API + +Once the data has been successfully loaded into Weaviate and the containers are up and running, we can test out a search query via an HTTP request as follows. + +```sh +curl -X 'GET' \ + 'http://0.0.0.0:8005/wine/search?terms=tuscany%20red&max_price=100&country=Italy' +``` + +This cURL request passes the search terms "**tuscany red**", along with the country "Italy" and a maximum price of "100" to the `/wine/search/` endpoint, which is then parsed into a working filter query to Weaviate by the FastAPI backend. The query runs and retrieves results that are semantically similar to the input query for red Tuscan wines, and, if the setup was done correctly, we should see the following response: + +```json +[ + { + "id": 8456, + "country": "Italy", + "province": "Tuscany", + "title": "Petra 2008 Petra Red (Toscana)", + "description": "From one of Italy's most important showcase designer wineries, this blend of Cabernet Sauvignon and Merlot lives up to its super Tuscan celebrity. It is gently redolent of dark chocolate, ripe fruit, leather, tobacco and crushed black pepper—the bouquet's elegant moderation is one of its strongest points. The mouthfeel is rich, creamy and long. Drink after 2018.", + "points": 92, + "price": 80.0, + "variety": "Red Blend", + "winery": "Petra" + }, + { + "id": 896, + "country": "Italy", + "province": "Tuscany", + "title": "Le Buche 2006 Giuseppe Olivi Memento Red (Toscana)", + "description": "Le Buche is an interesting winery to watch, and its various Tuscan blends show great promise. Memento is equal parts Sangiovese and Syrah with a soft, velvety texture and a bright berry finish.", + "points": 90, + "price": 45.0, + "variety": "Red Blend", + "winery": "Le Buche" + }, + { + "id": 9343, + "country": "Italy", + "province": "Tuscany", + "title": "Poggio Mandorlo 2008 Red (Toscana)", + "description": "Made from Merlot and Cabernet Franc, this structured red offers aromas of black currant, toast, graphite and a whiff of cedar. The firm palate offers coconut, coffee, grilled sage and red berry alongside bracing tannins. Drink sooner rather than later to capture the fruit richness.", + "points": 89, + "price": 60.0, + "variety": "Red Blend", + "winery": "Poggio Mandorlo" + } +] +``` + +Not bad! This example correctly returns some highly rated Tuscan red wines form Italy along with their price. More specific search queries, such as low/high acidity, or flavour profiles of wines can also be entered to get more relevant results by country. + +## Step 4: Extend the API + +The API can be easily extended with the provided structure. + +- The `schemas` directory houses the Pydantic schemas, both for the data input as well as for the endpoint outputs + - As the data model gets more complex, we can add more files and separate the ingestion logic from the API logic here +- The `api/routers` directory contains the endpoint routes so that we can provide additional endpoint that answer more business questions + - For e.g.: "What are the top rated wines from Argentina?" + - In general, it makes sense to organize specific business use cases into their own router files +- The `api/main.py` file collects all the routes and schemas to run the API + + +#### Existing endpoints + +As an example, some search endpoints are implemented and can be accessed via the API at the following URLs. + +``` +GET +/wine/search +Semantic similarity search +``` + +``` +GET +/wine/search_by_country +Semantic similarity search for wines by country +``` + +``` +GET +/wine/search_by_filters +Semantic similarity search for wines by country, price and points (review ratings) +``` + +``` +GET +/wine/count_by_country +Get counts of wines by country +``` + +``` +GET +/wine/count_by_filters +Get counts of wines by country, price and points (review ratings) +``` \ No newline at end of file diff --git a/dbs/weaviate/api/__init__.py b/dbs/weaviate/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dbs/weaviate/api/config.py b/dbs/weaviate/api/config.py new file mode 100644 index 0000000..7740296 --- /dev/null +++ b/dbs/weaviate/api/config.py @@ -0,0 +1,15 @@ +from pydantic import BaseSettings + + +class Settings(BaseSettings): + weaviate_service: str + weaviate_port: str + weaviate_host: str + weaviate_service: str + api_port = str + embedding_model_checkpoint: str + onnx_model_filename: str + tag: str + + class Config: + env_file = ".env" diff --git a/dbs/weaviate/api/main.py b/dbs/weaviate/api/main.py new file mode 100644 index 0000000..ca22ef4 --- /dev/null +++ b/dbs/weaviate/api/main.py @@ -0,0 +1,80 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from functools import lru_cache + +import weaviate +from fastapi import FastAPI + +from api.config import Settings +from api.routers.wine import wine_router + +try: + from optimum.onnxruntime import ORTModelForCustomTasks + from optimum.pipelines import pipeline + from transformers import AutoTokenizer + + model_type = "onnx" +except ModuleNotFoundError: + from sentence_transformers import SentenceTransformer + + model_type = "sbert" + + +@lru_cache() +def get_settings(): + # Use lru_cache to avoid loading .env file for every request + return Settings() + + +def get_embedding_pipeline(onnx_path, model_filename: str): + """ + Create a sentence embedding pipeline using the optimized ONNX model, if available in the environment + """ + # Reload tokenizer + tokenizer = AutoTokenizer.from_pretrained(onnx_path) + optimized_model = ORTModelForCustomTasks.from_pretrained(onnx_path, file_name=model_filename) + embedding_pipeline = pipeline("feature-extraction", model=optimized_model, tokenizer=tokenizer) + return embedding_pipeline + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + """Async context manager for Weaviate database connection.""" + settings = get_settings() + model_checkpoint = settings.embedding_model_checkpoint + if model_type == "sbert": + app.model = SentenceTransformer(model_checkpoint) + app.model_type = "sbert" + elif model_type == "onnx": + app.model = get_embedding_pipeline( + "onnx_model/onnx", model_filename=settings.onnx_model_filename + ) + app.model_type = "onnx" + # Create Weaviate client + HOST = settings.weaviate_service + PORT = settings.weaviate_port + app.client = weaviate.Client(f"http://{HOST}:{PORT}") + print("Successfully connected to Weaviate") + yield + print("Successfully closed Weaviate connection and released resources") + + +app = FastAPI( + title="REST API for wine reviews on Weaviate", + description=( + "Query from a Weaviate database of 130k wine reviews from the Wine Enthusiast magazine" + ), + version=get_settings().tag, + lifespan=lifespan, +) + + +@app.get("/", include_in_schema=False) +async def root(): + return { + "message": "REST API for querying Weaviate database of 130k wine reviews from the Wine Enthusiast magazine" + } + + +# Attach routes +app.include_router(wine_router, prefix="/wine", tags=["wine"]) diff --git a/dbs/weaviate/api/routers/wine.py b/dbs/weaviate/api/routers/wine.py new file mode 100644 index 0000000..64a2156 --- /dev/null +++ b/dbs/weaviate/api/routers/wine.py @@ -0,0 +1,338 @@ +from fastapi import APIRouter, HTTPException, Query, Request +from schemas.retriever import CountByCountry, SimilaritySearch + +wine_router = APIRouter() + + +# --- Routes --- + + +@wine_router.get( + "/search", + response_model=list[SimilaritySearch], + response_description="Search for wines via semantically similar terms", +) +def search_by_similarity( + request: Request, + terms: str = Query( + description="Specify terms to search for in the variety, title and description" + ), +) -> list[SimilaritySearch] | None: + CLASS_NAME = "Wine" + result = _search_by_similarity(request, CLASS_NAME, terms) + if not result: + raise HTTPException( + status_code=404, + detail=f"No wine with the provided terms '{terms}' found in database - please try again", + ) + return result + + +@wine_router.get( + "/search_by_country", + response_model=list[SimilaritySearch], + response_description="Search for wines via semantically similar terms from a particular country", +) +def search_by_similarity_and_country( + request: Request, + terms: str = Query( + description="Specify terms to search for in the variety, title and description" + ), + country: str = Query(description="Country name to search for wines from"), +) -> list[SimilaritySearch] | None: + CLASS_NAME = "Wine" + result = _search_by_similarity_and_country(request, CLASS_NAME, terms, country) + if not result: + raise HTTPException( + status_code=404, + detail=f"No wine with the provided terms '{terms}' found in database - please try again", + ) + return result + + +@wine_router.get( + "/search_by_filters", + response_model=list[SimilaritySearch], + response_description="Search for wines via semantically similar terms with added filters", +) +def search_by_similarity_and_filters( + request: Request, + terms: str = Query( + description="Specify terms to search for in the variety, title and description" + ), + country: str = Query(description="Country name to search for wines from"), + points: int = Query(default=85, description="Minimum number of points for a wine"), + price: float = Query(default=100.0, description="Maximum price for a wine"), +) -> list[SimilaritySearch] | None: + CLASS_NAME = "Wine" + result = _search_by_similarity_and_filters(request, CLASS_NAME, terms, country, points, price) + if not result: + raise HTTPException( + status_code=404, + detail=f"No wine with the provided terms '{terms}' found in database - please try again", + ) + return result + + +@wine_router.get( + "/count_by_country", + response_model=CountByCountry, + response_description="Get counts of wine for a particular country", +) +def count_by_country( + request: Request, + country: str = Query(description="Country name to get counts for"), +) -> CountByCountry | None: + CLASS_NAME = "Wine" + result = _count_by_country(request, CLASS_NAME, country) + if not result: + raise HTTPException( + status_code=404, + detail=f"No wine with the provided country '{country}' found in database - please try again", + ) + return result + + +@wine_router.get( + "/count_by_filters", + response_model=CountByCountry, + response_description="Get counts of wine for a particular country, filtered by points and price", +) +def count_by_filters( + request: Request, + country: str = Query(description="Country name to get counts for"), + points: int = Query(default=85, description="Minimum number of points for a wine"), + price: float = Query(default=100.0, description="Maximum price for a wine"), +) -> CountByCountry | None: + CLASS_NAME = "Wine" + result = _count_by_filters(request, CLASS_NAME, country, points, price) + if not result: + raise HTTPException( + status_code=404, + detail=f"No wine with the provided country '{country}' found in database - please try again", + ) + return result + + +# --- Helper functions --- + + +def _search_by_similarity( + request: Request, class_name: str, terms: str +) -> list[SimilaritySearch] | None: + # Convert input text query into a vector for lookup in the db + if request.app.model_type == "sbert": + vector = request.app.model.encode(terms, show_progress_bar=False, batch_size=128).tolist() + elif request.app.model_type == "onnx": + vector = request.app.model(terms)[0][0] + + near_vec = {"vector": vector} + response = ( + request.app.client.query.get( + class_name, + [ + "wineID", + "title", + "description", + "country", + "province", + "points", + "price", + "variety", + "winery", + "_additional {certainty}", + ], + ) + .with_near_vector(near_vec) + .with_limit(5) + .do() + ) + try: + payload = response["data"]["Get"][class_name] + return payload + except Exception as e: + print(f"Error {e}: Did not obtain appropriate response from Weaviate") + return None + + +def _search_by_similarity_and_country( + request: Request, + class_name: str, + terms: str, + country: str, +) -> list[SimilaritySearch] | None: + # Convert input text query into a vector for lookup in the db + if request.app.model_type == "sbert": + vector = request.app.model.encode(terms, show_progress_bar=False, batch_size=128).tolist() + elif request.app.model_type == "onnx": + vector = request.app.model(terms)[0][0] + + near_vec = {"vector": vector} + where_filter = { + "path": "country", + "operator": "Equal", + "valueText": country, + } + response = ( + request.app.client.query.get( + class_name, + [ + "wineID", + "title", + "description", + "country", + "province", + "points", + "price", + "variety", + "winery", + "_additional {certainty}", + ], + ) + .with_near_vector(near_vec) + .with_where(where_filter) + .with_limit(5) + .do() + ) + try: + payload = response["data"]["Get"][class_name] + return payload + except Exception as e: + print(f"Error {e}: Did not obtain appropriate response from Weaviate") + return None + + +def _search_by_similarity_and_filters( + request: Request, + class_name: str, + terms: str, + country: str, + points: int, + price: float, +) -> list[SimilaritySearch] | None: + # Convert input text query into a vector for lookup in the db + if request.app.model_type == "sbert": + vector = request.app.model.encode(terms, show_progress_bar=False, batch_size=128).tolist() + elif request.app.model_type == "onnx": + vector = request.app.model(terms)[0][0] + + near_vec = {"vector": vector} + where_filter = { + "operator": "And", + "operands": [ + { + "path": "country", + "operator": "Equal", + "valueText": country, + }, + { + "path": "price", + "operator": "LessThan", + "valueNumber": price, + }, + { + "path": "points", + "operator": "GreaterThan", + "valueInt": points, + }, + ], + } + response = ( + request.app.client.query.get( + class_name, + [ + "wineID", + "title", + "description", + "country", + "province", + "points", + "price", + "variety", + "winery", + "_additional {certainty}", + ], + ) + .with_near_vector(near_vec) + .with_where(where_filter) + .with_limit(5) + .do() + ) + try: + payload = response["data"]["Get"][class_name] + return payload + except Exception as e: + print(f"Error {e}: Did not obtain appropriate response from Weaviate") + return None + + +def _count_by_country( + request: Request, + class_name: str, + country: str, +) -> CountByCountry | None: + where_filter = { + "operator": "And", + "operands": [ + { + "path": "country", + "operator": "Equal", + "valueText": country, + } + ], + } + response = ( + request.app.client.query.aggregate(class_name) + .with_where(where_filter) + .with_fields("meta {count}") + .do() + ) + try: + payload = response["data"]["Aggregate"][class_name] + count = payload[0]["meta"] + return count + except Exception as e: + print(f"Error {e}: Did not obtain appropriate response from Weaviate") + return None + + +def _count_by_filters( + request: Request, + class_name: str, + country: str, + points: int, + price: float, +) -> CountByCountry | None: + where_filter = { + "operator": "And", + "operands": [ + { + "path": "country", + "operator": "Equal", + "valueText": country, + }, + { + "path": "price", + "operator": "LessThan", + "valueNumber": price, + }, + { + "path": "points", + "operator": "GreaterThan", + "valueInt": points, + }, + ], + } + response = ( + request.app.client.query.aggregate(class_name) + .with_where(where_filter) + .with_fields("meta {count}") + .do() + ) + try: + payload = response["data"]["Aggregate"][class_name] + count = payload[0]["meta"] + return count + except Exception as e: + print(f"Error {e}: Did not obtain appropriate response from Weaviate") + return None diff --git a/dbs/weaviate/docker-compose-onnx.yml b/dbs/weaviate/docker-compose-onnx.yml new file mode 100644 index 0000000..877451e --- /dev/null +++ b/dbs/weaviate/docker-compose-onnx.yml @@ -0,0 +1,43 @@ +version: "3.9" + +services: + weaviate: + image: semitechnologies/weaviate:${WEAVIATE_VERSION} + ports: + - ${WEAVIATE_PORT}:8080 + restart: on-failure:0 + environment: + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + DEFAULT_VECTORIZER_MODULE: 'none' + CLUSTER_HOSTNAME: 'node1' + volumes: + - weaviate_data:/var/lib/weaviate + networks: + - wine + + fastapi: + image: weaviate_wine_fastapi:${TAG} + build: + context: . + dockerfile: Dockerfile.onnxruntime + restart: unless-stopped + env_file: + - .env + ports: + - ${API_PORT}:8000 + depends_on: + - weaviate + volumes: + - ./:/wine + networks: + - wine + command: uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload + +volumes: + weaviate_data: + +networks: + wine: + driver: bridge \ No newline at end of file diff --git a/dbs/weaviate/docker-compose.yml b/dbs/weaviate/docker-compose.yml new file mode 100644 index 0000000..828c85d --- /dev/null +++ b/dbs/weaviate/docker-compose.yml @@ -0,0 +1,43 @@ +version: "3.9" + +services: + weaviate: + image: semitechnologies/weaviate:${WEAVIATE_VERSION} + ports: + - ${WEAVIATE_PORT}:8080 + restart: on-failure:0 + environment: + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + DEFAULT_VECTORIZER_MODULE: 'none' + CLUSTER_HOSTNAME: 'node1' + volumes: + - weaviate_data:/var/lib/weaviate + networks: + - wine + + fastapi: + image: weaviate_wine_fastapi:${TAG} + build: + context: . + dockerfile: Dockerfile + restart: unless-stopped + env_file: + - .env + ports: + - ${API_PORT}:8000 + depends_on: + - weaviate + volumes: + - ./:/wine + networks: + - wine + command: uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload + +volumes: + weaviate_data: + +networks: + wine: + driver: bridge \ No newline at end of file diff --git a/dbs/weaviate/onnx_model/onnx_optimizer.py b/dbs/weaviate/onnx_model/onnx_optimizer.py new file mode 100644 index 0000000..4ee37ca --- /dev/null +++ b/dbs/weaviate/onnx_model/onnx_optimizer.py @@ -0,0 +1,135 @@ +""" +This script is a modified version of the method shown in this blog post: +https://www.philschmid.de/optimize-sentence-transformers + +It uses the ONNX Runtime to dynamically optimize and quantize a sentence transformers model for better CPU performance. + +Using the quantized version of `sentence-transformers/multi-qa-MiniLM-L6-cos-v1` allows us to: + * Generate similar quality sentence embeddings as the original model, but with a roughly 1.8x speedup in vectorization time + * Reduce the model size from 86 MB to around 63 MB, a roughly 26% reduction in file size +""" +from pathlib import Path + +import torch +import torch.nn.functional as F +from optimum.onnxruntime import ORTModelForCustomTasks, ORTOptimizer, ORTQuantizer +from optimum.onnxruntime.configuration import AutoQuantizationConfig, OptimizationConfig +from sklearn.metrics.pairwise import cosine_similarity +from transformers import AutoModel, AutoTokenizer, Pipeline + + +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[ + 0 + ] # First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) + + +class SentenceEmbeddingPipeline(Pipeline): + def _sanitize_parameters(self, **kwargs): + # We don't have any hyperameters to sanitize + preprocess_kwargs = {} + return preprocess_kwargs, {}, {} + + def preprocess(self, inputs): + encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors="pt") + return encoded_inputs + + def _forward(self, model_inputs): + outputs = self.model(**model_inputs) + return {"outputs": outputs, "attention_mask": model_inputs["attention_mask"]} + + def postprocess(self, model_outputs): + # Perform mean pooling + sentence_embeddings = mean_pooling( + model_outputs["outputs"], model_outputs["attention_mask"] + ) + # Normalize embeddings + sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings + + +def optimize_model(model_id: str, onnx_path: Path) -> None: + """ + Optimize ONNX model for CPU performance + """ + model = ORTModelForCustomTasks.from_pretrained(model_id, export=True) + # Create ORTOptimizer and define optimization configuration + optimizer = ORTOptimizer.from_pretrained(model) + # Save models to local disk + model.save_pretrained(onnx_path) + tokenizer.save_pretrained(onnx_path) + # Set optimization_level = 99 -> enable all optimizations + optimization_config = OptimizationConfig(optimization_level=99) + # Apply the optimization configuration to the model + optimizer.optimize( + optimization_config=optimization_config, + save_dir=onnx_path, + ) + + +def quantize_optimized_model(onnx_path: Path) -> None: + """ + Quantize an already optimized ONNX model for even better CPU performance + """ + # Create ORTQuantizer and define quantization configuration + quantizer = ORTQuantizer.from_pretrained(onnx_path, file_name="model_optimized.onnx") + quantization_config = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=True) + # Apply the quantization configuration to the model + quantizer.quantize( + quantization_config=quantization_config, + save_dir=onnx_path, + ) + + +def generate_similarities(source_sentence: str, sentences: list[str], pipeline: Pipeline) -> None: + source_sentence_embedding = pipeline(source_sentence).tolist()[0] + + for sentence in sentences: + sentence_embedding = pipeline(sentence).tolist()[0] + similarity = cosine_similarity([source_sentence_embedding], [sentence_embedding])[0] + print(f"Similarity between '{source_sentence}' and '{sentence}': {similarity}") + + +def main() -> None: + """ + Generate optimized and quantized ONNX models from a vanilla sentence transformer model + """ + # Init vanilla sentence transformer pipeline + print("---\nLoading vanilla sentence transformer model\n---") + vanilla_pipeline = SentenceEmbeddingPipeline(model=vanilla_model, tokenizer=tokenizer) + # Print out pairwise similarities + generate_similarities(source_sentence, sentences, vanilla_pipeline) + + # Save model to ONNX + Path("onnx").mkdir(exist_ok=True) + onnx_path = Path("onnx") + + # First, dynamically optimize an existing sentence transformer model + optimize_model(model_id, onnx_path) + # Next, dynamically quantize the optimized model + quantize_optimized_model(onnx_path) + + # Init quantized ONNX pipeline + print("---\nLoading quantized ONNX model\n---") + model_filename = "model_optimized_quantized.onnx" + quantized_model = ORTModelForCustomTasks.from_pretrained(onnx_path, file_name=model_filename) + quantized_pipeline = SentenceEmbeddingPipeline(model=quantized_model, tokenizer=tokenizer) + # Print out pairwise similarities + generate_similarities(source_sentence, sentences, quantized_pipeline) + + +if __name__ == "__main__": + # Example sentences we want sentence embeddings for + source_sentence = "I'm very happy" + sentences = ["I am so glad", "I'm so sad", "My dog is missing", "The universe is so vast!"] + + model_id = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" + # Load AutoModel from huggingface model repository + tokenizer = AutoTokenizer.from_pretrained(model_id) + vanilla_model = AutoModel.from_pretrained(model_id) + + main() diff --git a/dbs/weaviate/requirements-onnx.txt b/dbs/weaviate/requirements-onnx.txt new file mode 100644 index 0000000..899a370 --- /dev/null +++ b/dbs/weaviate/requirements-onnx.txt @@ -0,0 +1,86 @@ +aiohttp==3.8.4 +aiosignal==1.3.1 +anyio==3.6.2 +async-timeout==4.0.2 +attrs==23.1.0 +Authlib==1.2.0 +catalogue==2.0.8 +certifi==2022.12.7 +cffi==1.15.1 +charset-normalizer==3.1.0 +click==8.1.3 +cmake==3.26.3 +coloredlogs==15.0.1 +cryptography==40.0.2 +datasets==2.11.0 +decorator==5.1.1 +dill==0.3.6 +evaluate==0.4.0 +fastapi==0.95.1 +filelock==3.12.0 +flatbuffers==23.3.3 +frozenlist==1.3.3 +fsspec==2023.4.0 +grpcio==1.54.0 +grpcio-tools==1.48.2 +h11==0.14.0 +h2==4.1.0 +hpack==4.0.0 +httpcore==0.17.0 +httpx==0.24.0 +huggingface-hub==0.13.4 +humanfriendly==10.0 +hyperframe==6.0.1 +idna==3.4 +Jinja2==3.1.2 +joblib==1.2.0 +lit==16.0.1 +MarkupSafe==2.1.2 +mpmath==1.3.0 +multidict==6.0.4 +multiprocess==0.70.14 +networkx==3.1 +nltk==3.8.1 +numpy==1.24.2 +onnx==1.13.1 +onnxruntime==1.14.1 +optimum==1.8.2 +packaging==23.1 +pandas==2.0.0 +Pillow==9.5.0 +protobuf==3.20.2 +pyarrow==11.0.0 +pycparser==2.21 +pydantic==1.10.7 +python-dateutil==2.8.2 +python-dotenv==1.0.0 +pytz==2023.3 +PyYAML==6.0 +qdrant-client==1.1.5 +regex==2023.3.23 +requests==2.28.2 +responses==0.18.0 +scikit-learn==1.2.2 +scipy==1.10.1 +sentence-transformers==2.2.2 +sentencepiece==0.1.98 +six==1.16.0 +sniffio==1.3.0 +srsly==2.4.6 +starlette==0.26.1 +sympy==1.11.1 +threadpoolctl==3.1.0 +tokenizers==0.13.3 +torch==2.0.0 +torchvision==0.15.1 +tqdm==4.65.0 +transformers==4.28.1 +triton==2.0.0 +typing_extensions==4.5.0 +tzdata==2023.3 +urllib3==1.26.15 +uvicorn==0.21.1 +validators==0.20.0 +weaviate-client==3.16.1 +xxhash==3.2.0 +yarl==1.9.1 diff --git a/dbs/weaviate/requirements.txt b/dbs/weaviate/requirements.txt new file mode 100644 index 0000000..a0e90e7 --- /dev/null +++ b/dbs/weaviate/requirements.txt @@ -0,0 +1,10 @@ +weaviate-client>=3.16.1 +transformers==4.28.1 +sentence-transformers==2.2.2 +pydantic>=1.10.7, <2.0.0 +fastapi>=0.95.0, <1.0.0 +httpx>=0.24.0 +aiohttp>=3.8.4 +uvicorn>=0.21.0, <1.0.0 +python-dotenv>=1.0.0 +srsly>=2.4.6 \ No newline at end of file diff --git a/dbs/weaviate/schemas/__init__.py b/dbs/weaviate/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dbs/weaviate/schemas/retriever.py b/dbs/weaviate/schemas/retriever.py new file mode 100644 index 0000000..0521979 --- /dev/null +++ b/dbs/weaviate/schemas/retriever.py @@ -0,0 +1,32 @@ +from pydantic import BaseModel + + +class SimilaritySearch(BaseModel): + wineID: int + country: str + province: str | None + title: str + description: str | None + points: int + price: float | str | None + variety: str | None + winery: str | None + + class Config: + extra = "ignore" + schema_extra = { + "example": { + "id": 3845, + "country": "Italy", + "title": "Castellinuzza e Piuca 2010 Chianti Classico", + "description": "This gorgeous Chianti Classico boasts lively cherry, strawberry and violet aromas. The mouthwatering palate shows concentrated wild-cherry flavor layered with mint, white pepper and clove. It has fresh acidity and firm tannins that will develop complexity with more bottle age. A textbook Chianti Classico.", + "points": 93, + "price": 16, + "variety": "Red Blend", + "winery": "Castellinuzza e Piuca", + } + } + + +class CountByCountry(BaseModel): + count: int | None diff --git a/dbs/weaviate/schemas/wine.py b/dbs/weaviate/schemas/wine.py new file mode 100644 index 0000000..0aaf930 --- /dev/null +++ b/dbs/weaviate/schemas/wine.py @@ -0,0 +1,69 @@ +from pydantic import BaseModel, root_validator + + +class Wine(BaseModel): + id: int + points: int + title: str + description: str | None + price: float | None + variety: str | None + winery: str | None + vineyard: str | None + country: str | None + province: str | None + region_1: str | None + region_2: str | None + taster_name: str | None + taster_twitter_handle: str | None + + class Config: + extra = "allow" + allow_population_by_field_name = True + validate_assignment = True + schema_extra = { + "example": { + "id": 45100, + "points": 85, + "title": "Balduzzi 2012 Reserva Merlot (Maule Valley)", + "description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.", + "price": 10.0, + "variety": "Merlot", + "winery": "Balduzzi", + "vineyard": "Reserva", + "country": "Chile", + "province": "Maule Valley", + "region_1": "null", + "region_2": "null", + "taster_name": "Michael Schachner", + "taster_twitter_handle": "@wineschach", + } + } + + @root_validator(pre=True) + def _get_vineyard(cls, values): + "Rename designation to vineyard" + vineyard = values.pop("designation", None) + if vineyard: + values["vineyard"] = vineyard.strip() + return values + + @root_validator + def _fill_country_unknowns(cls, values): + "Fill in missing country values with 'Unknown', as we always want this field to be queryable" + country = values.get("country") + if not country: + values["country"] = "Unknown" + return values + + @root_validator + def _add_to_vectorize_fields(cls, values): + "Add a field to_vectorize that will be used to create sentence embeddings" + variety = values.get("variety", "") + country = values.get("country", "Unknown") + province = values.get("province", "") + title = values.get("title", "") + description = values.get("description", "") + to_vectorize = list(filter(None, [variety, country, province, title, description])) + values["to_vectorize"] = " ".join(to_vectorize).strip() + return values diff --git a/dbs/weaviate/scripts/__init__.py b/dbs/weaviate/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dbs/weaviate/scripts/bulk_index_onnx.py b/dbs/weaviate/scripts/bulk_index_onnx.py new file mode 100644 index 0000000..b058f57 --- /dev/null +++ b/dbs/weaviate/scripts/bulk_index_onnx.py @@ -0,0 +1,157 @@ +import argparse +import json +import os +import sys +from functools import lru_cache +from pathlib import Path +from typing import Any, Iterator + +import srsly +import weaviate +from dotenv import load_dotenv +from optimum.onnxruntime import ORTModelForCustomTasks +from optimum.pipelines import pipeline +from pydantic.main import ModelMetaclass +from tqdm import tqdm +from transformers import AutoTokenizer +from weaviate.client import Client + +sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) +from api.config import Settings +from schemas.wine import Wine + +load_dotenv() +# Custom types +JsonBlob = dict[str, Any] + + +class FileNotFoundError(Exception): + pass + + +@lru_cache() +def get_settings(): + # Use lru_cache to avoid loading .env file for every request + return Settings() + + +def chunk_iterable(item_list: list[JsonBlob], chunksize: int) -> Iterator[tuple[JsonBlob, ...]]: + """ + Break a large iterable into an iterable of smaller iterables of size `chunksize` + """ + for i in range(0, len(item_list), chunksize): + yield tuple(item_list[i : i + chunksize]) + + +def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]: + """Get all line-delimited json files (.jsonl) from a directory with a given prefix""" + file_path = data_dir / filename + if not file_path.is_file(): + # File may not have been uncompressed yet so try to do that first + data = srsly.read_gzip_jsonl(file_path) + # This time if it isn't there it really doesn't exist + if not file_path.is_file(): + raise FileNotFoundError(f"No valid .jsonl file found in `{data_dir}`") + else: + data = srsly.read_gzip_jsonl(file_path) + return data + + +def validate( + data: list[JsonBlob], + model: ModelMetaclass, + exclude_none: bool = False, +) -> list[JsonBlob]: + validated_data = [model(**item).dict(exclude_none=exclude_none) for item in data] + return validated_data + + +def get_embedding_pipeline(onnx_path, model_filename: str) -> pipeline: + """ + Create a sentence embedding pipeline using the optimized ONNX model + """ + # Reload tokenizer + tokenizer = AutoTokenizer.from_pretrained(onnx_path) + optimized_model = ORTModelForCustomTasks.from_pretrained(onnx_path, file_name=model_filename) + embedding_pipeline = pipeline("feature-extraction", model=optimized_model, tokenizer=tokenizer) + return embedding_pipeline + + +def create_or_update_schema(client: Client) -> None: + # Create a schema with no vectorizer (we will be adding our own vectors) + with open("settings/schema.json", "r") as f: + schema = json.load(f) + class_names = [class_["class"] for class_ in schema["classes"]] + assert class_names, "No classes found in schema, please check schema definition and try again" + if not client.schema.get()["classes"]: + print(f"Creating schema with classes: {', '.join(class_names)}") + client.schema.create(schema) + else: + print(f"Existing schema found, deleting it & creating it again...") + client.schema.delete_all() + client.schema.create(schema) + + +def main(chunked_data: Iterator[tuple[JsonBlob, ...]]) -> None: + settings = get_settings() + CLASS_NAME = "Wine" + HOST = settings.weaviate_host + PORT = settings.weaviate_port + client = weaviate.Client(f"http://{HOST}:{PORT}") + # Add schema + create_or_update_schema(client) + + # Preload optimized, quantized ONNX sentence transformers model + # NOTE: This requires that the script ../onnx_model/onnx_optimizer.py has been run beforehand + pipeline = get_embedding_pipeline(ONNX_PATH, model_filename="model_optimized_quantized.onnx") + + counter = 0 + for chunk in chunked_data: + orig_data = validate(chunk, Wine, exclude_none=True) + counter += len(orig_data) + ids = [item.pop("id") for item in orig_data] + # Rename "id" (Weaviate reserves the "id" key for its own uuid assignment, so we can't use it) + data = [{"wineID": id, **fields} for id, fields in zip(ids, orig_data)] + to_vectorize = [text.pop("to_vectorize") for text in data] + sentence_embeddings = [] + for text in tqdm(to_vectorize, desc="Generating sentence embeddings"): + sentence_embeddings.append(pipeline(text.lower())[0][0]) + try: + # Use a context manager to manage batch flushing + with client.batch as batch: + batch.batch_size = 64 + batch.dynamic = True + for i, item in enumerate(data): + batch.add_data_object( + item, + CLASS_NAME, + vector=sentence_embeddings[i], + ) + print(f"Indexed ID range {min(ids)}-{max(ids)} to db") + print(f"Indexed {counter} items in total") + except Exception as e: + print(f"{e}: Failed to index items in the ID range {min(ids)}-{max(ids)} to db") + + +if __name__ == "__main__": + # fmt: off + parser = argparse.ArgumentParser("Bulk index database from the wine reviews JSONL data") + parser.add_argument("--limit", type=int, default=0, help="Limit the size of the dataset to load for testing purposes") + parser.add_argument("--chunksize", type=int, default=1024, help="Size of each chunk to break the dataset into before processing") + parser.add_argument("--filename", type=str, default="winemag-data-130k-v2.jsonl.gz", help="Name of the JSONL zip file to use") + args = vars(parser.parse_args()) + # fmt: on + + LIMIT = args["limit"] + DATA_DIR = Path(__file__).parents[3] / "data" + ONNX_PATH = Path(__file__).parents[1] / "onnx_model" / "onnx" + FILENAME = args["filename"] + CHUNKSIZE = args["chunksize"] + + data = list(get_json_data(DATA_DIR, FILENAME)) + + if data: + data = data[:LIMIT] if LIMIT > 0 else data + chunked_data = chunk_iterable(data, CHUNKSIZE) + main(chunked_data) + print("Finished execution!") diff --git a/dbs/weaviate/scripts/bulk_index_sbert.py b/dbs/weaviate/scripts/bulk_index_sbert.py new file mode 100644 index 0000000..11cc456 --- /dev/null +++ b/dbs/weaviate/scripts/bulk_index_sbert.py @@ -0,0 +1,143 @@ +import argparse +import json +import os +import sys +from functools import lru_cache +from pathlib import Path +from typing import Any, Iterator + +import srsly +import weaviate +from dotenv import load_dotenv +from pydantic.main import ModelMetaclass +from sentence_transformers import SentenceTransformer +from tqdm import tqdm +from weaviate.client import Client + +sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) +from api.config import Settings +from schemas.wine import Wine + +load_dotenv() +# Custom types +JsonBlob = dict[str, Any] + + +class FileNotFoundError(Exception): + pass + + +@lru_cache() +def get_settings(): + # Use lru_cache to avoid loading .env file for every request + return Settings() + + +def chunk_iterable(item_list: list[JsonBlob], chunksize: int) -> Iterator[tuple[JsonBlob, ...]]: + """ + Break a large iterable into an iterable of smaller iterables of size `chunksize` + """ + for i in range(0, len(item_list), chunksize): + yield tuple(item_list[i : i + chunksize]) + + +def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]: + """Get all line-delimited json files (.jsonl) from a directory with a given prefix""" + file_path = data_dir / filename + if not file_path.is_file(): + # File may not have been uncompressed yet so try to do that first + data = srsly.read_gzip_jsonl(file_path) + # This time if it isn't there it really doesn't exist + if not file_path.is_file(): + raise FileNotFoundError(f"No valid .jsonl file found in `{data_dir}`") + else: + data = srsly.read_gzip_jsonl(file_path) + return data + + +def validate( + data: list[JsonBlob], + model: ModelMetaclass, + exclude_none: bool = False, +) -> list[JsonBlob]: + validated_data = [model(**item).dict(exclude_none=exclude_none) for item in data] + return validated_data + + +def create_or_update_schema(client: Client) -> None: + # Create a schema with no vectorizer (we will be adding our own vectors) + with open("settings/schema.json", "r") as f: + schema = json.load(f) + class_names = [class_["class"] for class_ in schema["classes"]] + assert class_names, "No classes found in schema, please check schema definition and try again" + if not client.schema.get()["classes"]: + print(f"Creating schema with classes: {', '.join(class_names)}") + client.schema.create(schema) + else: + print(f"Existing schema found, deleting it & creating it again...") + client.schema.delete_all() + client.schema.create(schema) + + +def main(chunked_data: Iterator[tuple[JsonBlob, ...]]) -> None: + settings = get_settings() + CLASS_NAME = "Wine" + HOST = settings.weaviate_host + PORT = settings.weaviate_port + client = weaviate.Client(f"http://{HOST}:{PORT}") + # Add schema + create_or_update_schema(client) + + # Load a sentence transformer model for semantic similarity from a specified checkpoint + model_id = settings.embedding_model_checkpoint + model = SentenceTransformer(model_id) + + counter = 0 + for chunk in chunked_data: + orig_data = validate(chunk, Wine, exclude_none=True) + counter += len(orig_data) + ids = [item.pop("id") for item in orig_data] + # Rename "id" (Weaviate reserves the "id" key for its own uuid assignment, so we can't use it) + data = [{"wineID": id, **fields} for id, fields in zip(ids, orig_data)] + to_vectorize = [text.pop("to_vectorize") for text in data] + sentence_embeddings = [] + for text in tqdm(to_vectorize, desc="Generating sentence embeddings"): + sentence_embeddings.append(model.encode(text)) + try: + # Use a context manager to manage batch flushing + with client.batch as batch: + batch.batch_size = 64 + batch.dynamic = False + for i, item in enumerate(data): + batch.add_data_object( + item, + CLASS_NAME, + vector=sentence_embeddings[i], + ) + print(f"Indexed ID range {min(ids)}-{max(ids)} to db") + print(f"Indexed {counter} items in total") + except Exception as e: + print(f"{e}: Failed to index items in the ID range {min(ids)}-{max(ids)} to db") + + +if __name__ == "__main__": + # fmt: off + parser = argparse.ArgumentParser("Bulk index database from the wine reviews JSONL data") + parser.add_argument("--limit", type=int, default=0, help="Limit the size of the dataset to load for testing purposes") + parser.add_argument("--chunksize", type=int, default=512, help="Size of each chunk to break the dataset into before processing") + parser.add_argument("--filename", type=str, default="winemag-data-130k-v2.jsonl.gz", help="Name of the JSONL zip file to use") + args = vars(parser.parse_args()) + # fmt: on + + LIMIT = args["limit"] + DATA_DIR = Path(__file__).parents[3] / "data" + FILENAME = args["filename"] + CHUNKSIZE = args["chunksize"] + + data = list(get_json_data(DATA_DIR, FILENAME)) + + if data: + data = data[:LIMIT] if LIMIT > 0 else data + chunked_data = chunk_iterable(data, CHUNKSIZE) + main(chunked_data) + print("Finished execution!") diff --git a/dbs/weaviate/scripts/settings/schema.json b/dbs/weaviate/scripts/settings/schema.json new file mode 100644 index 0000000..81496f6 --- /dev/null +++ b/dbs/weaviate/scripts/settings/schema.json @@ -0,0 +1,50 @@ +{ + "classes": [ + { + "class": "Wine", + "vectorizer": "none", + "properties": [ + { + "name": "wineID", + "dataType": ["int"] + }, + { + "name": "points", + "dataType": ["int"] + }, + { + "name": "variety", + "dataType": ["text"] + }, + { + "name": "title", + "dataType": ["text"] + }, + { + "name": "description", + "dataType": ["text"] + }, + { + "name": "price", + "dataType": ["number"] + }, + { + "name": "country", + "dataType": ["text"] + }, + { + "name": "province", + "dataType": ["text"] + }, + { + "name": "taster_name", + "dataType": ["text"] + }, + { + "name": "taster_twitter_handle", + "dataType": ["text"] + } + ] + } + ] +} \ No newline at end of file