Skip to content

Commit

Permalink
Improve XComObjectStorageBackend implementation (#38608)
Browse files Browse the repository at this point in the history
Repeated configuration access is moved to cached functions so the string
literals don't need to be written repeatedly. The path configuration is
made mandatory since it more or less is; using this backend without a
path configured is most likely an unintended user error.

Various functions are rewritten to take advantage of early returns, and
more localized try-except blocks to improve code quality.
  • Loading branch information
uranusjr authored Apr 2, 2024
1 parent d212b11 commit 62f948c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 55 deletions.
99 changes: 48 additions & 51 deletions airflow/providers/common/io/xcom/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
# under the License.
from __future__ import annotations

import contextlib
import json
import uuid
from typing import TYPE_CHECKING, Any, TypeVar
from urllib.parse import urlsplit

import fsspec.utils

from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.io.path import ObjectStoragePath
from airflow.models.xcom import BaseXCom
Expand Down Expand Up @@ -65,6 +67,21 @@ def _get_compression_suffix(compression: str) -> str:
raise ValueError(f"Compression {compression} is not supported. Make sure it is installed.")


@cache
def _get_base_path() -> ObjectStoragePath:
return ObjectStoragePath(conf.get_mandatory_value(SECTION, "xcom_objectstorage_path"))


@cache
def _get_compression() -> str | None:
return conf.get(SECTION, "xcom_objectstorage_compression", fallback=None) or None


@cache
def _get_threshold() -> int:
return conf.getint(SECTION, "xcom_objectstorage_threshold", fallback=-1)


class XComObjectStorageBackend(BaseXCom):
"""XCom backend that stores data in an object store or database depending on the size of the data.
Expand All @@ -75,30 +92,24 @@ class XComObjectStorageBackend(BaseXCom):
"""

@staticmethod
def _get_key(data: str) -> str:
"""Get the key from the url and normalizes it to be relative to the configured path.
def _get_full_path(data: str) -> ObjectStoragePath:
"""Get the path from stored value.
:raises ValueError: if the key is not relative to the configured path
:raises TypeError: if the url is not a valid url or cannot be split
"""
path = conf.get(SECTION, "xcom_objectstorage_path", fallback="")
p = ObjectStoragePath(path)
p = _get_base_path()

# normalize the path
path = str(p)

try:
url = urlsplit(data)
except AttributeError:
raise TypeError(f"Not a valid url: {data}")
raise TypeError(f"Not a valid url: {data}") from None

if url.scheme:
k = ObjectStoragePath(data)

if _is_relative_to(k, p) is False:
if not _is_relative_to(ObjectStoragePath(data), p):
raise ValueError(f"Invalid key: {data}")
else:
return data.replace(path, "", 1).lstrip("/")
return p / data.replace(str(p), "", 1).lstrip("/")

raise ValueError(f"Not a valid url: {data}")

Expand All @@ -115,61 +126,47 @@ def serialize_value(
# we will always serialize ourselves and not by BaseXCom as the deserialize method
# from BaseXCom accepts only XCom objects and not the value directly
s_val = json.dumps(value, cls=XComEncoder).encode("utf-8")
path = conf.get(SECTION, "xcom_objectstorage_path", fallback="")
compression = conf.get(SECTION, "xcom_objectstorage_compression", fallback=None)

if compression:
suffix = "." + _get_compression_suffix(compression)
if compression := _get_compression():
suffix = f".{_get_compression_suffix(compression)}"
else:
suffix = ""
compression = None

threshold = conf.getint(SECTION, "xcom_objectstorage_threshold", fallback=-1)

if path and -1 < threshold < len(s_val):
# safeguard against collisions
while True:
p = ObjectStoragePath(path) / f"{dag_id}/{run_id}/{task_id}/{str(uuid.uuid4())}{suffix}"
if not p.exists():
break
threshold = _get_threshold()
if threshold < 0 or len(s_val) < threshold: # Either no threshold or value is small enough.
return s_val

if not p.parent.exists():
p.parent.mkdir(parents=True, exist_ok=True)
base_path = _get_base_path()
while True: # Safeguard against collisions.
p = base_path.joinpath(dag_id, run_id, task_id, f"{uuid.uuid4()}{suffix}")
if not p.exists():
break
p.parent.mkdir(parents=True, exist_ok=True)

with p.open(mode="wb", compression=compression) as f:
f.write(s_val)

return BaseXCom.serialize_value(str(p))
else:
return s_val
with p.open(mode="wb", compression=compression) as f:
f.write(s_val)
return BaseXCom.serialize_value(str(p))

@staticmethod
def deserialize_value(
result: XCom,
) -> Any:
def deserialize_value(result: XCom) -> Any:
"""Deserializes the value from the database or object storage.
Compression is inferred from the file extension.
"""
data = BaseXCom.deserialize_value(result)
path = conf.get(SECTION, "xcom_objectstorage_path", fallback="")

try:
p = ObjectStoragePath(path) / XComObjectStorageBackend._get_key(data)
return json.load(p.open(mode="rb", compression="infer"), cls=XComDecoder)
except TypeError:
path = XComObjectStorageBackend._get_full_path(data)
except (TypeError, ValueError): # Likely value stored directly in the database.
return data
except ValueError:
try:
with path.open(mode="rb", compression="infer") as f:
return json.load(f, cls=XComDecoder)
except (TypeError, ValueError):
return data

@staticmethod
def purge(xcom: XCom, session: Session) -> None:
path = conf.get(SECTION, "xcom_objectstorage_path", fallback="")
if isinstance(xcom.value, str):
try:
p = ObjectStoragePath(path) / XComObjectStorageBackend._get_key(xcom.value)
p.unlink(missing_ok=True)
except TypeError:
pass
except ValueError:
pass
if not isinstance(xcom.value, str):
return
with contextlib.suppress(TypeError, ValueError):
XComObjectStorageBackend._get_full_path(xcom.value).unlink(missing_ok=True)
20 changes: 16 additions & 4 deletions tests/providers/common/io/xcom/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import pytest

import airflow.models.xcom
from airflow.io.path import ObjectStoragePath
from airflow.models.xcom import BaseXCom, resolve_xcom_backend
from airflow.operators.empty import EmptyOperator
from airflow.providers.common.io.xcom.backend import XComObjectStorageBackend
Expand All @@ -42,6 +41,19 @@ def reset_db():
db.clear_db_xcom()


@pytest.fixture(autouse=True)
def reset_cache():
from airflow.providers.common.io.xcom import backend

backend._get_base_path.cache_clear()
backend._get_compression.cache_clear()
backend._get_threshold.cache_clear()
yield
backend._get_base_path.cache_clear()
backend._get_compression.cache_clear()
backend._get_threshold.cache_clear()


@pytest.fixture
def task_instance(create_task_instance_of_operator):
return create_task_instance_of_operator(
Expand Down Expand Up @@ -121,7 +133,7 @@ def test_value_storage(self, task_instance, session):
)

data = BaseXCom.deserialize_value(res)
p = ObjectStoragePath(self.path) / XComObjectStorageBackend._get_key(data)
p = XComObjectStorageBackend._get_full_path(data)
assert p.exists() is True

value = XCom.get_value(
Expand Down Expand Up @@ -166,7 +178,7 @@ def test_clear(self, task_instance, session):
)

data = BaseXCom.deserialize_value(res)
p = ObjectStoragePath(self.path) / XComObjectStorageBackend._get_key(data)
p = XComObjectStorageBackend._get_full_path(data)
assert p.exists() is True

XCom.clear(
Expand Down Expand Up @@ -205,7 +217,7 @@ def test_compression(self, task_instance, session):
)

data = BaseXCom.deserialize_value(res)
p = ObjectStoragePath(self.path) / XComObjectStorageBackend._get_key(data)
p = XComObjectStorageBackend._get_full_path(data)
assert p.exists() is True
assert p.suffix == ".gz"

Expand Down

0 comments on commit 62f948c

Please sign in to comment.