diff --git a/README.md b/README.md index 000de83..12dede7 100644 --- a/README.md +++ b/README.md @@ -734,8 +734,7 @@ rclpy.shutdown() -#### chat_llama_ros - +#### chat_llama_ros (Chat + LVM)
Click to expand @@ -778,6 +777,73 @@ rclpy.shutdown()
+#### 🎉 \*\*\*NEW*** chat_llama_ros (Tools) 🎉 + +
+Click to expand + +The current implementation of Tools allows executing tools without requiring a model trained for that task. + +```python + +import time + +import rclpy +from rclpy.node import Node +from llama_ros.langchain import ChatLlamaROS +from langchain_core.messages import HumanMessage +from langchain.tools import tool +from random import randint + +rclpy.init() + +@tool +def get_inhabitants(city: str) -> int: + """Get the current temperature of a city""" + return randint(4_000_000, 8_000_000) + + +@tool +def get_curr_temperature(city: str) -> int: + """Get the current temperature of a city""" + return randint(20, 30) + +chat = ChatLlamaROS(temp=0.6, penalty_last_n=8, use_llama_template=True) + +messages = [ + HumanMessage( + "What is the current temperature in Madrid? And its inhabitants?" + ) +] + +llm_tools = self.chat.bind_tools( + [get_inhabitants, get_curr_temperature], tool_choice='any' +) + +all_tools_res = llm_tools.invoke(messages) +messages.append(all_tools_res) + +for tool in all_tools_res.tool_calls: + selected_tool = { + "get_inhabitants": get_inhabitants, "get_curr_temperature": get_curr_temperature + }[tool['name']] + + tool_msg = selected_tool.invoke(tool) + + formatted_output = f"{tool['name']}({''.join(tool['args'].values())}) = {tool_msg.content}" + + tool_msg.additional_kwargs = {'args': tool['args']} + messages.append(tool_msg) + +res = self.chat.invoke(messages) + +print(f"Response: {res.content}") + +rclpy.shutdown() +``` + +
+ ## Demos ### LLM Demo @@ -868,6 +934,20 @@ ros2 run llama_demos chatllama_demo_node [ChatLlamaROS demo](https://github-production-user-asset-6210df.s3.amazonaws.com/55236157/363094669-c6de124a-4e91-4479-99b6-685fecb0ac20.webm?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240830%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240830T081232Z&X-Amz-Expires=300&X-Amz-Signature=f937758f4bcbaec7683e46ddb057fb642dc86a33cc8c736fca3b5ce2bf06ddac&X-Amz-SignedHeaders=host&actor_id=55236157&key_id=0&repo_id=622137360) +### Tools Demo + +```shell +ros2 llama launch MiniCPM-2.6.yaml +``` + +```shell +ros2 run llama_demos chatllama_tools_node +``` + + + +[Tools ChatLlama](https://github.com/user-attachments/assets/b912ee29-1466-4d6a-888b-9a2d9c16ae1d) + #### Full Demo (LLM + chat template + RAG + Reranking + Stream) ```shell diff --git a/llama_demos/CMakeLists.txt b/llama_demos/CMakeLists.txt index 42a1d0c..375e690 100644 --- a/llama_demos/CMakeLists.txt +++ b/llama_demos/CMakeLists.txt @@ -44,5 +44,11 @@ install(PROGRAMS RENAME chatllama_demo_node ) +install(PROGRAMS + llama_demos/chatllama_tools_node.py + DESTINATION lib/${PROJECT_NAME} + RENAME chatllama_tools_node +) + ament_python_install_package(${PROJECT_NAME}) ament_package() diff --git a/llama_demos/llama_demos/chatllama_demo_node.py b/llama_demos/llama_demos/chatllama_demo_node.py index 4acd478..8505fb7 100644 --- a/llama_demos/llama_demos/chatllama_demo_node.py +++ b/llama_demos/llama_demos/chatllama_demo_node.py @@ -54,6 +54,7 @@ def send_prompt(self) -> None: self.chat = ChatLlamaROS( temp=0.2, penalty_last_n=8, + use_gguf_template=False, ) self.prompt = ChatPromptTemplate.from_messages( diff --git a/llama_demos/llama_demos/chatllama_tools_node.py b/llama_demos/llama_demos/chatllama_tools_node.py new file mode 100644 index 0000000..8209a16 --- /dev/null +++ b/llama_demos/llama_demos/chatllama_tools_node.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 + +# MIT License + +# Copyright (c) 2024 Alejandro González Cantón +# Copyright (c) 2024 Miguel Ángel González Santamarta + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import time + +import rclpy +from rclpy.node import Node +from llama_ros.langchain import ChatLlamaROS +from langchain_core.messages import HumanMessage +from langchain.tools import tool +from random import randint + + +@tool +def get_inhabitants(city: str) -> int: + """Get the current temperature of a city""" + return randint(4_000_000, 8_000_000) + + +@tool +def get_curr_temperature(city: str) -> int: + """Get the current temperature of a city""" + return randint(20, 30) + + +class ChatLlamaToolsDemoNode(Node): + + def __init__(self) -> None: + super().__init__("chat_tools_demo_node") + + self.initial_time = -1 + self.tools_time = -1 + self.eval_time = -1 + + def send_prompt(self) -> None: + self.chat = ChatLlamaROS(temp=0.6, penalty_last_n=8, use_llama_template=True) + + messages = [ + HumanMessage( + "What is the current temperature in Madrid? And its inhabitants?" + ) + ] + + self.get_logger().info(f"\nPrompt: {messages[0].content}") + + llm_tools = self.chat.bind_tools( + [get_inhabitants, get_curr_temperature], tool_choice="any" + ) + + self.initial_time = time.time() + all_tools_res = llm_tools.invoke(messages) + self.tools_time = time.time() + + messages.append(all_tools_res) + + for tool in all_tools_res.tool_calls: + selected_tool = { + "get_inhabitants": get_inhabitants, + "get_curr_temperature": get_curr_temperature, + }[tool["name"]] + + tool_msg = selected_tool.invoke(tool) + + formatted_output = ( + f"{tool['name']}({''.join(tool['args'].values())}) = {tool_msg.content}" + ) + self.get_logger().info(f"Calling tool: {formatted_output}") + + tool_msg.additional_kwargs = {"args": tool["args"]} + messages.append(tool_msg) + + res = self.chat.invoke(messages) + + self.eval_time = time.time() + + self.get_logger().info(f"\nResponse: {res.content}") + + time_generate_tools = self.tools_time - self.initial_time + time_last_response = self.eval_time - self.tools_time + self.get_logger().info(f"Time to generate tools: {time_generate_tools:.2} s") + self.get_logger().info( + f"Time to generate last response: {time_last_response:.2} s" + ) + + +def main(): + rclpy.init() + node = ChatLlamaToolsDemoNode() + node.send_prompt() + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/llama_ros/CMakeLists.txt b/llama_ros/CMakeLists.txt index 129b7d1..4ff73eb 100644 --- a/llama_ros/CMakeLists.txt +++ b/llama_ros/CMakeLists.txt @@ -101,5 +101,9 @@ install(TARGETS DESTINATION lib/${PROJECT_NAME} ) +install(DIRECTORY + DESTINATION share/${PROJECT_NAME} +) + ament_python_install_package(${PROJECT_NAME}) ament_package() diff --git a/llama_ros/llama_ros/langchain/chat_llama_ros.py b/llama_ros/llama_ros/langchain/chat_llama_ros.py index a9234be..fc8027f 100644 --- a/llama_ros/llama_ros/langchain/chat_llama_ros.py +++ b/llama_ros/llama_ros/langchain/chat_llama_ros.py @@ -21,24 +21,99 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Any, List, Optional, Dict, Iterator +from typing import ( + Any, + Callable, + List, + Literal, + Optional, + Dict, + Iterator, + Sequence, + Type, + Union, + Tuple, +) +from operator import itemgetter +from langchain_core.output_parsers import ( + PydanticToolsParser, + JsonOutputKeyToolsParser, + PydanticOutputParser, + JsonOutputParser, +) +from langchain_core.output_parsers.base import OutputParserLike +from langchain_core.runnables import RunnablePassthrough, RunnableMap +from langchain_core.utils.pydantic import is_basemodel_subclass import base64 import cv2 import numpy as np +import jinja2 +from jinja2.sandbox import ImmutableSandboxedEnvironment +from pydantic import BaseModel +import uuid +from ament_index_python.packages import get_package_share_directory from llama_ros.langchain import LlamaROSCommon from llama_msgs.msg import Message -from llama_msgs.srv import FormatChatMessages from action_msgs.msg import GoalStatus +from llama_msgs.srv import Detokenize +from llama_msgs.srv import FormatChatMessages +from llama_msgs.action import GenerateResponse +from langchain_core.utils.function_calling import convert_to_openai_tool +import json from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models import BaseChatModel -from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk +from langchain_core.tools import BaseTool +from langchain_core.runnables import Runnable +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +DEFAULT_TEMPLATE = """{% if tools_grammar %} + {{- '<|im_start|>assistant\n' }} + {{- 'You are an assistant. You output in JSON format. The key "tool_calls" is a list of possible tools. For each tool, the format is {name, arguments}. You can use the following tools:' }} + {% for tool in tools_grammar %} + {% if not loop.last %} + {{- tool }} + {% else %} + {{- tool + '<|im_end|>' }} + {% endif %} + {% endfor %} +{% endif %} + +{% for message in messages %} + {% if (loop.last and add_generation_prompt) or not loop.last %} + {{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }} + {% else %} + {{- '<|im_start|>' + message['role'] + '\n' + message['content'] }} + {% endif %} +{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} + {{- '<|im_start|>assistant' }} +{% endif %} +""" + class ChatLlamaROS(BaseChatModel, LlamaROSCommon): + use_llama_template: bool = False + + use_gguf_template: bool = True + + jinja_env: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment( + loader=jinja2.BaseLoader(), + trim_blocks=True, + lstrip_blocks=True, + ) + @property def _default_params(self) -> Dict[str, Any]: return {} @@ -47,42 +122,184 @@ def _default_params(self) -> Dict[str, Any]: def _llm_type(self) -> str: return "chatllamaros" - def _messages_to_chat_messages( - self, messages: List[BaseMessage] - ) -> tuple[FormatChatMessages.Request, Optional[str], Optional[np.ndarray]]: + def _json_schema_to_definition(self, input_json): + # Extract the tool name + tool_name = input_json["properties"]["name"]["const"] + + # Extract and map arguments to desired format + properties = input_json["properties"]["arguments"]["properties"] + transformed_properties = {arg: prop["type"] for (arg, prop) in properties.items()} + + # Create the transformed object + return {"name": tool_name, "arguments": transformed_properties} + + def _generate_prompt(self, messages: List[dict[str, str]], **kwargs) -> str: + tools_grammar = kwargs.get("tools_grammar", None) + + if self.use_llama_template: + chat_template = DEFAULT_TEMPLATE + else: + chat_template = self.model_metadata.tokenizer.chat_template + + formatted_tools = [] + if tools_grammar: + list_options = json.loads(tools_grammar)["properties"]["tool_calls"]["items"] + for key in list_options.keys(): + if key.endswith("Of"): + list_key = key + + formatted_tools = [ + self._json_schema_to_definition(tool) for tool in list_options[list_key] + ] + formatted_tools = [ + {key: tool[key] for key in sorted(tool, reverse=True)} + for tool in formatted_tools + ] + + bos_token = self.llama_client.detokenize( + Detokenize.Request(tokens=[self.model_metadata.tokenizer.bos_token_id]) + ).text + + if self.use_gguf_template or self.use_llama_template: + formatted_prompt = self.jinja_env.from_string(chat_template).render( + messages=messages, + add_generation_prompt=True, + bos_token=bos_token, + tools_grammar=[json.dumps(tool) for tool in formatted_tools], + ) + return formatted_prompt + else: + ros_messages = [ + Message(content=message["content"], role=message["role"]) + for message in messages + ] + return self.llama_client.format_chat_prompt( + FormatChatMessages.Request(messages=ros_messages) + ).formatted_prompt - chat_messages = FormatChatMessages.Request() + def _convert_content( + self, content: Union[Dict[str, str], str, List[str], List[Dict[str, str]]] + ) -> List[Dict[str, str]]: + if isinstance(content, str): + return {"type": "text", "text": content} + if isinstance(content, list) and len(content) == 1: + return self._convert_content(content[0]) + elif isinstance(content, list): + return [self._convert_content(c) for c in content] + elif isinstance(content, dict): + if content["type"] == "text": + return {"type": "text", "text": content["text"]} + elif content["type"] == "image_url": + image_text = content["image_url"]["url"] + if "data:image" in image_text: + image_data = image_text.split(",")[-1] + decoded_image = base64.b64decode(image_data) + np_image = np.frombuffer(decoded_image, np.uint8) + image = cv2.imdecode(np_image, cv2.IMREAD_COLOR) + + return {"type": "image", "image": image} + else: + image_url = image_text + return {"type": "image_url", "image_url": image_url} + + def _convert_message_to_dict(self, message: BaseMessage) -> list[dict[str, str]]: + if isinstance(message, HumanMessage): + converted_msg = [ + {"role": "user", "content": self._convert_content(message.content)} + ] + return converted_msg + + elif isinstance(message, AIMessage): + all_messages = [] + + contents = self._convert_content(message.content) + if isinstance(contents, dict): + contents = [contents] + contents = [ + content + for content in contents + if content["type"] == "text" and content["text"] != "" + ] + + all_messages.extend( + [{"role": "assistant", "content": content} for content in contents] + ) + + return all_messages + + elif isinstance(message, SystemMessage): + converted_msg = [ + {"role": "system", "content": self._convert_content(message.content)} + ] + return converted_msg + + elif isinstance(message, ToolMessage): + tool_args = message.additional_kwargs.get("args", {}) + formatted_args = ", ".join([f"{value}" for _, value in tool_args.items()]) + formatted_content = f"{message.name}({formatted_args}): {message.content}" + return [ + { + "role": "tool", + "content": {"type": "text", "text": formatted_content}, + "tool_call_id": message.tool_call_id, + } + ] + else: + raise ValueError(f"Unsupported message type: {type(message)}") + + def _extract_data_from_messages( + self, messages: List[BaseMessage] + ) -> Tuple[Dict[str, str], Optional[str], Optional[str]]: + new_messages = [] image_url = None image = None - for message in messages: - role = message.type - if role.lower() == "human": - role = "user" + def process_content(role, content): + nonlocal image, image_url + if isinstance(content, str): + new_messages.append({"role": role, "content": content}) + elif isinstance(content, dict): + if content["type"] == "text": + new_messages.append({"role": role, "content": content["text"]}) + elif content["type"] == "image": + image = content["image"] + elif content["type"] == "image_url": + image_url = content["image_url"] - if type(message.content) == str: - chat_messages.messages.append(Message(role=role, content=message.content)) + for message in messages: + role = message["role"] + content = message["content"] + if isinstance(content, list): + for single_content in content: + process_content(role, single_content) else: - for single_content in message.content: - if type(single_content) == str: - chat_messages.messages.append( - Message(role=role, content=single_content) - ) - elif single_content["type"] == "text": - chat_messages.messages.append( - Message(role=role, content=single_content["text"]) - ) - elif single_content["type"] == "image_url": - image_text = single_content["image_url"]["url"] - if "data:image" in image_text: - image_data = image_text.split(",")[-1] - decoded_image = base64.b64decode(image_data) - np_image = np.frombuffer(decoded_image, np.uint8) - image = cv2.imdecode(np_image, cv2.IMREAD_COLOR) - else: - image_url = image_text - - return chat_messages, image_url, image + process_content(role, content) + + return new_messages, image_url, image + + def _create_chat_generations( + self, response: GenerateResponse.Result, method: str + ) -> List[BaseMessage]: + chat_gen = None + + if method == "function_calling": + ai_message = AIMessage(content="", tool_calls=[]) + parsed_output = json.loads(response.text) + for tool in parsed_output["tool_calls"]: + ai_message.tool_calls.append( + { + "name": tool["name"], + "args": tool["arguments"], + "type": "tool_call", + "id": f'{tool["name"]}_{uuid.uuid4()}', + } + ) + + chat_gen = ChatGeneration(message=ai_message) + else: + chat_gen = ChatGeneration(message=AIMessage(content=response.text)) + + return ChatResult(generations=[chat_gen]) def _generate( self, @@ -91,23 +308,25 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: + dict_messages = [] + for message in messages: + dict_messages.extend(self._convert_message_to_dict(message)) - llama_client = self.llama_client.get_instance() + chat_messages, image_url, image = self._extract_data_from_messages(dict_messages) - chat_messages, image_url, image = self._messages_to_chat_messages(messages) - formatted_prompt = llama_client.format_chat_prompt(chat_messages).formatted_prompt + formatted_prompt = self._generate_prompt(chat_messages, **kwargs) goal_action = self._create_action_goal( formatted_prompt, stop, image_url, image, **kwargs ) result, status = self.llama_client.generate_response(goal_action) + response = result.response if status != GoalStatus.STATUS_SUCCEEDED: return "" - generation = ChatGeneration(message=AIMessage(content=result.response.text)) - return ChatResult(generations=[generation]) + return self._create_chat_generations(response, kwargs.get("method", "chat")) def _stream( self, @@ -116,11 +335,18 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + if kwargs.get("method") == "function_calling": + raise ValueError( + "Streaming is not supported when using 'function_calling' method." + ) + + dict_messages = [] + for message in messages: + dict_messages.extend(self._convert_message_to_dict(message)) - llama_client = self.llama_client.get_instance() + chat_messages, image_url, image = self._extract_data_from_messages(dict_messages) - chat_messages, image_url, image = self._messages_to_chat_messages(messages) - formatted_prompt = llama_client.format_chat_prompt(chat_messages).formatted_prompt + formatted_prompt = self._generate_prompt(chat_messages) goal_action = self._create_action_goal( formatted_prompt, stop, image_url, image, **kwargs @@ -135,3 +361,143 @@ def _stream( ) yield ChatGenerationChunk(message=AIMessageChunk(content=pt.text)) + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + *, + tool_choice: Optional[Union[Dict[str, Dict], bool, str]] = None, + method: Literal[ + "function_calling", "json_schema", "json_mode" + ] = "function_calling", + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model + + tool_choice: does not currently support "any", "auto" choices like OpenAI + tool-calling API. should be a dict of the form to force this tool + {"type": "function", "function": {"name": <>}}. + """ + + formatted_tools = [] + + for tool in tools: + formatted_tools.append(convert_to_openai_tool(tool)["function"]) + + tool_names = [ft["name"] for ft in formatted_tools] + valid_choices = ["all", "one", "any"] + + is_valid_choice = tool_choice in valid_choices + + chosen_tool = [f for f in formatted_tools if f["name"] == tool_choice] + + if not chosen_tool and not is_valid_choice: + raise ValueError( + f"Tool choice {tool_choice=} was specified, but the only " + f"provided tools were {tool_names}." + ) + + grammar = {} + + if method == "json_mode" or method == "json_schema": + grammar = chosen_tool[0]["parameters"] + else: + grammar = { + "type": "object", + "properties": { + "tool_calls": { + "type": "array", + "items": {"type": "object"}, + "maxItems": 10, + } + }, + "required": ["tool_calls"], + } + + if chosen_tool: + grammar["properties"]["tool_calls"]["items"]["oneOf"] = [] + new_action = { + "type": "object", + "properties": { + "name": {"type": "string", "const": chosen_tool[0]["name"]}, + "arguments": chosen_tool[0]["parameters"], + }, + "required": ["name", "arguments"], + } + grammar["properties"]["tool_calls"]["items"]["oneOf"].append(new_action) + else: + grammar["properties"]["tool_calls"]["items"][f"{tool_choice}Of"] = [] + for tool in formatted_tools: + new_action = { + "type": "object", + "properties": { + "name": {"type": "string", "const": tool["name"]}, + "arguments": tool["parameters"], + }, + "required": ["name", "arguments"], + } + grammar["properties"]["tool_calls"]["items"][ + f"{tool_choice}Of" + ].append(new_action) + + return super().bind(tools_grammar=json.dumps(grammar), method=method, **kwargs) + + def with_structured_output( + self, + schema: Optional[Union[Dict, Type[BaseModel], Type]] = None, + *, + include_raw: bool = False, + method: Literal[ + "function_calling", "json_schema", "json_mode" + ] = "function_calling", + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + if kwargs: + raise ValueError(f"Received unsupported arguments {kwargs}") + is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema) + + if method == "json_mode" or method == "json_schema": + tool_name = schema.__name__ + + llm = self.bind_tools([schema], tool_choice=tool_name, method=method) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + elif method == "function_calling": + if schema is None: + raise ValueError( + "schema must be specified when method is not 'json_mode'. " + "Received None." + ) + schema = convert_to_openai_tool(schema)["function"] + tool_name = schema["name"] + + llm = self.bind_tools([schema], tool_choice=tool_name, method=method) + if is_pydantic_schema: + output_parser: OutputParserLike = PydanticToolsParser( + tools=[schema], # type: ignore[list-item] + first_tool_only=True, # type: ignore[list-item] + ) + else: + output_parser = JsonOutputKeyToolsParser( + key_name=tool_name, first_tool_only=True + ) + else: + raise ValueError( + f"Unrecognized method argument. Expected one of 'function_calling' or " + f"'json_mode'. Received: '{method}'" + ) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser diff --git a/llama_ros/llama_ros/langchain/llama_ros_common.py b/llama_ros/llama_ros/langchain/llama_ros_common.py index 0ed7c31..9adfd91 100644 --- a/llama_ros/llama_ros/langchain/llama_ros_common.py +++ b/llama_ros/llama_ros/langchain/llama_ros_common.py @@ -33,13 +33,16 @@ from llama_ros.llama_client_node import LlamaClientNode from llama_msgs.action import GenerateResponse +from llama_msgs.srv import GetMetadata from llama_msgs.msg import LogitBias +from llama_msgs.msg import Metadata class LlamaROSCommon(BaseLanguageModel, ABC): llama_client: LlamaClientNode = None cv_bridge: CvBridge = CvBridge() + model_metadata: Metadata = None # sampling params n_prev: int = 64 @@ -92,6 +95,9 @@ class Config: @classmethod def validate_environment(cls, values: Dict) -> Dict: values["llama_client"] = LlamaClientNode.get_instance() + values["model_metadata"] = ( + values["llama_client"].get_metadata(GetMetadata.Request()).metadata + ) return values def cancel(self) -> None: @@ -103,6 +109,8 @@ def _create_action_goal( stop: Optional[List[str]] = None, image_url: Optional[str] = None, image: Optional[np.ndarray] = None, + tools_grammar: Optional[str] = None, + **kwargs ) -> GenerateResponse.Result: goal = GenerateResponse.Goal() @@ -167,7 +175,9 @@ def _create_action_goal( goal.sampling_config.samplers_sequence = self.samplers_sequence goal.sampling_config.grammar = self.grammar - goal.sampling_config.grammar_schema = self.grammar_schema + goal.sampling_config.grammar_schema = ( + tools_grammar if tools_grammar else self.grammar_schema + ) goal.sampling_config.penalty_prompt_tokens = self.penalty_prompt_tokens goal.sampling_config.use_penalty_prompt_tokens = self.use_penalty_prompt_tokens diff --git a/llama_ros/llama_ros/llama_client_node.py b/llama_ros/llama_ros/llama_client_node.py index 1d4625a..eceb862 100644 --- a/llama_ros/llama_ros/llama_client_node.py +++ b/llama_ros/llama_ros/llama_client_node.py @@ -125,6 +125,7 @@ def __init__(self, namespace: str = "llama") -> None: self._spin_thread.start() def get_metadata(self, req: GetMetadata.Request) -> GetMetadata: + self._get_metadata_srv_client.wait_for_service() return self._get_metadata_srv_client.call(req) def tokenize(self, req: Tokenize.Request) -> Tokenize.Response: diff --git a/llama_ros/src/llama_ros/llama.cpp b/llama_ros/src/llama_ros/llama.cpp index e68d261..b4b3ff2 100644 --- a/llama_ros/src/llama_ros/llama.cpp +++ b/llama_ros/src/llama_ros/llama.cpp @@ -369,7 +369,7 @@ struct Metadata Llama::get_metadata() { metadata.tokenizer.add_bos_token = this->get_metadata("tokenizer.ggml.add_bos_token", 8) == "true"; metadata.tokenizer.chat_template = - this->get_metadata("tokenizer.chat_template", 2048); + this->get_metadata("tokenizer.chat_template", 4096); return metadata; }