diff --git a/README.md b/README.md
index 000de83..12dede7 100644
--- a/README.md
+++ b/README.md
@@ -734,8 +734,7 @@ rclpy.shutdown()
-#### chat_llama_ros
-
+#### chat_llama_ros (Chat + LVM)
Click to expand
@@ -778,6 +777,73 @@ rclpy.shutdown()
+#### 🎉 \*\*\*NEW*** chat_llama_ros (Tools) 🎉
+
+
+Click to expand
+
+The current implementation of Tools allows executing tools without requiring a model trained for that task.
+
+```python
+
+import time
+
+import rclpy
+from rclpy.node import Node
+from llama_ros.langchain import ChatLlamaROS
+from langchain_core.messages import HumanMessage
+from langchain.tools import tool
+from random import randint
+
+rclpy.init()
+
+@tool
+def get_inhabitants(city: str) -> int:
+ """Get the current temperature of a city"""
+ return randint(4_000_000, 8_000_000)
+
+
+@tool
+def get_curr_temperature(city: str) -> int:
+ """Get the current temperature of a city"""
+ return randint(20, 30)
+
+chat = ChatLlamaROS(temp=0.6, penalty_last_n=8, use_llama_template=True)
+
+messages = [
+ HumanMessage(
+ "What is the current temperature in Madrid? And its inhabitants?"
+ )
+]
+
+llm_tools = self.chat.bind_tools(
+ [get_inhabitants, get_curr_temperature], tool_choice='any'
+)
+
+all_tools_res = llm_tools.invoke(messages)
+messages.append(all_tools_res)
+
+for tool in all_tools_res.tool_calls:
+ selected_tool = {
+ "get_inhabitants": get_inhabitants, "get_curr_temperature": get_curr_temperature
+ }[tool['name']]
+
+ tool_msg = selected_tool.invoke(tool)
+
+ formatted_output = f"{tool['name']}({''.join(tool['args'].values())}) = {tool_msg.content}"
+
+ tool_msg.additional_kwargs = {'args': tool['args']}
+ messages.append(tool_msg)
+
+res = self.chat.invoke(messages)
+
+print(f"Response: {res.content}")
+
+rclpy.shutdown()
+```
+
+
+
## Demos
### LLM Demo
@@ -868,6 +934,20 @@ ros2 run llama_demos chatllama_demo_node
[ChatLlamaROS demo](https://github-production-user-asset-6210df.s3.amazonaws.com/55236157/363094669-c6de124a-4e91-4479-99b6-685fecb0ac20.webm?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240830%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240830T081232Z&X-Amz-Expires=300&X-Amz-Signature=f937758f4bcbaec7683e46ddb057fb642dc86a33cc8c736fca3b5ce2bf06ddac&X-Amz-SignedHeaders=host&actor_id=55236157&key_id=0&repo_id=622137360)
+### Tools Demo
+
+```shell
+ros2 llama launch MiniCPM-2.6.yaml
+```
+
+```shell
+ros2 run llama_demos chatllama_tools_node
+```
+
+
+
+[Tools ChatLlama](https://github.com/user-attachments/assets/b912ee29-1466-4d6a-888b-9a2d9c16ae1d)
+
#### Full Demo (LLM + chat template + RAG + Reranking + Stream)
```shell
diff --git a/llama_demos/CMakeLists.txt b/llama_demos/CMakeLists.txt
index 42a1d0c..375e690 100644
--- a/llama_demos/CMakeLists.txt
+++ b/llama_demos/CMakeLists.txt
@@ -44,5 +44,11 @@ install(PROGRAMS
RENAME chatllama_demo_node
)
+install(PROGRAMS
+ llama_demos/chatllama_tools_node.py
+ DESTINATION lib/${PROJECT_NAME}
+ RENAME chatllama_tools_node
+)
+
ament_python_install_package(${PROJECT_NAME})
ament_package()
diff --git a/llama_demos/llama_demos/chatllama_demo_node.py b/llama_demos/llama_demos/chatllama_demo_node.py
index 4acd478..8505fb7 100644
--- a/llama_demos/llama_demos/chatllama_demo_node.py
+++ b/llama_demos/llama_demos/chatllama_demo_node.py
@@ -54,6 +54,7 @@ def send_prompt(self) -> None:
self.chat = ChatLlamaROS(
temp=0.2,
penalty_last_n=8,
+ use_gguf_template=False,
)
self.prompt = ChatPromptTemplate.from_messages(
diff --git a/llama_demos/llama_demos/chatllama_tools_node.py b/llama_demos/llama_demos/chatllama_tools_node.py
new file mode 100644
index 0000000..8209a16
--- /dev/null
+++ b/llama_demos/llama_demos/chatllama_tools_node.py
@@ -0,0 +1,117 @@
+#!/usr/bin/env python3
+
+# MIT License
+
+# Copyright (c) 2024 Alejandro González Cantón
+# Copyright (c) 2024 Miguel Ángel González Santamarta
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+import time
+
+import rclpy
+from rclpy.node import Node
+from llama_ros.langchain import ChatLlamaROS
+from langchain_core.messages import HumanMessage
+from langchain.tools import tool
+from random import randint
+
+
+@tool
+def get_inhabitants(city: str) -> int:
+ """Get the current temperature of a city"""
+ return randint(4_000_000, 8_000_000)
+
+
+@tool
+def get_curr_temperature(city: str) -> int:
+ """Get the current temperature of a city"""
+ return randint(20, 30)
+
+
+class ChatLlamaToolsDemoNode(Node):
+
+ def __init__(self) -> None:
+ super().__init__("chat_tools_demo_node")
+
+ self.initial_time = -1
+ self.tools_time = -1
+ self.eval_time = -1
+
+ def send_prompt(self) -> None:
+ self.chat = ChatLlamaROS(temp=0.6, penalty_last_n=8, use_llama_template=True)
+
+ messages = [
+ HumanMessage(
+ "What is the current temperature in Madrid? And its inhabitants?"
+ )
+ ]
+
+ self.get_logger().info(f"\nPrompt: {messages[0].content}")
+
+ llm_tools = self.chat.bind_tools(
+ [get_inhabitants, get_curr_temperature], tool_choice="any"
+ )
+
+ self.initial_time = time.time()
+ all_tools_res = llm_tools.invoke(messages)
+ self.tools_time = time.time()
+
+ messages.append(all_tools_res)
+
+ for tool in all_tools_res.tool_calls:
+ selected_tool = {
+ "get_inhabitants": get_inhabitants,
+ "get_curr_temperature": get_curr_temperature,
+ }[tool["name"]]
+
+ tool_msg = selected_tool.invoke(tool)
+
+ formatted_output = (
+ f"{tool['name']}({''.join(tool['args'].values())}) = {tool_msg.content}"
+ )
+ self.get_logger().info(f"Calling tool: {formatted_output}")
+
+ tool_msg.additional_kwargs = {"args": tool["args"]}
+ messages.append(tool_msg)
+
+ res = self.chat.invoke(messages)
+
+ self.eval_time = time.time()
+
+ self.get_logger().info(f"\nResponse: {res.content}")
+
+ time_generate_tools = self.tools_time - self.initial_time
+ time_last_response = self.eval_time - self.tools_time
+ self.get_logger().info(f"Time to generate tools: {time_generate_tools:.2} s")
+ self.get_logger().info(
+ f"Time to generate last response: {time_last_response:.2} s"
+ )
+
+
+def main():
+ rclpy.init()
+ node = ChatLlamaToolsDemoNode()
+ node.send_prompt()
+ rclpy.shutdown()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llama_ros/CMakeLists.txt b/llama_ros/CMakeLists.txt
index 129b7d1..4ff73eb 100644
--- a/llama_ros/CMakeLists.txt
+++ b/llama_ros/CMakeLists.txt
@@ -101,5 +101,9 @@ install(TARGETS
DESTINATION lib/${PROJECT_NAME}
)
+install(DIRECTORY
+ DESTINATION share/${PROJECT_NAME}
+)
+
ament_python_install_package(${PROJECT_NAME})
ament_package()
diff --git a/llama_ros/llama_ros/langchain/chat_llama_ros.py b/llama_ros/llama_ros/langchain/chat_llama_ros.py
index a9234be..fc8027f 100644
--- a/llama_ros/llama_ros/langchain/chat_llama_ros.py
+++ b/llama_ros/llama_ros/langchain/chat_llama_ros.py
@@ -21,24 +21,99 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
-from typing import Any, List, Optional, Dict, Iterator
+from typing import (
+ Any,
+ Callable,
+ List,
+ Literal,
+ Optional,
+ Dict,
+ Iterator,
+ Sequence,
+ Type,
+ Union,
+ Tuple,
+)
+from operator import itemgetter
+from langchain_core.output_parsers import (
+ PydanticToolsParser,
+ JsonOutputKeyToolsParser,
+ PydanticOutputParser,
+ JsonOutputParser,
+)
+from langchain_core.output_parsers.base import OutputParserLike
+from langchain_core.runnables import RunnablePassthrough, RunnableMap
+from langchain_core.utils.pydantic import is_basemodel_subclass
import base64
import cv2
import numpy as np
+import jinja2
+from jinja2.sandbox import ImmutableSandboxedEnvironment
+from pydantic import BaseModel
+import uuid
+from ament_index_python.packages import get_package_share_directory
from llama_ros.langchain import LlamaROSCommon
from llama_msgs.msg import Message
-from llama_msgs.srv import FormatChatMessages
from action_msgs.msg import GoalStatus
+from llama_msgs.srv import Detokenize
+from llama_msgs.srv import FormatChatMessages
+from llama_msgs.action import GenerateResponse
+from langchain_core.utils.function_calling import convert_to_openai_tool
+import json
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
-from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk
+from langchain_core.tools import BaseTool
+from langchain_core.runnables import Runnable
+from langchain_core.language_models import LanguageModelInput
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ HumanMessage,
+ SystemMessage,
+ ToolMessage,
+)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+DEFAULT_TEMPLATE = """{% if tools_grammar %}
+ {{- '<|im_start|>assistant\n' }}
+ {{- 'You are an assistant. You output in JSON format. The key "tool_calls" is a list of possible tools. For each tool, the format is {name, arguments}. You can use the following tools:' }}
+ {% for tool in tools_grammar %}
+ {% if not loop.last %}
+ {{- tool }}
+ {% else %}
+ {{- tool + '<|im_end|>' }}
+ {% endif %}
+ {% endfor %}
+{% endif %}
+
+{% for message in messages %}
+ {% if (loop.last and add_generation_prompt) or not loop.last %}
+ {{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
+ {% else %}
+ {{- '<|im_start|>' + message['role'] + '\n' + message['content'] }}
+ {% endif %}
+{% endfor %}
+{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
+ {{- '<|im_start|>assistant' }}
+{% endif %}
+"""
+
class ChatLlamaROS(BaseChatModel, LlamaROSCommon):
+ use_llama_template: bool = False
+
+ use_gguf_template: bool = True
+
+ jinja_env: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
+ loader=jinja2.BaseLoader(),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
@property
def _default_params(self) -> Dict[str, Any]:
return {}
@@ -47,42 +122,184 @@ def _default_params(self) -> Dict[str, Any]:
def _llm_type(self) -> str:
return "chatllamaros"
- def _messages_to_chat_messages(
- self, messages: List[BaseMessage]
- ) -> tuple[FormatChatMessages.Request, Optional[str], Optional[np.ndarray]]:
+ def _json_schema_to_definition(self, input_json):
+ # Extract the tool name
+ tool_name = input_json["properties"]["name"]["const"]
+
+ # Extract and map arguments to desired format
+ properties = input_json["properties"]["arguments"]["properties"]
+ transformed_properties = {arg: prop["type"] for (arg, prop) in properties.items()}
+
+ # Create the transformed object
+ return {"name": tool_name, "arguments": transformed_properties}
+
+ def _generate_prompt(self, messages: List[dict[str, str]], **kwargs) -> str:
+ tools_grammar = kwargs.get("tools_grammar", None)
+
+ if self.use_llama_template:
+ chat_template = DEFAULT_TEMPLATE
+ else:
+ chat_template = self.model_metadata.tokenizer.chat_template
+
+ formatted_tools = []
+ if tools_grammar:
+ list_options = json.loads(tools_grammar)["properties"]["tool_calls"]["items"]
+ for key in list_options.keys():
+ if key.endswith("Of"):
+ list_key = key
+
+ formatted_tools = [
+ self._json_schema_to_definition(tool) for tool in list_options[list_key]
+ ]
+ formatted_tools = [
+ {key: tool[key] for key in sorted(tool, reverse=True)}
+ for tool in formatted_tools
+ ]
+
+ bos_token = self.llama_client.detokenize(
+ Detokenize.Request(tokens=[self.model_metadata.tokenizer.bos_token_id])
+ ).text
+
+ if self.use_gguf_template or self.use_llama_template:
+ formatted_prompt = self.jinja_env.from_string(chat_template).render(
+ messages=messages,
+ add_generation_prompt=True,
+ bos_token=bos_token,
+ tools_grammar=[json.dumps(tool) for tool in formatted_tools],
+ )
+ return formatted_prompt
+ else:
+ ros_messages = [
+ Message(content=message["content"], role=message["role"])
+ for message in messages
+ ]
+ return self.llama_client.format_chat_prompt(
+ FormatChatMessages.Request(messages=ros_messages)
+ ).formatted_prompt
- chat_messages = FormatChatMessages.Request()
+ def _convert_content(
+ self, content: Union[Dict[str, str], str, List[str], List[Dict[str, str]]]
+ ) -> List[Dict[str, str]]:
+ if isinstance(content, str):
+ return {"type": "text", "text": content}
+ if isinstance(content, list) and len(content) == 1:
+ return self._convert_content(content[0])
+ elif isinstance(content, list):
+ return [self._convert_content(c) for c in content]
+ elif isinstance(content, dict):
+ if content["type"] == "text":
+ return {"type": "text", "text": content["text"]}
+ elif content["type"] == "image_url":
+ image_text = content["image_url"]["url"]
+ if "data:image" in image_text:
+ image_data = image_text.split(",")[-1]
+ decoded_image = base64.b64decode(image_data)
+ np_image = np.frombuffer(decoded_image, np.uint8)
+ image = cv2.imdecode(np_image, cv2.IMREAD_COLOR)
+
+ return {"type": "image", "image": image}
+ else:
+ image_url = image_text
+ return {"type": "image_url", "image_url": image_url}
+
+ def _convert_message_to_dict(self, message: BaseMessage) -> list[dict[str, str]]:
+ if isinstance(message, HumanMessage):
+ converted_msg = [
+ {"role": "user", "content": self._convert_content(message.content)}
+ ]
+ return converted_msg
+
+ elif isinstance(message, AIMessage):
+ all_messages = []
+
+ contents = self._convert_content(message.content)
+ if isinstance(contents, dict):
+ contents = [contents]
+ contents = [
+ content
+ for content in contents
+ if content["type"] == "text" and content["text"] != ""
+ ]
+
+ all_messages.extend(
+ [{"role": "assistant", "content": content} for content in contents]
+ )
+
+ return all_messages
+
+ elif isinstance(message, SystemMessage):
+ converted_msg = [
+ {"role": "system", "content": self._convert_content(message.content)}
+ ]
+ return converted_msg
+
+ elif isinstance(message, ToolMessage):
+ tool_args = message.additional_kwargs.get("args", {})
+ formatted_args = ", ".join([f"{value}" for _, value in tool_args.items()])
+ formatted_content = f"{message.name}({formatted_args}): {message.content}"
+ return [
+ {
+ "role": "tool",
+ "content": {"type": "text", "text": formatted_content},
+ "tool_call_id": message.tool_call_id,
+ }
+ ]
+ else:
+ raise ValueError(f"Unsupported message type: {type(message)}")
+
+ def _extract_data_from_messages(
+ self, messages: List[BaseMessage]
+ ) -> Tuple[Dict[str, str], Optional[str], Optional[str]]:
+ new_messages = []
image_url = None
image = None
- for message in messages:
- role = message.type
- if role.lower() == "human":
- role = "user"
+ def process_content(role, content):
+ nonlocal image, image_url
+ if isinstance(content, str):
+ new_messages.append({"role": role, "content": content})
+ elif isinstance(content, dict):
+ if content["type"] == "text":
+ new_messages.append({"role": role, "content": content["text"]})
+ elif content["type"] == "image":
+ image = content["image"]
+ elif content["type"] == "image_url":
+ image_url = content["image_url"]
- if type(message.content) == str:
- chat_messages.messages.append(Message(role=role, content=message.content))
+ for message in messages:
+ role = message["role"]
+ content = message["content"]
+ if isinstance(content, list):
+ for single_content in content:
+ process_content(role, single_content)
else:
- for single_content in message.content:
- if type(single_content) == str:
- chat_messages.messages.append(
- Message(role=role, content=single_content)
- )
- elif single_content["type"] == "text":
- chat_messages.messages.append(
- Message(role=role, content=single_content["text"])
- )
- elif single_content["type"] == "image_url":
- image_text = single_content["image_url"]["url"]
- if "data:image" in image_text:
- image_data = image_text.split(",")[-1]
- decoded_image = base64.b64decode(image_data)
- np_image = np.frombuffer(decoded_image, np.uint8)
- image = cv2.imdecode(np_image, cv2.IMREAD_COLOR)
- else:
- image_url = image_text
-
- return chat_messages, image_url, image
+ process_content(role, content)
+
+ return new_messages, image_url, image
+
+ def _create_chat_generations(
+ self, response: GenerateResponse.Result, method: str
+ ) -> List[BaseMessage]:
+ chat_gen = None
+
+ if method == "function_calling":
+ ai_message = AIMessage(content="", tool_calls=[])
+ parsed_output = json.loads(response.text)
+ for tool in parsed_output["tool_calls"]:
+ ai_message.tool_calls.append(
+ {
+ "name": tool["name"],
+ "args": tool["arguments"],
+ "type": "tool_call",
+ "id": f'{tool["name"]}_{uuid.uuid4()}',
+ }
+ )
+
+ chat_gen = ChatGeneration(message=ai_message)
+ else:
+ chat_gen = ChatGeneration(message=AIMessage(content=response.text))
+
+ return ChatResult(generations=[chat_gen])
def _generate(
self,
@@ -91,23 +308,25 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
+ dict_messages = []
+ for message in messages:
+ dict_messages.extend(self._convert_message_to_dict(message))
- llama_client = self.llama_client.get_instance()
+ chat_messages, image_url, image = self._extract_data_from_messages(dict_messages)
- chat_messages, image_url, image = self._messages_to_chat_messages(messages)
- formatted_prompt = llama_client.format_chat_prompt(chat_messages).formatted_prompt
+ formatted_prompt = self._generate_prompt(chat_messages, **kwargs)
goal_action = self._create_action_goal(
formatted_prompt, stop, image_url, image, **kwargs
)
result, status = self.llama_client.generate_response(goal_action)
+ response = result.response
if status != GoalStatus.STATUS_SUCCEEDED:
return ""
- generation = ChatGeneration(message=AIMessage(content=result.response.text))
- return ChatResult(generations=[generation])
+ return self._create_chat_generations(response, kwargs.get("method", "chat"))
def _stream(
self,
@@ -116,11 +335,18 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
+ if kwargs.get("method") == "function_calling":
+ raise ValueError(
+ "Streaming is not supported when using 'function_calling' method."
+ )
+
+ dict_messages = []
+ for message in messages:
+ dict_messages.extend(self._convert_message_to_dict(message))
- llama_client = self.llama_client.get_instance()
+ chat_messages, image_url, image = self._extract_data_from_messages(dict_messages)
- chat_messages, image_url, image = self._messages_to_chat_messages(messages)
- formatted_prompt = llama_client.format_chat_prompt(chat_messages).formatted_prompt
+ formatted_prompt = self._generate_prompt(chat_messages)
goal_action = self._create_action_goal(
formatted_prompt, stop, image_url, image, **kwargs
@@ -135,3 +361,143 @@ def _stream(
)
yield ChatGenerationChunk(message=AIMessageChunk(content=pt.text))
+
+ def bind_tools(
+ self,
+ tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
+ *,
+ tool_choice: Optional[Union[Dict[str, Dict], bool, str]] = None,
+ method: Literal[
+ "function_calling", "json_schema", "json_mode"
+ ] = "function_calling",
+ **kwargs: Any,
+ ) -> Runnable[LanguageModelInput, BaseMessage]:
+ """Bind tool-like objects to this chat model
+
+ tool_choice: does not currently support "any", "auto" choices like OpenAI
+ tool-calling API. should be a dict of the form to force this tool
+ {"type": "function", "function": {"name": <>}}.
+ """
+
+ formatted_tools = []
+
+ for tool in tools:
+ formatted_tools.append(convert_to_openai_tool(tool)["function"])
+
+ tool_names = [ft["name"] for ft in formatted_tools]
+ valid_choices = ["all", "one", "any"]
+
+ is_valid_choice = tool_choice in valid_choices
+
+ chosen_tool = [f for f in formatted_tools if f["name"] == tool_choice]
+
+ if not chosen_tool and not is_valid_choice:
+ raise ValueError(
+ f"Tool choice {tool_choice=} was specified, but the only "
+ f"provided tools were {tool_names}."
+ )
+
+ grammar = {}
+
+ if method == "json_mode" or method == "json_schema":
+ grammar = chosen_tool[0]["parameters"]
+ else:
+ grammar = {
+ "type": "object",
+ "properties": {
+ "tool_calls": {
+ "type": "array",
+ "items": {"type": "object"},
+ "maxItems": 10,
+ }
+ },
+ "required": ["tool_calls"],
+ }
+
+ if chosen_tool:
+ grammar["properties"]["tool_calls"]["items"]["oneOf"] = []
+ new_action = {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string", "const": chosen_tool[0]["name"]},
+ "arguments": chosen_tool[0]["parameters"],
+ },
+ "required": ["name", "arguments"],
+ }
+ grammar["properties"]["tool_calls"]["items"]["oneOf"].append(new_action)
+ else:
+ grammar["properties"]["tool_calls"]["items"][f"{tool_choice}Of"] = []
+ for tool in formatted_tools:
+ new_action = {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string", "const": tool["name"]},
+ "arguments": tool["parameters"],
+ },
+ "required": ["name", "arguments"],
+ }
+ grammar["properties"]["tool_calls"]["items"][
+ f"{tool_choice}Of"
+ ].append(new_action)
+
+ return super().bind(tools_grammar=json.dumps(grammar), method=method, **kwargs)
+
+ def with_structured_output(
+ self,
+ schema: Optional[Union[Dict, Type[BaseModel], Type]] = None,
+ *,
+ include_raw: bool = False,
+ method: Literal[
+ "function_calling", "json_schema", "json_mode"
+ ] = "function_calling",
+ **kwargs: Any,
+ ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
+ if kwargs:
+ raise ValueError(f"Received unsupported arguments {kwargs}")
+ is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
+
+ if method == "json_mode" or method == "json_schema":
+ tool_name = schema.__name__
+
+ llm = self.bind_tools([schema], tool_choice=tool_name, method=method)
+ output_parser = (
+ PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
+ if is_pydantic_schema
+ else JsonOutputParser()
+ )
+ elif method == "function_calling":
+ if schema is None:
+ raise ValueError(
+ "schema must be specified when method is not 'json_mode'. "
+ "Received None."
+ )
+ schema = convert_to_openai_tool(schema)["function"]
+ tool_name = schema["name"]
+
+ llm = self.bind_tools([schema], tool_choice=tool_name, method=method)
+ if is_pydantic_schema:
+ output_parser: OutputParserLike = PydanticToolsParser(
+ tools=[schema], # type: ignore[list-item]
+ first_tool_only=True, # type: ignore[list-item]
+ )
+ else:
+ output_parser = JsonOutputKeyToolsParser(
+ key_name=tool_name, first_tool_only=True
+ )
+ else:
+ raise ValueError(
+ f"Unrecognized method argument. Expected one of 'function_calling' or "
+ f"'json_mode'. Received: '{method}'"
+ )
+
+ if include_raw:
+ parser_assign = RunnablePassthrough.assign(
+ parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
+ )
+ parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
+ parser_with_fallback = parser_assign.with_fallbacks(
+ [parser_none], exception_key="parsing_error"
+ )
+ return RunnableMap(raw=llm) | parser_with_fallback
+ else:
+ return llm | output_parser
diff --git a/llama_ros/llama_ros/langchain/llama_ros_common.py b/llama_ros/llama_ros/langchain/llama_ros_common.py
index 0ed7c31..9adfd91 100644
--- a/llama_ros/llama_ros/langchain/llama_ros_common.py
+++ b/llama_ros/llama_ros/langchain/llama_ros_common.py
@@ -33,13 +33,16 @@
from llama_ros.llama_client_node import LlamaClientNode
from llama_msgs.action import GenerateResponse
+from llama_msgs.srv import GetMetadata
from llama_msgs.msg import LogitBias
+from llama_msgs.msg import Metadata
class LlamaROSCommon(BaseLanguageModel, ABC):
llama_client: LlamaClientNode = None
cv_bridge: CvBridge = CvBridge()
+ model_metadata: Metadata = None
# sampling params
n_prev: int = 64
@@ -92,6 +95,9 @@ class Config:
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
values["llama_client"] = LlamaClientNode.get_instance()
+ values["model_metadata"] = (
+ values["llama_client"].get_metadata(GetMetadata.Request()).metadata
+ )
return values
def cancel(self) -> None:
@@ -103,6 +109,8 @@ def _create_action_goal(
stop: Optional[List[str]] = None,
image_url: Optional[str] = None,
image: Optional[np.ndarray] = None,
+ tools_grammar: Optional[str] = None,
+ **kwargs
) -> GenerateResponse.Result:
goal = GenerateResponse.Goal()
@@ -167,7 +175,9 @@ def _create_action_goal(
goal.sampling_config.samplers_sequence = self.samplers_sequence
goal.sampling_config.grammar = self.grammar
- goal.sampling_config.grammar_schema = self.grammar_schema
+ goal.sampling_config.grammar_schema = (
+ tools_grammar if tools_grammar else self.grammar_schema
+ )
goal.sampling_config.penalty_prompt_tokens = self.penalty_prompt_tokens
goal.sampling_config.use_penalty_prompt_tokens = self.use_penalty_prompt_tokens
diff --git a/llama_ros/llama_ros/llama_client_node.py b/llama_ros/llama_ros/llama_client_node.py
index 1d4625a..eceb862 100644
--- a/llama_ros/llama_ros/llama_client_node.py
+++ b/llama_ros/llama_ros/llama_client_node.py
@@ -125,6 +125,7 @@ def __init__(self, namespace: str = "llama") -> None:
self._spin_thread.start()
def get_metadata(self, req: GetMetadata.Request) -> GetMetadata:
+ self._get_metadata_srv_client.wait_for_service()
return self._get_metadata_srv_client.call(req)
def tokenize(self, req: Tokenize.Request) -> Tokenize.Response:
diff --git a/llama_ros/src/llama_ros/llama.cpp b/llama_ros/src/llama_ros/llama.cpp
index e68d261..b4b3ff2 100644
--- a/llama_ros/src/llama_ros/llama.cpp
+++ b/llama_ros/src/llama_ros/llama.cpp
@@ -369,7 +369,7 @@ struct Metadata Llama::get_metadata() {
metadata.tokenizer.add_bos_token =
this->get_metadata("tokenizer.ggml.add_bos_token", 8) == "true";
metadata.tokenizer.chat_template =
- this->get_metadata("tokenizer.chat_template", 2048);
+ this->get_metadata("tokenizer.chat_template", 4096);
return metadata;
}