Skip to content

Commit

Permalink
add deploy via api
Browse files Browse the repository at this point in the history
  • Loading branch information
cxnt committed Dec 25, 2023
1 parent c981d36 commit 5deddd7
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 25 deletions.
4 changes: 3 additions & 1 deletion serve/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
local.env
__pycache__/
__pycache__/
.venv
supervisely
5 changes: 2 additions & 3 deletions serve/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
supervisely==6.72.143
# git+https://github.com/supervisely/supervisely.git@some-test-branch

# supervisely==6.72.143
git+https://github.com/supervisely/supervisely.git@data-nodes-deploy-yolov8
ultralytics==8.0.112
--extra-index-url https://download.pytorch.org/whl/cu113
torch==1.10.1+cu113
Expand Down
66 changes: 45 additions & 21 deletions serve/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,33 +110,57 @@ def load_on_device(
self,
model_dir,
device: Literal["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"] = "cpu",
started_via_api: bool = False,
deploy_params: Dict[str, Any] = None,
):
model_source = self.gui.get_model_source()
if model_source == "Pretrained models":
self.task_type = self.task_type_select.get_value()
selected_model = self.gui.get_checkpoint_info()["Model"]
if selected_model.endswith("det"):
selected_model = selected_model[:-4]
model_filename = selected_model.lower() + ".pt"
weights_url = (
f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_filename}"
)
weights_dst_path = os.path.join(model_dir, model_filename)
if not sly.fs.file_exists(weights_dst_path):
self.download(
src_path=weights_url,
dst_path=weights_dst_path,
)
elif model_source == "Custom models":
self.task_type = self.custom_task_type_select.get_value()
custom_link = self.gui.get_custom_link()
weights_file_name = os.path.basename(custom_link)
if not started_via_api:
model_source = self.gui.get_model_source()
if model_source == "Pretrained models":
self.task_type = self.task_type_select.get_value()
selected_model = self.gui.get_checkpoint_info()["Model"]
if selected_model.endswith("det"):
selected_model = selected_model[:-4]
model_filename = selected_model.lower() + ".pt"
weights_url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_filename}"
weights_dst_path = os.path.join(model_dir, model_filename)
if not sly.fs.file_exists(weights_dst_path):
self.download(
src_path=weights_url,
dst_path=weights_dst_path,
)
elif model_source == "Custom models":
self.task_type = self.custom_task_type_select.get_value()
custom_link = self.gui.get_custom_link()
weights_file_name = os.path.basename(custom_link)
weights_dst_path = os.path.join(model_dir, weights_file_name)
if not sly.fs.file_exists(weights_dst_path):
self.download(
src_path=custom_link,
dst_path=weights_dst_path,
)
else:
# ------------------------ #
model_source = deploy_params["model_source"] # Pretrained models or Custom models
self.task_type = deploy_params["task_type"]
weights_file_name = deploy_params["weights_name"]
weights_dst_path = os.path.join(model_dir, weights_file_name)

if model_source == "Pretrained models":
weights_url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{weights_file_name}"
if not sly.fs.file_exists(weights_dst_path):
self.download(
src_path=weights_url,
dst_path=weights_dst_path,
)

if not sly.fs.file_exists(weights_dst_path):
custom_weights_path = deploy_params["custom_weights_path"]
self.download(
src_path=custom_link,
src_path=custom_weights_path,
dst_path=weights_dst_path,
)
# ------------------------ #

self.model = YOLO(weights_dst_path)
self.class_names = list(self.model.names.values())
if device.startswith("cuda"):
Expand Down

0 comments on commit 5deddd7

Please sign in to comment.