diff --git a/.github/workflows/pr-e2e-tests.yaml b/.github/workflows/pr-e2e-tests.yaml index 0480f907c..79550dd68 100644 --- a/.github/workflows/pr-e2e-tests.yaml +++ b/.github/workflows/pr-e2e-tests.yaml @@ -45,6 +45,10 @@ jobs: ports: - 7687:7687 - 7474:7474 + qdrant: + image: qdrant/qdrant + ports: + - 6333:6333 steps: - name: Install graphviz package diff --git a/.github/workflows/scheduled-e2e-tests.yaml b/.github/workflows/scheduled-e2e-tests.yaml index 59ddaa425..77e495bc9 100644 --- a/.github/workflows/scheduled-e2e-tests.yaml +++ b/.github/workflows/scheduled-e2e-tests.yaml @@ -52,6 +52,10 @@ jobs: credentials: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} + qdrant: + image: qdrant/qdrant + ports: + - 6333:6333 steps: - name: Install graphviz package diff --git a/CHANGELOG.md b/CHANGELOG.md index a1bcd2a1c..8eb1b98b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ - Added support for Cohere LLM and embeddings - added optional dependency to `cohere`. - Added support for Anthropic LLM - added optional dependency to `anthropic`. - Added support for MistralAI LLM - added optional dependency to `mistralai`. +- Added support for Qdrant - added optional dependency to `qdrant-client`. ### Fixed - Resolved import issue with the Vertex AI Embeddings class. diff --git a/docs/source/api.rst b/docs/source/api.rst index 5eff5add6..7022240d2 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -163,6 +163,12 @@ PineconeNeo4jRetriever .. autoclass:: neo4j_graphrag.retrievers.external.pinecone.pinecone.PineconeNeo4jRetriever :members: search +QdrantNeo4jRetriever +==================== + +.. autoclass:: neo4j_graphrag.retrievers.external.qdrant.qdrant.QdrantNeo4jRetriever + :members: search + ******** Embedder diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index 2de97079e..930b05631 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -327,6 +327,8 @@ We provide implementations for the following retrievers: - Use this retriever when vectors are saved in a Weaviate vector database * - :ref:`PineconeNeo4jRetriever ` - Use this retriever when vectors are saved in a Pinecone vector database + * - :ref:`QdrantNeo4jRetriever ` + - Use this retriever when vectors are saved in a Qdrant vector database Retrievers all expose a `search` method that we will discuss in the next sections. @@ -672,6 +674,35 @@ Pinecone Retrievers Also see :ref:`pineconeneo4jretriever`. +.. _qdrant-neo4j-retriever-user-guide: + +Qdrant Retrievers +----------------- + +.. note:: + + In order to import this retriever, the Qdrant Python client must be installed: + `pip install qdrant-client` + + +.. code:: python + + from qdrant_client import QdrantClient + from neo4j_graphrag.retrievers import QdrantNeo4jRetriever + + client = QdrantClient(...) # construct the Qdrant client instance + + retriever = QdrantNeo4jRetriever( + driver=driver, + client=client, + collection_name="my-collection", + id_property_external="neo4j_id", # The payload field that contains identifier to a corresponding Neo4j node id property + id_property_neo4j="id", + embedder=embedder, + ) + +See :ref:`qdrantneo4jretriever`. + Other Retrievers =================== diff --git a/examples/qdrant/README.md b/examples/qdrant/README.md new file mode 100644 index 000000000..57d9b14f9 --- /dev/null +++ b/examples/qdrant/README.md @@ -0,0 +1,31 @@ +### Start services locally + +Run the following command to spin up Neo4j and Qdrant containers. + +```bash +docker compose -f tests/e2e/docker-compose.yml up +``` + +### Write data (once) + +Run this from the project root to write data to both Neo4J and Qdrant. + +```bash +poetry run python tests/e2e/qdrant_e2e/populate_dbs.py +``` + +### Install Qdrant client + +```bash +pip install qdrant-client +``` + +### Search + +```bash +# search by vector +poetry run python -m examples.qdrant.vector_search + +# search by text, with embeddings generated locally +poetry run python -m examples.qdrant.text_search +``` diff --git a/examples/qdrant/__init__.py b/examples/qdrant/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/qdrant/text_search.py b/examples/qdrant/text_search.py new file mode 100644 index 000000000..5c79fceee --- /dev/null +++ b/examples/qdrant/text_search.py @@ -0,0 +1,27 @@ +from langchain_huggingface.embeddings import HuggingFaceEmbeddings +from neo4j import GraphDatabase +from neo4j_graphrag.retrievers import QdrantNeo4jRetriever +from qdrant_client import QdrantClient + +NEO4J_URL = "neo4j://localhost:7687" +NEO4J_AUTH = ("neo4j", "password") + + +def main() -> None: + with GraphDatabase.driver(NEO4J_URL, auth=NEO4J_AUTH) as neo4j_driver: + embedder = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") + retriever = QdrantNeo4jRetriever( + driver=neo4j_driver, + client=QdrantClient(url="http://localhost:6333"), + collection_name="Jeopardy", + id_property_external="neo4j_id", + id_property_neo4j="id", + embedder=embedder, # type: ignore + ) + + res = retriever.search(query_text="biology", top_k=2) + print(res) + + +if __name__ == "__main__": + main() diff --git a/examples/qdrant/vector_search.py b/examples/qdrant/vector_search.py new file mode 100644 index 000000000..f57d4bb63 --- /dev/null +++ b/examples/qdrant/vector_search.py @@ -0,0 +1,25 @@ +from neo4j import GraphDatabase +from neo4j_graphrag.retrievers import QdrantNeo4jRetriever +from qdrant_client import QdrantClient + +from examples.embedding_biology import EMBEDDING_BIOLOGY + +NEO4J_URL = "neo4j://localhost:7687" +NEO4J_AUTH = ("neo4j", "password") + + +def main() -> None: + with GraphDatabase.driver(NEO4J_URL, auth=NEO4J_AUTH) as neo4j_driver: + retriever = QdrantNeo4jRetriever( + driver=neo4j_driver, + client=QdrantClient(url="http://localhost:6333"), + collection_name="Jeopardy", + id_property_external="neo4j_id", + id_property_neo4j="id", + ) + res = retriever.search(query_vector=EMBEDDING_BIOLOGY, top_k=2) + print(res) + + +if __name__ == "__main__": + main() diff --git a/poetry.lock b/poetry.lock index e2e030c85..ff6b82ba4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1607,6 +1607,32 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "h2" +version = "4.1.0" +description = "HTTP/2 State-Machine based protocol implementation" +optional = false +python-versions = ">=3.6.1" +files = [ + {file = "h2-4.1.0-py3-none-any.whl", hash = "sha256:03a46bcf682256c95b5fd9e9a99c1323584c3eec6440d379b9903d709476bc6d"}, + {file = "h2-4.1.0.tar.gz", hash = "sha256:a83aca08fbe7aacb79fec788c9c0bac936343560ed9ec18b82a13a12c28d2abb"}, +] + +[package.dependencies] +hpack = ">=4.0,<5" +hyperframe = ">=6.0,<7" + +[[package]] +name = "hpack" +version = "4.0.0" +description = "Pure-Python HPACK header compression" +optional = false +python-versions = ">=3.6.1" +files = [ + {file = "hpack-4.0.0-py3-none-any.whl", hash = "sha256:84a076fad3dc9a9f8063ccb8041ef100867b1878b25ef0ee63847a5d53818a6c"}, + {file = "hpack-4.0.0.tar.gz", hash = "sha256:fc41de0c63e687ebffde81187a948221294896f6bdc0ae2312708df339430095"}, +] + [[package]] name = "httpcore" version = "1.0.5" @@ -1642,6 +1668,7 @@ files = [ [package.dependencies] anyio = "*" certifi = "*" +h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""} httpcore = "==1.*" idna = "*" sniffio = "*" @@ -1697,6 +1724,17 @@ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gr torch = ["safetensors[torch]", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] +[[package]] +name = "hyperframe" +version = "6.0.1" +description = "HTTP/2 framing layer for Python" +optional = false +python-versions = ">=3.6.1" +files = [ + {file = "hyperframe-6.0.1-py3-none-any.whl", hash = "sha256:0ec6bafd80d8ad2195c4f03aacba3a8265e57bc4cff261e802bf39970ed02a15"}, + {file = "hyperframe-6.0.1.tar.gz", hash = "sha256:ae510046231dc8e9ecb1a6586f63d2347bf4c8905914aa84ba585ae85f28a914"}, +] + [[package]] name = "identify" version = "2.6.1" @@ -3296,6 +3334,25 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "portalocker" +version = "2.10.1" +description = "Wraps the portalocker recipe for easy usage" +optional = false +python-versions = ">=3.8" +files = [ + {file = "portalocker-2.10.1-py3-none-any.whl", hash = "sha256:53a5984ebc86a025552264b459b46a2086e269b21823cb572f8f28ee759e45bf"}, + {file = "portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f"}, +] + +[package.dependencies] +pywin32 = {version = ">=226", markers = "platform_system == \"Windows\""} + +[package.extras] +docs = ["sphinx (>=1.7.1)"] +redis = ["redis"] +tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)", "types-redis"] + [[package]] name = "pre-commit" version = "3.8.0" @@ -3697,6 +3754,29 @@ files = [ {file = "pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a"}, ] +[[package]] +name = "pywin32" +version = "306" +description = "Python for Window Extensions" +optional = false +python-versions = "*" +files = [ + {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, + {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, + {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"}, + {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"}, + {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"}, + {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"}, + {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"}, + {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"}, + {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"}, + {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"}, + {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"}, + {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"}, + {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"}, + {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, +] + [[package]] name = "pyyaml" version = "6.0.2" @@ -3759,6 +3839,33 @@ files = [ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] +[[package]] +name = "qdrant-client" +version = "1.11.3" +description = "Client library for the Qdrant vector search engine" +optional = false +python-versions = ">=3.8" +files = [ + {file = "qdrant_client-1.11.3-py3-none-any.whl", hash = "sha256:fcf040b58203ed0827608c9ad957da671b1e31bf27e5e35b322c1b577b6ec133"}, + {file = "qdrant_client-1.11.3.tar.gz", hash = "sha256:5a155d8281a224ac18acef512eae2f5e9a0907975d52a7627ec66fa6586d0285"}, +] + +[package.dependencies] +grpcio = ">=1.41.0" +grpcio-tools = ">=1.41.0" +httpx = {version = ">=0.20.0", extras = ["http2"]} +numpy = [ + {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, + {version = ">=1.26", markers = "python_version >= \"3.12\""}, +] +portalocker = ">=2.7.0,<3.0.0" +pydantic = ">=1.10.8" +urllib3 = ">=1.26.14,<3" + +[package.extras] +fastembed = ["fastembed (==0.3.6)"] +fastembed-gpu = ["fastembed-gpu (==0.3.6)"] + [[package]] name = "regex" version = "2024.9.11" @@ -5035,20 +5142,6 @@ files = [ [package.dependencies] types-urllib3 = "*" -[[package]] -name = "types-requests" -version = "2.32.0.20240914" -description = "Typing stubs for requests" -optional = false -python-versions = ">=3.8" -files = [ - {file = "types-requests-2.32.0.20240914.tar.gz", hash = "sha256:2850e178db3919d9bf809e434eef65ba49d0e7e33ac92d588f4a5e295fffd405"}, - {file = "types_requests-2.32.0.20240914-py3-none-any.whl", hash = "sha256:59c2f673eb55f32a99b2894faf6020e1a9f4a402ad0f192bfee0b64469054310"}, -] - -[package.dependencies] -urllib3 = ">=2" - [[package]] name = "types-urllib3" version = "1.26.25.14" @@ -5113,23 +5206,6 @@ brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotl secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] -[[package]] -name = "urllib3" -version = "2.2.3" -description = "HTTP library with thread-safe connection pooling, file post, and more." -optional = false -python-versions = ">=3.8" -files = [ - {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, - {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, -] - -[package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -h2 = ["h2 (>=4,<5)"] -socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] -zstd = ["zstandard (>=0.18.0)"] - [[package]] name = "validators" version = "0.34.0" @@ -5389,10 +5465,10 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [extras] -external-clients = ["anthropic", "cohere", "google-cloud-aiplatform", "mistralai", "pinecone-client", "weaviate-client"] +external-clients = ["anthropic", "cohere", "google-cloud-aiplatform", "mistralai", "pinecone-client", "qdrant-client", "weaviate-client"] kg-creation-tools = ["pygraphviz", "pygraphviz"] [metadata] lock-version = "2.0" python-versions = "^3.9.0" -content-hash = "ecb73179848945e83c04273f59675703c9874be1db45ed9b810890c37b12dbfb" +content-hash = "3b13b34740f98c74b1d6a73331b298d4a31574883764e8aec4df0f9f3b312469" diff --git a/pyproject.toml b/pyproject.toml index 7b9890997..9a8aa63be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ from = "src" python = "^3.9.0" neo4j = "^5.17.0" pydantic = "^2.6.3" +urllib3 = "<2" weaviate-client = {version = "^4.6.1", optional = true} pinecone-client = {version = "^4.1.0", optional = true} types-mock = "^5.1.0.20240425" @@ -45,6 +46,7 @@ google-cloud-aiplatform = {version = "^1.66.0", optional = true} cohere = {version = "^5.9.0", optional = true} anthropic = { version = "^0.34.2", optional = true} mistralai = {version = "^1.0.3", optional = true} +qdrant-client = {version = "^1.11.3", optional = true} [tool.poetry.group.dev.dependencies] pylint = "^3.1.0" @@ -81,10 +83,10 @@ google-cloud-aiplatform = {version = "^1.66.0"} cohere = {version = "^5.9.0"} anthropic = { version = "^0.34.2"} mistralai = {version = "^1.0.3"} - +qdrant-client = {version = "^1.11.3"} [tool.poetry.extras] -external_clients = ["weaviate-client", "pinecone-client", "google-cloud-aiplatform", "cohere", "anthropic", "mistralai"] +external_clients = ["weaviate-client", "pinecone-client", "google-cloud-aiplatform", "cohere", "anthropic", "mistralai", "qdrant-client"] kg_creation_tools = ["pygraphviz"] [build-system] diff --git a/src/neo4j_graphrag/retrievers/__init__.py b/src/neo4j_graphrag/retrievers/__init__.py index 94d0aea4e..595eac93b 100644 --- a/src/neo4j_graphrag/retrievers/__init__.py +++ b/src/neo4j_graphrag/retrievers/__init__.py @@ -40,3 +40,10 @@ __all__.append("WeaviateNeo4jRetriever") except ImportError: pass + +try: + from .external.qdrant.qdrant import QdrantNeo4jRetriever # noqa: F401 + + __all__.append("QdrantNeo4jRetriever") +except ImportError: + pass diff --git a/src/neo4j_graphrag/retrievers/external/qdrant/__init__.py b/src/neo4j_graphrag/retrievers/external/qdrant/__init__.py new file mode 100644 index 000000000..c0199c144 --- /dev/null +++ b/src/neo4j_graphrag/retrievers/external/qdrant/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py b/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py new file mode 100644 index 000000000..37ce358fc --- /dev/null +++ b/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py @@ -0,0 +1,234 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import logging +from typing import Any, Callable, Optional + +import neo4j +from pydantic import ValidationError +from qdrant_client import QdrantClient + +from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.exceptions import ( + EmbeddingRequiredError, + RetrieverInitializationError, + SearchValidationError, +) +from neo4j_graphrag.retrievers.base import ExternalRetriever +from neo4j_graphrag.retrievers.external.qdrant.types import ( + QdrantClientModel, + QdrantNeo4jRetrieverModel, +) +from neo4j_graphrag.retrievers.external.utils import get_match_query +from neo4j_graphrag.types import ( + EmbedderModel, + Neo4jDriverModel, + RawSearchResult, + RetrieverResultItem, + VectorSearchModel, +) + +logger = logging.getLogger(__name__) + + +class QdrantNeo4jRetriever(ExternalRetriever): + """ + Provides retrieval method using vector search over embeddings with a Qdrant database. + + Example: + + .. code-block:: python + + from neo4j import GraphDatabase + from neo4j_graphrag.retrievers import QdrantNeo4jRetriever + from qdrant_client import QdrantClient + + with GraphDatabase.driver(NEO4J_URL, auth=NEO4J_AUTH) as neo4j_driver: + client = QdrantClient() + retriever = QdrantNeo4jRetriever( + driver=neo4j_driver, + client=client, + collection_name="my_collection", + id_property_external="neo4j_id" + ) + embedding = ... + retriever.search(query_vector=embedding, top_k=2) + + Args: + driver (neo4j.Driver): The Neo4j Python driver. + client (QdrantClient): The Qdrant client object. + collection_name (str): The name of the Qdrant collection to use. + id_property_neo4j (str): The name of the Neo4j node property that's used as the identifier for relating matches from Qdrant to Neo4j nodes. + id_property_external (str): The name of the Qdrant payload property with identifier that refers to a corresponding Neo4j node id property. + embedder (Optional[Embedder]): Embedder object to embed query text. + return_properties (Optional[list[str]]): List of node properties to return. + result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem. + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + + Raises: + RetrieverInitializationError: If validation of the input arguments fail. + """ + + def __init__( + self, + driver: neo4j.Driver, + client: QdrantClient, + collection_name: str, + id_property_neo4j: str, + id_property_external: str = "id", + embedder: Optional[Embedder] = None, + return_properties: Optional[list[str]] = None, + retrieval_query: Optional[str] = None, + result_formatter: Optional[ + Callable[[neo4j.Record], RetrieverResultItem] + ] = None, + neo4j_database: Optional[str] = None, + ): + try: + driver_model = Neo4jDriverModel(driver=driver) + client_model = QdrantClientModel(client=client) + embedder_model = EmbedderModel(embedder=embedder) if embedder else None + validated_data = QdrantNeo4jRetrieverModel( + driver_model=driver_model, + client_model=client_model, + collection_name=collection_name, + id_property_neo4j=id_property_neo4j, + id_property_external=id_property_external, + embedder_model=embedder_model, + return_properties=return_properties, + retrieval_query=retrieval_query, + result_formatter=result_formatter, + neo4j_database=neo4j_database, + ) + except ValidationError as e: + raise RetrieverInitializationError(e.errors()) from e + + super().__init__( + driver=driver, + id_property_external=validated_data.id_property_external, + id_property_neo4j=validated_data.id_property_neo4j, + neo4j_database=neo4j_database, + ) + self.driver = validated_data.driver_model.driver + self.client = validated_data.client_model.client + self.collection_name = validated_data.collection_name + self.embedder = ( + validated_data.embedder_model.embedder + if validated_data.embedder_model + else None + ) + self.return_properties = validated_data.return_properties + self.retrieval_query = validated_data.retrieval_query + self.result_formatter = validated_data.result_formatter + + def get_search_results( + self, + query_vector: Optional[list[float]] = None, + query_text: Optional[str] = None, + top_k: int = 5, + **kwargs: Any, + ) -> RawSearchResult: + """Get the top_k nearest neighbour embeddings using Qdrant for either provided query_vector or query_text. + If query_text is provided, then the provided embedder is used to generate the query_vector. + + See the following documentation for more details: + - `Query a vector index `_ + - `db.index.vector.queryNodes() `_ + - `db.index.fulltext.queryNodes() `_ + + + Example: + + .. code-block:: python + + from neo4j import GraphDatabase + from neo4j_graphrag.retrievers import QdrantNeo4jRetriever + from qdrant_client import QdrantClient + + with GraphDatabase.driver(NEO4J_URL, auth=NEO4J_AUTH) as neo4j_driver: + client = QdrantClient() + retriever = QdrantNeo4jRetriever( + driver=neo4j_driver, + client=client, + collection_name="my_collection", + id_property_external="neo4j_id" + ) + embedding = ... + retriever.search(query_vector=embedding, top_k=2) + + + Args: + query_text (str): The text to get the closest neighbours of. + query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbours of. Defaults to None. + top_k (Optional[int]): The number of neighbours to return. Defaults to 5. + kwargs: Additional keyword arguments to pass to QdrantClient#query(). + Raises: + SearchValidationError: If validation of the input arguments fail. + EmbeddingRequiredError: If no embedder is provided when using text as an input. + Returns: + RawSearchResult: The results of the search query as a list of neo4j.Record and an optional metadata dict + """ + + try: + validated_data = VectorSearchModel( + query_vector=query_vector, + query_text=query_text, + top_k=top_k, + ) + except ValidationError as e: + raise SearchValidationError(e.errors()) from e + + if validated_data.query_text: + if self.embedder: + query_vector = self.embedder.embed_query(validated_data.query_text) + logger.debug("Locally generated query vector: %s", query_vector) + else: + logger.error("No embedder provided for query_text.") + raise EmbeddingRequiredError("No embedder provided for query_text.") + + points = self.client.query_points( + collection_name=self.collection_name, + query=query_vector, + limit=top_k, + with_payload=[self.id_property_external], + **kwargs, + ).points + + result_tuples = [] + for point in points: + assert point.payload is not None + result_tuples.append( + [f"{point.payload[self.id_property_external]}", point.score] + ) + + search_query = get_match_query( + return_properties=self.return_properties, + retrieval_query=self.retrieval_query, + ) + + parameters = { + "match_params": result_tuples, + "id_property": self.id_property_neo4j, + } + + logger.debug("Qdrant Store Cypher parameters: %s", parameters) + logger.debug("Qdrant Store Cypher query: %s", search_query) + + records, _, _ = self.driver.execute_query( + search_query, parameters, database_=self.neo4j_database + ) + + return RawSearchResult(records=records) diff --git a/src/neo4j_graphrag/retrievers/external/qdrant/types.py b/src/neo4j_graphrag/retrievers/external/qdrant/types.py new file mode 100644 index 000000000..d731bf735 --- /dev/null +++ b/src/neo4j_graphrag/retrievers/external/qdrant/types.py @@ -0,0 +1,55 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Callable, Optional + +import neo4j +from pydantic import ( + BaseModel, + ConfigDict, + field_validator, +) +from qdrant_client import QdrantClient + +from neo4j_graphrag.types import ( + EmbedderModel, + Neo4jDriverModel, + RetrieverResultItem, +) + + +class QdrantClientModel(BaseModel): + client: QdrantClient + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("client") + def check_client(cls, value: QdrantClient) -> QdrantClient: + if not isinstance(value, QdrantClient): + raise ValueError("Provided client needs to be of type QdrantClient") + return value + + +class QdrantNeo4jRetrieverModel(BaseModel): + driver_model: Neo4jDriverModel + client_model: QdrantClientModel + collection_name: str + id_property_external: str + id_property_neo4j: str + embedder_model: Optional[EmbedderModel] = None + return_properties: Optional[list[str]] = None + retrieval_query: Optional[str] = None + result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None + neo4j_database: Optional[str] = None diff --git a/tests/e2e/docker-compose.yml b/tests/e2e/docker-compose.yml index 3eaa9f065..1ffdb449b 100644 --- a/tests/e2e/docker-compose.yml +++ b/tests/e2e/docker-compose.yml @@ -34,3 +34,7 @@ services: NEO4J_AUTH: neo4j/password NEO4J_ACCEPT_LICENSE_AGREEMENT: "eval" NEO4J_PLUGINS: "[\"apoc\"]" + qdrant: + image: qdrant/qdrant + ports: + - 6333:6333 diff --git a/tests/e2e/qdrant_e2e/__init__.py b/tests/e2e/qdrant_e2e/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/qdrant_e2e/populate_dbs.py b/tests/e2e/qdrant_e2e/populate_dbs.py new file mode 100644 index 000000000..00cfd4133 --- /dev/null +++ b/tests/e2e/qdrant_e2e/populate_dbs.py @@ -0,0 +1,54 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any + +import neo4j +from neo4j import GraphDatabase +from qdrant_client import QdrantClient, models + +from ..utils import build_data_objects, populate_neo4j + + +def populate_dbs( + neo4j_driver: neo4j.Driver, client: QdrantClient, collection_name: str = "Jeopardy" +) -> None: + neo4j_objects, question_objs = build_data_objects("qdrant") + + if client.collection_exists(collection_name): + client.delete_collection(collection_name) + + client.create_collection( + collection_name=collection_name, + vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE), + ) + + populate_qdrant(client, question_objs, collection_name) + + populate_neo4j(neo4j_driver, neo4j_objects) + + +def populate_qdrant( + client: QdrantClient, question_objs: list[Any], collection_name: str +) -> None: + client.upsert(collection_name=collection_name, points=question_objs) + + +if __name__ == "__main__": + NEO4J_AUTH = ("neo4j", "password") + NEO4J_URL = "neo4j://localhost:7687" + with GraphDatabase.driver(NEO4J_URL, auth=NEO4J_AUTH) as neo4j_driver: + populate_dbs(neo4j_driver, QdrantClient(url="http://localhost:6333")) diff --git a/tests/e2e/qdrant_e2e/test_qdrant_e2e.py b/tests/e2e/qdrant_e2e/test_qdrant_e2e.py new file mode 100644 index 000000000..bb7f0d952 --- /dev/null +++ b/tests/e2e/qdrant_e2e/test_qdrant_e2e.py @@ -0,0 +1,105 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any, Generator + +import pytest +from langchain_huggingface.embeddings import HuggingFaceEmbeddings +from neo4j import Driver +from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.retrievers import QdrantNeo4jRetriever +from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem +from qdrant_client import QdrantClient + +from ..utils import EMBEDDING_BIOLOGY +from .populate_dbs import populate_dbs + + +@pytest.fixture(scope="module") +def sentence_transformer_embedder() -> Generator[HuggingFaceEmbeddings, Any, Any]: + embedder = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") + yield embedder + + +@pytest.fixture(scope="module") +def qdrant_client() -> Generator[Any, Any, Any]: + client = QdrantClient(url="http://localhost:6333") + yield client + client.close() + + +@pytest.fixture(scope="module") +def populate_qdrant_neo4j(driver: Driver, qdrant_client: QdrantClient) -> None: + driver.execute_query("MATCH (n) DETACH DELETE n") + populate_dbs(driver, qdrant_client, "Jeopardy") + + +@pytest.mark.usefixtures("populate_qdrant_neo4j") +def test_qdrant_neo4j_vector_input(driver: Driver, qdrant_client: QdrantClient) -> None: + retriever = QdrantNeo4jRetriever( + driver=driver, + client=qdrant_client, + collection_name="Jeopardy", + id_property_external="neo4j_id", + id_property_neo4j="id", + ) + + top_k = 1 + results = retriever.search(query_vector=EMBEDDING_BIOLOGY, top_k=top_k) + + assert isinstance(results, RetrieverResult) + assert len(results.items) == top_k + assert isinstance(results.items[0], RetrieverResultItem) + print("Results are: ", results.items) + pattern = ( + r" " + r"score=0.2[0-9]+>" + ) + assert re.match(pattern, results.items[0].content) + + +@pytest.mark.usefixtures("populate_qdrant_neo4j") +def test_qdrant_neo4j_text_input_local_embedder( + driver: Driver, + qdrant_client: QdrantClient, + sentence_transformer_embedder: Embedder, +) -> None: + retriever = QdrantNeo4jRetriever( + driver=driver, + client=qdrant_client, + collection_name="Jeopardy", + id_property_external="neo4j_id", + id_property_neo4j="id", + embedder=sentence_transformer_embedder, + ) + + top_k = 2 + results = retriever.search(query_text="biology", top_k=top_k) + + assert isinstance(results, RetrieverResult) + assert len(results.items) == top_k + assert isinstance(results.items[0], RetrieverResultItem) + pattern = ( + r" " + r"score=0.2[0-9]+>" + ) + assert re.match(pattern, results.items[0].content) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 934e72fcc..5ef1c63a4 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -22,6 +22,7 @@ import neo4j import weaviate.classes as wvc from neo4j_graphrag.indexes import create_vector_index, drop_index_if_exists +from qdrant_client import models BASE_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -466,7 +467,7 @@ def populate_neo4j( def build_data_objects( - q_vector_fmt: Literal["weaviate", "pinecone", "neo4j"], + q_vector_fmt: Literal["weaviate", "pinecone", "neo4j", "qdrant"], ) -> tuple[dict[str, Any], list[Any]]: # read file from disk # this file is from https://github.com/weaviate-tutorials/quickstart/tree/main/data @@ -488,7 +489,7 @@ def build_data_objects( ] neo4j_objs["nodes"] += unique_categories - for d in data: + for i, d in enumerate(data): id = hashlib.md5(d["Question"].encode()).hexdigest() question_properties = { "id": f"question_{id}", @@ -541,7 +542,17 @@ def build_data_objects( elif q_vector_fmt == "neo4j": # vector inserted into the neo4j object, nothing to do here pass + elif q_vector_fmt == "qdrant": + question_objs.append( + models.PointStruct( + id=i, + payload={"neo4j_id": f"question_{id}"}, + vector=d["vector"], + ) + ) else: - raise ValueError("q_vector_fmt must be either weaviate, pinecone or neo4j") + raise ValueError( + "q_vector_fmt must be either weaviate, pinecone, neo4j or qdrant" + ) return neo4j_objs, question_objs diff --git a/tests/unit/retrievers/external/test_qdrant.py b/tests/unit/retrievers/external/test_qdrant.py new file mode 100644 index 000000000..e5b678a6d --- /dev/null +++ b/tests/unit/retrievers/external/test_qdrant.py @@ -0,0 +1,266 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock +from unittest.mock import MagicMock + +import neo4j +import pytest +from neo4j_graphrag.exceptions import RetrieverInitializationError +from neo4j_graphrag.retrievers import QdrantNeo4jRetriever +from neo4j_graphrag.retrievers.external.utils import get_match_query +from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem +from qdrant_client import QdrantClient +from qdrant_client.http.models import QueryResponse, ScoredPoint + + +@pytest.fixture(scope="function") +def client() -> MagicMock: + return MagicMock(spec=QdrantClient) + + +def test_qdrant_retriever_search_happy_path( + driver: MagicMock, client: MagicMock +) -> None: + retriever = QdrantNeo4jRetriever( + driver=driver, + client=client, + collection_name="dummy-text", + id_property_neo4j="sync_id", + id_property_external="sync_id", + ) + with mock.patch.object(retriever, "client") as mock_client: + top_k = 5 + mock_client.query_points.return_value = QueryResponse( + points=[ + ScoredPoint( + id=i, + version=0, + score=i / top_k, + payload={ + "sync_id": f"node_{i}", + }, + ) + for i in range(top_k) + ] + ) + driver.execute_query.return_value = ( + [ + neo4j.Record({"node": {"sync_id": f"node_{i}"}, "score": i / top_k}) + for i in range(top_k) + ], + None, + None, + ) + query_vector = [1.0 for _ in range(1536)] + search_query = get_match_query() + records = retriever.search(query_vector=query_vector) + + driver.execute_query.assert_called_once_with( + search_query, + { + "match_params": [[f"node_{i}", i / top_k] for i in range(top_k)], + "id_property": "sync_id", + }, + database_=None, + ) + + assert records == RetrieverResult( + items=[ + RetrieverResultItem( + content="", + metadata=None, + ) + for i in range(top_k) + ], + metadata={"__retriever": "QdrantNeo4jRetriever"}, + ) + + +def test_invalid_neo4j_database_name(driver: MagicMock, client: MagicMock) -> None: + with pytest.raises(RetrieverInitializationError) as exc_info: + QdrantNeo4jRetriever( + driver=driver, + client=client, + collection_name="dummy-text", + id_property_neo4j="sync_id", + neo4j_database=42, # type: ignore + ) + + assert "neo4j_database" in str(exc_info.value) + assert "Input should be a valid string" in str(exc_info.value) + + +def test_qdrant_retriever_search_return_properties( + driver: MagicMock, client: MagicMock +) -> None: + retriever = QdrantNeo4jRetriever( + driver=driver, + client=client, + collection_name="dummy-text", + id_property_neo4j="sync_id", + id_property_external="sync_id", + return_properties=["sync_id"], + ) + with mock.patch.object(retriever, "client") as mock_client: + top_k = 5 + mock_client.query_points.return_value = QueryResponse( + points=[ + ScoredPoint( + id=i, + version=0, + score=i / top_k, + payload={ + "sync_id": f"node_{i}", + }, + ) + for i in range(top_k) + ] + ) + driver.execute_query.return_value = ( + [ + neo4j.Record({"node": {"sync_id": f"node_{i}"}, "score": i / top_k}) + for i in range(top_k) + ], + None, + None, + ) + query_vector = [1.0 for _ in range(1536)] + search_query = get_match_query(return_properties=["sync_id"]) + records = retriever.search( + query_vector=query_vector, + ) + + driver.execute_query.assert_called_once_with( + search_query, + { + "match_params": [[f"node_{i}", i / top_k] for i in range(top_k)], + "id_property": "sync_id", + }, + database_=None, + ) + + assert records == RetrieverResult( + items=[ + RetrieverResultItem( + content="", + metadata=None, + ) + for i in range(top_k) + ], + metadata={"__retriever": "QdrantNeo4jRetriever"}, + ) + + +def test_qdrant_retriever_search_retrieval_query( + driver: MagicMock, client: MagicMock +) -> None: + retrieval_query = "WITH node MATCH (node)--(m) RETURN n, m LIMIT 10" + retriever = QdrantNeo4jRetriever( + driver=driver, + client=client, + collection_name="dummy-text", + id_property_neo4j="sync_id", + id_property_external="sync_id", + retrieval_query=retrieval_query, + ) + with mock.patch.object(retriever, "client") as mock_client: + top_k = 5 + mock_client.query_points.return_value = QueryResponse( + points=[ + ScoredPoint( + id=i, + version=0, + score=i / top_k, + payload={ + "sync_id": f"node_{i}", + }, + ) + for i in range(top_k) + ] + ) + driver.execute_query.return_value = ( + [ + neo4j.Record({"node": {"sync_id": f"node_{i}"}, "score": i / top_k}) + for i in range(top_k) + ], + None, + None, + ) + query_vector = [1.0 for _ in range(1536)] + search_query = get_match_query(retrieval_query=retrieval_query) + records = retriever.search( + query_vector=query_vector, + ) + + driver.execute_query.assert_called_once_with( + search_query, + { + "match_params": [[f"node_{i}", i / top_k] for i in range(top_k)], + "id_property": "sync_id", + }, + database_=None, + ) + + assert records == RetrieverResult( + items=[ + RetrieverResultItem( + content="", + metadata=None, + ) + for i in range(top_k) + ], + metadata={"__retriever": "QdrantNeo4jRetriever"}, + ) + + +def test_qdrant_retriever_invalid_return_properties( + driver: MagicMock, client: MagicMock +) -> None: + with pytest.raises(RetrieverInitializationError) as exc_info: + QdrantNeo4jRetriever( + driver=driver, + client=client, + collection_name="dummy-text", + id_property_neo4j="dummy-text", + return_properties=42, # type: ignore + ) + + assert "return_properties" in str(exc_info.value) + assert "Input should be a valid list" in str(exc_info.value) + + +def test_qdrant_retriever_invalid_retrieval_query( + driver: MagicMock, client: MagicMock +) -> None: + with pytest.raises(RetrieverInitializationError) as exc_info: + QdrantNeo4jRetriever( + driver=driver, + client=client, + collection_name="dummy-text", + id_property_neo4j="dummy-text", + retrieval_query=42, # type: ignore + ) + + assert "retrieval_query" in str(exc_info.value) + assert "Input should be a valid string" in str(exc_info.value)