Skip to content

Commit

Permalink
README and Black formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
agonzc34 committed Jan 9, 2025
1 parent f83dc2e commit 6f8cd07
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 10 deletions.
84 changes: 82 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -734,8 +734,7 @@ rclpy.shutdown()

</details>

#### chat_llama_ros

#### chat_llama_ros (Chat + LVM)
<details>
<summary>Click to expand</summary>

Expand Down Expand Up @@ -778,6 +777,73 @@ rclpy.shutdown()

</details>

#### 🎉 \*\*\*NEW*** chat_llama_ros (Tools) 🎉

<details>
<summary>Click to expand</summary>

The current implementation of Tools allows executing tools without requiring a model trained for that task.

```python

import time

import rclpy
from rclpy.node import Node
from llama_ros.langchain import ChatLlamaROS
from langchain_core.messages import HumanMessage
from langchain.tools import tool
from random import randint

rclpy.init()

@tool
def get_inhabitants(city: str) -> int:
"""Get the current temperature of a city"""
return randint(4_000_000, 8_000_000)


@tool
def get_curr_temperature(city: str) -> int:
"""Get the current temperature of a city"""
return randint(20, 30)

chat = ChatLlamaROS(temp=0.6, penalty_last_n=8, use_llama_template=True)

messages = [
HumanMessage(
"What is the current temperature in Madrid? And its inhabitants?"
)
]

llm_tools = self.chat.bind_tools(
[get_inhabitants, get_curr_temperature], tool_choice='any'
)

all_tools_res = llm_tools.invoke(messages)
messages.append(all_tools_res)

for tool in all_tools_res.tool_calls:
selected_tool = {
"get_inhabitants": get_inhabitants, "get_curr_temperature": get_curr_temperature
}[tool['name']]

tool_msg = selected_tool.invoke(tool)

formatted_output = f"{tool['name']}({''.join(tool['args'].values())}) = {tool_msg.content}"

tool_msg.additional_kwargs = {'args': tool['args']}
messages.append(tool_msg)

res = self.chat.invoke(messages)

print(f"Response: {res.content}")

rclpy.shutdown()
```

</details>

## Demos

### LLM Demo
Expand Down Expand Up @@ -868,6 +934,20 @@ ros2 run llama_demos chatllama_demo_node

[ChatLlamaROS demo](https://github-production-user-asset-6210df.s3.amazonaws.com/55236157/363094669-c6de124a-4e91-4479-99b6-685fecb0ac20.webm?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240830%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240830T081232Z&X-Amz-Expires=300&X-Amz-Signature=f937758f4bcbaec7683e46ddb057fb642dc86a33cc8c736fca3b5ce2bf06ddac&X-Amz-SignedHeaders=host&actor_id=55236157&key_id=0&repo_id=622137360)

### Tools Demo

```shell
ros2 llama launch MiniCPM-2.6.yaml
```

```shell
ros2 run llama_demos chatllama_tools_node
```



[Tools ChatLlama](https://github.com/user-attachments/assets/b912ee29-1466-4d6a-888b-9a2d9c16ae1d)

#### Full Demo (LLM + chat template + RAG + Reranking + Stream)

```shell
Expand Down
28 changes: 21 additions & 7 deletions llama_demos/llama_demos/chatllama_tools_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import rclpy
from rclpy.node import Node
from llama_ros.langchain import ChatLlamaROS
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages import HumanMessage
from langchain.tools import tool
from random import randint

Expand Down Expand Up @@ -63,9 +63,11 @@ def send_prompt(self) -> None:
"What is the current temperature in Madrid? And its inhabitants?"
)
]

self.get_logger().info(f"\nPrompt: {messages[0].content}")

llm_tools = self.chat.bind_tools(
[get_inhabitants, get_curr_temperature], tool_choice='any'
[get_inhabitants, get_curr_temperature], tool_choice="any"
)

self.initial_time = time.time()
Expand All @@ -76,20 +78,32 @@ def send_prompt(self) -> None:

for tool in all_tools_res.tool_calls:
selected_tool = {
"get_inhabitants": get_inhabitants, "get_curr_temperature": get_curr_temperature
"get_inhabitants": get_inhabitants,
"get_curr_temperature": get_curr_temperature
}[tool['name']]

tool_msg = selected_tool.invoke(tool)
tool_msg.additional_kwargs = {'args': tool['args']}

formatted_output = f"{tool['name']}({''.join(tool['args'].values())}) = {tool_msg.content}"
self.get_logger().info(f'Calling tool: {formatted_output}')

tool_msg.additional_kwargs = {"args": tool["args"]}
messages.append(tool_msg)

res = self.chat.invoke(messages)

self.eval_time = time.time()

self.get_logger().info(res.content)
self.get_logger().info(f"\nResponse: {res.content}")

self.get_logger().info(f"Time to generate tools: {self.tools_time - self.initial_time} s")
self.get_logger().info(f"Time to generate last response: {self.eval_time - self.tools_time} s")
time_generate_tools = self.tools_time - self.initial_time
time_last_response = self.eval_time - self.tools_time
self.get_logger().info(
f"Time to generate tools: {time_generate_tools:.2} s"
)
self.get_logger().info(
f"Time to generate last response: {time_last_response:.2} s"
)


def main():
Expand Down
1 change: 0 additions & 1 deletion llama_ros/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ install(TARGETS
)

install(DIRECTORY
templates
DESTINATION share/${PROJECT_NAME}
)

Expand Down

0 comments on commit 6f8cd07

Please sign in to comment.