Skip to content

Commit

Permalink
Merge pull request #1 from supervisely-ecosystem/fix-session-id-key-e…
Browse files Browse the repository at this point in the history
…rror

Enhance model connection handling and add session validation warnings
  • Loading branch information
GoldenAnpu authored Dec 21, 2024
2 parents d98ff10 + fe00f26 commit aabb367
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
39 changes: 22 additions & 17 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
PREVIEW_IMAGES_INFOS = []
CURRENT_REF_IMAGE_INDEX = 0
REF_IMAGE_HISTORY = [CURRENT_REF_IMAGE_INDEX]
F_MODEL_DATA = {}
S_MODEL_DATA = {}
F_MODEL_DATA = {"session_id": None, "model_meta": None}
S_MODEL_DATA = {"session_id": None, "model_meta": None}


# fetching some images for preview
Expand Down Expand Up @@ -154,7 +154,6 @@ def download_data():
florence_set_model_type_button.text = "Select model"
sam2_set_model_type_button.disable()
sam2_set_model_type_button.text = "Select model"
set_prompts_button.disable()
set_prompts_button.text = "Set settings"
new_random_images_preview_button.disable()
get_predictions_preview_button.disable()
Expand Down Expand Up @@ -313,7 +312,6 @@ def set_florence_model_type():
toggle_cards(
["prompt_for_predictions_card", "preview_card", "apply_to_project_card"], enabled=False
)
set_prompts_button.disable()
set_prompts_button.text = "Set settings"
new_random_images_preview_button.disable()
get_predictions_preview_button.disable()
Expand All @@ -340,18 +338,17 @@ def set_florence_model_type():
florence_set_model_type_button.text = "Disconnect model"
florence_set_model_type_button._plain = True
florence_set_model_type_button.enable()
toggle_cards(["prompt_for_predictions_card"], enabled=True)
set_prompts_button.enable()
if sam2_set_model_type_button.text == "Disconnect model":
stepper.set_active_step(3)
toggle_cards(["prompt_for_predictions_card"], enabled=True)
prompt_for_predictions_card.uncollapse()
except Exception as e:
sly.app.show_dialog(
"Error",
f"Cannot to connect to model. Make sure that model is deployed and try again.",
status="error",
)
sly.logger.warn(f"Cannot to connect to model. {e}")
sly.logger.warning(f"Cannot to connect to model. {e}")
florence_set_model_type_button.enable()
florence_set_model_type_button.text = "Connect to model"
florence_model_set_done.hide()
Expand All @@ -361,7 +358,6 @@ def set_florence_model_type():
["prompt_for_predictions_card", "preview_card", "apply_to_project_card"],
enabled=False,
)
set_prompts_button.disable()
new_random_images_preview_button.disable()
get_predictions_preview_button.disable()
apply_to_project_button.disable()
Expand All @@ -382,7 +378,6 @@ def set_sam2_model_type():
toggle_cards(
["prompt_for_predictions_card", "preview_card", "apply_to_project_card"], enabled=False
)
set_prompts_button.disable()
set_prompts_button.text = "Set settings"
new_random_images_preview_button.disable()
get_predictions_preview_button.disable()
Expand Down Expand Up @@ -410,18 +405,17 @@ def set_sam2_model_type():
sam2_set_model_type_button.text = "Disconnect model"
sam2_set_model_type_button._plain = True
sam2_set_model_type_button.enable()
toggle_cards(["prompt_for_predictions_card"], enabled=True)
set_prompts_button.enable()
if florence_set_model_type_button.text == "Disconnect model":
stepper.set_active_step(3)
toggle_cards(["prompt_for_predictions_card"], enabled=True)
prompt_for_predictions_card.uncollapse()
except Exception as e:
sly.app.show_dialog(
"Error",
f"Cannot to connect to model. Make sure that model is deployed and try again.",
status="error",
)
sly.logger.warn(f"Cannot to connect to model. {e}")
sly.logger.warning(f"Cannot to connect to model. {e}")
sam2_set_model_type_button.enable()
sam2_set_model_type_button.text = "Connect to model"
sam2_model_set_done.hide()
Expand All @@ -431,7 +425,6 @@ def set_sam2_model_type():
["prompt_for_predictions_card", "preview_card", "apply_to_project_card"],
enabled=False,
)
set_prompts_button.disable()
new_random_images_preview_button.disable()
get_predictions_preview_button.disable()
apply_to_project_button.disable()
Expand Down Expand Up @@ -596,7 +589,14 @@ def update_images_preview():

@get_predictions_preview_button.click
def get_and_update_predictions_preview():

if F_MODEL_DATA.get("session_id") is None or S_MODEL_DATA.get("session_id") is None:
sly.app.show_dialog(
"Warning",
"Please connect to both models before getting predictions preview.",
status="warning",
)
sly.logger.warning("Please connect to both models before getting predictions preview.")
return
new_random_images_preview_button.disable()
if g.force_common_tab or inference_prompt_types.get_active_tab() == common_tab_name:
temp_meta = sly.ProjectMeta()
Expand Down Expand Up @@ -734,7 +734,6 @@ def run_model():
)
g.save_bboxes = save_bbox_switch.is_on()
button_download_data.disable()
set_prompts_button.disable()
florence_set_model_type_button.disable()
sam2_set_model_type_button.disable()
new_random_images_preview_button.disable()
Expand All @@ -757,7 +756,13 @@ def get_inference_settings():
output_project_info = apply_to_project_event(
destination_project, get_inference_settings(), F_MODEL_DATA, S_MODEL_DATA
)

if output_project_info is None:
sly.logger.warning(
"Something went wrong during applying models to project."
"Project thumbnail will not be shown. "
"Check logs for more information."
)
return
output_project_thmb.set(output_project_info)
output_project_thmb.show()
sly.logger.info("Project was successfully labeled")
Expand Down Expand Up @@ -797,6 +802,7 @@ def set_card_state(card, state, elements=[]):
prompt_common_container,
inference_prompt_types,
prompt_common_input,
set_prompts_button,
# prompt_classes_mapping, # TODO enable after implementation
],
),
Expand Down Expand Up @@ -824,7 +830,6 @@ def set_card_state(card, state, elements=[]):
)
florence_set_model_type_button.disable()
sam2_set_model_type_button.disable()
set_prompts_button.disable()
apply_to_project_button.disable()

stepper = Stepper(
Expand Down
8 changes: 8 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def apply_to_project_event(
F_MODEL_DATA: dict,
S_MODEL_DATA: dict,
):
if F_MODEL_DATA.get("session_id") is None or S_MODEL_DATA.get("session_id") is None:
sly.app.show_dialog(
"Warning",
f"Please connect to both models before applying them to the project",
status="warning",
)
sly.logger.warning("Please connect to both models before applying them to the project")
return

def update_proj_meta_classes(
ann: dict,
Expand Down

0 comments on commit aabb367

Please sign in to comment.