From ca465e1dd09e9d271b81f906c87f8d667a2c2e5d Mon Sep 17 00:00:00 2001 From: TheoLisin Date: Tue, 7 Nov 2023 14:54:41 +0000 Subject: [PATCH] disable whatcher on shutdown --- train/src/main.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/train/src/main.py b/train/src/main.py index 0b9f8fd..87795a8 100644 --- a/train/src/main.py +++ b/train/src/main.py @@ -1,4 +1,5 @@ import os + os.environ["CUDA_VISIBLE_DEVICES"] = "0" from pathlib import Path import numpy as np @@ -638,7 +639,9 @@ def select_task(task_type): @select_classes_button.click def select_classes(): selected_classes = classes_table.get_selected_classes() - selected_shapes = [cls.geometry_type.geometry_name() for cls in project_meta.obj_classes if cls.name in selected_classes] + selected_shapes = [ + cls.geometry_type.geometry_name() for cls in project_meta.obj_classes if cls.name in selected_classes + ] task_type = task_type_select.get_value() if task_type == "pose estimation" and ("graph" not in selected_shapes or "rectangle" not in selected_shapes): sly.app.show_dialog( @@ -1110,6 +1113,11 @@ def on_results_file_changed(filepath, pbar): def watcher_func(): watcher.watch() + def disable_watcher(): + watcher.running = False + + app.call_before_shutdown(disable_watcher) + threading.Thread(target=watcher_func, daemon=True).start() if len(train_set) > 300: n_train_batches = math.ceil(len(train_set) / batch_size_input.get_value()) @@ -1341,7 +1349,7 @@ def auto_train(request: Request): stepper.set_active_step(curr_step) card_train_val_split.unlock() card_train_val_split.uncollapse() - + train_val_split.disable() unlabeled_images_select.disable() split_data_button.hide() @@ -1455,6 +1463,7 @@ def auto_train(request: Request): progress_bar_convert_to_yolo, task_type, ) + # download model def download_monitor(monitor, api: sly.Api, progress: sly.Progress): value = monitor @@ -1766,9 +1775,7 @@ def train_batch_watcher_func(): print(app_url, file=text_file) # upload training artifacts to team files - remote_artifacts_dir = os.path.join( - "/yolov8_train", task_type, project_info.name, str(g.app_session_id) - ) + remote_artifacts_dir = os.path.join("/yolov8_train", task_type, project_info.name, str(g.app_session_id)) def upload_monitor(monitor, api: sly.Api, progress: sly.Progress): value = monitor.bytes_read