From fd8b12b41d6efe85f93d46aa8827d9d9e6bfc291 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yusuf=20=C3=87i=C3=A7ek?= Date: Wed, 10 Apr 2024 00:11:33 +0300 Subject: [PATCH] * In training and inference time extracts optical flow features in multi frames * Multiple frames are taken as a ring buffer, then per-frame optical flow features are extracted * Frames and optical flow features are fused by DataFusionBlock * Configuration improvements were made according to the number of datasets * Fixed bugs in torch functions in Python 3.9 and above --- ...S-VideoInstanceSegmentation_long_bs16.yaml | 1 + demo_video/demo.py | 8 +- demo_video/predictor.py | 42 +++- demo_video/visualizer.py | 205 ++++++------------ .../data_video/dataset_mapper.py | 20 ++ .../data_video/datasets/ytvis.py | 5 +- .../datasets/ytvis_api/ytvoseval.py | 4 +- mask2former_video/modeling/matcher.py | 2 +- mask2former_video/video_maskformer_model.py | 22 +- maskfreevis/config/__init__.py | 1 + maskfreevis/config/config.py | 13 ++ maskfreevis/config/defaults.py | 7 + maskfreevis/data_fusion_modeling/__init__.py | 5 + maskfreevis/data_fusion_modeling/base.py | 25 +++ maskfreevis/data_fusion_modeling/build.py | 28 +++ maskfreevis/data_fusion_modeling/config.py | 22 ++ .../data_fusion_blocks.py | 103 +++++++++ .../data_fusion_modeling/optical_flow.py | 30 +++ maskfreevis/demo.py | 133 ++++++++++++ maskfreevis/train.py | 77 +++++++ maskfreevis/utils.py | 36 +++ train_net_video.py | 7 +- 22 files changed, 643 insertions(+), 153 deletions(-) create mode 100644 maskfreevis/config/__init__.py create mode 100644 maskfreevis/config/config.py create mode 100644 maskfreevis/config/defaults.py create mode 100644 maskfreevis/data_fusion_modeling/__init__.py create mode 100644 maskfreevis/data_fusion_modeling/base.py create mode 100644 maskfreevis/data_fusion_modeling/build.py create mode 100644 maskfreevis/data_fusion_modeling/config.py create mode 100644 maskfreevis/data_fusion_modeling/data_fusion_blocks.py create mode 100644 maskfreevis/data_fusion_modeling/optical_flow.py create mode 100644 maskfreevis/demo.py create mode 100644 maskfreevis/train.py create mode 100644 maskfreevis/utils.py diff --git a/configs/youtubevis_2019/Base-YouTubeVIS-VideoInstanceSegmentation_long_bs16.yaml b/configs/youtubevis_2019/Base-YouTubeVIS-VideoInstanceSegmentation_long_bs16.yaml index b2666d5..ceed848 100644 --- a/configs/youtubevis_2019/Base-YouTubeVIS-VideoInstanceSegmentation_long_bs16.yaml +++ b/configs/youtubevis_2019/Base-YouTubeVIS-VideoInstanceSegmentation_long_bs16.yaml @@ -17,6 +17,7 @@ MODEL: DATASETS: TRAIN: ("coco_2017_train_fake", "ytvis_2019_train",) TEST: ("ytvis_2019_val",) + DATASET_RATIO: (1.0,) SOLVER: IMS_PER_BATCH: 16 BASE_LR: 0.0001 diff --git a/demo_video/demo.py b/demo_video/demo.py index 2178fa0..5e836f5 100644 --- a/demo_video/demo.py +++ b/demo_video/demo.py @@ -22,6 +22,9 @@ from mask2former_video import add_maskformer2_video_config from predictor import VisualizationDemo import imageio +import random +from maskfreevis.config import get_cfg +from maskfreevis.data_fusion_modeling import add_data_fusion_block_config # constants WINDOW_NAME = "mask2former video demo" @@ -31,6 +34,7 @@ def setup_cfg(args): add_deeplab_config(cfg) add_maskformer2_config(cfg) add_maskformer2_video_config(cfg) + add_data_fusion_block_config(cfg) cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() @@ -107,14 +111,14 @@ def test_opencv_video_format(codec, file_ext): # assert args.input, "The input path(s) was not found" print('args input:', args.input) args.input = args.input[0] - for file_name in os.listdir(args.input): + for file_name in random.sample(os.listdir(args.input), 20): input_path_list = sorted([args.input + file_name + '/' + f for f in os.listdir(args.input + file_name)]) print('input path list:', input_path_list) if len(input_path_list) == 0: continue vid_frames = [] for path in input_path_list: - img = read_image(path, format="BGR") + img = read_image(path, format=cfg.INPUT.FORMAT) vid_frames.append(img) start_time = time.time() with autocast(): diff --git a/demo_video/predictor.py b/demo_video/predictor.py index 8ef7f66..e6bffa9 100644 --- a/demo_video/predictor.py +++ b/demo_video/predictor.py @@ -5,15 +5,22 @@ from collections import deque import cv2 import torch -from visualizer import TrackVisualizer +import copy from detectron2.data import MetadataCatalog from detectron2.engine.defaults import DefaultPredictor from detectron2.structures import Instances from detectron2.utils.video_visualizer import VideoVisualizer -from detectron2.utils.visualizer import ColorMode +from detectron2.utils.visualizer import ColorMode, Visualizer +from maskfreevis.data_fusion_modeling import extract_optical_flow_dense_matrix + +try: + from .visualizer import TrackVisualizer +except: + from visualizer import TrackVisualizer + class VisualizationDemo(object): - def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): + def __init__(self, cfg, metadata=None, instance_mode=ColorMode.IMAGE, parallel=False): """ Args: cfg (CfgNode): @@ -24,6 +31,8 @@ def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): self.metadata = MetadataCatalog.get( cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" ) + if metadata is not None: + self.metadata = metadata self.cpu_device = torch.device("cpu") self.instance_mode = instance_mode self.parallel = parallel @@ -87,6 +96,10 @@ class VideoPredictor(DefaultPredictor): inputs = cv2.imread("input.jpg") outputs = pred(inputs) """ + def __init__(self, cfg): + super().__init__(cfg) + self.data_fusion_status = cfg.MODEL.DATAFUSION.STATUS + def __call__(self, frames): """ Args: @@ -96,18 +109,35 @@ def __call__(self, frames): the output of the model for one image only. See :doc:`/tutorials/models` for details about the format. """ - with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 + with torch.inference_mode(): # https://github.com/sphinx-doc/sphinx/issues/4258 input_frames = [] - for original_image in frames: + optical_flow_matrixes = [] + for image_idx, original_image in enumerate(frames): # Apply pre-processing to image. if self.input_format == "RGB": # whether the model expects BGR inputs or RGB + prev_frame = copy.deepcopy(frames[image_idx - 1]) + current_frame = copy.deepcopy(frames[image_idx]) original_image = original_image[:, :, ::-1] + else: + prev_frame = copy.deepcopy(frames[image_idx - 1][:, :, ::-1]) + current_frame = copy.deepcopy(frames[image_idx][:, :, ::-1]) + + if self.data_fusion_status: + optical_flow_matrix = extract_optical_flow_dense_matrix(prev_frame, current_frame) + optical_flow_matrix = self.aug.get_transform(optical_flow_matrix).apply_image(optical_flow_matrix) + optical_flow_matrix = torch.as_tensor(optical_flow_matrix.astype("float32").transpose(2, 0, 1)) + optical_flow_matrixes.append(optical_flow_matrix) + else: + del prev_frame + del current_frame + height, width = original_image.shape[:2] image = self.aug.get_transform(original_image).apply_image(original_image) image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) input_frames.append(image) - inputs = {"image": input_frames, "height": height, "width": width} + + inputs = {"image": input_frames, "height": height, "width": width, "optical_flow": optical_flow_matrixes} predictions = self.model([inputs]) return predictions diff --git a/demo_video/visualizer.py b/demo_video/visualizer.py index 080f5d4..dbfdd02 100644 --- a/demo_video/visualizer.py +++ b/demo_video/visualizer.py @@ -2,10 +2,18 @@ import torch import numpy as np import matplotlib.colors as mplc + +from PIL import Image +from detectron2.structures import BoxMode +from detectron2.utils.file_io import PathManager +from detectron2.utils.colormap import random_color from detectron2.utils.visualizer import ColorMode, GenericMask, Visualizer, _create_text_labels + + _ID_JITTERS = [[0.9047944201469568, 0.3241718265806123, 0.33443746665210006], [0.4590171386127151, 0.9095038146383864, 0.3143840671974788], [0.4769356899795538, 0.5044406738441948, 0.5354530846360839], [0.00820945625670777, 0.24099210193126785, 0.15471834055332978], [0.6195684374237388, 0.4020380013509799, 0.26100266066404676], [0.08281237756545068, 0.05900744492710419, 0.06106221202154216], [0.2264886829978755, 0.04925271007292076, 0.10214429345996079], [0.1888247470009874, 0.11275000298612425, 0.46112894830685514], [0.37415767691880975, 0.844284596118331, 0.950471611180866], [0.3817344218157631, 0.3483259270707101, 0.6572989333690541], [0.2403115731054466, 0.03078280287279167, 0.5385975692534737], [0.7035076951650824, 0.12352084932325424, 0.12873080308790197], [0.12607434914489934, 0.111244793010015, 0.09333334699716023], [0.6551607300342269, 0.7003064103554443, 0.4131794512286162], [0.13592107365596595, 0.5390702818232149, 0.004540643174930525], [0.38286244894454347, 0.709142545393449, 0.529074791609835], [0.4279376583651734, 0.5634708596431771, 0.8505569717104301], [0.3460488523902999, 0.464769595519293, 0.6676839675477276], [0.8544063246675081, 0.5041190233407755, 0.9081217697141578], [0.9207009090747208, 0.2403865944739051, 0.05375410999863772], [0.6515786136947107, 0.6299918449948327, 0.45292029442034387], [0.986174217295693, 0.2424849846977214, 0.3981993323108266], [0.22101915872994693, 0.3408589198278038, 0.006381420347677524], [0.3159785813515982, 0.1145748921741011, 0.595754317197274], [0.10263421488052715, 0.5864139253490858, 0.23908000741142432], [0.8272999391532938, 0.6123527260897751, 0.3365197327803193], [0.5269583712937912, 0.25668929554516506, 0.7888411215078127], [0.2433880265410031, 0.7240751234287827, 0.8483215810528648], [0.7254601709704898, 0.8316525547295984, 0.9325253855921963], [0.5574483824856672, 0.2935331727879944, 0.6594839453793155], [0.6209642371433579, 0.054030693198821256, 0.5080873988178534], [0.9055507077365624, 0.12865888619203514, 0.9309191861440005], [0.9914469722960537, 0.3074114506206205, 0.8762107657323488], [0.4812682518247371, 0.15055826298548158, 0.9656340505308308], [0.6459219454316445, 0.9144794010251625, 0.751338812155106], [0.860840174209798, 0.8844626353077639, 0.3604624506769899], [0.8194991672032272, 0.926399617787601, 0.8059222327343247], [0.6540413175393658, 0.04579445254618297, 0.26891917826531275], [0.37778835833987046, 0.36247927666109536, 0.7989799305827889], [0.22738304978177726, 0.9038018263773739, 0.6970838854138303], [0.6362015495896184, 0.527680794236961, 0.5570915425178721], [0.6436401915860954, 0.6316925317144524, 0.9137151236993912], [0.04161828388587163, 0.3832413349082706, 0.6880829921949752], [0.7768167825719299, 0.8933821497682587, 0.7221278391266809], [0.8632760876301346, 0.3278628094906323, 0.8421587587114462], [0.8556499133262127, 0.6497385872901932, 0.5436895688477963], [0.9861940318610894, 0.03562313777386272, 0.9183454677106616], [0.8042586091176366, 0.6167222703170994, 0.24181981557207644], [0.9504247117633057, 0.3454233714011461, 0.6883727005547743], [0.9611909135491202, 0.46384154263898114, 0.32700443315058914], [0.523542176970206, 0.446222414615845, 0.9067402987747814], [0.7536954008682911, 0.6675512338797588, 0.22538238957839196], [0.1554052265688285, 0.05746097492966129, 0.8580358872587424], [0.8540838640971405, 0.9165504335482566, 0.6806982829158964], [0.7065090319405029, 0.8683059983962002, 0.05167128320624026], [0.39134812961899124, 0.8910075505622979, 0.7639815712623922], [0.1578117311479783, 0.20047326898284668, 0.9220177338840568], [0.2017488993096358, 0.6949259970936679, 0.8729196864798128], [0.5591089340651949, 0.15576770423813258, 0.1469857469387812], [0.14510398622626974, 0.24451497734532168, 0.46574271993578786], [0.13286397822351492, 0.4178244533944635, 0.03728728952131943], [0.556463206310225, 0.14027595183361663, 0.2731537988657907], [0.4093837966398032, 0.8015225687789814, 0.8033567296903834], [0.527442563956637, 0.902232617214431, 0.7066626674362227], [0.9058355503297827, 0.34983989180213004, 0.8353262183839384], [0.7108382186953104, 0.08591307895133471, 0.21434688012521974], [0.22757345065207668, 0.7943075496583976, 0.2992305547627421], [0.20454109788173636, 0.8251670332103687, 0.012981987094547232], [0.7672562637297392, 0.005429019973062554, 0.022163616037108702], [0.37487345910117564, 0.5086240194440863, 0.9061216063654387], [0.9878004014101087, 0.006345852772772331, 0.17499753379350858], [0.030061528704491303, 0.1409704315546606, 0.3337131835834506], [0.5022506782611504, 0.5448435505388706, 0.40584238936140726], [0.39560774627423445, 0.8905943695833262, 0.5850815030921116], [0.058615671926786406, 0.5365713844300387, 0.1620457551256279], [0.41843842882069693, 0.1536005983609976, 0.3127878501592438], [0.05947621790155899, 0.5412421167331932, 0.2611322146455659], [0.5196159938235607, 0.7066461551682705, 0.970261497412556], [0.30443031606149007, 0.45158581060034975, 0.4331841153149706], [0.8848298403933996, 0.7241791700943656, 0.8917110054596072], [0.5720260591898779, 0.3072801598203052, 0.8891066705989902], [0.13964015336177327, 0.2531778096760302, 0.5703756837403124], [0.2156307542329836, 0.4139947500641685, 0.87051676884144], [0.10800455881891169, 0.05554646035458266, 0.2947027428551443], [0.35198009410633857, 0.365849666213808, 0.06525787683513773], [0.5223264108118847, 0.9032195574351178, 0.28579084943315025], [0.7607724246546966, 0.3087194381828555, 0.6253235528354899], [0.5060485442077824, 0.19173600467625274, 0.9931175692203702], [0.5131805830323746, 0.07719515392040577, 0.923212006754969], [0.3629762141280106, 0.02429179642710888, 0.6963754952399983], [0.7542592485456767, 0.6478893299494212, 0.3424965345400731], [0.49944574453364454, 0.6775665366832825, 0.33758796076989583], [0.010621818120767679, 0.8221571611173205, 0.5186257457566332], [0.5857910304290109, 0.7178133992025467, 0.9729243483606071], [0.16987399482717613, 0.9942570210657463, 0.18120758122552927], [0.016362572521240848, 0.17582788603087263, 0.7255176922640298], [0.10981764283706419, 0.9078582203470377, 0.7638063718334003], [0.9252097840441119, 0.3330197086990039, 0.27888705301420136], [0.12769972651171546, 0.11121470804891687, 0.12710743734391716], [0.5753520518360334, 0.2763862879599456, 0.6115636613363361]] _OFF_WHITE = (1.0, 1.0, 240.0 / 255) + class TrackVisualizer(Visualizer): def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE): super().__init__( @@ -29,147 +37,75 @@ def _jitter(self, color, id): res = np.clip(vec + color, 0, 1) return tuple(res) - def overlay_instances( - self, - *, - boxes=None, - labels=None, - masks=None, - keypoints=None, - assigned_colors=None, - alpha=0.5 - ): + def draw_dataset_dict(self, dic): """ + Draw annotations/segmentations in Detectron2 Dataset format. + Args: - boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`, - or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image, - or a :class:`RotatedBoxes`, - or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format - for the N objects in a single image, - labels (list[str]): the text to be displayed for each instance. - masks (masks-like object): Supported types are: - * :class:`detectron2.structures.PolygonMasks`, - :class:`detectron2.structures.BitMasks`. - * list[list[ndarray]]: contains the segmentation masks for all objects in one image. - The first level of the list corresponds to individual instances. The second - level to all the polygon that compose the instance, and the third level - to the polygon coordinates. The third level should have the format of - [x0, y0, x1, y1, ..., xn, yn] (n >= 3). - * list[ndarray]: each ndarray is a binary mask of shape (H, W). - * list[dict]: each dict is a COCO-style RLE. - keypoints (Keypoint or array like): an array-like object of shape (N, K, 3), - where the N is the number of instances and K is the number of keypoints. - The last dimension corresponds to (x, y, visibility or score). - assigned_colors (list[matplotlib.colors]): a list of colors, where each color - corresponds to each mask or box in the image. Refer to 'matplotlib.colors' - for full list of formats that the colors are accepted in. + dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format. + Returns: output (VisImage): image object with visualizations. """ - num_instances = 0 - if boxes is not None: - boxes = self._convert_boxes(boxes) - num_instances = len(boxes) - if masks is not None: - # print('masks:', masks) - #masks = self._convert_masks(masks) - if num_instances: - assert len(masks) == num_instances + annos = dic.get("annotations", None) + if annos: + if "segmentation" in annos[0]: + masks = [x["segmentation"] for x in annos] else: - num_instances = len(masks) - if keypoints is not None: - if num_instances: - assert len(keypoints) == num_instances + masks = None + if "keypoints" in annos[0]: + keypts = [x["keypoints"] for x in annos] + keypts = np.array(keypts).reshape(len(annos), -1, 3) else: - num_instances = len(keypoints) - keypoints = self._convert_keypoints(keypoints) - if labels is not None: - assert len(labels) == num_instances - if assigned_colors is None: - assigned_colors = [random_color(ii, rgb=True, maximum=1) for ii in range(num_instances)] - if num_instances == 0: - return self.output - if boxes is not None and boxes.shape[1] == 5: - return self.overlay_rotated_instances( - boxes=boxes, labels=labels, assigned_colors=assigned_colors + keypts = None + + boxes = [ + BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS) + if len(x["bbox"]) == 4 + else x["bbox"] + for x in annos + ] + + colors = None + category_ids = [x["category_id"] for x in annos] + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]], id) + for id, c in enumerate(category_ids) + ] + names = self.metadata.get("thing_classes", None) + labels = _create_text_labels( + category_ids, + scores=None, + class_names=names, + is_crowd=[x.get("iscrowd", 0) for x in annos], ) - # Display in largest to smallest order to reduce occlusion. - areas = None - if boxes is not None: - areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1) - elif masks is not None: - areas = np.asarray([x.sum() for x in masks]) - if areas is not None: - sorted_idxs = np.argsort(-areas).tolist() - # Re-order overlapped instances in descending order. - boxes = boxes[sorted_idxs] if boxes is not None else None - labels = [labels[k] for k in sorted_idxs] if labels is not None else None - masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None - assigned_colors = [assigned_colors[idx] for idx in sorted_idxs] - keypoints = keypoints[sorted_idxs] if keypoints is not None else None - for i in range(num_instances): - color = assigned_colors[i] - # if boxes is not None: - # self.draw_box(boxes[i], edge_color=color) - if masks is not None: - #self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha) - binary_mask = masks[i].astype(np.uint8) - #alpha = 0.7 - #print('binary mask:', binary_mask) - self.draw_binary_mask( - binary_mask, - color=color, - edge_color=None, # _OFF_WHITE - alpha=alpha, - ) - if False: - # if labels is not None: - # first get a box - if boxes is not None: - x0, y0, x1, y1 = boxes[i] - text_pos = (x0, y0) # if drawing boxes, put text on the box corner. - horiz_align = "left" - elif masks is not None: - # skip small mask without polygon - if len(masks[i].polygons) == 0: - continue - x0, y0, x1, y1 = masks[i].bbox() - # draw text in the center (defined by median) when box is not drawn - # median is less sensitive to outliers. - text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1] - horiz_align = "center" - else: - continue # drawing the box confidence for keypoints isn't very useful. - # for small objects, draw text at the side to avoid occlusion - instance_area = (y1 - y0) * (x1 - x0) - if ( - instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale - or y1 - y0 < 40 * self.output.scale - ): - if y1 >= self.output.height - 5: - text_pos = (x1, y0) - else: - text_pos = (x0, y1) - height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width) - lighter_color = self._change_color_brightness(color, brightness_factor=0.7) - font_size = ( - np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) - * 0.5 - * self._default_font_size - ) - # self.draw_text( - # labels[i], - # text_pos, - # color=lighter_color, - # horizontal_alignment=horiz_align, - # font_size=font_size, - # ) - # draw keypoints - if keypoints is not None: - for keypoints_per_instance in keypoints: - self.draw_and_connect_keypoints(keypoints_per_instance) + self.overlay_instances( + labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors + ) + + sem_seg = dic.get("sem_seg", None) + if sem_seg is None and "sem_seg_file_name" in dic: + with PathManager.open(dic["sem_seg_file_name"], "rb") as f: + sem_seg = Image.open(f) + sem_seg = np.asarray(sem_seg, dtype="uint8") + if sem_seg is not None: + self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5) + + pan_seg = dic.get("pan_seg", None) + if pan_seg is None and "pan_seg_file_name" in dic: + with PathManager.open(dic["pan_seg_file_name"], "rb") as f: + pan_seg = Image.open(f) + pan_seg = np.asarray(pan_seg) + from panopticapi.utils import rgb2id + + pan_seg = rgb2id(pan_seg) + if pan_seg is not None: + segments_info = dic["segments_info"] + pan_seg = torch.tensor(pan_seg) + self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5) return self.output - + def draw_instance_predictions(self, predictions): """ Draw instance-level prediction results on an image. @@ -185,12 +121,9 @@ def draw_instance_predictions(self, predictions): scores = preds.scores if preds.has("scores") else None classes = preds.pred_classes if preds.has("pred_classes") else None labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None)) - if labels is not None: - labels = ["[{}] ".format(_id) + l for _id, l in enumerate(labels)] if preds.has("pred_masks"): masks = np.asarray(preds.pred_masks) - print('enter here==========') - # masks = [GenericMask(x, self.output.height, self.output.width) for x in masks] + masks = [GenericMask(x, self.output.height, self.output.width) for x in masks] else: masks = None if classes is None: diff --git a/mask2former_video/data_video/dataset_mapper.py b/mask2former_video/data_video/dataset_mapper.py index c0e23cb..f6fa382 100644 --- a/mask2former_video/data_video/dataset_mapper.py +++ b/mask2former_video/data_video/dataset_mapper.py @@ -25,6 +25,9 @@ from pycocotools import mask as coco_mask +from maskfreevis.data_fusion_modeling import extract_optical_flow_dense_matrix + + __all__ = ["YTVISDatasetMapper", "CocoClipDatasetMapper"] def seed_everything(seed): @@ -156,6 +159,7 @@ def __init__( num_classes: int = 40, src_dataset_name: str = "", tgt_dataset_name: str = "", + data_fusion_block_status: bool = False, ): """ NOTE: this interface is experimental. @@ -175,6 +179,7 @@ def __init__( self.sampling_frame_range = sampling_frame_range self.sampling_frame_shuffle = sampling_frame_shuffle self.num_classes = num_classes + self.data_fusion_block_status = data_fusion_block_status if not is_tgt: self.src_metadata = MetadataCatalog.get(src_dataset_name) @@ -212,6 +217,7 @@ def from_config(cls, cfg, is_train: bool = True, is_tgt: bool = True): sampling_frame_num = cfg.INPUT.SAMPLING_FRAME_NUM sampling_frame_range = cfg.INPUT.SAMPLING_FRAME_RANGE sampling_frame_shuffle = cfg.INPUT.SAMPLING_FRAME_SHUFFLE + data_fusion_block_status = cfg.MODEL.DATAFUSION.STATUS ret = { "is_train": is_train, @@ -224,6 +230,7 @@ def from_config(cls, cfg, is_train: bool = True, is_tgt: bool = True): "sampling_frame_shuffle": sampling_frame_shuffle, "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, "tgt_dataset_name": cfg.DATASETS.TRAIN[-1], + "data_fusion_block_status": data_fusion_block_status, } return ret @@ -273,6 +280,7 @@ def __call__(self, dataset_dict): dataset_dict["image"] = [] dataset_dict["instances"] = [] dataset_dict["file_names"] = [] + dataset_dict["optical_flow"] = [] for frame_idx in selected_idx: dataset_dict["file_names"].append(file_names[frame_idx]) @@ -280,6 +288,18 @@ def __call__(self, dataset_dict): image = utils.read_image(file_names[frame_idx], format=self.image_format) utils.check_image_size(dataset_dict, image) + if self.data_fusion_block_status: + prev_image = utils.read_image(file_names[frame_idx - 1], format=self.image_format) + utils.check_image_size(dataset_dict, prev_image) + + optical_flow_matrix = extract_optical_flow_dense_matrix(prev_image, image) + + optical_flow_aug_input = T.AugInput(optical_flow_matrix) + optical_flow_transforms = self.augmentations(optical_flow_aug_input) + optical_flow_matrix = optical_flow_aug_input.image + + dataset_dict["optical_flow"].append(torch.as_tensor(np.ascontiguousarray(optical_flow_matrix.transpose(2, 0, 1)))) + aug_input = T.AugInput(image) transforms = self.augmentations(aug_input) image = aug_input.image diff --git a/mask2former_video/data_video/datasets/ytvis.py b/mask2former_video/data_video/datasets/ytvis.py index 22fc227..b845919 100644 --- a/mask2former_video/data_video/datasets/ytvis.py +++ b/mask2former_video/data_video/datasets/ytvis.py @@ -148,7 +148,10 @@ def _get_ytvis_2021_instances_meta(): def load_ytvis_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None): - from .ytvis_api.ytvos import YTVOS + try: + from .ytvis_api.ytvos import YTVOS + except: + from ytvis_api.ytvos import YTVOS timer = Timer() json_file = PathManager.get_local_path(json_file) diff --git a/mask2former_video/data_video/datasets/ytvis_api/ytvoseval.py b/mask2former_video/data_video/datasets/ytvis_api/ytvoseval.py index 9248bc9..1ee3036 100644 --- a/mask2former_video/data_video/datasets/ytvis_api/ytvoseval.py +++ b/mask2former_video/data_video/datasets/ytvis_api/ytvoseval.py @@ -406,8 +406,8 @@ def accumulate(self, p = None): tps = np.logical_and( dtm, np.logical_not(dtIg) ) fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg) ) - tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float) - fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float) + tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float32) + fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float32) for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): tp = np.array(tp) fp = np.array(fp) diff --git a/mask2former_video/modeling/matcher.py b/mask2former_video/modeling/matcher.py index 9c4c58b..d8571bb 100644 --- a/mask2former_video/modeling/matcher.py +++ b/mask2former_video/modeling/matcher.py @@ -96,7 +96,7 @@ def masks_to_boxes_new(masks: torch.Tensor) -> torch.Tensor: hMask = torch.logical_or(torch.arange(h).unsqueeze(0).to(boxes)=boxes[:, 3, None]) wMask = torch.logical_or(torch.arange(w).unsqueeze(0).to(boxes)=boxes[:, 2, None]) - mem_mask = torch.logical_or(hMask.unsqueeze(2), wMask.unsqueeze(1)).float() + mem_mask = torch.bitwise_or(hMask.unsqueeze(2), wMask.unsqueeze(1)).float() # print('mem mask shape:', mem_mask.shape) mem_mask = 1.0 - mem_mask.view(n, -1, masks.shape[-2], masks.shape[-1]) return mem_mask diff --git a/mask2former_video/video_maskformer_model.py b/mask2former_video/video_maskformer_model.py index 6ec914d..167ce0c 100644 --- a/mask2former_video/video_maskformer_model.py +++ b/mask2former_video/video_maskformer_model.py @@ -21,6 +21,8 @@ import cv2 import numpy as np +from maskfreevis.data_fusion_modeling import DataFusionBlock, build_optical_flow_fusion_block + def unfold_wo_center(x, kernel_size, dilation): assert x.dim() == 4 assert kernel_size % 2 == 1 @@ -133,6 +135,7 @@ def __init__( pixel_std: Tuple[float], # video num_frames, + data_fusion_block: DataFusionBlock ): """ Args: @@ -159,6 +162,7 @@ def __init__( test_topk_per_image: int, instance segmentation parameter, keep topk instances per image """ super().__init__() + self.data_fusion_block = data_fusion_block self.backbone = backbone self.sem_seg_head = sem_seg_head self.criterion = criterion @@ -179,6 +183,7 @@ def __init__( @classmethod def from_config(cls, cfg): + data_fusion_block = build_optical_flow_fusion_block(cfg) backbone = build_backbone(cfg) sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) @@ -235,6 +240,7 @@ def from_config(cls, cfg): "pixel_std": cfg.MODEL.PIXEL_STD, # video "num_frames": cfg.INPUT.SAMPLING_FRAME_NUM, + "data_fusion_block": data_fusion_block } @property @@ -267,11 +273,15 @@ def forward(self, batched_inputs): Each dict contains keys "id", "category_id", "isthing". """ images = [] - + optical_flow_matrixes = [] + for video in batched_inputs: for frame in video["image"]: images.append(frame.to(self.device)) + for optical_flow_matrix in video["optical_flow"]: + optical_flow_matrixes.append(optical_flow_matrix.to(self.device)) + is_coco = (len(images) == 8) or (len(images) == 4)# change here, 4 is for swinl with bs 1 which cannot afford batch size 2 if self.training and not is_coco: k_size = 3 @@ -288,7 +298,13 @@ def forward(self, batched_inputs): images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) - features = self.backbone(images.tensor) + fusioned_tensor = images.tensor + if self.data_fusion_block is not None: + optical_flow_matrixes = [(x - self.pixel_mean) / self.pixel_std for x in optical_flow_matrixes] + optical_flow_matrixes = ImageList.from_tensors(optical_flow_matrixes, self.size_divisibility) + fusioned_tensor = self.data_fusion_block(images.tensor, optical_flow_matrixes.tensor) + + features = self.backbone(fusioned_tensor) outputs = self.sem_seg_head(features) if self.training: @@ -389,7 +405,7 @@ def inference_video(self, pred_cls, pred_masks, img_size, output_height, output_ pred_masks = pred_masks[:, :, : img_size[0], : img_size[1]] pred_masks = F.interpolate( - pred_masks, size=(output_height, output_width), mode="bilinear", align_corners=False + pred_masks, size=(int(output_height), int(output_width)), mode="bilinear", align_corners=False ) masks = pred_masks > 0. diff --git a/maskfreevis/config/__init__.py b/maskfreevis/config/__init__.py new file mode 100644 index 0000000..1731097 --- /dev/null +++ b/maskfreevis/config/__init__.py @@ -0,0 +1 @@ +from .config import get_cfg \ No newline at end of file diff --git a/maskfreevis/config/config.py b/maskfreevis/config/config.py new file mode 100644 index 0000000..3501d32 --- /dev/null +++ b/maskfreevis/config/config.py @@ -0,0 +1,13 @@ +from detectron2.config import CfgNode + + +def get_cfg() -> CfgNode: + """ + Get a copy of the default config. + + Returns: + a detectron2 CfgNode instance. + """ + from .defaults import _C + + return _C.clone() diff --git a/maskfreevis/config/defaults.py b/maskfreevis/config/defaults.py new file mode 100644 index 0000000..d694f6f --- /dev/null +++ b/maskfreevis/config/defaults.py @@ -0,0 +1,7 @@ +from detectron2.config.defaults import _C +from detectron2.config import CfgNode as CN + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASETS.DATASET_RATIO = () \ No newline at end of file diff --git a/maskfreevis/data_fusion_modeling/__init__.py b/maskfreevis/data_fusion_modeling/__init__.py new file mode 100644 index 0000000..49de5f8 --- /dev/null +++ b/maskfreevis/data_fusion_modeling/__init__.py @@ -0,0 +1,5 @@ +from .base import DataFusionBlock +from .build import build_optical_flow_fusion_block +from .config import add_data_fusion_block_config +from .data_fusion_blocks import OpticalFlowFusionBlock +from .optical_flow import extract_optical_flow_dense_matrix \ No newline at end of file diff --git a/maskfreevis/data_fusion_modeling/base.py b/maskfreevis/data_fusion_modeling/base.py new file mode 100644 index 0000000..78b6e2c --- /dev/null +++ b/maskfreevis/data_fusion_modeling/base.py @@ -0,0 +1,25 @@ +from torch import nn + +from abc import ABCMeta, abstractmethod + + +class DataFusionBlock(nn.Module, metaclass=ABCMeta): + """ + Abstract base class for network backbones. + """ + + def __init__(self): + """ + The `__init__` method of any subclass can specify its own set of arguments. + """ + super().__init__() + + @abstractmethod + def forward(self): + """ + Subclasses must override this method, but adhere to the same return type. + + Returns: + torch.Tensor which is fusioned feature map + """ + pass \ No newline at end of file diff --git a/maskfreevis/data_fusion_modeling/build.py b/maskfreevis/data_fusion_modeling/build.py new file mode 100644 index 0000000..984ab15 --- /dev/null +++ b/maskfreevis/data_fusion_modeling/build.py @@ -0,0 +1,28 @@ +from .base import DataFusionBlock + +from detectron2.utils.registry import Registry + +DATAFUSION_REGISTRY = Registry("DATAFUSION") +DATAFUSION_REGISTRY.__doc__ = """ +Registry for data fusions, which fusion raw images and additional features + +The registered object will be called with `obj(cfg)` +and expected to return a `nn.Module` object. +""" + + +def build_optical_flow_fusion_block(cfg): + """ + Build a backbone from `cfg.MODEL.BACKBONE.NAME`. + If `cfg.MODEL.DATAFUSION.STATUS` parameter is true, + it returns the module of OpticalFlowFusionBlock class. + """ + data_fusion_status = cfg.MODEL.DATAFUSION.STATUS + + if data_fusion_status: + data_fusion_block_name = cfg.MODEL.DATAFUSION.NAME + data_fusion_block = DATAFUSION_REGISTRY.get(data_fusion_block_name)(cfg) + assert isinstance(data_fusion_block, DataFusionBlock) + return data_fusion_block + + return None \ No newline at end of file diff --git a/maskfreevis/data_fusion_modeling/config.py b/maskfreevis/data_fusion_modeling/config.py new file mode 100644 index 0000000..7af8f9c --- /dev/null +++ b/maskfreevis/data_fusion_modeling/config.py @@ -0,0 +1,22 @@ +from detectron2.config import CfgNode as CN + + +def add_data_fusion_block_config(cfg): + + cfg.MODEL.DATAFUSION = CN() + cfg.MODEL.DATAFUSION.STATUS = False + cfg.MODEL.DATAFUSION.NAME = "build_optical_flow_fusion_block" + cfg.MODEL.DATAFUSION.NORM = "BN" + + # Channel size of datafusion block output for feed the backbone + cfg.MODEL.DATAFUSION.OUT_FEATURES = 3 + + # 1x1 conv in features and output features which is takes as input raw image + cfg.MODEL.DATAFUSION.RAW_IMAGE = CN() + cfg.MODEL.DATAFUSION.RAW_IMAGE.IN_FEATURES = 3 + cfg.MODEL.DATAFUSION.RAW_IMAGE.OUT_FEATURES = 8 + + # 1x1 conv in features and output features which is takes as input optical flow feature map + cfg.MODEL.DATAFUSION.OPTICAL_FLOW = CN() + cfg.MODEL.DATAFUSION.OPTICAL_FLOW.IN_FEATURES = 3 + cfg.MODEL.DATAFUSION.OPTICAL_FLOW.OUT_FEATURES = 8 diff --git a/maskfreevis/data_fusion_modeling/data_fusion_blocks.py b/maskfreevis/data_fusion_modeling/data_fusion_blocks.py new file mode 100644 index 0000000..2786814 --- /dev/null +++ b/maskfreevis/data_fusion_modeling/data_fusion_blocks.py @@ -0,0 +1,103 @@ +import torch +from torch import nn + +import torch.nn.functional as F +import fvcore.nn.weight_init as weight_init + +from detectron2.layers.wrappers import Conv2d +from detectron2.layers.batch_norm import get_norm + +from .base import DataFusionBlock +from .build import DATAFUSION_REGISTRY + + +class OpticalFlowFusionBlock(DataFusionBlock): + """ + OpticalFlowFusionBlock takes 2 feature. First feature is raw 3 channel image data. + Second feature is extracted 3 channel OpticalFlow feature from raw image data. + The aim of this module is to data fusion and concanete 2 features. + After the concanating prosess, output of OpticalFlowFusionBlock is made ready for feed the Backbone stem. + """ + + def __init__(self, + raw_image_in_channels: int = 3, + raw_image_out_channels: int = 8, + optical_flow_in_channels: int = 3, + optical_flow_out_channels: int = 8, + fusioned_feature_out_channels: int = 3, + norm: str = "BN"): + + super().__init__() + + self.conv1 = Conv2d( + raw_image_in_channels, + raw_image_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + norm=get_norm(norm, raw_image_out_channels), + ) + weight_init.c2_msra_fill(self.conv1) + + self.conv2 = Conv2d( + optical_flow_in_channels, + optical_flow_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + norm=get_norm(norm, optical_flow_out_channels), + ) + weight_init.c2_msra_fill(self.conv2) + + fusion_conv_in_channel = raw_image_out_channels + optical_flow_out_channels + self.conv3 = Conv2d( + fusion_conv_in_channel, + fusioned_feature_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + norm=get_norm(norm, fusioned_feature_out_channels), + ) + weight_init.c2_msra_fill(self.conv3) + + def forward(self, + raw_image: torch.Tensor, + optical_flow_image: torch.Tensor): + + raw_image_feature = self.conv1(raw_image) + raw_image_feature = F.relu_(raw_image_feature) + + optical_flow_image_feature = self.conv2(optical_flow_image) + optical_flow_image_feature = F.relu_(optical_flow_image_feature) + + concanated_feature = torch.cat((raw_image_feature, optical_flow_image_feature), dim=1) + concanated_feature = self.conv3(concanated_feature) + fusioned_feature = F.relu_(concanated_feature) + + return fusioned_feature + + +@DATAFUSION_REGISTRY.register() +def build_optical_flow_fusion_block(cfg): + """ + Reads parameters of optical flow data fusion block from cfg. + + Return + OpticalFlowFusionBlock which is used to feed backbone. + """ + raw_image_in_channels = cfg.MODEL.DATAFUSION.RAW_IMAGE.IN_FEATURES + raw_image_out_channels = cfg.MODEL.DATAFUSION.RAW_IMAGE.OUT_FEATURES + optical_flow_in_channels = cfg.MODEL.DATAFUSION.OPTICAL_FLOW.IN_FEATURES + optical_flow_out_channels = cfg.MODEL.DATAFUSION.OPTICAL_FLOW.OUT_FEATURES + fusioned_feature_out_channels = cfg.MODEL.DATAFUSION.OUT_FEATURES + norm = cfg.MODEL.DATAFUSION.NORM + + return OpticalFlowFusionBlock(raw_image_in_channels=raw_image_in_channels, + raw_image_out_channels=raw_image_out_channels, + optical_flow_in_channels=optical_flow_in_channels, + optical_flow_out_channels=optical_flow_out_channels, + fusioned_feature_out_channels=fusioned_feature_out_channels, + norm=norm) \ No newline at end of file diff --git a/maskfreevis/data_fusion_modeling/optical_flow.py b/maskfreevis/data_fusion_modeling/optical_flow.py new file mode 100644 index 0000000..71c1ff9 --- /dev/null +++ b/maskfreevis/data_fusion_modeling/optical_flow.py @@ -0,0 +1,30 @@ +import cv2 as cv +import numpy as np + + +def extract_optical_flow_dense_matrix(prev_img: np.ndarray, + next_img: np.ndarray) -> np.ndarray: + """ + Extract dense optical flow feature between 2 images using Lucas-Kanade method. + + Args: + prev_img (np.ndarray): (H,W,3) RGB image in 0-255 range, can be either float or uint8 + next_img (np.ndarray): (H,W,3) RGB image in 0-255 range, can be either float or uint8 + + return + (H,W,3) RGB dense optical flow feature in 0-255 range, can be either float or uint8 + """ + + prev_gray_img = cv.cvtColor(prev_img, cv.COLOR_RGB2GRAY) + prev_img_hsv_array = np.zeros_like(prev_img) + prev_img_hsv_array[..., 1] = 255 + + # TODO: calcOpticalFlowFarneback params will check. + next_gray_img = cv.cvtColor(next_img, cv.COLOR_BGR2GRAY) + flow_between_imgs = cv.calcOpticalFlowFarneback(prev_gray_img, next_gray_img, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + magnitude, angle = cv.cartToPolar(flow_between_imgs[..., 0], flow_between_imgs[..., 1]) + prev_img_hsv_array[..., 0] = angle * 180 / np.pi / 2 + prev_img_hsv_array[..., 2] = cv.normalize(magnitude, None, 0, 255, cv.NORM_MINMAX) + optical_flow_rgb_img = cv.cvtColor(prev_img_hsv_array, cv.COLOR_HSV2RGB) + return optical_flow_rgb_img \ No newline at end of file diff --git a/maskfreevis/demo.py b/maskfreevis/demo.py new file mode 100644 index 0000000..6d903cb --- /dev/null +++ b/maskfreevis/demo.py @@ -0,0 +1,133 @@ +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) + +import copy +import random +import imageio + +from torch.cuda.amp import autocast + +from detectron2.data import MetadataCatalog +from detectron2.data.catalog import DatasetCatalog +from detectron2.utils.visualizer import ColorMode, Visualizer +from detectron2.data.detection_utils import read_image +from detectron2.projects.deeplab import add_deeplab_config + +from mask2former import add_maskformer2_config +from mask2former_video import add_maskformer2_video_config +from mask2former_video.data_video.datasets.ytvis import register_ytvis_instances + +from maskfreevis.config import get_cfg +from maskfreevis.data_fusion_modeling import add_data_fusion_block_config + +from demo_video.predictor import VisualizationDemo +from demo_video.visualizer import TrackVisualizer + + +def setup_config(metadata): + cfg = get_cfg() + add_deeplab_config(cfg) + add_maskformer2_config(cfg) + add_maskformer2_video_config(cfg) + add_data_fusion_block_config(cfg) + cfg.merge_from_file("configs/youtubevis_2019/video_maskformer2_R50_bs16_8ep.yaml") + cfg.MODEL.WEIGHTS = "mfvis_models/model_final_r50_0466.pth" + cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = len(metadata.thing_classes) + cfg.freeze() + return cfg + +def register_datasets(): + global trainDatasetName + global testDatasetName + + trainAnnPath = "datasets/ytvis_2019/train.json" + trainImagesDir = "datasets/ytvis_2019/train/JPEGImages" + + testAnnPath = "datasets/ytvis_2019/test.json" + testImagesDir = "datasets/ytvis_2019/test/JPEGImages" + + ret = {"thing_colors": [[0, 0, 142], [220, 20, 60]]} + + try: + register_ytvis_instances(trainDatasetName, ret, trainAnnPath, trainImagesDir) + register_ytvis_instances(testDatasetName, ret, testAnnPath, testImagesDir) + except: + pass + +def get_dataset_dict(): + global trainDatasetName + global testDatasetName + + train_metadata = MetadataCatalog.get(testDatasetName) + dataset_dicts = DatasetCatalog.get(testDatasetName) + return train_metadata, dataset_dicts + +def extract_frame_dic(dic, frame_idx): + frame_dic = copy.deepcopy(dic) + annos = frame_dic.get("annotations", None) + if annos: + frame_dic["annotations"] = annos[frame_idx] + + return frame_dic + +def visualize_dataset_and_predict(): + global trainDatasetName + global testDatasetName + + trainDatasetName = "ytvis_2019_train" + testDatasetName = "ytvis_2019_test_unity" + + seedParam = 42 + numSamples = 5 + threshValue = 0.2 + dirname = "ytvis_2019_train_gt_visualize" + predictOutputDir = f"ytvis_2019_train_predict_visualize_thresh_{threshValue}" + + random.seed(seedParam) + register_datasets() + train_metadata, dataset_dicts = get_dataset_dict() + cfg = setup_config(train_metadata) + demo = VisualizationDemo(cfg, train_metadata) + + for d in random.sample(dataset_dicts, numSamples): + isFileExist = True + vid_name = d["file_names"][0].split('/')[-2] + if not os.path.isdir(os.path.join(dirname, vid_name)): + isFileExist = False + os.makedirs(os.path.join(dirname, vid_name), exist_ok=True) + + video_frames = [] + gt_images = [] + + for idx, file_name in enumerate(d["file_names"]): + img = read_image(file_name, format="BGR") + video_frames.append(img) + + if not isFileExist: + visualizer = TrackVisualizer(img[:, :, ::-1], metadata=train_metadata, scale=1.0, instance_mode=ColorMode.SEGMENTATION) + vis = visualizer.draw_dataset_dict(extract_frame_dic(d, idx)) + fpath = os.path.join(dirname, vid_name, file_name.split('/')[-1]) + vis.save(fpath) + gt_images.append(vis.get_image()) + + if gt_images: + imageio.mimsave(os.path.join(dirname, vid_name) + ".gif", gt_images, fps=5) + + if not os.path.isdir(os.path.join(predictOutputDir, vid_name)): + with autocast(): + predictions, visualized_output = demo.run_on_video(video_frames, threshValue) + + predicted_images = [] + os.makedirs(os.path.join(predictOutputDir, vid_name), exist_ok=True) + + for path, vis_output in zip(d["file_names"], visualized_output): + out_filename = os.path.join(predictOutputDir, vid_name, os.path.basename(path)) + vis_output.save(out_filename) + predicted_images.append(vis_output.get_image()) + imageio.mimsave(os.path.join(predictOutputDir, vid_name) + ".gif", predicted_images, fps=5) + + +if __name__ == "__main__": + visualize_dataset_and_predict() \ No newline at end of file diff --git a/maskfreevis/train.py b/maskfreevis/train.py new file mode 100644 index 0000000..0b79eee --- /dev/null +++ b/maskfreevis/train.py @@ -0,0 +1,77 @@ +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) + +from detectron2.data import MetadataCatalog +from detectron2.evaluation import COCOEvaluator, inference_on_dataset +from detectron2.projects.deeplab import add_deeplab_config + +from mask2former import add_maskformer2_config +from mask2former_video import add_maskformer2_video_config +from mask2former_video.data_video.datasets.ytvis import register_ytvis_instances + +from maskfreevis.config import get_cfg +from maskfreevis.data_fusion_modeling import add_data_fusion_block_config + +from train_net_video import Trainer +from utils import ValidationLoss, CUDAMemoryOptimizer + + +# Define datasets +trainDatasetName = "ytvis_train_unity" +trainAnnPath = "datasets/ytvis_2019/train.json" +trainImagesDir = "datasets/ytvis_2019/train/JPEGImages" + +validDatasetName = "ytvis_valid_unity" +validAnnPath = "datasets/ytvis_2019/valid.json" +validImagesDir = "datasets/ytvis_2019/valid/JPEGImages" + +testDatasetName = "ytvis_test_unity" +testAnnPath = "datasets/ytvis_2019/test.json" +testImagesDir = "datasets/ytvis_2019/test/JPEGImages" + +ret = { + "thing_classes": ["arac", "insan"], + "thing_colors": [(0, 0, 142), (220, 20, 60)], +} + +register_ytvis_instances(trainDatasetName, ret, trainAnnPath, trainImagesDir) +register_ytvis_instances(validDatasetName, ret, validAnnPath, validImagesDir) +register_ytvis_instances(testDatasetName, ret, testAnnPath, testImagesDir) + +train_metadata = MetadataCatalog.get(trainDatasetName) + +# Set model configs +cfg = get_cfg() +add_deeplab_config(cfg) +add_maskformer2_config(cfg) +add_maskformer2_video_config(cfg) +add_data_fusion_block_config(cfg) +cfg.merge_from_file("configs/youtubevis_2019/video_maskformer2_R50_bs16_8ep.yaml") + +cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = len(train_metadata.thing_classes) +cfg.MODEL.WEIGHTS = "mfvis_models/model_final_r50_0466.pth" +cfg.DATALOADER.NUM_WORKERS = 4 +cfg.DATASETS.TRAIN = (trainDatasetName,) +cfg.DATASETS.TEST = (validDatasetName,) +cfg.SOLVER.IMS_PER_BATCH = 1 +cfg.SOLVER.MAX_ITER = 50000 +cfg.MODEL.DEVICE = "cuda" +cfg.OUTPUT_DIR = "models/MaskFreeVIS_R50_50k_iter_1x_pretrained" +os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) + +with open("models/MaskFreeVIS_R50_50k_iter_1x_pretrained/config.yaml", "w") as f: + f.write(cfg.dump()) + +# Model train +trainer = Trainer(cfg) +trainer.register_hooks([CUDAMemoryOptimizer(), + ValidationLoss(cfg)]) +trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1] +trainer.resume_or_load(resume=True) +trainer.train() + +# Model eval +cfg.DATASETS.TEST = (testDatasetName,) +Trainer.test(cfg, trainer.model) \ No newline at end of file diff --git a/maskfreevis/utils.py b/maskfreevis/utils.py new file mode 100644 index 0000000..3871081 --- /dev/null +++ b/maskfreevis/utils.py @@ -0,0 +1,36 @@ +import torch +import detectron2.utils.comm as comm + +from detectron2.engine import HookBase +from train_net_video import Trainer + + +class ValidationLoss(HookBase): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg.clone() + self.cfg.DATASETS.TRAIN = cfg.DATASETS.TEST + self._loader = iter(Trainer.build_train_loader(self.cfg)) + + def after_step(self): + data = next(self._loader) + with torch.no_grad(): + loss_dict = self.trainer.model(data) + + losses = sum(loss_dict.values()) + assert torch.isfinite(losses).all(), loss_dict + + loss_dict_reduced = {"val_" + k: v.item() for k, v in + comm.reduce_dict(loss_dict).items()} + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + if comm.is_main_process(): + self.trainer.storage.put_scalars(total_val_loss=losses_reduced, + **loss_dict_reduced) + + +class CUDAMemoryOptimizer(HookBase): + def __init__(self): + super().__init__() + + def after_step(self): + torch.cuda.empty_cache() diff --git a/train_net_video.py b/train_net_video.py index 884694d..78cf48f 100644 --- a/train_net_video.py +++ b/train_net_video.py @@ -75,6 +75,10 @@ def build_evaluator(cls, cfg, dataset_name, output_folder=None): @classmethod def build_train_loader(cls, cfg): + assert len(cfg.DATASETS.TRAIN) == len(cfg.DATASETS.DATASET_RATIO), ( + f"cfg.DATASETS.TRAIN length and cfg.DATASETS.DATASET_RATIO length should be equal, " + f"{len(cfg.DATASETS.TRAIN)} != {len(cfg.DATASETS.DATASET_RATIO)}") + mappers = [] for d_i, dataset_name in enumerate(cfg.DATASETS.TRAIN): if dataset_name.startswith('coco'): @@ -92,8 +96,7 @@ def build_train_loader(cls, cfg): build_detection_train_loader(cfg, mapper=mapper, dataset_name=dataset_name) for mapper, dataset_name in zip(mappers, cfg.DATASETS.TRAIN) ] - DATASET_RATIO = [1.0, 0.75] - combined_data_loader = build_combined_loader(cfg, loaders, DATASET_RATIO) + combined_data_loader = build_combined_loader(cfg, loaders, cfg.DATASETS.DATASET_RATIO) return combined_data_loader