diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 2363e927b468f..c9cec771ea081 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -95,6 +95,11 @@ MIN_HEARTBEAT_INTERVAL: int = 5 MAX_FAILED_HEARTBEATS: int = 3 +# These are the task instance states that require some additional information to transition into. +# "Directly" here means that the PATCH API calls to transition into these states are +# made from _handle_request() itself and don't have to come all the way to wait(). +STATES_SENT_DIRECTLY = [IntermediateTIState.DEFERRED, IntermediateTIState.UP_FOR_RESCHEDULE] + @overload def mkpipe() -> tuple[socket, socket]: ... @@ -518,11 +523,11 @@ def wait(self) -> int: # If it hasn't, assume it's failed self._exit_code = self._exit_code if self._exit_code is not None else 1 - # If the process has finished in a terminal state, update the state of the TaskInstance - # to reflect the final state of the process. - # For states like `deferred`, the process will exit with 0, but the state will be updated + # If the process has finished non-directly patched state (directly means deferred, reschedule, etc.), + # update the state of the TaskInstance to reflect the final state of the process. + # For states like `deferred`, `up_for_reschedule`, the process will exit with 0, but the state will be updated # by the subprocess in the `handle_requests` method. - if self.final_state in TerminalTIState: + if self.final_state not in STATES_SENT_DIRECTLY: self.client.task_instances.finish( id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc) )