diff --git a/airflow/jobs/triggerer_job_runner.py b/airflow/jobs/triggerer_job_runner.py index b25ce6a7d3790..bb151b32cc87e 100644 --- a/airflow/jobs/triggerer_job_runner.py +++ b/airflow/jobs/triggerer_job_runner.py @@ -28,16 +28,16 @@ from contextlib import suppress from copy import copy from queue import SimpleQueue -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING from sqlalchemy import func, select from airflow.configuration import conf from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import perform_heartbeat -from airflow.models.trigger import ENCRYPTED_KWARGS_PREFIX, Trigger +from airflow.models.trigger import Trigger from airflow.stats import Stats -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.triggers.base import TriggerEvent from airflow.typing_compat import TypedDict from airflow.utils import timezone from airflow.utils.log.file_task_handler import FileTaskHandler @@ -60,6 +60,7 @@ from airflow.jobs.job import Job from airflow.models import TaskInstance + from airflow.triggers.base import BaseTrigger HANDLER_SUPPORTS_TRIGGERER = False """ @@ -235,9 +236,6 @@ def setup_queue_listener(): return None -U = TypeVar("U", bound=BaseTrigger) - - class TriggererJobRunner(BaseJobRunner, LoggingMixin): """ Run active triggers in asyncio and update their dependent tests/DAGs once their events have fired. @@ -675,7 +673,7 @@ def update_triggers(self, requested_trigger_ids: set[int]): continue try: - new_trigger_instance = self.trigger_row_to_trigger_instance(new_trigger_orm, trigger_class) + new_trigger_instance = trigger_class(**new_trigger_orm.kwargs) except TypeError as err: self.log.error("Trigger failed; message=%s", err) self.failed_triggers.append((new_id, err)) @@ -710,18 +708,3 @@ def get_trigger_by_classpath(self, classpath: str) -> type[BaseTrigger]: if classpath not in self.trigger_cache: self.trigger_cache[classpath] = import_string(classpath) return self.trigger_cache[classpath] - - def trigger_row_to_trigger_instance(self, trigger_row: Trigger, trigger_class: type[U]) -> U: - """Convert a Trigger row into a Trigger instance.""" - from airflow.models.crypto import get_fernet - - decrypted_kwargs = {} - fernet = get_fernet() - for k, v in trigger_row.kwargs.items(): - if k.startswith(ENCRYPTED_KWARGS_PREFIX): - decrypted_kwargs[k[len(ENCRYPTED_KWARGS_PREFIX) :]] = fernet.decrypt( - v.encode("utf-8") - ).decode("utf-8") - else: - decrypted_kwargs[k] = v - return trigger_class(**decrypted_kwargs) diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index 2a00ab3856789..4ad42b17b8fc7 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -38,8 +38,6 @@ from airflow.triggers.base import BaseTrigger -ENCRYPTED_KWARGS_PREFIX = "encrypted__" - class Trigger(Base): """ @@ -92,17 +90,8 @@ def __init__( @internal_api_call def from_object(cls, trigger: BaseTrigger) -> Trigger: """Alternative constructor that creates a trigger row based directly off of a Trigger object.""" - from airflow.models.crypto import get_fernet - classpath, kwargs = trigger.serialize() - secure_kwargs = {} - fernet = get_fernet() - for k, v in kwargs.items(): - if k.startswith(ENCRYPTED_KWARGS_PREFIX): - secure_kwargs[k] = fernet.encrypt(v.encode("utf-8")).decode("utf-8") - else: - secure_kwargs[k] = v - return cls(classpath=classpath, kwargs=secure_kwargs) + return cls(classpath=classpath, kwargs=kwargs) @classmethod @internal_api_call diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst b/docs/apache-airflow/authoring-and-scheduling/deferring.rst index 218eb04a36268..88b6f548ec561 100644 --- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst +++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst @@ -197,33 +197,6 @@ Triggers can be as complex or as simple as you want, provided they meet the desi If you are new to writing asynchronous Python, be very careful when writing your ``run()`` method. Python's async model means that code can block the entire process if it does not correctly ``await`` when it does a blocking operation. Airflow attempts to detect process blocking code and warn you in the triggerer logs when it happens. You can enable extra checks by Python by setting the variable ``PYTHONASYNCIODEBUG=1`` when you are writing your trigger to make sure you're writing non-blocking code. Be especially careful when doing filesystem calls, because if the underlying filesystem is network-backed, it can be blocking. -Sensitive information in triggers -''''''''''''''''''''''''''''''''' - -Triggers are serialized and stored in the database, so they can be re-instantiated on any triggerer process. This means that any sensitive information you pass to a trigger will be stored in the database. -If you want to pass sensitive information to a trigger, you can encrypt it before passing it to the trigger, and decrypt it inside the trigger, or update the argument name in the ``serialize`` method by adding ``encrypted__`` as a prefix, and Airflow will automatically encrypt the argument before storing it in the database, and decrypt it when it is read from the database. - -.. code-block:: python - - class MyTrigger(BaseTrigger): - def __init__(self, param, secret): - super().__init__() - self.param = param - self.secret = secret - - def serialize(self): - return ( - "airflow.triggers.MyTrigger", - { - "param": self.param, - "encrypted__secret": self.secret, - }, - ) - - async def run(self): - # self.my_secret will be decrypted here - ... - High Availability ----------------- diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py index ce1858443ef3a..8463edf831b02 100644 --- a/tests/models/test_trigger.py +++ b/tests/models/test_trigger.py @@ -17,21 +17,18 @@ from __future__ import annotations import datetime -from typing import Any, AsyncIterator import pytest import pytz -from cryptography.fernet import Fernet from airflow.jobs.job import Job -from airflow.jobs.triggerer_job_runner import TriggererJobRunner, TriggerRunner +from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models import TaskInstance, Trigger from airflow.operators.empty import EmptyOperator -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.triggers.base import TriggerEvent from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import State -from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @@ -340,45 +337,3 @@ def test_get_sorted_triggers_different_priority_weights(session, create_task_ins trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session) assert trigger_ids_query == [(2,), (1,)] - - -class SensitiveKwargsTrigger(BaseTrigger): - """ - A trigger that has sensitive kwargs. - """ - - def __init__(self, param1: str, param2: str): - super().__init__() - self.param1 = param1 - self.param2 = param2 - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - "tests.models.test_trigger.SensitiveKwargsTrigger", - { - "param1": self.param1, - "encrypted__param2": self.param2, - }, - ) - - async def run(self) -> AsyncIterator[TriggerEvent]: - yield TriggerEvent({}) - - -@conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()}) -def test_serialize_sensitive_kwargs(): - """ - Tests that sensitive kwargs are encrypted. - """ - trigger_instance = SensitiveKwargsTrigger(param1="value1", param2="value2") - trigger_row: Trigger = Trigger.from_object(trigger_instance) - - assert trigger_row.kwargs["param1"] == "value1" - assert "param2" not in trigger_row.kwargs - assert trigger_row.kwargs["encrypted__param2"] != "value2" - - loaded_trigger: SensitiveKwargsTrigger = TriggerRunner().trigger_row_to_trigger_instance( - trigger_row, SensitiveKwargsTrigger - ) - assert loaded_trigger.param1 == "value1" - assert loaded_trigger.param2 == "value2"