Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TaskStatusDB.check_out_task: Use subquery for taskid instead of separate select #30

Merged
merged 3 commits into from
Aug 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 22 additions & 28 deletions exorcist/taskdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,12 @@ def add_task_network(self, taskid_network: nx.DiGraph, max_tries: int):

def _task_row_update_statement(
self,
taskid: str,
taskid: Union[str, SQLStatement],
status: Union[TaskStatus, SQLStatement],
*,
is_checkout: bool = False,
max_tries: Optional[int] = None,
old_status: Optional[TaskStatus] = None,
old_tries: Optional[int] = None
) -> SQLStatement:
"""
Parameters
Expand Down Expand Up @@ -417,9 +416,6 @@ def _task_row_update_statement(
if old_status is not None:
stmt = stmt.where(self.tasks_table.c.status == old_status.value)

if old_tries is not None:
stmt = stmt.where(self.tasks_table.c.tries == old_tries)

# create a dict of values to update
values = {
'status': status,
Expand Down Expand Up @@ -461,52 +457,50 @@ def check_out_task(self):
# TODO: may need move this to a single attempt function and wrap it
# in while loop to catch NoStatusChange errors until we have a
# successful checkout

# TODO: separate selection so subclasses can easily override;
# something like `_select_task(conn: Connection) -> Row` (allow us
# to do something smarter than "take the first available")
_logger.info("Checking out a new task")
sel_stmt = (
sqla.select(self.tasks_table)
subq = (
sqla.select(self.tasks_table.c.taskid)
.where(self.tasks_table.c.status == TaskStatus.AVAILABLE.value)
# .order_by(tasks_table.c.priority.desc()) # FUTURE: priority
.limit(1)
.scalar_subquery()
)
with self.engine.begin() as conn:
task_row = conn.execute(sel_stmt).first()
_logger.debug(f"Before claiming task: {task_row=}")

if task_row is None:
# no tasks are available
return None

update_stmt = self._task_row_update_statement(
task_row.taskid,
taskid=subq,
status=TaskStatus.IN_PROGRESS,
is_checkout=True,
old_status=TaskStatus.AVAILABLE,
old_tries=task_row.tries,
)
result = conn.execute(update_stmt)

self._validate_update_result(result)
).returning(self.tasks_table.c.taskid)
result = list(conn.execute(update_stmt))

if len(result) == 1:
taskid = result[0][0]
elif len(result) == 0:
_logger.info("Unable to select an available task")
return None # skip extra logging
else: # -no-cov-
raise RuntimeError(f"Received {len(result)} task IDs to check "
"out. Something went very weird.")

# log the changed row if we're doing DEBUG logging
if _logger.isEnabledFor(logging.DEBUG):
reselect = (
sqla.select(self.tasks_table)
.where(self.tasks_table.c.taskid == task_row.taskid)
.where(self.tasks_table.c.taskid == taskid)
)
# read-only; use connect() (no autocommit)
with self.engine.connect() as conn:
reloaded = list(conn.execute(reselect).all())

assert len(reloaded) == 1, \
f"Got {len(reloaded)} rows for '{task_row.taskid}'"
f"Got {len(reloaded)} rows for '{taskid}'"

claimed = reloaded[0]
_logger.debug(f"After claiming task: {claimed=}")

_logger.info(f"Selected task '{task_row.taskid}'")
return task_row.taskid
_logger.info(f"Selected task '{taskid}'")
return taskid

def _mark_task_completed_failure(self, taskid: str):
_logger.info(f"Marking try of {taskid} as failed.")
Expand Down