Skip to content

Commit

Permalink
Add heartbeat timeout to task management and update expired task logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Nov 1, 2024
1 parent 83fd94c commit b59272d
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 15 deletions.
1 change: 0 additions & 1 deletion aana/api/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ async def execute_task(self, task_id: str | UUID) -> Any:
Any: The response from the endpoint.
"""
try:
print(f"Executing task {task_id}, type: {type(task_id)}")
self.running_tasks.add(task_id)
with get_session() as session:
task_repo = TaskRepository(session)
Expand Down
2 changes: 2 additions & 0 deletions aana/configs/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ class TaskQueueSettings(BaseModel):
execution_timeout (int): The maximum execution time for a task in seconds.
After this time, if the task is still running,
it will be considered as stuck and will be reassign to another worker.
heartbeat_timeout (int): The maximum time between heartbeats in seconds.
max_retries (int): The maximum number of retries for a task.
"""

enabled: bool = True
num_workers: int = 4
execution_timeout: int = 600
heartbeat_timeout: int = 60
max_retries: int = 3


Expand Down
6 changes: 3 additions & 3 deletions aana/deployments/task_queue_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ async def loop(self): # noqa: C901
)

# Check for expired tasks
execution_timeout = aana_settings.task_queue.execution_timeout
max_retries = aana_settings.task_queue.max_retries
expired_tasks = TaskRepository(session).update_expired_tasks(
execution_timeout=execution_timeout, max_retries=max_retries
execution_timeout=aana_settings.task_queue.execution_timeout,
heartbeat_timeout=aana_settings.task_queue.heartbeat_timeout,
max_retries=aana_settings.task_queue.max_retries,
)
for task in expired_tasks:
deployment_response = self.deployment_responses.get(task.id)
Expand Down
1 change: 0 additions & 1 deletion aana/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
DeploymentException,
EmptyMigrationsException,
FailedDeployment,
InferenceException,
InsufficientResources,
)
from aana.storage.op import run_alembic_migrations
Expand Down
35 changes: 25 additions & 10 deletions aana/storage/repository/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def filter_incomplete_tasks(self, task_ids: list[str]) -> list[str]:
return incomplete_task_ids

def update_expired_tasks(
self, execution_timeout: float, max_retries: int
self, execution_timeout: float, heartbeat_timeout: float, max_retries: int
) -> list[TaskEntity]:
"""Fetches all tasks that are expired and updates their status.
Expand All @@ -243,18 +243,23 @@ def update_expired_tasks(
Args:
execution_timeout (float): The maximum execution time for a task in seconds
heartbeat_timeout (float): The maximum time since the last heartbeat in seconds
max_retries (int): The maximum number of retries for a task
Returns:
list[TaskEntity]: the expired tasks.
"""
cutoff_time = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005
timeout_cutoff = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005
heartbeat_cutoff = datetime.now() - timedelta(seconds=heartbeat_timeout) # noqa: DTZ005
tasks = (
self.session.query(TaskEntity)
.filter(
and_(
TaskEntity.status.in_([TaskStatus.RUNNING, TaskStatus.ASSIGNED]),
TaskEntity.updated_at <= cutoff_time,
or_(
TaskEntity.updated_at <= timeout_cutoff,
TaskEntity.updated_at <= heartbeat_cutoff,
),
),
)
.populate_existing()
Expand All @@ -263,17 +268,28 @@ def update_expired_tasks(
)
for task in tasks:
if task.num_retries >= max_retries:
self.update_status(
task_id=task.id,
status=TaskStatus.FAILED,
progress=0,
result={
if task.updated_at <= timeout_cutoff:
result = {
"error": "TimeoutError",
"message": (
f"Task execution timed out after {execution_timeout} seconds and "
f"exceeded the maximum number of retries ({max_retries})"
),
},
}
else:
result = {
"error": "HeartbeatTimeoutError",
"message": (
f"The task has not received a heartbeat for {heartbeat_timeout} seconds and "
f"exceeded the maximum number of retries ({max_retries})"
),
}

self.update_status(
task_id=task.id,
status=TaskStatus.FAILED,
progress=0,
result=result,
commit=False,
)
else:
Expand All @@ -292,7 +308,6 @@ def heartbeat(self, task_ids: list[str] | set[str]):
Args:
task_ids (list[str] | set[str]): List or set of task IDs to update
"""
print(f"Heartbeat: {task_ids}")
task_ids = [
UUID(task_id) if isinstance(task_id, str) else task_id
for task_id in task_ids
Expand Down

0 comments on commit b59272d

Please sign in to comment.