From 8c6d0d20cd38ef4c09adc649634e4d5524ceef0c Mon Sep 17 00:00:00 2001 From: CoinCheung <867153576@qq.com> Date: Mon, 27 Jun 2022 17:18:27 +0800 Subject: [PATCH] Dev (#242) * refactor dataset * add customer dataset and associated config file * add script of check dataset information * modify script of check dataset * format * gitignore --- .gitignore | 2 - README.md | 9 ++- configs/bisenet_customer.py | 23 ++++++ lib/data/__init__.py | 2 + lib/{ => data}/base_dataset.py | 1 - lib/{ => data}/cityscapes_cv2.py | 4 +- lib/{ => data}/coco.py | 6 +- lib/data/customer_dataset.py | 21 ++++++ lib/{ => data}/get_dataloader.py | 10 ++- lib/{ => data}/sampler.py | 0 lib/{ => data}/transform_cv2.py | 0 run.sh | 14 ---- tools/check_dataset_info.py | 121 +++++++++++++++++++++++++++++++ tools/evaluate.py | 2 +- tools/train.py | 2 +- tools/train_amp.py | 2 +- 16 files changed, 189 insertions(+), 30 deletions(-) create mode 100644 configs/bisenet_customer.py create mode 100644 lib/data/__init__.py rename lib/{ => data}/base_dataset.py (97%) rename lib/{ => data}/cityscapes_cv2.py (98%) rename lib/{ => data}/coco.py (97%) create mode 100644 lib/data/customer_dataset.py rename lib/{ => data}/get_dataloader.py (87%) rename lib/{ => data}/sampler.py (100%) rename lib/{ => data}/transform_cv2.py (100%) delete mode 100644 run.sh create mode 100644 tools/check_dataset_info.py diff --git a/.gitignore b/.gitignore index 2e880ed..41115ee 100644 --- a/.gitignore +++ b/.gitignore @@ -104,8 +104,6 @@ venv.bak/ ## Coin: -data -data/ play.py preprocess_data.py res/ diff --git a/README.md b/README.md index be685b6..a6282fd 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,14 @@ frankfurt_000001_079206_leftImg8bit.png,frankfurt_000001_079206_gtFine_labelIds. ... ``` Each line is a pair of training sample and ground truth image path, which are separated by a single comma `,`. -Then you need to change the field of `im_root` and `train/val_im_anns` in the configuration files. If you found what shows in `cityscapes_cv2.py` is not clear, you can also see `coco.py`. + +I recommand you to check the information of your dataset with the script: +``` +$ python tools/check_dataset_info.py --im_root /path/to/your/data_root --im_anns /path/to/your/anno_file +``` +This will print some of the information of your dataset. + +Then you need to change the field of `im_root` and `train/val_im_anns` in the config file. I prepared a demo config file for you named [`bisenet_customer.py`](./configs/bisenet_customer.py). You can start from this conig file. ## train diff --git a/configs/bisenet_customer.py b/configs/bisenet_customer.py new file mode 100644 index 0000000..22f09c7 --- /dev/null +++ b/configs/bisenet_customer.py @@ -0,0 +1,23 @@ + +cfg = dict( + model_type='bisenetv1', + n_cats=20, + num_aux_heads=2, + lr_start=1e-2, + weight_decay=5e-4, + warmup_iters=1000, + max_iter=80000, + dataset='CustomerDataset', + im_root='./datasets/cityscapes', + train_im_anns='./datasets/cityscapes/train.txt', + val_im_anns='./datasets/cityscapes/val.txt', + scales=[0.75, 2.], + cropsize=[512, 512], + eval_crop=[512, 512], + eval_scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + ims_per_gpu=8, + eval_ims_per_gpu=2, + use_fp16=True, + use_sync_bn=False, + respth='./res', +) diff --git a/lib/data/__init__.py b/lib/data/__init__.py new file mode 100644 index 0000000..c275c03 --- /dev/null +++ b/lib/data/__init__.py @@ -0,0 +1,2 @@ + +from .get_dataloader import get_data_loader diff --git a/lib/base_dataset.py b/lib/data/base_dataset.py similarity index 97% rename from lib/base_dataset.py rename to lib/data/base_dataset.py index c7fe606..86eca5d 100644 --- a/lib/base_dataset.py +++ b/lib/data/base_dataset.py @@ -11,7 +11,6 @@ import cv2 import numpy as np -from lib.sampler import RepeatedDistSampler diff --git a/lib/cityscapes_cv2.py b/lib/data/cityscapes_cv2.py similarity index 98% rename from lib/cityscapes_cv2.py rename to lib/data/cityscapes_cv2.py index 9f07803..84e527a 100644 --- a/lib/cityscapes_cv2.py +++ b/lib/data/cityscapes_cv2.py @@ -11,8 +11,8 @@ import cv2 import numpy as np -import lib.transform_cv2 as T -from lib.base_dataset import BaseDataset +import lib.data.transform_cv2 as T +from lib.data.base_dataset import BaseDataset labels_info = [ diff --git a/lib/coco.py b/lib/data/coco.py similarity index 97% rename from lib/coco.py rename to lib/data/coco.py index 2395fd4..4d0676d 100644 --- a/lib/coco.py +++ b/lib/data/coco.py @@ -11,9 +11,8 @@ import cv2 import numpy as np -import lib.transform_cv2 as T -from lib.sampler import RepeatedDistSampler -from lib.base_dataset import BaseDataset +import lib.data.transform_cv2 as T +from lib.data.base_dataset import BaseDataset ''' 91(thing) + 91(stuff) = 182 classes, label proportions are: @@ -51,6 +50,7 @@ def __init__(self, dataroot, annpath, trans_func=None, mode='train'): super(CocoStuff, self).__init__( dataroot, annpath, trans_func, mode) self.n_cats = 171 # 91 stuff, 91 thing, 11 of thing have no annos + self.lb_ignore = 255 ## label mapping, remove non-existing labels missing = [11, 25, 28, 29, 44, 65, 67, 68, 70, 82, 90] diff --git a/lib/data/customer_dataset.py b/lib/data/customer_dataset.py new file mode 100644 index 0000000..f7355e6 --- /dev/null +++ b/lib/data/customer_dataset.py @@ -0,0 +1,21 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import lib.data.transform_cv2 as T +from lib.data.base_dataset import BaseDataset + + +class CustomerDataset(BaseDataset): + + def __init__(self, dataroot, annpath, trans_func=None, mode='train'): + super(CustomerDataset, self).__init__( + dataroot, annpath, trans_func, mode) + self.lb_ignore = 255 + + self.to_tensor = T.ToTensor( + mean=(0.4, 0.4, 0.4), # rgb + std=(0.2, 0.2, 0.2), + ) + + diff --git a/lib/get_dataloader.py b/lib/data/get_dataloader.py similarity index 87% rename from lib/get_dataloader.py rename to lib/data/get_dataloader.py index db9e3c5..6bd1a86 100644 --- a/lib/get_dataloader.py +++ b/lib/data/get_dataloader.py @@ -3,10 +3,12 @@ from torch.utils.data import Dataset, DataLoader import torch.distributed as dist -import lib.transform_cv2 as T -from lib.sampler import RepeatedDistSampler -from lib.cityscapes_cv2 import CityScapes -from lib.coco import CocoStuff +import lib.data.transform_cv2 as T +from lib.data.sampler import RepeatedDistSampler + +from lib.data.cityscapes_cv2 import CityScapes +from lib.data.coco import CocoStuff +from lib.data.customer_dataset import CustomerDataset diff --git a/lib/sampler.py b/lib/data/sampler.py similarity index 100% rename from lib/sampler.py rename to lib/data/sampler.py diff --git a/lib/transform_cv2.py b/lib/data/transform_cv2.py similarity index 100% rename from lib/transform_cv2.py rename to lib/data/transform_cv2.py diff --git a/run.sh b/run.sh deleted file mode 100644 index 7dccdfe..0000000 --- a/run.sh +++ /dev/null @@ -1,14 +0,0 @@ - - -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -export CUDA_VISIBLE_DEVICES=2,3 -PORT=12333 -NGPUS=2 -cfg=configs/bisenetv1_city.py -# cfg=configs/bisenetv2_city.py -# cfg=configs/bisenetv1_coco.py -# cfg=configs/bisenetv2_coco.py - -torchrun --nproc_per_node=$NGPUS --master_port $PORT tools/train_amp.py --config $cfg -# python -m torch.distributed.launch --use_env --nproc_per_node=$NGPUS --master_port $PORT tools/train_amp.py --config $cfg - diff --git a/tools/check_dataset_info.py b/tools/check_dataset_info.py new file mode 100644 index 0000000..183ba44 --- /dev/null +++ b/tools/check_dataset_info.py @@ -0,0 +1,121 @@ + +import os +import os.path as osp +import argparse +from tqdm import tqdm + +import cv2 +import numpy as np + + +parse = argparse.ArgumentParser() +parse.add_argument('--im_root', dest='im_root', type=str, default='./datasets/cityscapes',) +parse.add_argument('--im_anns', dest='im_anns', type=str, default='./datasets/cityscapes/train.txt',) +parse.add_argument('--lb_ignore', dest='lb_ignore', type=int, default=255) +args = parse.parse_args() + +lb_ignore = args.lb_ignore + + +with open(args.im_anns, 'r') as fr: + lines = fr.read().splitlines() + +n_pairs = len(lines) +impaths, lbpaths = [], [] +for l in lines: + impth, lbpth = l.split(',') + impth = osp.join(args.im_root, impth) + lbpth = osp.join(args.im_root, lbpth) + impaths.append(impth) + lbpaths.append(lbpth) + + +## shapes +max_shape_area, min_shape_area = [0, 0], [100000, 100000] +max_shape_height, min_shape_height = [0, 0], [100000, 100000] +max_shape_width, min_shape_width = [0, 0], [100000, 100000] +max_lb_val, min_lb_val = -1, 10000000 +for impth, lbpth in tqdm(zip(impaths, lbpaths), total=n_pairs): + im = cv2.imread(impth)[:, :, ::-1] + lb = cv2.imread(lbpth, 0) + assert im.shape[:2] == lb.shape + + shape = lb.shape + area = shape[0] * shape[1] + if area > max_shape_area[0] * max_shape_area[1]: + max_shape_area = shape + if area < min_shape_area[0] * min_shape_area[1]: + min_shape_area = shape + + if shape[0] > max_shape_height[0]: + max_shape_height = shape + if shape[0] < min_shape_height[0]: + min_shape_height = shape + + if shape[1] > max_shape_width[1]: + max_shape_width = shape + if shape[1] < min_shape_width[1]: + min_shape_width = shape + + lb = lb[lb != lb_ignore] + if lb.size > 0: + max_lb_val = max(max_lb_val, np.max(lb)) + min_lb_val = min(min_lb_val, np.min(lb)) + +min_lb_val = 0 +max_lb_val = 181 +lb_minlength = 182 +## label info +lb_minlength = max_lb_val+1-min_lb_val +lb_hist = np.zeros(lb_minlength) +for lbpth in tqdm(lbpaths): + lb = cv2.imread(lbpth, 0) + lb = lb[lb != lb_ignore] + min_lb_val + lb_hist += np.bincount(lb, minlength=lb_minlength) + +lb_missing_vals = [ind + min_lb_val + for ind, el in enumerate(lb_hist.tolist()) if el == 0] +lb_ratios = (lb_hist / lb_hist.sum()).tolist() + + +## pixel mean/std +rgb_mean = np.zeros(3).astype(np.float32) +n_pixels = 0 +for impth in tqdm(impaths): + im = cv2.imread(impth)[:, :, ::-1].astype(np.float32) + im = im.reshape(-1, 3) / 255. + n_pixels += im.shape[0] + rgb_mean += im.sum(axis=0) +rgb_mean = (rgb_mean / n_pixels) + +rgb_std = np.zeros(3).astype(np.float32) +for impth in tqdm(impaths): + im = cv2.imread(impth)[:, :, ::-1].astype(np.float32) + im = im.reshape(-1, 3) / 255. + + a = (im - rgb_mean.reshape(1, 3)) ** 2 + rgb_std += a.sum(axis=0) +rgb_std = (rgb_std / n_pixels) ** 0.5 + +rgb_mean = rgb_mean.tolist() +rgb_std = rgb_std.tolist() + + +print('\n') +print(f'there are {n_pairs} lines in {args.im_anns}, which means {n_pairs} image/label image pairs') +print('\n') + +print(f'max and min image shapes by area are: {max_shape_area}, {min_shape_area}') +print(f'max and min image shapes by height are: {max_shape_height}, {min_shape_height}') +print(f'max and min image shapes by width are: {max_shape_width}, {min_shape_width}') +print('\n') + +print(f'we ignore label value of {args.lb_ignore} in label images') +print(f'label values are within range of [{min_lb_val}, {max_lb_val}]') +print(f'label values that are missing: {lb_missing_vals}') +print('ratios of each label value: ') +print('\t', lb_ratios) +print('\n') + +print('pixel mean rgb: ', rgb_mean) +print('pixel std rgb: ', rgb_std) diff --git a/tools/evaluate.py b/tools/evaluate.py index cec751e..3da5c4a 100644 --- a/tools/evaluate.py +++ b/tools/evaluate.py @@ -22,7 +22,7 @@ from lib.models import model_factory from configs import set_cfg_from_file from lib.logger import setup_logger -from lib.get_dataloader import get_data_loader +from lib.data import get_data_loader def get_round_size(size, divisor=32): diff --git a/tools/train.py b/tools/train.py index bf604f9..35a2103 100644 --- a/tools/train.py +++ b/tools/train.py @@ -19,7 +19,7 @@ from lib.models import model_factory from configs import set_cfg_from_file -from lib.get_dataloader import get_data_loader +from lib.data import get_data_loader from tools.evaluate import eval_model from lib.ohem_ce_loss import OhemCELoss from lib.lr_scheduler import WarmupPolyLrScheduler diff --git a/tools/train_amp.py b/tools/train_amp.py index 63db177..95be433 100644 --- a/tools/train_amp.py +++ b/tools/train_amp.py @@ -20,7 +20,7 @@ from lib.models import model_factory from configs import set_cfg_from_file -from lib.get_dataloader import get_data_loader +from lib.data import get_data_loader from evaluate import eval_model from lib.ohem_ce_loss import OhemCELoss from lib.lr_scheduler import WarmupPolyLrScheduler