diff --git a/train/src/sly_functions.py b/train/src/sly_functions.py index d82ea74..d724385 100644 --- a/train/src/sly_functions.py +++ b/train/src/sly_functions.py @@ -13,5 +13,5 @@ def get_eval_results_dir_name(api, task_id, project_info): task_info = api.task.get_info_by_id(task_id) task_dir = f"{task_id}_{task_info['meta']['app']['name']}" eval_res_dir = f"/model-benchmark/{project_info.id}_{project_info.name}/{task_dir}/" - eval_res_dir = api.storage.get_free_dir_name(sly.env.team_id(), eval_res_dir) + eval_res_dir = api.file.get_free_dir_name(sly.env.team_id(), eval_res_dir) return eval_res_dir diff --git a/train/src/ui/monitoring.html b/train/src/ui/monitoring.html index 04b84fc..36d9406 100644 --- a/train/src/ui/monitoring.html +++ b/train/src/ui/monitoring.html @@ -125,6 +125,13 @@ +
+
+ {{data.progressTqdm}}: {{data.progressCurrentTqdm}} / + {{data.progressTotalTqdm}} +
+ +
{{data.progressEpoch}}: {{data.progressCurrentEpoch}} / diff --git a/train/src/ui/monitoring.py b/train/src/ui/monitoring.py index 7dfbd06..6893523 100644 --- a/train/src/ui/monitoring.py +++ b/train/src/ui/monitoring.py @@ -25,17 +25,27 @@ import sly_logger_hook -def external_callback(progress: sly.tqdm_sly): +def external_update_callback(progress: sly.tqdm_sly, progress_name: str): percent = math.floor(progress.n / progress.total * 100) fields = [] if hasattr(progress, "desc"): - fields.append({"field": f"data.progressBenchmark", "payload": progress.desc}) + fields.append({"field": f"data.progress{progress_name}", "payload": progress.desc}) elif hasattr(progress, "message"): - fields.append({"field": f"data.progressBenchmark", "payload": progress.message}) + fields.append({"field": f"data.progress{progress_name}", "payload": progress.message}) fields += [ - {"field": f"data.progressCurrentBenchmark", "payload": progress.n}, - {"field": f"data.progressTotalBenchmark", "payload": progress.total}, - {"field": f"data.progressPercentBenchmark", "payload": percent}, + {"field": f"data.progressCurrent{progress_name}", "payload": progress.n}, + {"field": f"data.progressTotal{progress_name}", "payload": progress.total}, + {"field": f"data.progressPercent{progress_name}", "payload": percent}, + ] + g.api.app.set_fields(g.task_id, fields) + + +def external_close_callback(progress: sly.tqdm_sly, progress_name: str): + fields = [ + {"field": f"data.progress{progress_name}", "payload": None}, + {"field": f"data.progressCurrent{progress_name}", "payload": None}, + {"field": f"data.progressTotal{progress_name}", "payload": None}, + {"field": f"data.progressPercent{progress_name}", "payload": None}, ] g.api.app.set_fields(g.task_id, fields) @@ -43,7 +53,21 @@ def external_callback(progress: sly.tqdm_sly): class TqdmBenchmark(sly.tqdm_sly): def update(self, n=1): super().update(n) - external_callback(self) + external_update_callback(self, "Benchmark") + + def close(self): + super().close() + external_close_callback(self, "Benchmark") + + +class TqdmProgress(sly.tqdm_sly): + def update(self, n=1): + super().update(n) + external_update_callback(self, "Tqdm") + + def close(self): + super().close() + external_close_callback(self, "Tqdm") _open_lnk_name = "open_app.lnk" @@ -55,6 +79,7 @@ def init(data, state): init_progress("Iter", data) init_progress("UploadDir", data) init_progress("Benchmark", data) + init_progress("Tqdm", data) data["eta"] = None state["isValidation"] = False @@ -294,37 +319,59 @@ def init_class_charts_series(state): def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes=None): target_classes = target_classes or state["selectedClasses"] temp_project_seg_dir = g.project_seg_dir + "_temp" - sly.Project.to_segmentation_task( - g.project_dir, temp_project_seg_dir, target_classes=target_classes - ) + + project = sly.Project(g.project_dir, sly.OpenMode.READ) + with TqdmProgress( + message="Converting project to segmentation task", + total=project.total_items, + ) as p: + sly.Project.to_segmentation_task( + g.project_dir, + temp_project_seg_dir, + target_classes=target_classes, + progress_cb=p.update, + ) + + palette_lookup = np.zeros(256**3, dtype=np.int32) + for idx, color in enumerate(palette, 1): + key = (color[0] << 16) | (color[1] << 8) | color[2] + palette_lookup[key] = idx datasets = os.listdir(temp_project_seg_dir) os.makedirs(os.path.join(g.project_seg_dir, img_dir), exist_ok=True) os.makedirs(os.path.join(g.project_seg_dir, ann_dir), exist_ok=True) - for dataset in datasets: - if not os.path.isdir(os.path.join(temp_project_seg_dir, dataset)): - if dataset == "meta.json": - shutil.move(os.path.join(temp_project_seg_dir, "meta.json"), g.project_seg_dir) - continue - # convert masks to required format and save to general ann_dir - mask_files = os.listdir(os.path.join(temp_project_seg_dir, dataset, ann_dir)) - for mask_file in mask_files: - mask = cv2.imread(os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file))[ - :, :, ::-1 - ] - result = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int32) - # human masks to machine masks - for color_idx, color in enumerate(palette): - colormap = np.where(np.all(mask == color, axis=-1)) - result[colormap] = color_idx - cv2.imwrite(os.path.join(g.project_seg_dir, ann_dir, mask_file), result) - - imgfiles_to_move = os.listdir(os.path.join(temp_project_seg_dir, dataset, img_dir)) - for filename in imgfiles_to_move: - shutil.move( - os.path.join(temp_project_seg_dir, dataset, img_dir, filename), - os.path.join(g.project_seg_dir, img_dir), - ) + + with TqdmProgress( + message="Converting masks to required format", + total=project.total_items, + ) as p: + for dataset in datasets: + if not os.path.isdir(os.path.join(temp_project_seg_dir, dataset)): + if dataset == "meta.json": + shutil.move(os.path.join(temp_project_seg_dir, "meta.json"), g.project_seg_dir) + continue + # convert masks to required format and save to general ann_dir + mask_files = os.listdir(os.path.join(temp_project_seg_dir, dataset, ann_dir)) + for mask_file in mask_files: + path = os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file) + mask = cv2.imread(path)[:, :, ::-1] + + mask_keys = ( + (mask[:, :, 0].astype(np.int32) << 16) + | (mask[:, :, 1].astype(np.int32) << 8) + | mask[:, :, 2].astype(np.int32) + ) + result = palette_lookup[mask_keys] + cv2.imwrite(os.path.join(g.project_seg_dir, ann_dir, mask_file), result) + + p.update(1) + + imgfiles_to_move = os.listdir(os.path.join(temp_project_seg_dir, dataset, img_dir)) + for filename in imgfiles_to_move: + shutil.move( + os.path.join(temp_project_seg_dir, dataset, img_dir, filename), + os.path.join(g.project_seg_dir, img_dir), + ) shutil.rmtree(temp_project_seg_dir) g.api.app.set_field(g.task_id, "state.preparingData", False) @@ -342,145 +389,150 @@ def run_benchmark(api: sly.Api, task_id, classes, cfg, state, remote_dir): import asyncio dataset_infos = api.dataset.get_list(g.project_id, recursive=True) - # creating_report.show() - - # 0. Find the best checkpoint - best_filename = None - best_checkpoints = [] - latest_checkpoint = None - other_checkpoints = [] - for root, dirs, files in os.walk(g.checkpoints_dir): - for file_name in files: - path = os.path.join(root, file_name) - if file_name.endswith(".pth"): - if file_name.startswith("best_"): - best_checkpoints.append(path) - elif file_name == "latest.pth": - latest_checkpoint = path - elif file_name.startswith("epoch_"): - other_checkpoints.append(path) - - if len(best_checkpoints) > 1: - best_checkpoints = sorted(best_checkpoints, key=lambda x: x, reverse=True) - elif len(best_checkpoints) == 0: - sly.logger.info("Best model checkpoint not found in the checkpoints directory.") - if latest_checkpoint is not None: - best_checkpoints = [latest_checkpoint] - sly.logger.info(f"Using latest checkpoint for evaluation: {latest_checkpoint!r}") - elif len(other_checkpoints) > 0: - parse_epoch = lambda x: int(x.split("_")[-1].split(".")[0]) - best_checkpoints = sorted(other_checkpoints, key=parse_epoch, reverse=True) - sly.logger.info( - f"Using the last epoch checkpoint for evaluation: {best_checkpoints[0]!r}" - ) - if len(best_checkpoints) == 0: - raise ValueError("No checkpoints found for evaluation.") - best_checkpoint = Path(best_checkpoints[0]) - sly.logger.info(f"Starting model benchmark with the checkpoint: {best_checkpoint!r}") - best_filename = best_checkpoint.name - workdir = best_checkpoint.parent - - # 1. Serve trained model - m = MMSegmentationModelBench(model_dir=str(workdir), use_gui=False) - - device = "cuda" if torch.cuda.is_available() else "cpu" - sly.logger.info(f"Using device: {device}") - - checkpoint_path = g.sly_mmseg.get_weights_path(remote_dir) + "/" + best_filename - config_path = g.sly_mmseg.get_config_path(remote_dir) - sly.logger.info(f"Checkpoint path: {checkpoint_path}") + dummy_pbar = TqdmProgress + with dummy_pbar(message="Preparing trained model for benchmark", total=1) as p: + # 0. Find the best checkpoint + best_filename = None + best_checkpoints = [] + latest_checkpoint = None + other_checkpoints = [] + for root, dirs, files in os.walk(g.checkpoints_dir): + for file_name in files: + path = os.path.join(root, file_name) + if file_name.endswith(".pth"): + if file_name.startswith("best_"): + best_checkpoints.append(path) + elif file_name == "latest.pth": + latest_checkpoint = path + elif file_name.startswith("epoch_"): + other_checkpoints.append(path) + + if len(best_checkpoints) > 1: + best_checkpoints = sorted(best_checkpoints, key=lambda x: x, reverse=True) + elif len(best_checkpoints) == 0: + sly.logger.info("Best model checkpoint not found in the checkpoints directory.") + if latest_checkpoint is not None: + best_checkpoints = [latest_checkpoint] + sly.logger.info( + f"Using latest checkpoint for evaluation: {latest_checkpoint!r}" + ) + elif len(other_checkpoints) > 0: + parse_epoch = lambda x: int(x.split("_")[-1].split(".")[0]) + best_checkpoints = sorted(other_checkpoints, key=parse_epoch, reverse=True) + sly.logger.info( + f"Using the last epoch checkpoint for evaluation: {best_checkpoints[0]!r}" + ) - try: - arch_type = cfg.model.backbone.type - except Exception as e: - arch_type = "unknown" - - sly.logger.info(f"Model architecture: {arch_type}") - - deploy_params = dict( - device=device, - model_source="Custom models", - task_type=sly.nn.TaskType.SEMANTIC_SEGMENTATION, - checkpoint_name=best_filename, - checkpoint_url=checkpoint_path, - config_url=config_path, - arch_type=arch_type, - ) - m._load_model(deploy_params) - asyncio.set_event_loop(asyncio.new_event_loop()) - m.serve() + if len(best_checkpoints) == 0: + raise ValueError("No checkpoints found for evaluation.") + best_checkpoint = Path(best_checkpoints[0]) + sly.logger.info(f"Starting model benchmark with the checkpoint: {best_checkpoint!r}") + best_filename = best_checkpoint.name + workdir = best_checkpoint.parent - import requests - import uvicorn - import time - from threading import Thread + # 1. Serve trained model + m = MMSegmentationModelBench(model_dir=str(workdir), use_gui=False) - def run_app(): - uvicorn.run(m.app, host="localhost", port=8000) + device = "cuda" if torch.cuda.is_available() else "cpu" + sly.logger.info(f"Using device: {device}") - thread = Thread(target=run_app, daemon=True) - thread.start() + checkpoint_path = g.sly_mmseg.get_weights_path(remote_dir) + "/" + best_filename + config_path = g.sly_mmseg.get_config_path(remote_dir) + sly.logger.info(f"Checkpoint path: {checkpoint_path}") - while True: try: - requests.get("http://localhost:8000") - print("✅ Local server is ready") - break - except requests.exceptions.ConnectionError: - print("Waiting for the server to be ready") - time.sleep(0.1) - - session = SessionJSON(api, session_url="http://localhost:8000") - if sly.fs.dir_exists(g.data_dir + "/benchmark"): - sly.fs.remove_dir(g.data_dir + "/benchmark") - - # 1. Init benchmark (todo: auto-detect task type) - benchmark_dataset_ids = None - benchmark_images_ids = None - train_dataset_ids = None - train_images_ids = None - - split_method = state["splitMethod"] - - if split_method == "datasets": - train_datasets = state["trainDatasets"] - val_datasets = state["valDatasets"] - benchmark_dataset_ids = [ds.id for ds in dataset_infos if ds.name in val_datasets] - train_dataset_ids = [ds.id for ds in dataset_infos if ds.name in train_datasets] - else: - - def get_image_infos_by_split(split: list): - ds_infos_dict = {ds_info.name: ds_info for ds_info in dataset_infos} - image_names_per_dataset = {} - for item in split: - name = item.name - if name[1] == "_": - name = name[2:] - image_names_per_dataset.setdefault(item.dataset_name, []).append(name) - image_infos = [] - for dataset_name, image_names in image_names_per_dataset.items(): - ds_info = ds_infos_dict[dataset_name] - image_infos.extend( - api.image.get_list( - ds_info.id, - filters=[ - { - "field": "name", - "operator": "in", - "value": image_names, - } - ], + arch_type = cfg.model.backbone.type + except Exception as e: + arch_type = "unknown" + + sly.logger.info(f"Model architecture: {arch_type}") + + deploy_params = dict( + device=device, + model_source="Custom models", + task_type=sly.nn.TaskType.SEMANTIC_SEGMENTATION, + checkpoint_name=best_filename, + checkpoint_url=checkpoint_path, + config_url=config_path, + arch_type=arch_type, + ) + m._load_model(deploy_params) + asyncio.set_event_loop(asyncio.new_event_loop()) + m.serve() + + import requests + import uvicorn + import time + from threading import Thread + + def run_app(): + uvicorn.run(m.app, host="localhost", port=8000) + + thread = Thread(target=run_app, daemon=True) + thread.start() + + while True: + try: + requests.get("http://localhost:8000") + print("✅ Local server is ready") + break + except requests.exceptions.ConnectionError: + print("Waiting for the server to be ready") + time.sleep(0.1) + + session = SessionJSON(api, session_url="http://localhost:8000") + if sly.fs.dir_exists(g.data_dir + "/benchmark"): + sly.fs.remove_dir(g.data_dir + "/benchmark") + + # 1. Init benchmark (todo: auto-detect task type) + benchmark_dataset_ids = None + benchmark_images_ids = None + train_dataset_ids = None + train_images_ids = None + + split_method = state["splitMethod"] + + if split_method == "datasets": + train_datasets = state["trainDatasets"] + val_datasets = state["valDatasets"] + benchmark_dataset_ids = [ds.id for ds in dataset_infos if ds.name in val_datasets] + train_dataset_ids = [ds.id for ds in dataset_infos if ds.name in train_datasets] + else: + + def get_image_infos_by_split(split: list): + ds_infos_dict = {ds_info.name: ds_info for ds_info in dataset_infos} + image_names_per_dataset = {} + for item in split: + name = item.name + if name[1] == "_": + name = name[2:] + image_names_per_dataset.setdefault(item.dataset_name, []).append(name) + image_infos = [] + for dataset_name, image_names in image_names_per_dataset.items(): + ds_info = ds_infos_dict[dataset_name] + image_infos.extend( + api.image.get_list( + ds_info.id, + filters=[ + { + "field": "name", + "operator": "in", + "value": image_names, + } + ], + ) ) - ) - return image_infos + return image_infos + + train_set, val_set = get_train_val_sets(g.project_dir, state) - train_set, val_set = get_train_val_sets(g.project_dir, state) + val_image_infos = get_image_infos_by_split(val_set) + train_image_infos = get_image_infos_by_split(train_set) + benchmark_images_ids = [img_info.id for img_info in val_image_infos] + train_images_ids = [img_info.id for img_info in train_image_infos] - val_image_infos = get_image_infos_by_split(val_set) - train_image_infos = get_image_infos_by_split(train_set) - benchmark_images_ids = [img_info.id for img_info in val_image_infos] - train_images_ids = [img_info.id for img_info in train_image_infos] + p.update(1) pbar = TqdmBenchmark bm = sly.nn.benchmark.SemanticSegmentationBenchmark( @@ -540,7 +592,6 @@ def get_image_infos_by_split(split: list): benchmark_report_template = bm.report fields = [ - {"field": f"data.progressBenchmark", "payload": False}, {"field": f"state.benchmarkInProgress", "payload": False}, {"field": f"data.benchmarkUrl", "payload": bm.get_report_link()}, ] @@ -550,7 +601,6 @@ def get_image_infos_by_split(split: list): ) except Exception as e: sly.logger.error(f"Model benchmark failed. {repr(e)}", exc_info=True) - api.app.set_field(task_id, "data.progressBenchmark", False) try: if bm.dt_project_info: api.project.remove(bm.dt_project_info.id)