From 33ad2addd816b01ec2bfe668d8b6254a84835912 Mon Sep 17 00:00:00 2001 From: MaxJa4 Date: Sun, 17 Dec 2023 14:10:26 +0100 Subject: [PATCH] Make linter happy --- .../traffic_light_detection/src/traffic_light_config.py | 2 +- .../traffic_light_detection/traffic_light_inference.py | 7 ++++--- .../src/traffic_light_detection/traffic_light_training.py | 4 ++-- code/perception/src/traffic_light_node.py | 4 ++-- code/perception/src/vision_node.py | 6 +++--- .../experiments/object-detection-model_evaluation/yolo.py | 8 ++++++-- 6 files changed, 18 insertions(+), 13 deletions(-) diff --git a/code/perception/src/traffic_light_detection/src/traffic_light_config.py b/code/perception/src/traffic_light_detection/src/traffic_light_config.py index fd175da9..e1c720c2 100644 --- a/code/perception/src/traffic_light_detection/src/traffic_light_config.py +++ b/code/perception/src/traffic_light_detection/src/traffic_light_config.py @@ -17,7 +17,7 @@ def __init__(self): # Amount of epochs to train # One epoch: Training with all images from training dataset once self.NUM_WORKERS = 4 - self.NUM_CLASSES = 5 # Traffic light states: green, yellow, red, back, side + self.NUM_CLASSES = 5 # States: green, yellow, red, back, side self.NUM_CHANNELS = 3 # RGB encoded images # Inference diff --git a/code/perception/src/traffic_light_detection/src/traffic_light_detection/traffic_light_inference.py b/code/perception/src/traffic_light_detection/src/traffic_light_detection/traffic_light_inference.py index 8277cb28..ada59c5f 100644 --- a/code/perception/src/traffic_light_detection/src/traffic_light_detection/traffic_light_inference.py +++ b/code/perception/src/traffic_light_detection/src/traffic_light_detection/traffic_light_inference.py @@ -2,9 +2,10 @@ import torch.cuda import torchvision.transforms as t -from traffic_light_detection.src.traffic_light_detection.transforms import Normalize, ResizeAndPadToSquare, \ - load_image -from traffic_light_detection.src.traffic_light_detection.classification_model import ClassificationModel +from traffic_light_detection.src.traffic_light_detection.transforms \ + import Normalize, ResizeAndPadToSquare, load_image +from traffic_light_detection.src.traffic_light_detection.classification_model \ + import ClassificationModel from torchvision.transforms import ToTensor from traffic_light_detection.src.traffic_light_config import TrafficLightConfig diff --git a/code/perception/src/traffic_light_detection/src/traffic_light_detection/traffic_light_training.py b/code/perception/src/traffic_light_detection/src/traffic_light_detection/traffic_light_training.py index 4ab43748..1322d024 100644 --- a/code/perception/src/traffic_light_detection/src/traffic_light_detection/traffic_light_training.py +++ b/code/perception/src/traffic_light_detection/src/traffic_light_detection/traffic_light_training.py @@ -11,8 +11,8 @@ import sys import os sys.path.append(os.path.abspath(sys.path[0] + '/..')) -from traffic_light_detection.transforms import Normalize, ResizeAndPadToSquare, \ - load_image # noqa: E402 +from traffic_light_detection.transforms import Normalize, \ + ResizeAndPadToSquare, load_image # noqa: E402 from data_generation.weights_organizer import WeightsOrganizer # noqa: E402 from traffic_light_detection.classification_model import ClassificationModel \ # noqa: E402 diff --git a/code/perception/src/traffic_light_node.py b/code/perception/src/traffic_light_node.py index e27c6d1a..6f67b5b1 100755 --- a/code/perception/src/traffic_light_node.py +++ b/code/perception/src/traffic_light_node.py @@ -5,9 +5,9 @@ from rospy.numpy_msg import numpy_msg from sensor_msgs.msg import Image as ImageMsg from perception.msg import TrafficLightState -from std_msgs.msg import Header from cv_bridge import CvBridge -from traffic_light_detection.src.traffic_light_detection.traffic_light_inference import TrafficLightInference +from traffic_light_detection.src.traffic_light_detection.traffic_light_inference \ + import TrafficLightInference # noqa: E501 class TrafficLightNode(CompatibleNode): diff --git a/code/perception/src/vision_node.py b/code/perception/src/vision_node.py index 69fecb96..0eaf4a51 100755 --- a/code/perception/src/vision_node.py +++ b/code/perception/src/vision_node.py @@ -18,7 +18,6 @@ from cv_bridge import CvBridge from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks import numpy as np -from time import perf_counter from ultralytics import NAS, YOLO, RTDETR, SAM, FastSAM """ VisionNode: @@ -188,7 +187,7 @@ def predict_ultralytics(self, image): cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR) # print(cv_image.shape) - output = self.model(cv_image, half=True, verbose=False, retina_masks=True) + output = self.model(cv_image, half=True, verbose=False) if 9 in output[0].boxes.cls: self.process_traffic_lights(output[0], cv_image, image.header) @@ -212,7 +211,8 @@ def process_traffic_lights(self, prediction, cv_image, image_header): box = box[0:4].astype(int) segmented = cv_image[box[1]:box[3], box[0]:box[2]] - traffic_light_image = self.bridge.cv2_to_imgmsg(segmented, encoding="rgb8") + traffic_light_image = self.bridge.cv2_to_imgmsg(segmented, + encoding="rgb8") traffic_light_image.header = image_header self.traffic_light_publisher.publish(traffic_light_image) diff --git a/doc/06_perception/experiments/object-detection-model_evaluation/yolo.py b/doc/06_perception/experiments/object-detection-model_evaluation/yolo.py index f7ff342d..39d727b7 100644 --- a/doc/06_perception/experiments/object-detection-model_evaluation/yolo.py +++ b/doc/06_perception/experiments/object-detection-model_evaluation/yolo.py @@ -1,5 +1,8 @@ ''' -Docs: https://docs.ultralytics.com/modes/predict/, https://docs.ultralytics.com/tasks/detect/#models, https://docs.ultralytics.com/models/yolo-nas +Docs: +https://docs.ultralytics.com/modes/predict/ +https://docs.ultralytics.com/tasks/detect/#models +https://docs.ultralytics.com/models/yolo-nas ''' import os @@ -35,6 +38,7 @@ image_path = os.path.join(IMAGE_BASE_FOLDER, IMAGES_FOR_TEST[p]) img = Image.open(image_path) - _ = model.predict(source=img, save=True, save_conf=True, line_width=1, half=True) + _ = model.predict(source=img, save=True, save_conf=True, + line_width=1, half=True) del model