diff --git a/cvm/utils/coco.py b/cvm/utils/coco.py new file mode 100644 index 0000000..a0aa188 --- /dev/null +++ b/cvm/utils/coco.py @@ -0,0 +1,110 @@ +import copy +import os + +import torch +import torch.utils.data +import torchvision +from PIL import Image +from pycocotools import mask as coco_mask +from .seg_transforms import Compose + + +class FilterAndRemapCocoCategories: + def __init__(self, categories, remap=True): + self.categories = categories + self.remap = remap + + def __call__(self, image, anno): + anno = [obj for obj in anno if obj["category_id"] in self.categories] + if not self.remap: + return image, anno + anno = copy.deepcopy(anno) + for obj in anno: + obj["category_id"] = self.categories.index(obj["category_id"]) + return image, anno + + def __repr__(self) -> str: + return self.__class__.__name__ + f'()' + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask: + def __call__(self, image, anno): + w, h = image.size + segmentations = [obj["segmentation"] for obj in anno] + cats = [obj["category_id"] for obj in anno] + if segmentations: + masks = convert_coco_poly_to_mask(segmentations, h, w) + cats = torch.as_tensor(cats, dtype=masks.dtype) + # merge all instance masks into a single segmentation map + # with its corresponding categories + target, _ = (masks * cats[:, None, None]).max(dim=0) + # discard overlapping instances + target[masks.sum(0) > 1] = 255 + else: + target = torch.zeros((h, w), dtype=torch.uint8) + target = Image.fromarray(target.numpy()) + return image, target + + def __repr__(self) -> str: + return self.__class__.__name__ + f'()' + + +def _coco_remove_images_without_annotations(dataset, cat_list=None): + def _has_valid_annotation(anno): + # if it's empty, there is no annotation + if len(anno) == 0: + return False + # if more than 1k pixels occupied in the image + return sum(obj["area"] for obj in anno) > 1000 + + assert isinstance(dataset, torchvision.datasets.CocoDetection) + ids = [] + for ds_idx, img_id in enumerate(dataset.ids): + ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = dataset.coco.loadAnns(ann_ids) + if cat_list: + anno = [obj for obj in anno if obj["category_id"] in cat_list] + if _has_valid_annotation(anno): + ids.append(ds_idx) + + dataset = torch.utils.data.Subset(dataset, ids) + return dataset + + +def get_coco(root, image_set, transforms): + PATHS = { + "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), + "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), + # "train": ("val2017", os.path.join("annotations", "instances_val2017.json")) + } + CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] + + transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms]) + + img_folder, ann_file = PATHS[image_set] + img_folder = os.path.join(root, img_folder) + ann_file = os.path.join(root, ann_file) + + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) + + if image_set == "train": + dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST) + + return dataset diff --git a/cvm/utils/factory.py b/cvm/utils/factory.py index 280ef45..73bef29 100644 --- a/cvm/utils/factory.py +++ b/cvm/utils/factory.py @@ -11,6 +11,7 @@ import cvm from cvm.data.samplers import RASampler +from cvm.utils.coco import get_coco from . import seg_transforms as ST from .utils import group_params, list_datasets, get_world_size from cvm.data.constants import * @@ -202,12 +203,13 @@ def create_optimizer( no_bias_bn_wd: bool = False, **kwargs ): - params = group_params( - params, - weight_decay, - no_bias_bn_wd, - lr - ) + if isinstance(params, nn.Module): + params = group_params( + params, + weight_decay, + no_bias_bn_wd, + lr + ) if name == 'sgd': return optim.SGD( @@ -285,6 +287,8 @@ def _get_dataset_mean_or_std(name, attr): return MNIST_MEAN if attr == 'mean' else MNIST_STD if name.startswith('voc') or name.startswith('sbd'): return VOC_MEAN if attr == 'mean' else VOC_STD + if name.startswith('coco'): + return VOC_MEAN if attr == 'mean' else VOC_STD return IMAGE_MEAN if attr == 'mean' else IMAGE_STD @@ -393,7 +397,8 @@ def create_segmentation_transforms( ): ops = [] if is_training: - ops.append(ST.RandomCrop(crop_size, pad_if_needed=True, padding=padding)) + # ops.append(ST.RandomCrop(crop_size, pad_if_needed=True, padding=padding)) + ops.append(ST.RandomResizedCrop(crop_size, (0.5, 2.0), interpolation=interpolation)) if hflip > 0.0: ops.append(ST.RandomHorizontalFlip(hflip)) else: @@ -417,17 +422,25 @@ def create_dataset( dataset = datasets.__dict__[name] params = inspect.signature(dataset.__init__).parameters.keys() + if 'mode' in params and 'image_set' in params: + return datasets.__dict__[name]( + path.expanduser(root), + mode='segmentation', + image_set='train' if is_training else 'val', + download=(download and is_training) + ) + if 'image_set' in params: return datasets.__dict__[name]( path.expanduser(root), image_set='train' if is_training else 'val', - download=download + download=(download and is_training) ) return datasets.__dict__[name]( path.expanduser(root), train=is_training, - download=download + download=(download and is_training) ) elif name == 'ImageNet': return datasets.ImageFolder( @@ -463,6 +476,7 @@ def create_loader( ra_repetitions: int = 0, transform: T.Compose = None, taskname: str = 'classification', + collate_fn=None, **kwargs ): assert taskname in ['classification', 'segmentation'], f'Unknown task: {taskname}.' @@ -499,16 +513,8 @@ def create_loader( ), 'dali') # Pytorch/Vision else: - if isinstance(dataset, str): - dataset = create_dataset( - dataset, - root=root, - is_training=is_training, - **kwargs - ) - if taskname == 'classification': - dataset.transform = transform or create_transforms( + transform = transform or create_transforms( is_training=is_training, random_scale=kwargs.get('random_scale', [0.08, 1.0]), interpolation=T.InterpolationMode(interpolation), @@ -529,7 +535,7 @@ def create_loader( dataset_image_size=_get_dataset_image_size(dataset), ) elif taskname == 'segmentation': - dataset.transforms = transform or create_segmentation_transforms( + transform = transform or create_segmentation_transforms( is_training=is_training, interpolation=T.InterpolationMode(interpolation), hflip=hflip, @@ -539,6 +545,24 @@ def create_loader( crop_size=crop_size if is_training else val_crop_size, ) + if dataset == 'CocoDetection': + dataset = get_coco( + root=root, + image_set='train' if is_training else 'val', + transforms=transform + ) + elif isinstance(dataset, str): + dataset = create_dataset( + dataset, + root=root, + is_training=is_training, + **kwargs + ) + if taskname == 'classification': + dataset.transform = transform + elif taskname == 'segmentation': + dataset.transforms = transform + sampler = None if distributed: if ra_repetitions > 0 and is_training: @@ -552,5 +576,7 @@ def create_loader( num_workers=workers, pin_memory=pin_memory, sampler=sampler, - shuffle=((not distributed) and is_training) + shuffle=((not distributed) and is_training), + collate_fn=collate_fn, + drop_last=(is_training and taskname == 'segmentation') ), 'torch') diff --git a/cvm/utils/utils.py b/cvm/utils/utils.py index dd02efc..55e2413 100644 --- a/cvm/utils/utils.py +++ b/cvm/utils/utils.py @@ -23,7 +23,7 @@ 'named_layers', 'AverageMeter', 'module_parameters', 'group_params', 'list_models', 'list_datasets', 'is_dist_avail_and_initialized', 'get_world_size', - 'init_distributed_mode', 'mask_to_label' + 'init_distributed_mode', 'mask_to_label', 'seg_collate_fn' ] @@ -277,3 +277,19 @@ def mask_to_label(masks, num_classes): for j in range(num_classes): labels[i][j] = bool((masks[i] == j).sum()) return labels.float() + + +def cat_list(images, fill_value=0): + max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) + batch_shape = (len(images),) + max_size + batched_imgs = images[0].new(*batch_shape).fill_(fill_value) + for img, pad_img in zip(images, batched_imgs): + pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img) + return batched_imgs + + +def seg_collate_fn(batch): + images, targets = list(zip(*batch)) + batched_imgs = cat_list(images, fill_value=0) + batched_targets = cat_list(targets, fill_value=255) + return batched_imgs, batched_targets diff --git a/cvm/version.py b/cvm/version.py index 1ac739d..423b798 100644 --- a/cvm/version.py +++ b/cvm/version.py @@ -1 +1 @@ -__version__ = '0.0.18' \ No newline at end of file +__version__ = '0.0.19' \ No newline at end of file diff --git a/train.py b/train.py index 1324b86..ea34f9d 100644 --- a/train.py +++ b/train.py @@ -36,6 +36,7 @@ def parse_args(): parser.add_argument('--model-path', type=str, default=None) parser.add_argument('--num-classes', type=int, default=1000, metavar='N', help='number of label classes') + parser.add_argument('--in-channels', type=int, default=3, metavar='N') parser.add_argument('--bn-eps', type=float, default=None) parser.add_argument('--bn-momentum', type=float, default=None) @@ -201,6 +202,7 @@ def validate(val_loader, model, criterion): model = create_model( args.model, + in_channels=args.in_channels, num_classes=args.num_classes, dropout_rate=args.dropout_rate, drop_path_rate=args.drop_path_rate, diff --git a/train_seg.py b/train_seg.py index 7feea87..fc9b1c1 100644 --- a/train_seg.py +++ b/train_seg.py @@ -21,10 +21,10 @@ def parse_args(): help='number of data loading workers pre GPU. (default: 4)') parser.add_argument('--batch-size', type=int, default=1, metavar='N', help='mini-batch size, this is the total batch size of all GPUs. (default: 256)') - parser.add_argument('--crop-size', type=int, default=320) + parser.add_argument('--crop-size', type=int, default=480) parser.add_argument('--crop-padding', type=int, default=4, metavar='S') - parser.add_argument('--val-resize-size', type=int, default=384) - parser.add_argument('--val-crop-size', type=int, default=384) + parser.add_argument('--val-resize-size', type=int, default=520) + parser.add_argument('--val-crop-size', type=int, default=520) # model parser.add_argument('--model', type=str, default='seg/fcn_regnet_x_400mf', choices=list_models(), @@ -120,7 +120,7 @@ def train(train_loader, model, criterion, optimizer, scheduler, scaler, epoch, a outputs = model(images) loss = criterion(outputs['out'], targets) if args.aux_loss: - loss += criterion(outputs['aux'], targets) + loss += 0.5 * criterion(outputs['aux'], targets) scaler.scale(loss).backward() scaler.step(optimizer) @@ -161,13 +161,7 @@ def validate(val_loader, model, args): confmat.update(predictions.argmax(1).flatten(), targets.flatten()) confmat.all_reduce() - iou = [f'{v*100:>4.1f}' for v in confmat.iou] - pa = [f'{v*100:>4.1f}' for v in confmat.mean_pa] - logger.info(f'\nPA = {pa}' - f'\ngloabal PA = {confmat.pa*100:>4.1f}' - f'\nIoU = {iou}' - f'\nmean IoU = {confmat.mean_iou*100:>4.1f}') - + logger.info(f'gloabal PA = {confmat.pa*100:>5.2f}, mean IoU = {confmat.mean_iou*100:>5.2f}') if __name__ == '__main__': assert torch.cuda.is_available(), 'CUDA IS NOT AVAILABLE!!' @@ -205,7 +199,15 @@ def validate(val_loader, model, args): local_rank=args.local_rank ) - optimizer = create_optimizer(args.optim, model, **dict(vars(args))) + params_to_optimize = [ + {"params": [p for p in model.module.backbone.parameters() if p.requires_grad]}, + {"params": [p for p in model.module.decode_head.parameters() if p.requires_grad]}, + ] + if args.aux_loss: + params = [p for p in model.module.aux_head.parameters() if p.requires_grad] + params_to_optimize.append({"params": params, "lr": args.lr * 10}) + + optimizer = create_optimizer(args.optim, params_to_optimize, **dict(vars(args))) criterion = nn.CrossEntropyLoss(ignore_index=255) train_loader = create_loader( @@ -220,6 +222,7 @@ def validate(val_loader, model, args): root=args.data_dir, is_training=False, taskname='segmentation', + collate_fn=seg_collate_fn, **(dict(vars(args))) ) @@ -238,7 +241,7 @@ def validate(val_loader, model, args): if args.local_rank == 0: logger.info(f'Model: \n{model}') - if not args.dali: + if not args.dali and isinstance(train_loader.dataset, (torchvision.datasets.VisionDataset)): logger.info(f'Training: \n{train_loader.dataset.transforms}') logger.info(f'Validation: \n{val_loader.dataset.transforms}') logger.info(f'Optimizer: \n{optimizer}')