Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Robust Support for Callbacks at Task and TaskGroup Level #322

Merged
merged 32 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5ea61fd
Unit-tests added for Tasks/TaskGroups
jroach-astronomer Dec 7, 2024
a4d3456
Refactored unit-tests, ready to go
jroach-astronomer Dec 17, 2024
46b58b3
Stashing changes for syncing fork
jroach-astronomer Dec 17, 2024
484058d
Updating example_callbacks.yml file
jroach-astronomer Dec 17, 2024
25baecc
Removing placeholders
jroach-astronomer Dec 17, 2024
81b8c7d
Resolving MC
jroach-astronomer Dec 17, 2024
a1d15a2
Resolving merge conflict
jroach-astronomer Dec 17, 2024
cb6ad85
Fixing unit tests
jroach-astronomer Dec 17, 2024
1635b1b
Pre-commit checks
jroach-astronomer Dec 17, 2024
5cc8a87
Stashing changes for the night
jroach-astronomer Dec 17, 2024
8c31635
on_skipped_callback removing from default_args
jroach-astronomer Dec 17, 2024
32a7400
Version management for Slack
jroach-astronomer Dec 17, 2024
f626458
Adding un-added files
jroach-astronomer Dec 17, 2024
7df1861
Version check for on_skipped_callback
jroach-astronomer Dec 17, 2024
4c9ab09
Updating unit-tests
jroach-astronomer Dec 18, 2024
8688408
Updating unit-tests
jroach-astronomer Dec 18, 2024
9a7d058
Updating unit-tests
jroach-astronomer Dec 18, 2024
d5d8e68
Updating unit-tests
jroach-astronomer Dec 18, 2024
fc26000
Updating unit-tests
jroach-astronomer Dec 18, 2024
05db194
Updating unit-tests
jroach-astronomer Dec 18, 2024
975efe0
Updating unit-tests
jroach-astronomer Dec 18, 2024
7710f31
Updating unit-tests
jroach-astronomer Dec 18, 2024
7a79fd5
Updating unit-tests
jroach-astronomer Dec 18, 2024
ca9d1cb
Updating unit-tests
jroach-astronomer Dec 18, 2024
196420f
Linting, adding version-checking
jroach-astronomer Dec 18, 2024
add2786
Adding action-items to tests for versioning
jroach-astronomer Dec 18, 2024
c4a2342
Added test-coverage for #253
jroach-astronomer Dec 30, 2024
197e8de
Updated example DAG, unit tests
jroach-astronomer Jan 6, 2025
a0e0bca
Merge branch 'main' into issue-253
jroach-astronomer Jan 6, 2025
27b336b
Merge branch 'main' into issue-253
jroach-astronomer Jan 7, 2025
060b96a
Adjusted example DAG to ensure that callback parameters were successf…
jroach-astronomer Jan 7, 2025
317b35e
Merge branch 'main' into issue-253
pankajkoti Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 110 additions & 87 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,78 +165,56 @@ def get_dag_params(self) -> Dict[str, Any]:
dag_params["default_args"]["sla"]: timedelta = timedelta(seconds=dag_params["default_args"]["sla_secs"])
del dag_params["default_args"]["sla_secs"]

if utils.check_dict_key(dag_params["default_args"], "sla_miss_callback"):
if isinstance(dag_params["default_args"]["sla_miss_callback"], str):
dag_params["default_args"]["sla_miss_callback"] = import_string(
dag_params["default_args"]["sla_miss_callback"]
)

if utils.check_dict_key(dag_params["default_args"], "on_execute_callback"):
if isinstance(dag_params["default_args"]["on_execute_callback"], str):
dag_params["default_args"]["on_execute_callback"] = import_string(
dag_params["default_args"]["on_execute_callback"]
)

if utils.check_dict_key(dag_params["default_args"], "on_success_callback"):
if isinstance(dag_params["default_args"]["on_success_callback"], str):
dag_params["default_args"]["on_success_callback"]: Callable = import_string(
dag_params["default_args"]["on_success_callback"]
)

if utils.check_dict_key(dag_params["default_args"], "on_failure_callback"):
dag_params["default_args"]["on_failure_callback"]: Callable = self.set_callback(
parameters=dag_params["default_args"], callback_type="on_failure_callback"
)

if utils.check_dict_key(dag_params["default_args"], "on_retry_callback"):
if isinstance(dag_params["default_args"]["on_retry_callback"], str):
dag_params["default_args"]["on_retry_callback"]: Callable = import_string(
dag_params["default_args"]["on_retry_callback"]
)

if utils.check_dict_key(dag_params, "sla_miss_callback"):
if isinstance(dag_params["sla_miss_callback"], str):
dag_params["sla_miss_callback"]: Callable = import_string(dag_params["sla_miss_callback"])

if utils.check_dict_key(dag_params, "on_success_callback"):
if isinstance(dag_params["on_success_callback"], str):
dag_params["on_success_callback"]: Callable = import_string(dag_params["on_success_callback"])

if utils.check_dict_key(dag_params, "on_failure_callback"):
dag_params["on_failure_callback"]: Callable = self.set_callback(
parameters=dag_params, callback_type="on_failure_callback"
)

if utils.check_dict_key(dag_params, "on_success_callback_name") and utils.check_dict_key(
dag_params, "on_success_callback_file"
):
dag_params["on_success_callback"]: Callable = utils.get_python_callable(
dag_params["on_success_callback_name"],
dag_params["on_success_callback_file"],
)

if utils.check_dict_key(dag_params, "on_failure_callback_name") and utils.check_dict_key(
dag_params, "on_failure_callback_file"
):
dag_params["on_failure_callback"] = self.set_callback(
parameters=dag_params, callback_type="on_failure_callback", has_name_and_file=True
)
# Parse callbacks at the DAG-level and at the Task-level, configured in default_args. Note that the version
# check has gone into the set_callback method
for callback_type in [
"on_execute_callback",
"on_success_callback",
"on_failure_callback",
"on_retry_callback", # Not applicable at the DAG-level
"on_skipped_callback", # Not applicable at the DAG-level
"sla_miss_callback", # Not applicable at the default_args level
]:
# Here, we are parsing both the DAG-level params and default_args for callbacks. Previously, this was
# copy-and-pasted for each callback type and each configuration option (via a string import, function
# defined via YAML, or file path and name

# First, check at the DAG-level for just the single field (via a string or via a provider callback that
# takes parameters). Since "on_retry_callback" and "on_skipped_callback" is only applicable at the
# Task-level, we are skipping that callback type here.
if callback_type not in ("on_retry_callback", "on_skipped_callback"):
if utils.check_dict_key(dag_params, callback_type):
dag_params[callback_type]: Callable = self.set_callback(
parameters=dag_params, callback_type=callback_type
)

if utils.check_dict_key(dag_params["default_args"], "on_success_callback_name") and utils.check_dict_key(
dag_params["default_args"], "on_success_callback_file"
):
# Then, check at the DAG-level for a file path and name
if utils.check_dict_key(dag_params, f"{callback_type}_name") and utils.check_dict_key(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m wondering if we should continue supporting the {callback_type}_name + {callback_type}_file approach for callbacks. Are there scenarios where simply specifying {callback_type} might be insufficient or lack necessary support?

Unifying the approach to a single, straightforward method seems more maintainable and less confusing for users. Additionally, what happens if a user specifies both on_success_callback and on_success_callback_name + on_success_callback_file? Currently, it appears that on_success_callback_name + on_success_callback_file will override the on_success_callback, which isn’t transparent to the end user and could cause confusion.

While this isn’t part of your changes and has been a pre-existing issue, it might be worth discussing whether we could drop support for the name+file approach altogether or at least an error when users try to provide callbacks using both approaches. What do you think?

dag_params, f"{callback_type}_file"
):
dag_params[callback_type] = self.set_callback(
parameters=dag_params, callback_type=callback_type, has_name_and_file=True
)

dag_params["default_args"]["on_success_callback"]: Callable = utils.get_python_callable(
dag_params["default_args"]["on_success_callback_name"],
dag_params["default_args"]["on_success_callback_file"],
)
# SLAs are defined at the DAG-level, and will be applied to every task.
# https://www.astronomer.io/docs/learn/error-notifications-in-airflow/. Here, we are not going to add
# callbacks for sla_miss_callback, or on_skipped_callback if the Airflow version is less than 2.7.0
if (callback_type != "sla_miss_callback") or not (
callback_type == "on_skipped_callback" and version.parse(AIRFLOW_VERSION) < version.parse("2.7.0")
):
# Next, check for a callback at the Task-level using default_args
if utils.check_dict_key(dag_params["default_args"], callback_type):
dag_params["default_args"][callback_type]: Callable = self.set_callback(
parameters=dag_params["default_args"], callback_type=callback_type
)

if utils.check_dict_key(dag_params["default_args"], "on_failure_callback_name") and utils.check_dict_key(
dag_params["default_args"], "on_failure_callback_file"
):
dag_params["default_args"]["on_failure_callback"] = self.set_callback(
parameters=dag_params["default_args"], callback_type="on_failure_callback", has_name_and_file=True
)
# Finally, check for file path and name at the Task-level using default_args
if utils.check_dict_key(dag_params["default_args"], f"{callback_type}_name") and utils.check_dict_key(
dag_params["default_args"], f"{callback_type}_file"
):
dag_params["default_args"][callback_type] = self.set_callback(
parameters=dag_params["default_args"], callback_type=callback_type, has_name_and_file=True
)

if utils.check_dict_key(dag_params, "template_searchpath"):
if isinstance(dag_params["template_searchpath"], (list, str)) and utils.check_template_searchpath(
Expand Down Expand Up @@ -451,6 +429,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
if task_params.get("init_containers") is not None
else None
)

DagBuilder.adjust_general_task_params(task_params)

expand_kwargs: Dict[str, Union[Dict[str, Any], Any]] = {}
Expand Down Expand Up @@ -495,23 +474,49 @@ def make_task_groups(task_groups: Dict[str, Any], dag: DAG) -> Dict[str, "TaskGr

@staticmethod
def _init_task_group_callback_param(task_group_conf):
"""
_init_task_group_callback_param

Handle configuring callbacks for TaskGroups in this method in this helper-method

:param task_group_conf: dict containing the configuration of the TaskGroup
"""
# The Airflow version needs to be at least 2.2.0, and default args must be present. Basically saying here: if
# it's not the case that we're using at least Airflow 2.2.0 and default_args are present, then return the
# TaskGroup configuration without doing anything
if not (
version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0")
and isinstance(task_group_conf.get("default_args"), dict)
):
return task_group_conf

default_args = task_group_conf["default_args"]
callback_keys = [
"on_success_callback",
# Check the callback types that can be in the default_args of the TaskGroup
for callback_type in [
"on_execute_callback",
"on_success_callback",
"on_failure_callback",
"on_retry_callback",
]
"on_skipped_callback", # This is only available AIRFLOW_VERSION >= 2.7.0
]:
# on_skipped_callback can only be added to the default_args of a TaskGroup for AIRFLOW_VERSION >= 2.7.0
if callback_type == "on_skipped_callback" and version.parse(AIRFLOW_VERSION) < version.parse("2.7.0"):
continue

# First, check for a str, str with params, or provider callback
if utils.check_dict_key(task_group_conf["default_args"], callback_type):
task_group_conf["default_args"][callback_type]: Callable = DagBuilder.set_callback(
parameters=task_group_conf["default_args"], callback_type=callback_type
)

for key in callback_keys:
if key in default_args and isinstance(default_args[key], str):
default_args[key]: Callable = import_string(default_args[key])
# Then, check for a file path and name
if utils.check_dict_key(task_group_conf["default_args"], f"{callback_type}_name") and utils.check_dict_key(
task_group_conf["default_args"], f"{callback_type}_file"
):
task_group_conf["default_args"][callback_type] = DagBuilder.set_callback(
parameters=task_group_conf["default_args"],
callback_type=callback_type,
has_name_and_file=True,
)

return task_group_conf

Expand Down Expand Up @@ -866,17 +871,25 @@ def adjust_general_task_params(task_params: dict(str, Any)):
del task_params["execution_date_fn_file"]

# on_execute_callback is an Airflow 2.0 feature
if utils.check_dict_key(task_params, "on_execute_callback"):
task_params["on_execute_callback"]: Callable = import_string(task_params["on_execute_callback"])

if utils.check_dict_key(task_params, "on_failure_callback"):
task_params["on_failure_callback"]: Callable = import_string(task_params["on_failure_callback"])

if utils.check_dict_key(task_params, "on_success_callback"):
task_params["on_success_callback"]: Callable = import_string(task_params["on_success_callback"])
for callback_type in [
"on_execute_callback",
"on_success_callback",
"on_failure_callback",
"on_retry_callback",
"on_skipped_callback",
]:
if utils.check_dict_key(task_params, callback_type):
task_params[callback_type]: Callable = DagBuilder.set_callback(
parameters=task_params, callback_type=callback_type
)

if utils.check_dict_key(task_params, "on_retry_callback"):
task_params["on_retry_callback"]: Callable = import_string(task_params["on_retry_callback"])
# Check for file path and name
if utils.check_dict_key(task_params, f"{callback_type}_name") and utils.check_dict_key(
task_params, f"{callback_type}_file"
):
task_params[callback_type] = DagBuilder.set_callback(
parameters=task_params, callback_type=callback_type, has_name_and_file=True
)

# use variables as arguments on operator
if utils.check_dict_key(task_params, "variables_as_arguments"):
Expand Down Expand Up @@ -991,15 +1004,25 @@ def set_callback(parameters: Union[dict, str], callback_type: str, has_name_and_
:param has_name_and_file:
:returns: Callable
"""
# Check Airflow version, raise an exception otherwise
if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
raise DagFactoryException("Cannot parse callbacks with an Airflow version less than 2.0.0.")

# There is scenario where a callback is passed in via a file and a name. For the most part, this will be a
# Python callable that is treated similarly to a Python callable that the PythonOperator may leverage. That
# being said, what if this is not a Python callable? What if this is another type?
if has_name_and_file:
return utils.get_python_callable(
on_state_callback_callable: Callable = utils.get_python_callable(
python_callable_name=parameters[f"{callback_type}_name"],
python_callable_file=parameters[f"{callback_type}_file"],
)

# Delete the callback_type name and file
del parameters[f"{callback_type}_name"]
del parameters[f"{callback_type}_file"]

return on_state_callback_callable

# If the value stored at parameters[callback_type] is a string, it should be imported under the assumption that
# it is a function that is "ready to be called". If not returning the function, something like this could be
# used to update the config parameters[callback_type] = import_string(parameters[callback_type])
Expand Down
45 changes: 37 additions & 8 deletions dev/dags/example_callbacks.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
example_callbacks:
default_args:
start_date: "2024-01-01"
# Callbacks at be set at the default_args level. These callbacks are then passed to each Task. Fun fact;
# default_args can be overridden within a Task
on_retry_callback: print_hello.print_hello_from_callback
on_failure_callback:
callback: airflow.providers.slack.notifications.slack.send_slack_notification
slack_conn_id: slack_conn_id
Expand All @@ -11,18 +14,44 @@ example_callbacks:
channel: "#channel"
schedule_interval: "@daily"
catchup: False
on_failure_callback:
# These callbacks are set at the DAG-level, vs. the callbacks set above in default_args that are passed onto each
# Task. Previously, the same "on_success_callback" configuration was set as part of task_2
on_execute_callback_name: print_hello_from_callback
on_execute_callback_file: $CONFIG_ROOT_DIR/print_hello.py
on_success_callback:
callback: customized.callbacks.custom_callbacks.output_message
param1: param1
param2: param2
task_groups:
task_group_1:
default_args:
on_success_callback: print_hello.print_hello_from_callback
dependencies: [task_1, task_2]
tasks:
start:
operator: airflow.operators.python.PythonOperator
python_callable_file: $CONFIG_ROOT_DIR/customized/callables/python.py
python_callable_name: succeeding_task
operator: airflow.operators.empty.EmptyOperator
on_success_callback_name: print_hello_from_callback
on_success_callback_file: $CONFIG_ROOT_DIR/print_hello.py
task_1:
operator: airflow.operators.bash_operator.BashOperator
bash_command: "echo 1"
on_success_callback:
callback: customized.callbacks.custom_callbacks.output_message
param1: param1
param2: param2
dependencies: [start]
task_2:
operator: airflow.operators.bash_operator.BashOperator
bash_command: "echo 2"
on_success_callback_name: print_hello_from_callback
on_success_callback_file: $CONFIG_ROOT_DIR/print_hello.py
dependencies: [start]
task_3:
operator: airflow.operators.bash_operator.BashOperator
bash_command: "echo 3"
task_group_name: task_group_1
end:
operator: airflow.operators.python.PythonOperator
python_callable_file: $CONFIG_ROOT_DIR/customized/callables/python.py
python_callable_name: failing_task
operator: airflow.operators.bash_operator.BashOperator
bash_command: "echo -1"
dependencies:
- start
- task_group_1
Loading
Loading