diff --git a/llama_demos/llama_demos/chatllama_tools_demo_node.py b/llama_demos/llama_demos/chatllama_tools_demo_node.py index 8c878ba..f2ef9fa 100644 --- a/llama_demos/llama_demos/chatllama_tools_demo_node.py +++ b/llama_demos/llama_demos/chatllama_tools_demo_node.py @@ -66,7 +66,9 @@ def send_prompt(self) -> None: ] self.get_logger().info(f"\nPrompt: {messages[0].content}") - llm_tools = self.chat.bind_tools([get_inhabitants, get_curr_temperature]) + llm_tools = self.chat.bind_tools( + [get_inhabitants, get_curr_temperature], tool_choice="any" + ) self.initial_time = time.time() all_tools_res = llm_tools.invoke(messages) diff --git a/llama_ros/llama_ros/langchain/chat_llama_ros.py b/llama_ros/llama_ros/langchain/chat_llama_ros.py index 2e15aea..6cb96c9 100644 --- a/llama_ros/llama_ros/langchain/chat_llama_ros.py +++ b/llama_ros/llama_ros/langchain/chat_llama_ros.py @@ -79,7 +79,7 @@ DEFAULT_TEMPLATE = """{% if tools_grammar %} {{- '<|im_start|>assistant\n' }} - {{- 'You are an assistant. Output in JSON format. The key "tool_calls" is a list of tools in the format: {name, arguments}. Available tools are:' }} + {{- 'You are an assistant. Output in JSON format. The key "tool_calls" is a list of tools in the format {name, arguments}. Available tools are:' }} {% for tool in tools_grammar %} {% if not loop.last %} {{- tool }} @@ -401,9 +401,8 @@ def bind_tools( tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], *, tool_choice: Optional[ - Union[dict, str, Literal["all", "one", "any"], bool] - ] = "any", - only_tool_calling: bool = True, + Union[dict, str, Literal["auto", "all", "one", "any"], bool] + ] = "auto", method: Literal[ "function_calling", "json_schema", "json_mode" ] = "function_calling", @@ -412,7 +411,7 @@ def bind_tools( formatted_tools = [convert_to_openai_tool(tool)["function"] for tool in tools] tool_names = [ft["name"] for ft in formatted_tools] - valid_choices = ["all", "one", "any"] + valid_choices = ["auto", "all", "one", "any"] is_valid_choice = tool_choice in valid_choices chosen_tool = [f for f in formatted_tools if f["name"] == tool_choice] @@ -471,6 +470,8 @@ def bind_tools( ] else: + only_tool_calling = tool_choice != "auto" + tool_choice = "any" if not only_tool_calling else tool_choice tool_calls["properties"]["tool_calls"]["items"][f"{tool_choice}Of"] = [] for tool in formatted_tools: