Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Add tool_calls support for ChatTongyi Model, allowing it to be used as the LLM of the lastest tool calling agent #21366

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 205 additions & 53 deletions libs/community/langchain_community/chat_models/tongyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
)
Expand All @@ -20,6 +23,7 @@
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
Expand All @@ -32,7 +36,9 @@
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
)
from langchain_core.messages.utils import message_chunk_to_message
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
Expand All @@ -42,8 +48,11 @@
ChatGenerationChunk,
ChatResult,
)
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from requests.exceptions import HTTPError
from tenacity import (
before_sleep_log,
Expand Down Expand Up @@ -75,29 +84,45 @@ def convert_dict_to_message(
else HumanMessage(content=content)
)
elif role == "assistant":
tool_calls = []
invalid_tool_calls = []
if "tool_calls" in _dict:
additional_kwargs = {"tool_calls": _dict["tool_calls"]}
for raw_tool_call in _dict["tool_calls"]:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
additional_kwargs = {}
if is_chunk:
tool_call_chunks = []
if "tool_calls" in _dict:
additional_kwargs["tool_calls"] = _dict["tool_calls"]
for idx, raw_tool_call in enumerate(_dict["tool_calls"]):
tool_call_chunks.append(
{
"name": raw_tool_call.get("function", {}).get("name"),
"args": raw_tool_call.get("function", {}).get("arguments"),
"id": raw_tool_call.get("id"),
"index": idx,
}
)
return _AITongyiMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
)
else:
additional_kwargs = {}
return (
AIMessageChunk(content=content)
if is_chunk
else AIMessage(
tool_calls = []
invalid_tool_calls = []
if "tool_calls" in _dict:
additional_kwargs["tool_calls"] = _dict["tool_calls"]
for raw_tool_call in _dict["tool_calls"]:
try:
tool_calls.append(
parse_tool_call(raw_tool_call, return_id=True)
)
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
)
elif role == "system":
return (
SystemMessageChunk(content=content)
Expand All @@ -114,30 +139,33 @@ def convert_dict_to_message(

def convert_message_chunk_to_message(message_chunk: BaseMessageChunk) -> BaseMessage:
"""Convert a message chunk to a message."""
if isinstance(message_chunk, HumanMessageChunk):
return HumanMessage(content=message_chunk.content)
elif isinstance(message_chunk, AIMessageChunk):
return AIMessage(content=message_chunk.content)
elif isinstance(message_chunk, SystemMessageChunk):
return SystemMessage(content=message_chunk.content)
elif isinstance(message_chunk, ChatMessageChunk):
return ChatMessage(role=message_chunk.role, content=message_chunk.content)
if isinstance(message_chunk, _AITongyiMessageChunk):
return message_chunk_to_message(
cast(AIMessageChunk, message_chunk_to_message(message_chunk))
)
else:
raise TypeError(f"Got unknown type {message_chunk}")
return message_chunk_to_message(message_chunk)


def convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a message to a dict."""

message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"name": message.name or message.additional_kwargs.get("name"),
}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
Expand All @@ -157,6 +185,37 @@ def _create_retry_decorator(llm: ChatTongyi) -> Callable[[Any], Any]:
)


def _remove_prefix(text: str, prefix: str) -> str:
if prefix and text.startswith(prefix):
return text[len(prefix) :]
return text


class _AITongyiMessageChunk(AIMessageChunk):
"""Message chunk from Tongyi LLM,
which handles the `tool_calls` stream appropriately.
"""

type: Literal["_AITongyiMessageChunk"] = "_AITongyiMessageChunk" # type: ignore[assignment] # noqa: E501

@classmethod
def get_lc_namespace(cls) -> List[str]:
return ["langchain_community", "chat_models"]

def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
super_add_result = super().__add__(other)
if isinstance(other, _AITongyiMessageChunk):
return self.__class__(
example=self.example,
content=super_add_result.content,
additional_kwargs=other.additional_kwargs,
tool_call_chunks=other.tool_call_chunks,
response_metadata=super_add_result.response_metadata,
id=super_add_result.id,
)
return super_add_result


class ChatTongyi(BaseChatModel):
"""Alibaba Tongyi Qwen chat models API.

Expand Down Expand Up @@ -227,17 +286,6 @@ def validate_environment(cls, values: Dict) -> Dict:

return values

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Tongyi Qwen API."""
return {
"model": self.model_name,
"top_p": self.top_p,
"api_key": cast(SecretStr, self.dashscope_api_key).get_secret_value(),
"result_format": "message",
**self.model_kwargs,
}

def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self)
Expand Down Expand Up @@ -373,14 +421,21 @@ def _stream(
params: Dict[str, Any] = self._invocation_params(
messages=messages, stop=stop, stream=True, **kwargs
)
incremental_output = params.get("incremental_output")
previous_resp: Any = None
for stream_resp, is_last_chunk in generate_with_last_element_mark(
self.stream_completion_with_retry(**params)
):
chunk = ChatGenerationChunk(
**self._chat_generation_from_qwen_resp(
stream_resp, is_chunk=True, is_last_chunk=is_last_chunk
stream_resp,
previous_resp=previous_resp,
is_chunk=True,
is_last_chunk=is_last_chunk,
)
)
if not incremental_output:
previous_resp = stream_resp
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
Expand All @@ -395,33 +450,51 @@ async def _astream(
params: Dict[str, Any] = self._invocation_params(
messages=messages, stop=stop, stream=True, **kwargs
)
incremental_output = params.get("incremental_output")
previous_resp: Any = None
async for stream_resp, is_last_chunk in agenerate_with_last_element_mark(
self.astream_completion_with_retry(**params)
):
chunk = ChatGenerationChunk(
**self._chat_generation_from_qwen_resp(
stream_resp, is_chunk=True, is_last_chunk=is_last_chunk
stream_resp,
previous_resp=previous_resp,
is_chunk=True,
is_last_chunk=is_last_chunk,
)
)
if not incremental_output:
previous_resp = stream_resp
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk

def _invocation_params(
self, messages: List[BaseMessage], stop: Any, **kwargs: Any
) -> Dict[str, Any]:
params = {**self._default_params, **kwargs}
params = {
"model": self.model_name,
"top_p": self.top_p,
"api_key": cast(SecretStr, self.dashscope_api_key).get_secret_value(),
"result_format": "message",
**self.model_kwargs,
**kwargs,
}
if stop is not None:
params["stop"] = stop
if params.get("stream"):
params["incremental_output"] = True

# the default value of `incremental_output` is `False` in LLM API,
# and it only works when `stream` is `True`.
# So, to prevent some unexpected behavior,
# we delete the `incremental_output` if it is unnecessary.
if not params.get("stream") or not params.get("incremental_output"):
if "incremental_output" in params:
del params["incremental_output"]

message_dicts = [convert_message_to_dict(m) for m in messages]

# According to the docs, the last message should be a `user` message
if message_dicts[-1]["role"] != "user":
raise ValueError("Last message should be user message.")
# And the `system` message should be the first message if present
# the `system` message should always be unique
# and if present, it should be the first message
system_message_indices = [
i for i, m in enumerate(message_dicts) if m["role"] == "system"
]
Expand All @@ -435,23 +508,102 @@ def _invocation_params(
return params

def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
if llm_outputs[0] is None:
return {}
return llm_outputs[0]
return llm_outputs[0] or {}

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.

Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to bind.

Example:
.. code-block:: python

from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_community.chat_models.tongyi import ChatTongyi

class GetWeather(BaseModel):
'''Get the current weather in a given location'''

location: str = Field(
...,
description="The city and state, e.g. San Francisco, CA",
)


llm = ChatTongyi(model="qwen-max")
llm_with_tools = llm.bind_tools([GetWeather])
llm_with_tools.invoke("what is the weather like in HangZhou, China")

# -> AIMessage(
# content='',
# id='run-f3bb9ff7-fbf5-43d4-880c-f28a5391c307-0',
# tool_calls=[{
# 'name': 'GetWeather',
# 'args': {'location': 'Hangzhou, China'},
# 'id': ''
# }],
# response_metadata={
# 'model_name': 'qwen-max',
# 'finish_reason': 'tool_calls',
# ...
# }
# additional_kwargs={'tool_calls': [{...}]}
# )
"""

# According to the documentation of the dashscope:
# 1. the `tools` parameter has exactly the same format
# as OpenAI, so we can use the `convert_to_openai_tool` function
# directly to convert the tools.
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
# 2. the `incremental_output` parameter is not supported
# when `tools` are provided.
return self.bind(tools=formatted_tools, incremental_output=False, **kwargs)

@staticmethod
def _chat_generation_from_qwen_resp(
resp: Any, is_chunk: bool = False, is_last_chunk: bool = True
resp: Any,
previous_resp: Any = None,
is_chunk: bool = False,
is_last_chunk: bool = True,
) -> Dict[str, Any]:
choice = resp["output"]["choices"][0]
raw_message = choice["message"]

message_dict = {"role": raw_message["role"]}
# if `previous_resp` is not None
# (`incremental_output` should be False in this case),
# we try to remove its content as the prefix of current response's content
if previous_resp is not None:
previous_content = previous_resp["output"]["choices"][0]["message"][
"content"
]
message_dict["content"] = _remove_prefix(
raw_message["content"], prefix=previous_content
)
else:
message_dict["content"] = raw_message["content"]
if "tool_calls" in raw_message:
message_dict["tool_calls"] = raw_message["tool_calls"]

message = convert_dict_to_message(message_dict, is_chunk=is_chunk)

# According to the response from dashscope,
# each chunk's `generation_info` overwrites the previous one.
# Besides, The `merge_dicts` method,
# which is used to concatenate `generation_info` in `GenerationChunk`,
# does not support merging of int type values.
# Therefore, we adopt the `generation_info` of the last chunk
# and discard the `generation_info` of the intermediate chunks.
choice = resp["output"]["choices"][0]
message = convert_dict_to_message(choice["message"], is_chunk=is_chunk)
if is_last_chunk:
return dict(
message=message,
Expand Down
Loading