Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* refactor dataset

* add customer dataset and associated config file

* add script of check dataset information

* modify script of check dataset

* format

* gitignore
  • Loading branch information
CoinCheung authored Jun 27, 2022
1 parent a54e37a commit 8c6d0d2
Show file tree
Hide file tree
Showing 16 changed files with 189 additions and 30 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ venv.bak/


## Coin:
data
data/
play.py
preprocess_data.py
res/
Expand Down
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,14 @@ frankfurt_000001_079206_leftImg8bit.png,frankfurt_000001_079206_gtFine_labelIds.
...
```
Each line is a pair of training sample and ground truth image path, which are separated by a single comma `,`.
Then you need to change the field of `im_root` and `train/val_im_anns` in the configuration files. If you found what shows in `cityscapes_cv2.py` is not clear, you can also see `coco.py`.

I recommand you to check the information of your dataset with the script:
```
$ python tools/check_dataset_info.py --im_root /path/to/your/data_root --im_anns /path/to/your/anno_file
```
This will print some of the information of your dataset.

Then you need to change the field of `im_root` and `train/val_im_anns` in the config file. I prepared a demo config file for you named [`bisenet_customer.py`](./configs/bisenet_customer.py). You can start from this conig file.


## train
Expand Down
23 changes: 23 additions & 0 deletions configs/bisenet_customer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

cfg = dict(
model_type='bisenetv1',
n_cats=20,
num_aux_heads=2,
lr_start=1e-2,
weight_decay=5e-4,
warmup_iters=1000,
max_iter=80000,
dataset='CustomerDataset',
im_root='./datasets/cityscapes',
train_im_anns='./datasets/cityscapes/train.txt',
val_im_anns='./datasets/cityscapes/val.txt',
scales=[0.75, 2.],
cropsize=[512, 512],
eval_crop=[512, 512],
eval_scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
ims_per_gpu=8,
eval_ims_per_gpu=2,
use_fp16=True,
use_sync_bn=False,
respth='./res',
)
2 changes: 2 additions & 0 deletions lib/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .get_dataloader import get_data_loader
1 change: 0 additions & 1 deletion lib/base_dataset.py → lib/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import cv2
import numpy as np

from lib.sampler import RepeatedDistSampler



Expand Down
4 changes: 2 additions & 2 deletions lib/cityscapes_cv2.py → lib/data/cityscapes_cv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import cv2
import numpy as np

import lib.transform_cv2 as T
from lib.base_dataset import BaseDataset
import lib.data.transform_cv2 as T
from lib.data.base_dataset import BaseDataset


labels_info = [
Expand Down
6 changes: 3 additions & 3 deletions lib/coco.py → lib/data/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
import cv2
import numpy as np

import lib.transform_cv2 as T
from lib.sampler import RepeatedDistSampler
from lib.base_dataset import BaseDataset
import lib.data.transform_cv2 as T
from lib.data.base_dataset import BaseDataset

'''
91(thing) + 91(stuff) = 182 classes, label proportions are:
Expand Down Expand Up @@ -51,6 +50,7 @@ def __init__(self, dataroot, annpath, trans_func=None, mode='train'):
super(CocoStuff, self).__init__(
dataroot, annpath, trans_func, mode)
self.n_cats = 171 # 91 stuff, 91 thing, 11 of thing have no annos
self.lb_ignore = 255

## label mapping, remove non-existing labels
missing = [11, 25, 28, 29, 44, 65, 67, 68, 70, 82, 90]
Expand Down
21 changes: 21 additions & 0 deletions lib/data/customer_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/python
# -*- encoding: utf-8 -*-


import lib.data.transform_cv2 as T
from lib.data.base_dataset import BaseDataset


class CustomerDataset(BaseDataset):

def __init__(self, dataroot, annpath, trans_func=None, mode='train'):
super(CustomerDataset, self).__init__(
dataroot, annpath, trans_func, mode)
self.lb_ignore = 255

self.to_tensor = T.ToTensor(
mean=(0.4, 0.4, 0.4), # rgb
std=(0.2, 0.2, 0.2),
)


10 changes: 6 additions & 4 deletions lib/get_dataloader.py → lib/data/get_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist

import lib.transform_cv2 as T
from lib.sampler import RepeatedDistSampler
from lib.cityscapes_cv2 import CityScapes
from lib.coco import CocoStuff
import lib.data.transform_cv2 as T
from lib.data.sampler import RepeatedDistSampler

from lib.data.cityscapes_cv2 import CityScapes
from lib.data.coco import CocoStuff
from lib.data.customer_dataset import CustomerDataset



Expand Down
File renamed without changes.
File renamed without changes.
14 changes: 0 additions & 14 deletions run.sh

This file was deleted.

121 changes: 121 additions & 0 deletions tools/check_dataset_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@

import os
import os.path as osp
import argparse
from tqdm import tqdm

import cv2
import numpy as np


parse = argparse.ArgumentParser()
parse.add_argument('--im_root', dest='im_root', type=str, default='./datasets/cityscapes',)
parse.add_argument('--im_anns', dest='im_anns', type=str, default='./datasets/cityscapes/train.txt',)
parse.add_argument('--lb_ignore', dest='lb_ignore', type=int, default=255)
args = parse.parse_args()

lb_ignore = args.lb_ignore


with open(args.im_anns, 'r') as fr:
lines = fr.read().splitlines()

n_pairs = len(lines)
impaths, lbpaths = [], []
for l in lines:
impth, lbpth = l.split(',')
impth = osp.join(args.im_root, impth)
lbpth = osp.join(args.im_root, lbpth)
impaths.append(impth)
lbpaths.append(lbpth)


## shapes
max_shape_area, min_shape_area = [0, 0], [100000, 100000]
max_shape_height, min_shape_height = [0, 0], [100000, 100000]
max_shape_width, min_shape_width = [0, 0], [100000, 100000]
max_lb_val, min_lb_val = -1, 10000000
for impth, lbpth in tqdm(zip(impaths, lbpaths), total=n_pairs):
im = cv2.imread(impth)[:, :, ::-1]
lb = cv2.imread(lbpth, 0)
assert im.shape[:2] == lb.shape

shape = lb.shape
area = shape[0] * shape[1]
if area > max_shape_area[0] * max_shape_area[1]:
max_shape_area = shape
if area < min_shape_area[0] * min_shape_area[1]:
min_shape_area = shape

if shape[0] > max_shape_height[0]:
max_shape_height = shape
if shape[0] < min_shape_height[0]:
min_shape_height = shape

if shape[1] > max_shape_width[1]:
max_shape_width = shape
if shape[1] < min_shape_width[1]:
min_shape_width = shape

lb = lb[lb != lb_ignore]
if lb.size > 0:
max_lb_val = max(max_lb_val, np.max(lb))
min_lb_val = min(min_lb_val, np.min(lb))

min_lb_val = 0
max_lb_val = 181
lb_minlength = 182
## label info
lb_minlength = max_lb_val+1-min_lb_val
lb_hist = np.zeros(lb_minlength)
for lbpth in tqdm(lbpaths):
lb = cv2.imread(lbpth, 0)
lb = lb[lb != lb_ignore] + min_lb_val
lb_hist += np.bincount(lb, minlength=lb_minlength)

lb_missing_vals = [ind + min_lb_val
for ind, el in enumerate(lb_hist.tolist()) if el == 0]
lb_ratios = (lb_hist / lb_hist.sum()).tolist()


## pixel mean/std
rgb_mean = np.zeros(3).astype(np.float32)
n_pixels = 0
for impth in tqdm(impaths):
im = cv2.imread(impth)[:, :, ::-1].astype(np.float32)
im = im.reshape(-1, 3) / 255.
n_pixels += im.shape[0]
rgb_mean += im.sum(axis=0)
rgb_mean = (rgb_mean / n_pixels)

rgb_std = np.zeros(3).astype(np.float32)
for impth in tqdm(impaths):
im = cv2.imread(impth)[:, :, ::-1].astype(np.float32)
im = im.reshape(-1, 3) / 255.

a = (im - rgb_mean.reshape(1, 3)) ** 2
rgb_std += a.sum(axis=0)
rgb_std = (rgb_std / n_pixels) ** 0.5

rgb_mean = rgb_mean.tolist()
rgb_std = rgb_std.tolist()


print('\n')
print(f'there are {n_pairs} lines in {args.im_anns}, which means {n_pairs} image/label image pairs')
print('\n')

print(f'max and min image shapes by area are: {max_shape_area}, {min_shape_area}')
print(f'max and min image shapes by height are: {max_shape_height}, {min_shape_height}')
print(f'max and min image shapes by width are: {max_shape_width}, {min_shape_width}')
print('\n')

print(f'we ignore label value of {args.lb_ignore} in label images')
print(f'label values are within range of [{min_lb_val}, {max_lb_val}]')
print(f'label values that are missing: {lb_missing_vals}')
print('ratios of each label value: ')
print('\t', lb_ratios)
print('\n')

print('pixel mean rgb: ', rgb_mean)
print('pixel std rgb: ', rgb_std)
2 changes: 1 addition & 1 deletion tools/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from lib.models import model_factory
from configs import set_cfg_from_file
from lib.logger import setup_logger
from lib.get_dataloader import get_data_loader
from lib.data import get_data_loader


def get_round_size(size, divisor=32):
Expand Down
2 changes: 1 addition & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from lib.models import model_factory
from configs import set_cfg_from_file
from lib.get_dataloader import get_data_loader
from lib.data import get_data_loader
from tools.evaluate import eval_model
from lib.ohem_ce_loss import OhemCELoss
from lib.lr_scheduler import WarmupPolyLrScheduler
Expand Down
2 changes: 1 addition & 1 deletion tools/train_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from lib.models import model_factory
from configs import set_cfg_from_file
from lib.get_dataloader import get_data_loader
from lib.data import get_data_loader
from evaluate import eval_model
from lib.ohem_ce_loss import OhemCELoss
from lib.lr_scheduler import WarmupPolyLrScheduler
Expand Down

0 comments on commit 8c6d0d2

Please sign in to comment.