From f181fcc78a7cdaacf2f7dd088df30f21a25aed1f Mon Sep 17 00:00:00 2001 From: TheoLisin <87002239+TheoLisin@users.noreply.github.com> Date: Sat, 11 Nov 2023 04:18:46 +0400 Subject: [PATCH] V2 stop (#18) * disable whatcher on shutdown * upd sdk * added logger to whatcher * start train in thread * added stop callback on train * added stop event on validation * beautify * rename * added checkpoints dir --- train/.gitignore | 3 +- train/requirements.txt | 3 +- train/src/globals.py | 5 +- train/src/main.py | 228 +++++++++++++++++++++++++---------- train/src/metrics_watcher.py | 3 + 5 files changed, 178 insertions(+), 64 deletions(-) diff --git a/train/.gitignore b/train/.gitignore index ad5a6b7..8c53cfe 100644 --- a/train/.gitignore +++ b/train/.gitignore @@ -1,3 +1,4 @@ src/__pycache__/ local.env -tempfiles/ \ No newline at end of file +tempfiles/ +runs/ \ No newline at end of file diff --git a/train/requirements.txt b/train/requirements.txt index 46c08cb..e180ce1 100644 --- a/train/requirements.txt +++ b/train/requirements.txt @@ -1,4 +1,5 @@ -supervisely==6.72.110 +supervisely==6.72.182 +# git+https://github.com/supervisely/supervisely.git@v2-stop ultralytics==8.0.112 --extra-index-url https://download.pytorch.org/whl/cu113 torch==1.10.1+cu113 diff --git a/train/src/globals.py b/train/src/globals.py index bd55a43..f64fb9a 100644 --- a/train/src/globals.py +++ b/train/src/globals.py @@ -13,8 +13,11 @@ if sly.is_production(): app_session_id = sly.io.env.task_id() + root_model_checkpoint_dir = sly.app.get_synced_data_dir() else: app_session_id = 777 # for debug + root_model_checkpoint_dir = os.path.join(app_root_directory, "runs") + det_models_data_path = os.path.join(root_source_path, "models", "det_models_data.json") seg_models_data_path = os.path.join(root_source_path, "models", "seg_models_data.json") @@ -29,4 +32,4 @@ train_params_filepath = "training_params.yml" # for debug train_counter, val_counter = 0, 0 center_matches = {} -keypoints_template = None \ No newline at end of file +keypoints_template = None diff --git a/train/src/main.py b/train/src/main.py index 0b9f8fd..9721b39 100644 --- a/train/src/main.py +++ b/train/src/main.py @@ -1,4 +1,5 @@ import os + os.environ["CUDA_VISIBLE_DEVICES"] = "0" from pathlib import Path import numpy as np @@ -144,7 +145,10 @@ def update_globals(new_dataset_ids): ) card_classes = Card( title="Task type & training classes", - description="Select task type and classes, that should be used for training. Supported shapes include rectangle, bitmap, polygon and graph", + description=( + "Select task type and classes, that should be used for training. " + "Supported shapes include rectangle, bitmap, polygon and graph" + ), content=classes_content, collapsable=True, lock_message="Complete the previous step to unlock", @@ -159,7 +163,10 @@ def update_globals(new_dataset_ids): unlabeled_images_select_f = Field( content=unlabeled_images_select, title="What to do with unlabeled images", - description="Sometimes unlabeled images can be used to reduce noise in predictions, sometimes it is a mistake in training data", + description=( + "Sometimes unlabeled images can be used to reduce noise in predictions, " + "sometimes it is a mistake in training data" + ), ) split_data_button = Button("Split data") resplit_data_button = Button( @@ -194,7 +201,10 @@ def update_globals(new_dataset_ids): ### 4. Model selection models_table_notification = NotificationBox( title="List of models in the table below depends on selected task type", - description="If you want to see list of available models for another computer vision task, please, go back to task type & training classes step and change task type", + description=( + "If you want to see list of available models for another computer vision task, " + "please, go back to task type & training classes step and change task type" + ), ) model_tabs_titles = ["Pretrained models", "Custom models"] model_tabs_descriptions = [ @@ -272,7 +282,10 @@ def update_globals(new_dataset_ids): select_train_mode_f = Field( content=select_train_mode, title="Select training mode", - description="Finetune mode - .pt file with pretrained model weights will be downloaded, Scratch mode - model weights will be initialized randomly", + description=( + "Finetune mode - .pt file with pretrained model weights will be downloaded, " + "Scratch mode - model weights will be initialized randomly" + ), ) n_epochs_input = InputNumber(value=100, min=1) n_epochs_input_f = Field(content=n_epochs_input, title="Number of epochs") @@ -320,7 +333,11 @@ def update_globals(new_dataset_ids): additional_config_template_select_f.hide() no_templates_notification = NotificationBox( title="No templates found", - description="There are no templates for this task type in Team Files. You can create custom config and save it as a template to Team Files - you will be able to reuse it in your future experiments", + description=( + "There are no templates for this task type in Team Files. " + "You can create custom config and save it as a template to " + "Team Files - you will be able to reuse it in your future experiments" + ), ) no_templates_notification.hide() train_settings_editor = Editor(language_mode="yaml", height_lines=50) @@ -638,9 +655,15 @@ def select_task(task_type): @select_classes_button.click def select_classes(): selected_classes = classes_table.get_selected_classes() - selected_shapes = [cls.geometry_type.geometry_name() for cls in project_meta.obj_classes if cls.name in selected_classes] + selected_shapes = [ + cls.geometry_type.geometry_name() + for cls in project_meta.obj_classes + if cls.name in selected_classes + ] task_type = task_type_select.get_value() - if task_type == "pose estimation" and ("graph" not in selected_shapes or "rectangle" not in selected_shapes): + if task_type == "pose estimation" and ( + "graph" not in selected_shapes or "rectangle" not in selected_shapes + ): sly.app.show_dialog( title="Pose estimation task requires input project to have at least one class of shape graph and one class of shape rectangle", description="Please, select both classes of shape rectangle and graph or change task type", @@ -741,7 +764,9 @@ def change_file_preview(value): @additional_config_radio.value_changed def change_radio(value): if value == "import template from Team Files": - remote_templates_dir = os.path.join("/yolov8_train", task_type_select.get_value(), "param_templates") + remote_templates_dir = os.path.join( + "/yolov8_train", task_type_select.get_value(), "param_templates" + ) templates = api.file.list(team_id, remote_templates_dir) if len(templates) == 0: no_templates_notification.show() @@ -756,7 +781,9 @@ def change_radio(value): @additional_config_template_select.value_changed def change_template(template): - remote_templates_dir = os.path.join("/yolov8_train", task_type_select.get_value(), "param_templates") + remote_templates_dir = os.path.join( + "/yolov8_train", task_type_select.get_value(), "param_templates" + ) remote_template_path = os.path.join(remote_templates_dir, template) local_template_path = os.path.join(g.app_data_dir, template) api.file.download(team_id, remote_template_path, local_template_path) @@ -768,7 +795,9 @@ def change_template(template): @save_template_button.click def upload_template(): save_template_button.loading = True - remote_templates_dir = os.path.join("/yolov8_train", task_type_select.get_value(), "param_templates") + remote_templates_dir = os.path.join( + "/yolov8_train", task_type_select.get_value(), "param_templates" + ) additional_params = train_settings_editor.get_text() ryaml = ruamel.yaml.YAML() additional_params = ryaml.load(additional_params) @@ -840,19 +869,20 @@ def change_logs_visibility(): @start_training_button.click def start_training(): task_type = task_type_select.get_value() - if sly.is_production(): - local_dir = g.root_source_path - else: - local_dir = g.app_root_directory + + local_dir = g.root_model_checkpoint_dir if task_type == "object detection": necessary_geometries = ["rectangle"] - local_artifacts_dir = os.path.join(local_dir, "runs", "detect", "train") + checkpoint_dir = os.path.join(local_dir, "detect") + local_artifacts_dir = os.path.join(local_dir, "detect", "train") elif task_type == "pose estimation": necessary_geometries = ["graph", "rectangle"] - local_artifacts_dir = os.path.join(local_dir, "runs", "pose", "train") + checkpoint_dir = os.path.join(local_dir, "pose") + local_artifacts_dir = os.path.join(local_dir, "pose", "train") elif task_type == "instance segmentation": necessary_geometries = ["bitmap", "polygon"] - local_artifacts_dir = os.path.join(local_dir, "runs", "segment", "train") + checkpoint_dir = os.path.join(local_dir, "segment") + local_artifacts_dir = os.path.join(local_dir, "segment", "train") sly.logger.info(f"Local artifacts dir: {local_artifacts_dir}") @@ -883,10 +913,15 @@ def start_training(): if task_type != "object detection": unnecessary_classes = [] for cls in project_meta.obj_classes: - if cls.name in selected_classes and cls.geometry_type.geometry_name() not in necessary_geometries: + if ( + cls.name in selected_classes + and cls.geometry_type.geometry_name() not in necessary_geometries + ): unnecessary_classes.append(cls.name) if len(unnecessary_classes) > 0: - sly.Project.remove_classes(g.project_dir, classes_to_remove=unnecessary_classes, inplace=True) + sly.Project.remove_classes( + g.project_dir, classes_to_remove=unnecessary_classes, inplace=True + ) # transfer project to detection task if necessary if task_type == "object detection": sly.Project.to_detection_task(g.project_dir, inplace=True) @@ -907,7 +942,9 @@ def start_training(): description="Val split length is 0 after ignoring images. Please check your data", status="error", ) - raise ValueError("Val split length is 0 after ignoring images. Please check your data") + raise ValueError( + "Val split length is 0 after ignoring images. Please check your data" + ) # split the data train_val_split._project_fs = sly.Project(g.project_dir, sly.OpenMode.READ) train_set, val_set = train_val_split.get_splits() @@ -942,7 +979,9 @@ def download_monitor(monitor, api: sly.Api, progress: sly.Progress): model_filename = selected_model.lower() + ".pt" pretrained = True weights_dst_path = os.path.join(g.app_data_dir, model_filename) - weights_url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_filename}" + weights_url = ( + f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_filename}" + ) with urlopen(weights_url) as file: weights_size = file.length @@ -1089,10 +1128,16 @@ def on_results_file_changed(filepath, pbar): # visualize train batch batch = f"train_batch{x}.jpg" local_train_batches_path = os.path.join(local_artifacts_dir, batch) - if os.path.exists(local_train_batches_path) and batch not in plotted_train_batches and x < 10: + if ( + os.path.exists(local_train_batches_path) + and batch not in plotted_train_batches + and x < 10 + ): plotted_train_batches.append(batch) remote_train_batches_path = os.path.join(remote_images_path, batch) - tf_train_batches_info = api.file.upload(team_id, local_train_batches_path, remote_train_batches_path) + tf_train_batches_info = api.file.upload( + team_id, local_train_batches_path, remote_train_batches_path + ) train_batches_gallery.append(tf_train_batches_info.full_storage_url) if x == 0: train_batches_gallery_f.show() @@ -1110,6 +1155,11 @@ def on_results_file_changed(filepath, pbar): def watcher_func(): watcher.watch() + def disable_watcher(): + watcher.running = False + + app.call_before_shutdown(disable_watcher) + threading.Thread(target=watcher_func, daemon=True).start() if len(train_set) > 300: n_train_batches = math.ceil(len(train_set) / batch_size_input.get_value()) @@ -1132,21 +1182,36 @@ def on_train_batches_file_changed(filepath, pbar): def train_batch_watcher_func(): train_batch_watcher.watch() + def train_batch_watcher_disable(): + train_batch_watcher.running = False + + app.call_before_shutdown(train_batch_watcher_disable) + threading.Thread(target=train_batch_watcher_func, daemon=True).start() - model.train( - data=data_path, - epochs=n_epochs_input.get_value(), - patience=patience_input.get_value(), - batch=batch_size_input.get_value(), - imgsz=image_size_input.get_value(), - save_period=1000, - device=device, - workers=n_workers_input.get_value(), - optimizer=select_optimizer.get_value(), - pretrained=pretrained, - **additional_params, - ) + def stop_on_batch_end_if_needed(*args, **kwargs): + if app.app_is_stopped(): + raise app.StopApp("This error is expected.") + + model.add_callback("on_train_batch_end", stop_on_batch_end_if_needed) + model.add_callback("on_val_batch_end", stop_on_batch_end_if_needed) + + with app.run_with_stop_app_suppression(): + model.train( + data=data_path, + epochs=n_epochs_input.get_value(), + patience=patience_input.get_value(), + batch=batch_size_input.get_value(), + imgsz=image_size_input.get_value(), + save_period=1000, + device=device, + workers=n_workers_input.get_value(), + optimizer=select_optimizer.get_value(), + pretrained=pretrained, + project=checkpoint_dir, + **additional_params, + ) + progress_bar_iters.hide() progress_bar_epochs.hide() watcher.running = False @@ -1178,8 +1243,12 @@ def train_batch_watcher_func(): # visualize additional training results confusion_matrix_path = os.path.join(local_artifacts_dir, "confusion_matrix_normalized.png") if os.path.exists(confusion_matrix_path): - remote_confusion_matrix_path = os.path.join(remote_images_path, "confusion_matrix_normalized.png") - tf_confusion_matrix_info = api.file.upload(team_id, confusion_matrix_path, remote_confusion_matrix_path) + remote_confusion_matrix_path = os.path.join( + remote_images_path, "confusion_matrix_normalized.png" + ) + tf_confusion_matrix_info = api.file.upload( + team_id, confusion_matrix_path, remote_confusion_matrix_path + ) additional_gallery.append(tf_confusion_matrix_info.full_storage_url) additional_gallery_f.show() pr_curve_path = os.path.join(local_artifacts_dir, "PR_curve.png") @@ -1200,18 +1269,24 @@ def train_batch_watcher_func(): pose_f1_curve_path = os.path.join(local_artifacts_dir, "PoseF1_curve.png") if os.path.exists(pose_f1_curve_path): remote_pose_f1_curve_path = os.path.join(remote_images_path, "PoseF1_curve.png") - tf_pose_f1_curve_info = api.file.upload(team_id, pose_f1_curve_path, remote_pose_f1_curve_path) + tf_pose_f1_curve_info = api.file.upload( + team_id, pose_f1_curve_path, remote_pose_f1_curve_path + ) additional_gallery.append(tf_pose_f1_curve_info.full_storage_url) mask_f1_curve_path = os.path.join(local_artifacts_dir, "MaskF1_curve.png") if os.path.exists(mask_f1_curve_path): remote_mask_f1_curve_path = os.path.join(remote_images_path, "MaskF1_curve.png") - tf_mask_f1_curve_info = api.file.upload(team_id, mask_f1_curve_path, remote_mask_f1_curve_path) + tf_mask_f1_curve_info = api.file.upload( + team_id, mask_f1_curve_path, remote_mask_f1_curve_path + ) additional_gallery.append(tf_mask_f1_curve_info.full_storage_url) # rename best checkpoint file results = pd.read_csv(watch_file) results.columns = [col.replace(" ", "") for col in results.columns] - results["fitness"] = (0.1 * results["metrics/mAP50(B)"]) + (0.9 * results["metrics/mAP50-95(B)"]) + results["fitness"] = (0.1 * results["metrics/mAP50(B)"]) + ( + 0.9 * results["metrics/mAP50-95(B)"] + ) print("Final results:") print(results) best_epoch = results["fitness"].idxmax() @@ -1223,7 +1298,10 @@ def train_batch_watcher_func(): # add geometry config to saved weights for pose estimation task if task_type == "pose estimation": for obj_class in project_meta.obj_classes: - if obj_class.geometry_type.geometry_name() == "graph" and obj_class.name in selected_classes: + if ( + obj_class.geometry_type.geometry_name() == "graph" + and obj_class.name in selected_classes + ): geometry_config = obj_class.geometry_config break weights_filepath = os.path.join(local_artifacts_dir, "weights", best_filename) @@ -1341,7 +1419,7 @@ def auto_train(request: Request): stepper.set_active_step(curr_step) card_train_val_split.unlock() card_train_val_split.uncollapse() - + train_val_split.disable() unlabeled_images_select.disable() split_data_button.hide() @@ -1382,19 +1460,19 @@ def auto_train(request: Request): card_train_progress.unlock() card_train_progress.uncollapse() - if sly.is_production(): - local_dir = g.root_source_path - else: - local_dir = g.app_root_directory + local_dir = g.root_model_checkpoint_dir if task_type == "object detection": necessary_geometries = ["rectangle"] - local_artifacts_dir = os.path.join(local_dir, "runs", "detect", "train") + checkpoint_dir = os.path.join(local_dir, "detect") + local_artifacts_dir = os.path.join(local_dir, "detect", "train") elif task_type == "pose estimation": necessary_geometries = ["graph", "rectangle"] - local_artifacts_dir = os.path.join(local_dir, "runs", "pose", "train") + checkpoint_dir = os.path.join(local_dir, "pose") + local_artifacts_dir = os.path.join(local_dir, "pose", "train") elif task_type == "instance segmentation": necessary_geometries = ["bitmap", "polygon"] - local_artifacts_dir = os.path.join(local_dir, "runs", "segment", "train") + checkpoint_dir = os.path.join(local_dir, "segment") + local_artifacts_dir = os.path.join(local_dir, "segment", "train") sly.logger.info(f"Local artifacts dir: {local_artifacts_dir}") @@ -1433,10 +1511,15 @@ def auto_train(request: Request): if task_type != "object detection": unnecessary_classes = [] for cls in project_meta.obj_classes: - if cls.name in selected_classes and cls.geometry_type.geometry_name() not in necessary_geometries: + if ( + cls.name in selected_classes + and cls.geometry_type.geometry_name() not in necessary_geometries + ): unnecessary_classes.append(cls.name) if len(unnecessary_classes) > 0: - sly.Project.remove_classes(g.project_dir, classes_to_remove=unnecessary_classes, inplace=True) + sly.Project.remove_classes( + g.project_dir, classes_to_remove=unnecessary_classes, inplace=True + ) # transfer project to detection task if necessary if task_type == "object detection": sly.Project.to_detection_task(g.project_dir, inplace=True) @@ -1455,6 +1538,7 @@ def auto_train(request: Request): progress_bar_convert_to_yolo, task_type, ) + # download model def download_monitor(monitor, api: sly.Api, progress: sly.Progress): value = monitor @@ -1478,7 +1562,9 @@ def download_monitor(monitor, api: sly.Api, progress: sly.Progress): model_filename = selected_model.lower() + ".pt" pretrained = True weights_dst_path = os.path.join(g.app_data_dir, model_filename) - weights_url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_filename}" + weights_url = ( + f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_filename}" + ) with urlopen(weights_url) as file: weights_size = file.length @@ -1597,10 +1683,16 @@ def on_results_file_changed(filepath, pbar): # visualize train batch batch = f"train_batch{x}.jpg" local_train_batches_path = os.path.join(local_artifacts_dir, batch) - if os.path.exists(local_train_batches_path) and batch not in plotted_train_batches and x < 10: + if ( + os.path.exists(local_train_batches_path) + and batch not in plotted_train_batches + and x < 10 + ): plotted_train_batches.append(batch) remote_train_batches_path = os.path.join(remote_images_path, batch) - tf_train_batches_info = api.file.upload(team_id, local_train_batches_path, remote_train_batches_path) + tf_train_batches_info = api.file.upload( + team_id, local_train_batches_path, remote_train_batches_path + ) train_batches_gallery.append(tf_train_batches_info.full_storage_url) if x == 0: train_batches_gallery_f.show() @@ -1644,6 +1736,7 @@ def train_batch_watcher_func(): model.train( data=data_path, + project=checkpoint_dir, epochs=state.get("n_epochs", n_epochs_input.get_value()), patience=state.get("patience", patience_input.get_value()), batch=state.get("batch_size", batch_size_input.get_value()), @@ -1706,8 +1799,12 @@ def train_batch_watcher_func(): # visualize additional training results confusion_matrix_path = os.path.join(local_artifacts_dir, "confusion_matrix_normalized.png") if os.path.exists(confusion_matrix_path): - remote_confusion_matrix_path = os.path.join(remote_images_path, "confusion_matrix_normalized.png") - tf_confusion_matrix_info = api.file.upload(team_id, confusion_matrix_path, remote_confusion_matrix_path) + remote_confusion_matrix_path = os.path.join( + remote_images_path, "confusion_matrix_normalized.png" + ) + tf_confusion_matrix_info = api.file.upload( + team_id, confusion_matrix_path, remote_confusion_matrix_path + ) additional_gallery.append(tf_confusion_matrix_info.full_storage_url) additional_gallery_f.show() pr_curve_path = os.path.join(local_artifacts_dir, "PR_curve.png") @@ -1728,18 +1825,24 @@ def train_batch_watcher_func(): pose_f1_curve_path = os.path.join(local_artifacts_dir, "PoseF1_curve.png") if os.path.exists(pose_f1_curve_path): remote_pose_f1_curve_path = os.path.join(remote_images_path, "PoseF1_curve.png") - tf_pose_f1_curve_info = api.file.upload(team_id, pose_f1_curve_path, remote_pose_f1_curve_path) + tf_pose_f1_curve_info = api.file.upload( + team_id, pose_f1_curve_path, remote_pose_f1_curve_path + ) additional_gallery.append(tf_pose_f1_curve_info.full_storage_url) mask_f1_curve_path = os.path.join(local_artifacts_dir, "MaskF1_curve.png") if os.path.exists(mask_f1_curve_path): remote_mask_f1_curve_path = os.path.join(remote_images_path, "MaskF1_curve.png") - tf_mask_f1_curve_info = api.file.upload(team_id, mask_f1_curve_path, remote_mask_f1_curve_path) + tf_mask_f1_curve_info = api.file.upload( + team_id, mask_f1_curve_path, remote_mask_f1_curve_path + ) additional_gallery.append(tf_mask_f1_curve_info.full_storage_url) # rename best checkpoint file results = pd.read_csv(watch_file) results.columns = [col.replace(" ", "") for col in results.columns] - results["fitness"] = (0.1 * results["metrics/mAP50(B)"]) + (0.9 * results["metrics/mAP50-95(B)"]) + results["fitness"] = (0.1 * results["metrics/mAP50(B)"]) + ( + 0.9 * results["metrics/mAP50-95(B)"] + ) print("Final results:") print(results) best_epoch = results["fitness"].idxmax() @@ -1751,7 +1854,10 @@ def train_batch_watcher_func(): # add geometry config to saved weights for pose estimation task if task_type == "pose estimation": for obj_class in project_meta.obj_classes: - if obj_class.geometry_type.geometry_name() == "graph" and obj_class.name in selected_classes: + if ( + obj_class.geometry_type.geometry_name() == "graph" + and obj_class.name in selected_classes + ): geometry_config = obj_class.geometry_config break weights_filepath = os.path.join(local_artifacts_dir, "weights", best_filename) diff --git a/train/src/metrics_watcher.py b/train/src/metrics_watcher.py index b8ece66..c1f314f 100644 --- a/train/src/metrics_watcher.py +++ b/train/src/metrics_watcher.py @@ -3,6 +3,7 @@ # import time import traceback +import supervisely as sly class Watcher(object): @@ -40,3 +41,5 @@ def watch(self): except Exception as e: print("Unhandled error:") print(traceback.format_exc()) + else: + sly.logger.debug("Watcher is stopped")