From e98a2d1e112a46dc86e4b909363cc83d3a9b14c0 Mon Sep 17 00:00:00 2001 From: willtai Date: Wed, 1 May 2024 10:26:36 +0100 Subject: [PATCH 1/3] Make test folder for src structure and refactored verify_version tests (#23) --- tests/retrievers/test_base.py | 39 +++ tests/retrievers/test_hybrid.py | 231 ++++++++++++++++++ .../test_vector.py} | 228 +---------------- ...{test_queries.py => test_neo4j_queries.py} | 0 4 files changed, 280 insertions(+), 218 deletions(-) create mode 100644 tests/retrievers/test_base.py create mode 100644 tests/retrievers/test_hybrid.py rename tests/{test_retrievers.py => retrievers/test_vector.py} (56%) rename tests/{test_queries.py => test_neo4j_queries.py} (100%) diff --git a/tests/retrievers/test_base.py b/tests/retrievers/test_base.py new file mode 100644 index 000000000..0386b81b6 --- /dev/null +++ b/tests/retrievers/test_base.py @@ -0,0 +1,39 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from neo4j_genai.retrievers.base import Retriever + + +@pytest.mark.parametrize( + "db_version,expected_exception", + [ + (["5.18-aura"], None), + (["5.3-aura"], ValueError), + (["5.19.0"], None), + (["4.3.5"], ValueError), + ], +) +def test_retriever_version_support(driver, db_version, expected_exception): + class MockRetriever(Retriever): + def search(self, *args, **kwargs): + pass + + driver.execute_query.return_value = [[{"versions": db_version}], None, None] + if expected_exception: + with pytest.raises(expected_exception): + MockRetriever(driver=driver) + else: + MockRetriever(driver=driver) diff --git a/tests/retrievers/test_hybrid.py b/tests/retrievers/test_hybrid.py new file mode 100644 index 000000000..b55e3c542 --- /dev/null +++ b/tests/retrievers/test_hybrid.py @@ -0,0 +1,231 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch, MagicMock + +import pytest + +from neo4j_genai import HybridRetriever, HybridCypherRetriever +from neo4j_genai.neo4j_queries import get_search_query +from neo4j_genai.types import SearchType + + +def test_vector_retriever_initialization(driver): + with patch("neo4j_genai.retrievers.base.Retriever._verify_version") as mock_verify: + HybridRetriever( + driver=driver, + vector_index_name="my-index", + fulltext_index_name="fulltext-index", + ) + mock_verify.assert_called_once() + + +def test_vector_cypher_retriever_initialization(driver): + with patch("neo4j_genai.retrievers.base.Retriever._verify_version") as mock_verify: + HybridCypherRetriever( + driver=driver, + vector_index_name="my-index", + fulltext_index_name="fulltext-index", + retrieval_query="", + ) + mock_verify.assert_called_once() + + +@patch("neo4j_genai.HybridRetriever._verify_version") +def test_hybrid_search_text_happy_path(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + vector_index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + retriever = HybridRetriever( + driver, vector_index_name, fulltext_index_name, custom_embeddings + ) + retriever.driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + search_query = get_search_query(SearchType.HYBRID) + + records = retriever.search(query_text=query_text, top_k=top_k) + + retriever.driver.execute_query.assert_called_once_with( + search_query, + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": embed_query_vector, + }, + ) + custom_embeddings.embed_query.assert_called_once_with(query_text) + assert records == [{"node": "dummy-node", "score": 1.0}] + + +@patch("neo4j_genai.HybridRetriever._verify_version") +def test_hybrid_search_favors_query_vector_over_embedding_vector( + _verify_version_mock, driver +): + embed_query_vector = [1.0 for _ in range(1536)] + query_vector = [2.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + vector_index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + retriever = HybridRetriever( + driver, vector_index_name, fulltext_index_name, custom_embeddings + ) + retriever.driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + search_query = get_search_query(SearchType.HYBRID) + + retriever.search(query_text=query_text, query_vector=query_vector, top_k=top_k) + + retriever.driver.execute_query.assert_called_once_with( + search_query, + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": query_vector, + }, + ) + custom_embeddings.embed_query.assert_not_called() + + +def test_error_when_hybrid_search_only_text_no_embedder(hybrid_retriever): + query_text = "may thy knife chip and shatter" + top_k = 5 + + with pytest.raises(ValueError, match="Embedding method required for text query."): + hybrid_retriever.search( + query_text=query_text, + top_k=top_k, + ) + + +def test_hybrid_search_retriever_search_missing_embedder_for_text( + hybrid_retriever, +): + query_text = "may thy knife chip and shatter" + top_k = 5 + + with pytest.raises(ValueError, match="Embedding method required for text query"): + hybrid_retriever.search( + query_text=query_text, + top_k=top_k, + ) + + +@patch("neo4j_genai.HybridRetriever._verify_version") +def test_hybrid_retriever_return_properties(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + vector_index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + return_properties = ["node-property-1", "node-property-2"] + retriever = HybridRetriever( + driver, + vector_index_name, + fulltext_index_name, + custom_embeddings, + return_properties, + ) + driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + search_query = get_search_query(SearchType.HYBRID, return_properties) + + records = retriever.search(query_text=query_text, top_k=top_k) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + driver.execute_query.assert_called_once_with( + search_query, + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": embed_query_vector, + }, + ) + assert records == [{"node": "dummy-node", "score": 1.0}] + + +@patch("neo4j_genai.HybridCypherRetriever._verify_version") +def test_hybrid_cypher_retrieval_query_with_params(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + vector_index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + retrieval_query = """ + RETURN node.id AS node_id, node.text AS text, score, {test: $param} AS metadata + """ + query_params = { + "param": "dummy-param", + } + retriever = HybridCypherRetriever( + driver, + vector_index_name, + fulltext_index_name, + retrieval_query, + custom_embeddings, + ) + driver.execute_query.return_value = [ + [{"node_id": 123, "text": "dummy-text", "score": 1.0}], + None, + None, + ] + search_query = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + + records = retriever.search( + query_text=query_text, + top_k=top_k, + query_params=query_params, + ) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + + driver.execute_query.assert_called_once_with( + search_query, + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": embed_query_vector, + "param": "dummy-param", + }, + ) + + assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] diff --git a/tests/test_retrievers.py b/tests/retrievers/test_vector.py similarity index 56% rename from tests/test_retrievers.py rename to tests/retrievers/test_vector.py index 2f7a6113a..69c1f6156 100644 --- a/tests/test_retrievers.py +++ b/tests/retrievers/test_vector.py @@ -12,46 +12,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import pytest from unittest.mock import patch, MagicMock +import pytest from neo4j.exceptions import CypherSyntaxError -from neo4j_genai import VectorRetriever, VectorCypherRetriever, HybridRetriever -from neo4j_genai.retrievers.hybrid import HybridCypherRetriever -from neo4j_genai.types import VectorSearchRecord, SearchType +from neo4j_genai import VectorRetriever, VectorCypherRetriever from neo4j_genai.neo4j_queries import get_search_query +from neo4j_genai.types import SearchType, VectorSearchRecord -def test_vector_retriever_supported_aura_version(driver): - driver.execute_query.return_value = [[{"versions": ["5.18-aura"]}], None, None] - - VectorRetriever(driver=driver, index_name="my-index") - - -def test_vector_retriever_no_supported_aura_version(driver): - driver.execute_query.return_value = [[{"versions": ["5.3-aura"]}], None, None] - - with pytest.raises(ValueError) as excinfo: +def test_vector_retriever_initialization(driver): + with patch("neo4j_genai.retrievers.base.Retriever._verify_version") as mock_verify: VectorRetriever(driver=driver, index_name="my-index") + mock_verify.assert_called_once() - assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo) - - -def test_vector_retriever_supported_version(driver): - driver.execute_query.return_value = [[{"versions": ["5.19.0"]}], None, None] - - VectorRetriever(driver=driver, index_name="my-index") - -def test_vector_retriever_no_supported_version(driver): - driver.execute_query.return_value = [[{"versions": ["4.3.5"]}], None, None] - - with pytest.raises(ValueError) as excinfo: - VectorRetriever(driver=driver, index_name="my-index") - - assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo) +def test_vector_cypher_retriever_initialization(driver): + with patch("neo4j_genai.retrievers.base.Retriever._verify_version") as mock_verify: + VectorCypherRetriever(driver=driver, index_name="my-index", retrieval_query="") + mock_verify.assert_called_once() @patch("neo4j_genai.VectorRetriever._verify_version") @@ -329,191 +309,3 @@ def test_retrieval_query_cypher_error(_verify_version_mock, driver): query_text=query_text, top_k=top_k, ) - - -@patch("neo4j_genai.HybridRetriever._verify_version") -def test_hybrid_search_text_happy_path(_verify_version_mock, driver): - embed_query_vector = [1.0 for _ in range(1536)] - custom_embeddings = MagicMock() - custom_embeddings.embed_query.return_value = embed_query_vector - vector_index_name = "my-index" - fulltext_index_name = "my-fulltext-index" - query_text = "may thy knife chip and shatter" - top_k = 5 - retriever = HybridRetriever( - driver, vector_index_name, fulltext_index_name, custom_embeddings - ) - retriever.driver.execute_query.return_value = [ - [{"node": "dummy-node", "score": 1.0}], - None, - None, - ] - search_query = get_search_query(SearchType.HYBRID) - - records = retriever.search(query_text=query_text, top_k=top_k) - - retriever.driver.execute_query.assert_called_once_with( - search_query, - { - "vector_index_name": vector_index_name, - "top_k": top_k, - "query_text": query_text, - "fulltext_index_name": fulltext_index_name, - "query_vector": embed_query_vector, - }, - ) - custom_embeddings.embed_query.assert_called_once_with(query_text) - assert records == [{"node": "dummy-node", "score": 1.0}] - - -@patch("neo4j_genai.HybridRetriever._verify_version") -def test_hybrid_search_favors_query_vector_over_embedding_vector( - _verify_version_mock, driver -): - embed_query_vector = [1.0 for _ in range(1536)] - query_vector = [2.0 for _ in range(1536)] - custom_embeddings = MagicMock() - custom_embeddings.embed_query.return_value = embed_query_vector - vector_index_name = "my-index" - fulltext_index_name = "my-fulltext-index" - query_text = "may thy knife chip and shatter" - top_k = 5 - retriever = HybridRetriever( - driver, vector_index_name, fulltext_index_name, custom_embeddings - ) - retriever.driver.execute_query.return_value = [ - [{"node": "dummy-node", "score": 1.0}], - None, - None, - ] - search_query = get_search_query(SearchType.HYBRID) - - retriever.search(query_text=query_text, query_vector=query_vector, top_k=top_k) - - retriever.driver.execute_query.assert_called_once_with( - search_query, - { - "vector_index_name": vector_index_name, - "top_k": top_k, - "query_text": query_text, - "fulltext_index_name": fulltext_index_name, - "query_vector": query_vector, - }, - ) - custom_embeddings.embed_query.assert_not_called() - - -def test_error_when_hybrid_search_only_text_no_embedder(hybrid_retriever): - query_text = "may thy knife chip and shatter" - top_k = 5 - - with pytest.raises(ValueError, match="Embedding method required for text query."): - hybrid_retriever.search( - query_text=query_text, - top_k=top_k, - ) - - -def test_hybrid_search_retriever_search_missing_embedder_for_text( - hybrid_retriever, -): - query_text = "may thy knife chip and shatter" - top_k = 5 - - with pytest.raises(ValueError, match="Embedding method required for text query"): - hybrid_retriever.search( - query_text=query_text, - top_k=top_k, - ) - - -@patch("neo4j_genai.HybridRetriever._verify_version") -def test_hybrid_retriever_return_properties(_verify_version_mock, driver): - embed_query_vector = [1.0 for _ in range(1536)] - custom_embeddings = MagicMock() - custom_embeddings.embed_query.return_value = embed_query_vector - vector_index_name = "my-index" - fulltext_index_name = "my-fulltext-index" - query_text = "may thy knife chip and shatter" - top_k = 5 - return_properties = ["node-property-1", "node-property-2"] - retriever = HybridRetriever( - driver, - vector_index_name, - fulltext_index_name, - custom_embeddings, - return_properties, - ) - driver.execute_query.return_value = [ - [{"node": "dummy-node", "score": 1.0}], - None, - None, - ] - search_query = get_search_query(SearchType.HYBRID, return_properties) - - records = retriever.search(query_text=query_text, top_k=top_k) - - custom_embeddings.embed_query.assert_called_once_with(query_text) - driver.execute_query.assert_called_once_with( - search_query, - { - "vector_index_name": vector_index_name, - "top_k": top_k, - "query_text": query_text, - "fulltext_index_name": fulltext_index_name, - "query_vector": embed_query_vector, - }, - ) - assert records == [{"node": "dummy-node", "score": 1.0}] - - -@patch("neo4j_genai.HybridCypherRetriever._verify_version") -def test_hybrid_cypher_retrieval_query_with_params(_verify_version_mock, driver): - embed_query_vector = [1.0 for _ in range(1536)] - custom_embeddings = MagicMock() - custom_embeddings.embed_query.return_value = embed_query_vector - vector_index_name = "my-index" - fulltext_index_name = "my-fulltext-index" - query_text = "may thy knife chip and shatter" - top_k = 5 - retrieval_query = """ - RETURN node.id AS node_id, node.text AS text, score, {test: $param} AS metadata - """ - query_params = { - "param": "dummy-param", - } - retriever = HybridCypherRetriever( - driver, - vector_index_name, - fulltext_index_name, - retrieval_query, - custom_embeddings, - ) - driver.execute_query.return_value = [ - [{"node_id": 123, "text": "dummy-text", "score": 1.0}], - None, - None, - ] - search_query = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) - - records = retriever.search( - query_text=query_text, - top_k=top_k, - query_params=query_params, - ) - - custom_embeddings.embed_query.assert_called_once_with(query_text) - - driver.execute_query.assert_called_once_with( - search_query, - { - "vector_index_name": vector_index_name, - "top_k": top_k, - "query_text": query_text, - "fulltext_index_name": fulltext_index_name, - "query_vector": embed_query_vector, - "param": "dummy-param", - }, - ) - - assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] diff --git a/tests/test_queries.py b/tests/test_neo4j_queries.py similarity index 100% rename from tests/test_queries.py rename to tests/test_neo4j_queries.py From 4167d4b41a274240ab779618acbddaf594aa8342 Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Fri, 3 May 2024 11:49:45 +0200 Subject: [PATCH 2/3] Add embedder object to complete README example code --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8272e13ca..75016337b 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Assumption: Neo4j running with populated vector index in place. ```python from neo4j import GraphDatabase from neo4j_genai import VectorRetriever +from langchain_openai import OpenAIEmbeddings URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") @@ -40,8 +41,11 @@ INDEX_NAME = "embedding-name" # Connect to Neo4j database driver = GraphDatabase.driver(URI, auth=AUTH) +# Create Embedder object +embedder = OpenAIEmbeddings(model="text-embedding-3-large") + # Initialize the retriever -retriever = VectorRetriever(driver, INDEX_NAME) +retriever = VectorRetriever(driver, INDEX_NAME, embedder) # Run the similarity search query_text = "How do I do similarity search in Neo4j?" From 6492ba666024702e882dd3f55f6c45d16693b7d4 Mon Sep 17 00:00:00 2001 From: willtai Date: Fri, 3 May 2024 11:08:29 +0100 Subject: [PATCH 3/3] Setup E2E Test pipeline and add E2E tests for vector and hybrid retrievers (#24) * Added E2E tests, new GitHub workflow, and separated out unit tests Setup neo4j db for e2e tests * Refactor query tail generation to separate function --------- Co-authored-by: Oskar Hane --- .github/workflows/pr-e2e-tests.yaml | 51 +++++++ .github/workflows/pr.yaml | 4 +- .gitignore | 1 + README.md | 4 +- examples/hybrid_cypher_search.py | 2 +- examples/hybrid_search.py | 2 +- examples/openai_search.py | 2 +- examples/similarity_search_for_text.py | 2 +- examples/vector_cypher_retrieval.py | 2 +- src/neo4j_genai/indexes.py | 2 +- src/neo4j_genai/neo4j_queries.py | 57 ++++---- tests/e2e/conftest.py | 90 ++++++++++++ tests/e2e/test_hybrid_e2e.py | 132 ++++++++++++++++++ tests/e2e/test_vector_e2e.py | 104 ++++++++++++++ tests/test_neo4j_queries.py | 76 ---------- tests/unit/__init__.py | 14 ++ tests/{ => unit}/conftest.py | 8 +- tests/unit/retrievers/__init__.py | 14 ++ tests/{ => unit}/retrievers/test_base.py | 0 tests/{ => unit}/retrievers/test_hybrid.py | 0 tests/{ => unit}/retrievers/test_vector.py | 0 tests/{ => unit}/test_indexes.py | 2 +- tests/unit/test_neo4j_queries.py | 153 +++++++++++++++++++++ 23 files changed, 607 insertions(+), 115 deletions(-) create mode 100644 .github/workflows/pr-e2e-tests.yaml create mode 100644 tests/e2e/conftest.py create mode 100644 tests/e2e/test_hybrid_e2e.py create mode 100644 tests/e2e/test_vector_e2e.py delete mode 100644 tests/test_neo4j_queries.py create mode 100644 tests/unit/__init__.py rename tests/{ => unit}/conftest.py (91%) create mode 100644 tests/unit/retrievers/__init__.py rename tests/{ => unit}/retrievers/test_base.py (100%) rename tests/{ => unit}/retrievers/test_hybrid.py (100%) rename tests/{ => unit}/retrievers/test_vector.py (100%) rename tests/{ => unit}/test_indexes.py (98%) create mode 100644 tests/unit/test_neo4j_queries.py diff --git a/.github/workflows/pr-e2e-tests.yaml b/.github/workflows/pr-e2e-tests.yaml new file mode 100644 index 000000000..b09a148ed --- /dev/null +++ b/.github/workflows/pr-e2e-tests.yaml @@ -0,0 +1,51 @@ +name: 'Neo4j-GenAI PR E2E Tests' + +on: + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number }} + cancel-in-progress: true + +jobs: + e2e-tests: + runs-on: ubuntu-latest + strategy: + matrix: + neo4j-version: + - 5 + neo4j-edition: + - community + - enterprise + services: + neo4j: + image: neo4j:${{ matrix.neo4j-version }}-${{ matrix.neo4j-edition }} + env: + NEO4J_AUTH: neo4j/password + NEO4J_ACCEPT_LICENSE_AGREEMENT: yes + ports: + - 7687:7687 + - 7474:7474 + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + + - name: Configure Poetry + run: | + echo "$HOME/.local/bin" >> $GITHUB_PATH + poetry config virtualenvs.create false + + - name: Install dependencies + run: poetry install + + - name: Run tests + run: poetry run pytest ./tests/e2e diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index eb1606865..f93110e38 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -28,7 +28,7 @@ jobs: run: | poetry run ruff format --check . poetry run ruff check . - - name: Run tests and check coverage + - name: Run unit tests and check coverage run: | - poetry run coverage run -m pytest + poetry run coverage run -m pytest tests/unit poetry run coverage report --fail-under=90 diff --git a/.gitignore b/.gitignore index e9cf65c65..16c7907ed 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist/ htmlcov/ .idea/ .env +docs/build/ diff --git a/README.md b/README.md index 75016337b..bc8eb161d 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ create_vector_index( ### Populating the Neo4j Vector Index -This library does not write to the database, that is up to you. +This library does not write to the database, that is up to you. See below for how to write using Cypher via the Neo4j driver. Assumption: Neo4j running with a defined vector index @@ -165,7 +165,7 @@ Open a new virtual environment and then run the tests. ```bash poetry shell -pytest +pytest tests/unit ``` ## Further information diff --git a/examples/hybrid_cypher_search.py b/examples/hybrid_cypher_search.py index a121b2eb4..9bf8ca231 100644 --- a/examples/hybrid_cypher_search.py +++ b/examples/hybrid_cypher_search.py @@ -58,5 +58,5 @@ def embed_query(self, text: str) -> list[float]: driver.execute_query(insert_query, parameters) # Perform the similarity search for a text query -query_text = "Who are the fremen?" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/hybrid_search.py b/examples/hybrid_search.py index 704a3841a..f035c28c2 100644 --- a/examples/hybrid_search.py +++ b/examples/hybrid_search.py @@ -55,5 +55,5 @@ def embed_query(self, text: str) -> list[float]: driver.execute_query(insert_query, parameters) # Perform the similarity search for a text query -query_text = "Who are the fremen?" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/openai_search.py b/examples/openai_search.py index 87864bdd9..f5e6760d8 100644 --- a/examples/openai_search.py +++ b/examples/openai_search.py @@ -48,5 +48,5 @@ driver.execute_query(insert_query, parameters) # Perform the similarity search for a text query -query_text = "hello world" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index e282b7d57..203c2965e 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -51,5 +51,5 @@ def embed_query(self, text: str) -> list[float]: driver.execute_query(insert_query, parameters) # Perform the similarity search for a text query -query_text = "hello world" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/vector_cypher_retrieval.py b/examples/vector_cypher_retrieval.py index 8ca829e2c..63a818607 100644 --- a/examples/vector_cypher_retrieval.py +++ b/examples/vector_cypher_retrieval.py @@ -63,5 +63,5 @@ def random_str(n: int) -> str: driver.execute_query(insert_query, parameters) # Perform the search -query_text = "Find me the closest text" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=1)) diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index b09901e08..132cc1448 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -115,7 +115,7 @@ def drop_index(driver: Driver, name: str) -> None: driver (Driver): Neo4j Python driver instance. name (str): The name of the index to delete. """ - query = "DROP INDEX $name" + query = "DROP INDEX $name IF EXISTS" parameters = { "name": name, } diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 752677d5f..b9ab366a5 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -23,34 +23,43 @@ def get_search_query( retrieval_query: Optional[str] = None, ): query_map = { - SearchType.VECTOR: ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + SearchType.VECTOR: "".join( + [ + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) ", + "YIELD node, score ", + get_query_tail(retrieval_query, return_properties), + ] ), - SearchType.HYBRID: ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "RETURN node, score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / max) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + SearchType.HYBRID: "".join( + [ + "CALL { ", + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) ", + "YIELD node, score ", + "RETURN node, score UNION ", + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) ", + "YIELD node, score ", + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max ", + "UNWIND nodes AS n ", + "RETURN n.node AS node, (n.score / max) AS score ", + "} ", + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k ", + get_query_tail( + retrieval_query, return_properties, "RETURN node, score" + ), + ] ), } + return query_map[search_type] - base_query = query_map[search_type] - additional_query = "" +def get_query_tail( + retrieval_query: Optional[str] = None, + return_properties: Optional[list[str]] = None, + fallback_return: Optional[str] = None, +) -> str: if retrieval_query: - additional_query += retrieval_query - elif return_properties: + return retrieval_query + if return_properties: return_properties_cypher = ", ".join([f".{prop}" for prop in return_properties]) - additional_query += "YIELD node, score " - additional_query += f"RETURN node {{{return_properties_cypher}}} as node, score" - else: - additional_query += "RETURN node, score" - - return base_query + additional_query + return f"RETURN node {{{return_properties_cypher}}} as node, score" + return fallback_return if fallback_return else "" diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py new file mode 100644 index 000000000..64cd65042 --- /dev/null +++ b/tests/e2e/conftest.py @@ -0,0 +1,90 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import string +import random +import uuid + +import pytest +from neo4j import GraphDatabase +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import drop_index, create_vector_index, create_fulltext_index + + +@pytest.fixture(scope="module") +def driver(): + uri = "neo4j://localhost:7687" + auth = ("neo4j", "password") + driver = GraphDatabase.driver(uri, auth=auth) + yield driver + driver.close() + + +@pytest.fixture(scope="module") +def custom_embedder(): + class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random.random() for _ in range(1536)] + + return CustomEmbedder() + + +@pytest.fixture(scope="module") +def setup_neo4j(driver): + vector_index_name = "vector-index-name" + fulltext_index_name = "fulltext-index-name" + + # Delete data and drop indexes to prevent data leakage + driver.execute_query("MATCH (n) DETACH DELETE n") + drop_index(driver, vector_index_name) + drop_index(driver, fulltext_index_name) + + # Create a vector index + create_vector_index( + driver, + vector_index_name, + label="Document", + property="propertyKey", + dimensions=1536, + similarity_fn="euclidean", + ) + + # Create a fulltext index + create_fulltext_index( + driver, fulltext_index_name, label="Document", node_properties=["propertyKey"] + ) + + # Insert 10 vectors and authors + vector = [random.random() for _ in range(1536)] + + def random_str(n: int) -> str: + return "".join([random.choice(string.ascii_letters) for _ in range(n)]) + + for i in range(10): + insert_query = ( + "MERGE (doc:Document {id: $id})" + "WITH doc " + "CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)" + "WITH doc " + "MERGE (author:Author {name: $authorName})" + "MERGE (doc)-[:AUTHORED_BY]->(author)" + "RETURN doc, author" + ) + + parameters = { + "id": str(uuid.uuid4()), + "vector": vector, + "authorName": random_str(10), + } + driver.execute_query(insert_query, parameters) diff --git a/tests/e2e/test_hybrid_e2e.py b/tests/e2e/test_hybrid_e2e.py new file mode 100644 index 000000000..f8f54466e --- /dev/null +++ b/tests/e2e/test_hybrid_e2e.py @@ -0,0 +1,132 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j import Record + +from neo4j_genai import ( + HybridRetriever, + HybridCypherRetriever, +) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_retriever_search_text(driver, custom_embedder): + retriever = HybridRetriever( + driver, "vector-index-name", "fulltext-index-name", custom_embedder + ) + + top_k = 5 + results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, Record) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_cypher_retriever_search_text(driver, custom_embedder): + retrieval_query = ( + "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" + ) + retriever = HybridCypherRetriever( + driver, + "vector-index-name", + "fulltext-index-name", + retrieval_query, + custom_embedder, + ) + + top_k = 5 + results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for record in results: + assert isinstance(record, Record) + assert "author.name" in record.keys() + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_retriever_search_vector(driver): + retriever = HybridRetriever( + driver, + "vector-index-name", + "fulltext-index-name", + ) + + top_k = 5 + results = retriever.search( + query_text="Find me a book about Fremen", + query_vector=[1.0 for _ in range(1536)], + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, Record) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_cypher_retriever_search_vector(driver): + retrieval_query = ( + "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" + ) + retriever = HybridCypherRetriever( + driver, + "vector-index-name", + "fulltext-index-name", + retrieval_query, + ) + + top_k = 5 + results = retriever.search( + query_text="Find me a book about Fremen", + query_vector=[1.0 for _ in range(1536)], + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 5 + for record in results: + assert isinstance(record, Record) + assert "author.name" in record.keys() + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_retriever_return_properties(driver): + properties = ["name", "age"] + retriever = HybridRetriever( + driver, + "vector-index-name", + "fulltext-index-name", + return_properties=properties, + ) + + top_k = 5 + results = retriever.search( + query_text="Find me a book about Fremen", + query_vector=[1.0 for _ in range(1536)], + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, Record) diff --git a/tests/e2e/test_vector_e2e.py b/tests/e2e/test_vector_e2e.py new file mode 100644 index 000000000..9bf3f5a45 --- /dev/null +++ b/tests/e2e/test_vector_e2e.py @@ -0,0 +1,104 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from neo4j import Record + +from neo4j_genai import VectorRetriever, VectorCypherRetriever +from neo4j_genai.types import VectorSearchRecord + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_retriever_search_text(driver, custom_embedder): + retriever = VectorRetriever(driver, "vector-index-name", custom_embedder) + + top_k = 5 + results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, VectorSearchRecord) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_cypher_retriever_search_text(driver, custom_embedder): + retrieval_query = ( + "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" + ) + retriever = VectorCypherRetriever( + driver, "vector-index-name", retrieval_query, custom_embedder + ) + + top_k = 5 + results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for record in results: + assert isinstance(record, Record) + assert "author.name" in record.keys() + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_retriever_search_vector(driver): + retriever = VectorRetriever(driver, "vector-index-name") + + top_k = 5 + results = retriever.search(query_vector=[1.0 for _ in range(1536)], top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, VectorSearchRecord) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_cypher_retriever_search_vector(driver): + retrieval_query = ( + "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" + ) + retriever = VectorCypherRetriever(driver, "vector-index-name", retrieval_query) + + top_k = 5 + results = retriever.search(query_vector=[1.0 for _ in range(1536)], top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for record in results: + assert isinstance(record, Record) + assert "author.name" in record.keys() + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_retriever_return_properties(driver): + properties = ["name", "age"] + retriever = VectorRetriever( + driver, + "vector-index-name", + return_properties=properties, + ) + + top_k = 5 + results = retriever.search( + query_vector=[1.0 for _ in range(1536)], + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, VectorSearchRecord) diff --git a/tests/test_neo4j_queries.py b/tests/test_neo4j_queries.py deleted file mode 100644 index 555388ae2..000000000 --- a/tests/test_neo4j_queries.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from neo4j_genai.neo4j_queries import get_search_query -from neo4j_genai.types import SearchType - - -def test_vector_search_basic(): - expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " - "RETURN node, score" - ) - result = get_search_query(SearchType.VECTOR) - assert result == expected - - -def test_hybrid_search_basic(): - expected = ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "RETURN node, score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / max) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " - "RETURN node, score" - ) - result = get_search_query(SearchType.HYBRID) - assert result == expected - - -def test_vector_search_with_properties(): - properties = ["name", "age"] - expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " - "YIELD node, score " - "RETURN node {.name, .age} as node, score" - ) - result = get_search_query(SearchType.VECTOR, return_properties=properties) - assert result == expected - - -def test_hybrid_search_with_retrieval_query(): - retrieval_query = "MATCH (n) RETURN n LIMIT 10" - expected = ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "RETURN node, score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / max) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " - + retrieval_query - ) - result = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) - assert result == expected diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 000000000..c0199c144 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/conftest.py b/tests/unit/conftest.py similarity index 91% rename from tests/conftest.py rename to tests/unit/conftest.py index b0359ec0b..b22e58fc6 100644 --- a/tests/conftest.py +++ b/tests/unit/conftest.py @@ -19,18 +19,18 @@ from unittest.mock import MagicMock, patch -@pytest.fixture +@pytest.fixture(scope="function") def driver(): return MagicMock(spec=Driver) -@pytest.fixture +@pytest.fixture(scope="function") @patch("neo4j_genai.VectorRetriever._verify_version") def vector_retriever(_verify_version_mock, driver): return VectorRetriever(driver, "my-index") -@pytest.fixture +@pytest.fixture(scope="function") @patch("neo4j_genai.VectorCypherRetriever._verify_version") def vector_cypher_retriever(_verify_version_mock, driver): retrieval_query = """ @@ -39,7 +39,7 @@ def vector_cypher_retriever(_verify_version_mock, driver): return VectorCypherRetriever(driver, "my-index", retrieval_query) -@pytest.fixture +@pytest.fixture(scope="function") @patch("neo4j_genai.HybridRetriever._verify_version") def hybrid_retriever(_verify_version_mock, driver): return HybridRetriever(driver, "my-index", "my-fulltext-index") diff --git a/tests/unit/retrievers/__init__.py b/tests/unit/retrievers/__init__.py new file mode 100644 index 000000000..c0199c144 --- /dev/null +++ b/tests/unit/retrievers/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/retrievers/test_base.py b/tests/unit/retrievers/test_base.py similarity index 100% rename from tests/retrievers/test_base.py rename to tests/unit/retrievers/test_base.py diff --git a/tests/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py similarity index 100% rename from tests/retrievers/test_hybrid.py rename to tests/unit/retrievers/test_hybrid.py diff --git a/tests/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py similarity index 100% rename from tests/retrievers/test_vector.py rename to tests/unit/retrievers/test_vector.py diff --git a/tests/test_indexes.py b/tests/unit/test_indexes.py similarity index 98% rename from tests/test_indexes.py rename to tests/unit/test_indexes.py index c624607d5..841226845 100644 --- a/tests/test_indexes.py +++ b/tests/unit/test_indexes.py @@ -75,7 +75,7 @@ def test_create_vector_index_validation_error_similarity_fn(driver): def test_drop_index(driver): - drop_query = "DROP INDEX $name" + drop_query = "DROP INDEX $name IF EXISTS" drop_index(driver, "my-index") diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py new file mode 100644 index 000000000..3ce7c7746 --- /dev/null +++ b/tests/unit/test_neo4j_queries.py @@ -0,0 +1,153 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neo4j_genai.neo4j_queries import get_search_query, get_query_tail +from neo4j_genai.types import SearchType + + +def test_vector_search_basic(): + expected = ( + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "YIELD node, score" + ) + result = get_search_query(SearchType.VECTOR) + assert result.strip() == expected.strip() + + +def test_hybrid_search_basic(): + expected = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node, score" + ) + result = get_search_query(SearchType.HYBRID) + assert result.strip() == expected.strip() + + +def test_vector_search_with_properties(): + properties = ["name", "age"] + expected = ( + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node {.name, .age} as node, score" + ) + result = get_search_query(SearchType.VECTOR, return_properties=properties) + assert result.strip() == expected.strip() + + +def test_vector_search_with_retrieval_query(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + expected = ( + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "YIELD node, score " + retrieval_query + ) + result = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + assert result.strip() == expected.strip() + + +def test_hybrid_search_with_retrieval_query(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + expected = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + + retrieval_query + ) + result = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + assert result.strip() == expected.strip() + + +def test_hybrid_search_with_properties(): + properties = ["name", "age"] + expected = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node {.name, .age} as node, score" + ) + result = get_search_query(SearchType.HYBRID, return_properties=properties) + assert result.strip() == expected.strip() + + +def test_get_query_tail_with_retrieval_query(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + expected = retrieval_query + result = get_query_tail(retrieval_query=retrieval_query) + assert result.strip() == expected.strip() + + +def test_get_query_tail_with_properties(): + properties = ["name", "age"] + expected = "RETURN node {.name, .age} as node, score" + result = get_query_tail(return_properties=properties) + assert result.strip() == expected.strip() + + +def test_get_query_tail_with_fallback(): + fallback = "HELLO" + expected = fallback + result = get_query_tail(fallback_return=fallback) + assert result.strip() == expected.strip() + + +def test_get_query_tail_ordering_all(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + properties = ["name", "age"] + fallback = "HELLO" + + expected = retrieval_query + result = get_query_tail( + retrieval_query=retrieval_query, + return_properties=properties, + fallback_return=fallback, + ) + assert result.strip() == expected.strip() + + +def test_get_query_tail_ordering_no_retrieval_query(): + properties = ["name", "age"] + fallback = "HELLO" + + expected = "RETURN node {.name, .age} as node, score" + result = get_query_tail( + return_properties=properties, + fallback_return=fallback, + ) + assert result.strip() == expected.strip()