From b314e6dbc0ff8500faaeca84d1e8f83946729c60 Mon Sep 17 00:00:00 2001 From: liushuaikobe Date: Sun, 5 May 2024 14:25:37 +0800 Subject: [PATCH] community[minor]: Add tool_calls support for ChatTongyi Model --- .../langchain_community/chat_models/tongyi.py | 258 ++++++++++++++---- 1 file changed, 205 insertions(+), 53 deletions(-) diff --git a/libs/community/langchain_community/chat_models/tongyi.py b/libs/community/langchain_community/chat_models/tongyi.py index 943cace9733e9..6eecb8254307d 100644 --- a/libs/community/langchain_community/chat_models/tongyi.py +++ b/libs/community/langchain_community/chat_models/tongyi.py @@ -10,8 +10,11 @@ Dict, Iterator, List, + Literal, Mapping, Optional, + Sequence, + Type, Union, cast, ) @@ -20,6 +23,7 @@ AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, @@ -32,7 +36,9 @@ HumanMessageChunk, SystemMessage, SystemMessageChunk, + ToolMessage, ) +from langchain_core.messages.utils import message_chunk_to_message from langchain_core.output_parsers.openai_tools import ( make_invalid_tool_call, parse_tool_call, @@ -42,8 +48,11 @@ ChatGenerationChunk, ChatResult, ) -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils.function_calling import convert_to_openai_tool from requests.exceptions import HTTPError from tenacity import ( before_sleep_log, @@ -75,29 +84,45 @@ def convert_dict_to_message( else HumanMessage(content=content) ) elif role == "assistant": - tool_calls = [] - invalid_tool_calls = [] - if "tool_calls" in _dict: - additional_kwargs = {"tool_calls": _dict["tool_calls"]} - for raw_tool_call in _dict["tool_calls"]: - try: - tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) - except Exception as e: - invalid_tool_calls.append( - make_invalid_tool_call(raw_tool_call, str(e)) + additional_kwargs = {} + if is_chunk: + tool_call_chunks = [] + if "tool_calls" in _dict: + additional_kwargs["tool_calls"] = _dict["tool_calls"] + for idx, raw_tool_call in enumerate(_dict["tool_calls"]): + tool_call_chunks.append( + { + "name": raw_tool_call.get("function", {}).get("name"), + "args": raw_tool_call.get("function", {}).get("arguments"), + "id": raw_tool_call.get("id"), + "index": idx, + } ) + return _AITongyiMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + tool_call_chunks=tool_call_chunks, + ) else: - additional_kwargs = {} - return ( - AIMessageChunk(content=content) - if is_chunk - else AIMessage( + tool_calls = [] + invalid_tool_calls = [] + if "tool_calls" in _dict: + additional_kwargs["tool_calls"] = _dict["tool_calls"] + for raw_tool_call in _dict["tool_calls"]: + try: + tool_calls.append( + parse_tool_call(raw_tool_call, return_id=True) + ) + except Exception as e: + invalid_tool_calls.append( + make_invalid_tool_call(raw_tool_call, str(e)) + ) + return AIMessage( content=content, additional_kwargs=additional_kwargs, tool_calls=tool_calls, invalid_tool_calls=invalid_tool_calls, ) - ) elif role == "system": return ( SystemMessageChunk(content=content) @@ -114,21 +139,16 @@ def convert_dict_to_message( def convert_message_chunk_to_message(message_chunk: BaseMessageChunk) -> BaseMessage: """Convert a message chunk to a message.""" - if isinstance(message_chunk, HumanMessageChunk): - return HumanMessage(content=message_chunk.content) - elif isinstance(message_chunk, AIMessageChunk): - return AIMessage(content=message_chunk.content) - elif isinstance(message_chunk, SystemMessageChunk): - return SystemMessage(content=message_chunk.content) - elif isinstance(message_chunk, ChatMessageChunk): - return ChatMessage(role=message_chunk.role, content=message_chunk.content) + if isinstance(message_chunk, _AITongyiMessageChunk): + return message_chunk_to_message( + cast(AIMessageChunk, message_chunk_to_message(message_chunk)) + ) else: - raise TypeError(f"Got unknown type {message_chunk}") + return message_chunk_to_message(message_chunk) def convert_message_to_dict(message: BaseMessage) -> dict: """Convert a message to a dict.""" - message_dict: Dict[str, Any] if isinstance(message, ChatMessage): message_dict = {"role": message.role, "content": message.content} @@ -136,8 +156,16 @@ def convert_message_to_dict(message: BaseMessage) -> dict: message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} + if "tool_calls" in message.additional_kwargs: + message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolMessage): + message_dict = { + "role": "tool", + "content": message.content, + "name": message.name or message.additional_kwargs.get("name"), + } else: raise TypeError(f"Got unknown type {message}") return message_dict @@ -157,6 +185,37 @@ def _create_retry_decorator(llm: ChatTongyi) -> Callable[[Any], Any]: ) +def _remove_prefix(text: str, prefix: str) -> str: + if prefix and text.startswith(prefix): + return text[len(prefix) :] + return text + + +class _AITongyiMessageChunk(AIMessageChunk): + """Message chunk from Tongyi LLM, + which handles the `tool_calls` stream appropriately. + """ + + type: Literal["_AITongyiMessageChunk"] = "_AITongyiMessageChunk" # type: ignore[assignment] # noqa: E501 + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return ["langchain_community", "chat_models"] + + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore + super_add_result = super().__add__(other) + if isinstance(other, _AITongyiMessageChunk): + return self.__class__( + example=self.example, + content=super_add_result.content, + additional_kwargs=other.additional_kwargs, + tool_call_chunks=other.tool_call_chunks, + response_metadata=super_add_result.response_metadata, + id=super_add_result.id, + ) + return super_add_result + + class ChatTongyi(BaseChatModel): """Alibaba Tongyi Qwen chat models API. @@ -227,17 +286,6 @@ def validate_environment(cls, values: Dict) -> Dict: return values - @property - def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling Tongyi Qwen API.""" - return { - "model": self.model_name, - "top_p": self.top_p, - "api_key": cast(SecretStr, self.dashscope_api_key).get_secret_value(), - "result_format": "message", - **self.model_kwargs, - } - def completion_with_retry(self, **kwargs: Any) -> Any: """Use tenacity to retry the completion call.""" retry_decorator = _create_retry_decorator(self) @@ -373,14 +421,21 @@ def _stream( params: Dict[str, Any] = self._invocation_params( messages=messages, stop=stop, stream=True, **kwargs ) + incremental_output = params.get("incremental_output") + previous_resp: Any = None for stream_resp, is_last_chunk in generate_with_last_element_mark( self.stream_completion_with_retry(**params) ): chunk = ChatGenerationChunk( **self._chat_generation_from_qwen_resp( - stream_resp, is_chunk=True, is_last_chunk=is_last_chunk + stream_resp, + previous_resp=previous_resp, + is_chunk=True, + is_last_chunk=is_last_chunk, ) ) + if not incremental_output: + previous_resp = stream_resp if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) yield chunk @@ -395,14 +450,21 @@ async def _astream( params: Dict[str, Any] = self._invocation_params( messages=messages, stop=stop, stream=True, **kwargs ) + incremental_output = params.get("incremental_output") + previous_resp: Any = None async for stream_resp, is_last_chunk in agenerate_with_last_element_mark( self.astream_completion_with_retry(**params) ): chunk = ChatGenerationChunk( **self._chat_generation_from_qwen_resp( - stream_resp, is_chunk=True, is_last_chunk=is_last_chunk + stream_resp, + previous_resp=previous_resp, + is_chunk=True, + is_last_chunk=is_last_chunk, ) ) + if not incremental_output: + previous_resp = stream_resp if run_manager: await run_manager.on_llm_new_token(chunk.text, chunk=chunk) yield chunk @@ -410,18 +472,29 @@ async def _astream( def _invocation_params( self, messages: List[BaseMessage], stop: Any, **kwargs: Any ) -> Dict[str, Any]: - params = {**self._default_params, **kwargs} + params = { + "model": self.model_name, + "top_p": self.top_p, + "api_key": cast(SecretStr, self.dashscope_api_key).get_secret_value(), + "result_format": "message", + **self.model_kwargs, + **kwargs, + } if stop is not None: params["stop"] = stop - if params.get("stream"): - params["incremental_output"] = True + + # the default value of `incremental_output` is `False` in LLM API, + # and it only works when `stream` is `True`. + # So, to prevent some unexpected behavior, + # we delete the `incremental_output` if it is unnecessary. + if not params.get("stream") or not params.get("incremental_output"): + if "incremental_output" in params: + del params["incremental_output"] message_dicts = [convert_message_to_dict(m) for m in messages] - # According to the docs, the last message should be a `user` message - if message_dicts[-1]["role"] != "user": - raise ValueError("Last message should be user message.") - # And the `system` message should be the first message if present + # the `system` message should always be unique + # and if present, it should be the first message system_message_indices = [ i for i, m in enumerate(message_dicts) if m["role"] == "system" ] @@ -435,14 +508,95 @@ def _invocation_params( return params def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: - if llm_outputs[0] is None: - return {} - return llm_outputs[0] + return llm_outputs[0] or {} + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + **kwargs: Any additional parameters to bind. + + Example: + .. code-block:: python + + from langchain_core.pydantic_v1 import BaseModel, Field + from langchain_community.chat_models.tongyi import ChatTongyi + + class GetWeather(BaseModel): + '''Get the current weather in a given location''' + + location: str = Field( + ..., + description="The city and state, e.g. San Francisco, CA", + ) + + + llm = ChatTongyi(model="qwen-max") + llm_with_tools = llm.bind_tools([GetWeather]) + llm_with_tools.invoke("what is the weather like in HangZhou, China") + + # -> AIMessage( + # content='', + # id='run-f3bb9ff7-fbf5-43d4-880c-f28a5391c307-0', + # tool_calls=[{ + # 'name': 'GetWeather', + # 'args': {'location': 'Hangzhou, China'}, + # 'id': '' + # }], + # response_metadata={ + # 'model_name': 'qwen-max', + # 'finish_reason': 'tool_calls', + # ... + # } + # additional_kwargs={'tool_calls': [{...}]} + # ) + """ + + # According to the documentation of the dashscope: + # 1. the `tools` parameter has exactly the same format + # as OpenAI, so we can use the `convert_to_openai_tool` function + # directly to convert the tools. + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + # 2. the `incremental_output` parameter is not supported + # when `tools` are provided. + return self.bind(tools=formatted_tools, incremental_output=False, **kwargs) @staticmethod def _chat_generation_from_qwen_resp( - resp: Any, is_chunk: bool = False, is_last_chunk: bool = True + resp: Any, + previous_resp: Any = None, + is_chunk: bool = False, + is_last_chunk: bool = True, ) -> Dict[str, Any]: + choice = resp["output"]["choices"][0] + raw_message = choice["message"] + + message_dict = {"role": raw_message["role"]} + # if `previous_resp` is not None + # (`incremental_output` should be False in this case), + # we try to remove its content as the prefix of current response's content + if previous_resp is not None: + previous_content = previous_resp["output"]["choices"][0]["message"][ + "content" + ] + message_dict["content"] = _remove_prefix( + raw_message["content"], prefix=previous_content + ) + else: + message_dict["content"] = raw_message["content"] + if "tool_calls" in raw_message: + message_dict["tool_calls"] = raw_message["tool_calls"] + + message = convert_dict_to_message(message_dict, is_chunk=is_chunk) + # According to the response from dashscope, # each chunk's `generation_info` overwrites the previous one. # Besides, The `merge_dicts` method, @@ -450,8 +604,6 @@ def _chat_generation_from_qwen_resp( # does not support merging of int type values. # Therefore, we adopt the `generation_info` of the last chunk # and discard the `generation_info` of the intermediate chunks. - choice = resp["output"]["choices"][0] - message = convert_dict_to_message(choice["message"], is_chunk=is_chunk) if is_last_chunk: return dict( message=message,