Skip to content

Commit

Permalink
Multiple attachment improvements #747 (#748)
Browse files Browse the repository at this point in the history
* Multiple attachment improvements #747

- Allow to invoke `run.attach` without errors when the tunnel is already established by another process
- Allow to invoke `run.attach` on a run obtained via `client.runs.get` and have attached logs
- Support `-a` (`--attach`) in `dstack logs`
- Automatically detach from the run on the program's exit

* Multiple attachment improvements #747

- Addressing feedback
  • Loading branch information
peterschmidt85 authored Oct 30, 2023
1 parent f68efea commit 02ab3df
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 94 deletions.
30 changes: 26 additions & 4 deletions src/dstack/_internal/cli/commands/logs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import argparse
import sys
from pathlib import Path

from dstack._internal.cli.commands import APIBaseCommand
from dstack._internal.cli.utils.common import confirm_ask
from dstack._internal.core.errors import CLIError
from dstack._internal.core.services.ssh.ports import PortUsedError


class LogsCommand(APIBaseCommand):
Expand All @@ -13,14 +14,35 @@ class LogsCommand(APIBaseCommand):
def _register(self):
super()._register()
self._parser.add_argument("-d", "--diagnose", action="store_true")
self._parser.add_argument(
"-a",
"--attach",
action="store_true",
help="Set up an SSH tunnel, and print logs as they follow.",
)
self._parser.add_argument(
"--ssh-identity",
metavar="SSH_PRIVATE_KEY",
help="A path to the private SSH key file for SSH tunneling",
type=Path,
dest="ssh_identity_file",
)
self._parser.add_argument("run_name")

def _command(self, args: argparse.Namespace):
super()._command(args)
run = self.api.runs.get(args.run_name)
if run is None:
raise CLIError(f"Run {args.run_name} not found")
if not args.diagnose and args.attach:
if run.status.is_finished():
raise CLIError(f"Run {args.run_name} is finished")
else:
run.attach(args.ssh_identity_file)
logs = run.logs(diagnose=args.diagnose)
for log in logs:
sys.stdout.buffer.write(log)
sys.stdout.buffer.flush()
try:
for log in logs:
sys.stdout.buffer.write(log)
sys.stdout.buffer.flush()
except KeyboardInterrupt:
pass
32 changes: 30 additions & 2 deletions src/dstack/_internal/core/services/ssh/attach.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,38 @@
import atexit
import re
import subprocess
import time
from typing import Optional, Tuple

from dstack._internal.core.errors import SSHError
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.core.services.ssh.ports import PortsLock
from dstack._internal.core.services.ssh.tunnel import ClientTunnel
from dstack._internal.utils.path import PathLike
from dstack._internal.utils.ssh import include_ssh_config, update_ssh_config
from dstack._internal.utils.ssh import get_ssh_config, include_ssh_config, update_ssh_config


class SSHAttach:
@staticmethod
def reuse_control_sock_path_and_port_locks(run_name: str) -> Optional[Tuple[str, PortsLock]]:
ssh_config_path = str(ConfigManager().dstack_ssh_config_path)
host_config = get_ssh_config(ssh_config_path, run_name)
if host_config and host_config.get("ControlPath"):
ps = subprocess.Popen(("ps", "-A", "-o", "command"), stdout=subprocess.PIPE)
control_sock_path = host_config.get("ControlPath")
output = subprocess.check_output(("grep", control_sock_path), stdin=ps.stdout)
ps.wait()
commands = list(
filter(lambda s: not s.startswith("grep"), output.decode().strip().split("\n"))
)
if commands:
port_pattern = r"-L (\d+):localhost:(\d+)"
matches = re.findall(port_pattern, commands[0])
return control_sock_path, PortsLock(
{int(local_port): int(target_port) for local_port, target_port in matches}
)
return None

def __init__(
self,
hostname: str,
Expand All @@ -18,11 +42,14 @@ def __init__(
ports_lock: PortsLock,
run_name: str,
dockerized: bool,
control_sock_path: Optional[str] = None,
):
self._ports_lock = ports_lock
self.ports = ports_lock.dict()
self.run_name = run_name
self.tunnel = ClientTunnel(run_name, self.ports, id_rsa_path=id_rsa_path)
self.tunnel = ClientTunnel(
run_name, self.ports, id_rsa_path=id_rsa_path, control_sock_path=control_sock_path
)
self.host_config = {
"HostName": hostname,
"Port": ssh_port,
Expand Down Expand Up @@ -61,6 +88,7 @@ def attach(self):
for i in range(max_retries):
try:
self.tunnel.open()
atexit.register(self.detach)
break
except SSHError:
if i < max_retries - 1:
Expand Down
8 changes: 7 additions & 1 deletion src/dstack/_internal/core/services/ssh/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ def release(self) -> Dict[int, int]:
return mapping

def dict(self) -> Dict[int, int]:
return {remote_port: sock.getsockname()[1] for remote_port, sock in self.sockets.items()}
d = {}
for remote_port, local_port in self.restrictions.items():
if local_port:
d[remote_port] = local_port
else:
d[remote_port] = self.sockets[remote_port].getsockname()[1]
return d

@staticmethod
def _listen(port: int) -> Optional[socket.socket]:
Expand Down
14 changes: 11 additions & 3 deletions src/dstack/_internal/core/services/ssh/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,20 @@ class ClientTunnel(SSHTunnel):
CLITunnel connects to the host from ssh config
"""

def __init__(self, host: str, ports: Dict[int, int], id_rsa_path: PathLike):
self.temp_dir = tempfile.TemporaryDirectory()
def __init__(
self,
host: str,
ports: Dict[int, int],
id_rsa_path: PathLike,
control_sock_path: Optional[str] = None,
):
self.temp_dir = tempfile.TemporaryDirectory() if not control_sock_path else None
super().__init__(
host=host,
id_rsa_path=id_rsa_path,
ports=ports,
control_sock_path=os.path.join(self.temp_dir.name, "control.sock"),
control_sock_path=os.path.join(self.temp_dir.name, "control.sock")
if not control_sock_path
else control_sock_path,
options={},
)
25 changes: 24 additions & 1 deletion src/dstack/_internal/utils/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import subprocess
import sys
from pathlib import Path
from typing import Dict
from typing import Dict, Optional

from filelock import FileLock
from paramiko.config import SSHConfig
Expand Down Expand Up @@ -53,6 +53,29 @@ def include_ssh_config(path: PathLike, ssh_config_path: PathLike = default_ssh_c
f.write(include + content)


def get_ssh_config(path: PathLike, host: str) -> Optional[Dict[str, str]]:
if os.path.exists(path):
config = {}
current_host = None

with open(path, "r") as f:
for line in f:
line = line.strip()

if not line or line.startswith("#"):
continue

if line.startswith("Host "):
current_host = line.split(" ")[1]
config[current_host] = {}
else:
key, value = line.split(maxsplit=1)
config[current_host][key] = value
return config.get(host)
else:
return None


def update_ssh_config(path: PathLike, host: str, options: Dict[str, str]):
Path(path).parent.mkdir(parents=True, exist_ok=True)
with FileLock(str(path) + ".lock"):
Expand Down
Loading

0 comments on commit 02ab3df

Please sign in to comment.