Skip to content

Commit

Permalink
Update workflow_input
Browse files Browse the repository at this point in the history
  • Loading branch information
GoldenAnpu committed Sep 6, 2024
1 parent 3975d41 commit e330e80
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions serve/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def load_model(
f"Config file not found: {config_url}. "
"Config should be placed in the same directory as the checkpoint file."
)

try:
cfg = Config.fromfile(local_config_path)
cfg.model.pretrained = None
Expand All @@ -216,6 +215,15 @@ def load_model(

self.model.cfg = cfg # save the config in the model for convenience
self.model.to(device)
# -------------------------------------- Add Workflow Input -------------------------------------- #
sly.logger.debug("Workflow: Start processing Input")
if model_source == "Custom models":
sly.logger.debug("Workflow: Custom model detected")
w.workflow_input(api, checkpoint_url)
else:
sly.logger.debug("Workflow: Pretrained model detected. No need to set Input")
sly.logger.debug("Workflow: Finish processing Input")
# ----------------------------------------------- - ---------------------------------------------- #
self.model.eval()
self.model = revert_sync_batchnorm(self.model)

Expand Down Expand Up @@ -373,13 +381,6 @@ def predict(
# this code block is running on Supervisely platform in production
# just ignore it during development
m.serve()
sly.logger.debug("Workflow: Start processing Input")
if m.model_source_tabs.get_active_tab() == "Custom models":
sly.logger.debug("Workflow: Custom model detected")
w.workflow_input(api, m.get_params_from_gui()["checkpoint_url"])
else:
sly.logger.debug("Workflow: Pretrained model detected. No need to set Input")
sly.logger.debug("Workflow: Finish processing Input")
else:
# for local development and debugging without GUI
models = m.get_models(add_links=True)
Expand Down

0 comments on commit e330e80

Please sign in to comment.