Skip to content

Commit

Permalink
Manipulation demo: Add streamlit interface (#337)
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Dąbrowski <[email protected]>
  • Loading branch information
knicked authored Dec 23, 2024
1 parent 542616e commit 9564f4f
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 37 deletions.
8 changes: 7 additions & 1 deletion docs/demos/manipulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ This demo showcases the capabilities of RAI in performing manipulation tasks usi
ros2 launch examples/manipulation-demo.launch.py game_launcher:=path/to/RAIManipulationDemo.GameLauncher
```

2. In the second terminal, run the interactive prompt:
2. In the second terminal, run the streamlit interface:

```shell
streamlit run examples/manipulation-demo-streamlit.py
```

Alternatively, you can run the simpler command-line version, which also serves as an example of how to use the RAI API for you own applications:

```shell
python examples/manipulation-demo.py
Expand Down
74 changes: 74 additions & 0 deletions examples/manipulation-demo-streamlit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language goveself.rning permissions and
# limitations under the License.

import importlib

import streamlit as st
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage

from rai.agents.integrations.streamlit import get_streamlit_cb, streamlit_invoke
from rai.messages import HumanMultimodalMessage

manipulation_demo = importlib.import_module("manipulation-demo")


@st.cache_resource
def initialize_graph():
return manipulation_demo.create_agent()


def main():
st.set_page_config(
page_title="RAI Manipulation Demo",
page_icon=":robot:",
)
st.title("RAI Manipulation Demo")
st.markdown("---")

st.sidebar.header("Tool Calls History")

if "graph" not in st.session_state:
graph = initialize_graph()
st.session_state["graph"] = graph

if "messages" not in st.session_state:
st.session_state["messages"] = [
AIMessage(content="Hi! I am a robotic arm. What can I do for you?")
]

prompt = st.chat_input()
for msg in st.session_state.messages:
if isinstance(msg, AIMessage):
if msg.content:
st.chat_message("assistant").write(msg.content)
elif isinstance(msg, HumanMultimodalMessage):
continue
elif isinstance(msg, HumanMessage):
st.chat_message("user").write(msg.content)
elif isinstance(msg, ToolMessage):
with st.sidebar.expander(f"Tool: {msg.name}", expanded=False):
st.code(msg.content, language="json")

if prompt:
st.session_state.messages.append(HumanMessage(content=prompt))
st.chat_message("user").write(prompt)
with st.chat_message("assistant"):
st_callback = get_streamlit_cb(st.container())
streamlit_invoke(
st.session_state["graph"], st.session_state.messages, [st_callback]
)


if __name__ == "__main__":
main()
84 changes: 48 additions & 36 deletions examples/manipulation-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import rclpy
import rclpy.qos
from langchain_core.messages import HumanMessage

from rai.agents.conversational_agent import create_conversational_agent
Expand All @@ -21,45 +22,56 @@
from rai.tools.ros.native import GetCameraImage, Ros2GetTopicsNamesAndTypesTool
from rai.utils.model_initialization import get_llm_model

rclpy.init()
node = RaiBaseNode(node_name="manipulation_demo")
node.declare_parameter("conversion_ratio", 1.0)

tools = [
GetObjectPositionsTool(
node=node,
target_frame="panda_link0",
source_frame="RGBDCamera5",
camera_topic="/color_image5",
depth_topic="/depth_image5",
camera_info_topic="/color_camera_info5",
),
MoveToPointTool(node=node, manipulator_frame="panda_link0"),
GetCameraImage(node=node),
Ros2GetTopicsNamesAndTypesTool(node=node),
]
def create_agent():
rclpy.init()
node = RaiBaseNode(node_name="manipulation_demo")
node.declare_parameter("conversion_ratio", 1.0)
node.qos_profile.reliability = rclpy.qos.ReliabilityPolicy.RELIABLE

llm = get_llm_model(model_type="complex_model")
tools = [
GetObjectPositionsTool(
node=node,
target_frame="panda_link0",
source_frame="RGBDCamera5",
camera_topic="/color_image5",
depth_topic="/depth_image5",
camera_info_topic="/color_camera_info5",
),
MoveToPointTool(node=node, manipulator_frame="panda_link0"),
GetCameraImage(node=node),
Ros2GetTopicsNamesAndTypesTool(node=node),
]

system_prompt = """
You are a robotic arm with interfaces to detect and manipulate objects.
Here are the coordinates information:
x - front to back (positive is forward)
y - left to right (positive is right)
z - up to down (positive is up)
llm = get_llm_model(model_type="complex_model", streaming=True)

Before starting the task, make sure to grab the camera image to understand the environment.
"""
system_prompt = """
You are a robotic arm with interfaces to detect and manipulate objects.
Here are the coordinates information:
x - front to back (positive is forward)
y - left to right (positive is right)
z - up to down (positive is up)
agent = create_conversational_agent(
llm=llm,
tools=tools,
system_prompt=system_prompt,
)
Before starting the task, make sure to grab the camera image to understand the environment.
"""

messages = []
while True:
prompt = input("Enter a prompt: ")
messages.append(HumanMessage(content=prompt))
output = agent.invoke({"messages": messages})
output["messages"][-1].pretty_print()
agent = create_conversational_agent(
llm=llm,
tools=tools,
system_prompt=system_prompt,
)
return agent


def main():
agent = create_agent()
messages = []
while True:
prompt = input("Enter a prompt: ")
messages.append(HumanMessage(content=prompt))
output = agent.invoke({"messages": messages})
output["messages"][-1].pretty_print()


if __name__ == "__main__":
main()

0 comments on commit 9564f4f

Please sign in to comment.