diff --git a/src/super_gradients/training/utils/callbacks/extreme_batch_obb_visualization_callback.py b/src/super_gradients/training/utils/callbacks/extreme_batch_obb_visualization_callback.py index ffaa6a0a7..bbaee97d8 100644 --- a/src/super_gradients/training/utils/callbacks/extreme_batch_obb_visualization_callback.py +++ b/src/super_gradients/training/utils/callbacks/extreme_batch_obb_visualization_callback.py @@ -24,33 +24,30 @@ class ExtremeBatchOBBVisualizationCallback(ExtremeBatchCaseVisualizationCallback """ ExtremeBatchOBBVisualizationCallback - Visualizes worst/best batch in an epoch for pose estimation task. - This class visualize horizontally-stacked GT and predicted poses. - It requires a key 'gt_samples' (List[PoseEstimationSample]) to be present in additional_batch_items dictionary. + Visualizes worst/best batch in an epoch for OBB detection task. + This class visualize horizontally-stacked GT and predicted boxes. + It requires a key 'gt_samples' (List[OBBSample]) to be present in additional_batch_items dictionary. - Supported models: YoloNASPose - Supported datasets: COCOPoseEstimationDataset + Supported models: YoloNAS-R + Supported datasets: DOTAOBBDataset Example usage in Yaml config: training_hyperparams: phase_callbacks: - - ExtremeBatchPoseEstimationVisualizationCallback: - keypoint_colors: ${dataset_params.keypoint_colors} - edge_colors: ${dataset_params.edge_colors} - edge_links: ${dataset_params.edge_links} - loss_to_monitor: YoloNASPoseLoss/loss + - ExtremeBatchOBBVisualizationCallback: + loss_to_monitor: YoloNASRLoss/loss max: True freq: 1 max_images: 16 enable_on_train_loader: True enable_on_valid_loader: True post_prediction_callback: - _target_: super_gradients.training.models.pose_estimation_models.yolo_nas_pose.YoloNASPosePostPredictionCallback - pose_confidence_threshold: 0.01 - nms_iou_threshold: 0.7 - pre_nms_max_predictions: 300 - post_nms_max_predictions: 30 + _target_: super_gradients.training.models.detection_models.yolo_nas_r.yolo_nas_r_post_prediction_callback.YoloNASRPostPredictionCallback + score_threshold: 0.25 + pre_nms_max_predictions: 4096 + post_nms_max_predictions: 512 + nms_iou_threshold: 0.6 :param metric: Metric, will be the metric which is monitored. @@ -73,9 +70,6 @@ class ExtremeBatchOBBVisualizationCallback(ExtremeBatchCaseVisualizationCallback the minimum (default=False). :param freq: int, epoch frequency to perform all of the above (default=1). - - - """ def __init__(