Skip to content

Commit

Permalink
feat: Simplify ChatWatsonx implementation, added base tests (#31)
Browse files Browse the repository at this point in the history
feat: Simplify ChatWatsonx implementation, added base tests
  • Loading branch information
MateuszOssGit authored Oct 14, 2024
1 parent 921dd85 commit 086cc32
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 234 deletions.
121 changes: 96 additions & 25 deletions libs/ibm/langchain_ibm/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""IBM watsonx.ai large language chat models wrapper."""

import hashlib
import json
import logging
from operator import itemgetter
Expand Down Expand Up @@ -158,11 +159,37 @@ def _format_message_content(content: Any) -> Any:
return formatted_content


def _convert_message_to_dict(message: BaseMessage) -> dict:
def _base62_encode(num: int) -> str:
"""Encodes a number in base62 and ensures result is of a specified length."""
base62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
if num == 0:
return base62[0]
arr = []
base = len(base62)
while num:
num, rem = divmod(num, base)
arr.append(base62[rem])
arr.reverse()
return "".join(arr)


def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
"""Convert a tool call ID to a Mistral-compatible format"""
hash_bytes = hashlib.sha256(tool_call_id.encode()).digest()
hash_int = int.from_bytes(hash_bytes, byteorder="big")
base62_str = _base62_encode(hash_int)
if len(base62_str) >= 9:
return base62_str[:9]
else:
return base62_str.rjust(9, "0")


def _convert_message_to_dict(message: BaseMessage, model_id: str | None) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
model_id: Type of model to use.
Returns:
The dictionary.
Expand Down Expand Up @@ -199,6 +226,22 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
# If tool calls present, content null value should be None not empty string.
if "function_call" in message_dict or "tool_calls" in message_dict:
message_dict["content"] = message_dict["content"] or None

# Workaround for "mistralai/mistral-large" model when id < 9
if model_id and model_id.startswith("mistralai"):
tool_calls = message_dict.get("tool_calls", [])
if (
isinstance(tool_calls, list)
and tool_calls
and isinstance(tool_calls[0], dict)
):
tool_call_id = tool_calls[0].get("id", "")
if len(tool_call_id) < 9:
tool_call_id = _convert_tool_call_id_to_mistral_compatible(
tool_call_id
)

message_dict["tool_calls"][0]["id"] = tool_call_id
elif isinstance(message, SystemMessage):
message_dict["role"] = "system"
elif isinstance(message, FunctionMessage):
Expand All @@ -207,6 +250,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict["role"] = "tool"
message_dict["tool_call_id"] = message.tool_call_id

# Workaround for "mistralai/mistral-large" model when tool_call_id < 9
if model_id and model_id.startswith("mistralai"):
tool_call_id = message_dict.get("tool_call_id", "")
if len(tool_call_id) < 9:
tool_call_id = _convert_tool_call_id_to_mistral_compatible(tool_call_id)

message_dict["tool_call_id"] = tool_call_id

supported_props = {"content", "role", "tool_call_id"}
message_dict = {k: v for k, v in message_dict.items() if k in supported_props}
else:
Expand All @@ -218,7 +269,7 @@ def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any],
default_class: Type[BaseMessageChunk],
call_id: str,
finish_reason: str,
is_first_tool_chunk: bool,
) -> BaseMessageChunk:
id_ = call_id
role = cast(str, _dict.get("role"))
Expand All @@ -235,11 +286,9 @@ def _convert_delta_to_message_chunk(
try:
tool_call_chunks = [
tool_call_chunk(
name=rtc["function"].get("name")
if finish_reason is not None
else None,
name=rtc["function"].get("name") if is_first_tool_chunk else None,
args=rtc["function"].get("arguments"),
id=call_id if finish_reason is not None else None,
id=rtc.get("id") if is_first_tool_chunk else None,
index=rtc["index"],
)
for rtc in raw_tool_calls
Expand Down Expand Up @@ -271,19 +320,17 @@ def _convert_delta_to_message_chunk(


def _convert_chunk_to_generation_chunk(
chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict]
chunk: dict,
default_chunk_class: Type,
base_generation_info: Optional[Dict],
is_first_chunk: bool,
is_first_tool_chunk: bool,
) -> Optional[ChatGenerationChunk]:
token_usage = chunk.get("usage")
choices = chunk.get("choices", [])

usage_metadata: Optional[UsageMetadata] = (
UsageMetadata(
input_tokens=token_usage.get("prompt_tokens", 0),
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
if token_usage
else None
_create_usage_metadata(token_usage, is_first_chunk) if token_usage else None
)

if len(choices) == 0:
Expand All @@ -298,7 +345,7 @@ def _convert_chunk_to_generation_chunk(
return None

message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class, chunk["id"], choice["finish_reason"]
choice["delta"], default_chunk_class, chunk["id"], is_first_tool_chunk
)
generation_info = {**base_generation_info} if base_generation_info else {}

Expand Down Expand Up @@ -547,11 +594,9 @@ def _generate(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
Expand All @@ -577,6 +622,7 @@ def _stream(
base_generation_info: dict = {}

is_first_chunk = True
is_first_tool_chunk = True

for chunk in self.watsonx_model.chat_stream(
messages=message_dicts, **(kwargs | {"params": params})
Expand All @@ -587,6 +633,8 @@ def _stream(
chunk,
default_chunk_class,
base_generation_info if is_first_chunk else {},
is_first_chunk,
is_first_tool_chunk,
)
if generation_chunk is None:
continue
Expand All @@ -596,7 +644,19 @@ def _stream(
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
if hasattr(generation_chunk.message, "tool_calls") and isinstance(
generation_chunk.message.tool_calls, list
):
first_tool_call = (
generation_chunk.message.tool_calls[0]
if generation_chunk.message.tool_calls
else None
)
if isinstance(first_tool_call, dict) and first_tool_call.get("name"):
is_first_tool_chunk = False

is_first_chunk = False

yield generation_chunk

def _create_message_dicts(
Expand Down Expand Up @@ -626,7 +686,7 @@ def _create_message_dicts(
"`stop_sequences` found in both the input and default params."
)
params = (params or {}) | {"stop_sequences": stop}
message_dicts = [_convert_message_to_dict(m) for m in messages]
message_dicts = [_convert_message_to_dict(m, self.model_id) for m in messages]
return message_dicts, params

def _create_chat_result(
Expand All @@ -643,11 +703,7 @@ def _create_chat_result(
message = _convert_dict_to_message(res["message"], response["id"])

if token_usage and isinstance(message, AIMessage):
message.usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
message.usage_metadata = _create_usage_metadata(token_usage, True)
generation_info = generation_info or {}
generation_info["finish_reason"] = (
res.get("finish_reason")
Expand Down Expand Up @@ -957,7 +1013,9 @@ class AnswerWithJustification(BaseModel):
"Received None."
)
# specifying a tool.
llm = self.bind_tools([schema], tool_choice="auto")
tool_name = convert_to_openai_tool(schema)["function"]["name"]
tool_choice = {"type": "function", "function": {"name": tool_name}}
llm = self.bind_tools([schema], tool_choice=tool_choice)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
Expand Down Expand Up @@ -1014,3 +1072,16 @@ def _lc_invalid_tool_call_to_watsonx_tool_call(
"arguments": invalid_tool_call["args"],
},
}


def _create_usage_metadata(
oai_token_usage: dict, is_first_chunk: bool
) -> UsageMetadata:
input_tokens = oai_token_usage.get("prompt_tokens", 0) if is_first_chunk else 0
output_tokens = oai_token_usage.get("completion_tokens", 0)
total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens)
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)
Loading

0 comments on commit 086cc32

Please sign in to comment.