Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Heartbeat Implementation for Task Queue #199

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions aana/api/request_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import time
from typing import Annotated, Any
from uuid import uuid4
from uuid import UUID, uuid4

import orjson
import ray
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(

app.openapi = self.custom_openapi
self.ready = True
self.running_tasks = set()

def custom_openapi(self) -> dict[str, Any]:
"""Returns OpenAPI schema, generating it if necessary."""
Expand Down Expand Up @@ -95,16 +96,24 @@ async def is_ready(self):
"""
return AanaJSONResponse(content={"ready": self.ready})

async def execute_task(self, task_id: str) -> Any:
async def check_health(self):
"""Check the health of the application."""
# Heartbeat for the running tasks
with get_session() as session:
task_repo = TaskRepository(session)
task_repo.heartbeat(self.running_tasks)
HRashidi marked this conversation as resolved.
Show resolved Hide resolved

async def execute_task(self, task_id: str | UUID) -> Any:
"""Execute a task.

Args:
task_id (str): The task ID.
task_id (str | UUID): The ID of the task.

Returns:
Any: The response from the endpoint.
"""
try:
self.running_tasks.add(task_id)
with get_session() as session:
task_repo = TaskRepository(session)
task = task_repo.read(task_id)
Expand Down Expand Up @@ -139,8 +148,9 @@ async def execute_task(self, task_id: str) -> Any:
TaskRepository(session).update_status(
task_id, TaskStatus.FAILED, 0, error
)
else:
return out
finally:
self.running_tasks.remove(task_id)
return out

@app.get(
"/tasks/get/{task_id}",
Expand Down
2 changes: 2 additions & 0 deletions aana/configs/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ class TaskQueueSettings(BaseModel):
execution_timeout (int): The maximum execution time for a task in seconds.
After this time, if the task is still running,
it will be considered as stuck and will be reassign to another worker.
heartbeat_timeout (int): The maximum time between heartbeats in seconds.
max_retries (int): The maximum number of retries for a task.
"""

enabled: bool = True
num_workers: int = 4
execution_timeout: int = 600
heartbeat_timeout: int = 60
max_retries: int = 3


Expand Down
6 changes: 3 additions & 3 deletions aana/deployments/task_queue_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ async def loop(self): # noqa: C901
)

# Check for expired tasks
execution_timeout = aana_settings.task_queue.execution_timeout
max_retries = aana_settings.task_queue.max_retries
expired_tasks = TaskRepository(session).update_expired_tasks(
execution_timeout=execution_timeout, max_retries=max_retries
execution_timeout=aana_settings.task_queue.execution_timeout,
heartbeat_timeout=aana_settings.task_queue.heartbeat_timeout,
max_retries=aana_settings.task_queue.max_retries,
)
for task in expired_tasks:
deployment_response = self.deployment_responses.get(task.id)
Expand Down
1 change: 0 additions & 1 deletion aana/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
DeploymentException,
EmptyMigrationsException,
FailedDeployment,
InferenceException,
InsufficientResources,
)
from aana.storage.op import run_alembic_migrations
Expand Down
50 changes: 41 additions & 9 deletions aana/storage/repository/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def filter_incomplete_tasks(self, task_ids: list[str]) -> list[str]:
return incomplete_task_ids

def update_expired_tasks(
self, execution_timeout: float, max_retries: int
self, execution_timeout: float, heartbeat_timeout: float, max_retries: int
) -> list[TaskEntity]:
"""Fetches all tasks that are expired and updates their status.

Expand All @@ -243,18 +243,23 @@ def update_expired_tasks(

Args:
execution_timeout (float): The maximum execution time for a task in seconds
heartbeat_timeout (float): The maximum time since the last heartbeat in seconds
max_retries (int): The maximum number of retries for a task

Returns:
list[TaskEntity]: the expired tasks.
"""
cutoff_time = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005
timeout_cutoff = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005
heartbeat_cutoff = datetime.now() - timedelta(seconds=heartbeat_timeout) # noqa: DTZ005
tasks = (
self.session.query(TaskEntity)
.filter(
and_(
TaskEntity.status.in_([TaskStatus.RUNNING, TaskStatus.ASSIGNED]),
TaskEntity.updated_at <= cutoff_time,
or_(
TaskEntity.assigned_at <= timeout_cutoff,
TaskEntity.updated_at <= heartbeat_cutoff,
),
),
)
.populate_existing()
Expand All @@ -263,17 +268,28 @@ def update_expired_tasks(
)
for task in tasks:
if task.num_retries >= max_retries:
self.update_status(
task_id=task.id,
status=TaskStatus.FAILED,
progress=0,
result={
if task.assigned_at <= timeout_cutoff:
result = {
"error": "TimeoutError",
"message": (
f"Task execution timed out after {execution_timeout} seconds and "
f"exceeded the maximum number of retries ({max_retries})"
),
},
}
else:
result = {
"error": "HeartbeatTimeoutError",
"message": (
f"The task has not received a heartbeat for {heartbeat_timeout} seconds and "
f"exceeded the maximum number of retries ({max_retries})"
),
}

self.update_status(
task_id=task.id,
status=TaskStatus.FAILED,
progress=0,
result=result,
commit=False,
)
else:
Expand All @@ -285,3 +301,19 @@ def update_expired_tasks(
)
self.session.commit()
return tasks

def heartbeat(self, task_ids: list[str] | set[str]):
"""Updates the updated_at timestamp for multiple tasks.

Args:
task_ids (list[str] | set[str]): List or set of task IDs to update
"""
task_ids = [
UUID(task_id) if isinstance(task_id, str) else task_id
for task_id in task_ids
]
self.session.query(TaskEntity).filter(TaskEntity.id.in_(task_ids)).update(
{TaskEntity.updated_at: datetime.now()}, # noqa: DTZ005
synchronize_session=False,
)
self.session.commit()
28 changes: 22 additions & 6 deletions aana/tests/db/datastore/test_task_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,49 +293,65 @@ def test_update_expired_tasks(db_session):
# Set up current time and a cutoff time
current_time = datetime.now() # noqa: DTZ005
execution_timeout = 3600 # 1 hour in seconds
heartbeat_timeout = 60 # 1 minute in seconds

# Create tasks with different updated_at times and statuses
task1 = TaskEntity(
endpoint="/task1",
data={"test": "data1"},
status=TaskStatus.RUNNING,
updated_at=current_time - timedelta(hours=2),
assigned_at=current_time - timedelta(hours=2),
updated_at=current_time - timedelta(seconds=10),
)
task2 = TaskEntity(
endpoint="/task2",
data={"test": "data2"},
status=TaskStatus.ASSIGNED,
updated_at=current_time - timedelta(seconds=2),
assigned_at=current_time - timedelta(seconds=2),
updated_at=current_time - timedelta(seconds=5),
)
task3 = TaskEntity(
endpoint="/task3",
data={"test": "data3"},
status=TaskStatus.RUNNING,
assigned_at=current_time - timedelta(seconds=2),
updated_at=current_time,
)
task4 = TaskEntity(
endpoint="/task4",
data={"test": "data4"},
status=TaskStatus.COMPLETED,
assigned_at=current_time - timedelta(hours=1),
updated_at=current_time - timedelta(hours=2),
)
task5 = TaskEntity(
endpoint="/task5",
data={"test": "data5"},
status=TaskStatus.FAILED,
assigned_at=current_time - timedelta(minutes=1),
updated_at=current_time - timedelta(seconds=4),
)
task6 = TaskEntity(
endpoint="/task6",
data={"test": "data6"},
status=TaskStatus.RUNNING,
assigned_at=current_time - timedelta(minutes=3),
updated_at=current_time - timedelta(minutes=2),
)

db_session.add_all([task1, task2, task3, task4, task5])
db_session.add_all([task1, task2, task3, task4, task5, task6])
db_session.commit()

# Fetch expired tasks
expired_tasks = task_repo.update_expired_tasks(
execution_timeout=execution_timeout, max_retries=3
execution_timeout=execution_timeout,
heartbeat_timeout=heartbeat_timeout,
max_retries=3,
)

# Assert that only tasks with RUNNING or ASSIGNED status and an updated_at older than the cutoff are returned
expected_task_ids = {str(task1.id)}
# Assert that only tasks with RUNNING or ASSIGNED status and an assigned_at time older than the execution_timeout or
# heartbeat_timeout are returned
expected_task_ids = {str(task1.id), str(task6.id)}
returned_task_ids = {str(task.id) for task in expired_tasks}

assert returned_task_ids == expected_task_ids
Loading