diff --git a/docs/docs/integrations/chat/tongyi.ipynb b/docs/docs/integrations/chat/tongyi.ipynb index 5c74e930a2819..6b517937c1799 100644 --- a/docs/docs/integrations/chat/tongyi.ipynb +++ b/docs/docs/integrations/chat/tongyi.ipynb @@ -26,14 +26,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "# Install the package\n", "%pip install --upgrade --quiet dashscope" @@ -48,15 +56,7 @@ "outputs_hidden": false } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " ········\n" - ] - } - ], + "outputs": [], "source": [ "# Get a new token: https://help.aliyun.com/document_detail/611472.html?spm=a2c4g.2399481.0.0\n", "from getpass import getpass\n", @@ -94,8 +94,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "chat resp: content='Hello! How' additional_kwargs={} example=False\n", - "chat resp: content=' can I assist you today?' additional_kwargs={} example=False\n" + "chat resp: content='Hello' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n", + "chat resp: content='!' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n", + "chat resp: content=' How' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n", + "chat resp: content=' can I assist you today' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n", + "chat resp: content='?' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n", + "chat resp: content='' response_metadata={'finish_reason': 'stop', 'request_id': '921db2c5-4d53-9a89-8e87-e4ad6a671237', 'token_usage': {'input_tokens': 20, 'output_tokens': 9, 'total_tokens': 29}} id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n" ] } ], @@ -116,10 +120,18 @@ "execution_count": 5, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/cheese/PARA/Projects/langchain-contribution/langchain/libs/core/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The method `BaseChatModel.__call__` was deprecated in langchain-core 0.1.7 and will be removed in 0.2.0. Use invoke instead.\n", + " warn_deprecated(\n" + ] + }, { "data": { "text/plain": [ - "AIMessageChunk(content=\"J'aime programmer.\", additional_kwargs={}, example=False)" + "AIMessage(content=\"J'adore programmer.\", response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'stop', 'request_id': 'ae725086-0ffa-9728-8c72-b204c7bc7eeb', 'token_usage': {'input_tokens': 36, 'output_tokens': 6, 'total_tokens': 42}}, id='run-060cc103-ef5f-4c8a-af40-792ac7f40c26-0')" ] }, "execution_count": 5, @@ -149,18 +161,65 @@ "ChatTongyi supports tool calling API that lets you describe tools and their arguments, and have the model return a JSON object with a tool to invoke and the inputs to that tool." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Use with `bind_tools`" + ] + }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "content='' additional_kwargs={'tool_calls': [{'function': {'name': 'multiply', 'arguments': '{\"first_int\": 5, \"second_int\": 42}'}, 'id': '', 'type': 'function'}]} response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': '4acf0e36-44af-987a-a0c0-8b5c5eaa1a8b', 'token_usage': {'input_tokens': 200, 'output_tokens': 25, 'total_tokens': 225}} id='run-0ecd0f09-1d20-4e55-a4f3-f14d1f710ae7-0' tool_calls=[{'name': 'multiply', 'args': {'first_int': 5, 'second_int': 42}, 'id': ''}]\n" + ] + } + ], + "source": [ + "from langchain_community.chat_models.tongyi import ChatTongyi\n", + "from langchain_core.tools import tool\n", + "\n", + "\n", + "@tool\n", + "def multiply(first_int: int, second_int: int) -> int:\n", + " \"\"\"Multiply two integers together.\"\"\"\n", + " return first_int * second_int\n", + "\n", + "\n", + "llm = ChatTongyi(model=\"qwen-turbo\")\n", + "\n", + "llm_with_tools = llm.bind_tools([multiply])\n", + "\n", + "msg = llm_with_tools.invoke(\"What's 5 times forty two\")\n", + "\n", + "print(msg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Construct args manually" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AIMessage(content='', additional_kwargs={'tool_calls': [{'function': {'name': 'get_current_weather', 'arguments': '{\"location\": \"San Francisco\"}'}, 'id': '', 'type': 'function'}]}, response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': 'dae79197-8780-9b7e-8c15-6a83e2a53534', 'token_usage': {'input_tokens': 229, 'output_tokens': 19, 'total_tokens': 248}}, id='run-9e06f837-582b-473b-bb1f-5e99a68ecc10-0', tool_calls=[{'name': 'get_current_weather', 'args': {'location': 'San Francisco'}, 'id': ''}])" + "AIMessage(content='', additional_kwargs={'tool_calls': [{'function': {'name': 'get_current_weather', 'arguments': '{\"location\": \"San Francisco\"}'}, 'id': '', 'type': 'function'}]}, response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': '87ef33d2-5c6b-9457-91e2-39faad7120eb', 'token_usage': {'input_tokens': 229, 'output_tokens': 19, 'total_tokens': 248}}, id='run-7939ba7f-e3f7-46f8-980b-30499b52723c-0', tool_calls=[{'name': 'get_current_weather', 'args': {'location': 'San Francisco'}, 'id': ''}])" ] }, - "execution_count": 5, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -224,7 +283,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/libs/community/langchain_community/chat_models/tongyi.py b/libs/community/langchain_community/chat_models/tongyi.py index 943cace9733e9..785bc79320657 100644 --- a/libs/community/langchain_community/chat_models/tongyi.py +++ b/libs/community/langchain_community/chat_models/tongyi.py @@ -2,6 +2,7 @@ import asyncio import functools +import json import logging from typing import ( Any, @@ -12,6 +13,8 @@ List, Mapping, Optional, + Sequence, + Type, Union, cast, ) @@ -20,6 +23,7 @@ AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, @@ -32,6 +36,8 @@ HumanMessageChunk, SystemMessage, SystemMessageChunk, + ToolMessage, + ToolMessageChunk, ) from langchain_core.output_parsers.openai_tools import ( make_invalid_tool_call, @@ -42,8 +48,11 @@ ChatGenerationChunk, ChatResult, ) -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils.function_calling import convert_to_openai_tool from requests.exceptions import HTTPError from tenacity import ( before_sleep_log, @@ -68,6 +77,7 @@ def convert_dict_to_message( """Convert a dict to a message.""" role = _dict["role"] content = _dict["content"] + if role == "user": return ( HumanMessageChunk(content=content) @@ -79,17 +89,39 @@ def convert_dict_to_message( invalid_tool_calls = [] if "tool_calls" in _dict: additional_kwargs = {"tool_calls": _dict["tool_calls"]} - for raw_tool_call in _dict["tool_calls"]: - try: - tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) - except Exception as e: - invalid_tool_calls.append( - make_invalid_tool_call(raw_tool_call, str(e)) - ) + + for index, value in enumerate(_dict["tool_calls"]): + if is_chunk: + try: + tool_calls.append( + { + "name": value["function"].get("name"), + "args": value["function"].get("arguments"), + "id": value.get("id"), + # Tongyi does not respond with index, + # use index in the list instead + "index": index, + } + ) + except KeyError: + pass + else: + try: + parsed_tool = parse_tool_call(value, return_id=True) + if parsed_tool: + tool_calls.append(parsed_tool) + except Exception as e: + invalid_tool_calls.append(make_invalid_tool_call(value, str(e))) else: additional_kwargs = {} + return ( - AIMessageChunk(content=content) + AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + tool_call_chunks=tool_calls, + id=_dict.get("id"), + ) if is_chunk else AIMessage( content=content, @@ -104,6 +136,23 @@ def convert_dict_to_message( if is_chunk else SystemMessage(content=content) ) + elif role == "tool": + additional_kwargs = {} + if "name" in _dict: + additional_kwargs["name"] = _dict["name"] + return ( + ToolMessageChunk( + content=_dict.get("content", ""), + tool_call_id=_dict.get("tool_call_id"), + additional_kwargs=additional_kwargs, + ) + if is_chunk + else ToolMessage( + content=_dict.get("content", ""), + tool_call_id=_dict.get("tool_call_id"), + additional_kwargs=additional_kwargs, + ) + ) else: return ( ChatMessageChunk(role=role, content=content) @@ -113,17 +162,23 @@ def convert_dict_to_message( def convert_message_chunk_to_message(message_chunk: BaseMessageChunk) -> BaseMessage: - """Convert a message chunk to a message.""" - if isinstance(message_chunk, HumanMessageChunk): - return HumanMessage(content=message_chunk.content) - elif isinstance(message_chunk, AIMessageChunk): - return AIMessage(content=message_chunk.content) - elif isinstance(message_chunk, SystemMessageChunk): - return SystemMessage(content=message_chunk.content) - elif isinstance(message_chunk, ChatMessageChunk): - return ChatMessage(role=message_chunk.role, content=message_chunk.content) - else: - raise TypeError(f"Got unknown type {message_chunk}") + """Convert a message chunk to a message. + + Args: + chunk: Message chunk to convert. + + Returns: + Message. + """ + if not isinstance(message_chunk, BaseMessageChunk): + return message_chunk + # chunk classes always have the equivalent non-chunk class as their first parent + ignore_keys = ["type"] + if isinstance(message_chunk, AIMessageChunk): + ignore_keys.append("tool_call_chunks") + return message_chunk.__class__.__mro__[1]( + **{k: v for k, v in message_chunk.__dict__.items() if k not in ignore_keys} + ) def convert_message_to_dict(message: BaseMessage) -> dict: @@ -136,8 +191,17 @@ def convert_message_to_dict(message: BaseMessage) -> dict: message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} + if "tool_calls" in message.additional_kwargs: + message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolMessage): + message_dict = { + "role": "tool", + "tool_call_id": message.tool_call_id, + "content": message.content, + "name": message.name, + } else: raise TypeError(f"Got unknown type {message}") return message_dict @@ -256,11 +320,57 @@ def stream_completion_with_retry(self, **kwargs: Any) -> Any: @retry_decorator def _stream_completion_with_retry(**_kwargs: Any) -> Any: responses = self.client.call(**_kwargs) + prev_resp = None + for resp in responses: - yield check_response(resp) + # If we are streaming without `incremental_output = True`, + # we need to calculate the delta response manually + if _kwargs.get("stream") and not _kwargs.get( + "incremental_output", False + ): + if prev_resp is None: + delta_resp = resp + else: + delta_resp = self.subtract_client_response(resp, prev_resp) + prev_resp = resp + yield check_response(delta_resp) + else: + yield check_response(resp) return _stream_completion_with_retry(**kwargs) + def subtract_client_response(self, resp: Any, prev_resp: Any) -> Any: + """Subtract prev response from curr response. + + Useful when streaming without `incremental_output = True` + """ + + resp_copy = json.loads(json.dumps(resp)) + choice = resp_copy["output"]["choices"][0] + message = choice["message"] + + prev_resp_copy = json.loads(json.dumps(prev_resp)) + prev_choice = prev_resp_copy["output"]["choices"][0] + prev_message = prev_choice["message"] + + message["content"] = message["content"].replace(prev_message["content"], "") + + if message.get("tool_calls"): + for index, tool_call in enumerate(message["tool_calls"]): + function = tool_call["function"] + + if prev_message.get("tool_calls"): + prev_function = prev_message["tool_calls"][index]["function"] + + function["name"] = function["name"].replace( + prev_function["name"], "" + ) + function["arguments"] = function["arguments"].replace( + prev_function["arguments"], "" + ) + + return resp_copy + async def astream_completion_with_retry(self, **kwargs: Any) -> Any: """Because the dashscope SDK doesn't provide an async API, we wrap `stream_generate_with_retry` with an async generator.""" @@ -301,16 +411,16 @@ def _generate( ) -> ChatResult: generations = [] if self.streaming: - generation: Optional[ChatGenerationChunk] = None + generation_chunk: Optional[ChatGenerationChunk] = None for chunk in self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ): - if generation is None: - generation = chunk + if generation_chunk is None: + generation_chunk = chunk else: - generation += chunk - assert generation is not None - generations.append(self._chunk_to_generation(generation)) + generation_chunk += chunk + assert generation_chunk is not None + generations.append(self._chunk_to_generation(generation_chunk)) else: params: Dict[str, Any] = self._invocation_params( messages=messages, stop=stop, **kwargs @@ -373,9 +483,19 @@ def _stream( params: Dict[str, Any] = self._invocation_params( messages=messages, stop=stop, stream=True, **kwargs ) + for stream_resp, is_last_chunk in generate_with_last_element_mark( self.stream_completion_with_retry(**params) ): + choice = stream_resp["output"]["choices"][0] + message = choice["message"] + if ( + choice["finish_reason"] == "null" + and message["content"] == "" + and "tool_calls" not in message + ): + continue + chunk = ChatGenerationChunk( **self._chat_generation_from_qwen_resp( stream_resp, is_chunk=True, is_last_chunk=is_last_chunk @@ -413,14 +533,13 @@ def _invocation_params( params = {**self._default_params, **kwargs} if stop is not None: params["stop"] = stop - if params.get("stream"): + # According to the Tongyi official docs, + # `incremental_output` with `tools` is not supported yet + if params.get("stream") and not params.get("tools"): params["incremental_output"] = True message_dicts = [convert_message_to_dict(m) for m in messages] - # According to the docs, the last message should be a `user` message - if message_dicts[-1]["role"] != "user": - raise ValueError("Last message should be user message.") # And the `system` message should be the first message if present system_message_indices = [ i for i, m in enumerate(message_dicts) if m["role"] == "system" @@ -470,3 +589,22 @@ def _chunk_to_generation(chunk: ChatGenerationChunk) -> ChatGeneration: message=convert_message_chunk_to_message(chunk.message), generation_info=chunk.generation_info, ) + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) diff --git a/libs/community/langchain_community/llms/tongyi.py b/libs/community/langchain_community/llms/tongyi.py index 6254609ece7f3..8e13b6e03f192 100644 --- a/libs/community/langchain_community/llms/tongyi.py +++ b/libs/community/langchain_community/llms/tongyi.py @@ -55,17 +55,17 @@ def _create_retry_decorator(llm: Tongyi) -> Callable[[Any], Any]: def check_response(resp: Any) -> Any: """Check the response from the completion call.""" - if resp.status_code == 200: + if resp["status_code"] == 200: return resp - elif resp.status_code in [400, 401]: + elif resp["status_code"] in [400, 401]: raise ValueError( - f"status_code: {resp.status_code} \n " - f"code: {resp.code} \n message: {resp.message}" + f"status_code: {resp['status_code']} \n " + f"code: {resp['code']} \n message: {resp['message']}" ) else: raise HTTPError( - f"HTTP error occurred: status_code: {resp.status_code} \n " - f"code: {resp.code} \n message: {resp.message}", + f"HTTP error occurred: status_code: {resp['status_code']} \n " + f"code: {resp['code']} \n message: {resp['message']}", response=resp, ) diff --git a/libs/community/tests/integration_tests/chat_models/test_tongyi.py b/libs/community/tests/integration_tests/chat_models/test_tongyi.py index 3e0a8f9442c72..120046f7d176e 100644 --- a/libs/community/tests/integration_tests/chat_models/test_tongyi.py +++ b/libs/community/tests/integration_tests/chat_models/test_tongyi.py @@ -1,11 +1,14 @@ """Test Alibaba Tongyi Chat Model.""" -from typing import Any, cast + +from typing import Any, List, cast from langchain_core.callbacks import CallbackManager from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.messages.ai import AIMessageChunk +from langchain_core.messages.tool import ToolCall, ToolMessage from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain_core.pydantic_v1 import SecretStr +from langchain_core.pydantic_v1 import BaseModel, SecretStr from pytest import CaptureFixture from langchain_community.chat_models.tongyi import ChatTongyi @@ -138,3 +141,76 @@ def test_multiple_messages() -> None: assert isinstance(generation, ChatGeneration) assert isinstance(generation.text, str) assert generation.text == generation.message.content + + +class GenerateUsername(BaseModel): + "Get a username based on someone's name and hair color." + + name: str + hair_color: str + + +def test_tool_use() -> None: + llm = ChatTongyi(model="qwen-turbo", temperature=0) + llm_with_tool = llm.bind_tools(tools=[GenerateUsername]) + msgs: List = [HumanMessage("Sally has green hair, what would her username be?")] + ai_msg = llm_with_tool.invoke(msgs) + # assert ai_msg is None + # ai_msg.content = " " + + assert isinstance(ai_msg, AIMessage) + assert isinstance(ai_msg.tool_calls, list) + assert len(ai_msg.tool_calls) == 1 + tool_call = ai_msg.tool_calls[0] + assert "args" in tool_call + + tool_msg = ToolMessage( + "sally_green_hair", + tool_call_id=ai_msg.tool_calls[0]["id"], + name=ai_msg.tool_calls[0]["name"], + ) + msgs.extend([ai_msg, tool_msg]) + llm_with_tool.invoke(msgs) + + # Test streaming + ai_messages = llm_with_tool.stream(msgs) + first = True + for message in ai_messages: + if first: + gathered = message + first = False + else: + gathered = gathered + message # type: ignore + assert isinstance(gathered, AIMessageChunk) + + streaming_tool_msg = ToolMessage( + "sally_green_hair", + name=tool_call["name"], + tool_call_id=tool_call["id"] if tool_call["id"] else " ", + ) + msgs.extend([gathered, streaming_tool_msg]) + llm_with_tool.invoke(msgs) + + +def test_manual_tool_call_msg() -> None: + """Test passing in manually construct tool call message.""" + llm = ChatTongyi(model="qwen-turbo", temperature=0) + llm_with_tool = llm.bind_tools(tools=[GenerateUsername]) + msgs: List = [ + HumanMessage("Sally has green hair, what would her username be?"), + AIMessage( + content=" ", + tool_calls=[ + ToolCall( + name="GenerateUsername", + args={"name": "Sally", "hair_color": "green"}, + id="foo", + ) + ], + ), + ToolMessage("sally_green_hair", tool_call_id="foo"), + ] + output: AIMessage = cast(AIMessage, llm_with_tool.invoke(msgs)) + assert output.content + # Should not have called the tool again. + assert not output.tool_calls and not output.invalid_tool_calls