diff --git a/exorcist/taskdb.py b/exorcist/taskdb.py index 3afed8f..f99cc81 100644 --- a/exorcist/taskdb.py +++ b/exorcist/taskdb.py @@ -369,13 +369,12 @@ def add_task_network(self, taskid_network: nx.DiGraph, max_tries: int): def _task_row_update_statement( self, - taskid: str, + taskid: Union[str, SQLStatement], status: Union[TaskStatus, SQLStatement], *, is_checkout: bool = False, max_tries: Optional[int] = None, old_status: Optional[TaskStatus] = None, - old_tries: Optional[int] = None ) -> SQLStatement: """ Parameters @@ -417,9 +416,6 @@ def _task_row_update_statement( if old_status is not None: stmt = stmt.where(self.tasks_table.c.status == old_status.value) - if old_tries is not None: - stmt = stmt.where(self.tasks_table.c.tries == old_tries) - # create a dict of values to update values = { 'status': status, @@ -461,52 +457,50 @@ def check_out_task(self): # TODO: may need move this to a single attempt function and wrap it # in while loop to catch NoStatusChange errors until we have a # successful checkout - - # TODO: separate selection so subclasses can easily override; - # something like `_select_task(conn: Connection) -> Row` (allow us - # to do something smarter than "take the first available") _logger.info("Checking out a new task") - sel_stmt = ( - sqla.select(self.tasks_table) + subq = ( + sqla.select(self.tasks_table.c.taskid) .where(self.tasks_table.c.status == TaskStatus.AVAILABLE.value) + # .order_by(tasks_table.c.priority.desc()) # FUTURE: priority + .limit(1) + .scalar_subquery() ) with self.engine.begin() as conn: - task_row = conn.execute(sel_stmt).first() - _logger.debug(f"Before claiming task: {task_row=}") - - if task_row is None: - # no tasks are available - return None - update_stmt = self._task_row_update_statement( - task_row.taskid, + taskid=subq, status=TaskStatus.IN_PROGRESS, is_checkout=True, old_status=TaskStatus.AVAILABLE, - old_tries=task_row.tries, - ) - result = conn.execute(update_stmt) - - self._validate_update_result(result) + ).returning(self.tasks_table.c.taskid) + result = list(conn.execute(update_stmt)) + + if len(result) == 1: + taskid = result[0][0] + elif len(result) == 0: + _logger.info("Unable to select an available task") + return None # skip extra logging + else: # -no-cov- + raise RuntimeError(f"Received {len(result)} task IDs to check " + "out. Something went very weird.") # log the changed row if we're doing DEBUG logging if _logger.isEnabledFor(logging.DEBUG): reselect = ( sqla.select(self.tasks_table) - .where(self.tasks_table.c.taskid == task_row.taskid) + .where(self.tasks_table.c.taskid == taskid) ) # read-only; use connect() (no autocommit) with self.engine.connect() as conn: reloaded = list(conn.execute(reselect).all()) assert len(reloaded) == 1, \ - f"Got {len(reloaded)} rows for '{task_row.taskid}'" + f"Got {len(reloaded)} rows for '{taskid}'" claimed = reloaded[0] _logger.debug(f"After claiming task: {claimed=}") - _logger.info(f"Selected task '{task_row.taskid}'") - return task_row.taskid + _logger.info(f"Selected task '{taskid}'") + return taskid def _mark_task_completed_failure(self, taskid: str): _logger.info(f"Marking try of {taskid} as failed.")