Skip to content

Commit

Permalink
feat: Added WatsonxRerank integration, update logic of passing `param…
Browse files Browse the repository at this point in the history
…s` (#33)

* add: Added WatsonxRerank integration
  • Loading branch information
MateuszOssGit authored Oct 29, 2024
1 parent b14a876 commit a605c74
Show file tree
Hide file tree
Showing 17 changed files with 900 additions and 278 deletions.
3 changes: 2 additions & 1 deletion libs/ibm/langchain_ibm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from langchain_ibm.chat_models import ChatWatsonx
from langchain_ibm.embeddings import WatsonxEmbeddings
from langchain_ibm.llms import WatsonxLLM
from langchain_ibm.rerank import WatsonxRerank

__all__ = ["WatsonxLLM", "WatsonxEmbeddings", "ChatWatsonx"]
__all__ = ["WatsonxLLM", "WatsonxEmbeddings", "ChatWatsonx", "WatsonxRerank"]
25 changes: 4 additions & 21 deletions libs/ibm/langchain_ibm/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ibm_watsonx_ai import APIClient, Credentials # type: ignore
from ibm_watsonx_ai.foundation_models import ModelInference # type: ignore
from ibm_watsonx_ai.foundation_models.schema import ( # type: ignore
BaseSchema,
TextChatParameters,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
Expand Down Expand Up @@ -74,7 +73,7 @@
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_ibm.utils import check_for_attribute
from langchain_ibm.utils import check_for_attribute, extract_params

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -668,32 +667,16 @@ def _stream(
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]], **kwargs: Any
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = (
{
**(
self.params.to_dict()
if isinstance(self.params, BaseSchema)
else self.params
)
}
if self.params
else {}
)
params = params | {
**(
kwargs.get("params", {}).to_dict()
if isinstance(kwargs.get("params", {}), BaseSchema)
else kwargs.get("params", {})
)
}
params = extract_params(kwargs, self.params)

if stop is not None:
if params and "stop_sequences" in params:
raise ValueError(
"`stop_sequences` found in both the input and default params."
)
params = (params or {}) | {"stop_sequences": stop}
message_dicts = [_convert_message_to_dict(m, self.model_id) for m in messages]
return message_dicts, params
return message_dicts, params or {}

def _create_chat_result(
self, response: dict, generation_info: Optional[Dict] = None
Expand Down
19 changes: 10 additions & 9 deletions libs/ibm/langchain_ibm/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import logging
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from ibm_watsonx_ai import APIClient, Credentials # type: ignore
from ibm_watsonx_ai.foundation_models.embeddings import Embeddings # type: ignore
Expand All @@ -10,7 +8,7 @@
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_ibm.utils import check_for_attribute
from langchain_ibm.utils import check_for_attribute, extract_params

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,7 +61,7 @@ class WatsonxEmbeddings(BaseModel, LangChainEmbeddings):
version: Optional[SecretStr] = None
"""Version of the CPD instance."""

params: Optional[dict] = None
params: Optional[Dict] = None
"""Model parameters to use during request generation."""

verify: Union[str, bool, None] = None
Expand Down Expand Up @@ -151,10 +149,13 @@ def validate_environment(self) -> Self:

return self

def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: List[str], **kwargs: Any) -> List[List[float]]:
"""Embed search docs."""
return self.watsonx_embed.embed_documents(texts=texts)
params = extract_params(kwargs, self.params)
return self.watsonx_embed.embed_documents(
texts=texts, **(kwargs | {"params": params})
)

def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str, **kwargs: Any) -> List[float]:
"""Embed query text."""
return self.embed_documents([text])[0]
return self.embed_documents([text], **kwargs)[0]
11 changes: 4 additions & 7 deletions libs/ibm/langchain_ibm/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_ibm.utils import check_for_attribute
from langchain_ibm.utils import check_for_attribute, extract_params

logger = logging.getLogger(__name__)
textgen_valid_params = [
Expand Down Expand Up @@ -305,12 +305,9 @@ def _override_chat_params(
def _get_chat_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
params = (
{**self.params, **kwargs.pop("params", {})}
if self.params
else kwargs.pop("params", {})
)
params, kwargs = self._override_chat_params(params, **kwargs)
params = extract_params(kwargs, self.params)

params, kwargs = self._override_chat_params(params or {}, **kwargs)
if stop is not None:
if params and "stop_sequences" in params:
raise ValueError(
Expand Down
232 changes: 232 additions & 0 deletions libs/ibm/langchain_ibm/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
from __future__ import annotations

from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Union

from ibm_watsonx_ai import APIClient, Credentials # type: ignore
from ibm_watsonx_ai.foundation_models import Rerank # type: ignore
from ibm_watsonx_ai.foundation_models.schema import ( # type: ignore
RerankParameters,
)
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.utils.utils import secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_ibm.utils import check_for_attribute, extract_params


class WatsonxRerank(BaseDocumentCompressor):
"""Document compressor that uses `watsonx Rerank API`."""

model_id: str
"""Type of model to use."""

project_id: Optional[str] = None
"""ID of the Watson Studio project."""

space_id: Optional[str] = None
"""ID of the Watson Studio space."""

url: SecretStr = Field(
alias="url", default_factory=secret_from_env("WATSONX_URL", default=None)
)
"""URL to the Watson Machine Learning or CPD instance."""

apikey: Optional[SecretStr] = Field(
alias="apikey", default_factory=secret_from_env("WATSONX_APIKEY", default=None)
)
"""API key to the Watson Machine Learning or CPD instance."""

token: Optional[SecretStr] = Field(
alias="token", default_factory=secret_from_env("WATSONX_TOKEN", default=None)
)
"""Token to the CPD instance."""

password: Optional[SecretStr] = Field(
alias="password",
default_factory=secret_from_env("WATSONX_PASSWORD", default=None),
)
"""Password to the CPD instance."""

username: Optional[SecretStr] = Field(
alias="username",
default_factory=secret_from_env("WATSONX_USERNAME", default=None),
)
"""Username to the CPD instance."""

instance_id: Optional[SecretStr] = Field(
alias="instance_id",
default_factory=secret_from_env("WATSONX_INSTANCE_ID", default=None),
)
"""Instance_id of the CPD instance."""

version: Optional[SecretStr] = None
"""Version of the CPD instance."""

params: Optional[Union[dict, RerankParameters]] = None
"""Model parameters to use during request generation."""

verify: Union[str, bool, None] = None
"""You can pass one of following as verify:
* the path to a CA_BUNDLE file
* the path of directory with certificates of trusted CAs
* True - default path to truststore will be taken
* False - no verification will be made"""

validate_model: bool = True
"""Model ID validation."""

streaming: bool = False
""" Whether to stream the results or not. """

watsonx_rerank: Rerank = Field(default=None, exclude=True) #: :meta private:

watsonx_client: Optional[APIClient] = Field(default=None, exclude=True)

model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
protected_namespaces=(),
)

@property
def lc_secrets(self) -> Dict[str, str]:
"""A map of constructor argument names to secret ids.
For example:
{
"url": "WATSONX_URL",
"apikey": "WATSONX_APIKEY",
"token": "WATSONX_TOKEN",
"password": "WATSONX_PASSWORD",
"username": "WATSONX_USERNAME",
"instance_id": "WATSONX_INSTANCE_ID",
}
"""
return {
"url": "WATSONX_URL",
"apikey": "WATSONX_APIKEY",
"token": "WATSONX_TOKEN",
"password": "WATSONX_PASSWORD",
"username": "WATSONX_USERNAME",
"instance_id": "WATSONX_INSTANCE_ID",
}

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that credentials and python package exists in environment."""
if isinstance(self.watsonx_client, APIClient):
watsonx_rerank = Rerank(
model_id=self.model_id,
params=self.params,
api_client=self.watsonx_client,
project_id=self.project_id,
space_id=self.space_id,
verify=self.verify,
)
self.watsonx_rerank = watsonx_rerank

else:
check_for_attribute(self.url, "url", "WATSONX_URL")

if "cloud.ibm.com" in self.url.get_secret_value():
check_for_attribute(self.apikey, "apikey", "WATSONX_APIKEY")
else:
if not self.token and not self.password and not self.apikey:
raise ValueError(
"Did not find 'token', 'password' or 'apikey',"
" please add an environment variable"
" `WATSONX_TOKEN`, 'WATSONX_PASSWORD' or 'WATSONX_APIKEY' "
"which contains it,"
" or pass 'token', 'password' or 'apikey'"
" as a named parameter."
)
elif self.token:
check_for_attribute(self.token, "token", "WATSONX_TOKEN")
elif self.password:
check_for_attribute(self.password, "password", "WATSONX_PASSWORD")
check_for_attribute(self.username, "username", "WATSONX_USERNAME")
elif self.apikey:
check_for_attribute(self.apikey, "apikey", "WATSONX_APIKEY")
check_for_attribute(self.username, "username", "WATSONX_USERNAME")

if not self.instance_id:
check_for_attribute(
self.instance_id, "instance_id", "WATSONX_INSTANCE_ID"
)

credentials = Credentials(
url=self.url.get_secret_value() if self.url else None,
api_key=self.apikey.get_secret_value() if self.apikey else None,
token=self.token.get_secret_value() if self.token else None,
password=self.password.get_secret_value() if self.password else None,
username=self.username.get_secret_value() if self.username else None,
instance_id=self.instance_id.get_secret_value()
if self.instance_id
else None,
version=self.version.get_secret_value() if self.version else None,
verify=self.verify,
)

watsonx_rerank = Rerank(
model_id=self.model_id,
credentials=credentials,
params=self.params,
project_id=self.project_id,
space_id=self.space_id,
verify=self.verify,
)
self.watsonx_rerank = watsonx_rerank

return self

def rerank(
self,
documents: Sequence[Union[str, Document, dict]],
query: str,
**kwargs: Any,
) -> List[Dict[str, Any]]:
if len(documents) == 0: # to avoid empty api call
return []
docs = [
doc.page_content if isinstance(doc, Document) else doc for doc in documents
]
params = extract_params(kwargs, self.params)

results = self.watsonx_rerank.generate(
query=query, inputs=docs, **(kwargs | {"params": params})
)
result_dicts = []
for res in results["results"]:
result_dicts.append(
{"index": res.get("index"), "relevance_score": res.get("score")}
)
return result_dicts

def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
**kwargs: Any,
) -> Sequence[Document]:
"""
Compress documents using watsonx's rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
compressed = []
for res in self.rerank(documents, query, **kwargs):
doc = documents[res["index"]]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res["relevance_score"]
compressed.append(doc_copy)
return compressed
20 changes: 20 additions & 0 deletions libs/ibm/langchain_ibm/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any, Dict, Optional, Union

from ibm_watsonx_ai.foundation_models.schema import BaseSchema # type: ignore
from pydantic import SecretStr


Expand All @@ -8,3 +11,20 @@ def check_for_attribute(value: SecretStr | None, key: str, env_key: str) -> None
f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter."
)


def extract_params(
kwargs: Dict[str, Any],
default_params: Optional[Union[BaseSchema, Dict[str, Any]]] = None,
) -> Dict[str, Any]:
if kwargs.get("params") is not None:
params = kwargs.pop("params")
elif default_params is not None:
params = default_params
else:
params = None

if isinstance(params, BaseSchema):
params = params.to_dict()

return params or {}
Loading

0 comments on commit a605c74

Please sign in to comment.