Skip to content

Commit

Permalink
add a little bit of typing
Browse files Browse the repository at this point in the history
  • Loading branch information
TheoLisin committed Sep 26, 2023
1 parent e955ef5 commit 4c40267
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
4 changes: 3 additions & 1 deletion agent/worker/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding: utf-8

import time
from typing import Dict
import docker
import json
import threading
Expand All @@ -20,6 +21,7 @@
import torch

from worker import constants
from worker.task_sly import TaskSly
from worker.task_factory import create_task, is_task_type
from worker.logs_to_rpc import add_task_handler
from worker.agent_utils import LogQueue
Expand All @@ -45,7 +47,7 @@ def __init__(self):
self.logger.info("Agent comes back...")

self.task_pool_lock = threading.Lock()
self.task_pool = {} # task_id -> task_manager (process_id)
self.task_pool: Dict[int, TaskSly] = {} # task_id -> task_manager (process_id)

self.thread_pool = ThreadPoolExecutor(max_workers=10)
self.thread_list = []
Expand Down
67 changes: 36 additions & 31 deletions agent/worker/task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import base64
import json
import time
from typing import Dict, Optional, Type

import supervisely_lib as sly

from docker import DockerClient

from worker import constants
from worker.task_sly import TaskSly
from worker.task_dockerized import TaskDockerized
from worker.task_dtl import TaskDTL
from worker.task_import import TaskImport
Expand All @@ -26,65 +30,66 @@
from worker.task_app import TaskApp


_task_class_mapping = {
'export': TaskDTL,
'import': TaskImport,
'upload_model': TaskUploadNN,
'train': TaskTrain,
'inference': TaskInference,
'cleanup': TaskCleanNode,
'smarttool': TaskInferenceRPC, # for compatibility
'infer_rpc': TaskInferenceRPC,
'upload_images': TaskUploadImages,
'import_agent': TaskImportLocal,
'custom': TaskCustom,
'update_agent': TaskUpdate,
'python': TaskPython,
'general_plugin': TaskPlugin,
'general_plugin_import_agent': TaskPluginImportLocal,
'app': TaskApp
_task_class_mapping: Dict[str, Type[TaskSly]] = {
"export": TaskDTL,
"import": TaskImport,
"upload_model": TaskUploadNN,
"train": TaskTrain,
"inference": TaskInference,
"cleanup": TaskCleanNode,
"smarttool": TaskInferenceRPC, # for compatibility
"infer_rpc": TaskInferenceRPC,
"upload_images": TaskUploadImages,
"import_agent": TaskImportLocal,
"custom": TaskCustom,
"update_agent": TaskUpdate,
"python": TaskPython,
"general_plugin": TaskPlugin,
"general_plugin_import_agent": TaskPluginImportLocal,
"app": TaskApp,
}


def create_task(task_msg, docker_api):
task_id = task_msg.get('task_id', None)
def create_task(task_msg, docker_api: DockerClient) -> TaskSly:
task_id = task_msg.get("task_id", None)
task_type = get_run_mode(docker_api, task_msg)
task_cls = _task_class_mapping.get(task_type, None)
if task_cls is None:
sly.logger.critical('unknown task type', extra={'task_msg': task_msg})
raise RuntimeError('unknown task type')
sly.logger.critical("unknown task type", extra={"task_msg": task_msg})
raise RuntimeError("unknown task type")
task_obj = task_cls(task_msg)
if issubclass(task_cls, TaskDockerized) or (task_msg['task_type'] == 'update_agent'):
if issubclass(task_cls, TaskDockerized) or (task_msg["task_type"] == "update_agent"):
task_obj.docker_api = docker_api
return task_obj


def get_run_mode(docker_api, task_msg):
if "docker_image" not in task_msg:
return task_msg['task_type']
return task_msg["task_type"]

temp_msg = {**task_msg, 'pull_policy': constants.PULL_POLICY()}
temp_msg = {**task_msg, "pull_policy": constants.PULL_POLICY()}
task_pull = TaskPullDockerImage(temp_msg)
task_pull.docker_api = docker_api
task_pull.start()
while task_pull.is_alive():
time.sleep(1)
#@TODO: check later
#task_pull.join(timeout=20)
#task_pull.terminate()
# @TODO: check later
# task_pull.join(timeout=20)
# task_pull.terminate()

image_info = docker_api.images.get(task_msg["docker_image"])
try:
plugin_info = json.loads(base64.b64decode(image_info.labels["INFO"]).decode("utf-8"))
except Exception as e:
plugin_info = {}

result = plugin_info.get("run_mode", task_msg['task_type'])
result = plugin_info.get("run_mode", task_msg["task_type"])

if result == 'general_plugin' and task_msg['task_type'] == "import_agent":
return 'general_plugin_import_agent'
if result == "general_plugin" and task_msg["task_type"] == "import_agent":
return "general_plugin_import_agent"

return result


def is_task_type(task_obj, task_name):
return type(task_obj) is _task_class_mapping[task_name]
return type(task_obj) is _task_class_mapping[task_name]

0 comments on commit 4c40267

Please sign in to comment.