Skip to content

Commit

Permalink
Make linter happy
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxJa4 committed Dec 17, 2023
1 parent 19b10dd commit 33ad2ad
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self):
# Amount of epochs to train
# One epoch: Training with all images from training dataset once
self.NUM_WORKERS = 4
self.NUM_CLASSES = 5 # Traffic light states: green, yellow, red, back, side
self.NUM_CLASSES = 5 # States: green, yellow, red, back, side
self.NUM_CHANNELS = 3 # RGB encoded images

# Inference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import torch.cuda
import torchvision.transforms as t
from traffic_light_detection.src.traffic_light_detection.transforms import Normalize, ResizeAndPadToSquare, \
load_image
from traffic_light_detection.src.traffic_light_detection.classification_model import ClassificationModel
from traffic_light_detection.src.traffic_light_detection.transforms \
import Normalize, ResizeAndPadToSquare, load_image
from traffic_light_detection.src.traffic_light_detection.classification_model \
import ClassificationModel
from torchvision.transforms import ToTensor
from traffic_light_detection.src.traffic_light_config import TrafficLightConfig

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import sys
import os
sys.path.append(os.path.abspath(sys.path[0] + '/..'))
from traffic_light_detection.transforms import Normalize, ResizeAndPadToSquare, \
load_image # noqa: E402
from traffic_light_detection.transforms import Normalize, \
ResizeAndPadToSquare, load_image # noqa: E402
from data_generation.weights_organizer import WeightsOrganizer # noqa: E402
from traffic_light_detection.classification_model import ClassificationModel \
# noqa: E402
Expand Down
4 changes: 2 additions & 2 deletions code/perception/src/traffic_light_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from rospy.numpy_msg import numpy_msg
from sensor_msgs.msg import Image as ImageMsg
from perception.msg import TrafficLightState
from std_msgs.msg import Header
from cv_bridge import CvBridge
from traffic_light_detection.src.traffic_light_detection.traffic_light_inference import TrafficLightInference
from traffic_light_detection.src.traffic_light_detection.traffic_light_inference \
import TrafficLightInference # noqa: E501


class TrafficLightNode(CompatibleNode):
Expand Down
6 changes: 3 additions & 3 deletions code/perception/src/vision_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from cv_bridge import CvBridge
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
import numpy as np
from time import perf_counter
from ultralytics import NAS, YOLO, RTDETR, SAM, FastSAM
"""
VisionNode:
Expand Down Expand Up @@ -188,7 +187,7 @@ def predict_ultralytics(self, image):
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR)
# print(cv_image.shape)

output = self.model(cv_image, half=True, verbose=False, retina_masks=True)
output = self.model(cv_image, half=True, verbose=False)

if 9 in output[0].boxes.cls:
self.process_traffic_lights(output[0], cv_image, image.header)
Expand All @@ -212,7 +211,8 @@ def process_traffic_lights(self, prediction, cv_image, image_header):
box = box[0:4].astype(int)
segmented = cv_image[box[1]:box[3], box[0]:box[2]]

traffic_light_image = self.bridge.cv2_to_imgmsg(segmented, encoding="rgb8")
traffic_light_image = self.bridge.cv2_to_imgmsg(segmented,
encoding="rgb8")
traffic_light_image.header = image_header
self.traffic_light_publisher.publish(traffic_light_image)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
'''
Docs: https://docs.ultralytics.com/modes/predict/, https://docs.ultralytics.com/tasks/detect/#models, https://docs.ultralytics.com/models/yolo-nas
Docs:
https://docs.ultralytics.com/modes/predict/
https://docs.ultralytics.com/tasks/detect/#models
https://docs.ultralytics.com/models/yolo-nas
'''

import os
Expand Down Expand Up @@ -35,6 +38,7 @@
image_path = os.path.join(IMAGE_BASE_FOLDER, IMAGES_FOR_TEST[p])
img = Image.open(image_path)

_ = model.predict(source=img, save=True, save_conf=True, line_width=1, half=True)
_ = model.predict(source=img, save=True, save_conf=True,
line_width=1, half=True)

del model

0 comments on commit 33ad2ad

Please sign in to comment.