Skip to content

Commit

Permalink
Return error when instance added to multiple fleets(#1699) (#1938)
Browse files Browse the repository at this point in the history
* Return error when instance added to multiple fleets(#1699)

* Reuse exception and extract check to function
  • Loading branch information
swsvc authored Nov 1, 2024
1 parent b8674ca commit 7fe2af6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import string
import uuid
from datetime import timezone
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, cast

from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -31,6 +31,7 @@
InstanceConfiguration,
InstanceOfferWithAvailability,
InstanceStatus,
RemoteConnectionInfo,
SSHKey,
)
from dstack._internal.core.models.pools import Instance
Expand Down Expand Up @@ -63,6 +64,7 @@
get_locker,
string_to_lock_id,
)
from dstack._internal.server.services.pools import list_active_remote_instances
from dstack._internal.server.services.projects import get_member, get_member_permissions
from dstack._internal.utils import common as common_utils
from dstack._internal.utils import random_names
Expand Down Expand Up @@ -136,6 +138,7 @@ async def get_plan(
spec: FleetSpec,
) -> FleetPlan:
# TODO: refactor offers logic into a separate module to avoid depending on runs
await _check_ssh_hosts_not_yet_added(session, spec)

offers = []
if spec.configuration.ssh_config is None:
Expand Down Expand Up @@ -537,6 +540,30 @@ def _check_can_manage_ssh_fleets(user: UserModel, project: ProjectModel):
raise ForbiddenError()


async def _check_ssh_hosts_not_yet_added(session: AsyncSession, spec: FleetSpec):
if spec.configuration.ssh_config and spec.configuration.ssh_config.hosts:
# there are manually listed hosts, need to check them for existence
active_instances = await list_active_remote_instances(session=session)

existing_hosts = set()
for instance in active_instances:
instance_conn_info = RemoteConnectionInfo.parse_raw(
cast(str, instance.remote_connection_info)
)
existing_hosts.add(instance_conn_info.host)

instances_already_in_fleet = []
for new_instance in spec.configuration.ssh_config.hosts:
hostname = new_instance if isinstance(new_instance, str) else new_instance.hostname
if hostname in existing_hosts:
instances_already_in_fleet.append(hostname)

if instances_already_in_fleet:
raise ServerClientError(
msg=f"Instances [{', '.join(instances_already_in_fleet)}] are already assigned to a fleet."
)


def _remove_fleet_spec_sensitive_info(spec: FleetSpec):
if spec.configuration.ssh_config is not None:
spec.configuration.ssh_config.ssh_key = None
Expand Down
12 changes: 12 additions & 0 deletions src/dstack/_internal/server/services/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,18 @@ async def list_user_pool_instances(
return instances


async def list_active_remote_instances(
session: AsyncSession,
) -> List[InstanceModel]:
filters: List = [InstanceModel.deleted == False, InstanceModel.backend == BackendType.REMOTE]

res = await session.execute(
select(InstanceModel).where(*filters).order_by(InstanceModel.created_at.asc())
)
instance_models = list(res.scalars().all())
return instance_models


async def create_instance_model(
session: AsyncSession,
project: ProjectModel,
Expand Down

0 comments on commit 7fe2af6

Please sign in to comment.