diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index b8068a09..1b3770ae 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -113,7 +113,6 @@ async def execute_task(self, task_id: str | UUID) -> Any: Any: The response from the endpoint. """ try: - print(f"Executing task {task_id}, type: {type(task_id)}") self.running_tasks.add(task_id) with get_session() as session: task_repo = TaskRepository(session) diff --git a/aana/configs/settings.py b/aana/configs/settings.py index ec161ac6..a2ef7a6b 100644 --- a/aana/configs/settings.py +++ b/aana/configs/settings.py @@ -28,12 +28,14 @@ class TaskQueueSettings(BaseModel): execution_timeout (int): The maximum execution time for a task in seconds. After this time, if the task is still running, it will be considered as stuck and will be reassign to another worker. + heartbeat_timeout (int): The maximum time between heartbeats in seconds. max_retries (int): The maximum number of retries for a task. """ enabled: bool = True num_workers: int = 4 execution_timeout: int = 600 + heartbeat_timeout: int = 60 max_retries: int = 3 diff --git a/aana/deployments/task_queue_deployment.py b/aana/deployments/task_queue_deployment.py index 392d889a..1d09ad42 100644 --- a/aana/deployments/task_queue_deployment.py +++ b/aana/deployments/task_queue_deployment.py @@ -137,10 +137,10 @@ async def loop(self): # noqa: C901 ) # Check for expired tasks - execution_timeout = aana_settings.task_queue.execution_timeout - max_retries = aana_settings.task_queue.max_retries expired_tasks = TaskRepository(session).update_expired_tasks( - execution_timeout=execution_timeout, max_retries=max_retries + execution_timeout=aana_settings.task_queue.execution_timeout, + heartbeat_timeout=aana_settings.task_queue.heartbeat_timeout, + max_retries=aana_settings.task_queue.max_retries, ) for task in expired_tasks: deployment_response = self.deployment_responses.get(task.id) diff --git a/aana/sdk.py b/aana/sdk.py index ea9554f1..f25a765f 100644 --- a/aana/sdk.py +++ b/aana/sdk.py @@ -23,7 +23,6 @@ DeploymentException, EmptyMigrationsException, FailedDeployment, - InferenceException, InsufficientResources, ) from aana.storage.op import run_alembic_migrations diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index 0705246f..670d00f6 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -223,7 +223,7 @@ def filter_incomplete_tasks(self, task_ids: list[str]) -> list[str]: return incomplete_task_ids def update_expired_tasks( - self, execution_timeout: float, max_retries: int + self, execution_timeout: float, heartbeat_timeout: float, max_retries: int ) -> list[TaskEntity]: """Fetches all tasks that are expired and updates their status. @@ -243,18 +243,23 @@ def update_expired_tasks( Args: execution_timeout (float): The maximum execution time for a task in seconds + heartbeat_timeout (float): The maximum time since the last heartbeat in seconds max_retries (int): The maximum number of retries for a task Returns: list[TaskEntity]: the expired tasks. """ - cutoff_time = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005 + timeout_cutoff = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005 + heartbeat_cutoff = datetime.now() - timedelta(seconds=heartbeat_timeout) # noqa: DTZ005 tasks = ( self.session.query(TaskEntity) .filter( and_( TaskEntity.status.in_([TaskStatus.RUNNING, TaskStatus.ASSIGNED]), - TaskEntity.updated_at <= cutoff_time, + or_( + TaskEntity.updated_at <= timeout_cutoff, + TaskEntity.updated_at <= heartbeat_cutoff, + ), ), ) .populate_existing() @@ -263,17 +268,28 @@ def update_expired_tasks( ) for task in tasks: if task.num_retries >= max_retries: - self.update_status( - task_id=task.id, - status=TaskStatus.FAILED, - progress=0, - result={ + if task.updated_at <= timeout_cutoff: + result = { "error": "TimeoutError", "message": ( f"Task execution timed out after {execution_timeout} seconds and " f"exceeded the maximum number of retries ({max_retries})" ), - }, + } + else: + result = { + "error": "HeartbeatTimeoutError", + "message": ( + f"The task has not received a heartbeat for {heartbeat_timeout} seconds and " + f"exceeded the maximum number of retries ({max_retries})" + ), + } + + self.update_status( + task_id=task.id, + status=TaskStatus.FAILED, + progress=0, + result=result, commit=False, ) else: @@ -292,7 +308,6 @@ def heartbeat(self, task_ids: list[str] | set[str]): Args: task_ids (list[str] | set[str]): List or set of task IDs to update """ - print(f"Heartbeat: {task_ids}") task_ids = [ UUID(task_id) if isinstance(task_id, str) else task_id for task_id in task_ids