Skip to content

Commit

Permalink
Retry instance termination in case of errors (#2190)
Browse files Browse the repository at this point in the history
Retry instance termination every ~30 seconds for
~15 minutes until it is successfully terminated.
Mark as terminated and log an error if all
attempts failed.

Bonus: fix `AsyncioCancelledErrorFilter` for the
case when `exc_info` is `False`.
  • Loading branch information
jvstme authored Jan 15, 2025
1 parent d1f8cfe commit 02abf94
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 17 deletions.
49 changes: 40 additions & 9 deletions src/dstack/_internal/server/background/tasks/process_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@
PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60)

TERMINATION_DEADLINE_OFFSET = timedelta(minutes=20)

TERMINATION_RETRY_TIMEOUT = timedelta(seconds=30)
TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15)
PROVISIONING_TIMEOUT_SECONDS = 10 * 60 # 10 minutes in seconds


Expand Down Expand Up @@ -765,6 +766,11 @@ def _instance_healthcheck(ports: Dict[int, int]) -> HealthStatus:


async def _terminate(instance: InstanceModel) -> None:
if (
instance.last_termination_retry_at is not None
and _next_termination_retry_at(instance) > get_current_datetime()
):
return
jpd = get_instance_provisioning_data(instance)
if jpd is not None:
if jpd.backend != BackendType.REMOTE:
Expand All @@ -786,16 +792,25 @@ async def _terminate(instance: InstanceModel) -> None:
jpd.region,
jpd.backend_data,
)
except BackendError as e:
except Exception as e:
if instance.first_termination_retry_at is None:
instance.first_termination_retry_at = get_current_datetime()
instance.last_termination_retry_at = get_current_datetime()
if _next_termination_retry_at(instance) < _get_termination_deadline(instance):
logger.warning(
"Failed to terminate instance %s. Will retry. Error: %r",
instance.name,
e,
exc_info=not isinstance(e, BackendError),
)
return
logger.error(
"Failed to terminate instance %s: %s",
instance.name,
repr(e),
)
except Exception:
logger.exception(
"Got exception when terminating instance %s",
"Failed all attempts to terminate instance %s."
" Please terminate the instance manually to avoid unexpected charges."
" Error: %r",
instance.name,
e,
exc_info=not isinstance(e, BackendError),
)

instance.deleted = True
Expand All @@ -812,6 +827,22 @@ async def _terminate(instance: InstanceModel) -> None:
)


def _next_termination_retry_at(instance: InstanceModel) -> datetime.datetime:
assert instance.last_termination_retry_at is not None
return (
instance.last_termination_retry_at.replace(tzinfo=datetime.timezone.utc)
+ TERMINATION_RETRY_TIMEOUT
)


def _get_termination_deadline(instance: InstanceModel) -> datetime.datetime:
assert instance.first_termination_retry_at is not None
return (
instance.first_termination_retry_at.replace(tzinfo=datetime.timezone.utc)
+ TERMINATION_RETRY_MAX_DURATION
)


def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool:
# Cluster cloud instances should wait for the first fleet instance to be provisioned
# so that they are provisioned in the same backend/region
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Add instance termination retries
Revision ID: c48df7985d57
Revises: 065588ec72b8
Create Date: 2025-01-14 13:33:17.722284
"""

import sqlalchemy as sa
from alembic import op

from dstack._internal.server.models import NaiveDateTime

# revision identifiers, used by Alembic.
revision = "c48df7985d57"
down_revision = "065588ec72b8"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.add_column(
sa.Column("first_termination_retry_at", NaiveDateTime(), nullable=True)
)
batch_op.add_column(sa.Column("last_termination_retry_at", NaiveDateTime(), nullable=True))

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.drop_column("last_termination_retry_at")
batch_op.drop_column("first_termination_retry_at")

# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,8 @@ class InstanceModel(BaseModel):
termination_deadline: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
termination_reason: Mapped[Optional[str]] = mapped_column(String(4000))
health_status: Mapped[Optional[str]] = mapped_column(String(4000))
first_termination_retry_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
last_termination_retry_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)

# backend
backend: Mapped[Optional[BackendType]] = mapped_column(Enum(BackendType))
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class AsyncioCancelledErrorFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
if record.exc_info is None:
if not record.exc_info:
return True
if isinstance(record.exc_info[1], asyncio.CancelledError):
return False
Expand Down
121 changes: 114 additions & 7 deletions src/tests/_internal/server/background/tasks/test_process_instances.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import datetime as dt
from contextlib import contextmanager
from typing import Optional
from unittest.mock import Mock, patch

import pytest
from freezegun import freeze_time
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.core.errors import BackendError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import (
InstanceAvailability,
Expand Down Expand Up @@ -330,6 +334,20 @@ async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):


class TestTerminate:
@staticmethod
@contextmanager
def mock_terminate_in_backend(error: Optional[Exception] = None):
backend = Mock()
backend.TYPE = BackendType.DATACRUNCH
terminate_instance = backend.compute.return_value.terminate_instance
if error is not None:
terminate_instance.side_effect = error
with patch(
"dstack._internal.server.background.tasks.process_instances.backends_services.get_project_backend_by_type"
) as get_backend:
get_backend.return_value = backend
yield terminate_instance

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_terminate(self, test_db, session: AsyncSession):
Expand All @@ -343,14 +361,9 @@ async def test_terminate(self, test_db, session: AsyncSession):
instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)
await session.commit()

with patch(
"dstack._internal.server.background.tasks.process_instances.backends_services.get_project_backends"
) as get_backends:
backend = Mock()
backend.TYPE = BackendType.DATACRUNCH
backend.compute.return_value.terminate_instance.return_value = Mock()
get_backends.return_value = [backend]
with self.mock_terminate_in_backend() as mock:
await process_instances()
mock.assert_called_once()

await session.refresh(instance)

Expand All @@ -361,6 +374,100 @@ async def test_terminate(self, test_db, session: AsyncSession):
assert instance.deleted_at is not None
assert instance.finished_at is not None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.parametrize("error", [BackendError("err"), RuntimeError("err")])
async def test_terminate_retry(self, test_db, session: AsyncSession, error: Exception):
project = await create_project(session=session)
pool = await create_pool(session, project)
instance = await create_instance(session, project, pool, status=InstanceStatus.TERMINATING)
instance.termination_reason = "some reason"
initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc)
instance.last_job_processed_at = initial_time
await session.commit()

# First attempt fails
with (
freeze_time(initial_time + dt.timedelta(minutes=1)),
self.mock_terminate_in_backend(error=error) as mock,
):
await process_instances()
mock.assert_called_once()
await session.refresh(instance)
assert instance.status == InstanceStatus.TERMINATING

# Second attempt succeeds
with (
freeze_time(initial_time + dt.timedelta(minutes=2)),
self.mock_terminate_in_backend(error=None) as mock,
):
await process_instances()
mock.assert_called_once()
await session.refresh(instance)
assert instance.status == InstanceStatus.TERMINATED

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_terminate_not_retries_if_too_early(self, test_db, session: AsyncSession):
project = await create_project(session=session)
pool = await create_pool(session, project)
instance = await create_instance(session, project, pool, status=InstanceStatus.TERMINATING)
instance.termination_reason = "some reason"
initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc)
instance.last_job_processed_at = initial_time
await session.commit()

# First attempt fails
with (
freeze_time(initial_time + dt.timedelta(minutes=1)),
self.mock_terminate_in_backend(error=BackendError("err")) as mock,
):
await process_instances()
mock.assert_called_once()
await session.refresh(instance)
assert instance.status == InstanceStatus.TERMINATING

# 3 seconds later - too early for the second attempt, nothing happens
with (
freeze_time(initial_time + dt.timedelta(minutes=1, seconds=3)),
self.mock_terminate_in_backend(error=None) as mock,
):
await process_instances()
mock.assert_not_called()
await session.refresh(instance)
assert instance.status == InstanceStatus.TERMINATING

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_terminate_on_termination_deadline(self, test_db, session: AsyncSession):
project = await create_project(session=session)
pool = await create_pool(session, project)
instance = await create_instance(session, project, pool, status=InstanceStatus.TERMINATING)
instance.termination_reason = "some reason"
initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc)
instance.last_job_processed_at = initial_time
await session.commit()

# First attempt fails
with (
freeze_time(initial_time + dt.timedelta(minutes=1)),
self.mock_terminate_in_backend(error=BackendError("err")) as mock,
):
await process_instances()
mock.assert_called_once()
await session.refresh(instance)
assert instance.status == InstanceStatus.TERMINATING

# Second attempt fails too, but it's the last attempt because the deadline is close
with (
freeze_time(initial_time + dt.timedelta(minutes=15, seconds=55)),
self.mock_terminate_in_backend(error=None) as mock,
):
await process_instances()
mock.assert_called_once()
await session.refresh(instance)
assert instance.status == InstanceStatus.TERMINATED


class TestCreateInstance:
@pytest.mark.asyncio
Expand Down

0 comments on commit 02abf94

Please sign in to comment.