Skip to content

Latest commit

 

History

History
231 lines (164 loc) · 4.9 KB

customize_sot_model.md

File metadata and controls

231 lines (164 loc) · 4.9 KB

Customize SOT Models

We basically categorize model components into 4 types.

  • backbone: usually an FCN network to extract feature maps, e.g., ResNet, MobileNet.
  • neck: the component between backbones and heads, e.g., ChannelMapper, FPN.
  • head: the component for specific tasks, e.g., tracking bbox prediction.
  • loss: the component in head for calculating losses, e.g., FocalLoss, L1Loss.

Add a new backbones

Here we show how to develop new components with an example of MobileNet.

1. Define a new backbone (e.g. MobileNet)

Create a new file mmtrack/models/backbones/mobilenet.py.

import torch.nn as nn

from mmdet.models.builder import BACKBONES


@BACKBONES.register_module()
class MobileNet(nn.Module):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

    def init_weights(self, pretrained=None):
        pass

2. Import the module

You can either add the following line to mmtrack/models/backbones/__init__.py

from .mobilenet import MobileNet

or alternatively add

custom_imports = dict(
    imports=['mmtrack.models.backbones.mobilenet'],
    allow_failed_imports=False)

to the config file to avoid modifying the original code.

3. Use the backbone in your config file

model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...

Add a new neck

1. Define a neck (e.g. MyFPN)

Create a new file mmtrack/models/necks/my_fpn.py.

from mmdet.models.builder import NECKS

@NECKS.register_module()
class MyFPN(nn.Module):

    def __init__(self,
                in_channels,
                out_channels,
                num_outs,
                start_level=0,
                end_level=-1,
                add_extra_convs=False):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

2. Import the module

You can either add the following line to mmtrack/models/necks/__init__.py,

from .my_fpn import MyFPN

or alternatively add

custom_imports = dict(
    imports=['mmtrack.models.necks.my_fpn.py'],
    allow_failed_imports=False)

to the config file and avoid modifying the original code.

3. Modify the config file

neck=dict(
    type='MyFPN',
    in_channels=[256, 512, 1024, 2048],
    out_channels=256,
    num_outs=5)

Add a new head

1. Define a head (e.g. MyHead)

Create a new file mmtrack/models/track_heads/my_head.py.

from mmdet.models import HEADS

@HEADS.register_module()
class MyHead(nn.Module):

    def __init__(self,
                arg1,
                arg2):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

2. Import the module

You can either add the following line to mmtrack/models/track_heads/__init__.py,

from .my_head import MyHead

or alternatively add

custom_imports = dict(
    imports=['mmtrack.models.track_heads.my_head.py'],
    allow_failed_imports=False)

to the config file and avoid modifying the original code.

3. Modify the config file

head=dict(
    type='MyHead',
    arg1=xxx,
    arg2=xxx)

Add a new loss

Assume you want to add a new loss as MyLoss, for bounding box regression. To add a new loss function, the users need implement it in mmtrack/models/losses/my_loss.py. The decorator weighted_loss enable the loss to be weighted for each element.

import torch
import torch.nn as nn

from mmdet.models import LOSSES, weighted_loss

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss_bbox

Then the users need to add it in the mmtrack/models/losses/__init__.py.

from .my_loss import MyLoss, my_loss

Alternatively, you can add

custom_imports=dict(
    imports=['mmtrack.models.losses.my_loss'])

to the config file and achieve the same goal.

To use it, modify the loss_xxx field. Since MyLoss is for regression, you need to modify the loss_bbox field in the head.

loss_bbox=dict(type='MyLoss', loss_weight=1.0))