Skip to content

Commit

Permalink
patch: Added support for passing params as keyword arg (#38)
Browse files Browse the repository at this point in the history
* Update for passing params as keyword arg
  • Loading branch information
MateuszOssGit authored Nov 13, 2024
1 parent e2f7f28 commit 9b5b9dd
Show file tree
Hide file tree
Showing 5 changed files with 544 additions and 132 deletions.
114 changes: 108 additions & 6 deletions libs/ibm/langchain_ibm/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ibm_watsonx_ai import APIClient, Credentials # type: ignore
from ibm_watsonx_ai.foundation_models import ModelInference # type: ignore
from ibm_watsonx_ai.foundation_models.schema import ( # type: ignore
BaseSchema,
TextChatParameters,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
Expand Down Expand Up @@ -64,6 +65,7 @@
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
Expand All @@ -73,7 +75,11 @@
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_ibm.utils import check_for_attribute, extract_params
from langchain_ibm.utils import (
check_duplicate_chat_params,
check_for_attribute,
extract_chat_params,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -460,7 +466,54 @@ class ChatWatsonx(BaseChatModel):
"""Version of the CPD instance."""

params: Optional[Union[dict, TextChatParameters]] = None
"""Model parameters to use during request generation."""
"""Model parameters to use during request generation.
Note:
`ValueError` is raised if the same Chat generation parameter is provided
within the params attribute and as keyword argument."""

frequency_penalty: Optional[float] = None
"""Positive values penalize new tokens based on their existing frequency in the
text so far, decreasing the model's likelihood to repeat the same line verbatim."""

logprobs: Optional[bool] = None
"""Whether to return log probabilities of the output tokens or not.
If true, returns the log probabilities of each output token returned
in the content of message."""

top_logprobs: Optional[int] = None
"""An integer specifying the number of most likely tokens to return at each
token position, each with an associated log probability. The option logprobs
must be set to true if this parameter is used."""

max_tokens: Optional[int] = None
"""The maximum number of tokens that can be generated in the chat completion.
The total length of input tokens and generated tokens is limited by the
model's context length."""

n: Optional[int] = None
"""How many chat completion choices to generate for each input message.
Note that you will be charged based on the number of generated tokens across
all of the choices. Keep n as 1 to minimize costs."""

presence_penalty: Optional[float] = None
"""Positive values penalize new tokens based on whether they appear in the
text so far, increasing the model's likelihood to talk about new topics."""

temperature: Optional[float] = None
"""What sampling temperature to use. Higher values like 0.8 will make the
output more random, while lower values like 0.2 will make it more focused
and deterministic.
We generally recommend altering this or top_p but not both."""

top_p: Optional[float] = None
"""An alternative to sampling with temperature, called nucleus sampling,
where the model considers the results of the tokens with top_p probability
mass. So 0.1 means only the tokens comprising the top 10% probability mass
are considered.
We generally recommend altering this or temperature but not both."""

verify: Union[str, bool, None] = None
"""You can pass one of following as verify:
Expand All @@ -473,7 +526,7 @@ class ChatWatsonx(BaseChatModel):
"""Model ID validation."""

streaming: bool = False
""" Whether to stream the results or not. """
"""Whether to stream the results or not."""

watsonx_model: ModelInference = Field(default=None, exclude=True) #: :meta private:

Expand Down Expand Up @@ -526,6 +579,30 @@ def lc_secrets(self) -> Dict[str, str]:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that credentials and python package exists in environment."""
self.params = self.params or {}

if isinstance(self.params, BaseSchema):
self.params = self.params.to_dict()

check_duplicate_chat_params(self.params, self.__dict__)

self.params.update(
{
k: v
for k, v in {
"frequency_penalty": self.frequency_penalty,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"max_tokens": self.max_tokens,
"n": self.n,
"presence_penalty": self.presence_penalty,
"temperature": self.temperature,
"top_p": self.top_p,
}.items()
if v is not None
}
)

if isinstance(self.watsonx_client, APIClient):
watsonx_model = ModelInference(
model_id=self.model_id,
Expand Down Expand Up @@ -608,9 +685,10 @@ def _generate(
return generate_from_stream(stream_iter)

message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
updated_params = self._merge_params(params, kwargs)

response = self.watsonx_model.chat(
messages=message_dicts, **(kwargs | {"params": params})
messages=message_dicts, **(kwargs | {"params": updated_params})
)
return self._create_chat_result(response)

Expand All @@ -622,6 +700,7 @@ def _stream(
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
updated_params = self._merge_params(params, kwargs)

default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info: dict = {}
Expand All @@ -630,7 +709,7 @@ def _stream(
is_first_tool_chunk = True

for chunk in self.watsonx_model.chat_stream(
messages=message_dicts, **(kwargs | {"params": params})
messages=message_dicts, **(kwargs | {"params": updated_params})
):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
Expand Down Expand Up @@ -664,10 +743,33 @@ def _stream(

yield generation_chunk

@staticmethod
def _merge_params(params: dict, kwargs: dict) -> dict:
param_updates = {}
for k in [
"frequency_penalty",
"logprobs",
"top_logprobs",
"max_tokens",
"n",
"presence_penalty",
"temperature",
"top_p",
]:
if kwargs.get(k) is not None:
param_updates[k] = kwargs.pop(k)

if kwargs.get("params"):
merged_params = merge_dicts(params, param_updates)
else:
merged_params = params | param_updates

return merged_params

def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]], **kwargs: Any
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = extract_params(kwargs, self.params)
params = extract_chat_params(kwargs, self.params)

if stop is not None:
if params and "stop_sequences" in params:
Expand Down
29 changes: 29 additions & 0 deletions libs/ibm/langchain_ibm/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import Any, Dict, Optional, Union

from ibm_watsonx_ai.foundation_models.schema import BaseSchema # type: ignore
Expand Down Expand Up @@ -28,3 +29,31 @@ def extract_params(
params = params.to_dict()

return params or {}


def extract_chat_params(
kwargs: Dict[str, Any],
default_params: Optional[Union[BaseSchema, Dict[str, Any]]] = None,
) -> Dict[str, Any]:
if kwargs.get("params") is not None:
params = kwargs.pop("params")
check_duplicate_chat_params(params, kwargs)
elif default_params is not None:
params = deepcopy(default_params)
else:
params = None

if isinstance(params, BaseSchema):
params = params.to_dict()

return params or {}


def check_duplicate_chat_params(params: dict, kwargs: dict) -> None:
duplicate_keys = {k for k, v in kwargs.items() if v is not None and k in params}

if duplicate_keys:
raise ValueError(
f"Duplicate parameters found in params and keyword arguments: "
f"{list(duplicate_keys)}"
)
Loading

0 comments on commit 9b5b9dd

Please sign in to comment.