From 12e5ec6de3dc48e3c03f6b78d7888c77d3f0393e Mon Sep 17 00:00:00 2001 From: Alex Sherstinsky Date: Wed, 24 Apr 2024 13:31:01 -0700 Subject: [PATCH 1/3] community: Support both Predibase SDK-v1 and SDK-v2 in Predibase-LangChain integration (#20859) --- docs/docs/integrations/llms/predibase.ipynb | 27 ++- docs/docs/integrations/providers/predibase.md | 23 +- .../langchain_community/llms/predibase.py | 202 +++++++++++++----- .../tests/unit_tests/llms/test_predibase.py | 20 +- 4 files changed, 205 insertions(+), 67 deletions(-) diff --git a/docs/docs/integrations/llms/predibase.ipynb b/docs/docs/integrations/llms/predibase.ipynb index fabd36d75fb30..fc5e43bd463b7 100644 --- a/docs/docs/integrations/llms/predibase.ipynb +++ b/docs/docs/integrations/llms/predibase.ipynb @@ -63,12 +63,13 @@ "source": [ "from langchain_community.llms import Predibase\n", "\n", - "# With a fine-tuned adapter hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).\n", + "# With a fine-tuned adapter hosted at Predibase (adapter_version must be specified).\n", "model = Predibase(\n", " model=\"mistral-7b\",\n", + " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"e2e_nlg\",\n", " adapter_version=1,\n", - " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", ")" ] }, @@ -83,8 +84,9 @@ "# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n", "model = Predibase(\n", " model=\"mistral-7b\",\n", - " adapter_id=\"predibase/e2e_nlg\",\n", " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", + " adapter_id=\"predibase/e2e_nlg\",\n", ")" ] }, @@ -122,7 +124,9 @@ "from langchain_community.llms import Predibase\n", "\n", "model = Predibase(\n", - " model=\"mistral-7b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n", + " model=\"mistral-7b\",\n", + " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", ")" ] }, @@ -136,12 +140,13 @@ }, "outputs": [], "source": [ - "# With a fine-tuned adapter hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).\n", + "# With a fine-tuned adapter hosted at Predibase (adapter_version must be specified).\n", "model = Predibase(\n", " model=\"mistral-7b\",\n", + " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"e2e_nlg\",\n", " adapter_version=1,\n", - " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", ")" ] }, @@ -154,8 +159,9 @@ "# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n", "llm = Predibase(\n", " model=\"mistral-7b\",\n", - " adapter_id=\"predibase/e2e_nlg\",\n", " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", + " adapter_id=\"predibase/e2e_nlg\",\n", ")" ] }, @@ -247,13 +253,14 @@ "\n", "model = Predibase(\n", " model=\"my-base-LLM\",\n", - " adapter_id=\"my-finetuned-adapter-id\", # Supports both, Predibase-hosted and HuggingFace-hosted model repositories.\n", - " # adapter_version=1, # optional (returns the latest, if omitted)\n", " predibase_api_key=os.environ.get(\n", " \"PREDIBASE_API_TOKEN\"\n", " ), # Adapter argument is optional.\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", + " adapter_id=\"my-finetuned-adapter-id\", # Supports both, Predibase-hosted and HuggingFace-hosted adapter repositories.\n", + " adapter_version=1, # required for Predibase-hosted adapters (ignored for HuggingFace-hosted adapters)\n", ")\n", - "# replace my-finetuned-LLM with the name of your model in Predibase" + "# replace my-base-LLM with the name of your choice of a serverless base model in Predibase" ] }, { diff --git a/docs/docs/integrations/providers/predibase.md b/docs/docs/integrations/providers/predibase.md index 5a88ff117f372..8f27b818f4a38 100644 --- a/docs/docs/integrations/providers/predibase.md +++ b/docs/docs/integrations/providers/predibase.md @@ -17,7 +17,11 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}" from langchain_community.llms import Predibase -model = Predibase(model="mistral-7b"", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN")) +model = Predibase( + model="mistral-7b", + predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"), + predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted) +) response = model("Can you recommend me a nice dry wine?") print(response) @@ -31,8 +35,14 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}" from langchain_community.llms import Predibase -# The fine-tuned adapter is hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version). -model = Predibase(model="mistral-7b"", adapter_id="e2e_nlg", adapter_version=1, predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN")) +# The fine-tuned adapter is hosted at Predibase (adapter_version must be specified). +model = Predibase( + model="mistral-7b", + predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"), + predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted) + adapter_id="e2e_nlg", + adapter_version=1, +) response = model("Can you recommend me a nice dry wine?") print(response) @@ -47,7 +57,12 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}" from langchain_community.llms import Predibase # The fine-tuned adapter is hosted at HuggingFace (adapter_version does not apply and will be ignored). -model = Predibase(model="mistral-7b"", adapter_id="predibase/e2e_nlg", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN")) +model = Predibase( + model="mistral-7b", + predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"), + predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted) + adapter_id="predibase/e2e_nlg", +) response = model("Can you recommend me a nice dry wine?") print(response) diff --git a/libs/community/langchain_community/llms/predibase.py b/libs/community/langchain_community/llms/predibase.py index e3f5da7fd9e34..b45ff1d1ea9a3 100644 --- a/libs/community/langchain_community/llms/predibase.py +++ b/libs/community/langchain_community/llms/predibase.py @@ -1,3 +1,4 @@ +import os from typing import Any, Dict, List, Mapping, Optional, Union from langchain_core.callbacks import CallbackManagerForLLMRun @@ -17,13 +18,15 @@ class Predibase(LLM): An optional `adapter_id` parameter is the Predibase ID or HuggingFace ID of a fine-tuned LLM adapter, whose base model is the `model` parameter; the fine-tuned adapter must be compatible with its base model; - otherwise, an error is raised. If a Predibase ID references the - fine-tuned adapter, then the `adapter_version` in the adapter repository can - be optionally specified; omitting it defaults to the most recent version. + otherwise, an error is raised. If the fine-tuned adapter is hosted at Predibase, + then `adapter_version` in the adapter repository must be specified. + + An optional `predibase_sdk_version` parameter defaults to latest SDK version. """ model: str predibase_api_key: SecretStr + predibase_sdk_version: Optional[str] = None adapter_id: Optional[str] = None adapter_version: Optional[int] = None model_kwargs: Dict[str, Any] = Field(default_factory=dict) @@ -46,65 +49,139 @@ def _call( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - try: - from predibase import PredibaseClient - from predibase.pql import get_session - from predibase.pql.api import ( - ServerResponseError, - Session, - ) - from predibase.resource.llm.interface import ( - HuggingFaceLLM, - LLMDeployment, - ) - from predibase.resource.llm.response import GeneratedResponse - from predibase.resource.model import Model - - session: Session = get_session( - token=self.predibase_api_key.get_secret_value(), - gateway="https://api.app.predibase.com/v1", - serving_endpoint="serving.app.predibase.com", - ) - pc: PredibaseClient = PredibaseClient(session=session) - except ImportError as e: - raise ImportError( - "Could not import Predibase Python package. " - "Please install it with `pip install predibase`." - ) from e - except ValueError as e: - raise ValueError("Your API key is not correct. Please try again") from e options: Dict[str, Union[str, float]] = ( self.model_kwargs or self.default_options_for_generation ) - base_llm_deployment: LLMDeployment = pc.LLM( - uri=f"pb://deployments/{self.model}" + if self._is_deprecated_sdk_version(): + try: + from predibase import PredibaseClient + from predibase.pql import get_session + from predibase.pql.api import ( + ServerResponseError, + Session, + ) + from predibase.resource.llm.interface import ( + HuggingFaceLLM, + LLMDeployment, + ) + from predibase.resource.llm.response import GeneratedResponse + from predibase.resource.model import Model + + session: Session = get_session( + token=self.predibase_api_key.get_secret_value(), + gateway="https://api.app.predibase.com/v1", + serving_endpoint="serving.app.predibase.com", + ) + pc: PredibaseClient = PredibaseClient(session=session) + except ImportError as e: + raise ImportError( + "Could not import Predibase Python package. " + "Please install it with `pip install predibase`." + ) from e + except ValueError as e: + raise ValueError("Your API key is not correct. Please try again") from e + + base_llm_deployment: LLMDeployment = pc.LLM( + uri=f"pb://deployments/{self.model}" + ) + result: GeneratedResponse + if self.adapter_id: + """ + Attempt to retrieve the fine-tuned adapter from a Predibase + repository. If absent, then load the fine-tuned adapter + from a HuggingFace repository. + """ + adapter_model: Union[Model, HuggingFaceLLM] + try: + adapter_model = pc.get_model( + name=self.adapter_id, + version=self.adapter_version, + model_id=None, + ) + except ServerResponseError: + # Predibase does not recognize the adapter ID (query HuggingFace). + adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}") + result = base_llm_deployment.with_adapter(model=adapter_model).generate( + prompt=prompt, + options=options, + ) + else: + result = base_llm_deployment.generate( + prompt=prompt, + options=options, + ) + return result.response + + from predibase import Predibase + + os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com" + predibase: Predibase = Predibase( + api_token=self.predibase_api_key.get_secret_value() + ) + + import requests + from lorax.client import Client as LoraxClient + from lorax.errors import GenerationError + from lorax.types import Response + + lorax_client: LoraxClient = predibase.deployments.client( + deployment_ref=self.model ) - result: GeneratedResponse + + response: Response if self.adapter_id: """ Attempt to retrieve the fine-tuned adapter from a Predibase repository. If absent, then load the fine-tuned adapter from a HuggingFace repository. """ - adapter_model: Union[Model, HuggingFaceLLM] + if self.adapter_version: + # Since the adapter version is provided, query the Predibase repository. + pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}" + try: + response = lorax_client.generate( + prompt=prompt, + adapter_id=pb_adapter_id, + **options, + ) + except GenerationError as ge: + raise ValueError( + f"""An adapter with the ID "{pb_adapter_id}" cannot be \ +found in the Predibase repository of fine-tuned adapters.""" + ) from ge + else: + # The adapter version is omitted, + # hence look for the adapter ID in the HuggingFace repository. + try: + response = lorax_client.generate( + prompt=prompt, + adapter_id=self.adapter_id, + adapter_source="hub", + **options, + ) + except GenerationError as ge: + raise ValueError( + f"""Either an adapter with the ID "{self.adapter_id}" \ +cannot be found in a HuggingFace repository, or it is incompatible with the \ +base model (please make sure that the adapter configuration is consistent). +""" + ) from ge + else: try: - adapter_model = pc.get_model( - name=self.adapter_id, - version=self.adapter_version, - model_id=None, + response = lorax_client.generate( + prompt=prompt, + **options, ) - except ServerResponseError: - # Predibase does not recognize the adapter ID (query HuggingFace). - adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}") - result = base_llm_deployment.with_adapter(model=adapter_model).generate( - prompt=prompt, - options=options, - ) - else: - result = base_llm_deployment.generate( - prompt=prompt, - options=options, - ) - return result.response + except requests.JSONDecodeError as jde: + raise ValueError( + f"""An LLM with the deployment ID "{self.model}" cannot be found \ +at Predibase (please refer to \ +"https://docs.predibase.com/user-guide/inference/models" for the list of \ +supported models). +""" + ) from jde + response_text = response.generated_text + + return response_text @property def _identifying_params(self) -> Mapping[str, Any]: @@ -112,3 +189,26 @@ def _identifying_params(self) -> Mapping[str, Any]: return { **{"model_kwargs": self.model_kwargs}, } + + def _is_deprecated_sdk_version(self) -> bool: + try: + import semantic_version + from predibase.version import __version__ as current_version + from semantic_version.base import Version + + sdk_semver_deprecated: Version = semantic_version.Version( + version_string="2024.4.8" + ) + actual_current_version: str = self.predibase_sdk_version or current_version + sdk_semver_current: Version = semantic_version.Version( + version_string=actual_current_version + ) + return not ( + (sdk_semver_current > sdk_semver_deprecated) + or ("+dev" in actual_current_version) + ) + except ImportError as e: + raise ImportError( + "Could not import Predibase Python package. " + "Please install it with `pip install semantic_version predibase`." + ) from e diff --git a/libs/community/tests/unit_tests/llms/test_predibase.py b/libs/community/tests/unit_tests/llms/test_predibase.py index 9a9fba7f0effc..e2b2dc128d9a9 100644 --- a/libs/community/tests/unit_tests/llms/test_predibase.py +++ b/libs/community/tests/unit_tests/llms/test_predibase.py @@ -19,6 +19,22 @@ def test_api_key_masked_when_passed_via_constructor( assert captured.out == "**********" +def test_specifying_predibase_sdk_version_argument() -> None: + llm = Predibase( + model="my_llm", + predibase_api_key="secret-api-key", + ) + assert not llm.predibase_sdk_version + + legacy_predibase_sdk_version = "2024.4.8" + llm = Predibase( + model="my_llm", + predibase_api_key="secret-api-key", + predibase_sdk_version=legacy_predibase_sdk_version, + ) + assert llm.predibase_sdk_version == legacy_predibase_sdk_version + + def test_specifying_adapter_id_argument() -> None: llm = Predibase(model="my_llm", predibase_api_key="secret-api-key") assert not llm.adapter_id @@ -33,8 +49,8 @@ def test_specifying_adapter_id_argument() -> None: llm = Predibase( model="my_llm", - adapter_id="my-other-hf-adapter", predibase_api_key="secret-api-key", + adapter_id="my-other-hf-adapter", ) assert llm.adapter_id == "my-other-hf-adapter" assert llm.adapter_version is None @@ -55,9 +71,9 @@ def test_specifying_adapter_id_and_adapter_version_arguments() -> None: llm = Predibase( model="my_llm", + predibase_api_key="secret-api-key", adapter_id="my-other-hf-adapter", adapter_version=3, - predibase_api_key="secret-api-key", ) assert llm.adapter_id == "my-other-hf-adapter" assert llm.adapter_version == 3 From 0186e4e633f1577bfe29a13b6f2115338f177b02 Mon Sep 17 00:00:00 2001 From: Martin Kolb Date: Wed, 24 Apr 2024 22:47:27 +0200 Subject: [PATCH 2/3] community[patch]: Advanced filtering for HANA Cloud Vector Engine (#20821) - **Description:** This PR adds support for advanced filtering to the integration of HANA Vector Engine. The newly supported filtering operators are: $eq, $ne, $gt, $gte, $lt, $lte, $between, $in, $nin, $like, $and, $or - **Issue:** N/A - **Dependencies:** no new dependencies added Added integration tests to: `libs/community/tests/integration_tests/vectorstores/test_hanavector.py` Description of the new capabilities in notebook: `docs/docs/integrations/vectorstores/hanavector.ipynb` --- .../vectorstores/sap_hanavector.ipynb | 173 ++++++++++++++++++ .../vectorstores/hanavector.py | 116 ++++++++++-- .../vectorstores/test_hanavector.py | 172 ++++++++++++++++- 3 files changed, 448 insertions(+), 13 deletions(-) diff --git a/docs/docs/integrations/vectorstores/sap_hanavector.ipynb b/docs/docs/integrations/vectorstores/sap_hanavector.ipynb index 42e89eb21f556..37f9c86ecc615 100644 --- a/docs/docs/integrations/vectorstores/sap_hanavector.ipynb +++ b/docs/docs/integrations/vectorstores/sap_hanavector.ipynb @@ -357,6 +357,179 @@ "print(len(docs))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Advanced filtering\n", + "In addition to the basic value-based filtering capabilities, it is possible to use more advanced filtering.\n", + "The table below shows the available filter operators.\n", + "\n", + "| Operator | Semantic |\n", + "|----------|-------------------------|\n", + "| `$eq` | Equality (==) |\n", + "| `$ne` | Inequality (!=) |\n", + "| `$lt` | Less than (<) |\n", + "| `$lte` | Less than or equal (<=) |\n", + "| `$gt` | Greater than (>) |\n", + "| `$gte` | Greater than or equal (>=) |\n", + "| `$in` | Contained in a set of given values (in) |\n", + "| `$nin` | Not contained in a set of given values (not in) |\n", + "| `$between` | Between the range of two boundary values |\n", + "| `$like` | Text equality based on the \"LIKE\" semantics in SQL (using \"%\" as wildcard) |\n", + "| `$and` | Logical \"and\", supporting 2 or more operands |\n", + "| `$or` | Logical \"or\", supporting 2 or more operands |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare some test documents\n", + "docs = [\n", + " Document(\n", + " page_content=\"First\",\n", + " metadata={\"name\": \"adam\", \"is_active\": True, \"id\": 1, \"height\": 10.0},\n", + " ),\n", + " Document(\n", + " page_content=\"Second\",\n", + " metadata={\"name\": \"bob\", \"is_active\": False, \"id\": 2, \"height\": 5.7},\n", + " ),\n", + " Document(\n", + " page_content=\"Third\",\n", + " metadata={\"name\": \"jane\", \"is_active\": True, \"id\": 3, \"height\": 2.4},\n", + " ),\n", + "]\n", + "\n", + "db = HanaDB(\n", + " connection=connection,\n", + " embedding=embeddings,\n", + " table_name=\"LANGCHAIN_DEMO_ADVANCED_FILTER\",\n", + ")\n", + "\n", + "# Delete already existing documents from the table\n", + "db.delete(filter={})\n", + "db.add_documents(docs)\n", + "\n", + "\n", + "# Helper function for printing filter results\n", + "def print_filter_result(result):\n", + " if len(result) == 0:\n", + " print(\"\")\n", + " for doc in result:\n", + " print(doc.metadata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Filtering with `$ne`, `$gt`, `$gte`, `$lt`, `$lte`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "advanced_filter = {\"id\": {\"$ne\": 1}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"id\": {\"$gt\": 1}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"id\": {\"$gte\": 1}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"id\": {\"$lt\": 1}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"id\": {\"$lte\": 1}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Filtering with `$between`, `$in`, `$nin`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "advanced_filter = {\"id\": {\"$between\": (1, 2)}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"name\": {\"$in\": [\"adam\", \"bob\"]}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"name\": {\"$nin\": [\"adam\", \"bob\"]}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Text filtering with `$like`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "advanced_filter = {\"name\": {\"$like\": \"a%\"}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"name\": {\"$like\": \"%a%\"}}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Combined filtering with `$and`, `$or`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "advanced_filter = {\"$or\": [{\"id\": 1}, {\"name\": \"bob\"}]}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"$and\": [{\"id\": 1}, {\"id\": 2}]}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))\n", + "\n", + "advanced_filter = {\"$or\": [{\"id\": 1}, {\"id\": 2}, {\"id\": 3}]}\n", + "print(f\"Filter: {advanced_filter}\")\n", + "print_filter_result(db.similarity_search(\"just testing\", k=5, filter=advanced_filter))" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/libs/community/langchain_community/vectorstores/hanavector.py b/libs/community/langchain_community/vectorstores/hanavector.py index cc8822c888e20..ca595dec93533 100644 --- a/libs/community/langchain_community/vectorstores/hanavector.py +++ b/libs/community/langchain_community/vectorstores/hanavector.py @@ -8,6 +8,7 @@ TYPE_CHECKING, Any, Callable, + Dict, Iterable, List, Optional, @@ -34,6 +35,27 @@ DistanceStrategy.EUCLIDEAN_DISTANCE: ("L2DISTANCE", "ASC"), } +COMPARISONS_TO_SQL = { + "$eq": "=", + "$ne": "<>", + "$lt": "<", + "$lte": "<=", + "$gt": ">", + "$gte": ">=", +} + +IN_OPERATORS_TO_SQL = { + "$in": "IN", + "$nin": "NOT IN", +} + +BETWEEN_OPERATOR = "$between" + +LIKE_OPERATOR = "$like" + +LOGICAL_OPERATORS_TO_SQL = {"$and": "AND", "$or": "OR"} + + default_distance_strategy = DistanceStrategy.COSINE default_table_name: str = "EMBEDDINGS" default_content_column: str = "VEC_TEXT" @@ -404,29 +426,99 @@ def similarity_search_by_vector( # type: ignore[override] return [doc for doc, _ in docs_and_scores] def _create_where_by_filter(self, filter): # type: ignore[no-untyped-def] + query_tuple = [] + where_str = "" + if filter: + where_str, query_tuple = self._process_filter_object(filter) + where_str = " WHERE " + where_str + return where_str, query_tuple + + def _process_filter_object(self, filter): # type: ignore[no-untyped-def] query_tuple = [] where_str = "" if filter: for i, key in enumerate(filter.keys()): - if i == 0: - where_str += " WHERE " - else: + filter_value = filter[key] + if i != 0: where_str += " AND " - where_str += f" JSON_VALUE({self.metadata_column}, '$.{key}') = ?" - - if isinstance(filter[key], bool): - if filter[key]: - query_tuple.append("true") + # Handling of 'special' boolean operators "$and", "$or" + if key in LOGICAL_OPERATORS_TO_SQL: + logical_operator = LOGICAL_OPERATORS_TO_SQL[key] + logical_operands = filter_value + for j, logical_operand in enumerate(logical_operands): + if j != 0: + where_str += f" {logical_operator} " + ( + where_str_logical, + query_tuple_logical, + ) = self._process_filter_object(logical_operand) + where_str += where_str_logical + query_tuple += query_tuple_logical + continue + + operator = "=" + sql_param = "?" + + if isinstance(filter_value, bool): + query_tuple.append("true" if filter_value else "false") + elif isinstance(filter_value, int) or isinstance(filter_value, str): + query_tuple.append(filter_value) + elif isinstance(filter_value, Dict): + # Handling of 'special' operators starting with "$" + special_op = next(iter(filter_value)) + special_val = filter_value[special_op] + # "$eq", "$ne", "$lt", "$lte", "$gt", "$gte" + if special_op in COMPARISONS_TO_SQL: + operator = COMPARISONS_TO_SQL[special_op] + if isinstance(special_val, bool): + query_tuple.append("true" if filter_value else "false") + elif isinstance(special_val, float): + sql_param = "CAST(? as float)" + query_tuple.append(special_val) + else: + query_tuple.append(special_val) + # "$between" + elif special_op == BETWEEN_OPERATOR: + between_from = special_val[0] + between_to = special_val[1] + operator = "BETWEEN" + sql_param = "? AND ?" + query_tuple.append(between_from) + query_tuple.append(between_to) + # "$like" + elif special_op == LIKE_OPERATOR: + operator = "LIKE" + query_tuple.append(special_val) + # "$in", "$nin" + elif special_op in IN_OPERATORS_TO_SQL: + operator = IN_OPERATORS_TO_SQL[special_op] + if isinstance(special_val, list): + for i, list_entry in enumerate(special_val): + if i == 0: + sql_param = "(" + sql_param = sql_param + "?" + if i == (len(special_val) - 1): + sql_param = sql_param + ")" + else: + sql_param = sql_param + "," + query_tuple.append(list_entry) + else: + raise ValueError( + f"Unsupported value for {operator}: {special_val}" + ) else: - query_tuple.append("false") - elif isinstance(filter[key], int) or isinstance(filter[key], str): - query_tuple.append(filter[key]) + raise ValueError(f"Unsupported operator: {special_op}") else: raise ValueError( - f"Unsupported filter data-type: {type(filter[key])}" + f"Unsupported filter data-type: {type(filter_value)}" ) + where_str += ( + f" JSON_VALUE({self.metadata_column}, '$.{key}')" + f" {operator} {sql_param}" + ) + return where_str, query_tuple def delete( # type: ignore[override] diff --git a/libs/community/tests/integration_tests/vectorstores/test_hanavector.py b/libs/community/tests/integration_tests/vectorstores/test_hanavector.py index c725c534a9e9c..6a1992cc748c3 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_hanavector.py +++ b/libs/community/tests/integration_tests/vectorstores/test_hanavector.py @@ -2,7 +2,7 @@ import os import random -from typing import List +from typing import Any, Dict, List import numpy as np import pytest @@ -12,6 +12,23 @@ from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, ) +from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import ( + DOCUMENTS, + TYPE_1_FILTERING_TEST_CASES, + TYPE_2_FILTERING_TEST_CASES, + TYPE_3_FILTERING_TEST_CASES, + TYPE_4_FILTERING_TEST_CASES, + TYPE_5_FILTERING_TEST_CASES, +) + +TYPE_4B_FILTERING_TEST_CASES = [ + # Test $nin, which is missing in TYPE_4_FILTERING_TEST_CASES + ( + {"name": {"$nin": ["adam", "bob"]}}, + [3], + ), +] + try: from hdbcli import dbapi @@ -924,3 +941,156 @@ def test_hanavector_table_mixed_case_names(texts: List[str]) -> None: # check results of similarity search assert texts[0] == vectordb.similarity_search(texts[0], 1)[0].page_content + + +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_hanavector_enhanced_filter_1() -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_1" + # Delete table if it exists + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES) +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_pgvector_with_with_metadata_filters_1( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_1" + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) + ids = [doc.metadata["id"] for doc in docs] + assert len(ids) == len(expected_ids), test_filter + assert set(ids).issubset(expected_ids), test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES) +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_pgvector_with_with_metadata_filters_2( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_2" + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) + ids = [doc.metadata["id"] for doc in docs] + assert len(ids) == len(expected_ids), test_filter + assert set(ids).issubset(expected_ids), test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES) +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_pgvector_with_with_metadata_filters_3( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_3" + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) + ids = [doc.metadata["id"] for doc in docs] + assert len(ids) == len(expected_ids), test_filter + assert set(ids).issubset(expected_ids), test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES) +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_pgvector_with_with_metadata_filters_4( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_4" + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) + ids = [doc.metadata["id"] for doc in docs] + assert len(ids) == len(expected_ids), test_filter + assert set(ids).issubset(expected_ids), test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4B_FILTERING_TEST_CASES) +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_pgvector_with_with_metadata_filters_4b( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_4B" + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) + ids = [doc.metadata["id"] for doc in docs] + assert len(ids) == len(expected_ids), test_filter + assert set(ids).issubset(expected_ids), test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES) +@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") +def test_pgvector_with_with_metadata_filters_5( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + table_name = "TEST_TABLE_ENHANCED_FILTER_5" + drop_table(test_setup.conn, table_name) + + vectorDB = HanaDB( + connection=test_setup.conn, + embedding=embedding, + table_name=table_name, + ) + + vectorDB.add_documents(DOCUMENTS) + + docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) + ids = [doc.metadata["id"] for doc in docs] + assert len(ids) == len(expected_ids), test_filter + assert set(ids).issubset(expected_ids), test_filter From 13751c32977e3f12c0ef7f47e3e1bdd1db6b1d8c Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Wed, 24 Apr 2024 13:49:21 -0700 Subject: [PATCH 3/3] community: `tigergraph` fixes (#20034) - added guard on the `pyTigerGraph` import - added a missed example page in the `docs/integrations/graphs/` - formatted the `docs/integrations/providers/` page to the consistent format. Added links. --- docs/docs/integrations/graphs/tigergraph.mdx | 37 +++++++++++++++++++ .../integrations/providers/tigergraph.mdx | 28 ++++---------- .../graphs/tigergraph_graph.py | 8 +++- 3 files changed, 51 insertions(+), 22 deletions(-) create mode 100644 docs/docs/integrations/graphs/tigergraph.mdx diff --git a/docs/docs/integrations/graphs/tigergraph.mdx b/docs/docs/integrations/graphs/tigergraph.mdx new file mode 100644 index 0000000000000..a9901459a0d9b --- /dev/null +++ b/docs/docs/integrations/graphs/tigergraph.mdx @@ -0,0 +1,37 @@ +# TigerGraph + +>[TigerGraph](https://www.tigergraph.com/tigergraph-db/) is a natively distributed and high-performance graph database. +> The storage of data in a graph format of vertices and edges leads to rich relationships, +> ideal for grouding LLM responses. + +A big example of the `TigerGraph` and `LangChain` integration [presented here](https://github.com/tigergraph/graph-ml-notebooks/blob/main/applications/large_language_models/TigerGraph_LangChain_Demo.ipynb). + +## Installation and Setup + +Follow instructions [how to connect to the `TigerGraph` database](https://docs.tigergraph.com/pytigergraph/current/getting-started/connection). + +Install the Python SDK: + +```bash +pip install pyTigerGraph +``` + +## Example + +To utilize the `TigerGraph InquiryAI` functionality, you can import `TigerGraph` from `langchain_community.graphs`. + +```python +import pyTigerGraph as tg + +conn = tg.TigerGraphConnection(host="DATABASE_HOST_HERE", graphname="GRAPH_NAME_HERE", username="USERNAME_HERE", password="PASSWORD_HERE") + +### ==== CONFIGURE INQUIRYAI HOST ==== +conn.ai.configureInquiryAIHost("INQUIRYAI_HOST_HERE") + +from langchain_community.graphs import TigerGraph + +graph = TigerGraph(conn) +result = graph.query("How many servers are there?") +print(result) +``` + diff --git a/docs/docs/integrations/providers/tigergraph.mdx b/docs/docs/integrations/providers/tigergraph.mdx index 50d48d3ec1855..95a62635c83a3 100644 --- a/docs/docs/integrations/providers/tigergraph.mdx +++ b/docs/docs/integrations/providers/tigergraph.mdx @@ -1,15 +1,13 @@ # TigerGraph -What is `TigerGraph`? - -**TigerGraph in a nutshell:** - -- `TigerGraph` is a natively distributed and high-performance graph database. -- The storage of data in a graph format of vertices and edges leads to rich relationships, ideal for grouding LLM responses. -- Get started quickly with `TigerGraph` by visiting [their website](https://tigergraph.com/). +>[TigerGraph](https://www.tigergraph.com/tigergraph-db/) is a natively distributed and high-performance graph database. +> The storage of data in a graph format of vertices and edges leads to rich relationships, +> ideal for grouding LLM responses. ## Installation and Setup +Follow instructions [how to connect to the `TigerGraph` database](https://docs.tigergraph.com/pytigergraph/current/getting-started/connection). + Install the Python SDK: ```bash @@ -18,22 +16,10 @@ pip install pyTigerGraph ## Graph store -### TigerGraph Store +### TigerGraph -To utilize the `TigerGraph InquiryAI` functionality, you can import `TigerGraph` from `langchain_community.graphs`. +See a [usage example](/docs/integrations/graphs/tigergraph). ```python -import pyTigerGraph as tg - -conn = tg.TigerGraphConnection(host="DATABASE_HOST_HERE", graphname="GRAPH_NAME_HERE", username="USERNAME_HERE", password="PASSWORD_HERE") - -### ==== CONFIGURE INQUIRYAI HOST ==== -conn.ai.configureInquiryAIHost("INQUIRYAI_HOST_HERE") - from langchain_community.graphs import TigerGraph - -graph = TigerGraph(conn) -result = graph.query("How many servers are there?") -print(result) ``` - diff --git a/libs/community/langchain_community/graphs/tigergraph_graph.py b/libs/community/langchain_community/graphs/tigergraph_graph.py index f32d43e6e5e2c..84b24218ad104 100644 --- a/libs/community/langchain_community/graphs/tigergraph_graph.py +++ b/libs/community/langchain_community/graphs/tigergraph_graph.py @@ -39,7 +39,13 @@ def get_schema(self) -> str: # type: ignore[override] return str(self._schema) def set_connection(self, conn: Any) -> None: - from pyTigerGraph import TigerGraphConnection + try: + from pyTigerGraph import TigerGraphConnection + except ImportError: + raise ImportError( + "Could not import pyTigerGraph python package. " + "Please install it with `pip install pyTigerGraph`." + ) if not isinstance(conn, TigerGraphConnection): msg = "**conn** parameter must inherit from TigerGraphConnection"