Skip to content

Commit

Permalink
COCO Segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ffiirree committed Jan 22, 2022
1 parent 0a4b33d commit 111c732
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 35 deletions.
110 changes: 110 additions & 0 deletions cvm/utils/coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import copy
import os

import torch
import torch.utils.data
import torchvision
from PIL import Image
from pycocotools import mask as coco_mask
from .seg_transforms import Compose


class FilterAndRemapCocoCategories:
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap

def __call__(self, image, anno):
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
return image, anno
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
return image, anno

def __repr__(self) -> str:
return self.__class__.__name__ + f'()'


def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks


class ConvertCocoPolysToMask:
def __call__(self, image, anno):
w, h = image.size
segmentations = [obj["segmentation"] for obj in anno]
cats = [obj["category_id"] for obj in anno]
if segmentations:
masks = convert_coco_poly_to_mask(segmentations, h, w)
cats = torch.as_tensor(cats, dtype=masks.dtype)
# merge all instance masks into a single segmentation map
# with its corresponding categories
target, _ = (masks * cats[:, None, None]).max(dim=0)
# discard overlapping instances
target[masks.sum(0) > 1] = 255
else:
target = torch.zeros((h, w), dtype=torch.uint8)
target = Image.fromarray(target.numpy())
return image, target

def __repr__(self) -> str:
return self.__class__.__name__ + f'()'


def _coco_remove_images_without_annotations(dataset, cat_list=None):
def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if more than 1k pixels occupied in the image
return sum(obj["area"] for obj in anno) > 1000

assert isinstance(dataset, torchvision.datasets.CocoDetection)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.coco.loadAnns(ann_ids)
if cat_list:
anno = [obj for obj in anno if obj["category_id"] in cat_list]
if _has_valid_annotation(anno):
ids.append(ds_idx)

dataset = torch.utils.data.Subset(dataset, ids)
return dataset


def get_coco(root, image_set, transforms):
PATHS = {
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
}
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]

transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])

img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)

dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)

if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)

return dataset
66 changes: 46 additions & 20 deletions cvm/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import cvm
from cvm.data.samplers import RASampler
from cvm.utils.coco import get_coco
from . import seg_transforms as ST
from .utils import group_params, list_datasets, get_world_size
from cvm.data.constants import *
Expand Down Expand Up @@ -202,12 +203,13 @@ def create_optimizer(
no_bias_bn_wd: bool = False,
**kwargs
):
params = group_params(
params,
weight_decay,
no_bias_bn_wd,
lr
)
if isinstance(params, nn.Module):
params = group_params(
params,
weight_decay,
no_bias_bn_wd,
lr
)

if name == 'sgd':
return optim.SGD(
Expand Down Expand Up @@ -285,6 +287,8 @@ def _get_dataset_mean_or_std(name, attr):
return MNIST_MEAN if attr == 'mean' else MNIST_STD
if name.startswith('voc') or name.startswith('sbd'):
return VOC_MEAN if attr == 'mean' else VOC_STD
if name.startswith('coco'):
return VOC_MEAN if attr == 'mean' else VOC_STD
return IMAGE_MEAN if attr == 'mean' else IMAGE_STD


Expand Down Expand Up @@ -393,7 +397,8 @@ def create_segmentation_transforms(
):
ops = []
if is_training:
ops.append(ST.RandomCrop(crop_size, pad_if_needed=True, padding=padding))
# ops.append(ST.RandomCrop(crop_size, pad_if_needed=True, padding=padding))
ops.append(ST.RandomResizedCrop(crop_size, (0.5, 2.0), interpolation=interpolation))
if hflip > 0.0:
ops.append(ST.RandomHorizontalFlip(hflip))
else:
Expand All @@ -417,17 +422,25 @@ def create_dataset(
dataset = datasets.__dict__[name]
params = inspect.signature(dataset.__init__).parameters.keys()

if 'mode' in params and 'image_set' in params:
return datasets.__dict__[name](
path.expanduser(root),
mode='segmentation',
image_set='train' if is_training else 'val',
download=(download and is_training)
)

if 'image_set' in params:
return datasets.__dict__[name](
path.expanduser(root),
image_set='train' if is_training else 'val',
download=download
download=(download and is_training)
)

return datasets.__dict__[name](
path.expanduser(root),
train=is_training,
download=download
download=(download and is_training)
)
elif name == 'ImageNet':
return datasets.ImageFolder(
Expand Down Expand Up @@ -463,6 +476,7 @@ def create_loader(
ra_repetitions: int = 0,
transform: T.Compose = None,
taskname: str = 'classification',
collate_fn=None,
**kwargs
):
assert taskname in ['classification', 'segmentation'], f'Unknown task: {taskname}.'
Expand Down Expand Up @@ -499,16 +513,8 @@ def create_loader(
), 'dali')
# Pytorch/Vision
else:
if isinstance(dataset, str):
dataset = create_dataset(
dataset,
root=root,
is_training=is_training,
**kwargs
)

if taskname == 'classification':
dataset.transform = transform or create_transforms(
transform = transform or create_transforms(
is_training=is_training,
random_scale=kwargs.get('random_scale', [0.08, 1.0]),
interpolation=T.InterpolationMode(interpolation),
Expand All @@ -529,7 +535,7 @@ def create_loader(
dataset_image_size=_get_dataset_image_size(dataset),
)
elif taskname == 'segmentation':
dataset.transforms = transform or create_segmentation_transforms(
transform = transform or create_segmentation_transforms(
is_training=is_training,
interpolation=T.InterpolationMode(interpolation),
hflip=hflip,
Expand All @@ -539,6 +545,24 @@ def create_loader(
crop_size=crop_size if is_training else val_crop_size,
)

if dataset == 'CocoDetection':
dataset = get_coco(
root=root,
image_set='train' if is_training else 'val',
transforms=transform
)
elif isinstance(dataset, str):
dataset = create_dataset(
dataset,
root=root,
is_training=is_training,
**kwargs
)
if taskname == 'classification':
dataset.transform = transform
elif taskname == 'segmentation':
dataset.transforms = transform

sampler = None
if distributed:
if ra_repetitions > 0 and is_training:
Expand All @@ -552,5 +576,7 @@ def create_loader(
num_workers=workers,
pin_memory=pin_memory,
sampler=sampler,
shuffle=((not distributed) and is_training)
shuffle=((not distributed) and is_training),
collate_fn=collate_fn,
drop_last=(is_training and taskname == 'segmentation')
), 'torch')
18 changes: 17 additions & 1 deletion cvm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
'named_layers', 'AverageMeter',
'module_parameters', 'group_params', 'list_models',
'list_datasets', 'is_dist_avail_and_initialized', 'get_world_size',
'init_distributed_mode', 'mask_to_label'
'init_distributed_mode', 'mask_to_label', 'seg_collate_fn'
]


Expand Down Expand Up @@ -277,3 +277,19 @@ def mask_to_label(masks, num_classes):
for j in range(num_classes):
labels[i][j] = bool((masks[i] == j).sum())
return labels.float()


def cat_list(images, fill_value=0):
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
batch_shape = (len(images),) + max_size
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
for img, pad_img in zip(images, batched_imgs):
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
return batched_imgs


def seg_collate_fn(batch):
images, targets = list(zip(*batch))
batched_imgs = cat_list(images, fill_value=0)
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets
2 changes: 1 addition & 1 deletion cvm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.18'
__version__ = '0.0.19'
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def parse_args():
parser.add_argument('--model-path', type=str, default=None)
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes')
parser.add_argument('--in-channels', type=int, default=3, metavar='N')
parser.add_argument('--bn-eps', type=float, default=None)
parser.add_argument('--bn-momentum', type=float, default=None)

Expand Down Expand Up @@ -201,6 +202,7 @@ def validate(val_loader, model, criterion):

model = create_model(
args.model,
in_channels=args.in_channels,
num_classes=args.num_classes,
dropout_rate=args.dropout_rate,
drop_path_rate=args.drop_path_rate,
Expand Down
29 changes: 16 additions & 13 deletions train_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def parse_args():
help='number of data loading workers pre GPU. (default: 4)')
parser.add_argument('--batch-size', type=int, default=1, metavar='N',
help='mini-batch size, this is the total batch size of all GPUs. (default: 256)')
parser.add_argument('--crop-size', type=int, default=320)
parser.add_argument('--crop-size', type=int, default=480)
parser.add_argument('--crop-padding', type=int, default=4, metavar='S')
parser.add_argument('--val-resize-size', type=int, default=384)
parser.add_argument('--val-crop-size', type=int, default=384)
parser.add_argument('--val-resize-size', type=int, default=520)
parser.add_argument('--val-crop-size', type=int, default=520)

# model
parser.add_argument('--model', type=str, default='seg/fcn_regnet_x_400mf', choices=list_models(),
Expand Down Expand Up @@ -120,7 +120,7 @@ def train(train_loader, model, criterion, optimizer, scheduler, scaler, epoch, a
outputs = model(images)
loss = criterion(outputs['out'], targets)
if args.aux_loss:
loss += criterion(outputs['aux'], targets)
loss += 0.5 * criterion(outputs['aux'], targets)

scaler.scale(loss).backward()
scaler.step(optimizer)
Expand Down Expand Up @@ -161,13 +161,7 @@ def validate(val_loader, model, args):
confmat.update(predictions.argmax(1).flatten(), targets.flatten())

confmat.all_reduce()
iou = [f'{v*100:>4.1f}' for v in confmat.iou]
pa = [f'{v*100:>4.1f}' for v in confmat.mean_pa]
logger.info(f'\nPA = {pa}'
f'\ngloabal PA = {confmat.pa*100:>4.1f}'
f'\nIoU = {iou}'
f'\nmean IoU = {confmat.mean_iou*100:>4.1f}')

logger.info(f'gloabal PA = {confmat.pa*100:>5.2f}, mean IoU = {confmat.mean_iou*100:>5.2f}')

if __name__ == '__main__':
assert torch.cuda.is_available(), 'CUDA IS NOT AVAILABLE!!'
Expand Down Expand Up @@ -205,7 +199,15 @@ def validate(val_loader, model, args):
local_rank=args.local_rank
)

optimizer = create_optimizer(args.optim, model, **dict(vars(args)))
params_to_optimize = [
{"params": [p for p in model.module.backbone.parameters() if p.requires_grad]},
{"params": [p for p in model.module.decode_head.parameters() if p.requires_grad]},
]
if args.aux_loss:
params = [p for p in model.module.aux_head.parameters() if p.requires_grad]
params_to_optimize.append({"params": params, "lr": args.lr * 10})

optimizer = create_optimizer(args.optim, params_to_optimize, **dict(vars(args)))
criterion = nn.CrossEntropyLoss(ignore_index=255)

train_loader = create_loader(
Expand All @@ -220,6 +222,7 @@ def validate(val_loader, model, args):
root=args.data_dir,
is_training=False,
taskname='segmentation',
collate_fn=seg_collate_fn,
**(dict(vars(args)))
)

Expand All @@ -238,7 +241,7 @@ def validate(val_loader, model, args):

if args.local_rank == 0:
logger.info(f'Model: \n{model}')
if not args.dali:
if not args.dali and isinstance(train_loader.dataset, (torchvision.datasets.VisionDataset)):
logger.info(f'Training: \n{train_loader.dataset.transforms}')
logger.info(f'Validation: \n{val_loader.dataset.transforms}')
logger.info(f'Optimizer: \n{optimizer}')
Expand Down

0 comments on commit 111c732

Please sign in to comment.