Skip to content

Commit

Permalink
Update task expiration logic to use assigned_at timestamp for executi…
Browse files Browse the repository at this point in the history
…on timeout and add heartbeat timeout to tests
  • Loading branch information
Aleksandr Movchan committed Nov 1, 2024
1 parent b59272d commit bd54ad5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
4 changes: 2 additions & 2 deletions aana/storage/repository/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def update_expired_tasks(
and_(
TaskEntity.status.in_([TaskStatus.RUNNING, TaskStatus.ASSIGNED]),
or_(
TaskEntity.updated_at <= timeout_cutoff,
TaskEntity.assigned_at <= timeout_cutoff,
TaskEntity.updated_at <= heartbeat_cutoff,
),
),
Expand All @@ -268,7 +268,7 @@ def update_expired_tasks(
)
for task in tasks:
if task.num_retries >= max_retries:
if task.updated_at <= timeout_cutoff:
if task.assigned_at <= timeout_cutoff:
result = {
"error": "TimeoutError",
"message": (
Expand Down
28 changes: 22 additions & 6 deletions aana/tests/db/datastore/test_task_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,49 +293,65 @@ def test_update_expired_tasks(db_session):
# Set up current time and a cutoff time
current_time = datetime.now() # noqa: DTZ005
execution_timeout = 3600 # 1 hour in seconds
heartbeat_timeout = 60 # 1 minute in seconds

# Create tasks with different updated_at times and statuses
task1 = TaskEntity(
endpoint="/task1",
data={"test": "data1"},
status=TaskStatus.RUNNING,
updated_at=current_time - timedelta(hours=2),
assigned_at=current_time - timedelta(hours=2),
updated_at=current_time - timedelta(seconds=10),
)
task2 = TaskEntity(
endpoint="/task2",
data={"test": "data2"},
status=TaskStatus.ASSIGNED,
updated_at=current_time - timedelta(seconds=2),
assigned_at=current_time - timedelta(seconds=2),
updated_at=current_time - timedelta(seconds=5),
)
task3 = TaskEntity(
endpoint="/task3",
data={"test": "data3"},
status=TaskStatus.RUNNING,
assigned_at=current_time - timedelta(seconds=2),
updated_at=current_time,
)
task4 = TaskEntity(
endpoint="/task4",
data={"test": "data4"},
status=TaskStatus.COMPLETED,
assigned_at=current_time - timedelta(hours=1),
updated_at=current_time - timedelta(hours=2),
)
task5 = TaskEntity(
endpoint="/task5",
data={"test": "data5"},
status=TaskStatus.FAILED,
assigned_at=current_time - timedelta(minutes=1),
updated_at=current_time - timedelta(seconds=4),
)
task6 = TaskEntity(
endpoint="/task6",
data={"test": "data6"},
status=TaskStatus.RUNNING,
assigned_at=current_time - timedelta(minutes=3),
updated_at=current_time - timedelta(minutes=2),
)

db_session.add_all([task1, task2, task3, task4, task5])
db_session.add_all([task1, task2, task3, task4, task5, task6])
db_session.commit()

# Fetch expired tasks
expired_tasks = task_repo.update_expired_tasks(
execution_timeout=execution_timeout, max_retries=3
execution_timeout=execution_timeout,
heartbeat_timeout=heartbeat_timeout,
max_retries=3,
)

# Assert that only tasks with RUNNING or ASSIGNED status and an updated_at older than the cutoff are returned
expected_task_ids = {str(task1.id)}
# Assert that only tasks with RUNNING or ASSIGNED status and an assigned_at time older than the execution_timeout or
# heartbeat_timeout are returned
expected_task_ids = {str(task1.id), str(task6.id)}
returned_task_ids = {str(task.id) for task in expired_tasks}

assert returned_task_ids == expected_task_ids

0 comments on commit bd54ad5

Please sign in to comment.