From 62f948cd309f4adeb6b15a2b634a66bfc87159cc Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 3 Apr 2024 06:00:35 +0800 Subject: [PATCH] Improve XComObjectStorageBackend implementation (#38608) Repeated configuration access is moved to cached functions so the string literals don't need to be written repeatedly. The path configuration is made mandatory since it more or less is; using this backend without a path configured is most likely an unintended user error. Various functions are rewritten to take advantage of early returns, and more localized try-except blocks to improve code quality. --- airflow/providers/common/io/xcom/backend.py | 99 +++++++++---------- .../providers/common/io/xcom/test_backend.py | 20 +++- 2 files changed, 64 insertions(+), 55 deletions(-) diff --git a/airflow/providers/common/io/xcom/backend.py b/airflow/providers/common/io/xcom/backend.py index b2416862ee34c..163e15e00d4db 100644 --- a/airflow/providers/common/io/xcom/backend.py +++ b/airflow/providers/common/io/xcom/backend.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import contextlib import json import uuid from typing import TYPE_CHECKING, Any, TypeVar @@ -23,6 +24,7 @@ import fsspec.utils +from airflow.compat.functools import cache from airflow.configuration import conf from airflow.io.path import ObjectStoragePath from airflow.models.xcom import BaseXCom @@ -65,6 +67,21 @@ def _get_compression_suffix(compression: str) -> str: raise ValueError(f"Compression {compression} is not supported. Make sure it is installed.") +@cache +def _get_base_path() -> ObjectStoragePath: + return ObjectStoragePath(conf.get_mandatory_value(SECTION, "xcom_objectstorage_path")) + + +@cache +def _get_compression() -> str | None: + return conf.get(SECTION, "xcom_objectstorage_compression", fallback=None) or None + + +@cache +def _get_threshold() -> int: + return conf.getint(SECTION, "xcom_objectstorage_threshold", fallback=-1) + + class XComObjectStorageBackend(BaseXCom): """XCom backend that stores data in an object store or database depending on the size of the data. @@ -75,30 +92,24 @@ class XComObjectStorageBackend(BaseXCom): """ @staticmethod - def _get_key(data: str) -> str: - """Get the key from the url and normalizes it to be relative to the configured path. + def _get_full_path(data: str) -> ObjectStoragePath: + """Get the path from stored value. :raises ValueError: if the key is not relative to the configured path :raises TypeError: if the url is not a valid url or cannot be split """ - path = conf.get(SECTION, "xcom_objectstorage_path", fallback="") - p = ObjectStoragePath(path) + p = _get_base_path() # normalize the path - path = str(p) - try: url = urlsplit(data) except AttributeError: - raise TypeError(f"Not a valid url: {data}") + raise TypeError(f"Not a valid url: {data}") from None if url.scheme: - k = ObjectStoragePath(data) - - if _is_relative_to(k, p) is False: + if not _is_relative_to(ObjectStoragePath(data), p): raise ValueError(f"Invalid key: {data}") - else: - return data.replace(path, "", 1).lstrip("/") + return p / data.replace(str(p), "", 1).lstrip("/") raise ValueError(f"Not a valid url: {data}") @@ -115,61 +126,47 @@ def serialize_value( # we will always serialize ourselves and not by BaseXCom as the deserialize method # from BaseXCom accepts only XCom objects and not the value directly s_val = json.dumps(value, cls=XComEncoder).encode("utf-8") - path = conf.get(SECTION, "xcom_objectstorage_path", fallback="") - compression = conf.get(SECTION, "xcom_objectstorage_compression", fallback=None) - if compression: - suffix = "." + _get_compression_suffix(compression) + if compression := _get_compression(): + suffix = f".{_get_compression_suffix(compression)}" else: suffix = "" - compression = None - threshold = conf.getint(SECTION, "xcom_objectstorage_threshold", fallback=-1) - - if path and -1 < threshold < len(s_val): - # safeguard against collisions - while True: - p = ObjectStoragePath(path) / f"{dag_id}/{run_id}/{task_id}/{str(uuid.uuid4())}{suffix}" - if not p.exists(): - break + threshold = _get_threshold() + if threshold < 0 or len(s_val) < threshold: # Either no threshold or value is small enough. + return s_val - if not p.parent.exists(): - p.parent.mkdir(parents=True, exist_ok=True) + base_path = _get_base_path() + while True: # Safeguard against collisions. + p = base_path.joinpath(dag_id, run_id, task_id, f"{uuid.uuid4()}{suffix}") + if not p.exists(): + break + p.parent.mkdir(parents=True, exist_ok=True) - with p.open(mode="wb", compression=compression) as f: - f.write(s_val) - - return BaseXCom.serialize_value(str(p)) - else: - return s_val + with p.open(mode="wb", compression=compression) as f: + f.write(s_val) + return BaseXCom.serialize_value(str(p)) @staticmethod - def deserialize_value( - result: XCom, - ) -> Any: + def deserialize_value(result: XCom) -> Any: """Deserializes the value from the database or object storage. Compression is inferred from the file extension. """ data = BaseXCom.deserialize_value(result) - path = conf.get(SECTION, "xcom_objectstorage_path", fallback="") - try: - p = ObjectStoragePath(path) / XComObjectStorageBackend._get_key(data) - return json.load(p.open(mode="rb", compression="infer"), cls=XComDecoder) - except TypeError: + path = XComObjectStorageBackend._get_full_path(data) + except (TypeError, ValueError): # Likely value stored directly in the database. return data - except ValueError: + try: + with path.open(mode="rb", compression="infer") as f: + return json.load(f, cls=XComDecoder) + except (TypeError, ValueError): return data @staticmethod def purge(xcom: XCom, session: Session) -> None: - path = conf.get(SECTION, "xcom_objectstorage_path", fallback="") - if isinstance(xcom.value, str): - try: - p = ObjectStoragePath(path) / XComObjectStorageBackend._get_key(xcom.value) - p.unlink(missing_ok=True) - except TypeError: - pass - except ValueError: - pass + if not isinstance(xcom.value, str): + return + with contextlib.suppress(TypeError, ValueError): + XComObjectStorageBackend._get_full_path(xcom.value).unlink(missing_ok=True) diff --git a/tests/providers/common/io/xcom/test_backend.py b/tests/providers/common/io/xcom/test_backend.py index 008394f365c11..2da2d6fecd26b 100644 --- a/tests/providers/common/io/xcom/test_backend.py +++ b/tests/providers/common/io/xcom/test_backend.py @@ -20,7 +20,6 @@ import pytest import airflow.models.xcom -from airflow.io.path import ObjectStoragePath from airflow.models.xcom import BaseXCom, resolve_xcom_backend from airflow.operators.empty import EmptyOperator from airflow.providers.common.io.xcom.backend import XComObjectStorageBackend @@ -42,6 +41,19 @@ def reset_db(): db.clear_db_xcom() +@pytest.fixture(autouse=True) +def reset_cache(): + from airflow.providers.common.io.xcom import backend + + backend._get_base_path.cache_clear() + backend._get_compression.cache_clear() + backend._get_threshold.cache_clear() + yield + backend._get_base_path.cache_clear() + backend._get_compression.cache_clear() + backend._get_threshold.cache_clear() + + @pytest.fixture def task_instance(create_task_instance_of_operator): return create_task_instance_of_operator( @@ -121,7 +133,7 @@ def test_value_storage(self, task_instance, session): ) data = BaseXCom.deserialize_value(res) - p = ObjectStoragePath(self.path) / XComObjectStorageBackend._get_key(data) + p = XComObjectStorageBackend._get_full_path(data) assert p.exists() is True value = XCom.get_value( @@ -166,7 +178,7 @@ def test_clear(self, task_instance, session): ) data = BaseXCom.deserialize_value(res) - p = ObjectStoragePath(self.path) / XComObjectStorageBackend._get_key(data) + p = XComObjectStorageBackend._get_full_path(data) assert p.exists() is True XCom.clear( @@ -205,7 +217,7 @@ def test_compression(self, task_instance, session): ) data = BaseXCom.deserialize_value(res) - p = ObjectStoragePath(self.path) / XComObjectStorageBackend._get_key(data) + p = XComObjectStorageBackend._get_full_path(data) assert p.exists() is True assert p.suffix == ".gz"