Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 2603 #2671

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified db.sqlite3
Binary file not shown.
51 changes: 41 additions & 10 deletions exercises/static/exercises/dl_digit_classifier/exercise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import cv2
import numpy as np
import onnxruntime
from onnxruntime.quantization import quantize_dynamic, QuantType
from websocket_server import WebsocketServer

from gui import GUI, ThreadGUI
Expand Down Expand Up @@ -45,7 +46,8 @@ def __init__(self):
self.gui = GUI(self.host, self.hal)

# The process function
def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
# The process function
def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
"""
Given a DL model in onnx format, yield prediction per frame.
:param raw_dl_model: raw DL model transferred through websocket
Expand All @@ -60,16 +62,44 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
raw_dl_model_bytes = raw_dl_model.encode('ascii')
raw_dl_model_bytes = base64.b64decode(raw_dl_model_bytes)

# Load ONNX model
# Load and optimize ONNX model
ort_session = None
try:
with open(self.aux_model_fname, "wb") as f:
f.write(raw_dl_model_bytes)
ort_session = onnxruntime.InferenceSession(self.aux_model_fname)

# Load the original model
model = onnx.load(self.aux_model_fname)

# Apply optimizations directly using ONNX Runtime
model_optimized = onnx.optimizer.optimize(model, passes=[
"eliminate_identity",
"eliminate_deadend",
"eliminate_nop_dropout",
"eliminate_nop_transpose",
"fuse_bn_into_conv",
"fuse_consecutive_transposes",
"fuse_pad_into_conv",
"fuse_transpose_into_gemm",
"lift_lexical_references",
"nop_elimination",
"split_init"
])

# Save the optimized model
optimized_model_fname = "optimized_model.onnx"
onnx.save(model_optimized, optimized_model_fname)

# Quantize the model
quantized_model_fname = "quantized_model.onnx"
quantize_dynamic(optimized_model_fname, quantized_model_fname, weight_type=QuantType.QInt8)

# Load the quantized model
ort_session = onnxruntime.InferenceSession(quantized_model_fname)
except Exception:
exc_type, exc_value, exc_traceback = sys.exc_info()
print(str(exc_value))
print("ERROR: Model couldn't be loaded")
print("ERROR: Model couldn't be loaded or optimized")

try:
# Init auxiliar variables used for stabilized predictions
Expand Down Expand Up @@ -102,10 +132,10 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
pred = int(np.argmax(output, axis=1)) # get the index of the max log-probability

end = time.time()
frame_time = round(end-start, 3)
fps = 1.0/frame_time
frame_time = round(end - start, 3)
fps = 1.0 / frame_time
# number of consecutive frames that must be reached to consider a validprediction
n_consecutive_frames = int(fps/2)
n_consecutive_frames = int(fps / 2)

# For stability, only show digit if detected in more than n consecutive frames
if pred != previous_established_pred:
Expand All @@ -122,9 +152,9 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
previous_pred = pred

# Show region used as ROI
cv2.rectangle(input_image,pt2=(w_border, h_border),pt1=(w_border + w_roi, h_border + h_roi),color=(255, 0, 0),thickness=3)
cv2.rectangle(input_image, pt2=(w_border, h_border), pt1=(w_border + w_roi, h_border + h_roi), color=(255, 0, 0), thickness=3)
# Show FPS count
cv2.putText(input_image, "FPS: {}".format(int(fps)), (7,25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
cv2.putText(input_image, "FPS: {}".format(int(fps)), (7, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

# Send result
self.gui.showResult(input_image, str(previous_established_pred))
Expand All @@ -136,7 +166,7 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):

self.iteration_counter += 1

# The code should be run for atleast the target time step
# The code should be run for at least the target time step
# If it's less put to sleep
if (ms < self.ideal_cycle):
time.sleep((self.ideal_cycle - ms) / 1000.0)
Expand All @@ -149,6 +179,7 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
exc_type, exc_value, exc_traceback = sys.exc_info()
print(str(exc_value))


# Function to measure the frequency of iterations
def measure_frequency(self):
previous_time = datetime.now()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os.path
from typing import Callable

from src.manager.libs.applications.compatibility.exercise_wrapper_ros2 import CompatibilityExerciseWrapperRos2


class Exercise(CompatibilityExerciseWrapperRos2):
def __init__(self, circuit: str, update_callback: Callable):
current_path = os.path.dirname(__file__)

super(Exercise, self).__init__(exercise_command=f"{current_path}/../../python_template/ros2_humble/exercise.py 0.0.0.0",
gui_command=f"{current_path}/../../python_template/ros2_humble/gui.py 0.0.0.0 {circuit}",
update_callback=update_callback)
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import launch
from launch import LaunchDescription
from launch_ros.actions import Node

def generate_launch_description():
return LaunchDescription([
Node(
package='v4l2_camera',
executable='v4l2_camera_node',
name='v4l2_camera_node',
parameters=[
{'video_device': '/dev/video0'}
],
),
])

Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import json
import os
import rclpy
import cv2
import sys
import base64
import threading
import time
import numpy as np
from datetime import datetime
import websocket
import subprocess
import logging

from hal_interfaces.general.odometry import OdometryNode
from console import start_console


# Graphical User Interface Class
class GUI:
# Initialization function
# The actual initialization
def __init__(self, host):

self.payload = {'image': ''}

# ROS2 init
if not rclpy.ok():
rclpy.init(args=None)


# Image variables
self.image_to_be_shown = None
self.digit_to_be_shown = None
self.image_to_be_shown_updated = False
self.image_show_lock = threading.Lock()
self.host = host
self.client = None



self.ack = False
self.ack_lock = threading.Lock()

# Create the lap object
# TODO: maybe move this to HAL and have it be hybrid


self.client_thread = threading.Thread(target=self.run_websocket)
self.client_thread.start()

def run_websocket(self):
while True:
print("GUI WEBSOCKET CONNECTED")
self.client = websocket.WebSocketApp(self.host, on_message=self.on_message)
self.client.run_forever(ping_timeout=None, ping_interval=0)

# Function to prepare image payload
# Encodes the image as a JSON string and sends through the WS
def payloadImage(self):
with self.image_show_lock:
image_to_be_shown_updated = self.image_to_be_shown_updated
image_to_be_shown = self.image_to_be_shown

image = image_to_be_shown
payload = {'image': '', 'shape': ''}

if not image_to_be_shown_updated:
return payload

shape = image.shape
frame = cv2.imencode('.JPEG', image)[1]
encoded_image = base64.b64encode(frame)

payload['image'] = encoded_image.decode('utf-8')
payload['shape'] = shape

with self.image_show_lock:
self.image_to_be_shown_updated = False

return payload

# Function for student to call
def showImage(self, image):
with self.image_show_lock:
self.image_to_be_shown = image
self.image_to_be_shown_updated = True

# Update the gui
def update_gui(self):
# print("GUI update")
# Payload Image Message
payload = self.payloadImage()
self.payload["image"] = json.dumps(payload)


message = json.dumps(self.payload)
if self.client:
try:
self.client.send(message)
# print(message)
except Exception as e:
print(f"Error sending message: {e}")

def on_message(self, ws, message):
"""Handles incoming messages from the websocket client."""
if message.startswith("#ack"):
# print("on message" + str(message))
self.set_acknowledge(True)

def get_acknowledge(self):
"""Gets the acknowledge status."""
with self.ack_lock:
ack = self.ack

return ack

def set_acknowledge(self, value):
"""Sets the acknowledge status."""
with self.ack_lock:
self.ack = value


class ThreadGUI:
"""Class to manage GUI updates and frequency measurements in separate threads."""

def __init__(self, gui):
"""Initializes the ThreadGUI with a reference to the GUI instance."""
self.gui = gui
self.ideal_cycle = 80
self.real_time_factor = 0
self.frequency_message = {'brain': '', 'gui': ''}
self.iteration_counter = 0
self.running = True

def start(self):
"""Starts the GUI, frequency measurement, and real-time factor threads."""
self.frequency_thread = threading.Thread(target=self.measure_and_send_frequency)
self.gui_thread = threading.Thread(target=self.run)
self.frequency_thread.start()
self.gui_thread.start()
print("GUI Thread Started!")

def measure_and_send_frequency(self):
"""Measures and sends the frequency of GUI updates and brain cycles."""
previous_time = datetime.now()
while self.running:
time.sleep(2)

current_time = datetime.now()
dt = current_time - previous_time
ms = (dt.days * 24 * 60 * 60 + dt.seconds) * 1000 + dt.microseconds / 1000.0
previous_time = current_time
measured_cycle = ms / self.iteration_counter if self.iteration_counter > 0 else 0
self.iteration_counter = 0
brain_frequency = round(1000 / measured_cycle, 1) if measured_cycle != 0 else 0
gui_frequency = round(1000 / self.ideal_cycle, 1)
self.frequency_message = {'brain': brain_frequency, 'gui': gui_frequency}
message = json.dumps(self.frequency_message)
if self.gui.client:
try:
self.gui.client.send(message)
except Exception as e:
print(f"Error sending frequency message: {e}")

def run(self):
"""Main loop to update the GUI at regular intervals."""
while self.running:
start_time = datetime.now()

self.gui.update_gui()
self.iteration_counter += 1
finish_time = datetime.now()

dt = finish_time - start_time
ms = (dt.days * 24 * 60 * 60 + dt.seconds) * 1000 + dt.microseconds / 1000.0
sleep_time = max(0, (50 - ms) / 1000.0)
time.sleep(sleep_time)


# Create a GUI interface
host = "ws://127.0.0.1:2303"
gui_interface = GUI(host)

# Spin a thread to keep the interface updated
thread_gui = ThreadGUI(gui_interface)
thread_gui.start()

# Redirect the console
start_console()

def showImage(image):
gui_interface.showImage(image)

Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
import threading
import cv2

current_frame = None # Global variable to store the frame

class WebcamSubscriber(Node):
def __init__(self):
super().__init__('webcam_subscriber')
self.subscription = self.create_subscription(
Image,
'/image_raw',
self.listener_callback,
10)
self.subscription # prevent unused variable warning
self.bridge = CvBridge()

def listener_callback(self, msg):
global current_frame
self.get_logger().info('Receiving video frame')
current_frame = self.bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')

def run_webcam_node():

webcam_subscriber = WebcamSubscriber()

rclpy.spin(webcam_subscriber)
webcam_subscriber.destroy_node()


# Start the ROS2 node in a separate thread
thread = threading.Thread(target=run_webcam_node)
thread.start()

def getImage():
global current_frame
return current_frame

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[Exercise Documentation Website](https://jderobot.github.io/RoboticsAcademy/exercises/ComputerVision/dl_digit_classifier)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Functions to start and close console
import os
import sys

def start_console():
# Get all the file descriptors and choose the latest one
fds = os.listdir("/dev/pts/")
fds.sort()
console_fd = fds[-2]

sys.stderr = open('/dev/pts/' + console_fd, 'w')
sys.stdout = open('/dev/pts/' + console_fd, 'w')
sys.stdin = open('/dev/pts/' + console_fd, 'w')

def close_console():
sys.stderr.close()
sys.stdout.close()
sys.stdin.close()
Loading