Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Multitask Training With mmdet and yolov5 Models #856

Draft
wants to merge 122 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
662feb5
add classifier heads
rsomani95 May 31, 2021
a6048cb
add multi augmentation dataset
rsomani95 May 31, 2021
22f7846
add `flatten` utility
rsomani95 May 31, 2021
6275472
hybrid single stage detector, dataloader, prediction
rsomani95 May 31, 2021
ce033d5
lightning adapter
rsomani95 May 31, 2021
0dabe85
add loss weight param
rsomani95 Jun 5, 2021
622e57f
return activated preds in eval mode; import `ClassifierConfig`
rsomani95 Jun 12, 2021
7447244
use `ClassifierConfig`s, doc improvements
rsomani95 Jun 12, 2021
eeee30a
add experimental onnx forward method
rsomani95 Jun 12, 2021
db46a49
rename for consistency with library
rsomani95 Jun 13, 2021
3dcfe4e
placeholders
rsomani95 Jun 13, 2021
9bbe4b3
implement hybrid yolov5
rsomani95 Jun 13, 2021
e56fa0e
add sample usage in docstring
rsomani95 Jun 13, 2021
bcc0256
basic dataloader
rsomani95 Jun 13, 2021
0a76855
fix error
rsomani95 Jun 13, 2021
5deddd8
variable names
rsomani95 Jun 13, 2021
c72acfd
fix
rsomani95 Jun 13, 2021
2bcb094
multi aug dataloader, fix typo
rsomani95 Jun 13, 2021
01ea835
add links to discord where relevant
rsomani95 Jun 13, 2021
6f4da15
keep code consistent
rsomani95 Jun 14, 2021
c7bdfa2
create `utils` module; move common code there
rsomani95 Jun 14, 2021
2e1e2d4
update
rsomani95 Jun 14, 2021
35e0cec
rename with lib naming scheme
rsomani95 Jun 14, 2021
4ce890f
fix imports, docs
rsomani95 Jun 14, 2021
a7b374c
forward api (wip)
rsomani95 Jun 14, 2021
8a8704d
fix return format, add docs for multi-aug dataloader
rsomani95 Jun 14, 2021
516edcc
add icevision style model API
rsomani95 Jun 14, 2021
1a8857f
lightning adapter w/ train step (TODO val)
rsomani95 Jun 14, 2021
73796a5
reorganise module to mimic the rest of the library
rsomani95 Jun 14, 2021
4ad3c08
documentation.
rsomani95 Jun 14, 2021
eac9ed4
move dataloading util to common module
rsomani95 Jun 15, 2021
0817f07
add `addict` safe Dictionary
rsomani95 Jun 15, 2021
7f66f71
remove accidental import
rsomani95 Jun 15, 2021
d20ed45
doc, type anno
rsomani95 Jun 15, 2021
d40edae
higher level dataloading functions
rsomani95 Jun 15, 2021
2acc516
TODO val step
rsomani95 Jun 15, 2021
d73d418
-___-
rsomani95 Jun 15, 2021
7bc3604
add doc for high level `model` creator
rsomani95 Jun 15, 2021
08bc1be
fix bug where I forgot to enter the dict key, making the output a `se…
rsomani95 Jun 15, 2021
83f1768
return same outputs regardless of mode;
rsomani95 Jun 15, 2021
3fb8a4b
correct tuple unpacking
rsomani95 Jun 15, 2021
213de99
add `forward_export`, different `step_type`s for exporting
rsomani95 Jun 15, 2021
af7cc05
WIP notebook - move to GPU and re-run.
rsomani95 Jun 15, 2021
00aec88
create common prediction utils for classification
rsomani95 Jun 15, 2021
26b6b30
add `unroll_dict`
rsomani95 Jun 15, 2021
d092e98
add yolov5 multitask raw predictions converter
rsomani95 Jun 15, 2021
b50309b
rename `forward_export` -> `forward_eval`; minor changes
rsomani95 Jun 15, 2021
63385d8
add higher level pred funs to yolov5... will this work?
rsomani95 Jun 15, 2021
648cfe7
add validation code. lets gooo
rsomani95 Jun 15, 2021
4dafd91
revert to common fwd method for train/eval mode;
rsomani95 Jun 15, 2021
c4204ae
add todos
rsomani95 Jun 15, 2021
6e0e99f
add classification metrics
rsomani95 Jun 15, 2021
71056ae
bugfix
rsomani95 Jun 15, 2021
08895b3
forgot to log metrics....
rsomani95 Jun 15, 2021
7017661
bugfixxxeessss
rsomani95 Jun 15, 2021
6726b85
successful training example with lightning
rsomani95 Jun 15, 2021
3f62a73
minor polishing
rsomani95 Jun 15, 2021
8eb9cff
add tensortuple dtype
rsomani95 Jun 16, 2021
5d12378
* modularise `forward` to skip classif / detection specific parts of …
rsomani95 Jun 16, 2021
a09f981
multi aug forward for yolov5
rsomani95 Jun 16, 2021
d06fa63
minor __repr__ bugfix
rsomani95 Jun 16, 2021
01231ea
properly unpack multi aug data
rsomani95 Jun 16, 2021
472e128
update w/ multi aug example
rsomani95 Jun 16, 2021
52e27b2
bugfix
rsomani95 Jun 17, 2021
32dd165
simplify forward method
rsomani95 Jun 17, 2021
dcbe55c
param freezing scheme
rsomani95 Jun 17, 2021
c9d310d
simplify forward modes
rsomani95 Jun 18, 2021
620ee51
add `extrace_features`; minor cleanup
rsomani95 Jun 18, 2021
8033f51
super awkard test scaffolding
rsomani95 Jun 18, 2021
3195385
rename `build_classifier_heads` -> `build_classification_modules`
rsomani95 Jun 18, 2021
3d820ba
move classifiers init location;
rsomani95 Jun 18, 2021
8993694
modularise `forward_inference`
rsomani95 Jun 18, 2021
0d3e08a
make pooling inputs optional when not using fpn inputs
rsomani95 Jun 19, 2021
054a1cc
move to `arch` folder
rsomani95 Jun 20, 2021
462f85b
add unfreezing; modularise freezing, param groups as pseudo mixins - …
rsomani95 Jun 20, 2021
d5037b5
fix import path
rsomani95 Jun 21, 2021
f271a38
fix import path... again
rsomani95 Jun 21, 2021
ffca2f0
batchnorm freezing
rsomani95 Jun 21, 2021
a63136b
add warning
rsomani95 Jun 21, 2021
8948a0e
store fpn dims as an attribute
rsomani95 Jun 22, 2021
d08d4a6
safety mechanism
rsomani95 Jun 22, 2021
ecd693b
dumb bugfix
rsomani95 Jun 22, 2021
19ea13e
klsvbaolskfbvjklfb WTF
rsomani95 Jun 22, 2021
2a198fb
model unfreezing bugfix
rsomani95 Jun 22, 2021
a65c8be
**hangs head in shame**
rsomani95 Jun 22, 2021
1200f58
freezing interface
rsomani95 Jun 26, 2021
f932b10
rename func
rsomani95 Jun 26, 2021
69e4f82
bugfix
rsomani95 Jun 27, 2021
31dddac
* higher level freeze/unfreeze detector
rsomani95 Jun 27, 2021
6428245
move wts to gpu if available (reqd)
rsomani95 Jun 28, 2021
f3b94f8
fix formatting
rsomani95 Jun 28, 2021
89b4fec
iterate on freezing interface
rsomani95 Jun 28, 2021
a253d58
cast loss func weights to fp32 (double by default)
rsomani95 Jun 28, 2021
b60ee28
try moving away from functional approach to avoid cryptic errors
rsomani95 Jun 28, 2021
18e7737
bugfix
rsomani95 Jun 28, 2021
ee5ef8e
...
rsomani95 Jun 28, 2021
4002c4d
patch
rsomani95 Jun 28, 2021
ca52794
log `total_loss` for easier model checkpointing
rsomani95 Jun 30, 2021
0cdfdcb
mystical bugfix
rsomani95 Jun 30, 2021
a2740a0
flexibility to define custom record loading logic
rsomani95 Jun 30, 2021
1ba47e8
model frezinggg buuugggggfix.. i hope
rsomani95 Jun 30, 2021
f4a59d7
use `load_record` for data validation
rsomani95 Jul 1, 2021
67bf463
typo
rsomani95 Jul 1, 2021
9900a38
extra safe record unloading (experiment)
rsomani95 Jul 3, 2021
8de99b2
generalise to all yolo architectures
rsomani95 Jul 6, 2021
fc52a6f
parametrise `num_bbone_blocks`
rsomani95 Jul 12, 2021
aa4b515
fix subtle bbone block idxs bug
rsomani95 Jul 12, 2021
cfe306b
#$%^&*!!!! really need some tests
rsomani95 Jul 12, 2021
56913dd
add `freeze_neck/fpn`
rsomani95 Jul 24, 2021
fcf2d3d
refactor classification dataloading
rsomani95 Jul 26, 2021
2c34448
some more refactoring. TODO: test
rsomani95 Jul 26, 2021
4561be1
refactor logging methods
rsomani95 Jul 26, 2021
72092cd
add speedup
rsomani95 Jul 27, 2021
718fb52
move stuff around
rsomani95 Jul 27, 2021
c2b958c
remove debug mode
rsomani95 Jul 28, 2021
5b50c4b
temp `im2tensor` patch while we discuss on Discord
rsomani95 Jul 30, 2021
e33f9a1
vastly simplified pipeline
rsomani95 Jul 30, 2021
2539faf
patch `set_img` to include `torch.Tensor`
rsomani95 Jul 30, 2021
6b17cb8
proper tfm dispatching
rsomani95 Jul 30, 2021
4ebc006
cleanup
rsomani95 Jul 31, 2021
8cc5474
remove unused arg
rsomani95 Aug 9, 2021
d14a453
auto calculate fpn dims
rsomani95 Aug 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions icevision/core/record_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,16 @@ def __init__(self, task=tasks.common):
super().__init__(task=task)
self.img = None

def set_img(self, img: Union[PIL.Image.Image, np.ndarray]):
assert isinstance(img, (PIL.Image.Image, np.ndarray))
def set_img(self, img: Union[PIL.Image.Image, np.ndarray, torch.Tensor]):
assert isinstance(img, (PIL.Image.Image, np.ndarray, torch.Tensor))
self.img = img
if isinstance(img, PIL.Image.Image):
height, width = img.shape
elif isinstance(img, np.ndarray):
# else:
height, width, _ = self.img.shape
elif isinstance(img, torch.Tensor):
_, height, width = self.img.shape
# this should set on SizeRecordComponent
self.composite.set_img_size(ImgSize(width=width, height=height), original=True)

Expand Down
13 changes: 12 additions & 1 deletion icevision/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
CosineAnnealingWarmRestarts,
)

from torchvision.transforms.functional import to_tensor as im2tensor
from torchvision.transforms.functional import to_tensor

from loguru import logger

Expand Down Expand Up @@ -92,3 +92,14 @@ def __str__(self):

def __repr__(self):
return str(self)


def im2tensor(pic: Union[np.ndarray, PIL.Image.Image, torch.Tensor]):
if isinstance(pic, torch.Tensor):
return pic
elif isinstance(pic, (np.ndarray, PIL.Image.Image)):
return to_tensor(pic)
else:
raise TypeError(
f"Expected {np.ndarray} | {PIL.Image.Image} | {torch.Tensor}, got {type(pic)}"
)
Empty file.
2 changes: 2 additions & 0 deletions icevision/models/multitask/classification_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .builder import *
from .head import *
47 changes: 47 additions & 0 deletions icevision/models/multitask/classification_heads/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Dict
from .head import CLASSIFICATION_HEADS, ImageClassificationHead, ClassifierConfig
import torch.nn as nn

__all__ = ["build_classifier_heads", "build_classifier_heads_from_configs"]

# Enter dict of dicts as `cfg`
def build_classifier_heads(configs: Dict[str, Dict[str, dict]]) -> nn.ModuleDict:
"""
Build classification head from a config which is a dict of dicts.
A head is created for each key in the input dictionary.

Expected to be used with `mmdet` models as it uses the
`CLASSIFICATION_HEADS` registry internally

Returns:
a `nn.ModuleDict()` mapping keys from `configs` to classifier heads
"""
heads = nn.ModuleDict()
# if configs is not None:
for name, config in configs.items():
head = CLASSIFICATION_HEADS.build(config)
heads.update({name: head})
return heads


def build_classifier_heads_from_configs(
configs: Dict[str, ClassifierConfig] = None
) -> nn.ModuleDict:
"""
Build a `nn.ModuleDict` of `ImageClassificationHead`s from a list of `ClassifierConfig`s
"""
if configs is None:
return nn.ModuleDict()

assert isinstance(configs, dict), f"Expected a `dict`, got {type(configs)}"
if not all(isinstance(cfg, ClassifierConfig) for cfg in configs.values()):
raise ValueError(
f"Expected a `list` of `ClassifierConfig`s \n"
f"Either one or more elements in the list are not of type `ClassifierConfig`"
)

heads = nn.ModuleDict()
for name, config in configs.items():
head = ImageClassificationHead.from_config(config)
heads.update({name: head})
return heads
216 changes: 216 additions & 0 deletions icevision/models/multitask/classification_heads/head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Hacked together by Rahul & Farid

from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import List, Union, Optional, Dict
from torch import Tensor
from functools import partial
from collections import namedtuple
from dataclasses import dataclass

TensorList = List[Tensor]
TensorDict = Dict[str, Tensor]

MODELS = Registry("models", parent=MMCV_MODELS)
CLASSIFICATION_HEADS = MODELS

__all__ = ["ImageClassificationHead", "ClassifierConfig"]


class Passthrough(nn.Module):
def forward(self, x):
return x


"""
`ClassifierConfig` is useful to instantiate `ImageClassificationHead`
in different settings. If using `mmdet`, we don't use this as the config
is then a regular dictionary.

When using yolov5, we can easily pass around this config to create the model
Often, it'll be used inside a dictionary of configs
"""


@dataclass
class ClassifierConfig:
# classifier_name: str
out_classes: int
num_fpn_features: int = 512
fpn_keys: Union[List[str], List[int], None] = None
dropout: Optional[float] = 0.2
pool_inputs: bool = True
# Loss function args
loss_func: Optional[nn.Module] = None
activation: Optional[nn.Module] = None
multilabel: bool = False
loss_func_wts: Optional[Tensor] = None
loss_weight: float = 1.0
# Post activation processing
thresh: Optional[float] = None
topk: Optional[int] = None

def __post_init__(self):
if isinstance(self.fpn_keys, int):
self.fpn_keys = [self.fpn_keys]

if self.loss_func_wts is not None:
if not self.multilabel:
self.loss_func_wts = self.loss_func_wts.to(torch.float32)
if torch.cuda.is_available():
self.loss_func_wts = self.loss_func_wts.cuda()

if self.multilabel:
if self.topk is None and self.thresh is None:
self.thresh = 0.5
else:
if self.topk is None and self.thresh is None:
self.topk = 1


@CLASSIFICATION_HEADS.register_module(name="ImageClassificationHead")
class ImageClassificationHead(nn.Module):
"""
Image classification head that optionally takes `fpn_keys` features from
an FPN, average pools and concatenates them into a single tensor
of shape `num_features` and then runs a linear layer to `out_classes

fpn_features: [List[Tensor]] => AvgPool => Flatten => Linear`

Also includes `compute_loss` to match the design of other
components of object detection systems.
To use your own loss function, pass it into `loss_func`.
If `loss_func` is None (by default), we create one based on other args:
If `multilabel` is true, one-hot encoded targets are expected and
nn.BCEWithLogitsLoss is used, else nn.CrossEntropyLoss is used
and targets are expected to be integers
NOTE: Not all loss function args are exposed
"""

def __init__(
self,
out_classes: int,
num_fpn_features: int,
fpn_keys: Union[List[str], List[int], None] = None,
dropout: Optional[float] = 0.2,
pool_inputs: bool = True, # ONLY for advanced use cases where input feature maps are already pooled
# Loss function args
loss_func: Optional[nn.Module] = None,
activation: Optional[nn.Module] = None,
multilabel: bool = False,
loss_func_wts: Optional[Tensor] = None,
loss_weight: float = 1.0,
# Final postprocessing args
thresh: Optional[float] = None,
topk: Optional[int] = None,
):
super().__init__()

# Setup loss function & activation
self.multilabel = multilabel
self.loss_func, self.loss_func_wts, self.loss_weight = (
loss_func,
loss_func_wts,
loss_weight,
)
self.activation = activation
self.pool_inputs = pool_inputs
self.thresh, self.topk = thresh, topk

# Setup head
self.fpn_keys = fpn_keys

layers = [
nn.Dropout(dropout) if dropout else Passthrough(),
nn.Linear(num_fpn_features, out_classes),
]
layers.insert(0, nn.Flatten(1)) if self.pool_inputs else None
self.classifier = nn.Sequential(*layers)

self.setup_loss_function()
self.setup_postprocessing()

def setup_postprocessing(self):
if self.multilabel:
if self.topk is None and self.thresh is None:
self.thresh = 0.5
else:
if self.topk is None and self.thresh is None:
self.topk = 1

def setup_loss_function(self):
if self.loss_func is None:
if self.multilabel:
self.loss_func = nn.BCEWithLogitsLoss(pos_weight=self.loss_func_wts)
# self.loss_func = partial(
# F.binary_cross_entropy_with_logits, pos_weight=self.loss_func_wts
# )
self.activation = nn.Sigmoid()
# self.activation = torch.sigmoid # nn.Sigmoid()
else:
# self.loss_func = nn.CrossEntropyLoss(self.loss_func_wts)
self.loss_func = nn.CrossEntropyLoss(weight=self.loss_func_wts)
# self.loss_func = partial(F.cross_entropy, weight=self.loss_func_wts)
self.activation = nn.Softmax(-1)
# self.activation = partial(F.softmax, dim=-1) # nn.Softmax(-1)

@classmethod
def from_config(cls, config: ClassifierConfig):
return cls(**config.__dict__)

# TODO: Make it run with regular features as well
def forward(self, features: Union[Tensor, TensorDict, TensorList]):
"""
Sequence of outputs from an FPN or regular feature extractor
=> Avg. Pool each into 1 dimension
=> Concatenate into single tensor
=> Linear layer -> output classes

If `self.fpn_keys` is specified, it grabs the specific (int|str) indices from
`features` for the pooling layer, else it takes _all_ of them
"""
if isinstance(features, (list, dict, tuple)):
# Grab specific features if specified
if self.fpn_keys is not None:
pooled_features = [
F.adaptive_avg_pool2d(features[k], 1) for k in self.fpn_keys
]
# If no `fpn_keys` exist, concat all the feature maps (could be expensive)
else:
pooled_features = [F.adaptive_avg_pool2d(feat, 1) for feat in features]
pooled_features = torch.cat(pooled_features, dim=1)

# If doing regular (non-FPN) feature extraction, we don't need `fpn_keys` and
# just avg. pool the last layer's features
elif isinstance(features, Tensor):
pooled_features = (
F.adaptive_avg_pool2d(features, 1) if self.pool_inputs else features
)
else:
raise TypeError(
f"Expected TensorList|TensorDict|Tensor|tuple, got {type(features)}"
)

return self.classifier(pooled_features)

# TorchVision style API
def compute_loss(self, predictions, targets):
return self.loss_weight * self.loss_func(predictions, targets)

def postprocess(self, predictions):
return self.activation(predictions)

# MMDet style API
def forward_train(self, x, gt_label) -> Tensor:
preds = self(x)
return self.loss_weight * self.loss_func(preds, gt_label)

def forward_activate(self, x):
"Run forward pass with activation function"
x = self(x)
return self.activation(x)
Empty file.
71 changes: 71 additions & 0 deletions icevision/models/multitask/data/dataloading_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
This may be a temporary file that may eventually be removed,
as it only slightly modifies an existing function.
"""

__all__ = [
"unload_records",
"assign_classification_targets_from_record",
"massage_multi_aug_classification_data",
]


import torch
from icevision.core.record_type import RecordType
from typing import Any, Dict, Optional, Callable, Sequence, Tuple
from icevision.core.record_components import ClassificationLabelsRecordComponent
from torch import tensor


def unload_records(
build_batch: Callable, build_batch_kwargs: Optional[Dict] = None
) -> Tuple[Tuple[Any, ...], Sequence[RecordType]]:
"""
This decorator function unloads records to not carry them around after batch creation.
It also optionally accepts `build_batch_kwargs` that are to be passed into
`build_batch`. These aren't accepted as keyword arguments as those are reserved
for PyTorch's DataLoader class which is used later in this chain of function calls

Args:
build_batch (Callable): A collate function that describes how to mash records
into a batch of inputs for a model
build_batch_kwargs (Optional[Dict], optional): Keyword arguments to pass into
`build_batch`. Defaults to None.

Returns:
Tuple[Tuple[Any, ...], Sequence[RecordType]]: [description]
"""
build_batch_kwargs = build_batch_kwargs or {}
assert isinstance(build_batch_kwargs, dict)

def inner(records):
tupled_output, records = build_batch(records, **build_batch_kwargs)
for record in records:
record.unload()
return tupled_output, records

return inner


def assign_classification_targets_from_record(classification_labels: dict, record):
for comp in record.components:
name = comp.task.name
if isinstance(comp, ClassificationLabelsRecordComponent):
if comp.is_multilabel:
labels = comp.one_hot_encoded()
classification_labels[name].append(labels)
else:
labels = comp.label_ids
classification_labels[name].extend(labels)


def massage_multi_aug_classification_data(
classification_data, classification_targets, target_key: str
):
for group in classification_data.values():
group[target_key] = {
task: tensor(classification_targets[task]) for task in group["tasks"]
}
group["images"] = torch.stack(group["images"])

return {k: dict(v) for k, v in classification_data.items()}
Loading