forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
80ed3ac
commit af987e6
Showing
9 changed files
with
166 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import supervisely as sly | ||
|
||
|
||
def check_compatibility(func): | ||
def wrapper(self, *args, **kwargs): | ||
if self.is_compatible is None: | ||
self.is_compatible = self.check_instance_ver_compatibility() | ||
if not self.is_compatible: | ||
return | ||
return func(self, *args, **kwargs) | ||
|
||
return wrapper | ||
|
||
|
||
class Workflow: | ||
def __init__(self, api: sly.Api, min_instance_version: str = None): | ||
self.is_compatible = None | ||
self.api = api | ||
self._min_instance_version = ( | ||
"6.9.31" if min_instance_version is None else min_instance_version | ||
) | ||
|
||
def check_instance_ver_compatibility(self): | ||
if not self.api.is_version_supported(self._min_instance_version): | ||
sly.logger.info( | ||
f"Supervisely instance version {self.api.instance_version} does not support workflow features." | ||
) | ||
if not sly.is_community(): | ||
sly.logger.info( | ||
f"To use them, please update your instance to version {self._min_instance_version} or higher." | ||
) | ||
return False | ||
return True | ||
|
||
@check_compatibility | ||
def add_input(self, checkpoint_url: str): | ||
meta = {"customNodeSettings": {"title": "<h4>Serve Custom Model</h4>"}} | ||
sly.logger.debug(f"Workflow Input: Checkpoint URL - {checkpoint_url}") | ||
if checkpoint_url and self.api.file.exists(sly.env.team_id(), checkpoint_url): | ||
self.api.app.workflow.add_input_file(checkpoint_url, model_weight=True, meta=meta) | ||
else: | ||
sly.logger.debug(f"Checkpoint {checkpoint_url} not found in Team Files. Cannot set workflow input") | ||
|
||
@check_compatibility | ||
def add_output(self): | ||
raise NotImplementedError("add_output is not implemented in this workflow") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import supervisely as sly | ||
import os | ||
|
||
api = sly.Api.from_env() | ||
api.task_id = 62688 | ||
os.environ.setdefault("TEAM_ID", "451") | ||
from workflow import Workflow | ||
|
||
workflow = Workflow(api) | ||
|
||
state = {"weightsInitialization": "custom"} | ||
|
||
team_files_dir = "/yolov5_train/Train dataset - Eschikon Wheat Segmentation (EWS)/62688" | ||
|
||
workflow.add_output(state, team_files_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Description: This file contains versioning features and the Workflow class that is used to add input and output to the workflow. | ||
|
||
import supervisely as sly | ||
import os | ||
|
||
|
||
def check_compatibility(func): | ||
def wrapper(self, *args, **kwargs): | ||
if self.is_compatible is None: | ||
self.is_compatible = self.check_instance_ver_compatibility() | ||
if not self.is_compatible: | ||
return | ||
return func(self, *args, **kwargs) | ||
|
||
return wrapper | ||
|
||
|
||
class Workflow: | ||
def __init__(self, api: sly.Api, min_instance_version: str = None): | ||
self.is_compatible = None | ||
self.api = api | ||
self._min_instance_version = ( | ||
"6.9.31" if min_instance_version is None else min_instance_version | ||
) | ||
|
||
def check_instance_ver_compatibility(self): | ||
if not self.api.is_version_supported(self._min_instance_version): | ||
sly.logger.info( | ||
f"Supervisely instance version {self.api.instance_version} does not support workflow and versioning features." | ||
) | ||
if not sly.is_community(): | ||
sly.logger.info( | ||
f"To use them, please update your instance to version {self._min_instance_version} or higher." | ||
) | ||
return False | ||
return True | ||
|
||
@check_compatibility | ||
def add_input(self, project_info: sly.ProjectInfo, state: dict): | ||
project_version_id = self.api.project.version.create( | ||
project_info, "Train YOLO v5", f"This backup was created automatically by Supervisely before the Train YOLO task with ID: {self.api.task_id}" | ||
) | ||
if project_version_id is None: | ||
project_version_id = project_info.version.get("id", None) if project_info.version else None | ||
self.api.app.workflow.add_input_project(project_info.id, version_id=project_version_id) | ||
if state["weightsInitialization"] is not None and state["weightsInitialization"] == "custom": | ||
file_info = self.api.file.get_info_by_path(sly.env.team_id(), state["_weightsPath"]) | ||
self.api.app.workflow.add_input_file(file_info, model_weight=True) | ||
sly.logger.debug(f"Workflow Input: Project ID - {project_info.id}, Project Version ID - {project_version_id}, Input File - {True if file_info else False}") | ||
|
||
@check_compatibility | ||
def add_output(self, state: dict, team_files_dir: str): | ||
weights_dir_in_team_files = os.path.join(team_files_dir, "weights") | ||
files_info = self.api.file.list(sly.env.team_id(), weights_dir_in_team_files, return_type="fileinfo") | ||
best_filename_info = None | ||
for file_info in files_info: | ||
if "best" in file_info.name: | ||
best_filename_info = file_info | ||
break | ||
if best_filename_info: | ||
module_id = self.api.task.get_info_by_id(self.api.task_id).get("meta", {}).get("app", {}).get("id") | ||
if state["weightsInitialization"] is not None and state["weightsInitialization"] == "custom": | ||
model_name = "Custom Model" | ||
else: | ||
model_name = "YOLOv5" | ||
|
||
meta = { | ||
"customNodeSettings": { | ||
"title": f"<h4>Train {model_name}</h4>", | ||
"mainLink": { | ||
"url": f"/apps/{module_id}/sessions/{self.api.task_id}" if module_id else f"apps/sessions/{self.api.task_id}", | ||
"title": "Show Results" | ||
} | ||
}, | ||
"customRelationSettings": { | ||
"icon": { | ||
"icon": "zmdi-folder", | ||
"color": "#FFA500", | ||
"backgroundColor": "#FFE8BE" | ||
}, | ||
"title": "<h4>Checkpoints</h4>", | ||
"mainLink": {"url": f"/files/{best_filename_info.id}/true", "title": "Open Folder"} | ||
} | ||
} | ||
sly.logger.debug(f"Workflow Output: Team Files dir - {team_files_dir}, Best filename - {best_filename_info.name}") | ||
sly.logger.debug(f"Workflow Output: meta \n {meta}") | ||
self.api.app.workflow.add_output_file(best_filename_info, model_weight=True, meta=meta) | ||
else: | ||
sly.logger.debug(f"File with the best weighs not found in Team Files. Cannot set workflow output.") | ||
|