Skip to content

Commit

Permalink
[Feature] Refactor and add support for schedule conditions in DAG con…
Browse files Browse the repository at this point in the history
…figuration: (#320)

### Description
This feature introduces a enhancement to DAG scheduling in Airflow,
enabling support for dynamic schedules based on dataset conditions. By
leveraging dataset filters and logical conditions, users can now create
more flexible and precise scheduling rules tailored to their workflows.

**Key Features**:

- Condition-Based Scheduling: Allows defining schedules using logical
conditions between datasets (e.g., ('dataset_1' & 'dataset_2') |
'dataset_3'), enabling workflows to trigger dynamically based on dataset
availability.

- Dynamic Dataset Processing: Introduced the process_file_with_datasets
function to evaluate and process dataset URIs from external files,
supporting both simple and condition-based schedules.

- Improved Dataset Evaluation: Developed the
evaluate_condition_with_datasets function to transform dataset URIs into
valid variable names and evaluate logical conditions securely.

**Workflow Example**:
Given the following condition:
```yaml
example_custom_config_condition_dataset_consumer_dag:
  description: "Example DAG consumer custom config condition datasets"
  schedule:
    file: $CONFIG_ROOT_DIR/datasets/example_config_datasets.yml
    datasets:  "((dataset_custom_1 & dataset_custom_2) | dataset_custom_3)"
  tasks:
    task_1:
      operator: airflow.operators.bash_operator.BashOperator
      bash_command: "echo 'consumer datasets'"
```
```yaml
example_without_custom_config_condition_dataset_consumer_dag:
  description: "Example DAG consumer custom config condition datasets"
  schedule:
    datasets: "((s3://bucket-cjmm/raw/dataset_custom_1 & s3://bucket-cjmm/raw/dataset_custom_2) | s3://bucket-cjmm/raw/dataset_custom_3)"
  tasks:
    task_1:
      operator: airflow.operators.bash_operator.BashOperator
      bash_command: "echo 'consumer datasets'"
```

```yaml
example_without_custom_config_condition_dataset_consumer_dag:
  description: "Example DAG consumer custom config condition datasets"
  schedule:
    datasets: 
      !or  
        - !and  
          - "s3://bucket-cjmm/raw/dataset_custom_1"  
          - "s3://bucket-cjmm/raw/dataset_custom_2"
        - "s3://bucket-cjmm/raw/dataset_custom_3"
  tasks:
    task_1:
      operator: airflow.operators.bash_operator.BashOperator
      bash_command: "echo 'consumer datasets'"
```

The system evaluates the datasets, ensuring valid references, and
schedules the DAG dynamically when the condition resolves to True.

**Example Use Case**:
Consider a data pipeline that processes files only when multiple
interdependent datasets are updated. With this feature, users can create
dynamic DAG schedules that automatically adjust based on dataset
availability and conditions, optimizing resource allocation and
execution timing.


Images:
![Captura de tela 2024-12-16
181059](https://github.com/user-attachments/assets/e591538f-3f39-44a4-9503-dac45b972e64)
![Captura de tela 2024-12-16
181103](https://github.com/user-attachments/assets/11a2cdca-5cae-4075-bc22-5b257b5d6b00)
![Captura de tela 2024-12-16
181131](https://github.com/user-attachments/assets/9b40f176-91d5-455c-9812-ee4c0ca50912)

---------

Co-authored-by: ErickSeo <[email protected]>
  • Loading branch information
ErickSeo and ErickSeo authored Jan 10, 2025
1 parent 48e5575 commit 4f2a57f
Show file tree
Hide file tree
Showing 7 changed files with 394 additions and 53 deletions.
184 changes: 141 additions & 43 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

from __future__ import annotations

import ast

# pylint: disable=ungrouped-imports
import inspect
import os
import re
from copy import deepcopy
from datetime import datetime, timedelta
from functools import partial
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Tuple, Union

from airflow import DAG, configuration
from airflow.models import BaseOperator, Variable
Expand Down Expand Up @@ -83,7 +85,7 @@
from airflow.utils.task_group import TaskGroup
from kubernetes.client.models import V1Container, V1Pod

from dagfactory import utils
from dagfactory import parsers, utils
from dagfactory.exceptions import DagFactoryConfigException, DagFactoryException

# TimeTable is introduced in Airflow 2.2.0
Expand Down Expand Up @@ -293,8 +295,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
)
if not task_params.get("python_callable"):
task_params["python_callable"]: Callable = utils.get_python_callable(
task_params["python_callable_name"],
task_params["python_callable_file"],
task_params["python_callable_name"], task_params["python_callable_file"]
)
# remove dag-factory specific parameters
# Airflow 2.0 doesn't allow these to be passed to operator
Expand All @@ -312,8 +313,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
# Success checks
if task_params.get("success_check_file") and task_params.get("success_check_name"):
task_params["success"]: Callable = utils.get_python_callable(
task_params["success_check_name"],
task_params["success_check_file"],
task_params["success_check_name"], task_params["success_check_file"]
)
del task_params["success_check_name"]
del task_params["success_check_file"]
Expand All @@ -325,8 +325,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
# Failure checks
if task_params.get("failure_check_file") and task_params.get("failure_check_name"):
task_params["failure"]: Callable = utils.get_python_callable(
task_params["failure_check_name"],
task_params["failure_check_file"],
task_params["failure_check_name"], task_params["failure_check_file"]
)
del task_params["failure_check_name"]
del task_params["failure_check_file"]
Expand All @@ -347,8 +346,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
)
if task_params.get("response_check_file"):
task_params["response_check"]: Callable = utils.get_python_callable(
task_params["response_check_name"],
task_params["response_check_file"],
task_params["response_check_name"], task_params["response_check_file"]
)
# remove dag-factory specific parameters
# Airflow 2.0 doesn't allow these to be passed to operator
Expand Down Expand Up @@ -438,11 +436,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial")
) and version.parse(AIRFLOW_VERSION) >= version.parse("2.3.0"):
# Getting expand and partial kwargs from task_params
(
task_params,
expand_kwargs,
partial_kwargs,
) = utils.get_expand_partial_kwargs(task_params)
(task_params, expand_kwargs, partial_kwargs) = utils.get_expand_partial_kwargs(task_params)

# If there are partial_kwargs we should merge them with existing task_params
if partial_kwargs and not utils.is_partial_duplicated(partial_kwargs, task_params):
Expand Down Expand Up @@ -626,6 +620,132 @@ def replace_expand_values(task_conf: Dict, tasks_dict: Dict[str, BaseOperator]):
task_conf["expand"][expand_key] = tasks_dict[task_id].output
return task_conf

@staticmethod
def safe_eval(condition_string: str, dataset_map: dict) -> Any:
"""
Safely evaluates a condition string using the provided dataset map.
:param condition_string: A string representing the condition to evaluate.
Example: "(dataset_custom_1 & dataset_custom_2) | dataset_custom_3".
:type condition_string: str
:param dataset_map: A dictionary where keys are valid variable names (dataset aliases),
and values are Dataset objects.
:type dataset_map: dict
:returns: The result of evaluating the condition.
:rtype: Any
"""
tree = ast.parse(condition_string, mode="eval")
evaluator = parsers.SafeEvalVisitor(dataset_map)
return evaluator.evaluate(tree)

@staticmethod
def _extract_and_transform_datasets(datasets_conditions: str) -> Tuple[str, Dict[str, Any]]:
"""
Extracts dataset names and storage paths from the conditions string and transforms them into valid variable names.
:param datasets_conditions: A string of conditions dataset URIs to be evaluated in the condition.
:type datasets_conditions: str
:returns: A tuple containing the transformed conditions string and the dataset map.
:rtype: Tuple[str, Dict[str, Any]]
"""
dataset_map = {}
datasets_filter: List[str] = utils.extract_dataset_names(datasets_conditions) + utils.extract_storage_names(
datasets_conditions
)

for uri in datasets_filter:
valid_variable_name = utils.make_valid_variable_name(uri)
datasets_conditions = datasets_conditions.replace(uri, valid_variable_name)
dataset_map[valid_variable_name] = Dataset(uri)

return datasets_conditions, dataset_map

@staticmethod
def evaluate_condition_with_datasets(datasets_conditions: str) -> Any:
"""
Evaluates a condition using the dataset filter, transforming URIs into valid variable names.
:param datasets_conditions: A string of conditions dataset URIs to be evaluated in the condition.
:type datasets_conditions: str
:returns: The result of the logical condition evaluation with URIs replaced by valid variable names.
:rtype: Any
"""
datasets_conditions, dataset_map = DagBuilder._extract_and_transform_datasets(datasets_conditions)
evaluated_condition = DagBuilder.safe_eval(datasets_conditions, dataset_map)
return evaluated_condition

@staticmethod
def process_file_with_datasets(file: str, datasets_conditions: str) -> Any:
"""
Processes datasets from a file and evaluates conditions if provided.
:param file: The file path containing dataset information in a YAML or other structured format.
:type file: str
:param datasets_conditions: A string of dataset conditions to filter and process.
:type datasets_conditions: str
:returns: The result of the condition evaluation if `condition_string` is provided, otherwise a list of `Dataset` objects.
:rtype: Any
"""
is_airflow_version_at_least_2_9 = version.parse(AIRFLOW_VERSION) >= version.parse("2.9.0")
datasets_conditions, dataset_map = DagBuilder._extract_and_transform_datasets(datasets_conditions)

if is_airflow_version_at_least_2_9:
map_datasets = utils.get_datasets_map_uri_yaml_file(file, list(dataset_map.keys()))
dataset_map = {alias_dataset: Dataset(uri) for alias_dataset, uri in map_datasets.items()}
evaluated_condition = DagBuilder.safe_eval(datasets_conditions, dataset_map)
return evaluated_condition
else:
datasets_uri = utils.get_datasets_uri_yaml_file(file, list(dataset_map.keys()))
return [Dataset(uri) for uri in datasets_uri]

@staticmethod
def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) -> None:
"""
Configures the schedule for the DAG based on parameters and the Airflow version.
:param dag_params: A dictionary containing DAG parameters, including scheduling configuration.
Example: {"schedule": {"file": "datasets.yaml", "datasets": ["dataset_1"], "conditions": "dataset_1 & dataset_2"}}
:type dag_params: Dict[str, Any]
:param dag_kwargs: A dictionary for setting the resulting schedule configuration for the DAG.
:type dag_kwargs: Dict[str, Any]
:raises KeyError: If required keys like "schedule" or "datasets" are missing in the parameters.
:returns: None. The function updates `dag_kwargs` in-place.
"""
is_airflow_version_at_least_2_4 = version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0")
is_airflow_version_at_least_2_9 = version.parse(AIRFLOW_VERSION) >= version.parse("2.9.0")
has_schedule_attr = utils.check_dict_key(dag_params, "schedule")
has_schedule_interval_attr = utils.check_dict_key(dag_params, "schedule_interval")

if has_schedule_attr and not has_schedule_interval_attr and is_airflow_version_at_least_2_4:
schedule: Dict[str, Any] = dag_params.get("schedule")

has_file_attr = utils.check_dict_key(schedule, "file")
has_datasets_attr = utils.check_dict_key(schedule, "datasets")

if has_file_attr and has_datasets_attr:
file = schedule.get("file")
datasets: Union[List[str], str] = schedule.get("datasets")
datasets_conditions: str = utils.parse_list_datasets(datasets)
dag_kwargs["schedule"] = DagBuilder.process_file_with_datasets(file, datasets_conditions)

elif has_datasets_attr and is_airflow_version_at_least_2_9:
datasets = schedule["datasets"]
datasets_conditions: str = utils.parse_list_datasets(datasets)
dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets(datasets_conditions)

else:
dag_kwargs["schedule"] = [Dataset(uri) for uri in schedule]

if has_file_attr:
schedule.pop("file")
if has_datasets_attr:
schedule.pop("datasets")

# pylint: disable=too-many-locals
def build(self) -> Dict[str, Union[str, DAG]]:
"""
Expand All @@ -649,8 +769,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:

if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"):
dag_kwargs["max_active_tasks"] = dag_params.get(
"max_active_tasks",
configuration.conf.getint("core", "max_active_tasks_per_dag"),
"max_active_tasks", configuration.conf.getint("core", "max_active_tasks_per_dag")
)

if dag_params.get("timetable"):
Expand All @@ -668,8 +787,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:
)

dag_kwargs["max_active_runs"] = dag_params.get(
"max_active_runs",
configuration.conf.getint("core", "max_active_runs_per_dag"),
"max_active_runs", configuration.conf.getint("core", "max_active_runs_per_dag")
)

dag_kwargs["dagrun_timeout"] = dag_params.get("dagrun_timeout", None)
Expand Down Expand Up @@ -702,24 +820,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:

dag_kwargs["is_paused_upon_creation"] = dag_params.get("is_paused_upon_creation", None)

if (
utils.check_dict_key(dag_params, "schedule")
and not utils.check_dict_key(dag_params, "schedule_interval")
and version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0")
):
if utils.check_dict_key(dag_params["schedule"], "file") and utils.check_dict_key(
dag_params["schedule"], "datasets"
):
file = dag_params["schedule"]["file"]
datasets_filter = dag_params["schedule"]["datasets"]
datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter)

del dag_params["schedule"]["file"]
del dag_params["schedule"]["datasets"]
else:
datasets_uri = dag_params["schedule"]

dag_kwargs["schedule"] = [Dataset(uri) for uri in datasets_uri]
DagBuilder.configure_schedule(dag_params, dag_kwargs)

dag_kwargs["params"] = dag_params.get("params", None)

Expand All @@ -734,8 +835,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:

if dag_params.get("doc_md_python_callable_file") and dag_params.get("doc_md_python_callable_name"):
doc_md_callable = utils.get_python_callable(
dag_params.get("doc_md_python_callable_name"),
dag_params.get("doc_md_python_callable_file"),
dag_params.get("doc_md_python_callable_name"), dag_params.get("doc_md_python_callable_file")
)
dag.doc_md = doc_md_callable(**dag_params.get("doc_md_python_arguments", {}))

Expand Down Expand Up @@ -872,8 +972,7 @@ def adjust_general_task_params(task_params: dict(str, Any)):
task_params, "execution_date_fn_file"
):
task_params["execution_date_fn"]: Callable = utils.get_python_callable(
task_params["execution_date_fn_name"],
task_params["execution_date_fn_file"],
task_params["execution_date_fn_name"], task_params["execution_date_fn_file"]
)
del task_params["execution_date_fn_name"]
del task_params["execution_date_fn_file"]
Expand Down Expand Up @@ -937,8 +1036,7 @@ def make_decorator(
# Fetch the Python callable
if set(mandatory_keys_set1).issubset(task_params):
python_callable: Callable = utils.get_python_callable(
task_params["python_callable_name"],
task_params["python_callable_file"],
task_params["python_callable_name"], task_params["python_callable_file"]
)
# Remove dag-factory specific parameters since Airflow 2.0 doesn't allow these to be passed to operator
del task_params["python_callable_name"]
Expand Down
16 changes: 11 additions & 5 deletions dagfactory/dagfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,21 @@ def __join(loader: yaml.FullLoader, node: yaml.Node) -> str:
seq = loader.construct_sequence(node)
return "".join([str(i) for i in seq])

def __or(loader: yaml.FullLoader, node: yaml.Node) -> str:
seq = loader.construct_sequence(node)
return " | ".join([f"({str(i)})" for i in seq])

def __and(loader: yaml.FullLoader, node: yaml.Node) -> str:
seq = loader.construct_sequence(node)
return " & ".join([f"({str(i)})" for i in seq])

yaml.add_constructor("!join", __join, yaml.FullLoader)
yaml.add_constructor("!or", __or, yaml.FullLoader)
yaml.add_constructor("!and", __and, yaml.FullLoader)

with open(config_filepath, "r", encoding="utf-8") as fp:
yaml.add_constructor("!join", __join, yaml.FullLoader)
config_with_env = os.path.expandvars(fp.read())
config: Dict[str, Any] = yaml.load(
stream=config_with_env,
Loader=yaml.FullLoader,
)
config: Dict[str, Any] = yaml.load(stream=config_with_env, Loader=yaml.FullLoader)
except Exception as err:
raise DagFactoryConfigException("Invalid DAG Factory config file") from err
return config
Expand Down
34 changes: 34 additions & 0 deletions dagfactory/parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import ast


class SafeEvalVisitor(ast.NodeVisitor):
def __init__(self, dataset_map):
self.dataset_map = dataset_map

def evaluate(self, tree):
return self.visit(tree)

def visit_Expression(self, node):
return self.visit(node.body)

def visit_BinOp(self, node):
left = self.visit(node.left)
right = self.visit(node.right)

if isinstance(node.op, ast.BitAnd):
return left & right
elif isinstance(node.op, ast.BitOr):
return left | right
else:
raise ValueError(f"Unsupported binary operation: {type(node.op).__name__}")

def visit_Name(self, node):
if node.id in self.dataset_map:
return self.dataset_map[node.id]
raise NameError(f"Undefined variable: {node.id}")

def visit_Constant(self, node):
return node.value

def generic_visit(self, node):
raise ValueError(f"Unsupported syntax: {type(node).__name__}")
Loading

0 comments on commit 4f2a57f

Please sign in to comment.