Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Implement bind_tools for ChatTongyi #20725

Merged
merged 20 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 66 additions & 23 deletions docs/docs/integrations/chat/tongyi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,37 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"# Install the package\n",
"%pip install --upgrade --quiet dashscope"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ········\n"
]
}
],
"outputs": [],
"source": [
"# Get a new token: https://help.aliyun.com/document_detail/611472.html?spm=a2c4g.2399481.0.0\n",
"from getpass import getpass\n",
Expand All @@ -66,7 +66,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"metadata": {
"collapsed": false,
"jupyter": {
Expand All @@ -82,7 +82,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 8,
"metadata": {
"collapsed": false,
"jupyter": {
Expand All @@ -94,8 +94,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"chat resp: content='Hello! How' additional_kwargs={} example=False\n",
"chat resp: content=' can I assist you today?' additional_kwargs={} example=False\n"
"chat resp: content='Hello' id='run-1df2c54b-94c4-4f84-8716-ed2f52cd42a9'\n",
"chat resp: content='!' id='run-1df2c54b-94c4-4f84-8716-ed2f52cd42a9'\n",
"chat resp: content=' How' id='run-1df2c54b-94c4-4f84-8716-ed2f52cd42a9'\n",
"chat resp: content=' can I assist you today' id='run-1df2c54b-94c4-4f84-8716-ed2f52cd42a9'\n",
"chat resp: content='?' id='run-1df2c54b-94c4-4f84-8716-ed2f52cd42a9'\n",
"chat resp: content='' response_metadata={'finish_reason': 'stop', 'request_id': '2ffa6db7-09d7-96b7-8bb2-4c59dadf467b', 'token_usage': {'input_tokens': 20, 'output_tokens': 9, 'total_tokens': 29}} id='run-1df2c54b-94c4-4f84-8716-ed2f52cd42a9'\n"
]
}
],
Expand All @@ -113,16 +117,24 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/cheese/PARA/Projects/langchain-contribution/langchain/libs/core/langchain_core/_api/deprecation.py:117: LangChainDeprecationWarning: The function `__call__` was deprecated in LangChain 0.1.7 and will be removed in 0.2.0. Use invoke instead.\n",
" warn_deprecated(\n"
]
},
{
"data": {
"text/plain": [
"AIMessageChunk(content=\"J'aime programmer.\", additional_kwargs={}, example=False)"
"AIMessage(content=\"J'aime programmer.\", response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'stop', 'request_id': 'e5533cb5-a2ab-9204-b318-e6b056e24cbe', 'token_usage': {'input_tokens': 36, 'output_tokens': 5, 'total_tokens': 41}}, id='run-44282599-b1a6-4e2b-a192-f707b286b5d8-0')"
]
},
"execution_count": 5,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -141,12 +153,43 @@
"chatLLM(messages)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool calling"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"content='' additional_kwargs={'tool_calls': [{'function': {'name': 'multiply', 'arguments': '{\"first_int\": 5, \"second_int\": 42}'}, 'id': '', 'type': 'function'}]} response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': '2830e26a-1a8f-94ac-834b-e4dfcd0ded56', 'token_usage': {'input_tokens': 200, 'output_tokens': 25, 'total_tokens': 225}} id='run-8bf406e7-e44b-481c-983c-ecba51fd1206-0' tool_calls=[{'name': 'multiply', 'args': {'first_int': 5, 'second_int': 42}, 'id': ''}]\n"
]
}
],
"source": [
"from langchain_community.chat_models.tongyi import ChatTongyi\n",
"from langchain_core.tools import tool\n",
"\n",
"@tool\n",
"def multiply(first_int: int, second_int: int) -> int:\n",
" \"\"\"Multiply two integers together.\"\"\"\n",
" return first_int * second_int\n",
"\n",
"llm = ChatTongyi(model=\"qwen-turbo\")\n",
"\n",
"llm_with_tools = llm.bind_tools([multiply])\n",
"\n",
"msg = llm_with_tools.invoke(\"What's 5 times forty two\")\n",
"\n",
"print(msg)"
]
}
],
"metadata": {
Expand All @@ -165,7 +208,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.12.2"
}
},
"nbformat": 4,
Expand Down
107 changes: 100 additions & 7 deletions libs/community/langchain_community/chat_models/tongyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
List,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
)
Expand All @@ -20,6 +22,7 @@
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
Expand All @@ -32,6 +35,8 @@
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
)
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
Expand All @@ -42,8 +47,11 @@
ChatGenerationChunk,
ChatResult,
)
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from requests.exceptions import HTTPError
from tenacity import (
before_sleep_log,
Expand Down Expand Up @@ -88,8 +96,14 @@ def convert_dict_to_message(
)
else:
additional_kwargs = {}

return (
AIMessageChunk(content=content)
AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these tool calls fully formed in a streaming context?

if not you can specify tool_call_chunks on AIMessageChunk instead, with args a (partial json) string. See example here:

tool_call_chunks=tool_call_chunks,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, I have made some updates about this.

invalid_tool_calls=invalid_tool_calls,
)
if is_chunk
else AIMessage(
content=content,
Expand All @@ -104,6 +118,23 @@ def convert_dict_to_message(
if is_chunk
else SystemMessage(content=content)
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return (
ToolMessageChunk(
content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id"),
additional_kwargs=additional_kwargs,
)
if is_chunk
else ToolMessage(
content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id"),
additional_kwargs=additional_kwargs,
)
)
else:
return (
ChatMessageChunk(role=role, content=content)
Expand All @@ -117,11 +148,25 @@ def convert_message_chunk_to_message(message_chunk: BaseMessageChunk) -> BaseMes
if isinstance(message_chunk, HumanMessageChunk):
return HumanMessage(content=message_chunk.content)
elif isinstance(message_chunk, AIMessageChunk):
return AIMessage(content=message_chunk.content)
# assert message_chunk is None
return (
AIMessage(
content=message_chunk.content,
tool_calls=message_chunk.additional_kwargs["tool_calls"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is confusing, because in convert_dict_to_message we are parsing tool calls (e.g., in call to parse_tool_call) but here we are passing them in although they are already parsed.

at this point, is the following true?

if message_chunk.additional_kwargs["tool_calls"]:
    item = message_chunk.additional_kwargs["tool_calls"][0]

assert isinstance(item["args"], dict)  # not a string

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for delayed response, having a really tough time with my job 😫
This was a mistake, I've rewritten the method referring langchain_openai's base module.

)
if "tool_calls" in message_chunk.additional_kwargs
else AIMessage(content=message_chunk.content)
)
elif isinstance(message_chunk, SystemMessageChunk):
return SystemMessage(content=message_chunk.content)
elif isinstance(message_chunk, ChatMessageChunk):
return ChatMessage(role=message_chunk.role, content=message_chunk.content)
elif isinstance(message_chunk, ToolMessageChunk):
return ToolMessage(
content=message_chunk.content,
tool_call_id=message_chunk.tool_call_id,
name=message_chunk.name,
)
else:
raise TypeError(f"Got unknown type {message_chunk}")

Expand All @@ -136,8 +181,17 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"tool_call_id": message.tool_call_id,
"content": message.content,
"name": message.name,
}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
Expand Down Expand Up @@ -373,9 +427,30 @@ def _stream(
params: Dict[str, Any] = self._invocation_params(
messages=messages, stop=stop, stream=True, **kwargs
)
prev_msg_content = ""

for stream_resp, is_last_chunk in generate_with_last_element_mark(
self.stream_completion_with_retry(**params)
):
choice = stream_resp["output"]["choices"][0]
message = choice["message"]
if (
choice["finish_reason"] == "null"
and message["content"] == ""
and "tool_calls" not in message
):
continue

# If it's a tool call response, wait until it's finished
if "tool_calls" in message and choice["finish_reason"] == "null":
continue

# If we are streaming without `incremental_output = True`,
# we need to chop off the previous message content
if not params.get("incremental_output", False):
message["content"] = message["content"].replace(prev_msg_content, "")
prev_msg_content += message["content"]

chunk = ChatGenerationChunk(
**self._chat_generation_from_qwen_resp(
stream_resp, is_chunk=True, is_last_chunk=is_last_chunk
Expand Down Expand Up @@ -413,14 +488,13 @@ def _invocation_params(
params = {**self._default_params, **kwargs}
if stop is not None:
params["stop"] = stop
if params.get("stream"):
# According to the Tongyi official docs,
# `incremental_output` with `tools` is not supported yet
if params.get("stream") and not params.get("tools"):
params["incremental_output"] = True

message_dicts = [convert_message_to_dict(m) for m in messages]

# According to the docs, the last message should be a `user` message
if message_dicts[-1]["role"] != "user":
raise ValueError("Last message should be user message.")
# And the `system` message should be the first message if present
system_message_indices = [
i for i, m in enumerate(message_dicts) if m["role"] == "system"
Expand Down Expand Up @@ -470,3 +544,22 @@ def _chunk_to_generation(chunk: ChatGenerationChunk) -> ChatGeneration:
message=convert_message_chunk_to_message(chunk.message),
generation_info=chunk.generation_info,
)

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.

Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""

formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
Loading
Loading