diff --git a/docs/docs/integrations/vectorstores/sap_hanavector.ipynb b/docs/docs/integrations/vectorstores/sap_hanavector.ipynb index 42e89eb21f556..37f9c86ecc615 100644 --- a/docs/docs/integrations/vectorstores/sap_hanavector.ipynb +++ b/docs/docs/integrations/vectorstores/sap_hanavector.ipynb @@ -357,6 +357,179 @@ "print(len(docs))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Advanced filtering\n", + "In addition to the basic value-based filtering capabilities, it is possible to use more advanced filtering.\n", + "The table below shows the available filter operators.\n", + "\n", + "| Operator | Semantic |\n", + "|----------|-------------------------|\n", + "| `$eq` | Equality (==) |\n", + "| `$ne` | Inequality (!=) |\n", + "| `$lt` | Less than (<) |\n", + "| `$lte` | Less than or equal (<=) |\n", + "| `$gt` | Greater than (>) |\n", + "| `$gte` | Greater than or equal (>=) |\n", + "| `$in` | Contained in a set of given values (in) |\n", + "| `$nin` | Not contained in a set of given values (not in) |\n", + "| `$between` | Between the range of two boundary values |\n", + "| `$like` | Text equality based on the \"LIKE\" semantics in SQL (using \"%\" as wildcard) |\n", + "| `$and` | Logical \"and\", supporting 2 or more operands |\n", + "| `$or` | Logical \"or\", supporting 2 or more operands |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare some test documents\n", + "docs = [\n", + " Document(\n", + " page_content=\"First\",\n", + " metadata={\"name\": \"adam\", \"is_active\": True, \"id\": 1, \"height\": 10.0},\n", + " ),\n", + " Document(\n", + " page_content=\"Second\",\n", + " metadata={\"name\": \"bob\", \"is_active\": False, \"id\": 2, \"height\": 5.7},\n", + " ),\n", + " Document(\n", + " page_content=\"Third\",\n", + " metadata={\"name\": \"jane\", \"is_active\": True, \"id\": 3, \"height\": 2.4},\n", + " ),\n", + "]\n", + "\n", + "db = HanaDB(\n", + " connection=connection,\n", + " embedding=embeddings,\n", + " table_name=\"LANGCHAIN_DEMO_ADVANCED_FILTER\",\n", + ")\n", + "\n", + "# Delete already existing documents from the table\n", + "db.delete(filter={})\n", + "db.add_documents(docs)\n", + "\n", + "\n", + "# Helper function for printing filter results\n", + "def print_filter_result(result):\n", + " if len(result) == 0:\n", + " print(\"\")\n", + " for doc in result:\n", + " print(doc.metadata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Filtering with `$ne`, `$gt`, `$gte`, `$lt`, `$lte`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "advanced_filter = {\"id\": {\"$ne\": 1}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"id\": {\"$gt\": 1}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"id\": {\"$gte\": 1}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"id\": {\"$lt\": 1}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"id\": {\"$lte\": 1}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Filtering with `$between`, `$in`, `$nin`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "advanced_filter = {\"id\": {\"$between\": (1, 2)}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"name\": {\"$in\": [\"adam\", \"bob\"]}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"name\": {\"$nin\": [\"adam\", \"bob\"]}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Text filtering with `$like`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "advanced_filter = {\"name\": {\"$like\": \"a%\"}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"name\": {\"$like\": \"%a%\"}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Combined filtering with `$and`, `$or`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "advanced_filter = {\"$or\": [{\"id\": 1}, {\"name\": \"bob\"}]}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"$and\": [{\"id\": 1}, {\"id\": 2}]}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"$or\": [{\"id\": 1}, {\"id\": 2}, {\"id\": 3}]}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/libs/community/langchain_community/vectorstores/hanavector.py b/libs/community/langchain_community/vectorstores/hanavector.py index cc8822c888e20..ca595dec93533 100644 --- a/libs/community/langchain_community/vectorstores/hanavector.py +++ b/libs/community/langchain_community/vectorstores/hanavector.py @@ -8,6 +8,7 @@ TYPE_CHECKING, Any, Callable, + Dict, Iterable, List, Optional, @@ -34,6 +35,27 @@ DistanceStrategy.EUCLIDEAN_DISTANCE: ("L2DISTANCE", "ASC"), } +COMPARISONS_TO_SQL = { + "$eq": "=", + "$ne": "<>", + "$lt": "<", + "$lte": "<=", + "$gt": ">", + "$gte": ">=", +} + +IN_OPERATORS_TO_SQL = { + "$in": "IN", + "$nin": "NOT IN", +} + +BETWEEN_OPERATOR = "$between" + +LIKE_OPERATOR = "$like" + +LOGICAL_OPERATORS_TO_SQL = {"$and": "AND", "$or": "OR"} + + default_distance_strategy = DistanceStrategy.COSINE default_table_name: str = "EMBEDDINGS" default_content_column: str = "VEC_TEXT" @@ -404,29 +426,99 @@ def similarity_search_by_vector( # type: ignore[override] return [doc for doc, _ in docs_and_scores] def _create_where_by_filter(self, filter): # type: ignore[no-untyped-def] + query_tuple = [] + where_str = "" + if filter: + where_str, query_tuple = self._process_filter_object(filter) + where_str = " WHERE " + where_str + return where_str, query_tuple + + def _process_filter_object(self, filter): # type: ignore[no-untyped-def] query_tuple = [] where_str = "" if filter: for i, key in enumerate(filter.keys()): - if i == 0: - where_str += " WHERE " - else: + filter_value = filter[key] + if i != 0: where_str += " AND " - where_str += f" JSON_VALUE({self.metadata_column}, '$.{key}') = ?" - - if isinstance(filter[key], bool): - if filter[key]: - query_tuple.append("true") + # Handling of 'special' boolean operators "$and", "$or" + if key in LOGICAL_OPERATORS_TO_SQL: + logical_operator = LOGICAL_OPERATORS_TO_SQL[key] + logical_operands = filter_value + for j, logical_operand in enumerate(logical_operands): + if j != 0: + where_str += f" {logical_operator} " + ( + where_str_logical, + query_tuple_logical, + ) = self._process_filter_object(logical_operand) + where_str += where_str_logical + query_tuple += query_tuple_logical + continue + + operator = "=" + sql_param = "?" + + if isinstance(filter_value, bool): + query_tuple.append("true" if filter_value else "false") + elif isinstance(filter_value, int) or isinstance(filter_value, str): + query_tuple.append(filter_value) + elif isinstance(filter_value, Dict): + # Handling of 'special' operators starting with "$" + special_op = next(iter(filter_value)) + special_val = filter_value[special_op] + # "$eq", "$ne", "$lt", "$lte", "$gt", "$gte" + if special_op in COMPARISONS_TO_SQL: + operator = COMPARISONS_TO_SQL[special_op] + if isinstance(special_val, bool): + query_tuple.append("true" if filter_value else "false") + elif isinstance(special_val, float): + sql_param = "CAST(? as float)" + query_tuple.append(special_val) + else: + query_tuple.append(special_val) + # "$between" + elif special_op == BETWEEN_OPERATOR: + between_from = special_val[0] + between_to = special_val[1] + operator = "BETWEEN" + sql_param = "? AND ?" + query_tuple.append(between_from) + query_tuple.append(between_to) + # "$like" + elif special_op == LIKE_OPERATOR: + operator = "LIKE" + query_tuple.append(special_val) + # "$in", "$nin" + elif special_op in IN_OPERATORS_TO_SQL: + operator = IN_OPERATORS_TO_SQL[special_op] + if isinstance(special_val, list): + for i, list_entry in enumerate(special_val): + if i == 0: + sql_param = "(" + sql_param = sql_param + "?" + if i == (len(special_val) - 1): + sql_param = sql_param + ")" + else: + sql_param = sql_param + "," + query_tuple.append(list_entry) + else: + raise ValueError( + f"Unsupported value for {operator}: {special_val}" + ) else: - query_tuple.append("false") - elif isinstance(filter[key], int) or isinstance(filter[key], str): - query_tuple.append(filter[key]) + raise ValueError(f"Unsupported operator: {special_op}") else: raise ValueError( - f"Unsupported filter data-type: {type(filter[key])}" + f"Unsupported filter data-type: {type(filter_value)}" ) + where_str += ( + f" JSON_VALUE({self.metadata_column}, '$.{key}')" + f" {operator} {sql_param}" + ) + return where_str, query_tuple def delete( # type: ignore[override] diff --git a/libs/community/tests/integration_tests/vectorstores/test_hanavector.py b/libs/community/tests/integration_tests/vectorstores/test_hanavector.py index c725c534a9e9c..6a1992cc748c3 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_hanavector.py +++ b/libs/community/tests/integration_tests/vectorstores/test_hanavector.py @@ -2,7 +2,7 @@ import os import random -from typing import List +from typing import Any, Dict, List import numpy as np import pytest @@ -12,6 +12,23 @@ from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, ) +from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import ( + DOCUMENTS, + TYPE_1_FILTERING_TEST_CASES, + TYPE_2_FILTERING_TEST_CASES, + TYPE_3_FILTERING_TEST_CASES, + TYPE_4_FILTERING_TEST_CASES, + TYPE_5_FILTERING_TEST_CASES, +) + +TYPE_4B_FILTERING_TEST_CASES = [ + # Test $nin, which is missing in TYPE_4_FILTERING_TEST_CASES + ( + {"name": {"$nin": ["adam", "bob"]}}, + [3], + ), +] + try: from hdbcli import dbapi @@ -924,3 +941,156 @@ def test_hanavector_table_mixed_case_names(texts: List[str]) -> None: # check results of similarity search assert texts[0] == vectordb.similarity_search(texts[0], 1)[0].page_content + + +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_hanavector_enhanced_filter_1() -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_1" + # Delete table if it exists + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES) +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_pgvector_with_with_metadata_filters_1( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_1" + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) + ids = [doc.metadata["id"] for doc in docs] + assert len(ids) == len(expected_ids), test_filter + assert set(ids).issubset(expected_ids), test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES) +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_pgvector_with_with_metadata_filters_2( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_2" + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) + ids = [doc.metadata["id"] for doc in docs] + assert len(ids) == len(expected_ids), test_filter + assert set(ids).issubset(expected_ids), test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES) +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_pgvector_with_with_metadata_filters_3( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_3" + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) + ids = [doc.metadata["id"] for doc in docs] + assert len(ids) == len(expected_ids), test_filter + assert set(ids).issubset(expected_ids), test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES) +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_pgvector_with_with_metadata_filters_4( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_4" + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) + ids = [doc.metadata["id"] for doc in docs] + assert len(ids) == len(expected_ids), test_filter + assert set(ids).issubset(expected_ids), test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4B_FILTERING_TEST_CASES) +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_pgvector_with_with_metadata_filters_4b( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_4B" + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) + ids = [doc.metadata["id"] for doc in docs] + assert len(ids) == len(expected_ids), test_filter + assert set(ids).issubset(expected_ids), test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES) +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_pgvector_with_with_metadata_filters_5( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_5" + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) + ids = [doc.metadata["id"] for doc in docs] + assert len(ids) == len(expected_ids), test_filter + assert set(ids).issubset(expected_ids), test_filter