Skip to content

Commit

Permalink
Revert "Support encryption for triggers parameters (#36492)" (#38253)
Browse files Browse the repository at this point in the history
This reverts commit 8fb55f2.
  • Loading branch information
ephraimbuddy authored Mar 18, 2024
1 parent 94f6fcc commit 671ba75
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 108 deletions.
27 changes: 5 additions & 22 deletions airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@
from contextlib import suppress
from copy import copy
from queue import SimpleQueue
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING

from sqlalchemy import func, select

from airflow.configuration import conf
from airflow.jobs.base_job_runner import BaseJobRunner
from airflow.jobs.job import perform_heartbeat
from airflow.models.trigger import ENCRYPTED_KWARGS_PREFIX, Trigger
from airflow.models.trigger import Trigger
from airflow.stats import Stats
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.triggers.base import TriggerEvent
from airflow.typing_compat import TypedDict
from airflow.utils import timezone
from airflow.utils.log.file_task_handler import FileTaskHandler
Expand All @@ -60,6 +60,7 @@

from airflow.jobs.job import Job
from airflow.models import TaskInstance
from airflow.triggers.base import BaseTrigger

HANDLER_SUPPORTS_TRIGGERER = False
"""
Expand Down Expand Up @@ -235,9 +236,6 @@ def setup_queue_listener():
return None


U = TypeVar("U", bound=BaseTrigger)


class TriggererJobRunner(BaseJobRunner, LoggingMixin):
"""
Run active triggers in asyncio and update their dependent tests/DAGs once their events have fired.
Expand Down Expand Up @@ -675,7 +673,7 @@ def update_triggers(self, requested_trigger_ids: set[int]):
continue

try:
new_trigger_instance = self.trigger_row_to_trigger_instance(new_trigger_orm, trigger_class)
new_trigger_instance = trigger_class(**new_trigger_orm.kwargs)
except TypeError as err:
self.log.error("Trigger failed; message=%s", err)
self.failed_triggers.append((new_id, err))
Expand Down Expand Up @@ -710,18 +708,3 @@ def get_trigger_by_classpath(self, classpath: str) -> type[BaseTrigger]:
if classpath not in self.trigger_cache:
self.trigger_cache[classpath] = import_string(classpath)
return self.trigger_cache[classpath]

def trigger_row_to_trigger_instance(self, trigger_row: Trigger, trigger_class: type[U]) -> U:
"""Convert a Trigger row into a Trigger instance."""
from airflow.models.crypto import get_fernet

decrypted_kwargs = {}
fernet = get_fernet()
for k, v in trigger_row.kwargs.items():
if k.startswith(ENCRYPTED_KWARGS_PREFIX):
decrypted_kwargs[k[len(ENCRYPTED_KWARGS_PREFIX) :]] = fernet.decrypt(
v.encode("utf-8")
).decode("utf-8")
else:
decrypted_kwargs[k] = v
return trigger_class(**decrypted_kwargs)
13 changes: 1 addition & 12 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@

from airflow.triggers.base import BaseTrigger

ENCRYPTED_KWARGS_PREFIX = "encrypted__"


class Trigger(Base):
"""
Expand Down Expand Up @@ -92,17 +90,8 @@ def __init__(
@internal_api_call
def from_object(cls, trigger: BaseTrigger) -> Trigger:
"""Alternative constructor that creates a trigger row based directly off of a Trigger object."""
from airflow.models.crypto import get_fernet

classpath, kwargs = trigger.serialize()
secure_kwargs = {}
fernet = get_fernet()
for k, v in kwargs.items():
if k.startswith(ENCRYPTED_KWARGS_PREFIX):
secure_kwargs[k] = fernet.encrypt(v.encode("utf-8")).decode("utf-8")
else:
secure_kwargs[k] = v
return cls(classpath=classpath, kwargs=secure_kwargs)
return cls(classpath=classpath, kwargs=kwargs)

@classmethod
@internal_api_call
Expand Down
27 changes: 0 additions & 27 deletions docs/apache-airflow/authoring-and-scheduling/deferring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,33 +197,6 @@ Triggers can be as complex or as simple as you want, provided they meet the desi

If you are new to writing asynchronous Python, be very careful when writing your ``run()`` method. Python's async model means that code can block the entire process if it does not correctly ``await`` when it does a blocking operation. Airflow attempts to detect process blocking code and warn you in the triggerer logs when it happens. You can enable extra checks by Python by setting the variable ``PYTHONASYNCIODEBUG=1`` when you are writing your trigger to make sure you're writing non-blocking code. Be especially careful when doing filesystem calls, because if the underlying filesystem is network-backed, it can be blocking.

Sensitive information in triggers
'''''''''''''''''''''''''''''''''

Triggers are serialized and stored in the database, so they can be re-instantiated on any triggerer process. This means that any sensitive information you pass to a trigger will be stored in the database.
If you want to pass sensitive information to a trigger, you can encrypt it before passing it to the trigger, and decrypt it inside the trigger, or update the argument name in the ``serialize`` method by adding ``encrypted__`` as a prefix, and Airflow will automatically encrypt the argument before storing it in the database, and decrypt it when it is read from the database.

.. code-block:: python
class MyTrigger(BaseTrigger):
def __init__(self, param, secret):
super().__init__()
self.param = param
self.secret = secret
def serialize(self):
return (
"airflow.triggers.MyTrigger",
{
"param": self.param,
"encrypted__secret": self.secret,
},
)
async def run(self):
# self.my_secret will be decrypted here
...
High Availability
-----------------

Expand Down
49 changes: 2 additions & 47 deletions tests/models/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,18 @@
from __future__ import annotations

import datetime
from typing import Any, AsyncIterator

import pytest
import pytz
from cryptography.fernet import Fernet

from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner, TriggerRunner
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import TaskInstance, Trigger
from airflow.operators.empty import EmptyOperator
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.triggers.base import TriggerEvent
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
from tests.test_utils.config import conf_vars

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -340,45 +337,3 @@ def test_get_sorted_triggers_different_priority_weights(session, create_task_ins
trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session)

assert trigger_ids_query == [(2,), (1,)]


class SensitiveKwargsTrigger(BaseTrigger):
"""
A trigger that has sensitive kwargs.
"""

def __init__(self, param1: str, param2: str):
super().__init__()
self.param1 = param1
self.param2 = param2

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"tests.models.test_trigger.SensitiveKwargsTrigger",
{
"param1": self.param1,
"encrypted__param2": self.param2,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
yield TriggerEvent({})


@conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
def test_serialize_sensitive_kwargs():
"""
Tests that sensitive kwargs are encrypted.
"""
trigger_instance = SensitiveKwargsTrigger(param1="value1", param2="value2")
trigger_row: Trigger = Trigger.from_object(trigger_instance)

assert trigger_row.kwargs["param1"] == "value1"
assert "param2" not in trigger_row.kwargs
assert trigger_row.kwargs["encrypted__param2"] != "value2"

loaded_trigger: SensitiveKwargsTrigger = TriggerRunner().trigger_row_to_trigger_instance(
trigger_row, SensitiveKwargsTrigger
)
assert loaded_trigger.param1 == "value1"
assert loaded_trigger.param2 == "value2"

0 comments on commit 671ba75

Please sign in to comment.