From 39172a3e79871224fa55cfcbd403d01f4bc9bfc9 Mon Sep 17 00:00:00 2001 From: almaz Date: Fri, 13 Dec 2024 16:20:57 +0100 Subject: [PATCH] fix bg class name and color. Use local checkpoint --- serve/src/mmsegm_model.py | 9 +++++---- train/src/ui/monitoring.py | 9 +++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/serve/src/mmsegm_model.py b/serve/src/mmsegm_model.py index e3e57e5..6533b86 100644 --- a/serve/src/mmsegm_model.py +++ b/serve/src/mmsegm_model.py @@ -167,10 +167,11 @@ def load_model( ) local_config_path = os.path.join(root_source_path, config_url) else: - self.download( - src_path=checkpoint_url, - dst_path=local_weights_path, - ) + if not sly.fs.file_exists(local_weights_path): + self.download( + src_path=checkpoint_url, + dst_path=local_weights_path, + ) local_config_path = os.path.join(configs_dir, "custom", "config.py") if sly.fs.file_exists(local_config_path): silent_remove(local_config_path) diff --git a/train/src/ui/monitoring.py b/train/src/ui/monitoring.py index ce48915..53c21cb 100644 --- a/train/src/ui/monitoring.py +++ b/train/src/ui/monitoring.py @@ -319,6 +319,13 @@ def init_class_charts_series(state): def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes=None): target_classes = target_classes or state["selectedClasses"] temp_project_seg_dir = g.project_seg_dir + "_temp" + bg_name = get_bg_class_name(target_classes) or "__bg__" + bg_color = (0, 0, 0) + if bg_name in target_classes: + try: + bg_color = palette[target_classes.index(bg_name)] + except: + pass project = sly.Project(g.project_dir, sly.OpenMode.READ) with TqdmProgress( @@ -330,6 +337,8 @@ def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes=N temp_project_seg_dir, target_classes=target_classes, progress_cb=p.update, + bg_color=bg_color, + bg_name=bg_name, ) palette_lookup = np.zeros(256**3, dtype=np.int32)