Skip to content

Commit

Permalink
AIP-72: Handling SystemExit from task sdk (apache#45282)
Browse files Browse the repository at this point in the history
  • Loading branch information
amoghrajesh authored Dec 30, 2024
1 parent f09bd4e commit d22684c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
7 changes: 6 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,12 @@ def run(ti: RuntimeTaskInstance, log: Logger):
)
# TODO: Run task failure callbacks here
except SystemExit:
...
# SystemExit needs to be retried if they are eligible.
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)
# TODO: Run task failure callbacks here
except BaseException:
# TODO: Run task failure callbacks here
msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc))
Expand Down
38 changes: 38 additions & 0 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,44 @@ def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context,
)


def test_run_raises_system_exit(time_machine, mocked_parse, make_ti_context, mock_supervisor_comms):
"""Test running a basic task that exits with SystemExit exception."""
from airflow.providers.standard.operators.python import PythonOperator

task = PythonOperator(
task_id="system_exit_task",
python_callable=lambda: exit(10),
)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(),
task_id="system_exit_task",
dag_id="basic_dag_system_exit",
run_id="c",
try_number=1,
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)

ti = mocked_parse(what, "basic_dag_system_exit", task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(
state=TerminalTIState.FAILED,
end_date=instant,
),
log=mock.ANY,
)


def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms):
"""Test running a DAG with templated task."""
from airflow.providers.standard.operators.bash import BashOperator
Expand Down

0 comments on commit d22684c

Please sign in to comment.