diff --git a/icevision/core/record_components.py b/icevision/core/record_components.py index 911bb1d1f..3df79ab25 100644 --- a/icevision/core/record_components.py +++ b/icevision/core/record_components.py @@ -111,14 +111,16 @@ def __init__(self, task=tasks.common): super().__init__(task=task) self.img = None - def set_img(self, img: Union[PIL.Image.Image, np.ndarray]): - assert isinstance(img, (PIL.Image.Image, np.ndarray)) + def set_img(self, img: Union[PIL.Image.Image, np.ndarray, torch.Tensor]): + assert isinstance(img, (PIL.Image.Image, np.ndarray, torch.Tensor)) self.img = img if isinstance(img, PIL.Image.Image): height, width = img.shape elif isinstance(img, np.ndarray): # else: height, width, _ = self.img.shape + elif isinstance(img, torch.Tensor): + _, height, width = self.img.shape # this should set on SizeRecordComponent self.composite.set_img_size(ImgSize(width=width, height=height), original=True) diff --git a/icevision/imports.py b/icevision/imports.py index 1f19a4669..6ee9b658d 100644 --- a/icevision/imports.py +++ b/icevision/imports.py @@ -51,7 +51,7 @@ CosineAnnealingWarmRestarts, ) -from torchvision.transforms.functional import to_tensor as im2tensor +from torchvision.transforms.functional import to_tensor from loguru import logger @@ -92,3 +92,14 @@ def __str__(self): def __repr__(self): return str(self) + + +def im2tensor(pic: Union[np.ndarray, PIL.Image.Image, torch.Tensor]): + if isinstance(pic, torch.Tensor): + return pic + elif isinstance(pic, (np.ndarray, PIL.Image.Image)): + return to_tensor(pic) + else: + raise TypeError( + f"Expected {np.ndarray} | {PIL.Image.Image} | {torch.Tensor}, got {type(pic)}" + ) diff --git a/icevision/models/multitask/__init__.py b/icevision/models/multitask/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/icevision/models/multitask/classification_heads/__init__.py b/icevision/models/multitask/classification_heads/__init__.py new file mode 100644 index 000000000..c82f63d71 --- /dev/null +++ b/icevision/models/multitask/classification_heads/__init__.py @@ -0,0 +1,2 @@ +from .builder import * +from .head import * diff --git a/icevision/models/multitask/classification_heads/builder.py b/icevision/models/multitask/classification_heads/builder.py new file mode 100644 index 000000000..1d81b99d0 --- /dev/null +++ b/icevision/models/multitask/classification_heads/builder.py @@ -0,0 +1,47 @@ +from typing import Dict +from .head import CLASSIFICATION_HEADS, ImageClassificationHead, ClassifierConfig +import torch.nn as nn + +__all__ = ["build_classifier_heads", "build_classifier_heads_from_configs"] + +# Enter dict of dicts as `cfg` +def build_classifier_heads(configs: Dict[str, Dict[str, dict]]) -> nn.ModuleDict: + """ + Build classification head from a config which is a dict of dicts. + A head is created for each key in the input dictionary. + + Expected to be used with `mmdet` models as it uses the + `CLASSIFICATION_HEADS` registry internally + + Returns: + a `nn.ModuleDict()` mapping keys from `configs` to classifier heads + """ + heads = nn.ModuleDict() + # if configs is not None: + for name, config in configs.items(): + head = CLASSIFICATION_HEADS.build(config) + heads.update({name: head}) + return heads + + +def build_classifier_heads_from_configs( + configs: Dict[str, ClassifierConfig] = None +) -> nn.ModuleDict: + """ + Build a `nn.ModuleDict` of `ImageClassificationHead`s from a list of `ClassifierConfig`s + """ + if configs is None: + return nn.ModuleDict() + + assert isinstance(configs, dict), f"Expected a `dict`, got {type(configs)}" + if not all(isinstance(cfg, ClassifierConfig) for cfg in configs.values()): + raise ValueError( + f"Expected a `list` of `ClassifierConfig`s \n" + f"Either one or more elements in the list are not of type `ClassifierConfig`" + ) + + heads = nn.ModuleDict() + for name, config in configs.items(): + head = ImageClassificationHead.from_config(config) + heads.update({name: head}) + return heads diff --git a/icevision/models/multitask/classification_heads/head.py b/icevision/models/multitask/classification_heads/head.py new file mode 100644 index 000000000..0b987de26 --- /dev/null +++ b/icevision/models/multitask/classification_heads/head.py @@ -0,0 +1,216 @@ +# Hacked together by Rahul & Farid + +from mmcv.cnn import MODELS as MMCV_MODELS +from mmcv.utils import Registry + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import List, Union, Optional, Dict +from torch import Tensor +from functools import partial +from collections import namedtuple +from dataclasses import dataclass + +TensorList = List[Tensor] +TensorDict = Dict[str, Tensor] + +MODELS = Registry("models", parent=MMCV_MODELS) +CLASSIFICATION_HEADS = MODELS + +__all__ = ["ImageClassificationHead", "ClassifierConfig"] + + +class Passthrough(nn.Module): + def forward(self, x): + return x + + +""" +`ClassifierConfig` is useful to instantiate `ImageClassificationHead` +in different settings. If using `mmdet`, we don't use this as the config +is then a regular dictionary. + +When using yolov5, we can easily pass around this config to create the model +Often, it'll be used inside a dictionary of configs +""" + + +@dataclass +class ClassifierConfig: + # classifier_name: str + out_classes: int + num_fpn_features: int = 512 + fpn_keys: Union[List[str], List[int], None] = None + dropout: Optional[float] = 0.2 + pool_inputs: bool = True + # Loss function args + loss_func: Optional[nn.Module] = None + activation: Optional[nn.Module] = None + multilabel: bool = False + loss_func_wts: Optional[Tensor] = None + loss_weight: float = 1.0 + # Post activation processing + thresh: Optional[float] = None + topk: Optional[int] = None + + def __post_init__(self): + if isinstance(self.fpn_keys, int): + self.fpn_keys = [self.fpn_keys] + + if self.loss_func_wts is not None: + if not self.multilabel: + self.loss_func_wts = self.loss_func_wts.to(torch.float32) + if torch.cuda.is_available(): + self.loss_func_wts = self.loss_func_wts.cuda() + + if self.multilabel: + if self.topk is None and self.thresh is None: + self.thresh = 0.5 + else: + if self.topk is None and self.thresh is None: + self.topk = 1 + + +@CLASSIFICATION_HEADS.register_module(name="ImageClassificationHead") +class ImageClassificationHead(nn.Module): + """ + Image classification head that optionally takes `fpn_keys` features from + an FPN, average pools and concatenates them into a single tensor + of shape `num_features` and then runs a linear layer to `out_classes + + fpn_features: [List[Tensor]] => AvgPool => Flatten => Linear` + + Also includes `compute_loss` to match the design of other + components of object detection systems. + To use your own loss function, pass it into `loss_func`. + If `loss_func` is None (by default), we create one based on other args: + If `multilabel` is true, one-hot encoded targets are expected and + nn.BCEWithLogitsLoss is used, else nn.CrossEntropyLoss is used + and targets are expected to be integers + NOTE: Not all loss function args are exposed + """ + + def __init__( + self, + out_classes: int, + num_fpn_features: int, + fpn_keys: Union[List[str], List[int], None] = None, + dropout: Optional[float] = 0.2, + pool_inputs: bool = True, # ONLY for advanced use cases where input feature maps are already pooled + # Loss function args + loss_func: Optional[nn.Module] = None, + activation: Optional[nn.Module] = None, + multilabel: bool = False, + loss_func_wts: Optional[Tensor] = None, + loss_weight: float = 1.0, + # Final postprocessing args + thresh: Optional[float] = None, + topk: Optional[int] = None, + ): + super().__init__() + + # Setup loss function & activation + self.multilabel = multilabel + self.loss_func, self.loss_func_wts, self.loss_weight = ( + loss_func, + loss_func_wts, + loss_weight, + ) + self.activation = activation + self.pool_inputs = pool_inputs + self.thresh, self.topk = thresh, topk + + # Setup head + self.fpn_keys = fpn_keys + + layers = [ + nn.Dropout(dropout) if dropout else Passthrough(), + nn.Linear(num_fpn_features, out_classes), + ] + layers.insert(0, nn.Flatten(1)) if self.pool_inputs else None + self.classifier = nn.Sequential(*layers) + + self.setup_loss_function() + self.setup_postprocessing() + + def setup_postprocessing(self): + if self.multilabel: + if self.topk is None and self.thresh is None: + self.thresh = 0.5 + else: + if self.topk is None and self.thresh is None: + self.topk = 1 + + def setup_loss_function(self): + if self.loss_func is None: + if self.multilabel: + self.loss_func = nn.BCEWithLogitsLoss(pos_weight=self.loss_func_wts) + # self.loss_func = partial( + # F.binary_cross_entropy_with_logits, pos_weight=self.loss_func_wts + # ) + self.activation = nn.Sigmoid() + # self.activation = torch.sigmoid # nn.Sigmoid() + else: + # self.loss_func = nn.CrossEntropyLoss(self.loss_func_wts) + self.loss_func = nn.CrossEntropyLoss(weight=self.loss_func_wts) + # self.loss_func = partial(F.cross_entropy, weight=self.loss_func_wts) + self.activation = nn.Softmax(-1) + # self.activation = partial(F.softmax, dim=-1) # nn.Softmax(-1) + + @classmethod + def from_config(cls, config: ClassifierConfig): + return cls(**config.__dict__) + + # TODO: Make it run with regular features as well + def forward(self, features: Union[Tensor, TensorDict, TensorList]): + """ + Sequence of outputs from an FPN or regular feature extractor + => Avg. Pool each into 1 dimension + => Concatenate into single tensor + => Linear layer -> output classes + + If `self.fpn_keys` is specified, it grabs the specific (int|str) indices from + `features` for the pooling layer, else it takes _all_ of them + """ + if isinstance(features, (list, dict, tuple)): + # Grab specific features if specified + if self.fpn_keys is not None: + pooled_features = [ + F.adaptive_avg_pool2d(features[k], 1) for k in self.fpn_keys + ] + # If no `fpn_keys` exist, concat all the feature maps (could be expensive) + else: + pooled_features = [F.adaptive_avg_pool2d(feat, 1) for feat in features] + pooled_features = torch.cat(pooled_features, dim=1) + + # If doing regular (non-FPN) feature extraction, we don't need `fpn_keys` and + # just avg. pool the last layer's features + elif isinstance(features, Tensor): + pooled_features = ( + F.adaptive_avg_pool2d(features, 1) if self.pool_inputs else features + ) + else: + raise TypeError( + f"Expected TensorList|TensorDict|Tensor|tuple, got {type(features)}" + ) + + return self.classifier(pooled_features) + + # TorchVision style API + def compute_loss(self, predictions, targets): + return self.loss_weight * self.loss_func(predictions, targets) + + def postprocess(self, predictions): + return self.activation(predictions) + + # MMDet style API + def forward_train(self, x, gt_label) -> Tensor: + preds = self(x) + return self.loss_weight * self.loss_func(preds, gt_label) + + def forward_activate(self, x): + "Run forward pass with activation function" + x = self(x) + return self.activation(x) diff --git a/icevision/models/multitask/data/__init__.py b/icevision/models/multitask/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/icevision/models/multitask/data/dataloading_utils.py b/icevision/models/multitask/data/dataloading_utils.py new file mode 100644 index 000000000..0665e0a1c --- /dev/null +++ b/icevision/models/multitask/data/dataloading_utils.py @@ -0,0 +1,71 @@ +""" +This may be a temporary file that may eventually be removed, +as it only slightly modifies an existing function. +""" + +__all__ = [ + "unload_records", + "assign_classification_targets_from_record", + "massage_multi_aug_classification_data", +] + + +import torch +from icevision.core.record_type import RecordType +from typing import Any, Dict, Optional, Callable, Sequence, Tuple +from icevision.core.record_components import ClassificationLabelsRecordComponent +from torch import tensor + + +def unload_records( + build_batch: Callable, build_batch_kwargs: Optional[Dict] = None +) -> Tuple[Tuple[Any, ...], Sequence[RecordType]]: + """ + This decorator function unloads records to not carry them around after batch creation. + It also optionally accepts `build_batch_kwargs` that are to be passed into + `build_batch`. These aren't accepted as keyword arguments as those are reserved + for PyTorch's DataLoader class which is used later in this chain of function calls + + Args: + build_batch (Callable): A collate function that describes how to mash records + into a batch of inputs for a model + build_batch_kwargs (Optional[Dict], optional): Keyword arguments to pass into + `build_batch`. Defaults to None. + + Returns: + Tuple[Tuple[Any, ...], Sequence[RecordType]]: [description] + """ + build_batch_kwargs = build_batch_kwargs or {} + assert isinstance(build_batch_kwargs, dict) + + def inner(records): + tupled_output, records = build_batch(records, **build_batch_kwargs) + for record in records: + record.unload() + return tupled_output, records + + return inner + + +def assign_classification_targets_from_record(classification_labels: dict, record): + for comp in record.components: + name = comp.task.name + if isinstance(comp, ClassificationLabelsRecordComponent): + if comp.is_multilabel: + labels = comp.one_hot_encoded() + classification_labels[name].append(labels) + else: + labels = comp.label_ids + classification_labels[name].extend(labels) + + +def massage_multi_aug_classification_data( + classification_data, classification_targets, target_key: str +): + for group in classification_data.values(): + group[target_key] = { + task: tensor(classification_targets[task]) for task in group["tasks"] + } + group["images"] = torch.stack(group["images"]) + + return {k: dict(v) for k, v in classification_data.items()} diff --git a/icevision/models/multitask/data/dataset.py b/icevision/models/multitask/data/dataset.py new file mode 100644 index 000000000..67c8cab61 --- /dev/null +++ b/icevision/models/multitask/data/dataset.py @@ -0,0 +1,234 @@ +from icevision.imports import * +from icevision.core import * +from icevision.core.tasks import Task +from torch.utils.data import Dataset +from icevision.data.dataset import Dataset as RecordDataset +from icevision.utils.utils import normalize, flatten + +import icevision.tfms as tfms +import torchvision.transforms as Tfms +import albumentations as A + +__all__ = ["HybridAugmentationsRecordDataset", "RecordDataset"] + + +class HybridAugmentationsRecordDataset(Dataset): + """ + A Dataset that allows you to apply different augmentations to different tasks in your + record. `detection_transforms` are applied to the `detection` task specifically, and + `classification_transforms_groups` describe how to group and apply augmentations to + the classification tasks in the record. + + This object stores the records internally and dynamically attaches an `img` component + to each task when being fetched. Some basic validation is done on init to ensure that + the given transforms cover all tasks described in the record. + + Important NOTE: All images are returned as normalised numpy arrays upon fetching. If + running in `debug` mode, normalisation is skipped and PIL Images are returned inside + the record instead. This is done to facilitate visual inspection of the transforms + applied to the images + + Arguments: + * records: A list of records where only the `common` attribute has an `img`. Upon fetching, + _each_ task in the record will have an `img` attribute added to it based on the + `classification_transforms_groups` + * classification_transforms_groups - Icevision albumentations adapter for detection transforms. + * norm_mean : norm mean stats + * norm_std : norm stdev stats + * debug : If true, prints info & unnormalised `PIL.Image`s are returned on fetching items + + Usage: + Sample record: + BaseRecord + + common: + - Image ID: 4 + - Filepath: sample_image.png + - Image: 640x640x3 Image + - Image size ImgSize(width=640, height=640) + color_saturation: + - Class Map: + - Labels: [1] + shot_composition: + - Class Map: + - Labels: [1] + detection: + - BBoxes: [] + - Class Map: + - Labels: [1] + shot_framing: + - Class Map: + - Labels: [3] + + classification_transforms_groups = { + "group1": dict( + tasks=["shot_composition"], + transforms=Tfms.Compose([ + Tfms.Resize((IMG_HEIGHT, IMG_WIDTH)), + Tfms.RandomPerspective(), + ]) + ), + "group2": dict( + tasks=["color_saturation", "shot_framing"], + transforms=Tfms.Compose([ + Tfms.Resize((IMG_HEIGHT, IMG_WIDTH)), + Tfms.RandomPerspective(), + Tfms.RandomHorizontalFlip(), + Tfms.RandomVerticalFlip(), + ]) + ) + } + import icevision.tfms as tfms + detection_transforms = tfms.A.Adapter([ + tfms.A.Normalize(), + tfms.A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH), + tfms.A.PadIfNeeded(img_H, img_W, border_mode=cv2.BORDER_CONSTANT), + ]) + + dset = HybridAugmentationsRecordDataset( + records=records, + classification_transforms_groups=classification_transforms_groups, + detection_transforms=detection_transforms, + ) + + Returned Record Example: + Note that unlike the input record, each task has an `Image` attribute which + is after the transforms have been applied. In the dataloader, these task specific + images must be used, and the `record.common.img` is just the original image + untransformed that shouldn't be used to train the model + + BaseRecord + + common: + - Image ID: 4 + - Filepath: sample_image.png + - Image: 640x640x3 Image + - Image size ImgSize(width=640, height=640) + color_saturation: + - Image: 640x640x3 Image + - Class Map: + - Labels: [1] + shot_composition: + - Class Map: + - Labels: [1] + - Image: 640x640x3 Image + detection: + - BBoxes: [] + - Image: 640x640x3 Image + - Class Map: + - Labels: [1] + shot_framing: + - Class Map: + - Labels: [3] + - Image: 640x640x3 Image + """ + + def __init__( + self, + records: List[dict], + classification_transforms_groups: dict, + detection_transforms: Optional[tfms.Transform] = None, + # norm_mean: Collection[float] = [0.485, 0.456, 0.406], + # norm_std: Collection[float] = [0.229, 0.224, 0.225], + debug: bool = False, + ): + "Return `PIL.Image` when `debug=True`" + self.records = records + self.classification_transforms_groups = classification_transforms_groups + self.detection_transforms = detection_transforms + # self.norm_mean = norm_mean + # self.norm_std = norm_std + self.debug = debug + self.validate() + + def validate(self): + """ + Input args validation + * Ensure that each value in the `classification_transforms_groups` dict + has a "tasks" and "transforms" key + * Ensure the number of tasks mentioned in `classification_transforms_groups` + match up _exactly_ with the tasks in the record + """ + for group in self.classification_transforms_groups.values(): + assert set(group.keys()).issuperset( + ["tasks", "transforms"] + ), f"Invalid keys in `classification_transforms_groups`" + + missing_tasks = [] + record = self.load_record(0) + for attr in flatten( + [g["tasks"] for g in self.classification_transforms_groups.values()] + ): + if not hasattr(record, attr): + missing_tasks += [attr] + if not missing_tasks == []: + raise ValueError( + f"`classification_transforms_groups` has more groups than are present in the `record`. \n" + f"Missing the following tasks: {missing_tasks}" + ) + + def __len__(self): + return len(self.records) + + def load_record(self, i: int): + """ + Simple record loader. Externalised for easy subclassing for custom behavior + like loading cached records from disk + """ + return self.records[i].load() + + @staticmethod + def dispatch_classification_tfms( + tfm: Union[A.Compose, Tfms.Compose], image: PIL.Image.Image + ): + "Dispatch albu / torchvision transforms with appropriate inp / out formats" + assert isinstance(image, PIL.Image.Image) + if isinstance(tfm, A.Compose): + return tfm(image=np.array(image))["image"] + elif isinstance(tfm, Tfms.Compose): + return tfm(image) + else: + raise TypeError(f"Only Albu | Torchvision transforms supported") + + def __getitem__(self, i): + record = self.load_record(i) + + # Keep a copy of the orig img as it gets modified by albu + original_img: PIL.Image.Image = deepcopy(record.img) + + # Do detection transform and assign it to the detection task + if self.detection_transforms is not None: + record = self.detection_transforms(record) + record.add_component(ImageRecordComponent(Task("detection"))) + record.detection.set_img(record.img) + + if self.debug: + print(f"Fetching Item #{i}") + + # Do classification transforms + if self.classification_transforms_groups is not None: + for group in self.classification_transforms_groups.values(): + img_tfms = group["transforms"] + tfmd_img = self.dispatch_classification_tfms(img_tfms, original_img) + if self.debug: + print(f" Group: {group['tasks']}, ID: {id(tfmd_img)}") + + # NOTE: + # Setting the same img twice (to diff parts in memory) but it's ok cuz we will unload the record later + for task in group["tasks"]: + # TODO FIXME: This adds a component but doesn't display it when printing + # Also, doing `set_img` overrides the base `record.img` + # record.add_component(ImageRecordComponent(Task(task))) + comp = getattr(record, task) + comp.add_component(ImageRecordComponent()) + comp.set_img(tfmd_img) + if self.debug: + print(f" - Task: {task}, ID: {id(tfmd_img)}") + + return record + + def __repr__(self): + return f"<{self.__class__.__name__} with {len(self.records)} items and {len(self.classification_transforms_groups)+1} groups>" diff --git a/icevision/models/multitask/engines/__init__.py b/icevision/models/multitask/engines/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/icevision/models/multitask/engines/lightning/__init__.py b/icevision/models/multitask/engines/lightning/__init__.py new file mode 100644 index 000000000..9a5eeb3d9 --- /dev/null +++ b/icevision/models/multitask/engines/lightning/__init__.py @@ -0,0 +1 @@ +from .lightning_model_adapter import * diff --git a/icevision/models/multitask/engines/lightning/lightning_model_adapter.py b/icevision/models/multitask/engines/lightning/lightning_model_adapter.py new file mode 100644 index 000000000..e3eb0424c --- /dev/null +++ b/icevision/models/multitask/engines/lightning/lightning_model_adapter.py @@ -0,0 +1,59 @@ +__all__ = ["MultiTaskLightningModelAdapter"] + +import pytorch_lightning as pl +from icevision.imports import * +from icevision.metrics import * +from icevision.engines.lightning import LightningModelAdapter +from icevision.models.multitask.utils.dtypes import * + + +class MultiTaskLightningModelAdapter(LightningModelAdapter): + def compute_and_log_classification_metrics( + self, + classification_preds: TensorDict, # activated predictions + yb: TensorDict, + on_step: bool = False, + # prefix: str = "valid", + ): + if not set(classification_preds.keys()) == set(yb.keys()): + raise RuntimeError( + f"Mismatch between prediction and target items. Predictions have " + f"{classification_preds.keys()} keys and targets have {yb.keys()} keys" + ) + # prefix = f"{prefix}/" if not prefix == "" else "" + prefix = "valid/" + for (name, metric), (_, preds) in zip( + self.classification_metrics.items(), classification_preds.items() + ): + self.log( + f"{prefix}{metric.__class__.__name__.lower()}_{name}", # accuracy_{task_name} + metric(preds, yb[name].type(torch.int)), + on_step=on_step, + on_epoch=True, + ) + + def log_losses( + self, + mode: str, + detection_loss: Tensor, + classification_total_loss: Tensor, + classification_losses: TensorDict, + ): + log_vars = dict( + total_loss=detection_loss + classification_total_loss, + detection_loss=detection_loss, + classification_total_loss=classification_total_loss, + **{ + f"classification_loss_{name}": loss + for name, loss in classification_losses.items() + }, + ) + for k, v in log_vars.items(): + self.log(f"{mode}/{k}", v.item() if isinstance(v, torch.Tensor) else v) + + def validation_epoch_end(self, outs): + self.finalize_metrics() + + # Modest speedup (See https://pytorch-lightning.readthedocs.io/en/stable/benchmarking/performance.html#zero-grad-set-to-none-true) + def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + optimizer.zero_grad(set_to_none=True) diff --git a/icevision/models/multitask/mmdet/__init__.py b/icevision/models/multitask/mmdet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/icevision/models/multitask/mmdet/dataloaders.py b/icevision/models/multitask/mmdet/dataloaders.py new file mode 100644 index 000000000..e2f2b02d3 --- /dev/null +++ b/icevision/models/multitask/mmdet/dataloaders.py @@ -0,0 +1,159 @@ +# from icevision.all import * +from icevision.imports import * +from icevision.core import * +from icevision.models.multitask.utils.dtypes import * +from icevision.models.multitask.mmdet.dtypes import * +from icevision.models.mmdet.common.utils import convert_background_from_zero_to_last +from icevision.models.utils import unload_records +from icevision.models.mmdet.common.bbox.dataloaders import ( + _img_tensor, + _img_meta, + _labels, + _bboxes, +) +from icevision.models.multitask.data.dataloading_utils import * +from collections import defaultdict + + +def build_multi_aug_batch( + records: Sequence[RecordType], classification_transform_groups: dict +) -> Tuple[ + Dict[str, Union[DataDictClassification, DataDictDetection]], Sequence[RecordType] +]: + """ + Docs: + Take as inputs `records` and `classification_transform_groups` and return + a tuple of dictionaries, one for detection data and the other for classification. + + See `icevision.models.multitask.data.dataset.HybridAugmentationsRecordDataset` + for example of what `records` and `classification_transform_groups` look like + + Returns: + A nested data dictionary - (`detection_data`, `classification_data`) and + the loaded records + { + `detection_data`: + { + "detection": dict( + images: Tensor = ..., + img_metas: Dict[ + 'img_shape': HWC tuple, + 'pad_shape': HWC tuple, + 'scale_factor': np.ndarray + ] = ..., + gt_bboxes: Tensor = ..., + gt_bbox_labels: Tensor = ..., + ) + } + + `classification_data`: + { + "group1": dict( + tasks = ["shot_composition"], + images: Tensor = ..., + classification_labels=dict( + "shot_composition": Tensor = ..., + ) + ), + "group2": dict( + tasks = ["color_saturation", "shot_framing"], + images: Tensor = ..., + classification_labels=dict( + "color_saturation": Tensor = ..., + "shot_framing": Tensor = ..., + ) + ) + } + } + """ + # NOTE: `detection` is ALWAYS treated as a distinct group + det_images, bbox_labels, bboxes, img_metas = [], [], [], [] + classification_data = defaultdict(lambda: defaultdict(list)) + classification_labels = defaultdict(list) + + for record in records: + # Create detection data + det_images.append(_img_tensor(record.detection)) + img_metas.append(_img_meta(record)) + bbox_labels.append(_labels(record)) + bboxes.append(_bboxes(record)) + + # Get classification images for each group + for key, group in classification_transform_groups.items(): + task = getattr(record, group["tasks"][0]) + # assert (record.color_saturation.img == record.shot_framing.img).all() + + classification_data[key]["tasks"] = group["tasks"] + classification_data[key]["images"].append(_img_tensor(task)) + + # Get classification labels for each group + assign_classification_targets_from_record(classification_labels, record) + # for comp in record.components: + # name = comp.task.name + # if isinstance(comp, ClassificationLabelsRecordComponent): + # if comp.is_multilabel: + # labels = comp.one_hot_encoded() + # classification_labels[name].append(labels) + # else: + # labels = comp.label_ids + # classification_labels[name].extend(labels) + + # Massage data + classification_data = massage_multi_aug_classification_data( + classification_data, classification_labels, "classification_labels" + ) + # for group in classification_data.values(): + # group["classification_labels"] = { + # task: tensor(classification_labels[task]) for task in group["tasks"] + # } + # group["images"] = torch.stack(group["images"]) + # classification_data = {k: dict(v) for k, v in classification_data.items()} + + detection_data = { + "img": torch.stack(det_images), + "img_metas": img_metas, + "gt_bboxes": bboxes, + "gt_bbox_labels": bbox_labels, + } + + data = dict(detection=detection_data, classification=classification_data) + return data, records + + +@unload_records +def build_single_aug_batch(records: Sequence[RecordType]): + """ + Regular `mmdet` dataloader but with classification added in + """ + images, bbox_labels, bboxes, img_metas = [], [], [], [] + classification_labels = defaultdict(list) + + for record in records: + images.append(_img_tensor(record)) + img_metas.append(_img_meta(record)) + bbox_labels.append(_labels(record)) + bboxes.append(_bboxes(record)) + + # Loop through and create classifier dict of inputs + assign_classification_targets_from_record(classification_labels, record) + # for comp in record.components: + # name = comp.task.name + # if isinstance(comp, ClassificationLabelsRecordComponent): + # if comp.is_multilabel: + # labels = comp.one_hot_encoded() + # classification_labels[name].append(labels) + # else: + # labels = comp.label_ids + # classification_labels[name].extend(labels) + + classification_labels = {k: tensor(v) for k, v in classification_labels.items()} + + data = { + "img": torch.stack(images), + "img_metas": img_metas, + "gt_bboxes": bboxes, + "gt_bbox_labels": bbox_labels, + "gt_classification_labels": classification_labels, + } + + return data, records diff --git a/icevision/models/multitask/mmdet/dtypes.py b/icevision/models/multitask/mmdet/dtypes.py new file mode 100644 index 000000000..feac21d0b --- /dev/null +++ b/icevision/models/multitask/mmdet/dtypes.py @@ -0,0 +1,8 @@ +from icevision.imports import * +from icevision.models.multitask.utils.dtypes import * + +ClassificationGroupDataDict = Dict[str, Union[List[str], Tensor, TensorDict]] +DataDictClassification = Dict[str, ClassificationGroupDataDict] +DataDictDetection = Union[ + TensorDict, ArrayDict, Dict[str, Union[Tuple[int], ImgMetadataDict]] +] diff --git a/icevision/models/multitask/mmdet/fastai/__init__.py b/icevision/models/multitask/mmdet/fastai/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/icevision/models/multitask/mmdet/lightning/__init__.py b/icevision/models/multitask/mmdet/lightning/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/icevision/models/multitask/mmdet/models/__init__.py b/icevision/models/multitask/mmdet/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/icevision/models/multitask/mmdet/pl_adapter.py b/icevision/models/multitask/mmdet/pl_adapter.py new file mode 100644 index 000000000..8b2e7d831 --- /dev/null +++ b/icevision/models/multitask/mmdet/pl_adapter.py @@ -0,0 +1,98 @@ +# Modified from `icevision.models.mmdet.lightning.model_adapter` +# NOTE `torchmetrics` comes installed with `pytorch-lightning` +# We could in theory also do `pl.metrics` + +# import pytorch_lightning.metrics as tm +from icevision.models.multitask.utils.prediction import extract_classifier_pred_cfgs +import torchmetrics as tm +from icevision.all import * +from mmcv.utils import ConfigDict +from loguru import logger +from icevision.models.multitask.mmdet.single_stage import ( + ForwardType, + HybridSingleStageDetector, +) +from icevision.models.multitask.mmdet.prediction import * +from icevision.models.multitask.utils.dtypes import * +from icevision.models.multitask.engines.lightning import MultiTaskLightningModelAdapter + +__all__ = ["HybridSingleStageDetectorLightningAdapter"] + + +class HybridSingleStageDetectorLightningAdapter(MultiTaskLightningModelAdapter): + """""" + + def __init__(self, model: HybridSingleStageDetector, metrics: List[Metric] = None): + super().__init__() + self.metrics = metrics or [] + self.model = model + + self.classification_metrics = nn.ModuleDict() + for name, head in model.classifier_heads.items(): + if head.multilabel: + thresh = head.thresh if head.thresh is not None else 0.5 + metric = tm.Accuracy(threshold=thresh, subset_accuracy=True) + else: + metric = tm.Accuracy(threshold=0.01, top_k=1) + self.classification_metrics[name] = metric + self.post_init() + + def post_init(self): + pass + + # ======================== TRAINING METHODS ======================== # + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def training_step(self, batch: Tuple[dict, Sequence[RecordType]], batch_idx): + # Unpack batch into dict + list of records + data, samples = batch + + # Get model outputs - dict of losses and vars to log + step_type = ForwardType.TRAIN_MULTI_AUG + if "img_metas" in data.keys(): + step_type = ForwardType.TRAIN + + outputs = self.model.train_step(data=data, step_type=step_type) + + # Log losses + self._log_vars(outputs["log_vars"], "train") + + return outputs["loss"] + + def validation_step(self, batch, batch_idx): + data, records = batch + if self.debug: + logger.info(f"Validation Step: {data.keys()}") + logger.info(f"Batch Idx: {batch_idx}") + + self.model.eval() + with torch.no_grad(): + # get losses + outputs = self.model.train_step(data=data, step_type=ForwardType.TRAIN) + raw_preds = self.model(data=data, forward_type=ForwardType.EVAL) + self.compute_and_log_classification_metrics( + classification_preds=raw_preds["classification_results"], + yb=data["gt_classification_labels"], + ) + + preds = self.convert_raw_predictions( + batch=data, raw_preds=raw_preds, records=records + ) + self.accumulate_metrics(preds) + self.log_losses(outputs["log_vars"], "valid") + + # TODO: is train and eval model automatically set by lighnting? + self.model.train() + + # ======================== LOGGING METHODS ======================== # + + def convert_raw_predictions(self, batch, raw_preds, records): + return convert_raw_predictions( + batch=batch, + raw_preds=raw_preds, + records=records, + detection_threshold=0.0, + classification_configs=extract_classifier_pred_cfgs(self.model), + ) diff --git a/icevision/models/multitask/mmdet/prediction.py b/icevision/models/multitask/mmdet/prediction.py new file mode 100644 index 000000000..1f0daa108 --- /dev/null +++ b/icevision/models/multitask/mmdet/prediction.py @@ -0,0 +1,233 @@ +# Modified from icevision.models.mmdet.common.bbox.prediction + +from icevision.all import * +from icevision.models.mmdet.common.bbox.prediction import _unpack_raw_bboxes + +from ..utils import * + + +__all__ = [ + "predict", + "predict_from_dl", + "convert_raw_prediction", + "convert_raw_predictions", + "finalize_classifier_preds", +] + +from icevision.imports import * +from icevision.utils import * +from icevision.core import * +from icevision.data import * +from icevision.core.tasks import Task +from icevision.models.utils import _predict_from_dl +from icevision.models.mmdet.common.utils import * +from icevision.models.mmdet.common.bbox.dataloaders import build_infer_batch +from icevision.models.mmdet.common.utils import convert_background_from_last_to_zero +from icevision.models.multitask.utils.prediction import finalize_classifier_preds + + +@torch.no_grad() +def _predict_batch( + model: nn.Module, + batch: Sequence[torch.Tensor], + records: Sequence[BaseRecord], + classification_configs: dict, + detection_threshold: float = 0.5, + keep_images: bool = False, + device: Optional[torch.device] = None, +): + device = device or model_device(model) + batch["img"] = [img.to(device) for img in batch["img"]] + + raw_preds = model(return_loss=False, rescale=False, **batch) + return convert_raw_predictions( + batch=batch, + raw_preds=raw_preds, + records=records, + classification_configs=classification_configs, + keep_images=keep_images, + detection_threshold=detection_threshold, + ) + + +def predict( + model: nn.Module, + dataset: Dataset, + classification_configs: dict, + detection_threshold: float = 0.5, + keep_images: bool = False, + device: Optional[torch.device] = None, +) -> List[Prediction]: + batch, records = build_infer_batch(dataset) + + return _predict_batch( + model=model, + batch=batch, + records=records, + classification_configs=classification_configs, + detection_threshold=detection_threshold, + keep_images=keep_images, + device=device, + ) + + +@torch.no_grad() +def _predict_from_dl( + predict_fn, + model: nn.Module, + infer_dl: DataLoader, + keep_images: bool = False, + show_pbar: bool = True, + **predict_kwargs, +) -> List[Prediction]: + all_preds = [] + for batch, records in pbar(infer_dl, show=show_pbar): + preds = predict_fn( + model=model, + batch=batch, + records=records, + keep_images=keep_images, + **predict_kwargs, + ) + all_preds.extend(preds) + + return all_preds + + +def predict_from_dl( + model: nn.Module, + infer_dl: DataLoader, + # classification_configs: dict, + show_pbar: bool = True, + keep_images: bool = False, + **predict_kwargs, +): + _predict_batch_fn = partial(_predict_batch, keep_images=keep_images) + # FIXME `classification_configs` needs to be passed in as **predict_kwargs + return _predict_from_dl( + predict_fn=_predict_batch_fn, + model=model, + # classification_configs=classification_configs, + infer_dl=infer_dl, + show_pbar=show_pbar, + keep_images=keep_images, + **predict_kwargs, + ) + + +def convert_raw_predictions( + batch, + raw_preds, + records: Sequence[BaseRecord], + classification_configs: dict, + detection_threshold: float, + keep_images: bool = False, +): + + # In inference, both "img" and "img_metas" are lists. Check out the `build_infer_batch()` definition + # We need to convert that to a batch similar to train and valid batches + if isinstance(batch["img"], list): + batch = { + "img": batch["img"][0], + "img_metas": batch["img_metas"][0], + } + bbox_preds, classification_preds = ( + raw_preds["bbox_results"], + raw_preds["classification_results"], + ) + + # Convert dicts of sequences into a form that we can iterate over in a for loop + # A test / infer dataloader will not have "gt_classification_labels" as a key + if "gt_classification_labels" in batch: + gt_classification_labels = [ + dict(zip(batch["gt_classification_labels"], t)) + for t in zipsafe(*batch["gt_classification_labels"].values()) + ] + batch["gt_classification_labels"] = gt_classification_labels + classification_preds = [ + dict(zip(classification_preds, t)) + for t in zipsafe(*classification_preds.values()) + ] + batch_list = [dict(zip(batch, t)) for t in zipsafe(*batch.values())] + + return [ + convert_raw_prediction( + sample=sample, + raw_bbox_pred=bbox_pred, + raw_classification_pred=classification_pred, + classification_configs=classification_configs, + record=record, + detection_threshold=detection_threshold, + keep_image=keep_images, + ) + for sample, bbox_pred, classification_pred, record in zip( + batch_list, bbox_preds, classification_preds, records + ) + ] + + +def convert_raw_prediction( + sample, + raw_bbox_pred: dict, + raw_classification_pred: TensorDict, + classification_configs: dict, + record: BaseRecord, + detection_threshold: float, + keep_image: bool = False, +): + # convert predictions + raw_bboxes = raw_bbox_pred + scores, labels, bboxes = _unpack_raw_bboxes(raw_bboxes) + + keep_mask = scores > detection_threshold + keep_scores = scores[keep_mask] + keep_labels = labels[keep_mask] + keep_bboxes = [BBox.from_xyxy(*o) for o in bboxes[keep_mask]] + + keep_labels = convert_background_from_last_to_zero( + label_ids=keep_labels, class_map=record.detection.class_map + ) + + # TODO: Refactor with functions from `...multitask.utils.prediction` + pred = BaseRecord( + [ + FilepathRecordComponent(), + ScoresRecordComponent(), + ImageRecordComponent(), + InstancesLabelsRecordComponent(), + BBoxesRecordComponent(), + *[ScoresRecordComponent(Task(task)) for task in classification_configs], + *[ + ClassificationLabelsRecordComponent( + Task(task), is_multilabel=cfg.multilabel + ) + for task, cfg in classification_configs.items() + ], + ] + ) + pred.detection.set_class_map(record.detection.class_map) + pred.detection.set_scores(keep_scores) + pred.detection.set_labels_by_id(keep_labels) + pred.detection.set_bboxes(keep_bboxes) + pred.above_threshold = keep_mask + + # TODO: Refactor with functions from `...multitask.utils.prediction` + for task, classification_pred in raw_classification_pred.items(): + labels, scores = finalize_classifier_preds( + pred=classification_pred, + cfg=classification_configs[task], + record=record, + task=task, + ) + pred.set_filepath(record.filepath) + getattr(pred, task).set_class_map(getattr(record, task).class_map) + getattr(pred, task).set_scores(scores) + getattr(pred, task).set_labels(labels) + + if keep_image: + image = mmdet_tensor_to_image(sample["img"]) + + pred.set_img(image) + record.set_img(image) + + return Prediction(pred=pred, ground_truth=record) diff --git a/icevision/models/multitask/mmdet/single_stage.py b/icevision/models/multitask/mmdet/single_stage.py new file mode 100644 index 000000000..e496b4a15 --- /dev/null +++ b/icevision/models/multitask/mmdet/single_stage.py @@ -0,0 +1,348 @@ +from typing import Dict, List +from collections import OrderedDict +from icevision.models.multitask.classification_heads import * + + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch import Tensor + +from icevision.models.mmdet.utils import * +from mmcv import Config, ConfigDict +from mmdet.models.builder import DETECTORS +from mmdet.models.builder import build_backbone, build_detector, build_head, build_neck +from mmdet.models.detectors.single_stage import SingleStageDetector +from mmdet.core.bbox import * +from typing import Union, List, Dict, Tuple, Optional + +from icevision.models.multitask.mmdet.dataloaders import ( + TensorDict, + ClassificationGroupDataDict, + DataDictClassification, + DataDictDetection, +) +import numpy as np +from icevision.models.multitask.utils.model import * +from icevision.models.multitask.utils.dtypes import * + + +__all__ = [ + "HybridSingleStageDetector", + "build_backbone", + "build_detector", + "build_head", + "build_neck", +] + + +@DETECTORS.register_module(name="HybridSingleStageDetector") +class HybridSingleStageDetector(SingleStageDetector): + # TODO: Add weights for loss functions + def __init__( + self, + backbone: Union[dict, ConfigDict], + neck: Union[dict, ConfigDict], + bbox_head: Union[dict, ConfigDict], + classification_heads: Union[None, dict, ConfigDict] = None, + # keypoint_heads=None, # TODO Someday SOON. + train_cfg: Union[None, dict, ConfigDict] = None, + test_cfg: Union[None, dict, ConfigDict] = None, + pretrained=None, + init_cfg: Union[None, dict, ConfigDict] = None, + ): + super(HybridSingleStageDetector, self).__init__( + # Use `init_cfg` post mmdet 2.12 + # backbone, neck, bbox_head, train_cfg, test_cfg, pretrained, init_cfg + backbone=ConfigDict(backbone), + neck=ConfigDict(neck), + bbox_head=ConfigDict(bbox_head), + train_cfg=ConfigDict(train_cfg), + test_cfg=ConfigDict(test_cfg), + pretrained=pretrained, + init_cfg=ConfigDict(init_cfg), + ) + if classification_heads is not None: + self.classifier_heads = build_classifier_heads(classification_heads) + + def train_step( + self, + data: dict, + step_type: ForwardType = ForwardType.TRAIN, + ) -> Dict[str, Union[Tensor, TensorDict, int]]: + """ + A single iteration step (over a batch) + Args: + data: The output of dataloader. Typically `self.fwd_train_data_keys` or + `self.fwd_eval_data_keys` + step_type (Enum): ForwardType.TRAIN | ForwardType.EVAL | ForwardType.TRAIN_MULTI_AUG + + Returns: + dict[str, Union[Tensor, TensorDict, int]] + * `loss` : summed losses for backprop + * `log_vars` : variables to be logged + * `num_samples` : batch size per GPU when using DDP + """ + losses = self(data=data, step_type=step_type) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, + log_vars=log_vars, + num_samples=len(data["img_metas"]) + if "img_metas" in data.keys() + else len(data["detection"]["img_metas"]), + ) + return outputs + + # @auto_fp16(apply_to=("img",)) + def forward(self, data: dict, step_type: ForwardType): + """ + Calls either `self.forward_train`, `self.forward_eval` or + `self.forward_multi_aug_train` depending on the value of `step_type` + + No TTA supported unlike all other mmdet models + """ + if step_type is ForwardType.TRAIN_MULTI_AUG: + return self.forward_multi_aug_train(data) + + elif step_type is ForwardType.TRAIN: + return self.forward_train(data, gt_bboxes_ignore=None) + + elif step_type is ForwardType.EVAL: + return self.forward_eval(data, rescale=False) + + else: + raise RuntimeError( + f"Invalid `step_type`. Received: {type(step_type.__class__)}; Expected: {ForwardType.__class__}" + ) + + fwd_multi_aug_train_data_keys = ["detection", "classification"] + fwd_train_data_keys = [ + "img", + "gt_bboxes", + "gt_bbox_labels", + "gt_classification_labels", + ] + fwd_eval_data_keys = ["img", "img_metas"] + + def forward_multi_aug_train( + self, + data: Dict[str, Union[DataDictClassification, DataDictDetection]], + ) -> Dict[str, Tensor]: + """ + Forward method where multiple views of the same image are passed. + The model does a dedicated forward pass for the `detection` images + and dedicated forward passes for each `classification` group. See + the dataloader docs for more details + Args: + data : a dictionary with two keys - + `detection` and `classification`. See the dataloader docs for + more details on the exact structure + + Returns: + dict[str, Tensor] + * `loss_classification`: Dictionary of classification losses where each key + corresponds to the classification head / task name + * `loss_cls`: Bbox classification loss + * `loss_bbox`: Bbox regression loss + """ + assert set(data.keys()).issuperset(self.fwd_multi_aug_train_data_keys) + # detection_img, img_metas, gt_bboxes, gt_bbox_labels = data["detection"].values() + super(SingleStageDetector, self).forward_train( + data["detection"]["img"], + data["detection"]["img_metas"], + ) + detection_features = self.extract_feat(data["detection"]["img"]) + + losses = self.bbox_head.forward_train( + x=detection_features, + img_metas=data["detection"]["img_metas"], + gt_bboxes=data["detection"]["gt_bboxes"], + gt_labels=data["detection"]["gt_bbox_labels"], + # NOTE we do not return `gt_bboxes_ignore` in the dataloader + gt_bboxes_ignore=data["detection"].get("gt_bboxes_ignore", None), + ) + + # Compute features per _group_, then do a forward pass through each + # classification head in that group to compute the loss + classification_losses = {} + for group, data in data["classification"].items(): + classification_features = self.extract_feat(data["images"]) + for task in data["tasks"]: + head = self.classifier_heads[task] + classification_losses[task] = head.forward_train( + x=classification_features, + gt_label=data["classification_labels"][task], + ) + + losses["loss_classification"] = classification_losses + return losses + + def forward_train(self, data: dict, gt_bboxes_ignore=None) -> Dict[str, Tensor]: + """ + Forward pass + Args: + img: Normalised input images of shape (N, C, H, W). + img_metas: A List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + gt_bboxes: List of gt bboxes in `xyxy` format for each image + gt_labels: Integer class indices corresponding to each box + gt_classification_labels: Dict of ground truths per classification task + gt_bboxes_ignore (None | list[Tensor]): Specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor] + * `loss_classification`: Dictionary of classification losses where each key + corresponds to the classification head / task name + * `loss_cls`: Bbox classification loss + * `loss_bbox`: Bbox regression loss + """ + assert set(data.keys()).issuperset(self.fwd_train_data_keys) + super(SingleStageDetector, self).forward_train(data["img"], data["img_metas"]) + features = self.extract_feat(data["img"]) + losses = self.bbox_head.forward_train( + x=features, + img_metas=data["img_metas"], + gt_bboxes=data["gt_bboxes"], + gt_labels=data["gt_bbox_labels"], + gt_bboxes_ignore=gt_bboxes_ignore, + ) + + classification_losses = { + name: head.forward_train( + x=features, + gt_label=data["gt_classification_labels"][name], + ) + for name, head in self.classifier_heads.items() + } + losses["loss_classification"] = classification_losses + return losses + + # Maintain API + # Placeholder in case we want to do TTA during eval? + def simple_test(self, *args): + return self.forward_eval(*args) + + def forward_eval( + self, data: dict, rescale: bool = False + ) -> Dict[str, Union[TensorDict, List[np.ndarray]]]: + """ + TODO Update mmdet docstring + + Eval / test function on a single image (without TTA). Returns raw predictions of + the model that can be processed in `convert_raw_predictions` + + Args: + imgs: List of multiple images + img_metas: List of image metadata. + rescale: Whether to rescale the results. + + Returns: + { + "bbox_results": List[ArrayList], + "classification_results": TensorDict + } + + bbox_results: Nested list of BBox results The outer list corresponds + to each image. The inner list + corresponds to each class. + classification_results: Dictionary of activated outputs for each classification head + """ + assert set(data.keys()).issuperset(self.fwd_eval_data_keys) + # Raw outputs from network + img, img_metas = data["img"], data["img_metas"] + features = self.extract_feat(img) + bbox_outs = self.bbox_head(features) + classification_results = { + name: head.forward_activate(features) + for name, head in self.classifier_heads.items() + } + + # Get original input shape to support onnx dynamic shape + if torch.onnx.is_in_onnx_export(): + # get shape as tensor + img_shape = torch._shape_as_tensor(img)[2:] + img_metas[0]["img_shape_for_onnx"] = img_shape + + bbox_list = self.bbox_head.get_bboxes(*bbox_outs, img_metas, rescale=rescale) + + # Skip post-processing when exporting to ONNX + if torch.onnx.is_in_onnx_export(): + return bbox_list, classification_results + + bbox_results = [ + bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) + for det_bboxes, det_labels in bbox_list + ] + return { + "bbox_results": bbox_results, + "classification_results": classification_results, + } + + # NOTE: This is experimental + def forward_onnx(self, one_img: Tensor, one_img_metas: List[ImgMetadataDict]): + """ """ + # assert torch.onnx.is_in_onnx_export() + assert len(one_img) == len(one_img_metas) == 1 + + img, img_metas = one_img, one_img_metas + + features = self.extract_feat(img) + bbox_outs = self.bbox_head(features) + classification_results = { + name: head.forward_activate(features) + for name, head in self.classifier_heads.items() + } + + img_shape = torch._shape_as_tensor(img)[2:] # Gets (H, W) + img_metas[0]["img_shape_for_onnx"] = img_shape + bbox_list = self.bbox_head.get_bboxes(*bbox_outs, img_metas, rescale=False) + + return bbox_list, list(classification_results.values()) + + def _parse_losses( + self, losses: Dict[str, Union[Tensor, TensorDict, TensorList]] + ) -> tuple: + # TODO: Pass weights into loss + # NOTE: This is where you can pass in weights for each loss function + r"""Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, coming typically from `self.train_step` + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \ + which may be a weighted sum of all losses, log_vars contains \ + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + elif isinstance(loss_value, dict): + # Unroll classification losses returned as a dict + for k, v in loss_value.items(): + log_vars[f"loss_classification_{k}"] = v + else: + raise TypeError( + f"{loss_name} is not a tensor or list or dict of tensors" + ) + + loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) + + log_vars["loss"] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars diff --git a/icevision/models/multitask/ultralytics/__init__.py b/icevision/models/multitask/ultralytics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/icevision/models/multitask/ultralytics/yolov5/__init__.py b/icevision/models/multitask/ultralytics/yolov5/__init__.py new file mode 100644 index 000000000..d2cabe4f0 --- /dev/null +++ b/icevision/models/multitask/ultralytics/yolov5/__init__.py @@ -0,0 +1,33 @@ +""" +The following imports are not strictly necessary if you're only using the high level API, + but are nice to have for quick dev / debugging, and to use some of the lower level APIs. + They must be imported before the rest of the imports to respect namespaces +""" + +import icevision.tfms as tfms +from icevision.models.multitask.classification_heads import * +from icevision.models.multitask.utils import * +from icevision.data.dataset import Dataset +from yolov5.utils.loss import ComputeLoss + + +""" +The following imports are what are essential for the high level API, and in line with + the way modules are imported with the rest of the library +""" + +from icevision.models.multitask.ultralytics.yolov5.dataloaders import * +from icevision.models.multitask.ultralytics.yolov5.model import * +from icevision.models.multitask.ultralytics.yolov5.prediction import * +from icevision.models.multitask.ultralytics.yolov5.utils import * +from icevision.models.multitask.ultralytics.yolov5.backbones import * +from icevision.models.multitask.ultralytics.yolov5.arch.yolo_hybrid import * + + +from icevision.soft_dependencies import SoftDependencies + +if SoftDependencies.fastai: + from icevision.models.multitask.ultralytics.yolov5 import fastai + +if SoftDependencies.pytorch_lightning: + from icevision.models.multitask.ultralytics.yolov5 import lightning diff --git a/icevision/models/multitask/ultralytics/yolov5/arch/__init__.py b/icevision/models/multitask/ultralytics/yolov5/arch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/icevision/models/multitask/ultralytics/yolov5/arch/model_freezing.py b/icevision/models/multitask/ultralytics/yolov5/arch/model_freezing.py new file mode 100644 index 000000000..dd13c6211 --- /dev/null +++ b/icevision/models/multitask/ultralytics/yolov5/arch/model_freezing.py @@ -0,0 +1,196 @@ +import torch +import torch.nn as nn +import numpy as np + +from torch import Tensor +from torch.nn import Parameter + +from typing import Collection, Union, List, Tuple +from icevision.utils.torch_utils import params +from icevision.utils.utils import flatten +from loguru import logger + +logger = logger.opt(colors=True) + +__all__ = ["FreezingInterfaceExtension"] + + +class FreezingInterfaceExtension: + """ + Model freezing and unfreezing extensions for `HybridYOLOV5` + Note that the BatchNorm layers are also frozen, but that part is not + defined here, but in the main module's `.train()` method directly + """ + + def _get_params_stem(self) -> List[nn.Parameter]: + return params(self.model[0]) + + def _get_params_backbone(self) -> List[List[Parameter]]: + return [params(m) for m in self.model[1 : self.bbone_blocks_end_idx]] + + def _get_params_neck(self) -> List[List[Parameter]]: + return [params(m) for m in self.model[self.bbone_blocks_end_idx :][:-1]] + + def _get_params_bbox_head(self) -> List[List[Parameter]]: + return params(self.model[-1]) + + def _get_params_classifier_heads(self) -> List[List[Parameter]]: + return [params(self.classifier_heads)] + + def _set_param_grad_stem(self, mode: bool): + for p in flatten(self._get_params_stem()): + p.requires_grad = mode + + def _set_param_grad_backbone(self, mode: bool, bbone_blocks: Collection[int]): + error_msg = f""" + `bbone_blocks` must be a list|tuple of values between 0-{self.num_bbone_blocks} specifying which blocks to set this state for + """ + + if not isinstance(bbone_blocks, (list, tuple)): + raise TypeError(error_msg) + if not all(isinstance(x, int) for x in bbone_blocks): + raise TypeError(error_msg) + if not bbone_blocks == []: + if not 0 <= bbone_blocks[0] <= self.num_bbone_blocks - 1: + raise ValueError(error_msg) + + pgs = np.array(self._get_params_backbone(), dtype="object") + for p in flatten(pgs[bbone_blocks]): + p.requires_grad = mode + + def _set_param_grad_neck(self, mode: bool): + for p in flatten(self._get_params_neck()): + p.requires_grad = mode + + def _set_param_grad_bbox_head(self, mode: bool): + for p in flatten(self._get_params_bbox_head()): + p.requires_grad = mode + + def _set_param_grad_classifier_heads(self, mode: bool): + for p in flatten(self._get_params_classifier_heads()): + p.requires_grad = mode + + def freeze( + self, + stem: bool = False, + bbone_blocks: int = 0, # between 0 to self.num_bbone_blocks + neck: bool = False, + bbox_head: bool = False, + classifier_heads: bool = False, + ): + """ + Freeze selected parts of the network. By default, none of the parts are frozen, you need + to manually set each arg's value to `True` if you want to freeze it. If you don't want + this fine grained control, see `.freeze_detector()`, `.freeze_backbone()`, `.freeze_classifier_heads()` + + Args: + stem (bool, optional): Freeze the first conv layer. Defaults to True. + bbone_blocks (int, optional): Number of blocks to freeze. If 0, none are frozen; if ==self.num_bbone_blocks, all are frozen. + neck (bool, optional): Freeze the neck (FPN). Defaults to False. + bbox_head (bool, optional): Freeze the bounding box head (the `Detect` module). Defaults to False. + classifier_heads (bool, optional): Freeze all the classification heads. Defaults to False. + """ + if stem: + self._set_param_grad_stem(False) + if bbone_blocks: + self._set_param_grad_backbone(False, [i for i in range(bbone_blocks)]) + if neck: + self._set_param_grad_neck(False) + if bbox_head: + self._set_param_grad_bbox_head(False) + if classifier_heads: + self._set_param_grad_classifier_heads(False) + + def unfreeze( + self, + stem: bool = False, + bbone_blocks: int = 0, + neck: bool = False, + bbox_head: bool = False, + classifier_heads: bool = False, + ): + """ + Unfreeze specific parts of the model. By default all parts are kept frozen. + You need to manually set whichever part you want to unfreeze by passing that arg as `True`. + See `.unfreeze_detector()`, `.unfreeze_backbone()`, `.unfreeze_classifier_heads()` methods if you + don't want this fine grained control. + + Note that `bbone_blocks` works differently from `.freeze()`. `bbone_blocks=3` will unfreeze + the _last 3_ blocks, and `bbone_blocks=self.num_bbone_blocks` will unfreeze _all_ the blocks + """ + if stem: + self._set_param_grad_stem(True) + if bbone_blocks: + self._set_param_grad_backbone( + True, + [ + i + for i in range( + self.num_bbone_blocks - bbone_blocks, self.num_bbone_blocks + ) + ], + ) + if neck: + self._set_param_grad_neck(True) + if bbox_head: + self._set_param_grad_bbox_head(True) + if classifier_heads: + self._set_param_grad_classifier_heads(True) + + def freeze_detector(self): + "Freezes the entire detector i.e. stem, bbone, neck, bbox head" + self.freeze( + stem=True, bbone_blocks=self.num_bbone_blocks, neck=True, bbox_head=True + ) + + def unfreeze_detector(self): + "Unfreezes the entire detector i.e. stem, bbone, neck, bbox head" + self.unfreeze( + stem=True, bbone_blocks=self.num_bbone_blocks, neck=True, bbox_head=True + ) + + def freeze_backbone(self, fpn=True): + "Freezes the entire backbone, optionally without the neck/fpn" + self.freeze( + stem=True, bbone_blocks=self.num_bbone_blocks, neck=True if fpn else False + ) + + def freeze_neck(self): + "Freeze the FPN/Neck" + self.freeze(neck=True) + + def freeze_fpn(self): + "Freeze the FPN/Neck" + self.freeze_neck() + + def unfreeze_backbone(self, fpn=True): + "Unfreezes the entire backbone, optionally without the neck/fpn" + self.unfreeze( + stem=True, bbone_blocks=self.num_bbone_blocks, neck=True if fpn else False + ) + + def freeze_classifier_heads(self): + "Freezes just the classification heads" + self.freeze(classifier_heads=True) + + def unfreeze_classifier_heads(self): + "Unfreezes just the classification heads" + self.unfreeze(classifier_heads=True) + + def freeze_specific_classifier_heads( + self, names: Union[str, List[str], None] = None, _grad: bool = False + ): + "Freeze all, one or a few classifier heads" + if isinstance(names, str): + names = [names] + if names is None: + names = list(self.classifier_heads.keys()) + + for name in names: + for p in flatten(params(self.classifier_heads[name])): + p.requires_grad = _grad + + def unfreeze_specific_classifier_heads( + self, names: Union[str, List[str], None] = None + ): + self.freeze_specific_classifier_heads(names=names, _grad=True) diff --git a/icevision/models/multitask/ultralytics/yolov5/arch/param_groups.py b/icevision/models/multitask/ultralytics/yolov5/arch/param_groups.py new file mode 100644 index 000000000..1845c441f --- /dev/null +++ b/icevision/models/multitask/ultralytics/yolov5/arch/param_groups.py @@ -0,0 +1,43 @@ +""" +This file defines how to get parameter groups from the `HybridYOLOV5` +model. It is expected to be used along with the other classes in this +submodule, but is defined in a distinct file for easier referencing +and if one wanted to define a custom param_groups functions +""" + +from typing import List +from torch.nn import Parameter +from icevision.utils.utils import flatten +from icevision.utils.torch_utils import check_all_model_params_in_groups2 + +__all__ = ["ParamGroupsExtension"] + + +class ParamGroupsExtension: + """ + Splits the model into distinct parameter groups to pass differential + learning rates to. Given the structure of the model, you must note + that the param groups are not returned sequentially. The last returned + group is the classifier heads, and the second last is bbox head, and you + may want to apply the same LR to both. The `lr=slice(1e-3)` syntax will not + work for that and you'd have to manually pass in a sequence of + `len(param_groups)` (5) learning rates instead + + Param Groups: + 1. Stem - The first conv layer + 2. Backbone - Layers 1:10 + 3. Neck - The FPN layers i.e. layers 10:23 (24?) + 4. BBox Head - The `Detect` module, which is the last layer in `self.model` + 5. Classifier Heads + """ + + def param_groups(self) -> List[List[Parameter]]: + param_groups = [ + flatten(self._get_params_stem()), + flatten(self._get_params_backbone()), + flatten(self._get_params_neck()), + flatten(self._get_params_bbox_head()), + flatten(self._get_params_classifier_heads()), + ] + check_all_model_params_in_groups2(self, param_groups=param_groups) + return param_groups diff --git a/icevision/models/multitask/ultralytics/yolov5/arch/yolo_hybrid.py b/icevision/models/multitask/ultralytics/yolov5/arch/yolo_hybrid.py new file mode 100644 index 000000000..c4e0f2c6e --- /dev/null +++ b/icevision/models/multitask/ultralytics/yolov5/arch/yolo_hybrid.py @@ -0,0 +1,401 @@ +""" +Multitask implementation of YOLO-V5. +Supports the following tasks: + * Object Detection + * Image Classification + +See https://discord.com/channels/735877944085446747/770279401791160400/853698548750745610 + for a more detailed discussion +""" + +__all__ = ["HybridYOLOV5", "ClassifierConfig"] + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from pathlib import Path +from torch import Tensor +from torch.nn.parameter import Parameter +from icevision.models.multitask.classification_heads.head import ( + ClassifierConfig, + ImageClassificationHead, + Passthrough, +) +from icevision.models.multitask.utils.dtypes import * +from icevision.models.multitask.classification_heads.builder import ( + build_classifier_heads_from_configs, +) +from icevision.models.multitask.utils.model import ForwardType, set_bn_eval +from icevision.models.multitask.ultralytics.yolov5.arch.model_freezing import * +from icevision.models.multitask.ultralytics.yolov5.arch.param_groups import * + +# from .yolo import * +from yolov5.models.yolo import * + +from typing import Collection, Dict, Optional, List, Tuple, Union +from copy import deepcopy +from loguru import logger + +logger = logger.opt(colors=True) + + +# fmt: off +YOLO_FEATURE_MAP_DIMS = { + # models/*yaml + "yolov5s": [128, 256, 512], # (128, 32, 32), (256, 16, 16), (512, 8, 8) + "yolov5m": [192, 384, 768], # (192, 32, 32), (384, 16, 16), (768, 8, 8) + "yolov5l": [256, 512, 1024], # (256, 32, 32), (512, 16, 16), (1024, 8, 8) + "yolov5x": [320, 640, 1280], # (320, 32, 32), (640, 16, 16), (1280, 8, 8) + + # models/hub/*yaml + "yolov3-spp": [256, 512, 1024], # (256, 32, 32), (512, 16, 16), (1024, 8, 8) + "yolov3-tiny": [256, 512], # (256, 16, 16), (512, 8, 8) + "yolov3": [256, 512, 1024], # (256, 32, 32), (512, 16, 16), (1024, 8, 8) + "yolov5-fpn": [256, 512, 1024], # (256, 32, 32), (512, 16, 16), (1024, 8, 8) + "yolov5-p2": [256, 512, 1024], # (256, 32, 32), (512, 16, 16), (1024, 8, 8) + "yolov5-p6": [256, 512, 768, 1024], # (256, 32, 32), (512, 16, 16), (768, 8, 8), (1024, 4, 4) + "yolov5-p7": [256, 512, 768, 1024, 1280], # (256, 32, 32), (512, 16, 16), (768, 8, 8), (1024, 4, 4), (1280, 2, 2) + "yolov5-panet": [256, 512, 1024], # (256, 32, 32), (512, 16, 16), (1024, 8, 8) + "yolov5l6": [256, 512, 768, 1024], # (256, 32, 32), (512, 16, 16), (768, 8, 8), (1024, 4, 4) + "yolov5m6": [192, 384, 576, 768], # (192, 32, 32), (384, 16, 16), (576, 8, 8), (768, 4, 4) + "yolov5s6": [128, 256, 384, 512], # (128, 32, 32), (256, 16, 16), (384, 8, 8), (512, 4, 4) + "yolov5x6": [320, 640, 960, 1280], # (320, 32, 32), (640, 16, 16), (960, 8, 8), (1280, 4, 4) + "yolov5s-transformer": [128, 256, 512], # (128, 32, 32), (256, 16, 16), (512, 8, 8) +} +# fmt: on + + +class HybridYOLOV5( + nn.Module, + FreezingInterfaceExtension, + ParamGroupsExtension, +): + """ + Info: + Create a multitask variant of any YOLO model from ultralytics + Currently, multitasking detection + classification is supported. An + arbitrary number of classification heads can be created by passing + in a dictionary of `ClassifierConfig`s where the keys are names of tasks + + Sample Usage: + HybridYOLOV5( + cfg="models/yolov5s.yaml", + classifier_configs=dict( + classifier_head_1=ClassifierConfig(out_classes=10), + classifier_head_2=ClassifierConfig(out_classes=20, multilabel=True), + ), + ) + """ + + # HACK sort of... as subclassing is a bit problematic with super(...).__init__() + fuse = Model.fuse + nms = Model.nms + _initialize_biases = Model._initialize_biases + _print_biases = Model._print_biases + autoshape = Model.autoshape + info = Model.info + + def __init__( + self, + cfg, # Path to `.yaml` config + ch=3, # Num. input channels (3 for RGB image) + nc=None, # Num. bbox classes + anchors=None, + classifier_configs: Dict[str, ClassifierConfig] = None, + ): + super(HybridYOLOV5, self).__init__() + + if isinstance(cfg, dict): + self.yaml = cfg # model dict + else: # is *.yaml + import yaml # for torch hub + + self.yaml_file = Path(cfg).name + with open(cfg) as f: + self.yaml = yaml.safe_load(f) # model dict + + # Define model + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: + logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") + self.yaml["nc"] = nc # override yaml value + if anchors: + logger.info(f"Overriding model.yaml anchors with anchors={anchors}") + self.yaml["anchors"] = round(anchors) # override yaml value + self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) + self.names = [str(i) for i in range(self.yaml["nc"])] # default names + self.inplace = self.yaml.get("inplace", True) + # logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) + + self.classifier_configs = classifier_configs + self.build_classification_modules() + self.post_layers_init() + + # Build strides, anchors + m = self.model[-1] # Detect() + if isinstance(m, Detect): + s = 256 # 2x min stride + m.inplace = self.inplace + # NOTE: This is the only modified line before classifier heads + # because we are now returning 2 outputs, not one + m.stride = torch.tensor( + # Index into [0] because [1]th index is the classification preds + [s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))[0]] + # [s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))] + ) # forward + m.anchors /= m.stride.view(-1, 1, 1) + check_anchor_order(m) + self.stride = m.stride + self._initialize_biases() # only run once + # logger.info('Strides: %s' % m.stride.tolist()) + + self.fpn_dims = YOLO_FEATURE_MAP_DIMS[Path(self.yaml_file).stem] + # self.fpn_dims = self.extract_features(torch.rand(1, 3, 224, 224)) + self.num_fpn_dims = len(self.fpn_dims) + + # Init weights, biases + initialize_weights(self) + self.info() + logger.success(f"Built *{Path(self.yaml_file).stem}* model successfully") + + self.post_init() + + def post_layers_init(self): + """ + Run before doing test forward passes for determining the `Detect` (bbox_head) hparams. + If you want to inject custom modules into the model, this is the place to do it + """ + pass + + def post_init(self): + pass + + def train(self, mode: bool = True): + "Set model to training mode, while freezing non trainable layers' BN statistics" + super(HybridYOLOV5, self).train(mode) + set_bn_eval(self) + return self + + @property + def num_bbone_blocks(self) -> int: + return len(self.yaml["backbone"]) - 1 + + @property + def bbone_blocks_start_idx(self) -> int: + return 1 + + @property + def bbone_blocks_end_idx(self) -> int: + return len(self.yaml["backbone"]) + + def build_classification_modules(self, verbose: bool = True): + """ + Description: + Build classifier heads from `self.classifier_configs`. + Does checks to see if `num_fpn_features` are given and if they are + correct for each classifier config, and corrects them if not + """ + arch = Path(self.yaml_file).stem + # fpn_dims = np.array(YOLO_FEATURE_MAP_DIMS[arch]) + fpn_dims = [ + o.shape[1] for o in self.extract_features(torch.rand(1, 3, 640, 640)) + ] + + for task, cfg in self.classifier_configs.items(): + num_fpn_features = ( + sum(fpn_dims) if cfg.fpn_keys is None else sum(fpn_dims[cfg.fpn_keys]) + ) + + if cfg.num_fpn_features is None: + cfg.num_fpn_features = num_fpn_features + + elif cfg.num_fpn_features != num_fpn_features: + if verbose: + logger.warning( + f"Incompatible `num_fpn_features={cfg.num_fpn_features}` detected in task '{task}'. " + f"Replacing with the correct dimensions: {num_fpn_features}" + ) + cfg.num_fpn_features = num_fpn_features + + self.classifier_heads = build_classifier_heads_from_configs( + self.classifier_configs + ) + if verbose: + logger.success(f"Built classifier heads successfully") + + def forward( + self, + x: Union[Tensor, dict], + profile=False, + # forward_detection: bool = True, + # forward_classification: bool = True, + # activate_classification: bool = False, + step_type=ForwardType.TRAIN, + ) -> Tuple[Union[Tensor, TensorList], TensorDict]: + "Forward method that is dispatched based on `step_type`" + + if step_type is ForwardType.TRAIN or step_type is ForwardType.EVAL: + # Assume that model is set to `.eval()` mode before calling this function...? + return self.forward_once(x, profile=profile) + + elif step_type is ForwardType.INFERENCE: + return self.forward_inference(x) + + elif step_type is ForwardType.TRAIN_MULTI_AUG: + return self.forward_multi_augment(x) + + else: + raise RuntimeError( + f"Invalid `step_type`. Received: {type(step_type.__class__)}; Expected: {ForwardType.__class__}" + ) + + def forward_inference(self, x): + # You may export model in training mode? + if not self.training: + (det_out, _), clf_out = self.forward_once(x, activate_classification=True) + if self.training: + self.classifier_heads.eval() # Turn off dropout + det_out, clf_out = self.forward_once(x, activate_classification=True) + return det_out, tuple(clf_out.values()) + + # This is here for API compatibility with the main repo; will likely not be used + def forward_augment(self, x): + raise NotImplementedError + + def extract_features(self, x: Tensor): + return self.forward_once( + x, forward_detection=False, forward_classification=False + )[0] + + def forward_multi_augment(self, data: dict) -> Tuple[TensorList, TensorDict]: + """ + Description: + Multi augmentation training where we do multiple forward passes over the + same batch, going through different parts of the network each time. + + Detection and classification are treated separately, and within classification, + you can group together different tasks. A `group` has multiple `tasks`, so we + extract features once per group, then iterate over each head for that group's + `tasks`, and compute the outputs from these features + + Args: + data (dict): Input container with the following structure: + ```python + xb = torch.zeros(1, 3, 224, 224) + multi_aug_data = dict( + detection={"images": xb}, + classification={ + "group_1": dict( + tasks=["framing", "saturation"], + images=x, + ) + } + ) + ``` + Each group in data["classification"]'s `tasks` must correspond to + a key in `self.classifier_heads` + + Raises: + RuntimeError: If model is not in `training` mode (as a safety check) + + Returns: + Tuple[TensorList, TensorDict]: Tuple of `(detection_preds, classification_preds)` + """ + if not self.training: + raise RuntimeError(f"Can only run `forward_multi_augment` in training mode") + + # Detection forward pass + xb = data["detection"]["images"] + detection_preds, _ = self.forward_once( + xb, forward_detection=True, forward_classification=False + ) + + # Classification forward pass + classification_preds = {} + for group, datum in data["classification"].items(): + xb = datum["images"] + features, _ = self.forward_once( + xb, forward_detection=False, forward_classification=False + ) + for name in datum["tasks"]: + head = self.classifier_heads[name] + classification_preds[name] = head(features) + + return detection_preds, classification_preds + + def forward_once( + self, + x, + profile=False, # Will fail + forward_detection: bool = True, + forward_classification: bool = True, + activate_classification: bool = False, + ) -> Tuple[Union[TensorList, Tuple[Tensor, TensorList]], TensorDict]: + """ + Returns: + A tuple of two elements `(detection_preds, classification_preds)`: + 1) A TensorList in training mode, and a Tuple[Tensor, TensorList] in + eval mode where the first element (Tensor) is the inference output and + second is the training output (for loss computation). If `forward_detection` is + False, the list of FPN features are returned right before feeding into the bbox + head i.e. the `Detect` module which can be accessed via `self.model[-1]` + 2) A TensorDict of classification predictions. If `forward_classification` is + False, an empty dictionary is returned + """ + y, dt = [], [] # outputs + classification_preds: Dict[str, Tensor] = {} + for m in self.model: + if m.f != -1: # if not from previous layer + x = ( + y[m.f] + if isinstance(m.f, int) + else [x if j == -1 else y[j] for j in m.f] + ) # from earlier layers + + if profile: + o = ( + thop.profile(m, inputs=(x,), verbose=False)[0] / 1e9 * 2 + if thop + else 0 + ) # FLOPs + t = time_synchronized() + for _ in range(10): + _ = m(x) + dt.append((time_synchronized() - t) * 100) + if m == self.model[0]: + logger.info( + f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}" + ) + logger.info(f"{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}") + + """ + This is where the feature maps are passed into the classification heads. + Is there a cleaner way to do this? It's tricky as the whole model is wrapped in an + `nn.Sequential` container and we can't access attribues like `.backbone` or `.neck`. + We know for certain that `Detect` is the last layer in the model, so this should be + safe to do. + """ + if isinstance(m, Detect): + if forward_classification: + for name, head in self.classifier_heads.items(): + classification_preds[name] = ( + head.forward_activate(x) + if activate_classification + else head(x) + ) + + if not forward_detection: + if profile: + logger.info("%.1fms total" % sum(dt)) + return x, classification_preds + + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + + if profile: + logger.info("%.1fms total" % sum(dt)) + + return x, classification_preds diff --git a/icevision/models/multitask/ultralytics/yolov5/backbones.py b/icevision/models/multitask/ultralytics/yolov5/backbones.py new file mode 100644 index 000000000..0ee5f34b8 --- /dev/null +++ b/icevision/models/multitask/ultralytics/yolov5/backbones.py @@ -0,0 +1,17 @@ +""" +This file is redundant in terms of code as it uses the exact same code +as `icevision.models.ultralytics.yolov5.backbones` + +We're keeping it to maintain structure, and in case we want to change +something in the future that is multitask model specific +""" + +from icevision.models.multitask.ultralytics.yolov5.utils import * +from icevision.models.ultralytics.yolov5.backbones import * + +__all__ = [ + "small", + "medium", + "large", + "extra_large", +] diff --git a/icevision/models/multitask/ultralytics/yolov5/dataloaders.py b/icevision/models/multitask/ultralytics/yolov5/dataloaders.py new file mode 100644 index 000000000..405576dc8 --- /dev/null +++ b/icevision/models/multitask/ultralytics/yolov5/dataloaders.py @@ -0,0 +1,221 @@ +""" +YOLO-V5 dataloaders for multitask training. + +The model uses a peculiar format for bounding box annotations where the + length of the tensor is the total number of bounding boxes for that batch +The first dimension is the index of the image that the box belongs to/ +See https://discord.com/channels/735877944085446747/770279401791160400/853691059338084372 + for a more thorough explanation +""" + +from icevision.imports import * +from icevision.core import * +from icevision.models.utils import * +from icevision.models.ultralytics.yolov5.dataloaders import ( + _build_train_sample as _build_train_detection_sample, +) +from icevision.models.ultralytics.yolov5.dataloaders import build_infer_batch, infer_dl +from icevision.models.multitask.utils.dtypes import * +from icevision.models.multitask.data.dataset import HybridAugmentationsRecordDataset +from icevision.models.multitask.data.dataloading_utils import * +from torch.utils.data import Dataset + + +__all__ = [ + "build_single_aug_batch", # <- build_train_batch, build_valid_batch + "build_multi_aug_batch", # <- build_train_batch + "build_infer_batch", + "train_dl", + "train_dl_multi_aug", + "valid_dl", + "infer_dl", +] + + +def build_single_aug_batch( + records: Sequence[RecordType], +) -> Tuple[TensorList, TensorDict, Sequence[RecordType]]: + """Builds a batch in the format required by the model when training. + + # Arguments + records: A `Sequence` of records. + + # Returns + A tuple with two items. The first will be a tuple like `(images, detection_targets, classification_targets)` + in the input format required by the model. The second will be an updated list + of the input records. + + # Examples + + Use the result of this function to feed the model. + ```python + batch, records = build_train_batch(records) + outs = model(*batch) + ``` + """ + images, detection_targets = [], [] + classification_targets = defaultdict(list) + + for i, record in enumerate(records): + image, detection_target = _build_train_detection_sample(record) + images.append(image) + + # See file header for more info on why this is done + if detection_target.numel() > 0: + detection_target[:, 0] = i + + detection_targets.append(detection_target) + assign_classification_targets_from_record(classification_targets, record) + + classification_targets = {k: tensor(v) for k, v in classification_targets.items()} + + return ( + torch.stack(images, 0), + torch.cat(detection_targets, 0), + classification_targets, + ), records + + +def build_multi_aug_batch( + records: Sequence[RecordType], classification_transform_groups: dict +): + """ + Docs: + Take as inputs `records` and `classification_transform_groups` and return + a tuple of dictionaries, one for detection data and the other for classification. + + See `icevision.models.multitask.data.dataset.HybridAugmentationsRecordDataset` + for example of what `records` and `classification_transform_groups` look like + + + Returns: + A tuple with two items: + 1. A tuple with two dictionaries - (`detection_data`, `classification_data`) + `detection_data`: + { + "detection": dict( + images: Tensor = ..., # (N,C,H,W) + targets: Tensor = ..., # of shape (num_boxes, 6) + # (img_idx, box_class_idx, **bbox_relative_coords) + ) + } + `classification_data`: + { + "group1": dict( + tasks = ["shot_composition"], + images: Tensor = ..., + targets=dict( + "shot_composition": Tensor = ..., + ) + ), + "group2": dict( + tasks = ["color_saturation", "shot_framing"], + images: Tensor = ..., + targets=dict( + "color_saturation": Tensor = ..., + "shot_framing": Tensor = ..., + ) + ) + } + 2. Loaded records (same ones passed as inputs) + """ + detection_images = [] + detection_targets = [] + classification_data = defaultdict(lambda: defaultdict(list)) + classification_targets = defaultdict(list) + + for i, record in enumerate(records): + detection_image, detection_target = _build_train_detection_sample(record) + detection_images.append(detection_image) + + # See file header for more info on why this is done + if detection_target.numel() > 0: + detection_target[:, 0] = i + + detection_targets.append(detection_target) + for key, group in classification_transform_groups.items(): + task = getattr(record, group["tasks"][0]) + classification_data[key]["tasks"] = group["tasks"] + classification_data[key]["images"].append(im2tensor(task.img)) + + assign_classification_targets_from_record(classification_targets, record) + record.unload() # NOTE: Safety mechanism + + # Massage data + classification_data = massage_multi_aug_classification_data( + classification_data, classification_targets, "targets" + ) + + detection_data = dict( + images=torch.stack(detection_images, 0), + targets=torch.cat(detection_targets, 0), + ) + + return (detection_data, classification_data), records + + +def train_dl(dataset: Dataset, batch_tfms=None, **dataloader_kwargs) -> DataLoader: + """ + A `DataLoader` with a custom `collate_fn` that batches records as required for feeding a YOLO-V5 model. + + Args: + dataset (Dataset): A `Dataset` that returns a transformed record upon indexing + batch_tfms: ... # TODO + **dataloader_kwargs: Keyword arguments that will be internally passed to a Pytorch `DataLoader`. + The parameter `collate_fn` is already defined internally and cannot be passed here. + + Returns: + DataLoader: A PyTorch `DataLoader` + """ + return transform_dl( + dataset=dataset, + build_batch=build_single_aug_batch, + batch_tfms=batch_tfms, + **dataloader_kwargs, + ) + + +def valid_dl(dataset: Dataset, batch_tfms=None, **dataloader_kwargs) -> DataLoader: + """ + A `DataLoader` with a custom `collate_fn` that batches items as required for validating the YOLO-V5 model. + + Args: + dataset (Dataset): A `Dataset` that returns a transformed record upon indexing + batch_tfms: ... # TODO + **dataloader_kwargs: Keyword arguments that will be internally passed to a Pytorch `DataLoader`. + The parameter `collate_fn` is already defined internally and cannot be passed here. + + Returns: + DataLoader: A PyTorch `DataLoader` + """ + return train_dl(dataset=dataset, batch_tfms=batch_tfms, **dataloader_kwargs) + + +def train_dl_multi_aug( + dataset: HybridAugmentationsRecordDataset, + classification_transform_groups: dict, + **dataloader_kwargs, +) -> DataLoader: + """ + A `DataLoader` meant to work with `HybridAugmentationsRecordDataset`, a multitasking + dataset, where individual or groups of tasks receive their own unique transforms. + `batch_tfms` is not yet implemented for this DataLoader. + + Args: + dataset (HybridAugmentationsRecordDataset): A custom dataset that groups tasks and returns + records where _each_ task has its own `img` + + classification_transform_groups (dict): The exact same dictionary that is passed to + HybridAugmentationsRecordDataset`, describing how to group and transform classification tasks. + See the dataset's docs for more details. + + Returns: + DataLoader: A PyTorch `DataLoader` + """ + collate_fn = unload_records( + build_batch=build_multi_aug_batch, + build_batch_kwargs=dict( + classification_transform_groups=classification_transform_groups + ), + ) + return DataLoader(dataset=dataset, collate_fn=collate_fn, **dataloader_kwargs) diff --git a/icevision/models/multitask/ultralytics/yolov5/fastai/__init__.py b/icevision/models/multitask/ultralytics/yolov5/fastai/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/icevision/models/multitask/ultralytics/yolov5/lightning/__init__.py b/icevision/models/multitask/ultralytics/yolov5/lightning/__init__.py new file mode 100644 index 000000000..33a818e05 --- /dev/null +++ b/icevision/models/multitask/ultralytics/yolov5/lightning/__init__.py @@ -0,0 +1 @@ +from icevision.models.multitask.ultralytics.yolov5.lightning.model_adapter import * diff --git a/icevision/models/multitask/ultralytics/yolov5/lightning/model_adapter.py b/icevision/models/multitask/ultralytics/yolov5/lightning/model_adapter.py new file mode 100644 index 000000000..233e7c9b8 --- /dev/null +++ b/icevision/models/multitask/ultralytics/yolov5/lightning/model_adapter.py @@ -0,0 +1,143 @@ +# Modified from `icevision.models.mmdet.lightning.model_adapter` +# NOTE `torchmetrics` comes installed with `pytorch-lightning` +# We could in theory also do `pl.metrics` + + +from icevision.models.multitask.classification_heads.head import TensorDict +import torchmetrics as tm +import pytorch_lightning as pl + +from icevision.imports import * +from icevision.metrics import * +from icevision.core import * + +from loguru import logger +from icevision.models.multitask.ultralytics.yolov5.arch.yolo_hybrid import HybridYOLOV5 +from icevision.models.multitask.utils.prediction import * +from icevision.models.multitask.ultralytics.yolov5.prediction import ( + convert_raw_predictions, +) +from icevision.models.multitask.utils.model import ForwardType +from icevision.models.multitask.engines.lightning import MultiTaskLightningModelAdapter +from yolov5.utils.loss import ComputeLoss + + +class HybridYOLOV5LightningAdapter(MultiTaskLightningModelAdapter): + """ """ + + def __init__( + self, + model: HybridYOLOV5, + metrics: List[Metric] = None, + debug: bool = False, + ): + super().__init__() + self.metrics = metrics or [] + self.model = model + self.debug = debug + self.compute_loss = ComputeLoss(model) + + self.classification_metrics = nn.ModuleDict() + for name, head in model.classifier_heads.items(): + if head.multilabel: + thresh = head.thresh if head.thresh is not None else 0.5 + metric = tm.Accuracy(threshold=thresh, subset_accuracy=True) + else: + metric = tm.Accuracy(threshold=0.01, top_k=1) + self.classification_metrics[name] = metric + self.post_init() + + def post_init(self): + pass + + # ======================== TRAINING METHODS ======================== # + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def training_step(self, batch: Tuple[dict, Sequence[RecordType]], batch_idx): + # batch will ALWAYS return a tuple of 2 elements - batched inputs, records + tupled_inputs, _ = batch + if isinstance(tupled_inputs[0], torch.Tensor): + (xb, detection_targets, classification_targets) = tupled_inputs + detection_preds, classification_preds = self( + xb, step_type=ForwardType.TRAIN + ) + + elif isinstance(tupled_inputs[0], dict): + # TODO: Model method not yet implemented + data = dict(detection=tupled_inputs[0], classification=tupled_inputs[1]) + detection_targets = data["detection"]["targets"] + + # Go through (a nested dict) each task inside each group and fetch targets + classification_targets = {} + for group, datum in data["classification"].items(): + classification_targets.update(datum["targets"]) + + detection_preds, classification_preds = self( + data, step_type=ForwardType.TRAIN_MULTI_AUG + ) + + detection_loss = self.compute_loss(detection_preds, detection_targets)[0] + + # Iterate through each head and compute classification losses + classification_losses = { + name: head.compute_loss( + predictions=classification_preds[name], + targets=classification_targets[name], + ) + for name, head in self.model.classifier_heads.items() + } + total_classification_loss = sum(classification_losses.values()) + + self.log_losses( + "train", detection_loss, total_classification_loss, classification_losses + ) + + return detection_loss + total_classification_loss + + def validation_step(self, batch, batch_idx): + tupled_inputs, records = batch + (xb, detection_targets, classification_targets) = tupled_inputs + + with torch.no_grad(): + # Get bbox preds and unactivated classifier preds, ready to feed to loss funcs + (inference_det_preds, training_det_preds), classification_preds = self( + xb, step_type=ForwardType.EVAL + ) + + detection_loss = self.compute_loss(training_det_preds, detection_targets)[0] + classification_losses = { + name: head.compute_loss( + predictions=classification_preds[name], + targets=classification_targets[name], + ) + for name, head in self.model.classifier_heads.items() + } + total_classification_loss = sum(classification_losses.values()) + + # Run activation function on classification predictions + classification_preds = { + name: head.postprocess(classification_preds[name]) + for name, head in self.model.classifier_heads.items() + } + self.compute_and_log_classification_metrics( + classification_preds=classification_preds, + yb=classification_targets, + ) + + preds = convert_raw_predictions( + batch=xb, + records=records, + raw_detection_preds=inference_det_preds, + raw_classification_preds=classification_preds, + classification_configs=extract_classifier_pred_cfgs(self.model), + detection_threshold=0.001, + nms_iou_threshold=0.6, + keep_images=False, + ) + + self.accumulate_metrics(preds) + self.log_losses( + "valid", detection_loss, total_classification_loss, classification_losses + ) diff --git a/icevision/models/multitask/ultralytics/yolov5/model.py b/icevision/models/multitask/ultralytics/yolov5/model.py new file mode 100644 index 000000000..9e923ad17 --- /dev/null +++ b/icevision/models/multitask/ultralytics/yolov5/model.py @@ -0,0 +1,114 @@ +""" +Largely copied over from `icevision.models.ultralytics.yolov5.model` +The only aspect added is the ability to pass in a `Dict[str, ClassifierCongig]` to + create the classification heads +""" + + +__all__ = ["model"] + +from icevision.imports import * +from icevision.utils import * + +import yaml +import yolov5 +from yolov5.utils.google_utils import attempt_download +from yolov5.utils.torch_utils import intersect_dicts +from yolov5.utils.general import check_img_size + +# from icevision.models.ultralytics.yolov5.utils import * +from icevision.models.multitask.ultralytics.yolov5.utils import * +from icevision.models.ultralytics.yolov5.backbones import * + +from icevision.models.multitask.ultralytics.yolov5.arch.yolo_hybrid import HybridYOLOV5 +from icevision.models.multitask.classification_heads import ClassifierConfig + +yolo_dir = get_root_dir() / "yolo" +yolo_dir.mkdir(exist_ok=True) + + +def model( + backbone: YoloV5BackboneConfig, + num_detection_classes: int, + img_size: int, # must be multiple of 32 + device: Optional[torch.device] = None, + classifier_configs: Dict[str, ClassifierConfig] = None, +) -> HybridYOLOV5: + """ + Build a `HybridYOLOV5` Multitask Model with detection & classification heads. + + Args: + backbone (YoloV5BackboneConfig): Config from `icevision.models.ultralytics.yolov5.backbones.{}` + num_detection_classes (int): Number of object detection classes (including background) + img_size (int): Size of input images (assumes square inputs) + classifier_configs (Dict[str, ClassifierConfig], optional): A dictionary mapping of `ClassifierConfig`s + where each key corresponds to the name of the task in the input records. Defaults to None. + + Returns: + HybridYOLOV5: A multitask YOLO-V5 model with one detection head and `len(classifier_configs)` + classification heads + """ + model_name = backbone.model_name + pretrained = backbone.pretrained + + # this is to remove background from ClassMap as discussed + # here: https://github.com/ultralytics/yolov5/issues/2950 + # and here: https://discord.com/channels/735877944085446747/782062040168267777/836692604224536646 + # so we should pass `num_detection_classes=parser.class_map.num_detection_classes` + num_detection_classes -= 1 + + device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None + else device + ) + + cfg_filepath = Path(yolov5.__file__).parent / f"models/{model_name}.yaml" + if pretrained: + weights_path = yolo_dir / f"{model_name}.pt" + + with open(Path(yolov5.__file__).parent / "data/hyp.finetune.yaml") as f: + hyp = yaml.load(f, Loader=yaml.SafeLoader) + + attempt_download(weights_path) # download if not found locally + sys.path.insert(0, str(Path(yolov5.__file__).parent)) + ckpt = torch.load(weights_path, map_location=device) # load checkpoint + sys.path.remove(str(Path(yolov5.__file__).parent)) + if hyp.get("anchors"): + ckpt["model"].yaml["anchors"] = round(hyp["anchors"]) # force autoanchor + model = HybridYOLOV5( + cfg_filepath or ckpt["model"].yaml, + ch=3, + nc=num_detection_classes, + classifier_configs=classifier_configs, + ).to(device) + exclude = [] # exclude keys + state_dict = ckpt["model"].float().state_dict() # to FP32 + state_dict = intersect_dicts( + state_dict, model.state_dict(), exclude=exclude + ) # intersect + model.load_state_dict(state_dict, strict=False) # load + else: + with open(Path(yolov5.__file__).parent / "data/hyp.scratch.yaml") as f: + hyp = yaml.load(f, Loader=yaml.SafeLoader) + + model = HybridYOLOV5( + cfg_filepath, + ch=3, + nc=num_detection_classes, + anchors=hyp.get("anchors"), + classifier_configs=classifier_configs, + ).to(device) + + gs = int(model.stride.max()) # grid size (max stride) + nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj']) + imgsz = check_img_size(img_size, gs) # verify imgsz are gs-multiples + + hyp["box"] *= 3.0 / nl # scale to layers + hyp["cls"] *= num_detection_classes / 80.0 * 3.0 / nl # scale to classes and layers + hyp["obj"] *= (imgsz / 640) ** 2 * 3.0 / nl # scale to image size and layers + model.nc = num_detection_classes # attach number of classes to model + model.hyp = hyp # attach hyperparameters to model + model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) + + return model diff --git a/icevision/models/multitask/ultralytics/yolov5/prediction.py b/icevision/models/multitask/ultralytics/yolov5/prediction.py new file mode 100644 index 000000000..e984c0c20 --- /dev/null +++ b/icevision/models/multitask/ultralytics/yolov5/prediction.py @@ -0,0 +1,130 @@ +""" +Largely copied over from `icevision.models.ultralytics.yolov5.prectiction`, but with +classification added +""" + +from icevision.models.multitask.utils.model import ForwardType +from icevision.utils.utils import unroll_dict +from icevision.imports import * +from icevision.utils import * +from icevision.core import * +from icevision.data import * +from icevision.models.utils import _predict_from_dl + +from icevision.models.multitask.ultralytics.yolov5.dataloaders import * +from icevision.models.ultralytics.yolov5.prediction import ( + convert_raw_predictions as convert_raw_detection_predictions, +) +from icevision.models.multitask.utils.prediction import * + + +@torch.no_grad() +def _predict_batch( + model: nn.Module, + batch: Sequence[Tensor], + records: Sequence[BaseRecord], + detection_threshold: float = 0.25, + nms_iou_threshold: float = 0.45, + keep_images: bool = False, + device: Optional[torch.device] = None, +) -> List[Prediction]: + # device issue addressed on discord: https://discord.com/channels/735877944085446747/770279401791160400/832361687855923250 + if device is not None: + raise ValueError( + "For YOLOv5 device can only be specified during model creation, " + "for more info take a look at the discussion here: " + "https://discord.com/channels/735877944085446747/770279401791160400/832361687855923250" + ) + grid = model.model[-1].grid[-1] + # if `grid.numel() == 1` it means the grid isn't initialized yet and we can't + # trust it's device (will always be CPU) + device = grid.device if grid.numel() > 1 else model_device(model) + + batch = batch[0].to(device) + model = model.eval().to(device) + + (det_preds, _), classif_preds = model(batch, step_type=ForwardType.INFERENCE) + classification_configs = extract_classifier_pred_cfgs(model) + + return convert_raw_predictions( + batch=batch, + raw_detection_preds=det_preds, + raw_classification_preds=classif_preds, + records=records, + classification_configs=classification_configs, + detection_threshold=detection_threshold, + nms_iou_threshold=nms_iou_threshold, + keep_images=keep_images, + ) + + +def predict( + model: nn.Module, + dataset: Dataset, + detection_threshold: float = 0.25, + nms_iou_threshold: float = 0.45, + keep_images: bool = False, + device: Optional[torch.device] = None, +) -> List[Prediction]: + batch, records = build_infer_batch(dataset) + return _predict_batch( + model=model, + batch=batch, + records=records, + detection_threshold=detection_threshold, + nms_iou_threshold=nms_iou_threshold, + keep_images=keep_images, + device=device, + ) + + +def predict_from_dl( + model: nn.Module, + infer_dl: DataLoader, + show_pbar: bool = True, + keep_images: bool = False, + **predict_kwargs, +): + return _predict_from_dl( + predict_fn=_predict_batch, + model=model, + infer_dl=infer_dl, + show_pbar=show_pbar, + keep_images=keep_images, + **predict_kwargs, + ) + + +def convert_raw_predictions( + batch, + raw_detection_preds: Tensor, + raw_classification_preds: TensorDict, + records: Sequence[BaseRecord], + classification_configs: dict, + detection_threshold: float = 0.4, + nms_iou_threshold: float = 0.6, + keep_images: bool = False, +): + preds = convert_raw_detection_predictions( + batch=batch, + raw_preds=raw_detection_preds, + records=records, + detection_threshold=detection_threshold, + nms_iou_threshold=nms_iou_threshold, + keep_images=keep_images, + ) + for pred, raw_classification_pred in zipsafe( + preds, unroll_dict(raw_classification_preds) + ): + add_classification_components_to_pred_record( + pred_record=pred.pred, + classification_configs=classification_configs, + ) + postprocess_and_add_classification_preds_to_record( + gt_record=pred.ground_truth, + pred_record=pred.pred, + classification_configs=classification_configs, + raw_classification_pred=raw_classification_pred, + ) + + return preds diff --git a/icevision/models/multitask/ultralytics/yolov5/utils.py b/icevision/models/multitask/ultralytics/yolov5/utils.py new file mode 100644 index 000000000..5ae9ea83c --- /dev/null +++ b/icevision/models/multitask/ultralytics/yolov5/utils.py @@ -0,0 +1,29 @@ +""" +This file is redundant in terms of code as it uses the exact same code +as `icevision.models.ultralytics.yolov5.utils.YoloV5BackboneConfig` + +We're keeping it to maintain structure, and in case we want to change +something in the future that is multitask model specific +""" + +from icevision.models.ultralytics.yolov5.utils import YoloV5BackboneConfig + +__all__ = ["YoloV5BackboneConfig"] + +# from icevision.imports import * +# from icevision.backbones import BackboneConfig + + +# class YoloV5MultitaskBackboneConfig(BackboneConfig): +# def __init__(self, model_name: str): +# self.model_name = model_name +# self.pretrained: bool + +# def __call__(self, pretrained: bool = True) -> "YoloV5MultitaskBackboneConfig": +# """Completes the configuration of the backbone + +# # Arguments +# pretrained: If True, use a pretrained backbone (on COCO). +# """ +# self.pretrained = pretrained +# return self diff --git a/icevision/models/multitask/utils/__init__.py b/icevision/models/multitask/utils/__init__.py new file mode 100644 index 000000000..3b4ead489 --- /dev/null +++ b/icevision/models/multitask/utils/__init__.py @@ -0,0 +1,2 @@ +from .dtypes import * +from .model import * diff --git a/icevision/models/multitask/utils/dtypes.py b/icevision/models/multitask/utils/dtypes.py new file mode 100644 index 000000000..bc305fa6f --- /dev/null +++ b/icevision/models/multitask/utils/dtypes.py @@ -0,0 +1,20 @@ +from typing import Dict, List, Tuple, Union +from torch import Tensor +import numpy as np +import torch + +__all__ = [ + "ImgMetadataDict", + "TensorList", + "TensorTuple", + "TensorDict", + "ArrayList", + "ArrayDict", +] + +ImgMetadataDict = Dict[str, Union[Tuple[int], np.ndarray]] +TensorList = List[Tensor] +TensorDict = Dict[str, Tensor] +TensorTuple = Tuple[Tensor] +ArrayList = List[np.ndarray] +ArrayDict = Dict[str, np.ndarray] diff --git a/icevision/models/multitask/utils/model.py b/icevision/models/multitask/utils/model.py new file mode 100644 index 000000000..710313bc5 --- /dev/null +++ b/icevision/models/multitask/utils/model.py @@ -0,0 +1,24 @@ +from enum import Enum +from torch.nn.modules.batchnorm import _BatchNorm +from torch import nn + +__all__ = ["ForwardType", "set_bn_eval"] + + +class ForwardType(Enum): + TRAIN_MULTI_AUG = 1 + TRAIN = 2 + EVAL = 3 + INFERENCE = 4 + # EXPORT_ONNX = 5 + # EXPORT_TORCHSCRIPT = 6 + # EXPORT_COREML = 7 + + +# Modified from from https://github.com/fastai/fastai/blob/4decc673ba811a41c6e3ab648aab96dd27244ff7/fastai/callback/training.py#L43-L49 +def set_bn_eval(m: nn.Module) -> None: + "Set bn layers in eval mode for all recursive, non-trainable children of `m`." + for l in m.children(): + if isinstance(l, _BatchNorm) and not next(l.parameters()).requires_grad: + l.eval() + set_bn_eval(l) diff --git a/icevision/models/multitask/utils/prediction.py b/icevision/models/multitask/utils/prediction.py new file mode 100644 index 000000000..93f1bc465 --- /dev/null +++ b/icevision/models/multitask/utils/prediction.py @@ -0,0 +1,110 @@ +from icevision.imports import * +from icevision.core import * +from icevision.utils import Dictionary +from icevision.models.multitask.classification_heads.head import ( + ImageClassificationHead, + ClassifierConfig, + TensorDict, +) +from icevision.core.tasks import Task + + +# __all__ = ["finalize_classifier_preds"] + + +def finalize_classifier_preds( + pred, cfg: Dictionary, record: RecordType, task: str +) -> tuple: + """ + Analyse preds post-activations based on `cfg` arguments; return the + relevant scores and string labels derived from `record` + + Can compute the following: + * top-k (`cfg` defaults to 1 for single-label problems) + * filter preds by threshold + """ + + # pred = np.array(pred) + pred = pred.detach().cpu().numpy() + + if cfg.topk is not None: + index = np.argsort(pred)[-cfg.topk :] # argsort gives idxs in ascending order + value = pred[index] + + elif cfg.thresh is not None: + index = np.where(pred > cfg.thresh)[0] # index into the tuple + value = pred[index] + + labels = [getattr(record, task).class_map._id2class[i] for i in index] + scores = pred[index].tolist() + + return labels, scores + + +def extract_classifier_pred_cfgs(model: nn.Module): + return { + name: Dictionary(multilabel=head.multilabel, topk=head.topk, thresh=head.thresh) + for name, head in model.classifier_heads.items() + } + + +def add_classification_components_to_pred_record( + pred_record: RecordType, classification_configs: dict +): + """ + Adds `ClassificationLabelsRecordComponent` and `ScoresRecordComponent` to `pred_record` + for each task; where the keys of `classification_configs` are the names of the tasks + + Args: + pred_record (RecordType) + classification_configs (dict) + + Returns: + [type]: [description] + """ + r = pred_record + for name, cfg in classification_configs.items(): + r.add_component(ScoresRecordComponent(Task(name=name))) + r.add_component( + ClassificationLabelsRecordComponent( + Task(name=name), is_multilabel=cfg.multilabel + ) + ) + return r + + +def postprocess_and_add_classification_preds_to_record( + gt_record: RecordType, + pred_record: RecordType, + classification_configs: dict, + raw_classification_pred: TensorDict, +): + """ + Postprocesses predictions based on `classification_configs` and adds the results + to `pred_record`. Uses `gt_record` to set the `pred_record`'s class maps + + Args: + gt_record (RecordType) + pred_record (RecordType) + + classification_configs (dict): A dict that describes how to postprocess raw + classification preds. Note that the raw preds are assumed to have already gone + through an activation function like Softmax or Sigmoid. For example: + dict( + multilabel=False, topk=1, thresh=None + ) + + raw_classification_pred (TensorDict): Container whose preds will be processed. Is + expected to have the exact same keys as `classification_configs` + """ + for task, classification_pred in raw_classification_pred.items(): + labels, scores = finalize_classifier_preds( + pred=classification_pred, + cfg=classification_configs[task], + record=gt_record, + task=task, + ) + # sub_record = getattr(pred_record, task) + getattr(pred_record, task).set_class_map(getattr(gt_record, task).class_map) + getattr(pred_record, task).set_labels(labels) + getattr(pred_record, task).set_scores(scores) diff --git a/icevision/utils/utils.py b/icevision/utils/utils.py index b37e80ac0..039c3ae3a 100644 --- a/icevision/utils/utils.py +++ b/icevision/utils/utils.py @@ -16,9 +16,12 @@ "denormalize_imagenet", "denormalize_mask", "patch_class_to_main", + "flatten", + "Dictionary", ] from icevision.imports import * +from addict import Dict as _Dict def notnone(x): @@ -109,3 +112,50 @@ def patch_class_to_main(cls): setattr(__main__, cls.__name__, cls) cls.__module__ = "__main__" return cls + + +def flatten(x: Any) -> List[Any]: + import pandas as pd + + flattened_list = [] + for item in x: + if isinstance(item, (tuple, list, np.ndarray, pd.Series)): + [flattened_list.append(i) for i in item] + else: + flattened_list.append(item) + return flattened_list + + +def unroll_dict(x: dict) -> List[dict]: + """ + Unroll a dictionary into a list of dictionaries where the key is repeated. + Useful when you want to throw a dictionary into a for loop + + Args: + x (dict) + + Returns: + List[dict] + + Example: + x = dict( + location=[[0.8, 0.2], [0.9, 0.1]], + lighting=[[0.6, 0.4], [0.2, 0.8]] + ) + unroll_dict(x) == [ + {"location": [0.8, 0.2], "lighting": [0.6, 0.4]}, + {"location": [0.9, 0.1], "lighting": [0.2, 0.8]}, + ] + """ + return [dict(zip(x, t)) for t in zipsafe(*x.values())] + + +[ + {"location": [0.8, 0.2], "lighting": [0.6, 0.4]}, + {"location": [0.9, 0.1], "lighting": [0.2, 0.8]}, +] + + +class Dictionary(_Dict): + def __missing__(self, key): + raise KeyError(key) diff --git a/notebooks/multitask.ipynb b/notebooks/multitask.ipynb new file mode 100644 index 000000000..b587a47d8 --- /dev/null +++ b/notebooks/multitask.ipynb @@ -0,0 +1,446 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3a909b36", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9bcfe126", + "metadata": {}, + "outputs": [], + "source": [ + "from icevision.imports import *\n", + "from icevision.models.multitask.ultralytics.yolov5 import *\n", + "from icevision.data.data_splitter import *\n", + "from icevision.visualize import *\n", + "from icevision.metrics import *\n", + "\n", + "import icedata.datasets.exdark_trimmed as exdark" + ] + }, + { + "cell_type": "markdown", + "id": "ce2db1e1", + "metadata": {}, + "source": [ + "#### Regular Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02713d8d", + "metadata": {}, + "outputs": [], + "source": [ + "IMG_SIZE=512\n", + "data_dir = exdark.load_data()\n", + "data_dir = Path(\"/Users/rahulsomani/datasets/ExDark-Trimmed/\")\n", + "parser = exdark.parser(data_dir)\n", + "\n", + "train_records, valid_records = parser.parse(data_splitter=RandomSplitter([0.8, 0.2]))\n", + "train_tfms = tfms.A.Adapter(\n", + " [\n", + " *tfms.A.aug_tfms(size=IMG_SIZE, lightning=None),\n", + " tfms.A.Normalize(),\n", + " ]\n", + ")\n", + "valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(IMG_SIZE), tfms.A.Normalize()])\n", + "\n", + "train_ds = Dataset(train_records, tfm=train_tfms)\n", + "valid_ds = Dataset(valid_records, tfm=valid_tfms)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bd020e0", + "metadata": {}, + "outputs": [], + "source": [ + "dl_train = train_dl(train_ds, batch_size=32)\n", + "dl_valid = valid_dl(valid_ds, batch_size=64)" + ] + }, + { + "cell_type": "markdown", + "id": "4f577bcc", + "metadata": {}, + "source": [ + "#### Multi Augmentation Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af5e084", + "metadata": {}, + "outputs": [], + "source": [ + "from icevision.models.multitask.data.dataset import HybridAugmentationsRecordDataset\n", + "import torchvision.transforms as Tfms" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60d54941", + "metadata": {}, + "outputs": [], + "source": [ + "detection_train_transforms = tfms.A.Adapter(\n", + " [\n", + " # tfms.A.Normalize(), # NOTE: Normalizing happens inside the `Dataset` itself\n", + " tfms.A.Resize(height=IMG_SIZE, width=IMG_SIZE),\n", + " tfms.A.RandomSizedBBoxSafeCrop(\n", + " width=IMG_SIZE, height=IMG_SIZE, erosion_rate=0.2\n", + " ),\n", + " # tfms.A.PadIfNeeded(IMG_HEIGHT, IMG_WIDTH, border_mode=cv2.BORDER_CONSTANT),\n", + " tfms.A.ChannelDropout(p=0.05),\n", + " tfms.A.HorizontalFlip(p=0.5),\n", + " tfms.A.VerticalFlip(p=0.2),\n", + " tfms.A.ColorJitter(p=0.3), # This may destroy some information for lighting\n", + " tfms.A.JpegCompression(p=0.1),\n", + " ]\n", + ")\n", + "\n", + "valid_transforms = tfms.A.Adapter(\n", + " [\n", + " tfms.A.Normalize(),\n", + " tfms.A.Resize(height=IMG_SIZE, width=IMG_SIZE),\n", + " ]\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99c68ac0", + "metadata": {}, + "outputs": [], + "source": [ + "classification_tfms = dict(\n", + " group_1=dict(\n", + " tasks=[\"lighting\"],\n", + " transforms=Tfms.Compose(\n", + " [\n", + " Tfms.RandomPerspective(),\n", + " Tfms.Resize((IMG_SIZE, IMG_SIZE)),\n", + " Tfms.RandomHorizontalFlip(),\n", + " Tfms.RandomVerticalFlip(),\n", + " Tfms.RandomAffine(degrees=20),\n", + " Tfms.RandomAutocontrast(),\n", + " ]\n", + " )\n", + " ),\n", + " group_2=dict(\n", + " tasks=[\"location\"],\n", + " transforms=Tfms.Compose(\n", + " [\n", + " Tfms.RandomPerspective(),\n", + " Tfms.Resize((IMG_SIZE, IMG_SIZE)),\n", + " Tfms.RandomHorizontalFlip(),\n", + " Tfms.RandomVerticalFlip(),\n", + " Tfms.RandomAffine(degrees=20),\n", + " Tfms.RandomAutocontrast(),\n", + " Tfms.RandomChoice(\n", + " [Tfms.ColorJitter(), Tfms.RandomGrayscale(), Tfms.RandomEqualize()]\n", + " ),\n", + " ]\n", + " )\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f64157d", + "metadata": {}, + "outputs": [], + "source": [ + "train_ds = HybridAugmentationsRecordDataset(\n", + " records=train_records,\n", + " classification_transforms_groups=classification_tfms,\n", + " detection_transforms=detection_train_transforms,\n", + ")\n", + "valid_ds = Dataset(valid_records, tfm=valid_transforms)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8257640d", + "metadata": {}, + "outputs": [], + "source": [ + "train_ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2659baa6", + "metadata": {}, + "outputs": [], + "source": [ + "train_ds[2]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "119791cf", + "metadata": {}, + "outputs": [], + "source": [ + "valid_ds[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0064494", + "metadata": {}, + "outputs": [], + "source": [ + "dl_train = train_dl_multi_aug(train_ds, classification_tfms, batch_size=8)\n", + "dl_valid = valid_dl(valid_ds, batch_size=8)" + ] + }, + { + "cell_type": "markdown", + "id": "3840e81e", + "metadata": {}, + "source": [ + "### Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cac65c58", + "metadata": {}, + "outputs": [], + "source": [ + "hybrid_model = model(\n", + " backbone=backbones.small(pretrained=True),\n", + " # backbone=backbones.large(pretrained=True),\n", + " num_detection_classes=len(parser.CLASS_MAPS['detection']),\n", + " classifier_configs={\n", + " name: ClassifierConfig(out_classes=len(cm))\n", + " for name, cm in parser.CLASS_MAPS.items() if not name==\"detection\"\n", + " },\n", + " img_size=IMG_SIZE,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25f8d008", + "metadata": {}, + "outputs": [], + "source": [ + "from torch import optim\n", + "import pytorch_lightning as pl\n", + "\n", + "class LightModel(lightning.HybridYOLOV5LightningAdapter):\n", + " def configure_optimizers(self):\n", + " return optim.Adam(self.parameters(), lr=1e-4)\n", + "\n", + "pl_model = LightModel(\n", + " model=hybrid_model,\n", + " metrics=[COCOMetric(metric_type=COCOMetricType.bbox)],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "317995e3", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=20, gpus=[0])\n", + "trainer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eeb20f86", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.fit(pl_model, dl_train, dl_valid)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad93fcdd", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "ce718acc", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1d7c7e9", + "metadata": {}, + "outputs": [], + "source": [ + "from icevision.models.multitask.ultralytics.yolov5.prediction import *" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5ff993a", + "metadata": {}, + "outputs": [], + "source": [ + "valid_ds = Dataset(valid_records[:20], tfm=valid_tfms)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ed84e62", + "metadata": {}, + "outputs": [], + "source": [ + "preds = predict(\n", + " model=pl_model.model,\n", + " dataset=valid_ds,\n", + " detection_threshold=0.4,\n", + " keep_images=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0782f8f9", + "metadata": {}, + "outputs": [], + "source": [ + "draw_sample = partial(draw_sample, denormalize_fn=denormalize_imagenet, return_as_pil_img=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f2a8ec7", + "metadata": {}, + "outputs": [], + "source": [ + "import fastcore.all as fastcore\n", + "import PIL\n", + "import PIL.Image\n", + "\n", + "@fastcore.patch\n", + "def __or__(self: PIL.Image.Image, other: PIL.Image.Image):\n", + " \"Horizontally stack two PIL Images\"\n", + " assert isinstance(other, PIL.Image.Image)\n", + " widths, heights = zip(*(i.size for i in [self, other]))\n", + "\n", + " new_img = PIL.Image.new(\"RGB\", (sum(widths), max(heights)))\n", + " x_offset = 0\n", + " for img in [self, other]:\n", + " new_img.paste(img, (x_offset, 0))\n", + " x_offset += img.size[0]\n", + " return new_img" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f826796", + "metadata": {}, + "outputs": [], + "source": [ + "pred = preds[19]\n", + "p, gt = pred.pred, pred.ground_truth\n", + "\n", + "draw_sample(gt) | draw_sample(p)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7c50992", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/setup.cfg b/setup.cfg index 0edd535a5..6789f67e0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,6 +37,7 @@ install_requires = loguru >=0.5.3 pillow > 8.0.0 importlib-metadata>=1;python_version<"3.8" + addict [options.extras_require] all = diff --git a/tests/multitask/__init__.py b/tests/multitask/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/multitask/ultralytics/__init__.py b/tests/multitask/ultralytics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/multitask/ultralytics/yolov5/__init__.py b/tests/multitask/ultralytics/yolov5/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/multitask/ultralytics/yolov5/test_yolo_hybrid.py b/tests/multitask/ultralytics/yolov5/test_yolo_hybrid.py new file mode 100644 index 000000000..f4314cedf --- /dev/null +++ b/tests/multitask/ultralytics/yolov5/test_yolo_hybrid.py @@ -0,0 +1,48 @@ +# from numpy.lib.arraysetops import isin +# import pytest +# from icevision.imports import * +# from icevision.models.multitask.ultralytics.yolov5.yolo_hybrid import * +# from icevision.models.multitask.utils import * + + +# @pytest.fixture +# def model(): +# return HybridYOLOV5( +# cfg="models/yolov5m.yaml", +# classifier_configs=dict( +# framing=ClassifierConfig(out_classes=10, num_fpn_features=10), +# saturation=ClassifierConfig(out_classes=20, num_fpn_features=None), +# ), +# ) + + +# def x(): +# return torch.rand(1, 3, 224, 224) + + +# def test_forward(model, x): +# det_out, clf_out = model.forward_once(x) +# assert isinstance(det_out, TensorList) +# assert isinstance(clf_out, TensorDict) +# assert det_out[0].ndim == 5 + + +# def test_forward_eval(model, x): +# det_out, clf_out = model.forward_once(x) + +# assert len(det_out == 2) +# assert isinstance(det_out[0], Tensor) +# assert isinstance(det_out[1], TensorList) + + +# def test_feature_extraction(model, x): +# det_out, clf_out = model.forward_once( +# forward_detection=False, forward_classification=False +# ) +# assert det_out[0].ndim == 3 +# assert clf_out == {} + + +# def test_fwd_inference(model, x): +# det_out, clf_out = model.forward_once(activate_classification=True) +# torch.allclose(clf_out["framing"].sum(), tensor(1.0))