diff --git a/examples/vector_search_with_filters.py b/examples/vector_search_with_filters.py new file mode 100644 index 00000000..bf5fa444 --- /dev/null +++ b/examples/vector_search_with_filters.py @@ -0,0 +1,72 @@ +from neo4j import GraphDatabase +from neo4j_genai import VectorRetriever + +import random +import string +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import create_vector_index + + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +DIMENSION = 1536 + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + + +# Create Embedder object +class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random.random() for _ in range(DIMENSION)] + + +# Generate random strings +def random_str(n: int) -> str: + return "".join([random.choice(string.ascii_letters) for _ in range(n)]) + + +embedder = CustomEmbedder() + +# Creating the index +create_vector_index( + driver, + INDEX_NAME, + label="Document", + property="propertyKey", + dimensions=DIMENSION, + similarity_fn="euclidean", +) + +# Initialize the retriever +retriever = VectorRetriever(driver, INDEX_NAME, embedder) + +# Upsert the query +vector = [random.random() for _ in range(DIMENSION)] +insert_query = ( + "MERGE (doc:Document {id: $id})" + "ON CREATE SET doc.int_property = $id, " + " doc.short_text_property = toString($id)" + "WITH doc " + "CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)" + "WITH doc " + "MERGE (author:Author {name: $authorName})" + "MERGE (doc)-[:AUTHORED_BY]->(author)" + "RETURN doc, author" +) +parameters = { + "id": random.randint(0, 10000), + "vector": vector, + "authorName": random_str(10), +} +driver.execute_query(insert_query, parameters) + +# Perform the search +query_text = "Find me a book about Fremen" +print( + retriever.search( + query_text=query_text, top_k=1, filters={"int_property": {"$gt": 100}} + ) +) diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py new file mode 100644 index 00000000..ec80ac8c --- /dev/null +++ b/src/neo4j_genai/filters.py @@ -0,0 +1,368 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from typing import Any, Type +from collections import Counter + + +DEFAULT_NODE_ALIAS = "node" + + +class Operator: + """Operator classes are helper classes to build the Cypher queries + from a filter like {"field_name": "field_value"} + They implement two important methods: + - lhs: (left hand side): the node + property to be filtered on + + optional operations on it (see ILikeOperator for instance) + - cleaned_value: a method to make sure the provided parameter values are + consistent with the operator (e.g. LIKE operator only works with string values) + """ + + CYPHER_OPERATOR = None + + def __init__(self, node_alias=DEFAULT_NODE_ALIAS): + self.node_alias = node_alias + + @staticmethod + def safe_field_cypher(field_name: str) -> str: + """This method must be used to escape a field name if + necessary to build a valid Cypher query. See: + https://neo4j.com/docs/cypher-manual/current/syntax/naming/ + + Args: + field_name (str): The initial unescaped field name + + Returns: + The field name potentially surrounded with backticks if needed, + ready to be inserted into a Cypher query. + """ + pattern = r"^[a-z_][0-9a-z_]*$" + if re.match(pattern, field_name, re.IGNORECASE): + return field_name + escaped_field = field_name.replace("`", "``") + return f"`{escaped_field}`" + + def lhs(self, field): + safe_field_cypher = self.safe_field_cypher(field) + return f"{self.node_alias}.{safe_field_cypher}" + + def cleaned_value(self, value): + return value + + +class EqOperator(Operator): + CYPHER_OPERATOR = "=" + + +class NeqOperator(Operator): + CYPHER_OPERATOR = "<>" + + +class LtOperator(Operator): + CYPHER_OPERATOR = "<" + + +class GtOperator(Operator): + CYPHER_OPERATOR = ">" + + +class LteOperator(Operator): + CYPHER_OPERATOR = "<=" + + +class GteOperator(Operator): + CYPHER_OPERATOR = ">=" + + +class InOperator(Operator): + CYPHER_OPERATOR = "IN" + + def cleaned_value(self, value): + for val in value: + if not isinstance(val, (str, int, float)): + raise ValueError(f"Unsupported type: {type(val)} for value: {val}") + return value + + +class NinOperator(InOperator): + CYPHER_OPERATOR = "NOT IN" + + +class LikeOperator(Operator): + CYPHER_OPERATOR = "CONTAINS" + + def cleaned_value(self, value): + if not isinstance(value, str): + raise ValueError(f"Expected string value, got {type(value)}: {value}") + return value.rstrip("%") + + +class ILikeOperator(LikeOperator): + def lhs(self, field): + safe_field_cypher = self.safe_field_cypher(field) + return f"toLower({self.node_alias}.{safe_field_cypher})" + + def cleaned_value(self, value): + value = super().cleaned_value(value) + return value.lower() + + +OPERATOR_PREFIX = "$" + +OPERATOR_EQ = "$eq" +OPERATOR_NE = "$ne" +OPERATOR_LT = "$lt" +OPERATOR_LTE = "$lte" +OPERATOR_GT = "$gt" +OPERATOR_GTE = "$gte" +OPERATOR_BETWEEN = "$between" +OPERATOR_IN = "$in" +OPERATOR_NIN = "$nin" +OPERATOR_LIKE = "$like" +OPERATOR_ILIKE = "$ilike" + +OPERATOR_AND = "$and" +OPERATOR_OR = "$or" + +COMPARISONS_TO_NATIVE = { + OPERATOR_EQ: EqOperator, + OPERATOR_NE: NeqOperator, + OPERATOR_LT: LtOperator, + OPERATOR_LTE: LteOperator, + OPERATOR_GT: GtOperator, + OPERATOR_GTE: GteOperator, + OPERATOR_IN: InOperator, + OPERATOR_NIN: NinOperator, + OPERATOR_LIKE: LikeOperator, + OPERATOR_ILIKE: ILikeOperator, +} + + +LOGICAL_OPERATORS = {OPERATOR_AND, OPERATOR_OR} + +SUPPORTED_OPERATORS = ( + set(COMPARISONS_TO_NATIVE).union(LOGICAL_OPERATORS).union({OPERATOR_BETWEEN}) +) + + +class ParameterStore: + """ + Store parameters for a given query. + Determine the parameter name depending on a parameter counter + """ + + def __init__(self): + self._counter = Counter() + self.params = {} + + def _get_params_name(self): + """Find parameter name so that param names are unique. + This function adds a suffix to the key corresponding to the number + of times the key have been used in the query. + E.g. + node.age >= $param_0 AND node.age <= $param_1 + + Args: + key (str): The prefix for the parameter name + Returns: + The full unique parameter name + """ + key = "param" + param_name = f"{key}_{self._counter[key]}" + self._counter[key] += 1 + return param_name + + def add(self, value): + """This function adds a new parameter to the param dict. + It returns the name of the parameter to be used as a placeholder + in the cypher query, e.g. $param_0""" + param_name = self._get_params_name() + self.params[param_name] = value + return param_name + + +def _single_condition_cypher( + field: str, + native_operator_class: Type[Operator], + value: Any, + param_store: ParameterStore, + node_alias: str = DEFAULT_NODE_ALIAS, +) -> str: + """Return Cypher for field operator value. + + Args: + field: The name of the field being filtered + native_operator_class: The operator class that will be used to generate + the Cypher query + value: filtered value + param_store: ParameterStore objet that will be updated in this function + node_alias: Name of the node being filtered in the Cypher query + Returns: + str: The Cypher condition, e.g. node.`property` = $param_0 + + NB: the param_store argument is mutable, it will be updated in this function + """ + native_op = native_operator_class(node_alias=node_alias) + param_name = param_store.add(native_op.cleaned_value(value)) + query_snippet = f"{native_op.lhs(field)} {native_op.CYPHER_OPERATOR} ${param_name}" + return query_snippet + + +def _handle_field_filter( + field: str, + value: Any, + param_store: ParameterStore, + node_alias: str = DEFAULT_NODE_ALIAS, +) -> str: + """Create a filter for a specific field. + + Args: + field: Name of field + value: Value to filter + If provided as is then this will be an equality filter + If provided as a dictionary then this will be a filter, the key + will be the operator and the value will be the value to filter by + param_store: ParameterStore objet that will be updated in this function + node_alias: Name of the node being filtered in the Cypher query + + Returns + str: Cypher filter snippet + + NB: the param_store argument is mutable, it will be updated in this function + """ + # first, perform some sanity checks + if not isinstance(field, str): + raise ValueError( + f"Field should be a string but got: {type(field)} with value: {field}" + ) + + if field.startswith(OPERATOR_PREFIX): + raise ValueError( + f"Invalid filter condition. Expected a field but got an operator: " + f"{field}" + ) + + if isinstance(value, dict): + # This is a filter specification e.g. {"$gte": 0} + if len(value) != 1: + raise ValueError( + "Invalid filter condition. Expected a value which " + "is a dictionary with a single key that corresponds to an operator " + f"but got a dictionary with {len(value)} keys. The first few " + f"keys are: {list(value.keys())[:3]}" + ) + operator, filter_value = list(value.items())[0] + operator = operator.lower() + # Verify that that operator is an operator + if operator not in SUPPORTED_OPERATORS: + raise ValueError( + f"Invalid operator: {operator}. " + f"Expected one of {SUPPORTED_OPERATORS}" + ) + else: # if value is not dict, then we assume an equality operator + operator = OPERATOR_EQ + filter_value = value + + # now everything is set, we can start and build the query + # special case for the BETWEEN operator that requires + # two tests (lower_bound <= value <= higher_bound) + if operator == OPERATOR_BETWEEN: + if len(filter_value) != 2: + raise ValueError( + f"Expected lower and upper bounds in a list, got {filter_value}" + ) + low, high = filter_value + param_name_low = param_store.add(low) + param_name_high = param_store.add(high) + query_snippet = f"${param_name_low} <= {DEFAULT_NODE_ALIAS}.{Operator.safe_field_cypher(field)} <= ${param_name_high}" + return query_snippet + # all the other operators are handled through their own classes: + native_op_class = COMPARISONS_TO_NATIVE[operator] + return _single_condition_cypher( + field, native_op_class, filter_value, param_store, node_alias + ) + + +def _construct_metadata_filter( + filter: dict[str, Any], param_store: ParameterStore, node_alias: str +) -> str: + """Construct a metadata filter. This is a recursive function parsing the filter dict + + Args: + filter: A dictionary representing the filter condition. + param_store: ParameterStore objet that will be updated in this function + node_alias: Name of the node being filtered in the Cypher query + + Returns: + str: The Cypher WHERE clause + + NB: the param_store argument is mutable, it will be updated in this function + """ + + if not isinstance(filter, dict): + raise ValueError(f"Filter must be a dictionary, got {type(filter)}") + # if we have more than one entry, this is an implicit "AND" filter + if len(filter) > 1: + return _construct_metadata_filter( + {OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias + ) + # The only operators allowed at the top level are $AND and $OR + # First check if an operator or a field + key, value = list(filter.items())[0] + if not key.startswith("$"): + # it's not an operator, must be a field + return _handle_field_filter( + key, filter[key], param_store, node_alias=node_alias + ) + + # Here we handle the $and and $or operators + if not isinstance(value, list): + raise ValueError(f"Expected a list, but got {type(value)} for value: {value}") + if key.lower() == OPERATOR_AND: + cypher_operator = " AND " + elif key.lower() == OPERATOR_OR: + cypher_operator = " OR " + else: + raise ValueError(f"Unsupported operator: {key}") + query = cypher_operator.join( + [ + f"({ _construct_metadata_filter(el, param_store, node_alias)})" + for el in value + ] + ) + return query + + +def get_metadata_filter( + filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS +) -> tuple[str, dict]: + """Construct the cypher filter snippet based on a filter dict + + Note: the _construct_metadata_filter function is not thread-safe because + of the ParameterStore object. + + Args: + filter (dict): The filters to be converted to Cypher + node_alias (str): The alias of node the filters must be applied on + in the Cypher query + + Return: + A tuple of str, dict where the string is the cypher query and the dict + contains the query parameters + """ + param_store = ParameterStore() + return _construct_metadata_filter( + filter, param_store, node_alias=node_alias + ), param_store.params diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index 132cc144..32ace578 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -13,13 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j import Driver +import neo4j from pydantic import ValidationError from .types import VectorIndexModel, FulltextIndexModel +import logging + + +logger = logging.getLogger(__name__) def create_vector_index( - driver: Driver, + driver: neo4j.Driver, name: str, label: str, property: str, @@ -32,8 +36,11 @@ def create_vector_index( See Cypher manual on [Create vector index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/vector-indexes/#indexes-vector-create) + Important: This operation will fail if an index with the same name already exists. + Ensure that the index name provided is unique within the database context. + Args: - driver (Driver): Neo4j Python driver instance. + driver (neo4j.Driver): Neo4j Python driver instance. name (str): The unique name of the index. label (str): The node label to be indexed. property (str): The property key of a node which contains embedding values. @@ -43,6 +50,7 @@ def create_vector_index( Raises: ValueError: If validation of the input arguments fail. + neo4j.exceptions.ClientError: If creation of vector index fails. """ try: VectorIndexModel( @@ -58,17 +66,23 @@ def create_vector_index( except ValidationError as e: raise ValueError(f"Error for inputs to create_vector_index {str(e)}") - query = ( - f"CREATE VECTOR INDEX $name FOR (n:{label}) ON n.{property} OPTIONS " - "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" - ) - driver.execute_query( - query, {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn} - ) + try: + query = ( + f"CREATE VECTOR INDEX $name FOR (n:{label}) ON n.{property} OPTIONS " + "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" + ) + logger.info(f"Creating vector index named '{name}'") + driver.execute_query( + query, + {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn}, + ) + except neo4j.exceptions.ClientError as e: + logger.error(f"Neo4j vector index creation failed {e}") + raise def create_fulltext_index( - driver: Driver, name: str, label: str, node_properties: list[str] + driver: neo4j.Driver, name: str, label: str, node_properties: list[str] ) -> None: """ This method constructs a Cypher query and executes it @@ -76,14 +90,18 @@ def create_fulltext_index( See Cypher manual on [Create fulltext index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/full-text-indexes/#create-full-text-indexes) + Important: This operation will fail if an index with the same name already exists. + Ensure that the index name provided is unique within the database context. + Args: - driver (Driver): Neo4j Python driver instance. + driver (neo4j.Driver): Neo4j Python driver instance. name (str): The unique name of the index. label (str): The node label to be indexed. node_properties (list[str]): The node properties to create the fulltext index on. Raises: ValueError: If validation of the input arguments fail. + neo4j.exceptions.ClientError: If creation of fulltext index fails. """ try: FulltextIndexModel( @@ -97,26 +115,39 @@ def create_fulltext_index( except ValidationError as e: raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}") - query = ( - "CREATE FULLTEXT INDEX $name " - f"FOR (n:`{label}`) ON EACH " - f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]" - ) - driver.execute_query(query, {"name": name}) + try: + query = ( + "CREATE FULLTEXT INDEX $name " + f"FOR (n:`{label}`) ON EACH " + f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]" + ) + logger.info(f"Creating fulltext index named '{name}'") + driver.execute_query(query, {"name": name}) + except neo4j.exceptions.ClientError as e: + logger.error(f"Neo4j fulltext index creation failed {e}") + raise -def drop_index(driver: Driver, name: str) -> None: +def drop_index_if_exists(driver: neo4j.Driver, name: str) -> None: """ This method constructs a Cypher query and executes it - to drop a vector index in Neo4j. + to drop an index in Neo4j, if the index exists. See Cypher manual on [Drop vector indexes](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-drop) Args: - driver (Driver): Neo4j Python driver instance. + driver (neo4j.Driver): Neo4j Python driver instance. name (str): The name of the index to delete. + + Raises: + neo4j.exceptions.ClientError: If dropping of index fails. """ - query = "DROP INDEX $name IF EXISTS" - parameters = { - "name": name, - } - driver.execute_query(query, parameters) + try: + query = "DROP INDEX $name IF EXISTS" + parameters = { + "name": name, + } + logger.info(f"Dropping index named '{name}'") + driver.execute_query(query, parameters) + except neo4j.exceptions.ClientError as e: + logger.error(f"Dropping Neo4j index failed {e}") + raise diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index b9ab366a..e974a047 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -12,51 +12,163 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Any from neo4j_genai.types import SearchType +from neo4j_genai.filters import get_metadata_filter + + +VECTOR_INDEX_QUERY = ( + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score" +) + +VECTOR_EXACT_QUERY = ( + "WITH node, " + "vector.similarity.cosine(node.`{embedding_node_property}`, $query_vector) AS score " + "ORDER BY score DESC LIMIT $top_k" +) + +BASE_VECTOR_EXACT_QUERY = ( + "MATCH (node:`{node_label}`) " + "WHERE node.`{embedding_node_property}` IS NOT NULL " + "AND size(node.`{embedding_node_property}`) = toInteger($embedding_dimension)" +) + +FULL_TEXT_SEARCH_QUERY = ( + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score" +) + + +def _get_hybrid_query() -> str: + return ( + f"CALL {{ {VECTOR_INDEX_QUERY} " + f"RETURN node, score " + f"UNION " + f"{FULL_TEXT_SEARCH_QUERY} " + f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS max " + f"UNWIND nodes AS n " + f"RETURN n.node AS node, (n.score / max) AS score }} " + f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k" + ) + + +def _get_filtered_vector_query( + filters: dict[str, Any], + node_label: str, + embedding_node_property: str, + embedding_dimension: int, +) -> tuple[str, dict[str, Any]]: + """Build Cypher query for vector search with filters + Uses exact KNN. + + Args: + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + + Returns: + tuple[str, dict[str, Any]]: query and parameters + """ + where_filters, query_params = get_metadata_filter(filters, node_alias="node") + base_query = BASE_VECTOR_EXACT_QUERY.format( + node_label=node_label, + embedding_node_property=embedding_node_property, + ) + vector_query = VECTOR_EXACT_QUERY.format( + embedding_node_property=embedding_node_property, + ) + query_params["embedding_dimension"] = embedding_dimension + return f"{base_query} AND ({where_filters}) {vector_query}", query_params + + +def _get_vector_query( + filters: Optional[dict[str, Any]], + node_label: str, + embedding_node_property: str, + embedding_dimension: int, +) -> tuple[str, dict[str, Any]]: + """Build the vector query with or without filters + + Args: + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + + Returns: + tuple[str, dict[str, Any]]: query and parameters + + """ + if filters: + return _get_filtered_vector_query( + filters, node_label, embedding_node_property, embedding_dimension + ) + return VECTOR_INDEX_QUERY, {} def get_search_query( search_type: SearchType, return_properties: Optional[list[str]] = None, retrieval_query: Optional[str] = None, -): - query_map = { - SearchType.VECTOR: "".join( - [ - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) ", - "YIELD node, score ", - get_query_tail(retrieval_query, return_properties), - ] - ), - SearchType.HYBRID: "".join( - [ - "CALL { ", - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) ", - "YIELD node, score ", - "RETURN node, score UNION ", - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) ", - "YIELD node, score ", - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max ", - "UNWIND nodes AS n ", - "RETURN n.node AS node, (n.score / max) AS score ", - "} ", - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k ", - get_query_tail( - retrieval_query, return_properties, "RETURN node, score" - ), - ] - ), - } - return query_map[search_type] - - -def get_query_tail( + node_label: Optional[str] = None, + embedding_node_property: Optional[str] = None, + embedding_dimension: Optional[int] = None, + filters: Optional[dict[str, Any]] = None, +) -> tuple[str, dict[str, Any]]: + """Build the search query, including pre-filtering if needed, and return clause. + + Args + search_type: Search type we want to search for: + return_properties (list[str]): list of property names to return. + It can't be provided together with retrieval_query. + retrieval_query (str): the query to use to retrieve the search results + It can't be provided together with return_properties. + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + + Returns: + tuple[str, dict[str, Any]]: query and parameters + + """ + if search_type == SearchType.HYBRID: + if filters: + raise Exception("Filters is not supported with Hybrid Search") + query = _get_hybrid_query() + params = {} + elif search_type == SearchType.VECTOR: + query, params = _get_vector_query( + filters, node_label, embedding_node_property, embedding_dimension + ) + else: + raise ValueError(f"Search type is not supported: {search_type}") + query_tail = _get_query_tail( + retrieval_query, return_properties, fallback_return="RETURN node, score" + ) + return f"{query} {query_tail}", params + + +def _get_query_tail( retrieval_query: Optional[str] = None, return_properties: Optional[list[str]] = None, fallback_return: Optional[str] = None, ) -> str: + """Build the RETURN statement after the search is performed + + Args + return_properties (list[str]): list of property names to return. + It can't be provided together with retrieval_query. + retrieval_query (str): the query to use to retrieve the search results + It can't be provided together with return_properties. + fallback_return (str): the fallback return statement to use to retrieve the search results + + Returns: + str: the RETURN statement + """ if retrieval_query: return retrieval_query if return_properties: diff --git a/src/neo4j_genai/retrievers/base.py b/src/neo4j_genai/retrievers/base.py index dc483eb6..cbf8abb3 100644 --- a/src/neo4j_genai/retrievers/base.py +++ b/src/neo4j_genai/retrievers/base.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from typing import Any -from neo4j import Driver +import neo4j class Retriever(ABC): @@ -23,7 +23,7 @@ class Retriever(ABC): Abstract class for Neo4j retrievers """ - def __init__(self, driver: Driver): + def __init__(self, driver: neo4j.Driver): self.driver = driver self._verify_version() @@ -57,3 +57,21 @@ def _verify_version(self) -> None: @abstractmethod def search(self, *args, **kwargs) -> Any: pass + + def _fetch_index_infos(self): + """Fetch the node label and embedding property from the index definition""" + query = ( + "SHOW VECTOR INDEXES " + "YIELD name, labelsOrTypes, properties, options " + "WHERE name = $index_name " + "RETURN labelsOrTypes as labels, properties, " + "options.indexConfig.`vector.dimensions` as dimensions" + ) + result = self.driver.execute_query(query, {"index_name": self.index_name}) + try: + result = result.records[0] + except IndexError: + raise Exception(f"No index with name {self.index_name} found") + self._node_label = result["labels"][0] + self._embedding_node_property = result["properties"][0] + self._embedding_dimension = result["dimensions"] diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py index 0690555a..fea96a2d 100644 --- a/src/neo4j_genai/retrievers/hybrid.py +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -14,7 +14,7 @@ # limitations under the License. from typing import Optional, Any -from neo4j import Record, Driver +import neo4j from pydantic import ValidationError from neo4j_genai.embedder import Embedder @@ -29,7 +29,7 @@ class HybridRetriever(Retriever): def __init__( self, - driver: Driver, + driver: neo4j.Driver, vector_index_name: str, fulltext_index_name: str, embedder: Optional[Embedder] = None, @@ -46,7 +46,7 @@ def search( query_text: str, query_vector: Optional[list[float]] = None, top_k: int = 5, - ) -> list[Record]: + ) -> list[neo4j.Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. Both query_vector and query_text can be provided. If query_vector is provided, then it will be preferred over the embedded query_text @@ -63,7 +63,7 @@ def search( ValueError: If validation of the input arguments fail. ValueError: If no embedder is provided. Returns: - list[Record]: The results of the search query + list[neo4j.Record]: The results of the search query """ try: validated_data = HybridSearchModel( @@ -84,7 +84,7 @@ def search( query_vector = self.embedder.embed_query(query_text) parameters["query_vector"] = query_vector - search_query = get_search_query(SearchType.HYBRID, self.return_properties) + search_query, _ = get_search_query(SearchType.HYBRID, self.return_properties) logger.debug("HybridRetriever Cypher parameters: %s", parameters) logger.debug("HybridRetriever Cypher query: %s", search_query) @@ -96,7 +96,7 @@ def search( class HybridCypherRetriever(Retriever): def __init__( self, - driver: Driver, + driver: neo4j.Driver, vector_index_name: str, fulltext_index_name: str, retrieval_query: str, @@ -114,7 +114,7 @@ def search( query_vector: Optional[list[float]] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, - ) -> list[Record]: + ) -> list[neo4j.Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. Both query_vector and query_text can be provided. If query_vector is provided, then it will be preferred over the embedded query_text @@ -132,7 +132,7 @@ def search( ValueError: If validation of the input arguments fail. ValueError: If no embedder is provided. Returns: - list[Record]: The results of the search query + list[neo4j.Record]: The results of the search query """ try: validated_data = HybridCypherSearchModel( @@ -160,7 +160,7 @@ def search( parameters[key] = value del parameters["query_params"] - search_query = get_search_query( + search_query, _ = get_search_query( SearchType.HYBRID, retrieval_query=self.retrieval_query ) diff --git a/src/neo4j_genai/retrievers/vector.py b/src/neo4j_genai/retrievers/vector.py index 954cd04e..32314801 100644 --- a/src/neo4j_genai/retrievers/vector.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -14,7 +14,7 @@ # limitations under the License. from typing import Optional, Any -from neo4j import Driver, Record +import neo4j from neo4j_genai.retrievers.base import Retriever from pydantic import ValidationError @@ -39,7 +39,7 @@ class VectorRetriever(Retriever): def __init__( self, - driver: Driver, + driver: neo4j.Driver, index_name: str, embedder: Optional[Embedder] = None, return_properties: Optional[list[str]] = None, @@ -48,12 +48,17 @@ def __init__( self.index_name = index_name self.return_properties = return_properties self.embedder = embedder + self._node_label = None + self._embedding_node_property = None + self._embedding_dimension = None + self._fetch_index_infos() def search( self, query_vector: Optional[list[float]] = None, query_text: Optional[str] = None, top_k: int = 5, + filters: Optional[dict[str, Any]] = None, ) -> list[VectorSearchRecord]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -75,7 +80,7 @@ def search( """ try: validated_data = VectorSearchModel( - index_name=self.index_name, + vector_index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, @@ -93,7 +98,15 @@ def search( parameters["query_vector"] = query_vector del parameters["query_text"] - search_query = get_search_query(SearchType.VECTOR, self.return_properties) + search_query, search_params = get_search_query( + SearchType.VECTOR, + self.return_properties, + node_label=self._node_label, + embedding_node_property=self._embedding_node_property, + embedding_dimension=self._embedding_dimension, + filters=filters, + ) + parameters.update(search_params) logger.debug("VectorRetriever Cypher parameters: %s", parameters) logger.debug("VectorRetriever Cypher query: %s", search_query) @@ -120,7 +133,7 @@ class VectorCypherRetriever(Retriever): def __init__( self, - driver: Driver, + driver: neo4j.Driver, index_name: str, retrieval_query: str, embedder: Optional[Embedder] = None, @@ -129,6 +142,10 @@ def __init__( self.index_name = index_name self.retrieval_query = retrieval_query self.embedder = embedder + self._node_label = None + self._node_embedding_property = None + self._embedding_dimension = None + self._fetch_index_infos() def search( self, @@ -136,7 +153,8 @@ def search( query_text: Optional[str] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, - ) -> list[Record]: + filters: Optional[dict[str, Any]] = None, + ) -> list[neo4j.Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -154,11 +172,11 @@ def search( ValueError: If no embedder is provided. Returns: - list[Record]: The results of the search query + list[neo4j.Record]: The results of the search query """ try: validated_data = VectorCypherSearchModel( - index_name=self.index_name, + vector_index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, @@ -181,9 +199,15 @@ def search( parameters[key] = value del parameters["query_params"] - search_query = get_search_query( - SearchType.VECTOR, retrieval_query=self.retrieval_query + search_query, search_params = get_search_query( + SearchType.VECTOR, + retrieval_query=self.retrieval_query, + node_label=self._node_label, + embedding_node_property=self._node_embedding_property, + embedding_dimension=self._embedding_dimension, + filters=filters, ) + parameters.update(search_params) logger.debug("VectorCypherRetriever Cypher parameters: %s", parameters) logger.debug("VectorCypherRetriever Cypher query: %s", search_query) diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 67a31175..357bb44e 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -15,7 +15,7 @@ from enum import Enum from typing import Any, Literal, Optional from pydantic import BaseModel, PositiveInt, model_validator, field_validator -from neo4j import Driver +import neo4j class VectorSearchRecord(BaseModel): @@ -28,7 +28,7 @@ class IndexModel(BaseModel): @field_validator("driver") def check_driver_is_valid(cls, v): - if not isinstance(v, Driver): + if not isinstance(v, neo4j.Driver): raise ValueError("driver must be an instance of neo4j.Driver") return v @@ -54,7 +54,7 @@ def check_node_properties_not_empty(cls, v): class VectorSearchModel(BaseModel): - index_name: str + vector_index_name: str top_k: PositiveInt = 5 query_vector: Optional[list[float]] = None query_text: Optional[str] = None diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 64cd6504..5e5f3f97 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -19,7 +19,11 @@ import pytest from neo4j import GraphDatabase from neo4j_genai.embedder import Embedder -from neo4j_genai.indexes import drop_index, create_vector_index, create_fulltext_index +from neo4j_genai.indexes import ( + drop_index_if_exists, + create_vector_index, + create_fulltext_index, +) @pytest.fixture(scope="module") @@ -47,8 +51,8 @@ def setup_neo4j(driver): # Delete data and drop indexes to prevent data leakage driver.execute_query("MATCH (n) DETACH DELETE n") - drop_index(driver, vector_index_name) - drop_index(driver, fulltext_index_name) + drop_index_if_exists(driver, vector_index_name) + drop_index_if_exists(driver, fulltext_index_name) # Create a vector index create_vector_index( @@ -74,6 +78,8 @@ def random_str(n: int) -> str: for i in range(10): insert_query = ( "MERGE (doc:Document {id: $id})" + "ON CREATE SET doc.int_property = $i, " + " doc.short_text_property = toString($i)" "WITH doc " "CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)" "WITH doc " @@ -84,7 +90,8 @@ def random_str(n: int) -> str: parameters = { "id": str(uuid.uuid4()), + "i": i, "vector": vector, - "authorName": random_str(10), + "authorName": random_str(1536), } driver.execute_query(insert_query, parameters) diff --git a/tests/e2e/test_hybrid_e2e.py b/tests/e2e/test_hybrid_e2e.py index f8f54466..3ba48c62 100644 --- a/tests/e2e/test_hybrid_e2e.py +++ b/tests/e2e/test_hybrid_e2e.py @@ -16,7 +16,7 @@ import pytest -from neo4j import Record +import neo4j from neo4j_genai import ( HybridRetriever, @@ -36,7 +36,7 @@ def test_hybrid_retriever_search_text(driver, custom_embedder): assert isinstance(results, list) assert len(results) == 5 for result in results: - assert isinstance(result, Record) + assert isinstance(result, neo4j.Record) @pytest.mark.usefixtures("setup_neo4j") @@ -58,7 +58,7 @@ def test_hybrid_cypher_retriever_search_text(driver, custom_embedder): assert isinstance(results, list) assert len(results) == 5 for record in results: - assert isinstance(record, Record) + assert isinstance(record, neo4j.Record) assert "author.name" in record.keys() @@ -80,7 +80,7 @@ def test_hybrid_retriever_search_vector(driver): assert isinstance(results, list) assert len(results) == 5 for result in results: - assert isinstance(result, Record) + assert isinstance(result, neo4j.Record) @pytest.mark.usefixtures("setup_neo4j") @@ -105,7 +105,7 @@ def test_hybrid_cypher_retriever_search_vector(driver): assert isinstance(results, list) assert len(results) == 5 for record in results: - assert isinstance(record, Record) + assert isinstance(record, neo4j.Record) assert "author.name" in record.keys() @@ -129,4 +129,4 @@ def test_hybrid_retriever_return_properties(driver): assert isinstance(results, list) assert len(results) == 5 for result in results: - assert isinstance(result, Record) + assert isinstance(result, neo4j.Record) diff --git a/tests/e2e/test_vector_e2e.py b/tests/e2e/test_vector_e2e.py index 9bf3f5a4..608dd4d0 100644 --- a/tests/e2e/test_vector_e2e.py +++ b/tests/e2e/test_vector_e2e.py @@ -102,3 +102,24 @@ def test_vector_retriever_return_properties(driver): assert len(results) == 5 for result in results: assert isinstance(result, VectorSearchRecord) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_retriever_filters(driver): + retriever = VectorRetriever( + driver, + "vector-index-name", + ) + + top_k = 2 + results = retriever.search( + query_vector=[1.0 for _ in range(1536)], + filters={"int_property": {"$gt": 2}}, + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 2 + for result in results: + assert isinstance(result, VectorSearchRecord) + assert result.node["int_property"] > 2 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index b22e58fc..75e0419f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -14,14 +14,14 @@ # limitations under the License. import pytest +import neo4j from neo4j_genai import VectorRetriever, VectorCypherRetriever, HybridRetriever -from neo4j import Driver from unittest.mock import MagicMock, patch @pytest.fixture(scope="function") def driver(): - return MagicMock(spec=Driver) + return MagicMock(spec=neo4j.Driver) @pytest.fixture(scope="function") diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index b55e3c54..79486835 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -60,7 +60,7 @@ def test_hybrid_search_text_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.HYBRID) + search_query, _ = get_search_query(SearchType.HYBRID) records = retriever.search(query_text=query_text, top_k=top_k) @@ -98,7 +98,7 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector( None, None, ] - search_query = get_search_query(SearchType.HYBRID) + search_query, _ = get_search_query(SearchType.HYBRID) retriever.search(query_text=query_text, query_vector=query_vector, top_k=top_k) @@ -161,7 +161,7 @@ def test_hybrid_retriever_return_properties(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.HYBRID, return_properties) + search_query, _ = get_search_query(SearchType.HYBRID, return_properties) records = retriever.search(query_text=query_text, top_k=top_k) @@ -206,7 +206,9 @@ def test_hybrid_cypher_retrieval_query_with_params(_verify_version_mock, driver) None, None, ] - search_query = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + search_query, _ = get_search_query( + SearchType.HYBRID, retrieval_query=retrieval_query + ) records = retriever.search( query_text=query_text, diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index 69c1f615..c3fd1ade 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -34,8 +34,11 @@ def test_vector_cypher_retriever_initialization(driver): mock_verify.assert_called_once() +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_happy_path(_verify_version_mock, driver): +def test_similarity_search_vector_happy_path( + _verify_version_mock, _fetch_index_infos, driver +): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -46,14 +49,14 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR) + search_query, _ = get_search_query(SearchType.VECTOR) records = retriever.search(query_vector=query_vector, top_k=top_k) retriever.driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": query_vector, }, @@ -61,8 +64,11 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_happy_path(_verify_version_mock, driver): +def test_similarity_search_text_happy_path( + _verify_version_mock, _fetch_index_infos, driver +): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -75,7 +81,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR) + search_query, _ = get_search_query(SearchType.VECTOR) records = retriever.search(query_text=query_text, top_k=top_k) @@ -83,7 +89,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, }, @@ -92,8 +98,11 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_return_properties(_verify_version_mock, driver): +def test_similarity_search_text_return_properties( + _verify_version_mock, _fetch_index_infos, driver +): embed_query_vector = [1.0 for _ in range(3)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -111,7 +120,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR, return_properties) + search_query, _ = get_search_query(SearchType.VECTOR, return_properties) records = retriever.search(query_text=query_text, top_k=top_k) @@ -119,7 +128,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query.rstrip(), { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, }, @@ -175,8 +184,11 @@ def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retri ) +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_bad_results(_verify_version_mock, driver): +def test_similarity_search_vector_bad_results( + _verify_version_mock, _fetch_index_infos, driver +): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -187,7 +199,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR) + search_query, _ = get_search_query(SearchType.VECTOR) with pytest.raises(ValueError): retriever.search(query_vector=query_vector, top_k=top_k) @@ -195,15 +207,16 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): retriever.driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": query_vector, }, ) +@patch("neo4j_genai.VectorCypherRetriever._fetch_index_infos") @patch("neo4j_genai.VectorCypherRetriever._verify_version") -def test_retrieval_query_happy_path(_verify_version_mock, driver): +def test_retrieval_query_happy_path(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -221,7 +234,9 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query( + SearchType.VECTOR, retrieval_query=retrieval_query + ) records = retriever.search( query_text=query_text, @@ -232,7 +247,7 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, }, @@ -240,8 +255,9 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] +@patch("neo4j_genai.VectorCypherRetriever._fetch_index_infos") @patch("neo4j_genai.VectorCypherRetriever._verify_version") -def test_retrieval_query_with_params(_verify_version_mock, driver): +def test_retrieval_query_with_params(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -265,7 +281,9 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query( + SearchType.VECTOR, retrieval_query=retrieval_query + ) records = retriever.search( query_text=query_text, @@ -278,7 +296,7 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, "param": "dummy-param", @@ -288,8 +306,9 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] +@patch("neo4j_genai.VectorCypherRetriever._fetch_index_infos") @patch("neo4j_genai.VectorCypherRetriever._verify_version") -def test_retrieval_query_cypher_error(_verify_version_mock, driver): +def test_retrieval_query_cypher_error(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py new file mode 100644 index 00000000..66d3d6c8 --- /dev/null +++ b/tests/unit/test_filters.py @@ -0,0 +1,598 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch, call + +import pytest + +from neo4j_genai.filters import ( + get_metadata_filter, + _single_condition_cypher, + _handle_field_filter, + _construct_metadata_filter, + Operator, + EqOperator, + NeqOperator, + LtOperator, + GtOperator, + LteOperator, + GteOperator, + InOperator, + NinOperator, + LikeOperator, + ILikeOperator, + ParameterStore, +) + + +@pytest.fixture(scope="function") +def param_store_empty(): + return ParameterStore() + + +def test_param_store(): + ps = ParameterStore() + assert ps.params == {} + ps.add(1) + assert ps.params == {"param_0": 1} + ps.add("some value") + assert ps.params == {"param_0": 1, "param_1": "some value"} + + +def test_operator_field_escape(): + assert Operator.safe_field_cypher("name") == "name" + assert Operator.safe_field_cypher("_name") == "_name" + assert Operator.safe_field_cypher("na_me123") == "na_me123" + # escape if using separators different from underscore + assert Operator.safe_field_cypher("na-me") == "`na-me`" + assert Operator.safe_field_cypher("na me") == "`na me`" + assert Operator.safe_field_cypher("na.me") == "`na.me`" + # escape if name starts with a non alpha character + assert Operator.safe_field_cypher("1name") == "`1name`" + assert Operator.safe_field_cypher("?name") == "`?name`" + # escape if name contains special characters + assert Operator.safe_field_cypher("n*ame") == "`n*ame`" + assert Operator.safe_field_cypher("na_me123%") == "`na_me123%`" + assert Operator.safe_field_cypher("\name") == "`\name`" + # escape the escape character + assert Operator.safe_field_cypher("na`me") == "`na``me`" + + +def test_single_condition_cypher_eq(param_store_empty): + generated = _single_condition_cypher( + "field", EqOperator, "value", param_store=param_store_empty + ) + assert generated == "node.field = $param_0" + assert param_store_empty.params == {"param_0": "value"} + + +def test_single_condition_cypher_eq_node_alias(param_store_empty): + generated = _single_condition_cypher( + "field", EqOperator, "value", node_alias="n", param_store=param_store_empty + ) + assert generated == "n.field = $param_0" + assert param_store_empty.params == {"param_0": "value"} + + +def test_single_condition_cypher_neq(param_store_empty): + generated = _single_condition_cypher( + "field", NeqOperator, "value", param_store=param_store_empty + ) + assert generated == "node.field <> $param_0" + assert param_store_empty.params == {"param_0": "value"} + + +def test_single_condition_cypher_lt(param_store_empty): + generated = _single_condition_cypher( + "field", LtOperator, 10, param_store=param_store_empty + ) + assert generated == "node.field < $param_0" + assert param_store_empty.params == {"param_0": 10} + + +def test_single_condition_cypher_gt(param_store_empty): + generated = _single_condition_cypher( + "field", GtOperator, 10, param_store=param_store_empty + ) + assert generated == "node.field > $param_0" + assert param_store_empty.params == {"param_0": 10} + + +def test_single_condition_cypher_lte(param_store_empty): + generated = _single_condition_cypher( + "field", LteOperator, 10, param_store=param_store_empty + ) + assert generated == "node.field <= $param_0" + assert param_store_empty.params == {"param_0": 10} + + +def test_single_condition_cypher_gte(param_store_empty): + generated = _single_condition_cypher( + "field", GteOperator, 10, param_store=param_store_empty + ) + assert generated == "node.field >= $param_0" + assert param_store_empty.params == {"param_0": 10} + + +def test_single_condition_cypher_in_int(param_store_empty): + generated = _single_condition_cypher( + "field", InOperator, [1, 2, 3], param_store=param_store_empty + ) + assert generated == "node.field IN $param_0" + assert param_store_empty.params == {"param_0": [1, 2, 3]} + + +def test_single_condition_cypher_in_str(param_store_empty): + generated = _single_condition_cypher( + "field", InOperator, ["a", "b", "c"], param_store=param_store_empty + ) + assert generated == "node.field IN $param_0" + assert param_store_empty.params == {"param_0": ["a", "b", "c"]} + + +def test_single_condition_cypher_in_invalid_type(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _single_condition_cypher( + "field", + InOperator, + [ + {"my_tuple"}, + ], + param_store=param_store_empty, + ) + assert "Unsupported type: " in str(excinfo) + + +def test_single_condition_cypher_nin(param_store_empty): + generated = _single_condition_cypher( + "field", NinOperator, ["a", "b", "c"], param_store=param_store_empty + ) + assert generated == "node.field NOT IN $param_0" + assert param_store_empty.params == {"param_0": ["a", "b", "c"]} + + +def test_single_condition_cypher_like(param_store_empty): + generated = _single_condition_cypher( + "field", LikeOperator, "value", param_store=param_store_empty + ) + assert generated == "node.field CONTAINS $param_0" + assert param_store_empty.params == {"param_0": "value"} + + +def test_single_condition_cypher_ilike(param_store_empty): + generated = _single_condition_cypher( + "field", ILikeOperator, "My Value", param_store=param_store_empty + ) + assert generated == "toLower(node.field) CONTAINS $param_0" + assert param_store_empty.params == {"param_0": "my value"} + + +def test_single_condition_cypher_like_not_a_string(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _single_condition_cypher( + "field", ILikeOperator, 1, param_store=param_store_empty + ) + assert "Expected string value, got " in str(excinfo) + + +def test_single_condition_cypher_escaped_field_name(param_store_empty): + generated = _single_condition_cypher( + "na`me", EqOperator, "value", param_store=param_store_empty + ) + assert generated == "node.`na``me` = $param_0" + + +def test_handle_field_filter_not_a_string(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter(1, "value", param_store=param_store_empty) + assert "Field should be a string but got: with value: 1" in str( + excinfo + ) + + +def test_handle_field_filter_field_start_with_dollar_sign(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter("$field_name", "value", param_store=param_store_empty) + assert ( + "Invalid filter condition. Expected a field but got an operator: $field_name" + in str(excinfo) + ) + + +def test_handle_field_filter_bad_value(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter( + "field", + value={"operator1": "value1", "operator2": "value2"}, + param_store=param_store_empty, + ) + assert "Invalid filter condition" in str(excinfo) + + +def test_handle_field_filter_bad_operator_name(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter( + "field", value={"$invalid": "value"}, param_store=param_store_empty + ) + assert "Invalid operator: $invalid" in str(excinfo) + + +def test_handle_field_filter_operator_between(param_store_empty): + generated = _handle_field_filter( + "field", value={"$between": [0, 1]}, param_store=param_store_empty + ) + assert generated == "$param_0 <= node.field <= $param_1" + assert param_store_empty.params == {"param_0": 0, "param_1": 1} + + +def test_handle_field_filter_operator_between_not_enough_parameters(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter( + "field", + value={ + "$between": [ + 0, + ] + }, + param_store=param_store_empty, + ) + assert "Expected lower and upper bounds in a list, got [0]" in str(excinfo) + + +@patch("neo4j_genai.filters._single_condition_cypher", return_value="condition") +def test_handle_field_filter_implicit_eq( + _single_condition_cypher_mocked, param_store_empty +): + generated = _handle_field_filter( + "field", value="some_value", param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", EqOperator, "some_value", param_store_empty, "node" + ) + assert generated == "condition" + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_eq(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter( + "field", value={"$eq": "some_value"}, param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", EqOperator, "some_value", param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_neq(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter( + "field", value={"$ne": "some_value"}, param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", NeqOperator, "some_value", param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_lt(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$lt": 1}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", LtOperator, 1, param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_gt(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$gt": 1}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", GtOperator, 1, param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_lte(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$lte": 1}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", LteOperator, 1, param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_gte(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$gte": 1}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", GteOperator, 1, param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_in(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$in": [1, 2]}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", InOperator, [1, 2], param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_nin(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$nin": [1, 2]}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", NinOperator, [1, 2], param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_like(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter( + "field", value={"$like": "value"}, param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", LikeOperator, "value", param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_ilike(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter( + "field", value={"$ilike": "value"}, param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", ILikeOperator, "value", param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._handle_field_filter") +def test_construct_metadata_filter_filter_is_not_a_dict( + _handle_field_filter_mock, param_store_empty +): + with pytest.raises(ValueError) as excinfo: + _construct_metadata_filter([], param_store_empty, node_alias="n") + assert "Filter must be a dictionary, got " in str(excinfo) + + +@patch("neo4j_genai.filters._handle_field_filter") +def test_construct_metadata_filter_no_operator( + _handle_field_filter_mock, param_store_empty +): + _construct_metadata_filter({"field": "value"}, param_store_empty, node_alias="n") + _handle_field_filter_mock.assert_called_once_with( + "field", "value", param_store_empty, node_alias="n" + ) + + +@patch("neo4j_genai.filters._construct_metadata_filter") +def test_construct_metadata_filter_implicit_and( + _construct_metadata_filter_mock, param_store_empty +): + _construct_metadata_filter( + {"field_1": "value_1", "field_2": "value_2"}, param_store_empty, node_alias="n" + ) + _construct_metadata_filter_mock.assert_has_calls( + [ + call( + {"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, + param_store_empty, + "n", + ), + ] + ) + + +@patch( + "neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"] +) +def test_construct_metadata_filter_explicit_and( + _construct_metadata_filter_mock, param_store_empty +): + generated = _construct_metadata_filter( + {"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, + param_store_empty, + node_alias="n", + ) + _construct_metadata_filter_mock.assert_has_calls( + [ + call({"field_1": "value_1"}, param_store_empty, "n"), + call({"field_2": "value_2"}, param_store_empty, "n"), + ] + ) + assert generated == "(filter1) AND (filter2)" + + +@patch( + "neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"] +) +def test_construct_metadata_filter_or( + _construct_metadata_filter_mock, param_store_empty +): + generated = _construct_metadata_filter( + {"$or": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, + param_store_empty, + node_alias="n", + ) + _construct_metadata_filter_mock.assert_has_calls( + [ + call({"field_1": "value_1"}, param_store_empty, "n"), + call({"field_2": "value_2"}, param_store_empty, "n"), + ] + ) + assert generated == "(filter1) OR (filter2)" + + +def test_construct_metadata_filter_invalid_operator(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _construct_metadata_filter( + {"$invalid": [{}, {}]}, param_store_empty, node_alias="n" + ) + assert "Unsupported operator: $invalid" in str(excinfo) + + +def test_get_metadata_filter_single_field_string(): + filters = {"field": "string_value"} + query, params = get_metadata_filter(filters) + assert query == "node.field = $param_0" + assert params == {"param_0": "string_value"} + + +def test_get_metadata_filter_single_field_int(): + filters = {"field": 28} + query, params = get_metadata_filter(filters) + assert query == "node.field = $param_0" + assert params == {"param_0": 28} + + +def test_get_metadata_filter_single_field_bool(): + filters = {"field": False} + query, params = get_metadata_filter(filters) + assert query == "node.field = $param_0" + assert params == {"param_0": False} + + +def test_get_metadata_filter_explicit_eq_operator(): + filters = {"field": {"$eq": "string_value"}} + query, params = get_metadata_filter(filters) + assert query == "node.field = $param_0" + assert params == {"param_0": "string_value"} + + +def test_get_metadata_filter_neq_operator(): + filters = {"field": {"$ne": "string_value"}} + query, params = get_metadata_filter(filters) + assert query == "node.field <> $param_0" + assert params == {"param_0": "string_value"} + + +def test_get_metadata_filter_lt_operator(): + filters = {"field": {"$lt": 1}} + query, params = get_metadata_filter(filters) + assert query == "node.field < $param_0" + assert params == {"param_0": 1} + + +def test_get_metadata_filter_gt_operator(): + filters = {"field": {"$gt": 1}} + query, params = get_metadata_filter(filters) + assert query == "node.field > $param_0" + assert params == {"param_0": 1} + + +def test_get_metadata_filter_lte_operator(): + filters = {"field": {"$lte": 1}} + query, params = get_metadata_filter(filters) + assert query == "node.field <= $param_0" + assert params == {"param_0": 1} + + +def test_get_metadata_filter_gte_operator(): + filters = {"field": {"$gte": 1}} + query, params = get_metadata_filter(filters) + assert query == "node.field >= $param_0" + assert params == {"param_0": 1} + + +def test_get_metadata_filter_in_operator(): + filters = {"field": {"$in": ["a", "b"]}} + query, params = get_metadata_filter(filters) + assert query == "node.field IN $param_0" + assert params == {"param_0": ["a", "b"]} + + +def test_get_metadata_filter_not_in_operator(): + filters = {"field": {"$nin": ["a", "b"]}} + query, params = get_metadata_filter(filters) + assert query == "node.field NOT IN $param_0" + assert params == {"param_0": ["a", "b"]} + + +def test_get_metadata_filter_like_operator(): + filters = {"field": {"$like": "some_value"}} + query, params = get_metadata_filter(filters) + assert query == "node.field CONTAINS $param_0" + assert params == {"param_0": "some_value"} + + +def test_get_metadata_filter_ilike_operator(): + filters = {"field": {"$ilike": "Some Value"}} + query, params = get_metadata_filter(filters) + assert query == "toLower(node.field) CONTAINS $param_0" + assert params == {"param_0": "some value"} + + +def test_get_metadata_filter_between_operator(): + filters = {"field": {"$between": [0, 1]}} + query, params = get_metadata_filter(filters) + assert query == "$param_0 <= node.field <= $param_1" + assert params == {"param_0": 0, "param_1": 1} + + +def test_get_metadata_filter_implicit_and_condition(): + filters = {"field_1": "string_value", "field_2": True} + query, params = get_metadata_filter(filters) + assert query == "(node.field_1 = $param_0) AND (node.field_2 = $param_1)" + assert params == {"param_0": "string_value", "param_1": True} + + +def test_get_metadata_filter_explicit_and_condition(): + filters = {"$and": [{"field_1": "string_value"}, {"field_2": True}]} + query, params = get_metadata_filter(filters) + assert query == "(node.field_1 = $param_0) AND (node.field_2 = $param_1)" + assert params == {"param_0": "string_value", "param_1": True} + + +def test_get_metadata_filter_explicit_and_condition_with_operator(): + filters = { + "$and": [{"field_1": {"$ne": "string_value"}}, {"field_2": {"$in": [1, 2]}}] + } + query, params = get_metadata_filter(filters) + assert query == "(node.field_1 <> $param_0) AND (node.field_2 IN $param_1)" + assert params == {"param_0": "string_value", "param_1": [1, 2]} + + +def test_get_metadata_filter_or_condition(): + filters = {"$or": [{"field_1": "string_value"}, {"field_2": True}]} + query, params = get_metadata_filter(filters) + assert query == "(node.field_1 = $param_0) OR (node.field_2 = $param_1)" + assert params == {"param_0": "string_value", "param_1": True} + + +def test_get_metadata_filter_and_or_combined(): + filters = { + "$and": [ + {"$or": [{"field_1": "string_value"}, {"field_2": True}]}, + {"field_3": 11}, + ] + } + query, params = get_metadata_filter(filters) + assert query == ( + "((node.field_1 = $param_0) OR (node.field_2 = $param_1)) " + "AND (node.field_3 = $param_2)" + ) + assert params == {"param_0": "string_value", "param_1": True, "param_2": 11} + + +# now testing bad filters +def test_get_metadata_filter_field_name_with_dollar_sign(): + filters = {"$field": "value"} + with pytest.raises(ValueError): + get_metadata_filter(filters) + + +def test_get_metadata_filter_and_no_list(): + filters = {"$and": {}} + with pytest.raises(ValueError): + get_metadata_filter(filters) + + +def test_get_metadata_filter_unsupported_operator(): + filters = {"field": {"$unsupported": "value"}} + with pytest.raises(ValueError): + get_metadata_filter(filters) diff --git a/tests/unit/test_indexes.py b/tests/unit/test_indexes.py index 84122684..c5509da9 100644 --- a/tests/unit/test_indexes.py +++ b/tests/unit/test_indexes.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import neo4j.exceptions import pytest from neo4j_genai.indexes import ( create_vector_index, - drop_index, + drop_index_if_exists, create_fulltext_index, ) @@ -68,16 +68,33 @@ def test_create_vector_index_validation_error_dimensions(driver): assert "Error for inputs to create_vector_index" in str(excinfo) +def test_create_vector_index_raises_error_with_neo4j_client_error(driver): + driver.execute_query.side_effect = neo4j.exceptions.ClientError + with pytest.raises(neo4j.exceptions.ClientError): + create_vector_index(driver, "my-index", "People", "name", 2048, "cosine") + + def test_create_vector_index_validation_error_similarity_fn(driver): with pytest.raises(ValueError) as excinfo: create_vector_index(driver, "my-index", "People", "name", 1536, "algebra") assert "Error for inputs to create_vector_index" in str(excinfo) -def test_drop_index(driver): +def test_drop_index_if_exists(driver): drop_query = "DROP INDEX $name IF EXISTS" - drop_index(driver, "my-index") + drop_index_if_exists(driver, "my-index") + + driver.execute_query.assert_called_once_with( + drop_query, + {"name": "my-index"}, + ) + + +def test_drop_index_if_exists_raises_error_with_neo4j_client_error(driver): + drop_query = "DROP INDEX $name IF EXISTS" + + drop_index_if_exists(driver, "my-index") driver.execute_query.assert_called_once_with( drop_query, @@ -102,6 +119,15 @@ def test_create_fulltext_index_happy_path(driver): ) +def test_create_fulltext_index_raises_error_with_neo4j_client_error(driver): + label = "node-label" + text_node_properties = ["property-1", "property-2"] + driver.execute_query.side_effect = neo4j.exceptions.ClientError + + with pytest.raises(neo4j.exceptions.ClientError): + create_fulltext_index(driver, "my-index", label, text_node_properties) + + def test_create_fulltext_index_empty_node_properties(driver): label = "node-label" node_properties = [] diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py index 3ce7c774..0ef2b68e 100644 --- a/tests/unit/test_neo4j_queries.py +++ b/tests/unit/test_neo4j_queries.py @@ -12,18 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch -from neo4j_genai.neo4j_queries import get_search_query, get_query_tail +from neo4j_genai.neo4j_queries import get_search_query, _get_query_tail from neo4j_genai.types import SearchType def test_vector_search_basic(): expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " - "YIELD node, score" + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score" ) - result = get_search_query(SearchType.VECTOR) + result, params = get_search_query(SearchType.VECTOR) assert result.strip() == expected.strip() + assert params == {} def test_hybrid_search_basic(): @@ -41,29 +44,78 @@ def test_hybrid_search_basic(): "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " "RETURN node, score" ) - result = get_search_query(SearchType.HYBRID) + result, _ = get_search_query(SearchType.HYBRID) assert result.strip() == expected.strip() def test_vector_search_with_properties(): properties = ["name", "age"] expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " "YIELD node, score " "RETURN node {.name, .age} as node, score" ) - result = get_search_query(SearchType.VECTOR, return_properties=properties) + result, _ = get_search_query(SearchType.VECTOR, return_properties=properties) assert result.strip() == expected.strip() def test_vector_search_with_retrieval_query(): retrieval_query = "MATCH (n) RETURN n LIMIT 10" expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " "YIELD node, score " + retrieval_query ) - result = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + result, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + assert result.strip() == expected.strip() + + +@patch("neo4j_genai.neo4j_queries.get_metadata_filter", return_value=["True", {}]) +def test_vector_search_with_filters(_mock): + expected = ( + "MATCH (node:`Label`) " + "WHERE node.`vector` IS NOT NULL " + "AND size(node.`vector`) = toInteger($embedding_dimension)" + " AND (True) " + "WITH node, " + "vector.similarity.cosine(node.`vector`, $query_vector) AS score " + "ORDER BY score DESC LIMIT $top_k" + " RETURN node, score" + ) + result, params = get_search_query( + SearchType.VECTOR, + node_label="Label", + embedding_node_property="vector", + embedding_dimension=1, + filters={"field": "value"}, + ) + assert result.strip() == expected.strip() + assert params == {"embedding_dimension": 1} + + +@patch( + "neo4j_genai.neo4j_queries.get_metadata_filter", + return_value=["True", {"param": "value"}], +) +def test_vector_search_with_params_from_filters(_mock): + expected = ( + "MATCH (node:`Label`) " + "WHERE node.`vector` IS NOT NULL " + "AND size(node.`vector`) = toInteger($embedding_dimension)" + " AND (True) " + "WITH node, " + "vector.similarity.cosine(node.`vector`, $query_vector) AS score " + "ORDER BY score DESC LIMIT $top_k" + " RETURN node, score" + ) + result, params = get_search_query( + SearchType.VECTOR, + node_label="Label", + embedding_node_property="vector", + embedding_dimension=1, + filters={"field": "value"}, + ) assert result.strip() == expected.strip() + assert params == {"embedding_dimension": 1, "param": "value"} def test_hybrid_search_with_retrieval_query(): @@ -82,7 +134,7 @@ def test_hybrid_search_with_retrieval_query(): "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + retrieval_query ) - result = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + result, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) assert result.strip() == expected.strip() @@ -102,28 +154,28 @@ def test_hybrid_search_with_properties(): "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " "RETURN node {.name, .age} as node, score" ) - result = get_search_query(SearchType.HYBRID, return_properties=properties) + result, _ = get_search_query(SearchType.HYBRID, return_properties=properties) assert result.strip() == expected.strip() def test_get_query_tail_with_retrieval_query(): retrieval_query = "MATCH (n) RETURN n LIMIT 10" expected = retrieval_query - result = get_query_tail(retrieval_query=retrieval_query) + result = _get_query_tail(retrieval_query=retrieval_query) assert result.strip() == expected.strip() def test_get_query_tail_with_properties(): properties = ["name", "age"] expected = "RETURN node {.name, .age} as node, score" - result = get_query_tail(return_properties=properties) + result = _get_query_tail(return_properties=properties) assert result.strip() == expected.strip() def test_get_query_tail_with_fallback(): fallback = "HELLO" expected = fallback - result = get_query_tail(fallback_return=fallback) + result = _get_query_tail(fallback_return=fallback) assert result.strip() == expected.strip() @@ -133,7 +185,7 @@ def test_get_query_tail_ordering_all(): fallback = "HELLO" expected = retrieval_query - result = get_query_tail( + result = _get_query_tail( retrieval_query=retrieval_query, return_properties=properties, fallback_return=fallback, @@ -146,7 +198,7 @@ def test_get_query_tail_ordering_no_retrieval_query(): fallback = "HELLO" expected = "RETURN node {.name, .age} as node, score" - result = get_query_tail( + result = _get_query_tail( return_properties=properties, fallback_return=fallback, )