From 12e5ec6de3dc48e3c03f6b78d7888c77d3f0393e Mon Sep 17 00:00:00 2001 From: Alex Sherstinsky Date: Wed, 24 Apr 2024 13:31:01 -0700 Subject: [PATCH] community: Support both Predibase SDK-v1 and SDK-v2 in Predibase-LangChain integration (#20859) --- docs/docs/integrations/llms/predibase.ipynb | 27 ++- docs/docs/integrations/providers/predibase.md | 23 +- .../langchain_community/llms/predibase.py | 202 +++++++++++++----- .../tests/unit_tests/llms/test_predibase.py | 20 +- 4 files changed, 205 insertions(+), 67 deletions(-) diff --git a/docs/docs/integrations/llms/predibase.ipynb b/docs/docs/integrations/llms/predibase.ipynb index fabd36d75fb30..fc5e43bd463b7 100644 --- a/docs/docs/integrations/llms/predibase.ipynb +++ b/docs/docs/integrations/llms/predibase.ipynb @@ -63,12 +63,13 @@ "source": [ "from langchain_community.llms import Predibase\n", "\n", - "# With a fine-tuned adapter hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).\n", + "# With a fine-tuned adapter hosted at Predibase (adapter_version must be specified).\n", "model = Predibase(\n", " model=\"mistral-7b\",\n", + " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"e2e_nlg\",\n", " adapter_version=1,\n", - " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", ")" ] }, @@ -83,8 +84,9 @@ "# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n", "model = Predibase(\n", " model=\"mistral-7b\",\n", - " adapter_id=\"predibase/e2e_nlg\",\n", " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", + " adapter_id=\"predibase/e2e_nlg\",\n", ")" ] }, @@ -122,7 +124,9 @@ "from langchain_community.llms import Predibase\n", "\n", "model = Predibase(\n", - " model=\"mistral-7b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n", + " model=\"mistral-7b\",\n", + " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", ")" ] }, @@ -136,12 +140,13 @@ }, "outputs": [], "source": [ - "# With a fine-tuned adapter hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).\n", + "# With a fine-tuned adapter hosted at Predibase (adapter_version must be specified).\n", "model = Predibase(\n", " model=\"mistral-7b\",\n", + " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"e2e_nlg\",\n", " adapter_version=1,\n", - " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", ")" ] }, @@ -154,8 +159,9 @@ "# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n", "llm = Predibase(\n", " model=\"mistral-7b\",\n", - " adapter_id=\"predibase/e2e_nlg\",\n", " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", + " adapter_id=\"predibase/e2e_nlg\",\n", ")" ] }, @@ -247,13 +253,14 @@ "\n", "model = Predibase(\n", " model=\"my-base-LLM\",\n", - " adapter_id=\"my-finetuned-adapter-id\", # Supports both, Predibase-hosted and HuggingFace-hosted model repositories.\n", - " # adapter_version=1, # optional (returns the latest, if omitted)\n", " predibase_api_key=os.environ.get(\n", " \"PREDIBASE_API_TOKEN\"\n", " ), # Adapter argument is optional.\n", + " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", + " adapter_id=\"my-finetuned-adapter-id\", # Supports both, Predibase-hosted and HuggingFace-hosted adapter repositories.\n", + " adapter_version=1, # required for Predibase-hosted adapters (ignored for HuggingFace-hosted adapters)\n", ")\n", - "# replace my-finetuned-LLM with the name of your model in Predibase" + "# replace my-base-LLM with the name of your choice of a serverless base model in Predibase" ] }, { diff --git a/docs/docs/integrations/providers/predibase.md b/docs/docs/integrations/providers/predibase.md index 5a88ff117f372..8f27b818f4a38 100644 --- a/docs/docs/integrations/providers/predibase.md +++ b/docs/docs/integrations/providers/predibase.md @@ -17,7 +17,11 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}" from langchain_community.llms import Predibase -model = Predibase(model="mistral-7b"", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN")) +model = Predibase( + model="mistral-7b", + predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"), + predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted) +) response = model("Can you recommend me a nice dry wine?") print(response) @@ -31,8 +35,14 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}" from langchain_community.llms import Predibase -# The fine-tuned adapter is hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version). -model = Predibase(model="mistral-7b"", adapter_id="e2e_nlg", adapter_version=1, predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN")) +# The fine-tuned adapter is hosted at Predibase (adapter_version must be specified). +model = Predibase( + model="mistral-7b", + predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"), + predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted) + adapter_id="e2e_nlg", + adapter_version=1, +) response = model("Can you recommend me a nice dry wine?") print(response) @@ -47,7 +57,12 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}" from langchain_community.llms import Predibase # The fine-tuned adapter is hosted at HuggingFace (adapter_version does not apply and will be ignored). -model = Predibase(model="mistral-7b"", adapter_id="predibase/e2e_nlg", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN")) +model = Predibase( + model="mistral-7b", + predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"), + predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted) + adapter_id="predibase/e2e_nlg", +) response = model("Can you recommend me a nice dry wine?") print(response) diff --git a/libs/community/langchain_community/llms/predibase.py b/libs/community/langchain_community/llms/predibase.py index e3f5da7fd9e34..b45ff1d1ea9a3 100644 --- a/libs/community/langchain_community/llms/predibase.py +++ b/libs/community/langchain_community/llms/predibase.py @@ -1,3 +1,4 @@ +import os from typing import Any, Dict, List, Mapping, Optional, Union from langchain_core.callbacks import CallbackManagerForLLMRun @@ -17,13 +18,15 @@ class Predibase(LLM): An optional `adapter_id` parameter is the Predibase ID or HuggingFace ID of a fine-tuned LLM adapter, whose base model is the `model` parameter; the fine-tuned adapter must be compatible with its base model; - otherwise, an error is raised. If a Predibase ID references the - fine-tuned adapter, then the `adapter_version` in the adapter repository can - be optionally specified; omitting it defaults to the most recent version. + otherwise, an error is raised. If the fine-tuned adapter is hosted at Predibase, + then `adapter_version` in the adapter repository must be specified. + + An optional `predibase_sdk_version` parameter defaults to latest SDK version. """ model: str predibase_api_key: SecretStr + predibase_sdk_version: Optional[str] = None adapter_id: Optional[str] = None adapter_version: Optional[int] = None model_kwargs: Dict[str, Any] = Field(default_factory=dict) @@ -46,65 +49,139 @@ def _call( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - try: - from predibase import PredibaseClient - from predibase.pql import get_session - from predibase.pql.api import ( - ServerResponseError, - Session, - ) - from predibase.resource.llm.interface import ( - HuggingFaceLLM, - LLMDeployment, - ) - from predibase.resource.llm.response import GeneratedResponse - from predibase.resource.model import Model - - session: Session = get_session( - token=self.predibase_api_key.get_secret_value(), - gateway="https://api.app.predibase.com/v1", - serving_endpoint="serving.app.predibase.com", - ) - pc: PredibaseClient = PredibaseClient(session=session) - except ImportError as e: - raise ImportError( - "Could not import Predibase Python package. " - "Please install it with `pip install predibase`." - ) from e - except ValueError as e: - raise ValueError("Your API key is not correct. Please try again") from e options: Dict[str, Union[str, float]] = ( self.model_kwargs or self.default_options_for_generation ) - base_llm_deployment: LLMDeployment = pc.LLM( - uri=f"pb://deployments/{self.model}" + if self._is_deprecated_sdk_version(): + try: + from predibase import PredibaseClient + from predibase.pql import get_session + from predibase.pql.api import ( + ServerResponseError, + Session, + ) + from predibase.resource.llm.interface import ( + HuggingFaceLLM, + LLMDeployment, + ) + from predibase.resource.llm.response import GeneratedResponse + from predibase.resource.model import Model + + session: Session = get_session( + token=self.predibase_api_key.get_secret_value(), + gateway="https://api.app.predibase.com/v1", + serving_endpoint="serving.app.predibase.com", + ) + pc: PredibaseClient = PredibaseClient(session=session) + except ImportError as e: + raise ImportError( + "Could not import Predibase Python package. " + "Please install it with `pip install predibase`." + ) from e + except ValueError as e: + raise ValueError("Your API key is not correct. Please try again") from e + + base_llm_deployment: LLMDeployment = pc.LLM( + uri=f"pb://deployments/{self.model}" + ) + result: GeneratedResponse + if self.adapter_id: + """ + Attempt to retrieve the fine-tuned adapter from a Predibase + repository. If absent, then load the fine-tuned adapter + from a HuggingFace repository. + """ + adapter_model: Union[Model, HuggingFaceLLM] + try: + adapter_model = pc.get_model( + name=self.adapter_id, + version=self.adapter_version, + model_id=None, + ) + except ServerResponseError: + # Predibase does not recognize the adapter ID (query HuggingFace). + adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}") + result = base_llm_deployment.with_adapter(model=adapter_model).generate( + prompt=prompt, + options=options, + ) + else: + result = base_llm_deployment.generate( + prompt=prompt, + options=options, + ) + return result.response + + from predibase import Predibase + + os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com" + predibase: Predibase = Predibase( + api_token=self.predibase_api_key.get_secret_value() + ) + + import requests + from lorax.client import Client as LoraxClient + from lorax.errors import GenerationError + from lorax.types import Response + + lorax_client: LoraxClient = predibase.deployments.client( + deployment_ref=self.model ) - result: GeneratedResponse + + response: Response if self.adapter_id: """ Attempt to retrieve the fine-tuned adapter from a Predibase repository. If absent, then load the fine-tuned adapter from a HuggingFace repository. """ - adapter_model: Union[Model, HuggingFaceLLM] + if self.adapter_version: + # Since the adapter version is provided, query the Predibase repository. + pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}" + try: + response = lorax_client.generate( + prompt=prompt, + adapter_id=pb_adapter_id, + **options, + ) + except GenerationError as ge: + raise ValueError( + f"""An adapter with the ID "{pb_adapter_id}" cannot be \ +found in the Predibase repository of fine-tuned adapters.""" + ) from ge + else: + # The adapter version is omitted, + # hence look for the adapter ID in the HuggingFace repository. + try: + response = lorax_client.generate( + prompt=prompt, + adapter_id=self.adapter_id, + adapter_source="hub", + **options, + ) + except GenerationError as ge: + raise ValueError( + f"""Either an adapter with the ID "{self.adapter_id}" \ +cannot be found in a HuggingFace repository, or it is incompatible with the \ +base model (please make sure that the adapter configuration is consistent). +""" + ) from ge + else: try: - adapter_model = pc.get_model( - name=self.adapter_id, - version=self.adapter_version, - model_id=None, + response = lorax_client.generate( + prompt=prompt, + **options, ) - except ServerResponseError: - # Predibase does not recognize the adapter ID (query HuggingFace). - adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}") - result = base_llm_deployment.with_adapter(model=adapter_model).generate( - prompt=prompt, - options=options, - ) - else: - result = base_llm_deployment.generate( - prompt=prompt, - options=options, - ) - return result.response + except requests.JSONDecodeError as jde: + raise ValueError( + f"""An LLM with the deployment ID "{self.model}" cannot be found \ +at Predibase (please refer to \ +"https://docs.predibase.com/user-guide/inference/models" for the list of \ +supported models). +""" + ) from jde + response_text = response.generated_text + + return response_text @property def _identifying_params(self) -> Mapping[str, Any]: @@ -112,3 +189,26 @@ def _identifying_params(self) -> Mapping[str, Any]: return { **{"model_kwargs": self.model_kwargs}, } + + def _is_deprecated_sdk_version(self) -> bool: + try: + import semantic_version + from predibase.version import __version__ as current_version + from semantic_version.base import Version + + sdk_semver_deprecated: Version = semantic_version.Version( + version_string="2024.4.8" + ) + actual_current_version: str = self.predibase_sdk_version or current_version + sdk_semver_current: Version = semantic_version.Version( + version_string=actual_current_version + ) + return not ( + (sdk_semver_current > sdk_semver_deprecated) + or ("+dev" in actual_current_version) + ) + except ImportError as e: + raise ImportError( + "Could not import Predibase Python package. " + "Please install it with `pip install semantic_version predibase`." + ) from e diff --git a/libs/community/tests/unit_tests/llms/test_predibase.py b/libs/community/tests/unit_tests/llms/test_predibase.py index 9a9fba7f0effc..e2b2dc128d9a9 100644 --- a/libs/community/tests/unit_tests/llms/test_predibase.py +++ b/libs/community/tests/unit_tests/llms/test_predibase.py @@ -19,6 +19,22 @@ def test_api_key_masked_when_passed_via_constructor( assert captured.out == "**********" +def test_specifying_predibase_sdk_version_argument() -> None: + llm = Predibase( + model="my_llm", + predibase_api_key="secret-api-key", + ) + assert not llm.predibase_sdk_version + + legacy_predibase_sdk_version = "2024.4.8" + llm = Predibase( + model="my_llm", + predibase_api_key="secret-api-key", + predibase_sdk_version=legacy_predibase_sdk_version, + ) + assert llm.predibase_sdk_version == legacy_predibase_sdk_version + + def test_specifying_adapter_id_argument() -> None: llm = Predibase(model="my_llm", predibase_api_key="secret-api-key") assert not llm.adapter_id @@ -33,8 +49,8 @@ def test_specifying_adapter_id_argument() -> None: llm = Predibase( model="my_llm", - adapter_id="my-other-hf-adapter", predibase_api_key="secret-api-key", + adapter_id="my-other-hf-adapter", ) assert llm.adapter_id == "my-other-hf-adapter" assert llm.adapter_version is None @@ -55,9 +71,9 @@ def test_specifying_adapter_id_and_adapter_version_arguments() -> None: llm = Predibase( model="my_llm", + predibase_api_key="secret-api-key", adapter_id="my-other-hf-adapter", adapter_version=3, - predibase_api_key="secret-api-key", ) assert llm.adapter_id == "my-other-hf-adapter" assert llm.adapter_version == 3