From f7731ba87aa6cefa5fc23cbebaefd4cf1af2a729 Mon Sep 17 00:00:00 2001 From: jhpiedrahitao Date: Mon, 11 Nov 2024 11:03:48 -0500 Subject: [PATCH] add tool calling to samabstudio chat model docs --- docs/docs/integrations/chat/sambastudio.ipynb | 97 ++++++++++++++++++- 1 file changed, 93 insertions(+), 4 deletions(-) diff --git a/docs/docs/integrations/chat/sambastudio.ipynb b/docs/docs/integrations/chat/sambastudio.ipynb index 64dd05fd96b8c..e719354038fb0 100644 --- a/docs/docs/integrations/chat/sambastudio.ipynb +++ b/docs/docs/integrations/chat/sambastudio.ipynb @@ -34,7 +34,7 @@ "\n", "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", - "| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n", + "| ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n", "\n", "## Setup\n", "\n", @@ -119,20 +119,20 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from langchain_community.chat_models.sambanova import ChatSambaStudio\n", "\n", "llm = ChatSambaStudio(\n", - " model=\"Meta-Llama-3-70B-Instruct-4096\", # set if using a CoE endpoint\n", + " model=\"Meta-Llama-3-70B-Instruct-4096\", # set if using a Bundle endpoint\n", " max_tokens=1024,\n", " temperature=0.7,\n", " top_k=1,\n", " top_p=0.01,\n", " do_sample=True,\n", - " process_prompt=\"True\", # set if using a CoE endpoint\n", + " process_prompt=\"True\", # set if using a Bundle endpoint\n", ")" ] }, @@ -349,6 +349,95 @@ " print(chunk.content, end=\"\", flush=True)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tool calling" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "\n", + "from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage\n", + "from langchain_core.tools import tool\n", + "\n", + "\n", + "@tool\n", + "def get_time(kind: str = \"both\") -> str:\n", + " \"\"\"Returns current date, current time or both.\n", + " Args:\n", + " kind: date, time or both\n", + " \"\"\"\n", + " if kind == \"date\":\n", + " date = datetime.now().strftime(\"%m/%d/%Y\")\n", + " return f\"Current date: {date}\"\n", + " elif kind == \"time\":\n", + " time = datetime.now().strftime(\"%H:%M:%S\")\n", + " return f\"Current time: {time}\"\n", + " else:\n", + " date = datetime.now().strftime(\"%m/%d/%Y\")\n", + " time = datetime.now().strftime(\"%H:%M:%S\")\n", + " return f\"Current date: {date}, Current time: {time}\"\n", + "\n", + "\n", + "def invoke_tools(tool_calls, messages):\n", + " for tool_call in tool_calls:\n", + " selected_tool = {\"get_time\": get_time}[tool_call[\"name\"].lower()]\n", + " tool_output = selected_tool.invoke(tool_call[\"args\"])\n", + " print(f\"Tool output: {tool_output}\")\n", + " messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n", + " return messages\n", + "\n", + "\n", + "tools = [get_time]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm_with_tools = llm.bind_tools(tools=tools)\n", + "messages = [\n", + " HumanMessage(\n", + " content=\"I need to schedule a meeting for two weeks from today. Can you tell me the exact date of the meeting?\"\n", + " )\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Intermediate model response: [{'name': 'get_time', 'args': {'kind': 'date'}, 'id': 'call_4092d5dd21cd4eb494', 'type': 'tool_call'}]\n", + "Tool output: Current date: 11/07/2024\n", + "final response: The meeting will be exactly two weeks from today, which would be 25/07/2024.\n" + ] + } + ], + "source": [ + "response = llm_with_tools.invoke(messages)\n", + "if response.tool_calls:\n", + " print(f\"Intermediate model response: {response.tool_calls}\")\n", + " messages.append(response)\n", + " messages = invoke_tools(response.tool_calls, messages)\n", + "response = llm.invoke(messages)\n", + "\n", + "print(f\"final response: {response.content}\")" + ] + }, { "cell_type": "markdown", "metadata": {},