Skip to content

Commit

Permalink
Update docs on visualization callback
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed May 22, 2024
1 parent 27592e7 commit 772c245
Showing 1 changed file with 12 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,30 @@ class ExtremeBatchOBBVisualizationCallback(ExtremeBatchCaseVisualizationCallback
"""
ExtremeBatchOBBVisualizationCallback
Visualizes worst/best batch in an epoch for pose estimation task.
This class visualize horizontally-stacked GT and predicted poses.
It requires a key 'gt_samples' (List[PoseEstimationSample]) to be present in additional_batch_items dictionary.
Visualizes worst/best batch in an epoch for OBB detection task.
This class visualize horizontally-stacked GT and predicted boxes.
It requires a key 'gt_samples' (List[OBBSample]) to be present in additional_batch_items dictionary.
Supported models: YoloNASPose
Supported datasets: COCOPoseEstimationDataset
Supported models: YoloNAS-R
Supported datasets: DOTAOBBDataset
Example usage in Yaml config:
training_hyperparams:
phase_callbacks:
- ExtremeBatchPoseEstimationVisualizationCallback:
keypoint_colors: ${dataset_params.keypoint_colors}
edge_colors: ${dataset_params.edge_colors}
edge_links: ${dataset_params.edge_links}
loss_to_monitor: YoloNASPoseLoss/loss
- ExtremeBatchOBBVisualizationCallback:
loss_to_monitor: YoloNASRLoss/loss
max: True
freq: 1
max_images: 16
enable_on_train_loader: True
enable_on_valid_loader: True
post_prediction_callback:
_target_: super_gradients.training.models.pose_estimation_models.yolo_nas_pose.YoloNASPosePostPredictionCallback
pose_confidence_threshold: 0.01
nms_iou_threshold: 0.7
pre_nms_max_predictions: 300
post_nms_max_predictions: 30
_target_: super_gradients.training.models.detection_models.yolo_nas_r.yolo_nas_r_post_prediction_callback.YoloNASRPostPredictionCallback
score_threshold: 0.25
pre_nms_max_predictions: 4096
post_nms_max_predictions: 512
nms_iou_threshold: 0.6
:param metric: Metric, will be the metric which is monitored.
Expand All @@ -73,9 +70,6 @@ class ExtremeBatchOBBVisualizationCallback(ExtremeBatchCaseVisualizationCallback
the minimum (default=False).
:param freq: int, epoch frequency to perform all of the above (default=1).
"""

def __init__(
Expand Down

0 comments on commit 772c245

Please sign in to comment.