From c10b1fe28f1b1de6ffcf765b9f7fce6975a5bd50 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 4 Dec 2024 11:19:35 +0500 Subject: [PATCH] Support specifying internal_ip for SSH fleet hosts (#2056) * Support specifying internal_ip for SSH fleet hosts * Validate internal_ip * Handle client backward compatibility * Remove extra space --- src/dstack/_internal/core/models/fleets.py | 29 +++++++++++++++++- .../background/tasks/process_instances.py | 29 ++++++++++++++---- .../_internal/server/services/fleets.py | 15 ++++++++++ src/dstack/_internal/server/services/pools.py | 3 +- src/dstack/_internal/utils/network.py | 18 ++++++++++- src/dstack/api/server/_fleets.py | 30 ++++++++++++------- 6 files changed, 105 insertions(+), 19 deletions(-) diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py index ce4f50532..dab8ac8c0 100644 --- a/src/dstack/_internal/core/models/fleets.py +++ b/src/dstack/_internal/core/models/fleets.py @@ -50,8 +50,29 @@ class SSHHostParams(CoreModel): identity_file: Annotated[ Optional[str], Field(description="The private key to use for this host") ] = None + internal_ip: Annotated[ + Optional[str], + Field( + description=( + "The internal IP of the host used for communication inside the cluster." + " If not specified, `dstack` will use the IP address from `network` or from the first found internal network." + ) + ), + ] = None ssh_key: Optional[SSHKey] = None + @validator("internal_ip") + def validate_internal_ip(cls, value): + if value is None: + return value + try: + internal_ip = ipaddress.ip_address(value) + except ValueError as e: + raise ValueError("Invalid IP address") from e + if not internal_ip.is_private: + raise ValueError("IP address is not private") + return value + class SSHParams(CoreModel): user: Annotated[Optional[str], Field(description="The user to log in with on all hosts")] = ( @@ -70,7 +91,13 @@ class SSHParams(CoreModel): ] network: Annotated[ Optional[str], - Field(description="The network address for cluster setup in the format `/`"), + Field( + description=( + "The network address for cluster setup in the format `/`." + " `dstack` will use IP addresses from this network for communication between hosts." + " If not specified, `dstack` will use IPs from the first found internal network." + ) + ), ] @validator("network") diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 5c39ea1c6..018c3c8ce 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -92,7 +92,7 @@ from dstack._internal.server.utils.common import run_async from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.logging import get_logger -from dstack._internal.utils.network import get_ip_from_network +from dstack._internal.utils.network import get_ip_from_network, is_ip_among_addresses from dstack._internal.utils.ssh import ( pkey_from_str, ) @@ -290,16 +290,20 @@ async def _add_remote(instance: InstanceModel) -> None: instance_type = host_info_to_instance_type(host_info) instance_network = None + internal_ip = None try: default_jpd = JobProvisioningData.__response__.parse_raw(instance.job_provisioning_data) instance_network = default_jpd.instance_network + internal_ip = default_jpd.internal_ip except ValidationError: pass - internal_ip = get_ip_from_network( - network=instance_network, - addresses=host_info.get("addresses", []), - ) + host_network_addresses = host_info.get("addresses", []) + if internal_ip is None: + internal_ip = get_ip_from_network( + network=instance_network, + addresses=host_network_addresses, + ) if instance_network is not None and internal_ip is None: instance.status = InstanceStatus.TERMINATED instance.termination_reason = "Failed to locate internal IP address on the given network" @@ -312,6 +316,21 @@ async def _add_remote(instance: InstanceModel) -> None: }, ) return + if internal_ip is not None: + if not is_ip_among_addresses(ip_address=internal_ip, addresses=host_network_addresses): + instance.status = InstanceStatus.TERMINATED + instance.termination_reason = ( + "Specified internal IP not found among instance interfaces" + ) + logger.warning( + "Failed to add instance %s: specified internal IP not found among instance interfaces", + instance.name, + extra={ + "instance_name": instance.name, + "instance_status": InstanceStatus.TERMINATED.value, + }, + ) + return region = instance.region jpd = JobProvisioningData( diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 9d1a9447e..9155b7d1b 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -402,11 +402,13 @@ async def create_fleet_ssh_instance_model( ssh_user = ssh_params.user ssh_key = ssh_params.ssh_key port = ssh_params.port + internal_ip = None else: hostname = host.hostname ssh_user = host.user or ssh_params.user ssh_key = host.ssh_key or ssh_params.ssh_key port = host.port or ssh_params.port + internal_ip = host.internal_ip if ssh_user is None or ssh_key is None: # This should not be reachable but checked by fleet spec validation @@ -422,6 +424,7 @@ async def create_fleet_ssh_instance_model( ssh_user=ssh_user, ssh_keys=[ssh_key], env=env, + internal_ip=internal_ip, instance_network=ssh_params.network, port=port or 22, ) @@ -678,6 +681,7 @@ def _validate_fleet_spec(spec: FleetSpec): for host in spec.configuration.ssh_config.hosts: if is_core_model_instance(host, SSHHostParams) and host.ssh_key is not None: _validate_ssh_key(host.ssh_key) + _validate_internal_ips(spec.configuration.ssh_config) def _validate_all_ssh_params_specified(ssh_config: SSHParams): @@ -706,6 +710,17 @@ def _validate_ssh_key(ssh_key: SSHKey): ) +def _validate_internal_ips(ssh_config: SSHParams): + internal_ips_num = 0 + for host in ssh_config.hosts: + if not isinstance(host, str) and host.internal_ip is not None: + internal_ips_num += 1 + if internal_ips_num != 0 and internal_ips_num != len(ssh_config.hosts): + raise ServerClientError("internal_ip must be specified for all hosts") + if internal_ips_num > 0 and ssh_config.network is not None: + raise ServerClientError("internal_ip is mutually exclusive with network") + + def _get_fleet_nodes_to_provision(spec: FleetSpec) -> int: if spec.configuration.nodes is None or spec.configuration.nodes.min is None: return 0 diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index dc0495b10..fed2c4e9b 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -656,6 +656,7 @@ async def create_ssh_instance_model( pool: PoolModel, instance_name: str, instance_num: int, + internal_ip: Optional[str], instance_network: Optional[str], region: Optional[str], host: str, @@ -676,7 +677,7 @@ async def create_ssh_instance_model( instance_id=instance_name, hostname=host, region=host_region, - internal_ip=None, + internal_ip=internal_ip, instance_network=instance_network, price=0, username=ssh_user, diff --git a/src/dstack/_internal/utils/network.py b/src/dstack/_internal/utils/network.py index 517438025..355753a61 100644 --- a/src/dstack/_internal/utils/network.py +++ b/src/dstack/_internal/utils/network.py @@ -1,5 +1,5 @@ import ipaddress -from typing import Optional, Sequence +from typing import List, Optional, Sequence def get_ip_from_network(network: Optional[str], addresses: Sequence[str]) -> Optional[str]: @@ -32,3 +32,19 @@ def get_ip_from_network(network: Optional[str], addresses: Sequence[str]) -> Opt # return any ipv4 internal_ip = str(ip_addresses[0]) if ip_addresses else None return internal_ip + + +def is_ip_among_addresses(ip_address: str, addresses: Sequence[str]) -> bool: + ip_addresses = get_ips_from_addresses(addresses) + return ip_address in ip_addresses + + +def get_ips_from_addresses(addresses: Sequence[str]) -> List[str]: + ip_addresses = [] + for address in addresses: + try: + interface = ipaddress.IPv4Interface(address) + ip_addresses.append(interface.ip) + except ipaddress.AddressValueError: + continue + return [str(ip) for ip in ip_addresses] diff --git a/src/dstack/api/server/_fleets.py b/src/dstack/api/server/_fleets.py index cc118ddf7..2a067e99d 100644 --- a/src/dstack/api/server/_fleets.py +++ b/src/dstack/api/server/_fleets.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from pydantic import parse_obj_as @@ -29,11 +29,7 @@ def get_plan( spec: FleetSpec, ) -> FleetPlan: body = GetFleetPlanRequest(spec=spec) - body_json = body.json() - if spec.configuration_path is None: - # Handle old server versions that do not accept configuration_path - # TODO: Can be removed in 0.19 - body_json = body.json(exclude={"spec": {"configuration_path"}}) + body_json = body.json(exclude=_get_fleet_spec_excludes(spec)) resp = self._request(f"/api/project/{project_name}/fleets/get_plan", body=body_json) return parse_obj_as(FleetPlan.__response__, resp.json()) @@ -43,11 +39,7 @@ def create( spec: FleetSpec, ) -> Fleet: body = CreateFleetRequest(spec=spec) - body_json = body.json() - if spec.configuration_path is None: - # Handle old server versions that do not accept configuration_path - # TODO: Can be removed in 0.19 - body_json = body.json(exclude={"spec": {"configuration_path"}}) + body_json = body.json(exclude=_get_fleet_spec_excludes(spec)) resp = self._request(f"/api/project/{project_name}/fleets/create", body=body_json) return parse_obj_as(Fleet.__response__, resp.json()) @@ -58,3 +50,19 @@ def delete(self, project_name: str, names: List[str]) -> None: def delete_instances(self, project_name: str, name: str, instance_nums: List[int]) -> None: body = DeleteFleetInstancesRequest(name=name, instance_nums=instance_nums) self._request(f"/api/project/{project_name}/fleets/delete_instances", body=body.json()) + + +def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[dict]: + exclude = {} + # TODO: Can be removed in 0.19 + if fleet_spec.configuration_path is None: + exclude["spec"] = {"configuration_path"} + if fleet_spec.configuration.ssh_config is not None: + if all( + isinstance(h, str) or h.internal_ip is None + for h in fleet_spec.configuration.ssh_config.hosts + ): + exclude["spec"] = { + "configuration": {"ssh_config": {"hosts": {"__all__": {"internal_ip"}}}} + } + return exclude or None