Skip to content

Commit

Permalink
replace "squeeze" with "group" II
Browse files Browse the repository at this point in the history
  • Loading branch information
phinik committed Jan 9, 2024
1 parent f4bc801 commit e18d3a7
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions yoeo/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,19 @@ def _evaluate(model, dataloader, class_config, img_size, iou_thres, conf_thres,
import time
times = []

if class_config.classes_should_be_squeezed():
secondary_metric = Metric(len(class_config.get_squeeze_ids()))
if class_config.classes_should_be_grouped():
secondary_metric = Metric(len(class_config.get_group_ids()))
else:
secondary_metric = None

for _, imgs, bb_targets, mask_targets in tqdm.tqdm(dataloader, desc="Validating"):
# Extract labels
labels += bb_targets[:, 1].tolist()

# If a subset of the detection classes should be squeezed into one class for non-maximum suppression and the
# subsequent AP-computation, we need to squeeze those class labels here.
if class_config.classes_should_be_squeezed():
labels = class_config.squeeze(labels)
# If a subset of the detection classes should be grouped into one class for non-maximum suppression and the

Check warning on line 160 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

trailing whitespace
# subsequent AP-computation, we need to group those class labels here.
if class_config.classes_should_be_grouped():
labels = class_config.group(labels)

# Rescale target
bb_targets[:, 2:] = xywh2xyxy(bb_targets[:, 2:])
Expand All @@ -176,19 +176,19 @@ def _evaluate(model, dataloader, class_config, img_size, iou_thres, conf_thres,
yolo_outputs,
conf_thres=conf_thres,
iou_thres=nms_thres,
group_config=class_config.get_squeeze_config()
group_config=class_config.get_group_config()
)

sample_stat, secondary_stat = get_batch_statistics(
yolo_outputs,

Check warning on line 183 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

trailing whitespace
bb_targets,

Check warning on line 184 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

trailing whitespace
iou_threshold=iou_thres,

Check warning on line 185 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

trailing whitespace
group_config=class_config.get_squeeze_config()
group_config=class_config.get_group_config()
)

sample_metrics += sample_stat

if class_config.classes_should_be_squeezed():
if class_config.classes_should_be_grouped():
secondary_metric += secondary_stat

seg_ious.append(seg_iou(to_cpu(segmentation_outputs), mask_targets, model.num_seg_classes))
Expand Down

0 comments on commit e18d3a7

Please sign in to comment.