Skip to content

Commit

Permalink
Update QDTrack + YOLOX (#131)
Browse files Browse the repository at this point in the history
* Update zoo

* Clean up code, update zoo

* Updates to qdtrack configs and yolox

* Fix lint

* Fix qdtrack inference test

* Update test
  • Loading branch information
thomasehuang authored Nov 29, 2023
1 parent c1c8157 commit a631099
Show file tree
Hide file tree
Showing 12 changed files with 152 additions and 160 deletions.
13 changes: 8 additions & 5 deletions tests/model/track/qdtrack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from vis4d.data.transforms.pad import PadImages
from vis4d.data.transforms.resize import GenResizeParameters, ResizeImages
from vis4d.data.transforms.to_tensor import ToTensor
from vis4d.model.adapter import ModelExpEMAAdapter
from vis4d.model.track.qdtrack import (
REV_KEYS,
YOLOX_REV_KEYS,
FasterRCNNQDTrack,
TrackOut,
YOLOXQDTrack,
Expand Down Expand Up @@ -87,10 +87,13 @@ def test_inference_yolox(self):
"""Inference test for YOLOX QDTrack."""
TrackIDCounter.reset() # reset track ID counter
model_weights = (
"https://dl.cv.ethz.ch/vis4d/qdtrack-yolox-ema_bdd100k.ckpt"
"https://dl.cv.ethz.ch/vis4d/bdd100k/qdtrack/"
"qdtrack_yolox_x_25e_bdd100k/qdtrack_yolox_x_25e_bdd100k_c14af2.pt"
)
qdtrack = ModelExpEMAAdapter(YOLOXQDTrack(num_classes=8))
load_model_checkpoint(
qdtrack, model_weights, rev_keys=[("^model.", "")]
)
qdtrack = YOLOXQDTrack(num_classes=8)
load_model_checkpoint(qdtrack, model_weights, rev_keys=YOLOX_REV_KEYS)
qdtrack.eval()

data_root = osp.join(get_test_data("bdd100k_test"), "track/images")
Expand Down Expand Up @@ -133,4 +136,4 @@ def test_inference_yolox(self):
for pred, expected in zip(pred_entry, expected_entry):
print("PREDICTION:", pred.shape, pred)
print("EXPECTED:", expected.shape, expected)
assert torch.isclose(pred, expected, atol=1e-4).all().item()
assert torch.isclose(pred, expected, atol=1e-2).all().item()
Binary file modified tests/model/track/testcases/qdtrack-yolox.pt
Binary file not shown.
3 changes: 1 addition & 2 deletions vis4d/config/common/models/qdtrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
CONN_BBOX_2D_TRAIN = {
"images": K.images,
"images_hw": K.input_hw,
"original_hw": K.original_hw,
"frame_ids": K.frame_ids,
"boxes2d": K.boxes2d,
"boxes2d_classes": K.boxes2d_classes,
Expand Down Expand Up @@ -67,8 +68,6 @@
"ref_track_ids": pred_key("ref_track_ids"),
}

CONN_BBOX_2D_YOLOX_TRAIN = {**CONN_BBOX_2D_TRAIN, "original_hw": K.original_hw}

CONN_YOLOX_LOSS_2D = {
"cls_outs": pred_key(f"{PRED_PREFIX}.cls_score"),
"reg_outs": pred_key(f"{PRED_PREFIX}.bbox_pred"),
Expand Down
8 changes: 5 additions & 3 deletions vis4d/engine/callbacks/yolox_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
get_world_size,
synchronize,
)
from vis4d.common.logging import rank_zero_info
from vis4d.common.logging import rank_zero_info, rank_zero_warn
from vis4d.data.const import CommonKeys as K
from vis4d.data.data_pipe import DataPipe
from vis4d.data.typing import DictDataOrList
Expand Down Expand Up @@ -65,13 +65,15 @@ def on_train_epoch_end(
found_loss = True
yolox_loss = loss["loss"]
break
assert found_loss, "YOLOXHeadLoss should be in LossModule."
rank_zero_info(
"Switching YOLOX training mode starting next training epoch "
"(turning off strong augmentations, adding L1 loss, switching to "
"validation every epoch)."
)
yolox_loss.loss_l1 = l1_loss # set L1 loss function
if found_loss:
yolox_loss.loss_l1 = l1_loss # set L1 loss function
else:
rank_zero_warn("YOLOXHeadLoss should be in LossModule.")
# Set data pipeline to default DataPipe to skip strong augs.
# Switch to checking validation every epoch.
dataloader = trainer_state["train_dataloader"]
Expand Down
31 changes: 17 additions & 14 deletions vis4d/engine/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,28 @@ def __init__(
self.lr_schedulers_cfg = lr_schedulers_cfg
self.lr_schedulers: dict[int, LRSchedulerDict] = {}
super().__init__(optimizer)

self.steps_per_epoch = steps_per_epoch
self._convert_epochs_to_steps()

for i, lr_scheduler_cfg in enumerate(self.lr_schedulers_cfg):
if lr_scheduler_cfg["begin"] == 0:
self._instantiate_lr_scheduler(i, lr_scheduler_cfg)

def _convert_epochs_to_steps(self) -> None:
"""Convert epochs to steps."""
for lr_scheduler_cfg in self.lr_schedulers_cfg:
if (
lr_scheduler_cfg["convert_epochs_to_steps"]
and not lr_scheduler_cfg["epoch_based"]
):
lr_scheduler_cfg["begin"] *= self.steps_per_epoch
lr_scheduler_cfg["end"] *= self.steps_per_epoch
if lr_scheduler_cfg["convert_attributes"] is not None:
for attr in lr_scheduler_cfg["convert_attributes"]:
lr_scheduler_cfg["scheduler"]["init_args"][
attr
] *= self.steps_per_epoch

def _instantiate_lr_scheduler(
self, scheduler_idx: int, lr_scheduler_cfg: LrSchedulerConfig
) -> None:
Expand All @@ -49,19 +65,6 @@ def _instantiate_lr_scheduler(
pg["lr"] for pg in self.optimizer.param_groups
]

# Convert epochs to steps
if (
lr_scheduler_cfg["convert_epochs_to_steps"]
and not lr_scheduler_cfg["epoch_based"]
):
lr_scheduler_cfg["begin"] *= self.steps_per_epoch
lr_scheduler_cfg["end"] *= self.steps_per_epoch
if lr_scheduler_cfg["convert_attributes"] is not None:
for attr in lr_scheduler_cfg["convert_attributes"]:
lr_scheduler_cfg["scheduler"]["init_args"][
attr
] *= self.steps_per_epoch

self.lr_schedulers[scheduler_idx] = {
"scheduler": instantiate_classes(
lr_scheduler_cfg["scheduler"], optimizer=self.optimizer
Expand Down
27 changes: 1 addition & 26 deletions vis4d/model/track/qdtrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import Tensor, nn

from vis4d.common.ckpt import load_model_checkpoint
from vis4d.model.detect.yolox import REV_KEYS as YOLOX_REV_KEYS
from vis4d.op.base import BaseModel, CSPDarknet, ResNet
from vis4d.op.box.box2d import scale_and_clip_boxes
from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder
Expand All @@ -26,37 +27,11 @@
from .util import split_key_ref_indices

REV_KEYS = [
# (r"^detector.rpn_head.mm_dense_head\.", "rpn_head."),
# (r"\.rpn_reg\.", ".rpn_box."),
# (r"^detector.roi_head.mm_roi_head.bbox_head\.", "roi_head."),
# (r"^detector.backbone.mm_backbone\.", "body."),
# (
# r"^detector.backbone.neck.mm_neck.lateral_convs\.",
# "inner_blocks.",
# ),
# (
# r"^detector.backbone.neck.mm_neck.fpn_convs\.",
# "layer_blocks.",
# ),
# (r"\.conv.weight", ".weigh2t"),
# (r"\.conv.bias", ".bias"),
(r"^faster_rcnn_heads\.", "faster_rcnn_head."),
(r"^backbone.body\.", "basemodel."),
(r"^qdtrack\.", "qdtrack_head."),
]

# from old Vis4D checkpoint
YOLOX_REV_KEYS = [
(r"^detector.backbone.mm_backbone\.", "basemodel."),
(r"^bbox_head\.", "yolox_head."),
(r"^detector.backbone.neck.mm_neck\.", "fpn."),
(r"^detector.bbox_head.mm_dense_head\.", "yolox_head."),
(r"^similarity_head\.", "qdtrack_head.similarity_head."),
(r"\.bn\.", ".norm."),
(r"\.conv.weight", ".weight"),
(r"\.conv.bias", ".bias"),
]


class FasterRCNNQDTrackOut(NamedTuple):
"""Output of QDtrack model."""
Expand Down
13 changes: 7 additions & 6 deletions vis4d/zoo/bdd100k/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,17 @@ The BDD100K dataset contains MOT annotations for 2K videos (1.4K/200/400 for tra

[QDTrack: Quasi-Dense Similarity Learning for Appearance-Only Multiple Object Tracking](https://arxiv.org/abs/2210.06984) [TPAMI, CVPR 2021 Oral]

Authors: [Tobias Fischer](https://tobiasfshr.github.io/), [Thomas E Huang](https://www.thomasehuang.com/), [Jiangmiao Pang](https://scholar.google.com/citations?user=ssSfKpAAAAAJ), [Linlu Qiu](https://linlu-qiu.github.io/), [Haofeng Chen](https://www.haofeng.io/), Qi Li, [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/), [Fisher Yu](https://www.yf.io/)
Authors: [Tobias Fischer*](https://tobiasfshr.github.io/), [Thomas E Huang*](https://www.thomasehuang.com/), [Jiangmiao Pang*](https://scholar.google.com/citations?user=ssSfKpAAAAAJ), [Linlu Qiu](https://linlu-qiu.github.io/), [Haofeng Chen](https://www.haofeng.io/), Qi Li, [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/), [Fisher Yu](https://www.yf.io/)

<details>
<summary>Abstract</summary>
Similarity learning has been recognized as a crucial step for object tracking. However, existing multiple object tracking methods only use sparse ground truth matching as the training objective, while ignoring the majority of the informative regions on the images. In this paper, we present Quasi-Dense Similarity Learning, which densely samples hundreds of region proposals on a pair of images for contrastive learning. We can naturally combine this similarity learning with existing detection methods to build Quasi-Dense Tracking (QDTrack) without turning to displacement regression or motion priors. We also find that the resulting distinctive feature space admits a simple nearest neighbor search at the inference time. Despite its simplicity, QDTrack outperforms all existing methods on MOT, BDD100K, Waymo, and TAO tracking benchmarks. It achieves 68.7 MOTA at 20.3 FPS on MOT17 without using external training data. Compared to methods with similar detectors, it boosts almost 10 points of MOTA and significantly decreases the number of ID switches on BDD100K and Waymo datasets.
Similarity learning has been recognized as a crucial step for object tracking. However, existing multiple object tracking methods only use sparse ground truth matching as the training objective, while ignoring the majority of the informative regions in images. In this paper, we present Quasi-Dense Similarity Learning, which densely samples hundreds of object regions on a pair of images for contrastive learning. We combine this similarity learning with multiple existing object detectors to build Quasi-Dense Tracking (QDTrack), which does not require displacement regression or motion priors. We find that the resulting distinctive feature space admits a simple nearest neighbor search at inference time for object association. In addition, we show that our similarity learning scheme is not limited to video data, but can learn effective instance similarity even from static input, enabling a competitive tracking performance without training on videos or using tracking supervision. We conduct extensive experiments on a wide variety of popular MOT benchmarks. We find that, despite its simplicity, QDTrack rivals the performance of state-of-the-art tracking methods on all benchmarks and sets a new state-of-the-art on the large-scale BDD100K MOT benchmark, while introducing negligible computational overhead to the detector.
</details>

#### Results

| Detector | Base Network | mMOTA-val | mIDF1-val | ID Sw.-val | Scores-val | Config | Weights | Preds | Visuals |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| Faster R-CNN | R-50-FPN | | | | [scores]() | [config]() | [model]() | [preds]() | [visuals]() |
| YOLOX-x | CSPNet | | | | [scores]() | [config]() | [model]() | [preds]() | [visuals]() |
| Detector | Base Network | Strong Augs. | mMOTA-val | mIDF1-val | ID Sw.-val | Scores-val | Config | Weights | Preds | Visuals |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| Faster R-CNN | R-50-FPN | | 36.1 | 51.8 | 6165 | [scores]() | [config](./qdtrack/qdtrack_frcnn_r50_fpn_1x_bdd100k.py) | [model]() | [preds]() | [visuals]() |
| Faster R-CNN | R-50-FPN || 37.7 | 52.7 | 7257 | [scores]() | [config](./qdtrack/qdtrack_frcnn_r50_fpn_augs_1x_bdd100k.py) | [model]() | [preds]() | [visuals]() |
| YOLOX-x | CSPNet || 42.3 | 55.1 | 9164 | [scores]() | [config](./qdtrack/qdtrack_yolox_x_50e_bdd100k.py) | [model]() | [preds]() | [visuals]() |
2 changes: 0 additions & 2 deletions vis4d/zoo/bdd100k/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
)
from .qdtrack import (
qdtrack_frcnn_r50_fpn_1x_bdd100k,
qdtrack_yolox_s_50e_bdd100k,
qdtrack_yolox_x_50e_bdd100k,
)
from .semantic_fpn import (
Expand All @@ -27,6 +26,5 @@
"semantic_fpn_r50_80k_bdd100k": semantic_fpn_r50_80k_bdd100k,
"semantic_fpn_r101_80k_bdd100k": semantic_fpn_r101_80k_bdd100k,
"qdtrack_frcnn_r50_fpn_1x_bdd100k": qdtrack_frcnn_r50_fpn_1x_bdd100k,
"qdtrack_yolox_s_50e_bdd100k": qdtrack_yolox_s_50e_bdd100k,
"qdtrack_yolox_x_50e_bdd100k": qdtrack_yolox_x_50e_bdd100k,
}
77 changes: 43 additions & 34 deletions vis4d/zoo/bdd100k/qdtrack/data_yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
MosaicBoxes2D,
MosaicImages,
)
from vis4d.data.transforms.normalize import NormalizeImages
from vis4d.data.transforms.pad import PadImages
from vis4d.data.transforms.photometric import RandomHSV
from vis4d.data.transforms.post_process import (
Expand All @@ -51,6 +52,7 @@
def get_train_dataloader(
data_backend: None | ConfigDict,
image_size: tuple[int, int],
normalize_image: bool,
samples_per_gpu: int,
workers_per_gpu: int,
) -> ConfigDict:
Expand Down Expand Up @@ -141,36 +143,36 @@ def get_train_dataloader(
[class_config(PostProcessBoxes2D, min_area=1.0)]
)

batch_transforms = [
class_config(RandomHSV, same_on_batch=False),
class_config(
RandomApply,
transforms=[class_config(FlipImages), class_config(FlipBoxes2D)],
probability=0.5,
same_on_batch=False,
),
class_config(
GenResizeParameters,
shape=image_size,
keep_ratio=True,
scale_range=(0.5, 1.5),
same_on_batch=False,
),
class_config(ResizeImages),
class_config(ResizeBoxes2D),
class_config(GenCropParameters, shape=image_size, same_on_batch=False),
class_config(CropImages),
class_config(CropBoxes2D),
]
if normalize_image:
batch_transforms += [
class_config(NormalizeImages),
class_config(PadImages),
]
else:
batch_transforms += [class_config(PadImages, value=114.0)]
train_batchprocess_cfg = class_config(
compose,
transforms=[
class_config(RandomHSV, same_on_batch=False),
class_config(
RandomApply,
transforms=[
class_config(FlipImages),
class_config(FlipBoxes2D),
],
probability=0.5,
same_on_batch=False,
),
class_config(
GenResizeParameters,
shape=image_size,
keep_ratio=True,
scale_range=(0.5, 1.5),
same_on_batch=False,
),
class_config(ResizeImages),
class_config(ResizeBoxes2D),
class_config(
GenCropParameters, shape=image_size, same_on_batch=False
),
class_config(CropImages),
class_config(CropBoxes2D),
class_config(PadImages, value=114.0),
class_config(ToTensor),
],
compose, transforms=batch_transforms + [class_config(ToTensor)]
)

return class_config(
Expand All @@ -192,6 +194,7 @@ def get_train_dataloader(
def get_test_dataloader(
data_backend: None | ConfigDict,
image_size: tuple[int, int],
normalize_image: bool,
samples_per_gpu: int,
workers_per_gpu: int,
) -> ConfigDict:
Expand All @@ -218,12 +221,15 @@ def get_test_dataloader(
compose, transforms=preprocess_transforms
)

if normalize_image:
batch_transforms = [
class_config(NormalizeImages),
class_config(PadImages),
]
else:
batch_transforms = [class_config(PadImages, value=114.0)]
test_batchprocess_cfg = class_config(
compose,
transforms=[
class_config(PadImages, value=114.0),
class_config(ToTensor),
],
compose, transforms=batch_transforms + [class_config(ToTensor)]
)

test_dataset_cfg = class_config(
Expand All @@ -242,6 +248,7 @@ def get_test_dataloader(
def get_bdd100k_track_cfg(
data_backend: None | ConfigDict = None,
image_size: tuple[int, int] = (800, 1440),
normalize_image: bool = False,
samples_per_gpu: int = 2,
workers_per_gpu: int = 2,
) -> DataConfig:
Expand All @@ -251,13 +258,15 @@ def get_bdd100k_track_cfg(
data.train_dataloader = get_train_dataloader(
data_backend=data_backend,
image_size=image_size,
normalize_image=normalize_image,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
)

data.test_dataloader = get_test_dataloader(
data_backend=data_backend,
image_size=image_size,
normalize_image=normalize_image,
samples_per_gpu=1,
workers_per_gpu=1,
)
Expand Down
10 changes: 4 additions & 6 deletions vis4d/zoo/bdd100k/qdtrack/qdtrack_frcnn_r50_fpn_1x_bdd100k.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=duplicate-code
"""QDTrack-FasterRCNN BDD100K."""
"""QDTrack with Faster R-CNN on BDD100K."""
from __future__ import annotations

import lightning.pytorch as pl
Expand Down Expand Up @@ -44,8 +44,8 @@ def get_config() -> ExperimentConfig:

# High level hyper parameters
params = ExperimentParameters()
params.samples_per_gpu = 2
params.workers_per_gpu = 2
params.samples_per_gpu = 4 # batch size = 4 GPUs * 4 samples per GPU = 16
params.workers_per_gpu = 4
params.lr = 0.02
params.num_epochs = 12
config.params = params
Expand All @@ -70,9 +70,7 @@ def get_config() -> ExperimentConfig:
)

config.model, config.loss = get_qdtrack_cfg(
num_classes=num_classes,
basemodel=basemodel,
# weights="https://dl.cv.ethz.ch/vis4d/qdtrack_bdd100k_frcnn_res50_heavy_augs.pt", # pylint: disable=line-too-long
num_classes=num_classes, basemodel=basemodel
)

######################################################
Expand Down
Loading

0 comments on commit a631099

Please sign in to comment.