From a66ce0999dc5c43279fbf480809fa8972b451ec1 Mon Sep 17 00:00:00 2001 From: tjeerddie Date: Tue, 4 Jun 2024 09:51:27 +0200 Subject: [PATCH] Improve subscription caching with domain model endpoint and graphql subscriptions (#663) * Improve subscription caching with get domain model and graphql subscriptions - change get_subscription_dict to a util function and add subscription to cache when not fetched from cache. - use get_subscription_dict for get domain model endpoint and in graphql get_subscription_details * Bump version to 2.3.0rc4 * Remove to_redis in get_subscription_dict and use _generate_etag * Add docstring to get_subscription_dict and add unit tests --------- Co-authored-by: Mark90 --- .bumpversion.cfg | 2 +- orchestrator/__init__.py | 2 +- .../api/api_v1/endpoints/subscriptions.py | 16 ++----- .../graphql/resolvers/subscription.py | 42 +++++++++++-------- orchestrator/graphql/schemas/product_block.py | 2 +- .../utils/get_subscription_product_blocks.py | 15 +------ orchestrator/utils/get_subscription_dict.py | 17 ++++++++ orchestrator/utils/redis.py | 12 +++--- test/unit_tests/api/test_subscriptions.py | 2 +- .../unit_tests/utils/get_subscription_dict.py | 36 ++++++++++++++++ 10 files changed, 95 insertions(+), 51 deletions(-) create mode 100644 orchestrator/utils/get_subscription_dict.py create mode 100644 test/unit_tests/utils/get_subscription_dict.py diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 6edc8d499..7725102fe 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 2.3.0rc3 +current_version = 2.3.0rc4 commit = False tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(rc(?P\d+))? diff --git a/orchestrator/__init__.py b/orchestrator/__init__.py index 287317cef..be974048b 100644 --- a/orchestrator/__init__.py +++ b/orchestrator/__init__.py @@ -13,7 +13,7 @@ """This is the orchestrator workflow engine.""" -__version__ = "2.3.0rc3" +__version__ = "2.3.0rc4" from orchestrator.app import OrchestratorCore from orchestrator.settings import app_settings diff --git a/orchestrator/api/api_v1/endpoints/subscriptions.py b/orchestrator/api/api_v1/endpoints/subscriptions.py index 97f0dc5f7..f8f49aff8 100644 --- a/orchestrator/api/api_v1/endpoints/subscriptions.py +++ b/orchestrator/api/api_v1/endpoints/subscriptions.py @@ -37,13 +37,10 @@ SubscriptionTable, db, ) -from orchestrator.domain import SubscriptionModel from orchestrator.schemas import SubscriptionWorkflowListsSchema from orchestrator.schemas.subscription import SubscriptionDomainModelSchema, SubscriptionWithMetadata from orchestrator.security import authenticate from orchestrator.services.subscriptions import ( - _generate_etag, - build_extended_domain_model, format_extended_domain_model, format_special_types, get_subscription, @@ -52,7 +49,7 @@ from orchestrator.settings import app_settings from orchestrator.types import SubscriptionLifecycle from orchestrator.utils.deprecation_logger import deprecated_endpoint -from orchestrator.utils.redis import from_redis +from orchestrator.utils.get_subscription_dict import get_subscription_dict router = APIRouter() @@ -106,7 +103,7 @@ def _filter_statuses(filter_statuses: str | None = None) -> list[str]: "/domain-model/{subscription_id}", response_model=SubscriptionDomainModelSchema | None, ) -def subscription_details_by_id_with_domain_model( +async def subscription_details_by_id_with_domain_model( request: Request, subscription_id: UUID, response: Response, filter_owner_relations: bool = True ) -> dict[str, Any] | None: def _build_response(model: dict, etag: str) -> dict[str, Any] | None: @@ -117,14 +114,9 @@ def _build_response(model: dict, etag: str) -> dict[str, Any] | None: filtered = format_extended_domain_model(model, filter_owner_relations=filter_owner_relations) return format_special_types(filtered) - if cache_response := from_redis(subscription_id): - return _build_response(*cache_response) - try: - subscription_model = SubscriptionModel.from_subscription(subscription_id) - extended_model = build_extended_domain_model(subscription_model) - etag = _generate_etag(extended_model) - return _build_response(extended_model, etag) + subscription, etag = await get_subscription_dict(subscription_id) + return _build_response(subscription, etag) except ValueError as e: if str(e) == f"Subscription with id: {subscription_id}, does not exist": raise_status(HTTPStatus.NOT_FOUND, f"Subscription with id: {subscription_id}, not found") diff --git a/orchestrator/graphql/resolvers/subscription.py b/orchestrator/graphql/resolvers/subscription.py index 6019e7a77..4692babcb 100644 --- a/orchestrator/graphql/resolvers/subscription.py +++ b/orchestrator/graphql/resolvers/subscription.py @@ -20,6 +20,7 @@ from sqlalchemy import Select, func, select from strawberry.experimental.pydantic.conversion_types import StrawberryTypeFromPydantic +from nwastdlib.asyncio import gather_nice from orchestrator.db import ProductTable, SubscriptionTable, db from orchestrator.db.filters import Filter from orchestrator.db.filters.subscription import ( @@ -33,7 +34,6 @@ sort_subscriptions, subscription_sort_fields, ) -from orchestrator.domain.base import SubscriptionModel from orchestrator.graphql.pagination import Connection from orchestrator.graphql.schemas.product import ProductModelGraphql from orchestrator.graphql.schemas.subscription import SubscriptionInterface @@ -48,7 +48,7 @@ is_querying_page_data, to_graphql_result_page, ) -from orchestrator.types import SubscriptionLifecycle +from orchestrator.utils.get_subscription_dict import get_subscription_dict logger = structlog.get_logger(__name__) # Note: we can make this more fancy by adding metadata to the field annotation that indicates if a resolver @@ -65,26 +65,34 @@ def get_subscription_graphql_type(info: OrchestratorInfo, subscription_name: str) -> StrawberryTypeFromPydantic: subscription_graphql_type = info.context.graphql_models.get(subscription_name) if not subscription_graphql_type: - raise GraphQLError(message=f"No graphql type found for {subscription_name}") + logger.warning(message=f"No graphql type found for {subscription_name}") + base_type = info.context.graphql_models.get("subscription") + if not base_type: + raise GraphQLError("No subscription base type found") + return base_type return subscription_graphql_type -def get_subscription_details(info: OrchestratorInfo, subscription: SubscriptionTable) -> SubscriptionInterface: +async def get_subscription_details(info: OrchestratorInfo, subscription: SubscriptionTable) -> SubscriptionInterface: + from orchestrator.domain import SUBSCRIPTION_MODEL_REGISTRY from orchestrator.graphql.autoregistration import graphql_subscription_name - subscription_details = SubscriptionModel.from_subscription(subscription.subscription_id) - base_model = subscription_details.__base_type__ if subscription_details.__base_type__ else subscription_details - base_subscription_details = base_model.from_other_lifecycle( # type: ignore - subscription_details, SubscriptionLifecycle.INITIAL, skip_validation=True - ) - base_subscription_details.status = subscription_details.status - strawberry_type = get_subscription_graphql_type(info, graphql_subscription_name(base_model.__name__)) # type: ignore - return strawberry_type.from_pydantic(base_subscription_details) # type:ignore + subscription_dict_data, _ = await get_subscription_dict(subscription.subscription_id) + + domain_model_type = SUBSCRIPTION_MODEL_REGISTRY[subscription.product.name] + base_model = domain_model_type.__base_type__ or domain_model_type + + subscription_name = graphql_subscription_name(base_model.__name__) + subscription_details = base_model.model_validate(subscription_dict_data, strict=False) + subscription_details._db_model = subscription + + strawberry_type = get_subscription_graphql_type(info, subscription_name) + return strawberry_type.from_pydantic(subscription_details) # type: ignore -def format_subscription(info: OrchestratorInfo, subscription: SubscriptionTable) -> SubscriptionInterface: +async def format_subscription(info: OrchestratorInfo, subscription: SubscriptionTable) -> SubscriptionInterface: if _is_subscription_detailed(info): - return get_subscription_details(info, subscription) + return await get_subscription_details(info, subscription) strawberry_type = get_subscription_graphql_type(info, "subscription") return strawberry_type.from_pydantic(subscription) # type:ignore @@ -94,7 +102,7 @@ async def resolve_subscription(info: OrchestratorInfo, id: UUID) -> Subscription stmt = select(SubscriptionTable).where(SubscriptionTable.subscription_id == id) if subscription := db.session.scalar(stmt): - return format_subscription(info, subscription) + return await format_subscription(info, subscription) return None @@ -127,10 +135,10 @@ async def resolve_subscriptions( total = db.session.scalar(select(func.count()).select_from(stmt.subquery())) stmt = apply_range_to_statement(stmt, after, after + first + 1) - graphql_subscriptions = [] + graphql_subscriptions: list[SubscriptionInterface] = [] if is_querying_page_data(info): subscriptions = db.session.scalars(stmt).all() - graphql_subscriptions = [format_subscription(info, p) for p in subscriptions] + graphql_subscriptions = list(await gather_nice((format_subscription(info, p) for p in subscriptions))) logger.info("Resolve subscriptions", filter_by=filter_by, total=graphql_subscriptions) return to_graphql_result_page( diff --git a/orchestrator/graphql/schemas/product_block.py b/orchestrator/graphql/schemas/product_block.py index 8b54db9b9..8a5e65ce6 100644 --- a/orchestrator/graphql/schemas/product_block.py +++ b/orchestrator/graphql/schemas/product_block.py @@ -42,7 +42,7 @@ async def owner_subscription_resolver( stmt = select(SubscriptionTable).where(SubscriptionTable.subscription_id == root.owner_subscription_id) if subscription := db.session.scalar(stmt): - return format_subscription(info, subscription) + return await format_subscription(info, subscription) return None diff --git a/orchestrator/graphql/utils/get_subscription_product_blocks.py b/orchestrator/graphql/utils/get_subscription_product_blocks.py index 78bcfab92..842afead8 100644 --- a/orchestrator/graphql/utils/get_subscription_product_blocks.py +++ b/orchestrator/graphql/utils/get_subscription_product_blocks.py @@ -7,10 +7,8 @@ from pydantic.alias_generators import to_camel as to_lower_camel from strawberry.scalars import JSON -from orchestrator.domain.base import SubscriptionModel from orchestrator.graphql.schemas.product_block import owner_subscription_resolver -from orchestrator.services.subscriptions import build_extended_domain_model -from orchestrator.utils.redis import from_redis +from orchestrator.utils.get_subscription_dict import get_subscription_dict if TYPE_CHECKING: from orchestrator.graphql.schemas.subscription import SubscriptionInterface @@ -62,19 +60,10 @@ def new_product_block(item: dict[str, Any]) -> Generator: pb_instance_property_keys = ("id", "parent", "owner_subscription_id", "subscription_instance_id", "in_use_by_relations") -async def get_subscription_dict(subscription_id: UUID) -> dict: - if cached_model := from_redis(subscription_id): - subscription, _ = cached_model - else: - subscription_model = SubscriptionModel.from_subscription(subscription_id) - subscription = build_extended_domain_model(subscription_model) - return subscription - - async def get_subscription_product_blocks( subscription_id: UUID, tags: list[str] | None = None, product_block_instance_values: list[str] | None = None ) -> list[ProductBlockInstance]: - subscription = await get_subscription_dict(subscription_id) + subscription, _ = await get_subscription_dict(subscription_id) def to_product_block(product_block: dict[str, Any]) -> ProductBlockInstance: def is_resource_type(candidate: Any) -> bool: diff --git a/orchestrator/utils/get_subscription_dict.py b/orchestrator/utils/get_subscription_dict.py new file mode 100644 index 000000000..f5f67c84e --- /dev/null +++ b/orchestrator/utils/get_subscription_dict.py @@ -0,0 +1,17 @@ +from uuid import UUID + +from orchestrator.domain.base import SubscriptionModel +from orchestrator.services.subscriptions import _generate_etag, build_extended_domain_model +from orchestrator.utils.redis import from_redis + + +async def get_subscription_dict(subscription_id: UUID) -> tuple[dict, str]: + """Helper function to get subscription dict by uuid from db or cache.""" + + if cached_model := from_redis(subscription_id): + return cached_model # type: ignore + + subscription_model = SubscriptionModel.from_subscription(subscription_id) + subscription = build_extended_domain_model(subscription_model) + etag = _generate_etag(subscription) + return subscription, etag diff --git a/orchestrator/utils/redis.py b/orchestrator/utils/redis.py index 87cdd6b84..cbcf4d1ab 100644 --- a/orchestrator/utils/redis.py +++ b/orchestrator/utils/redis.py @@ -24,7 +24,7 @@ from orchestrator.services.subscriptions import _generate_etag from orchestrator.settings import app_settings -from orchestrator.utils.json import json_dumps, json_loads +from orchestrator.utils.json import PY_JSON_TYPES, json_dumps, json_loads logger = get_logger(__name__) @@ -37,17 +37,19 @@ def caching_models_enabled() -> bool: return getenv("AIOCACHE_DISABLE", "0") == "0" and app_settings.CACHE_DOMAIN_MODELS -def to_redis(subscription: dict[str, Any]) -> None: +def to_redis(subscription: dict[str, Any]) -> str | None: if caching_models_enabled(): logger.info("Setting cache for subscription", subscription=subscription["subscription_id"]) etag = _generate_etag(subscription) cache.set(f"domain:{subscription['subscription_id']}", json_dumps(subscription), ex=ONE_WEEK) cache.set(f"domain:etag:{subscription['subscription_id']}", etag, ex=ONE_WEEK) - else: - logger.warning("Caching disabled, not caching subscription", subscription=subscription["subscription_id"]) + return etag + + logger.warning("Caching disabled, not caching subscription", subscription=subscription["subscription_id"]) + return None -def from_redis(subscription_id: UUID) -> tuple[Any, str] | None: +def from_redis(subscription_id: UUID) -> tuple[PY_JSON_TYPES, str] | None: log = logger.bind(subscription_id=subscription_id) if caching_models_enabled(): log.debug("Try to retrieve subscription from cache") diff --git a/test/unit_tests/api/test_subscriptions.py b/test/unit_tests/api/test_subscriptions.py index b5164daf5..797d1b474 100644 --- a/test/unit_tests/api/test_subscriptions.py +++ b/test/unit_tests/api/test_subscriptions.py @@ -737,7 +737,7 @@ def test_subscription_detail_with_in_use_by_ids_filtered_self(test_client, produ assert not response.json()["block"]["sub_block"]["in_use_by_ids"] -@mock.patch("orchestrator.api.api_v1.endpoints.subscriptions.from_redis") +@mock.patch("orchestrator.api.api_v1.endpoints.subscriptions.get_subscription_dict") def test_subscription_detail_special_fields(mock_from_redis, test_client): """Test that a subscription with special field types is correctly serialized by Pydantic. diff --git a/test/unit_tests/utils/get_subscription_dict.py b/test/unit_tests/utils/get_subscription_dict.py new file mode 100644 index 000000000..c3298a0ab --- /dev/null +++ b/test/unit_tests/utils/get_subscription_dict.py @@ -0,0 +1,36 @@ +from os import getenv +from unittest import mock +from unittest.mock import Mock + +import pytest + +from orchestrator import app_settings +from orchestrator.domain.base import SubscriptionModel +from orchestrator.services.subscriptions import build_extended_domain_model +from orchestrator.utils.get_subscription_dict import get_subscription_dict +from orchestrator.utils.redis import to_redis + + +@mock.patch.object(app_settings, "CACHE_DOMAIN_MODELS", False) +@mock.patch("orchestrator.utils.get_subscription_dict._generate_etag") +async def test_get_subscription_dict_db(generate_etag, generic_subscription_1): + generate_etag.side_effect = Mock(return_value="etag-mock") + await get_subscription_dict(generic_subscription_1) + assert generate_etag.called + + +@pytest.mark.skipif( + not getenv("AIOCACHE_DISABLE", "0") == "0", reason="AIOCACHE must be enabled for this test to do anything" +) +@mock.patch("orchestrator.utils.get_subscription_dict._generate_etag") +async def test_get_subscription_dict_cache(generate_etag, generic_subscription_1, cache_fixture): + subscription = SubscriptionModel.from_subscription(generic_subscription_1) + extended_model = build_extended_domain_model(subscription) + + # Add domainmodel to cache + to_redis(extended_model) + cache_fixture.extend([f"domain:{generic_subscription_1}", f"domain:etag:{generic_subscription_1}"]) + + generate_etag.side_effect = Mock(return_value="etag-mock") + await get_subscription_dict(generic_subscription_1) + assert not generate_etag.called