Skip to content

Commit

Permalink
disable whatcher on shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
TheoLisin committed Nov 7, 2023
1 parent 38d7053 commit ca465e1
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions train/src/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from pathlib import Path
import numpy as np
Expand Down Expand Up @@ -638,7 +639,9 @@ def select_task(task_type):
@select_classes_button.click
def select_classes():
selected_classes = classes_table.get_selected_classes()
selected_shapes = [cls.geometry_type.geometry_name() for cls in project_meta.obj_classes if cls.name in selected_classes]
selected_shapes = [
cls.geometry_type.geometry_name() for cls in project_meta.obj_classes if cls.name in selected_classes
]
task_type = task_type_select.get_value()
if task_type == "pose estimation" and ("graph" not in selected_shapes or "rectangle" not in selected_shapes):
sly.app.show_dialog(
Expand Down Expand Up @@ -1110,6 +1113,11 @@ def on_results_file_changed(filepath, pbar):
def watcher_func():
watcher.watch()

def disable_watcher():
watcher.running = False

app.call_before_shutdown(disable_watcher)

threading.Thread(target=watcher_func, daemon=True).start()
if len(train_set) > 300:
n_train_batches = math.ceil(len(train_set) / batch_size_input.get_value())
Expand Down Expand Up @@ -1341,7 +1349,7 @@ def auto_train(request: Request):
stepper.set_active_step(curr_step)
card_train_val_split.unlock()
card_train_val_split.uncollapse()

train_val_split.disable()
unlabeled_images_select.disable()
split_data_button.hide()
Expand Down Expand Up @@ -1455,6 +1463,7 @@ def auto_train(request: Request):
progress_bar_convert_to_yolo,
task_type,
)

# download model
def download_monitor(monitor, api: sly.Api, progress: sly.Progress):
value = monitor
Expand Down Expand Up @@ -1766,9 +1775,7 @@ def train_batch_watcher_func():
print(app_url, file=text_file)

# upload training artifacts to team files
remote_artifacts_dir = os.path.join(
"/yolov8_train", task_type, project_info.name, str(g.app_session_id)
)
remote_artifacts_dir = os.path.join("/yolov8_train", task_type, project_info.name, str(g.app_session_id))

def upload_monitor(monitor, api: sly.Api, progress: sly.Progress):
value = monitor.bytes_read
Expand Down

0 comments on commit ca465e1

Please sign in to comment.