Skip to content

Commit

Permalink
Enhance model connection handling and add session validation warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
GoldenAnpu committed Dec 21, 2024
1 parent cc17ebf commit 84bda69
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
30 changes: 14 additions & 16 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
PREVIEW_IMAGES_INFOS = []
CURRENT_REF_IMAGE_INDEX = 0
REF_IMAGE_HISTORY = [CURRENT_REF_IMAGE_INDEX]
F_MODEL_DATA = {}
S_MODEL_DATA = {}
F_MODEL_DATA = {"session_id": None, "model_meta": None}
S_MODEL_DATA = {"session_id": None, "model_meta": None}


# fetching some images for preview
Expand Down Expand Up @@ -154,7 +154,6 @@ def download_data():
florence_set_model_type_button.text = "Select model"
sam2_set_model_type_button.disable()
sam2_set_model_type_button.text = "Select model"
set_prompts_button.disable()
set_prompts_button.text = "Set settings"
new_random_images_preview_button.disable()
get_predictions_preview_button.disable()
Expand Down Expand Up @@ -313,7 +312,6 @@ def set_florence_model_type():
toggle_cards(
["prompt_for_predictions_card", "preview_card", "apply_to_project_card"], enabled=False
)
set_prompts_button.disable()
set_prompts_button.text = "Set settings"
new_random_images_preview_button.disable()
get_predictions_preview_button.disable()
Expand All @@ -340,18 +338,17 @@ def set_florence_model_type():
florence_set_model_type_button.text = "Disconnect model"
florence_set_model_type_button._plain = True
florence_set_model_type_button.enable()
toggle_cards(["prompt_for_predictions_card"], enabled=True)
set_prompts_button.enable()
if sam2_set_model_type_button.text == "Disconnect model":
stepper.set_active_step(3)
toggle_cards(["prompt_for_predictions_card"], enabled=True)
prompt_for_predictions_card.uncollapse()
except Exception as e:
sly.app.show_dialog(
"Error",
f"Cannot to connect to model. Make sure that model is deployed and try again.",
status="error",
)
sly.logger.warn(f"Cannot to connect to model. {e}")
sly.logger.warning(f"Cannot to connect to model. {e}")
florence_set_model_type_button.enable()
florence_set_model_type_button.text = "Connect to model"
florence_model_set_done.hide()
Expand All @@ -361,7 +358,6 @@ def set_florence_model_type():
["prompt_for_predictions_card", "preview_card", "apply_to_project_card"],
enabled=False,
)
set_prompts_button.disable()
new_random_images_preview_button.disable()
get_predictions_preview_button.disable()
apply_to_project_button.disable()
Expand All @@ -382,7 +378,6 @@ def set_sam2_model_type():
toggle_cards(
["prompt_for_predictions_card", "preview_card", "apply_to_project_card"], enabled=False
)
set_prompts_button.disable()
set_prompts_button.text = "Set settings"
new_random_images_preview_button.disable()
get_predictions_preview_button.disable()
Expand Down Expand Up @@ -410,18 +405,17 @@ def set_sam2_model_type():
sam2_set_model_type_button.text = "Disconnect model"
sam2_set_model_type_button._plain = True
sam2_set_model_type_button.enable()
toggle_cards(["prompt_for_predictions_card"], enabled=True)
set_prompts_button.enable()
if florence_set_model_type_button.text == "Disconnect model":
stepper.set_active_step(3)
toggle_cards(["prompt_for_predictions_card"], enabled=True)
prompt_for_predictions_card.uncollapse()
except Exception as e:
sly.app.show_dialog(
"Error",
f"Cannot to connect to model. Make sure that model is deployed and try again.",
status="error",
)
sly.logger.warn(f"Cannot to connect to model. {e}")
sly.logger.warning(f"Cannot to connect to model. {e}")
sam2_set_model_type_button.enable()
sam2_set_model_type_button.text = "Connect to model"
sam2_model_set_done.hide()
Expand All @@ -431,7 +425,6 @@ def set_sam2_model_type():
["prompt_for_predictions_card", "preview_card", "apply_to_project_card"],
enabled=False,
)
set_prompts_button.disable()
new_random_images_preview_button.disable()
get_predictions_preview_button.disable()
apply_to_project_button.disable()
Expand Down Expand Up @@ -596,7 +589,13 @@ def update_images_preview():

@get_predictions_preview_button.click
def get_and_update_predictions_preview():

if F_MODEL_DATA.get("session_id") is None or S_MODEL_DATA.get("session_id") is None:
sly.app.show_dialog(
"Warning",
"Please connect to both models before getting predictions preview.",
status="warning",
)
return
new_random_images_preview_button.disable()
if g.force_common_tab or inference_prompt_types.get_active_tab() == common_tab_name:
temp_meta = sly.ProjectMeta()
Expand Down Expand Up @@ -734,7 +733,6 @@ def run_model():
)
g.save_bboxes = save_bbox_switch.is_on()
button_download_data.disable()
set_prompts_button.disable()
florence_set_model_type_button.disable()
sam2_set_model_type_button.disable()
new_random_images_preview_button.disable()
Expand Down Expand Up @@ -797,6 +795,7 @@ def set_card_state(card, state, elements=[]):
prompt_common_container,
inference_prompt_types,
prompt_common_input,
set_prompts_button,
# prompt_classes_mapping, # TODO enable after implementation
],
),
Expand Down Expand Up @@ -824,7 +823,6 @@ def set_card_state(card, state, elements=[]):
)
florence_set_model_type_button.disable()
sam2_set_model_type_button.disable()
set_prompts_button.disable()
apply_to_project_button.disable()

stepper = Stepper(
Expand Down
6 changes: 6 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def apply_to_project_event(
F_MODEL_DATA: dict,
S_MODEL_DATA: dict,
):
if F_MODEL_DATA.get("session_id") is None or S_MODEL_DATA.get("session_id") is None:
sly.app.show_dialog(
"Warning",
f"Please run both models before applying them to the project",
status="warning",
)

def update_proj_meta_classes(
ann: dict,
Expand Down

0 comments on commit 84bda69

Please sign in to comment.