Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Support both Predibase SDK-v1 and SDK-v2 in Predibase-LangChain integration #20859

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions docs/docs/integrations/llms/predibase.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@
"source": [
"from langchain_community.llms import Predibase\n",
"\n",
"# With a fine-tuned adapter hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).\n",
"# With a fine-tuned adapter hosted at Predibase (adapter_version must be specified).\n",
"model = Predibase(\n",
" model=\"mistral-7b\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
" adapter_id=\"e2e_nlg\",\n",
" adapter_version=1,\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
")"
]
},
Expand All @@ -83,8 +84,9 @@
"# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n",
"model = Predibase(\n",
" model=\"mistral-7b\",\n",
" adapter_id=\"predibase/e2e_nlg\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
" adapter_id=\"predibase/e2e_nlg\",\n",
")"
]
},
Expand Down Expand Up @@ -122,7 +124,9 @@
"from langchain_community.llms import Predibase\n",
"\n",
"model = Predibase(\n",
" model=\"mistral-7b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
" model=\"mistral-7b\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
")"
]
},
Expand All @@ -136,12 +140,13 @@
},
"outputs": [],
"source": [
"# With a fine-tuned adapter hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).\n",
"# With a fine-tuned adapter hosted at Predibase (adapter_version must be specified).\n",
"model = Predibase(\n",
" model=\"mistral-7b\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
" adapter_id=\"e2e_nlg\",\n",
" adapter_version=1,\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
")"
]
},
Expand All @@ -154,8 +159,9 @@
"# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n",
"llm = Predibase(\n",
" model=\"mistral-7b\",\n",
" adapter_id=\"predibase/e2e_nlg\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
" adapter_id=\"predibase/e2e_nlg\",\n",
")"
]
},
Expand Down Expand Up @@ -247,13 +253,14 @@
"\n",
"model = Predibase(\n",
" model=\"my-base-LLM\",\n",
" adapter_id=\"my-finetuned-adapter-id\", # Supports both, Predibase-hosted and HuggingFace-hosted model repositories.\n",
" # adapter_version=1, # optional (returns the latest, if omitted)\n",
" predibase_api_key=os.environ.get(\n",
" \"PREDIBASE_API_TOKEN\"\n",
" ), # Adapter argument is optional.\n",
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
" adapter_id=\"my-finetuned-adapter-id\", # Supports both, Predibase-hosted and HuggingFace-hosted adapter repositories.\n",
" adapter_version=1, # required for Predibase-hosted adapters (ignored for HuggingFace-hosted adapters)\n",
")\n",
"# replace my-finetuned-LLM with the name of your model in Predibase"
"# replace my-base-LLM with the name of your choice of a serverless base model in Predibase"
]
},
{
Expand Down
23 changes: 19 additions & 4 deletions docs/docs/integrations/providers/predibase.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"

from langchain_community.llms import Predibase

model = Predibase(model="mistral-7b"", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
model = Predibase(
model="mistral-7b",
predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"),
predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)
)

response = model("Can you recommend me a nice dry wine?")
print(response)
Expand All @@ -31,8 +35,14 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"

from langchain_community.llms import Predibase

# The fine-tuned adapter is hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).
model = Predibase(model="mistral-7b"", adapter_id="e2e_nlg", adapter_version=1, predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
# The fine-tuned adapter is hosted at Predibase (adapter_version must be specified).
model = Predibase(
model="mistral-7b",
predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"),
predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)
adapter_id="e2e_nlg",
adapter_version=1,
)

response = model("Can you recommend me a nice dry wine?")
print(response)
Expand All @@ -47,7 +57,12 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
from langchain_community.llms import Predibase

# The fine-tuned adapter is hosted at HuggingFace (adapter_version does not apply and will be ignored).
model = Predibase(model="mistral-7b"", adapter_id="predibase/e2e_nlg", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
model = Predibase(
model="mistral-7b",
predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"),
predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)
adapter_id="predibase/e2e_nlg",
)

response = model("Can you recommend me a nice dry wine?")
print(response)
Expand Down
202 changes: 151 additions & 51 deletions libs/community/langchain_community/llms/predibase.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Any, Dict, List, Mapping, Optional, Union

from langchain_core.callbacks import CallbackManagerForLLMRun
Expand All @@ -17,13 +18,15 @@ class Predibase(LLM):
An optional `adapter_id` parameter is the Predibase ID or HuggingFace ID of a
fine-tuned LLM adapter, whose base model is the `model` parameter; the
fine-tuned adapter must be compatible with its base model;
otherwise, an error is raised. If a Predibase ID references the
fine-tuned adapter, then the `adapter_version` in the adapter repository can
be optionally specified; omitting it defaults to the most recent version.
otherwise, an error is raised. If the fine-tuned adapter is hosted at Predibase,
then `adapter_version` in the adapter repository must be specified.

An optional `predibase_sdk_version` parameter defaults to latest SDK version.
"""

model: str
predibase_api_key: SecretStr
predibase_sdk_version: Optional[str] = None
adapter_id: Optional[str] = None
adapter_version: Optional[int] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
Expand All @@ -46,69 +49,166 @@ def _call(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
try:
from predibase import PredibaseClient
from predibase.pql import get_session
from predibase.pql.api import (
ServerResponseError,
Session,
)
from predibase.resource.llm.interface import (
HuggingFaceLLM,
LLMDeployment,
)
from predibase.resource.llm.response import GeneratedResponse
from predibase.resource.model import Model

session: Session = get_session(
token=self.predibase_api_key.get_secret_value(),
gateway="https://api.app.predibase.com/v1",
serving_endpoint="serving.app.predibase.com",
)
pc: PredibaseClient = PredibaseClient(session=session)
except ImportError as e:
raise ImportError(
"Could not import Predibase Python package. "
"Please install it with `pip install predibase`."
) from e
except ValueError as e:
raise ValueError("Your API key is not correct. Please try again") from e
options: Dict[str, Union[str, float]] = (
self.model_kwargs or self.default_options_for_generation
)
base_llm_deployment: LLMDeployment = pc.LLM(
uri=f"pb://deployments/{self.model}"
if self._is_deprecated_sdk_version():
try:
from predibase import PredibaseClient
from predibase.pql import get_session
from predibase.pql.api import (
ServerResponseError,
Session,
)
from predibase.resource.llm.interface import (
HuggingFaceLLM,
LLMDeployment,
)
from predibase.resource.llm.response import GeneratedResponse
from predibase.resource.model import Model

session: Session = get_session(
token=self.predibase_api_key.get_secret_value(),
gateway="https://api.app.predibase.com/v1",
serving_endpoint="serving.app.predibase.com",
)
pc: PredibaseClient = PredibaseClient(session=session)
except ImportError as e:
raise ImportError(
"Could not import Predibase Python package. "
"Please install it with `pip install predibase`."
) from e
except ValueError as e:
raise ValueError("Your API key is not correct. Please try again") from e

base_llm_deployment: LLMDeployment = pc.LLM(
uri=f"pb://deployments/{self.model}"
)
result: GeneratedResponse
if self.adapter_id:
"""
Attempt to retrieve the fine-tuned adapter from a Predibase
repository. If absent, then load the fine-tuned adapter
from a HuggingFace repository.
"""
adapter_model: Union[Model, HuggingFaceLLM]
try:
adapter_model = pc.get_model(
name=self.adapter_id,
version=self.adapter_version,
model_id=None,
)
except ServerResponseError:
# Predibase does not recognize the adapter ID (query HuggingFace).
adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}")
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
prompt=prompt,
options=options,
)
else:
result = base_llm_deployment.generate(
prompt=prompt,
options=options,
)
return result.response

from predibase import Predibase

os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com"
predibase: Predibase = Predibase(
api_token=self.predibase_api_key.get_secret_value()
)

import requests
from lorax.client import Client as LoraxClient
from lorax.errors import GenerationError
from lorax.types import Response

lorax_client: LoraxClient = predibase.deployments.client(
deployment_ref=self.model
)
result: GeneratedResponse

response: Response
if self.adapter_id:
"""
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
If absent, then load the fine-tuned adapter from a HuggingFace repository.
"""
adapter_model: Union[Model, HuggingFaceLLM]
if self.adapter_version:
# Since the adapter version is provided, query the Predibase repository.
pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}"
try:
response = lorax_client.generate(
prompt=prompt,
adapter_id=pb_adapter_id,
**options,
)
except GenerationError as ge:
raise ValueError(
f"""An adapter with the ID "{pb_adapter_id}" cannot be \
found in the Predibase repository of fine-tuned adapters."""
) from ge
else:
# The adapter version is omitted,
# hence look for the adapter ID in the HuggingFace repository.
try:
response = lorax_client.generate(
prompt=prompt,
adapter_id=self.adapter_id,
adapter_source="hub",
**options,
)
except GenerationError as ge:
raise ValueError(
f"""Either an adapter with the ID "{self.adapter_id}" \
cannot be found in a HuggingFace repository, or it is incompatible with the \
base model (please make sure that the adapter configuration is consistent).
"""
) from ge
else:
try:
adapter_model = pc.get_model(
name=self.adapter_id,
version=self.adapter_version,
model_id=None,
response = lorax_client.generate(
prompt=prompt,
**options,
)
except ServerResponseError:
# Predibase does not recognize the adapter ID (query HuggingFace).
adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}")
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
prompt=prompt,
options=options,
)
else:
result = base_llm_deployment.generate(
prompt=prompt,
options=options,
)
return result.response
except requests.JSONDecodeError as jde:
raise ValueError(
f"""An LLM with the deployment ID "{self.model}" cannot be found \
at Predibase (please refer to \
"https://docs.predibase.com/user-guide/inference/models" for the list of \
supported models).
"""
) from jde
response_text = response.generated_text

return response_text

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"model_kwargs": self.model_kwargs},
}

def _is_deprecated_sdk_version(self) -> bool:
try:
import semantic_version
from predibase.version import __version__ as current_version
from semantic_version.base import Version

sdk_semver_deprecated: Version = semantic_version.Version(
version_string="2024.4.8"
)
actual_current_version: str = self.predibase_sdk_version or current_version
sdk_semver_current: Version = semantic_version.Version(
version_string=actual_current_version
)
return not (
(sdk_semver_current > sdk_semver_deprecated)
or ("+dev" in actual_current_version)
)
except ImportError as e:
raise ImportError(
"Could not import Predibase Python package. "
"Please install it with `pip install semantic_version predibase`."
) from e
Loading
Loading