Skip to content

Commit

Permalink
add support for langgraph for mixtrall
Browse files Browse the repository at this point in the history
  • Loading branch information
MateuszOssGit committed Jul 12, 2024
1 parent ab61442 commit 4281572
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 60 deletions.
178 changes: 129 additions & 49 deletions libs/ibm/langchain_ibm/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
InvalidToolCall,
SystemMessage,
SystemMessageChunk,
ToolCall,
ToolMessage,
ToolMessageChunk,
convert_to_messages,
Expand All @@ -67,14 +68,12 @@
logger = logging.getLogger(__name__)


def _convert_dict_to_message(
_dict: Mapping[str, Any], timestamp_id: str
) -> BaseMessage:
def _convert_dict_to_message(_dict: Mapping[str, Any], call_id: str) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Args:
_dict: The dictionary.
timestamp_id: timestamp id
call_id: call id
Returns:
The LangChain message.
Expand Down Expand Up @@ -106,7 +105,7 @@ def _convert_dict_to_message(
parsed = {
"name": json_parts["function"]["name"] or "",
"args": json_parts["function"]["arguments"] or {},
"id": timestamp_id,
"id": call_id,
}
tool_calls.append(parsed)

Expand All @@ -124,6 +123,50 @@ def _convert_dict_to_message(
)


def _format_message_content(content: Any) -> Any:
"""Format message content."""
if content and isinstance(content, list):
# Remove unexpected block types
formatted_content = []
for block in content:
if (
isinstance(block, dict)
and "type" in block
and block["type"] == "tool_use"
):
continue
else:
formatted_content.append(block)
else:
formatted_content = content

return formatted_content


def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
return {
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
}


def _lc_invalid_tool_call_to_openai_tool_call(
invalid_tool_call: InvalidToolCall,
) -> dict:
return {
"type": "function",
"id": invalid_tool_call["id"],
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
},
}


def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Expand All @@ -133,50 +176,66 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
message_dict: Dict[str, Any] = {"content": _format_message_content(message.content)}
if (name := message.name or message.additional_kwargs.get("name")) is not None:
message_dict["name"] = name

# populate role and additional message data
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
message_dict["role"] = message.role
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
message_dict["role"] = "user"
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
message_dict["role"] = "assistant"
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
if "tool_calls" in message.additional_kwargs:
if message.tool_calls or message.invalid_tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
] + [
_lc_invalid_tool_call_to_openai_tool_call(tc)
for tc in message.invalid_tool_calls
]
elif "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
tool_call_supported_props = {"id", "type", "function"}
message_dict["tool_calls"] = [
{k: v for k, v in tool_call.items() if k in tool_call_supported_props}
for tool_call in message_dict["tool_calls"]
]
else:
pass
# If tool calls present, content null value should be None not empty string.
if "function_call" in message_dict or "tool_calls" in message_dict:
message_dict["content"] = message_dict["content"] or ""
message_dict["tool_calls"][0]["name"] = message_dict["tool_calls"][0][
"function"
]["name"]
message_dict["tool_calls"][0]["args"] = json.loads(
message_dict["tool_calls"][0]["function"]["arguments"]
)

elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
message_dict["role"] = "system"
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
message_dict["role"] = "function"
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": "None",
}
message_dict["role"] = "tool"
message_dict["tool_call_id"] = message.tool_call_id

supported_props = {"content", "role", "tool_call_id"}
message_dict = {k: v for k, v in message_dict.items() if k in supported_props}
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict


def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
id_ = _dict.get("id")
id_ = "sample_id"
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
content = cast(str, _dict.get("generated_text") or "")
additional_kwargs: Dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
Expand Down Expand Up @@ -446,7 +505,17 @@ def _generate(
return generate_from_stream(stream_iter)

message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
chat_prompt = self._create_chat_prompt(message_dicts)
if message_dicts[-1].get("role") == "tool":
chat_prompt = (
"User: Please summarize given sentences into "
"JSON containing Final Answer: '"
)
for message in message_dicts:
if message["content"]:
chat_prompt += message["content"] + "\n"
chat_prompt += "'"
else:
chat_prompt = self._create_chat_prompt(message_dicts)

tools = kwargs.get("tools")

Expand Down Expand Up @@ -495,15 +564,16 @@ def _generate(
Remember to end your response with '</endoftext>'
{chat_prompt[:-5]}
{chat_prompt}
(reminder to respond in a JSON blob no matter what and use tools only if necessary)"""

if "tools" in kwargs:
del kwargs["tools"]
if "tool_choice" in kwargs:
del kwargs["tool_choice"]
params = params | {"stop_sequences": ["</endoftext>"]}

if "tools" in kwargs:
del kwargs["tools"]
if "tool_choice" in kwargs:
del kwargs["tool_choice"]

response = self.watsonx_model.generate(
prompt=chat_prompt, **(kwargs | {"params": params})
)
Expand All @@ -517,7 +587,17 @@ def _stream(
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
chat_prompt = self._create_chat_prompt(message_dicts)
if message_dicts[-1].get("role") == "tool":
chat_prompt = (
"User: Please summarize given sentences into JSON "
"containing Final Answer: '"
)
for message in message_dicts:
if message["content"]:
chat_prompt += message["content"] + "\n"
chat_prompt += "'"
else:
chat_prompt = self._create_chat_prompt(message_dicts)

tools = kwargs.get("tools")

Expand Down Expand Up @@ -569,11 +649,13 @@ def _stream(
{chat_prompt[:-5]}
(reminder to respond in a JSON blob no matter what and use tools only if necessary)"""

if "tools" in kwargs:
del kwargs["tools"]
if "tool_choice" in kwargs:
del kwargs["tool_choice"]
params = params | {"stop_sequences": ["</endoftext>"]}
if "tools" in kwargs:
del kwargs["tools"]
if "tool_choice" in kwargs:
del kwargs["tool_choice"]

default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk

for chunk in self.watsonx_model.generate_text_stream(
prompt=chat_prompt, raw_response=True, **(kwargs | {"params": params})
Expand All @@ -584,17 +666,15 @@ def _stream(
continue
choice = chunk["results"][0]

chunk = AIMessageChunk(
content=choice["generated_text"],
)
message_chunk = _convert_delta_to_message_chunk(choice, default_chunk_class)
generation_info = {}
if finish_reason := choice.get("stop_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
Expand Down Expand Up @@ -658,17 +738,17 @@ def _create_chat_result(self, response: Union[dict]) -> ChatResult:
generations = []
sum_of_total_generated_tokens = 0
sum_of_total_input_tokens = 0
timestamp_id = ""
call_id = ""
date_string = response.get("created_at")
if date_string:
date_object = datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%fZ")
timestamp_id = str(date_object.timestamp())
call_id = str(date_object.timestamp())

if response.get("error"):
raise ValueError(response.get("error"))

for res in response["results"]:
message = _convert_dict_to_message(res, timestamp_id)
message = _convert_dict_to_message(res, call_id)
generation_info = dict(finish_reason=res.get("stop_reason"))
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
Expand Down
19 changes: 9 additions & 10 deletions libs/ibm/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/ibm/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-ibm"
version = "0.1.9rc0"
version = "0.1.10"
description = "An integration package connecting IBM watsonx.ai and LangChain"
authors = ["IBM"]
readme = "README.md"
Expand Down

0 comments on commit 4281572

Please sign in to comment.