Skip to content

Commit

Permalink
Improve subscription caching with domain model endpoint and graphql s…
Browse files Browse the repository at this point in the history
…ubscriptions (#663)

* Improve subscription caching with get domain model and graphql subscriptions

- change get_subscription_dict to a util function and add subscription to cache when not fetched from cache.
- use get_subscription_dict for get domain model endpoint and in graphql get_subscription_details

* Bump version to 2.3.0rc4

* Remove to_redis in get_subscription_dict and use _generate_etag

* Add docstring to get_subscription_dict and add unit tests

---------

Co-authored-by: Mark90 <[email protected]>
  • Loading branch information
tjeerddie and Mark90 authored Jun 4, 2024
1 parent 3bf1278 commit a66ce09
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 2.3.0rc3
current_version = 2.3.0rc4
commit = False
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(rc(?P<build>\d+))?
Expand Down
2 changes: 1 addition & 1 deletion orchestrator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

"""This is the orchestrator workflow engine."""

__version__ = "2.3.0rc3"
__version__ = "2.3.0rc4"

from orchestrator.app import OrchestratorCore
from orchestrator.settings import app_settings
Expand Down
16 changes: 4 additions & 12 deletions orchestrator/api/api_v1/endpoints/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,10 @@
SubscriptionTable,
db,
)
from orchestrator.domain import SubscriptionModel
from orchestrator.schemas import SubscriptionWorkflowListsSchema
from orchestrator.schemas.subscription import SubscriptionDomainModelSchema, SubscriptionWithMetadata
from orchestrator.security import authenticate
from orchestrator.services.subscriptions import (
_generate_etag,
build_extended_domain_model,
format_extended_domain_model,
format_special_types,
get_subscription,
Expand All @@ -52,7 +49,7 @@
from orchestrator.settings import app_settings
from orchestrator.types import SubscriptionLifecycle
from orchestrator.utils.deprecation_logger import deprecated_endpoint
from orchestrator.utils.redis import from_redis
from orchestrator.utils.get_subscription_dict import get_subscription_dict

router = APIRouter()

Expand Down Expand Up @@ -106,7 +103,7 @@ def _filter_statuses(filter_statuses: str | None = None) -> list[str]:
"/domain-model/{subscription_id}",
response_model=SubscriptionDomainModelSchema | None,
)
def subscription_details_by_id_with_domain_model(
async def subscription_details_by_id_with_domain_model(
request: Request, subscription_id: UUID, response: Response, filter_owner_relations: bool = True
) -> dict[str, Any] | None:
def _build_response(model: dict, etag: str) -> dict[str, Any] | None:
Expand All @@ -117,14 +114,9 @@ def _build_response(model: dict, etag: str) -> dict[str, Any] | None:
filtered = format_extended_domain_model(model, filter_owner_relations=filter_owner_relations)
return format_special_types(filtered)

if cache_response := from_redis(subscription_id):
return _build_response(*cache_response)

try:
subscription_model = SubscriptionModel.from_subscription(subscription_id)
extended_model = build_extended_domain_model(subscription_model)
etag = _generate_etag(extended_model)
return _build_response(extended_model, etag)
subscription, etag = await get_subscription_dict(subscription_id)
return _build_response(subscription, etag)
except ValueError as e:
if str(e) == f"Subscription with id: {subscription_id}, does not exist":
raise_status(HTTPStatus.NOT_FOUND, f"Subscription with id: {subscription_id}, not found")
Expand Down
42 changes: 25 additions & 17 deletions orchestrator/graphql/resolvers/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sqlalchemy import Select, func, select
from strawberry.experimental.pydantic.conversion_types import StrawberryTypeFromPydantic

from nwastdlib.asyncio import gather_nice
from orchestrator.db import ProductTable, SubscriptionTable, db
from orchestrator.db.filters import Filter
from orchestrator.db.filters.subscription import (
Expand All @@ -33,7 +34,6 @@
sort_subscriptions,
subscription_sort_fields,
)
from orchestrator.domain.base import SubscriptionModel
from orchestrator.graphql.pagination import Connection
from orchestrator.graphql.schemas.product import ProductModelGraphql
from orchestrator.graphql.schemas.subscription import SubscriptionInterface
Expand All @@ -48,7 +48,7 @@
is_querying_page_data,
to_graphql_result_page,
)
from orchestrator.types import SubscriptionLifecycle
from orchestrator.utils.get_subscription_dict import get_subscription_dict

logger = structlog.get_logger(__name__)
# Note: we can make this more fancy by adding metadata to the field annotation that indicates if a resolver
Expand All @@ -65,26 +65,34 @@
def get_subscription_graphql_type(info: OrchestratorInfo, subscription_name: str) -> StrawberryTypeFromPydantic:
subscription_graphql_type = info.context.graphql_models.get(subscription_name)
if not subscription_graphql_type:
raise GraphQLError(message=f"No graphql type found for {subscription_name}")
logger.warning(message=f"No graphql type found for {subscription_name}")
base_type = info.context.graphql_models.get("subscription")
if not base_type:
raise GraphQLError("No subscription base type found")
return base_type
return subscription_graphql_type


def get_subscription_details(info: OrchestratorInfo, subscription: SubscriptionTable) -> SubscriptionInterface:
async def get_subscription_details(info: OrchestratorInfo, subscription: SubscriptionTable) -> SubscriptionInterface:
from orchestrator.domain import SUBSCRIPTION_MODEL_REGISTRY
from orchestrator.graphql.autoregistration import graphql_subscription_name

subscription_details = SubscriptionModel.from_subscription(subscription.subscription_id)
base_model = subscription_details.__base_type__ if subscription_details.__base_type__ else subscription_details
base_subscription_details = base_model.from_other_lifecycle( # type: ignore
subscription_details, SubscriptionLifecycle.INITIAL, skip_validation=True
)
base_subscription_details.status = subscription_details.status
strawberry_type = get_subscription_graphql_type(info, graphql_subscription_name(base_model.__name__)) # type: ignore
return strawberry_type.from_pydantic(base_subscription_details) # type:ignore
subscription_dict_data, _ = await get_subscription_dict(subscription.subscription_id)

domain_model_type = SUBSCRIPTION_MODEL_REGISTRY[subscription.product.name]
base_model = domain_model_type.__base_type__ or domain_model_type

subscription_name = graphql_subscription_name(base_model.__name__)
subscription_details = base_model.model_validate(subscription_dict_data, strict=False)
subscription_details._db_model = subscription

strawberry_type = get_subscription_graphql_type(info, subscription_name)
return strawberry_type.from_pydantic(subscription_details) # type: ignore


def format_subscription(info: OrchestratorInfo, subscription: SubscriptionTable) -> SubscriptionInterface:
async def format_subscription(info: OrchestratorInfo, subscription: SubscriptionTable) -> SubscriptionInterface:
if _is_subscription_detailed(info):
return get_subscription_details(info, subscription)
return await get_subscription_details(info, subscription)

strawberry_type = get_subscription_graphql_type(info, "subscription")
return strawberry_type.from_pydantic(subscription) # type:ignore
Expand All @@ -94,7 +102,7 @@ async def resolve_subscription(info: OrchestratorInfo, id: UUID) -> Subscription
stmt = select(SubscriptionTable).where(SubscriptionTable.subscription_id == id)

if subscription := db.session.scalar(stmt):
return format_subscription(info, subscription)
return await format_subscription(info, subscription)
return None


Expand Down Expand Up @@ -127,10 +135,10 @@ async def resolve_subscriptions(
total = db.session.scalar(select(func.count()).select_from(stmt.subquery()))
stmt = apply_range_to_statement(stmt, after, after + first + 1)

graphql_subscriptions = []
graphql_subscriptions: list[SubscriptionInterface] = []
if is_querying_page_data(info):
subscriptions = db.session.scalars(stmt).all()
graphql_subscriptions = [format_subscription(info, p) for p in subscriptions]
graphql_subscriptions = list(await gather_nice((format_subscription(info, p) for p in subscriptions)))
logger.info("Resolve subscriptions", filter_by=filter_by, total=graphql_subscriptions)

return to_graphql_result_page(
Expand Down
2 changes: 1 addition & 1 deletion orchestrator/graphql/schemas/product_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def owner_subscription_resolver(
stmt = select(SubscriptionTable).where(SubscriptionTable.subscription_id == root.owner_subscription_id)

if subscription := db.session.scalar(stmt):
return format_subscription(info, subscription)
return await format_subscription(info, subscription)
return None


Expand Down
15 changes: 2 additions & 13 deletions orchestrator/graphql/utils/get_subscription_product_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from pydantic.alias_generators import to_camel as to_lower_camel
from strawberry.scalars import JSON

from orchestrator.domain.base import SubscriptionModel
from orchestrator.graphql.schemas.product_block import owner_subscription_resolver
from orchestrator.services.subscriptions import build_extended_domain_model
from orchestrator.utils.redis import from_redis
from orchestrator.utils.get_subscription_dict import get_subscription_dict

if TYPE_CHECKING:
from orchestrator.graphql.schemas.subscription import SubscriptionInterface
Expand Down Expand Up @@ -62,19 +60,10 @@ def new_product_block(item: dict[str, Any]) -> Generator:
pb_instance_property_keys = ("id", "parent", "owner_subscription_id", "subscription_instance_id", "in_use_by_relations")


async def get_subscription_dict(subscription_id: UUID) -> dict:
if cached_model := from_redis(subscription_id):
subscription, _ = cached_model
else:
subscription_model = SubscriptionModel.from_subscription(subscription_id)
subscription = build_extended_domain_model(subscription_model)
return subscription


async def get_subscription_product_blocks(
subscription_id: UUID, tags: list[str] | None = None, product_block_instance_values: list[str] | None = None
) -> list[ProductBlockInstance]:
subscription = await get_subscription_dict(subscription_id)
subscription, _ = await get_subscription_dict(subscription_id)

def to_product_block(product_block: dict[str, Any]) -> ProductBlockInstance:
def is_resource_type(candidate: Any) -> bool:
Expand Down
17 changes: 17 additions & 0 deletions orchestrator/utils/get_subscription_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from uuid import UUID

from orchestrator.domain.base import SubscriptionModel
from orchestrator.services.subscriptions import _generate_etag, build_extended_domain_model
from orchestrator.utils.redis import from_redis


async def get_subscription_dict(subscription_id: UUID) -> tuple[dict, str]:
"""Helper function to get subscription dict by uuid from db or cache."""

if cached_model := from_redis(subscription_id):
return cached_model # type: ignore

subscription_model = SubscriptionModel.from_subscription(subscription_id)
subscription = build_extended_domain_model(subscription_model)
etag = _generate_etag(subscription)
return subscription, etag
12 changes: 7 additions & 5 deletions orchestrator/utils/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from orchestrator.services.subscriptions import _generate_etag
from orchestrator.settings import app_settings
from orchestrator.utils.json import json_dumps, json_loads
from orchestrator.utils.json import PY_JSON_TYPES, json_dumps, json_loads

logger = get_logger(__name__)

Expand All @@ -37,17 +37,19 @@ def caching_models_enabled() -> bool:
return getenv("AIOCACHE_DISABLE", "0") == "0" and app_settings.CACHE_DOMAIN_MODELS


def to_redis(subscription: dict[str, Any]) -> None:
def to_redis(subscription: dict[str, Any]) -> str | None:
if caching_models_enabled():
logger.info("Setting cache for subscription", subscription=subscription["subscription_id"])
etag = _generate_etag(subscription)
cache.set(f"domain:{subscription['subscription_id']}", json_dumps(subscription), ex=ONE_WEEK)
cache.set(f"domain:etag:{subscription['subscription_id']}", etag, ex=ONE_WEEK)
else:
logger.warning("Caching disabled, not caching subscription", subscription=subscription["subscription_id"])
return etag

logger.warning("Caching disabled, not caching subscription", subscription=subscription["subscription_id"])
return None


def from_redis(subscription_id: UUID) -> tuple[Any, str] | None:
def from_redis(subscription_id: UUID) -> tuple[PY_JSON_TYPES, str] | None:
log = logger.bind(subscription_id=subscription_id)
if caching_models_enabled():
log.debug("Try to retrieve subscription from cache")
Expand Down
2 changes: 1 addition & 1 deletion test/unit_tests/api/test_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def test_subscription_detail_with_in_use_by_ids_filtered_self(test_client, produ
assert not response.json()["block"]["sub_block"]["in_use_by_ids"]


@mock.patch("orchestrator.api.api_v1.endpoints.subscriptions.from_redis")
@mock.patch("orchestrator.api.api_v1.endpoints.subscriptions.get_subscription_dict")
def test_subscription_detail_special_fields(mock_from_redis, test_client):
"""Test that a subscription with special field types is correctly serialized by Pydantic.
Expand Down
36 changes: 36 additions & 0 deletions test/unit_tests/utils/get_subscription_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from os import getenv
from unittest import mock
from unittest.mock import Mock

import pytest

from orchestrator import app_settings
from orchestrator.domain.base import SubscriptionModel
from orchestrator.services.subscriptions import build_extended_domain_model
from orchestrator.utils.get_subscription_dict import get_subscription_dict
from orchestrator.utils.redis import to_redis


@mock.patch.object(app_settings, "CACHE_DOMAIN_MODELS", False)
@mock.patch("orchestrator.utils.get_subscription_dict._generate_etag")
async def test_get_subscription_dict_db(generate_etag, generic_subscription_1):
generate_etag.side_effect = Mock(return_value="etag-mock")
await get_subscription_dict(generic_subscription_1)
assert generate_etag.called


@pytest.mark.skipif(
not getenv("AIOCACHE_DISABLE", "0") == "0", reason="AIOCACHE must be enabled for this test to do anything"
)
@mock.patch("orchestrator.utils.get_subscription_dict._generate_etag")
async def test_get_subscription_dict_cache(generate_etag, generic_subscription_1, cache_fixture):
subscription = SubscriptionModel.from_subscription(generic_subscription_1)
extended_model = build_extended_domain_model(subscription)

# Add domainmodel to cache
to_redis(extended_model)
cache_fixture.extend([f"domain:{generic_subscription_1}", f"domain:etag:{generic_subscription_1}"])

generate_etag.side_effect = Mock(return_value="etag-mock")
await get_subscription_dict(generic_subscription_1)
assert not generate_etag.called

0 comments on commit a66ce09

Please sign in to comment.