Skip to content

Commit

Permalink
add force cpu only option
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolaiPetukhov committed Feb 20, 2024
1 parent 5de7ac9 commit 6e74a26
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 38 deletions.
47 changes: 18 additions & 29 deletions agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,12 @@ def _start_net_client(docker_api=None):
]

if constants.SLY_EXTRA_CA_CERTS() and os.path.exists(constants.SLY_EXTRA_CA_CERTS()):
envs.append(f"{constants._SLY_EXTRA_CA_CERTS}={constants.SLY_EXTRA_CA_CERTS_FILEPATH()}")
volumes.append(f"{constants.SLY_EXTRA_CA_CERTS_VOLUME_NAME()}:{constants.SLY_EXTRA_CA_CERTS_DIR()}")
envs.append(
f"{constants._SLY_EXTRA_CA_CERTS}={constants.SLY_EXTRA_CA_CERTS_FILEPATH()}"
)
volumes.append(
f"{constants.SLY_EXTRA_CA_CERTS_VOLUME_NAME()}:{constants.SLY_EXTRA_CA_CERTS_DIR()}"
)

log_config = LogConfig(
type="local", config={"max-size": "1m", "max-file": "1", "compress": "false"}
Expand Down Expand Up @@ -178,26 +182,9 @@ def _start_net_client(docker_api=None):
raise e


def _nvidia_runtime_check():
def _is_runtime_changed(new_runtime):
container_info = get_container_info()
runtime = container_info["HostConfig"]["Runtime"]
if runtime == "nvidia":
return False
sly.logger.debug("NVIDIA runtime is not enabled. Checking if it can be enabled...")
docker_api = docker.from_env()
image = constants.DEFAULT_APP_DOCKER_IMAGE()
try:
docker_api.containers.run(
image,
command="nvidia-smi",
runtime="nvidia",
remove=True,
)
sly.logger.debug("NVIDIA runtime is available. Will restart Agent with NVIDIA runtime.")
return True
except Exception as e:
sly.logger.debug("NVIDIA runtime is not available.")
return False
return container_info["HostConfig"]["Runtime"] != new_runtime


def main(args):
Expand Down Expand Up @@ -235,21 +222,23 @@ def init_envs():
"Can not update agent options. Agent will be started with current options"
)
return
restart_with_nvidia_runtime = _nvidia_runtime_check()
if new_envs.get(constants._FORCE_CPU_ONLY, "false") == "true":
runtime = "runc"
runtime_changed = _is_runtime_changed(runtime)
else:
runtime = agent_utils.maybe_update_runtime()
runtime_changed = _is_runtime_changed(runtime)
envs_changes, volumes_changes, new_ca_cert_path = agent_utils.get_options_changes(
new_envs, new_volumes, ca_cert
)
if (
len(envs_changes) > 0
or len(volumes_changes) > 0
or restart_with_nvidia_runtime
or runtime_changed
or new_ca_cert_path is not None
):
docker_api = docker.from_env()
container_info = get_container_info()
runtime = (
"nvidia" if restart_with_nvidia_runtime else container_info["HostConfig"]["Runtime"]
)

# TODO: only set true if some NET_CLIENT envs changed
new_envs[constants._UPDATE_SLY_NET_AFTER_RESTART] = "true"
Expand All @@ -262,9 +251,9 @@ def init_envs():
for k, v in envs_changes.items()
},
"volumes_changes": volumes_changes,
"runtime_changes": {container_info["HostConfig"]["Runtime"]: runtime}
if restart_with_nvidia_runtime
else {},
"runtime_changes": (
{container_info["HostConfig"]["Runtime"]: runtime} if runtime_changed else {}
),
"ca_cert_changed": bool(new_ca_cert_path),
},
)
Expand Down
55 changes: 48 additions & 7 deletions agent/worker/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class AgentOptionsJsonFields:
NET_CLIENT_DOCKER_IMAGE = "dockerImage"
NET_SERVER_PORT = "netServerPort"
DOCKER_IMAGE = "dockerImage"
FORCE_CPU_ONLY = "forceCPUOnly"


def create_img_meta_str(img_size_bytes, width, height):
Expand Down Expand Up @@ -584,7 +585,7 @@ def get_agent_options(server_address=None, token=None, timeout=60) -> dict:

url = constants.PUBLIC_API_SERVER_ADDRESS() + "agents.options.info"
resp = requests.post(url=url, json={"token": token}, timeout=timeout)
if resp.status_code != requests.codes.ok: # pylint: disable=no-member
if resp.status_code != requests.codes.ok: # pylint: disable=no-member
try:
text = resp.text
except:
Expand All @@ -601,7 +602,7 @@ def get_instance_version(server_address=None, timeout=60):
server_address = constants.SERVER_ADDRESS()
url = constants.PUBLIC_API_SERVER_ADDRESS() + "instance.version"
resp = requests.get(url=url, timeout=timeout)
if resp.status_code != requests.codes.ok: # pylint: disable=no-member
if resp.status_code != requests.codes.ok: # pylint: disable=no-member
if resp.status_code in (400, 401, 403, 404):
return None
try:
Expand Down Expand Up @@ -699,9 +700,6 @@ def update_env_param(name, value, default=None):
)
update_env_param(constants._HTTPS_PROXY, http_proxy, optional_defaults[constants._HTTPS_PROXY])
update_env_param(constants._NO_PROXY, no_proxy, optional_defaults[constants._NO_PROXY])
# DOCKER_IMAGE
# maybe_update_env_param(constants._DOCKER_IMAGE, options.get(AgentOptionsJsonFields.DOCKER_IMAGE, None))

update_env_param(
constants._NET_CLIENT_DOCKER_IMAGE,
net_options.get(AgentOptionsJsonFields.NET_CLIENT_DOCKER_IMAGE, None),
Expand All @@ -715,6 +713,11 @@ def update_env_param(name, value, default=None):
update_env_param(
constants._DOCKER_IMAGE, options.get(AgentOptionsJsonFields.DOCKER_IMAGE, None)
)
update_env_param(
constants._FORCE_CPU_ONLY,
options.get(AgentOptionsJsonFields.FORCE_CPU_ONLY, "false"),
"false",
)

agent_host_dir = options.get(AgentOptionsJsonFields.AGENT_HOST_DIR, "").strip()
if agent_host_dir == "":
Expand Down Expand Up @@ -782,6 +785,7 @@ def _volumes_changes(volumes) -> dict:
changes[key] = value
return changes


def _is_bind_attached(container_info, bind_path):
vols = binds_to_volumes_dict(container_info.get("HostConfig", {}).get("Binds", []))

Expand All @@ -791,15 +795,17 @@ def _is_bind_attached(container_info, bind_path):

return False


def _copy_file_to_container(container, src, dst_dir: str):
stream = io.BytesIO()
with tarfile.open(fileobj=stream, mode='w|') as tar, open(src, 'rb') as f:
with tarfile.open(fileobj=stream, mode="w|") as tar, open(src, "rb") as f:
info = tar.gettarinfo(fileobj=f)
info.name = os.path.basename(src)
tar.addfile(info, f)

container.put_archive(dst_dir, stream.getvalue())


def _ca_cert_changed(ca_cert) -> str:
if ca_cert is None:
return None
Expand Down Expand Up @@ -832,7 +838,12 @@ def _ca_cert_changed(ca_cert) -> str:
tmp_container = docker_api.containers.create(
agent_image,
"",
volumes={constants.SLY_EXTRA_CA_CERTS_VOLUME_NAME(): {"bind": constants.SLY_EXTRA_CA_CERTS_DIR(), "mode": "rw"}},
volumes={
constants.SLY_EXTRA_CA_CERTS_VOLUME_NAME(): {
"bind": constants.SLY_EXTRA_CA_CERTS_DIR(),
"mode": "rw",
}
},
)

_copy_file_to_container(tmp_container, cert_path, constants.SLY_EXTRA_CA_CERTS_DIR())
Expand Down Expand Up @@ -958,3 +969,33 @@ def restart_agent(
"Docker container is spawned",
extra={"container_id": container.id, "container_name": container.name},
)


def nvidia_runtime_is_available():
docker_api = docker.from_env()
image = constants.DEFAULT_APP_DOCKER_IMAGE()
try:
docker_api.containers.run(
image,
command="nvidia-smi",
runtime="nvidia",
remove=True,
)
return True
except Exception as e:
return False


def maybe_update_runtime():
container_info = get_container_info()
runtime = container_info["HostConfig"]["Runtime"]
if runtime == "nvidia":
return runtime
sly.logger.debug("NVIDIA runtime is not enabled. Checking if it can be enabled...")
is_available = nvidia_runtime_is_available()
if is_available:
sly.logger.debug("NVIDIA runtime is available. Will restart Agent with NVIDIA runtime.")
return "nvidia"
else:
sly.logger.debug("NVIDIA runtime is not available.")
return runtime
13 changes: 11 additions & 2 deletions agent/worker/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import supervisely_lib as sly
import hashlib
import re
from supervisely_lib.io.docker_utils import PullPolicy # pylint: disable=import-error, no-name-in-module
from supervisely_lib.io.docker_utils import (
PullPolicy,
) # pylint: disable=import-error, no-name-in-module


_SERVER_ADDRESS = "SERVER_ADDRESS"
Expand Down Expand Up @@ -99,6 +101,7 @@ def TOKEN():
_UPDATE_SLY_NET_AFTER_RESTART = "UPDATE_SLY_NET_AFTER_RESTART"
_DOCKER_IMAGE = "DOCKER_IMAGE"
_CONTAINER_NAME = "CONTAINER_NAME"
_FORCE_CPU_ONLY = "FORCE_CPU_ONLY"

_NET_CLIENT_DOCKER_IMAGE = "NET_CLIENT_DOCKER_IMAGE"
_NET_SERVER_PORT = "NET_SERVER_PORT"
Expand Down Expand Up @@ -163,6 +166,7 @@ def TOKEN():
_AGENT_RESTART_COUNT: 0,
_SLY_EXTRA_CA_CERTS_DIR: "/sly_certs",
_SLY_EXTRA_CA_CERTS_VOLUME_NAME: f"supervisely-agent-ca-certs-{TOKEN()[:8]}",
_FORCE_CPU_ONLY: "false",
}


Expand Down Expand Up @@ -261,7 +265,7 @@ def AGENT_TASKS_DIR():

def AGENT_TASK_SHARED_DIR():
"""default: /sly_agent/tasks/task_shared"""
return os.path.join(AGENT_TASKS_DIR(), sly.task.paths.TASK_SHARED) # pylint: disable=no-member
return os.path.join(AGENT_TASKS_DIR(), sly.task.paths.TASK_SHARED) # pylint: disable=no-member


def AGENT_TMP_DIR():
Expand Down Expand Up @@ -658,6 +662,7 @@ def AGENT_RESTART_COUNT():
def SLY_EXTRA_CA_CERTS_DIR():
return read_optional_setting(_SLY_EXTRA_CA_CERTS_DIR)


def SLY_EXTRA_CA_CERTS_FILEPATH():
return os.path.join(SLY_EXTRA_CA_CERTS_DIR(), "instance_ca_chain.crt")

Expand All @@ -666,6 +671,10 @@ def SLY_EXTRA_CA_CERTS_VOLUME_NAME():
return read_optional_setting(_SLY_EXTRA_CA_CERTS_VOLUME_NAME)


def FORCE_CPU_ONLY():
return sly.env.flag_from_env(read_optional_setting(_FORCE_CPU_ONLY))


def init_constants():
sly.fs.mkdir(AGENT_LOG_DIR())
sly.fs.mkdir(AGENT_TASKS_DIR())
Expand Down
6 changes: 6 additions & 0 deletions agent/worker/task_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def task_main_func(self):
"Couldn't find sly-net-client attached to this agent. We'll try to deploy it during the agent restart"
)

if envs.get(constants._FORCE_CPU_ONLY, "false") == "true":
runtime = "runc"
else:
runtime = agent_utils.maybe_update_runtime()

# Stop current container
cur_container_id = container_info["Config"]["Hostname"]
envs[constants._REMOVE_OLD_AGENT] = cur_container_id
Expand All @@ -88,6 +93,7 @@ def task_main_func(self):
image=image,
envs=envs,
volumes=volumes,
runtime=runtime,
ca_cert_path=ca_cert_path,
docker_api=self._docker_api,
)
Expand Down

0 comments on commit 6e74a26

Please sign in to comment.