Skip to content

Commit

Permalink
Merge pull request #30 from OpenFreeEnergy/new_update
Browse files Browse the repository at this point in the history
`TaskStatusDB.check_out_task`: Use subquery for taskid instead of separate select
  • Loading branch information
dwhswenson authored Aug 28, 2023
2 parents 3e84595 + c0722a5 commit c391a8e
Showing 1 changed file with 22 additions and 28 deletions.
50 changes: 22 additions & 28 deletions exorcist/taskdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,12 @@ def add_task_network(self, taskid_network: nx.DiGraph, max_tries: int):

def _task_row_update_statement(
self,
taskid: str,
taskid: Union[str, SQLStatement],
status: Union[TaskStatus, SQLStatement],
*,
is_checkout: bool = False,
max_tries: Optional[int] = None,
old_status: Optional[TaskStatus] = None,
old_tries: Optional[int] = None
) -> SQLStatement:
"""
Parameters
Expand Down Expand Up @@ -417,9 +416,6 @@ def _task_row_update_statement(
if old_status is not None:
stmt = stmt.where(self.tasks_table.c.status == old_status.value)

if old_tries is not None:
stmt = stmt.where(self.tasks_table.c.tries == old_tries)

# create a dict of values to update
values = {
'status': status,
Expand Down Expand Up @@ -461,52 +457,50 @@ def check_out_task(self):
# TODO: may need move this to a single attempt function and wrap it
# in while loop to catch NoStatusChange errors until we have a
# successful checkout

# TODO: separate selection so subclasses can easily override;
# something like `_select_task(conn: Connection) -> Row` (allow us
# to do something smarter than "take the first available")
_logger.info("Checking out a new task")
sel_stmt = (
sqla.select(self.tasks_table)
subq = (
sqla.select(self.tasks_table.c.taskid)
.where(self.tasks_table.c.status == TaskStatus.AVAILABLE.value)
# .order_by(tasks_table.c.priority.desc()) # FUTURE: priority
.limit(1)
.scalar_subquery()
)
with self.engine.begin() as conn:
task_row = conn.execute(sel_stmt).first()
_logger.debug(f"Before claiming task: {task_row=}")

if task_row is None:
# no tasks are available
return None

update_stmt = self._task_row_update_statement(
task_row.taskid,
taskid=subq,
status=TaskStatus.IN_PROGRESS,
is_checkout=True,
old_status=TaskStatus.AVAILABLE,
old_tries=task_row.tries,
)
result = conn.execute(update_stmt)

self._validate_update_result(result)
).returning(self.tasks_table.c.taskid)
result = list(conn.execute(update_stmt))

if len(result) == 1:
taskid = result[0][0]
elif len(result) == 0:
_logger.info("Unable to select an available task")
return None # skip extra logging
else: # -no-cov-
raise RuntimeError(f"Received {len(result)} task IDs to check "
"out. Something went very weird.")

# log the changed row if we're doing DEBUG logging
if _logger.isEnabledFor(logging.DEBUG):
reselect = (
sqla.select(self.tasks_table)
.where(self.tasks_table.c.taskid == task_row.taskid)
.where(self.tasks_table.c.taskid == taskid)
)
# read-only; use connect() (no autocommit)
with self.engine.connect() as conn:
reloaded = list(conn.execute(reselect).all())

assert len(reloaded) == 1, \
f"Got {len(reloaded)} rows for '{task_row.taskid}'"
f"Got {len(reloaded)} rows for '{taskid}'"

claimed = reloaded[0]
_logger.debug(f"After claiming task: {claimed=}")

_logger.info(f"Selected task '{task_row.taskid}'")
return task_row.taskid
_logger.info(f"Selected task '{taskid}'")
return taskid

def _mark_task_completed_failure(self, taskid: str):
_logger.info(f"Marking try of {taskid} as failed.")
Expand Down

0 comments on commit c391a8e

Please sign in to comment.