From 8844b07d7e21609204013555ee35ca260f6add2d Mon Sep 17 00:00:00 2001 From: Leon Okrusch Date: Fri, 17 Nov 2023 12:12:04 +0100 Subject: [PATCH] feat(53): Panoptic Segmentation --- code/agent/launch/agent.launch | 4 +- code/perception/launch/perception.launch | 15 ++- code/perception/src/p_testing_node.py | 102 ++++++++++++++++++ .../efficientps/model.py | 2 +- code/perception/src/segmentation_node.py | 48 +++++---- 5 files changed, 147 insertions(+), 24 deletions(-) create mode 100755 code/perception/src/p_testing_node.py diff --git a/code/agent/launch/agent.launch b/code/agent/launch/agent.launch index 4291717f..e469fc8a 100644 --- a/code/agent/launch/agent.launch +++ b/code/agent/launch/agent.launch @@ -7,6 +7,7 @@ + @@ -14,7 +15,6 @@ - @@ -22,5 +22,5 @@ - + diff --git a/code/perception/launch/perception.launch b/code/perception/launch/perception.launch index bc8c8f9c..6c9d8e7c 100644 --- a/code/perception/launch/perception.launch +++ b/code/perception/launch/perception.launch @@ -2,18 +2,26 @@ - + + + + + + - + + diff --git a/code/perception/src/p_testing_node.py b/code/perception/src/p_testing_node.py new file mode 100755 index 00000000..ba406902 --- /dev/null +++ b/code/perception/src/p_testing_node.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 + +from ros_compatibility.node import CompatibleNode +import ros_compatibility as roscomp +import torch +from torchvision.models.segmentation import deeplabv3_resnet101 +from torchvision.models.segmentation import DeepLabV3_ResNet101_Weights +import torchvision.transforms as t +import cv2 +from rospy.numpy_msg import numpy_msg +from sensor_msgs.msg import Image +from cv_bridge import CvBridge + + +class PerceptionTestingNode(CompatibleNode): + def __init__(self, name, **kwargs): + # starting comment + + super().__init__(name, **kwargs) + # self.model = torch.hub.load('pytorch/vision:v0.10.0', + # 'deeplabv3_resnet50', pretrained=True) + + self.model = deeplabv3_resnet101(DeepLabV3_ResNet101_Weights) + # self.model.eval() + # print("Model Test: ", self.model(torch.zeros((1,3,720,1280)))) + + self.bridge = CvBridge() + + self.role_name = self.get_param("role_name", "hero") + self.side = self.get_param("side", "Center") + self.setup_camera_subscriptions() + self.setup_camera_publishers() + + def setup_camera_subscriptions(self): + self.new_subscription( + msg_type=numpy_msg(Image), + callback=self.handle_camera_image, + topic=f"/carla/{self.role_name}/{self.side}/image", + qos_profile=1 + ) + + def setup_camera_publishers(self): + self.publisher = self.new_publisher( + msg_type=numpy_msg(Image), + topic=f"/paf/{self.role_name}/{self.side}/segmented_image", + qos_profile=1 + ) + + def handle_camera_image(self, image): + self.model.eval() + self.loginfo(f"got image from camera {self.side}") + + cv_image = self.bridge.imgmsg_to_cv2(img_msg=image, + desired_encoding='passthrough') + cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR) + """ + image_array = np.frombuffer(image.data, dtype=np.uint8) + print(image_array.shape) + image_array = image_array.reshape((image.height, image.width, -1)) + print(image_array.shape) + # remove alpha channel + image_array = image_array[:, :, :3] + print(image_array.shape)""" + + preprocess = t.Compose([ + t.ToTensor(), + t.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + input_image = preprocess(cv_image).unsqueeze(dim=0) + prediction = self.model(input_image)['out'][0] + # prediction = id2rgb(prediction) + # print(prediction) + print(prediction.shape) + + masked_image = self.create_mask(prediction, input_image) + self.publisher.publish(self.bridge.cv2_to_imgmsg(masked_image)) + + pass + + def create_mask(self, model_output, input_image): + palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + colors = torch.as_tensor([i for i in range(21)])[:, None] * palette + colors = (colors % 255).numpy().astype("uint8") + r = Image.fromarray(model_output.byte().cpu().numpy()) + r = r.resize(input_image.shape[2], input_image.shape[3]) + r.putpalette(colors) + return r + + def run(self): + self.spin() + pass + # while True: + # self.spin() + + +if __name__ == "__main__": + roscomp.init("PerceptionTestingNode") + # try: + + node = PerceptionTestingNode("PerceptionTestingNode") + node.run() diff --git a/code/perception/src/panoptic_segmentation/efficientps/model.py b/code/perception/src/panoptic_segmentation/efficientps/model.py index 60962502..60ced17f 100644 --- a/code/perception/src/panoptic_segmentation/efficientps/model.py +++ b/code/perception/src/panoptic_segmentation/efficientps/model.py @@ -2,7 +2,7 @@ import torch from torch.optim.lr_scheduler import ReduceLROnPlateau from .fpn.two_way_fpn import TwoWayFpn -import pytorch_lightning as pl +import lightning as pl from .backbone.modify_efficientnet import \ generate_backbone_EfficientPS, \ output_feature_size diff --git a/code/perception/src/segmentation_node.py b/code/perception/src/segmentation_node.py index b840d9b0..eb0d882d 100755 --- a/code/perception/src/segmentation_node.py +++ b/code/perception/src/segmentation_node.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 - +""" import pathlib import numpy as np import ros_compatibility as roscomp @@ -13,25 +13,25 @@ from panoptic_segmentation.efficientps import EfficientPS as EfficientPS from panoptic_segmentation.train_net import add_custom_param -from panoptic_segmentation.efficientps.panoptic_segmentation_module import \ - panoptic_segmentation_module +from panoptic_segmentation.efficientps.panoptic_segmentation_module +import panoptic_segmentation_module from detectron2.config import get_cfg import torchvision.transforms.functional as F from detectron2.structures import Instances, BitMasks, Boxes -CFG_FILE_PATH = pathlib.Path( - __file__).parent / "panoptic_segmentation" / "config.yaml" -MODEL_PATH = pathlib.Path( - __file__).parent.parent / \ - "models/panoptic_segmentation/efficientps.ckpt" +CFG_FILE_PATH = pathlib.Path(__file__).parent / +"panoptic_segmentation" / "config.yaml" + +MODEL_PATH = pathlib.Path(__file__).parent.parent / +"src/panoptic_segmentation/efficientps/model.pth" class SegmentationNode(CompatibleNode): - """ + This node runs the panoptic segmentation model on the camera images and publishes the segmented results. - """ + def __init__(self, name, **kwargs): super().__init__(name, **kwargs) @@ -44,12 +44,13 @@ def __init__(self, name, **kwargs): self.model, self.transform, self.model_cfg = self.load_model() # warm up - self.predict(np.zeros((720, 1280, 3))) + #self.predict(np.zeros((720, 1280, 3))) - self.setup_camera_subscriptions() - self.setup_camera_publishers() + #self.setup_camera_subscriptions() + #self.setup_camera_publishers() def setup_camera_subscriptions(self): + self.new_subscription( msg_type=numpy_msg(Image), callback=self.handle_camera_image, @@ -58,6 +59,7 @@ def setup_camera_subscriptions(self): ) def setup_camera_publishers(self): + self.publisher = self.new_publisher( msg_type=numpy_msg(Image), topic=f"/paf/{self.role_name}/{self.side}/segmented_image", @@ -66,6 +68,7 @@ def setup_camera_publishers(self): @staticmethod def load_model(): + cfg = get_cfg() cfg['train'] = False add_custom_param(cfg) @@ -76,18 +79,21 @@ def load_model(): A.Normalize(mean=cfg.TRANSFORM.NORMALIZE.MEAN, std=cfg.TRANSFORM.NORMALIZE.STD), ]) - - model = EfficientPS.load_from_checkpoint( + "model = EfficientPS.load_from_checkpoint( cfg=cfg, checkpoint_path=str(MODEL_PATH) ) + model = EfficientPS(cfg) + model.load_state_dict(torch.load(MODEL_PATH)) + print(model) model.eval() model.freeze() - model.to(torch.device("cuda:0")) + model.to(torch.device("cuda:0")) #add device definition before - return model, transform, cfg + return model, transform, None def predict(self, image: np.ndarray): + self.loginfo(f"predicting image shape: {image.shape}") # expand # image = np.expand_dims(image, axis=0) @@ -113,7 +119,9 @@ def predict(self, image: np.ndarray): result = segmented_result.cpu().numpy() # self.loginfo(f"predictions: {prediction.shape}") - return result + return resultskip python linting +code/perception/src/p_testing_node.py:17:27: W291 trailing w + self.loginfo("predicting something") def handle_camera_image(self, image): self.loginfo(f"got image from camera {self.side}") @@ -149,6 +157,8 @@ def handle_camera_image(self, image): self.publisher.publish(msg) self.loginfo(f"prediction shape: {prediction.shape}") + self.loginfo("reveived image from camera") + pass def run(self): self.spin() @@ -168,3 +178,5 @@ def run(self): # finally: # roscomp.shutdown() # + +"""