diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index 3e2728cf..1b3770ae 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -1,7 +1,7 @@ import json import time from typing import Annotated, Any -from uuid import uuid4 +from uuid import UUID, uuid4 import orjson import ray @@ -68,6 +68,7 @@ def __init__( app.openapi = self.custom_openapi self.ready = True + self.running_tasks = set() def custom_openapi(self) -> dict[str, Any]: """Returns OpenAPI schema, generating it if necessary.""" @@ -95,16 +96,24 @@ async def is_ready(self): """ return AanaJSONResponse(content={"ready": self.ready}) - async def execute_task(self, task_id: str) -> Any: + async def check_health(self): + """Check the health of the application.""" + # Heartbeat for the running tasks + with get_session() as session: + task_repo = TaskRepository(session) + task_repo.heartbeat(self.running_tasks) + + async def execute_task(self, task_id: str | UUID) -> Any: """Execute a task. Args: - task_id (str): The task ID. + task_id (str | UUID): The ID of the task. Returns: Any: The response from the endpoint. """ try: + self.running_tasks.add(task_id) with get_session() as session: task_repo = TaskRepository(session) task = task_repo.read(task_id) @@ -139,8 +148,9 @@ async def execute_task(self, task_id: str) -> Any: TaskRepository(session).update_status( task_id, TaskStatus.FAILED, 0, error ) - else: - return out + finally: + self.running_tasks.remove(task_id) + return out @app.get( "/tasks/get/{task_id}", diff --git a/aana/configs/settings.py b/aana/configs/settings.py index ec161ac6..a2ef7a6b 100644 --- a/aana/configs/settings.py +++ b/aana/configs/settings.py @@ -28,12 +28,14 @@ class TaskQueueSettings(BaseModel): execution_timeout (int): The maximum execution time for a task in seconds. After this time, if the task is still running, it will be considered as stuck and will be reassign to another worker. + heartbeat_timeout (int): The maximum time between heartbeats in seconds. max_retries (int): The maximum number of retries for a task. """ enabled: bool = True num_workers: int = 4 execution_timeout: int = 600 + heartbeat_timeout: int = 60 max_retries: int = 3 diff --git a/aana/deployments/task_queue_deployment.py b/aana/deployments/task_queue_deployment.py index 392d889a..1d09ad42 100644 --- a/aana/deployments/task_queue_deployment.py +++ b/aana/deployments/task_queue_deployment.py @@ -137,10 +137,10 @@ async def loop(self): # noqa: C901 ) # Check for expired tasks - execution_timeout = aana_settings.task_queue.execution_timeout - max_retries = aana_settings.task_queue.max_retries expired_tasks = TaskRepository(session).update_expired_tasks( - execution_timeout=execution_timeout, max_retries=max_retries + execution_timeout=aana_settings.task_queue.execution_timeout, + heartbeat_timeout=aana_settings.task_queue.heartbeat_timeout, + max_retries=aana_settings.task_queue.max_retries, ) for task in expired_tasks: deployment_response = self.deployment_responses.get(task.id) diff --git a/aana/sdk.py b/aana/sdk.py index ea9554f1..f25a765f 100644 --- a/aana/sdk.py +++ b/aana/sdk.py @@ -23,7 +23,6 @@ DeploymentException, EmptyMigrationsException, FailedDeployment, - InferenceException, InsufficientResources, ) from aana.storage.op import run_alembic_migrations diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index e3d3bf07..74971013 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -223,7 +223,7 @@ def filter_incomplete_tasks(self, task_ids: list[str]) -> list[str]: return incomplete_task_ids def update_expired_tasks( - self, execution_timeout: float, max_retries: int + self, execution_timeout: float, heartbeat_timeout: float, max_retries: int ) -> list[TaskEntity]: """Fetches all tasks that are expired and updates their status. @@ -243,18 +243,23 @@ def update_expired_tasks( Args: execution_timeout (float): The maximum execution time for a task in seconds + heartbeat_timeout (float): The maximum time since the last heartbeat in seconds max_retries (int): The maximum number of retries for a task Returns: list[TaskEntity]: the expired tasks. """ - cutoff_time = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005 + timeout_cutoff = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005 + heartbeat_cutoff = datetime.now() - timedelta(seconds=heartbeat_timeout) # noqa: DTZ005 tasks = ( self.session.query(TaskEntity) .filter( and_( TaskEntity.status.in_([TaskStatus.RUNNING, TaskStatus.ASSIGNED]), - TaskEntity.updated_at <= cutoff_time, + or_( + TaskEntity.assigned_at <= timeout_cutoff, + TaskEntity.updated_at <= heartbeat_cutoff, + ), ), ) .populate_existing() @@ -263,17 +268,28 @@ def update_expired_tasks( ) for task in tasks: if task.num_retries >= max_retries: - self.update_status( - task_id=task.id, - status=TaskStatus.FAILED, - progress=0, - result={ + if task.assigned_at <= timeout_cutoff: + result = { "error": "TimeoutError", "message": ( f"Task execution timed out after {execution_timeout} seconds and " f"exceeded the maximum number of retries ({max_retries})" ), - }, + } + else: + result = { + "error": "HeartbeatTimeoutError", + "message": ( + f"The task has not received a heartbeat for {heartbeat_timeout} seconds and " + f"exceeded the maximum number of retries ({max_retries})" + ), + } + + self.update_status( + task_id=task.id, + status=TaskStatus.FAILED, + progress=0, + result=result, commit=False, ) else: @@ -285,3 +301,19 @@ def update_expired_tasks( ) self.session.commit() return tasks + + def heartbeat(self, task_ids: list[str] | set[str]): + """Updates the updated_at timestamp for multiple tasks. + + Args: + task_ids (list[str] | set[str]): List or set of task IDs to update + """ + task_ids = [ + UUID(task_id) if isinstance(task_id, str) else task_id + for task_id in task_ids + ] + self.session.query(TaskEntity).filter(TaskEntity.id.in_(task_ids)).update( + {TaskEntity.updated_at: datetime.now()}, # noqa: DTZ005 + synchronize_session=False, + ) + self.session.commit() diff --git a/aana/tests/db/datastore/test_task_repo.py b/aana/tests/db/datastore/test_task_repo.py index 521b0dee..f8ab11ab 100644 --- a/aana/tests/db/datastore/test_task_repo.py +++ b/aana/tests/db/datastore/test_task_repo.py @@ -293,49 +293,65 @@ def test_update_expired_tasks(db_session): # Set up current time and a cutoff time current_time = datetime.now() # noqa: DTZ005 execution_timeout = 3600 # 1 hour in seconds + heartbeat_timeout = 60 # 1 minute in seconds # Create tasks with different updated_at times and statuses task1 = TaskEntity( endpoint="/task1", data={"test": "data1"}, status=TaskStatus.RUNNING, - updated_at=current_time - timedelta(hours=2), + assigned_at=current_time - timedelta(hours=2), + updated_at=current_time - timedelta(seconds=10), ) task2 = TaskEntity( endpoint="/task2", data={"test": "data2"}, status=TaskStatus.ASSIGNED, - updated_at=current_time - timedelta(seconds=2), + assigned_at=current_time - timedelta(seconds=2), + updated_at=current_time - timedelta(seconds=5), ) task3 = TaskEntity( endpoint="/task3", data={"test": "data3"}, status=TaskStatus.RUNNING, + assigned_at=current_time - timedelta(seconds=2), updated_at=current_time, ) task4 = TaskEntity( endpoint="/task4", data={"test": "data4"}, status=TaskStatus.COMPLETED, + assigned_at=current_time - timedelta(hours=1), updated_at=current_time - timedelta(hours=2), ) task5 = TaskEntity( endpoint="/task5", data={"test": "data5"}, status=TaskStatus.FAILED, + assigned_at=current_time - timedelta(minutes=1), updated_at=current_time - timedelta(seconds=4), ) + task6 = TaskEntity( + endpoint="/task6", + data={"test": "data6"}, + status=TaskStatus.RUNNING, + assigned_at=current_time - timedelta(minutes=3), + updated_at=current_time - timedelta(minutes=2), + ) - db_session.add_all([task1, task2, task3, task4, task5]) + db_session.add_all([task1, task2, task3, task4, task5, task6]) db_session.commit() # Fetch expired tasks expired_tasks = task_repo.update_expired_tasks( - execution_timeout=execution_timeout, max_retries=3 + execution_timeout=execution_timeout, + heartbeat_timeout=heartbeat_timeout, + max_retries=3, ) - # Assert that only tasks with RUNNING or ASSIGNED status and an updated_at older than the cutoff are returned - expected_task_ids = {str(task1.id)} + # Assert that only tasks with RUNNING or ASSIGNED status and an assigned_at time older than the execution_timeout or + # heartbeat_timeout are returned + expected_task_ids = {str(task1.id), str(task6.id)} returned_task_ids = {str(task.id) for task in expired_tasks} assert returned_task_ids == expected_task_ids