diff --git a/lib/dl_api_lib/dl_api_lib/api_common/update_dataset_mutation_key.py b/lib/dl_api_lib/dl_api_lib/api_common/update_dataset_mutation_key.py index 27cd2c39c..45be4f200 100644 --- a/lib/dl_api_lib/dl_api_lib/api_common/update_dataset_mutation_key.py +++ b/lib/dl_api_lib/dl_api_lib/api_common/update_dataset_mutation_key.py @@ -1,17 +1,13 @@ from __future__ import annotations import hashlib -import json -from typing import ( - Any, - List, -) import attr from dl_api_lib.request_model.data import FieldAction +from dl_constants.types import TJSONExt from dl_core.us_manager.mutation_cache.mutation_key_base import MutationKey -from dl_model_tools.serialization import RedisDatalensDataJSONEncoder +from dl_model_tools.serialization import hashable_dumps class MutationKeySerializationError(ValueError): @@ -23,25 +19,23 @@ class UpdateDatasetMutationKey(MutationKey): _dumped: str _hash: str - def get_collision_tier_breaker(self) -> Any: + def get_collision_tier_breaker(self) -> str: return self._dumped def get_hash(self) -> str: return self._hash + @staticmethod + def _dumps(value: TJSONExt) -> str: + return hashable_dumps(value, sort_keys=True, ensure_ascii=True, check_circular=True) + @classmethod - def create(cls, dataset_revision_id: str, updates: List[FieldAction]) -> UpdateDatasetMutationKey: + def create(cls, dataset_revision_id: str, updates: list[FieldAction]) -> UpdateDatasetMutationKey: try: serialized = [upd.serialized for upd in updates] except Exception as e: raise MutationKeySerializationError() from e - serialized.sort(key=lambda x: json.dumps(x, indent=None, sort_keys=True, cls=RedisDatalensDataJSONEncoder)) - dumped = json.dumps( - dict(ds_rev=dataset_revision_id, mutation=serialized), - sort_keys=True, - indent=None, - separators=(",", ":"), - cls=RedisDatalensDataJSONEncoder, - ) + serialized.sort(key=cls._dumps) + dumped = cls._dumps(dict(ds_rev=dataset_revision_id, mutation=serialized)) hashed = hashlib.sha256(dumped.encode()).hexdigest() return UpdateDatasetMutationKey(dumped=dumped, hash=hashed) diff --git a/lib/dl_core/dl_core/connection_executors/qe_serializer/schemas_common.py b/lib/dl_core/dl_core/connection_executors/qe_serializer/schemas_common.py index 69f5fdb51..0f9c17671 100644 --- a/lib/dl_core/dl_core/connection_executors/qe_serializer/schemas_common.py +++ b/lib/dl_core/dl_core/connection_executors/qe_serializer/schemas_common.py @@ -38,8 +38,8 @@ ) from dl_core.enums import QueryExecutorMode from dl_model_tools.serialization import ( - RedisDatalensDataJSONDecoder, - RedisDatalensDataJSONEncoder, + DataLensJSONDecoder, + DataLensJSONEncoder, ) from dl_utils.utils import get_type_full_name @@ -82,13 +82,13 @@ def dump_conn_params(self, dba_query: DBAdapterQuery) -> Optional[dict]: ) if conn_params is not None: for k, v in conn_params.items(): - conn_params[k] = json.dumps(v, cls=RedisDatalensDataJSONEncoder) + conn_params[k] = json.dumps(v, cls=DataLensJSONEncoder) return conn_params def load_conn_params(self, conn_params: Optional[dict]) -> Optional[Dict[str, Any]]: if conn_params is not None: for k, v in conn_params.items(): - conn_params[k] = json.loads(v, cls=RedisDatalensDataJSONDecoder) + conn_params[k] = json.loads(v, cls=DataLensJSONDecoder) return conn_params diff --git a/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_async.py b/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_async.py index 98f81a29b..6b184d5a3 100644 --- a/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_async.py +++ b/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_async.py @@ -1,7 +1,6 @@ from __future__ import annotations import argparse -from functools import partial import logging import pickle import sys @@ -59,7 +58,7 @@ from dl_core.logging_config import configure_logging from dl_dashsql.typed_query.query_serialization import get_typed_query_serializer from dl_dashsql.typed_query.result_serialization import get_typed_query_result_serializer -from dl_model_tools.serialization import hashable_dumps +from dl_model_tools.serialization import safe_dumps from dl_utils.aio import ContextVarExecutor @@ -173,8 +172,7 @@ async def handle_non_stream_query_action( with GenericProfiler("async_qe_serialization"): if self.request.headers.get(HEADER_USE_JSON_SERIALIZER) == "1": - dumps = partial(hashable_dumps, sort_keys=False, check_circular=True) - response = web.json_response(events, dumps=dumps) + response = web.json_response(events, dumps=safe_dumps) else: response = web.Response(body=pickle.dumps(events)) diff --git a/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_sync.py b/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_sync.py index 5ca3ba099..6844ac96e 100644 --- a/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_sync.py +++ b/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_sync.py @@ -57,7 +57,7 @@ from dl_dashsql.typed_query.result_serialization import get_typed_query_result_serializer from dl_model_tools.serialization import ( common_loads, - hashable_dumps, + safe_dumps, ) @@ -298,7 +298,7 @@ def loads(self, s: str | bytes, **kwargs: Any) -> Any: return common_loads(s, **kwargs) def dumps(self, obj: Any, **kwargs: Any) -> str: - return hashable_dumps(obj, sort_keys=False, check_circular=True, **kwargs) + return safe_dumps(obj, **kwargs) def create_sync_app() -> flask.Flask: diff --git a/lib/dl_core/dl_core/us_manager/mutation_cache/mutation_key_base.py b/lib/dl_core/dl_core/us_manager/mutation_cache/mutation_key_base.py index fcdc2fd39..6bb7c7af8 100644 --- a/lib/dl_core/dl_core/us_manager/mutation_cache/mutation_key_base.py +++ b/lib/dl_core/dl_core/us_manager/mutation_cache/mutation_key_base.py @@ -9,7 +9,7 @@ def get_hash(self) -> str: raise NotImplementedError() @abc.abstractmethod - def get_collision_tier_breaker(self) -> Any: + def get_collision_tier_breaker(self) -> str: """Returns less collision-affected but serializable representation of key.""" raise NotImplementedError() diff --git a/lib/dl_model_tools/dl_model_tools/serialization.py b/lib/dl_model_tools/dl_model_tools/serialization.py index d8536b016..43ccc6711 100644 --- a/lib/dl_model_tools/dl_model_tools/serialization.py +++ b/lib/dl_model_tools/dl_model_tools/serialization.py @@ -1,5 +1,5 @@ """ -RDL JSON - Redis DataLens JSON +DataLens JSON serialization tools Serialization with support for custom objects like ``date`` & ``datetime``. """ @@ -12,6 +12,7 @@ import decimal import ipaddress import json +import logging from typing import ( Any, Callable, @@ -33,6 +34,9 @@ from dl_type_transformer.native_type_schema import OneOfNativeTypeSchema +LOGGER = logging.getLogger(__name__) + + _TS_TV = TypeVar("_TS_TV") @@ -264,6 +268,25 @@ def from_jsonable(value: TJSONLike) -> GenericNativeType: return NativeTypeSerializer.schema.load(value) +class UnsupportedSerializer(TypeSerializer[object]): + """ + Special serializer that logs warning and dumps null + instead of an unserializable value + """ + + typename = "unsupported" + + @staticmethod + def to_jsonable(value: object) -> TJSONLike: + LOGGER.warning(f"Value of type {value.__class__.__name__} is not JSON serializable, skipping serialization") + return None + + @staticmethod + def from_jsonable(value: TJSONLike) -> object: + assert value is None + return None + + COMMON_SERIALIZERS: list[Type[TypeSerializer]] = [ DateSerializer, DatetimeSerializer, @@ -279,27 +302,36 @@ def from_jsonable(value: TJSONLike) -> GenericNativeType: IPv4InterfaceSerializer, IPv6InterfaceSerializer, NativeTypeSerializer, + UnsupportedSerializer, ] assert len(set(cls.typename for cls in COMMON_SERIALIZERS)) == len(COMMON_SERIALIZERS), "uniqueness check" -class RedisDatalensDataJSONEncoder(json.JSONEncoder): +class DataLensJSONEncoder(json.JSONEncoder): JSONABLERS_MAP = {cls.typeobj(): cls for cls in COMMON_SERIALIZERS} + def _get_preprocessor(self, typeobj: type) -> Optional[Type[TypeSerializer]]: + if issubclass(typeobj, GenericNativeType): + return NativeTypeSerializer + return self.JSONABLERS_MAP.get(typeobj) + def default(self, obj: Any) -> Any: typeobj = type(obj) - preprocessor: Optional[Type[TypeSerializer]] - if issubclass(typeobj, GenericNativeType): - preprocessor = NativeTypeSerializer - else: - preprocessor = self.JSONABLERS_MAP.get(typeobj) + preprocessor = self._get_preprocessor(typeobj) if preprocessor is not None: return dict(__dl_type__=preprocessor.typename, value=preprocessor.to_jsonable(obj)) return super().default(obj) # effectively, raises `TypeError` -class RedisDatalensDataJSONDecoder(json.JSONDecoder): +class SafeDataLensJSONEncoder(DataLensJSONEncoder): + def _get_preprocessor(self, typeobj: type) -> Optional[Type[TypeSerializer]]: + if (preprocessor := super()._get_preprocessor(typeobj)) is not None: + return preprocessor + return UnsupportedSerializer # don't raise a TypeError and log warning + + +class DataLensJSONDecoder(json.JSONDecoder): DEJSONABLERS_MAP = {cls.typename: cls for cls in COMMON_SERIALIZERS} def __init__( @@ -336,7 +368,7 @@ def object_hook(self, obj: dict[str, Any]) -> Any: def common_dumps(value: TJSONExt, **kwargs: Any) -> bytes: return json.dumps( value, - cls=RedisDatalensDataJSONEncoder, + cls=DataLensJSONEncoder, separators=(",", ":"), ensure_ascii=False, check_circular=False, # dangerous but faster @@ -344,26 +376,40 @@ def common_dumps(value: TJSONExt, **kwargs: Any) -> bytes: ).encode("utf-8") -def hashable_dumps(value: TJSONExt, sort_keys: bool = True, check_circular: bool = False, **kwargs: Any) -> str: +def hashable_dumps( + value: TJSONExt, + sort_keys: bool = True, + check_circular: bool = False, + ensure_ascii: bool = False, + **kwargs: Any, +) -> str: return json.dumps( value, - cls=RedisDatalensDataJSONEncoder, + cls=DataLensJSONEncoder, separators=(",", ":"), - ensure_ascii=False, + ensure_ascii=ensure_ascii, check_circular=check_circular, sort_keys=sort_keys, **kwargs, ) -def common_loads(value: Union[bytes, str], **kwargs: Any) -> TJSONExt: - return json.loads( +def safe_dumps(value: TJSONExt, **kwargs: Any) -> str: + return json.dumps( value, - cls=RedisDatalensDataJSONDecoder, + cls=SafeDataLensJSONEncoder, + separators=(",", ":"), + ensure_ascii=False, + check_circular=True, + sort_keys=False, **kwargs, ) +def common_loads(value: Union[bytes, str], **kwargs: Any) -> TJSONExt: + return json.loads(value, cls=DataLensJSONDecoder, **kwargs) + + class CacheMetadataSerialization: r""" Serialization/deserialization of metadata+data pair. diff --git a/lib/dl_model_tools/dl_model_tools_tests/unit/test_json_serializer.py b/lib/dl_model_tools/dl_model_tools_tests/unit/test_json_serializer.py index b4ef9aaa9..7a885ed0d 100644 --- a/lib/dl_model_tools/dl_model_tools_tests/unit/test_json_serializer.py +++ b/lib/dl_model_tools/dl_model_tools_tests/unit/test_json_serializer.py @@ -4,12 +4,17 @@ import decimal import ipaddress import json +from typing import Any import uuid +import pytest + from dl_model_tools.serialization import ( - RedisDatalensDataJSONDecoder, - RedisDatalensDataJSONEncoder, + common_dumps, + common_loads, + safe_dumps, ) +from dl_testing.utils import get_log_record from dl_type_transformer.native_type import ( ClickHouseDateTime64NativeType, ClickHouseDateTime64WithTZNativeType, @@ -22,7 +27,7 @@ TZINFO = datetime.timezone(datetime.timedelta(seconds=-1320)) -SAMPLE_DATA = dict( +SAMPLE_DATA: dict[str, Any] = dict( # Scalars some_int=42, some_float=19.89, @@ -71,7 +76,7 @@ ) -EXPECTED_DUMP = dict( +EXPECTED_DUMP: dict[str, Any] = dict( # Scalars are unchanged some_int=42, some_float=19.89, @@ -161,17 +166,37 @@ def test_json_serialization(): data = SAMPLE_DATA - dumped = json.dumps(data, cls=RedisDatalensDataJSONEncoder) + dumped = common_dumps(data) dumped_dict = json.loads(dumped) assert dumped_dict == EXPECTED_DUMP - roundtrip = json.loads(dumped, cls=RedisDatalensDataJSONDecoder) + roundtrip = common_loads(dumped) assert roundtrip == data def test_json_tricky_serialization(): tricky_data = dict(normal=SAMPLE_DATA, abnormal=EXPECTED_DUMP) - tricky_data_dumped = json.dumps(tricky_data, cls=RedisDatalensDataJSONEncoder) - tricky_roundtrip = json.loads(tricky_data_dumped, cls=RedisDatalensDataJSONDecoder) + tricky_data_dumped = common_dumps(tricky_data) + tricky_roundtrip = common_loads(tricky_data_dumped) assert tricky_roundtrip["normal"] == tricky_data["normal"], tricky_roundtrip # abnormal data contains __dl_type__ fields, so decoder considers them to be dumps of BI types and decodes them assert tricky_roundtrip["abnormal"] == tricky_data["normal"], tricky_roundtrip + + +class CustomType: + pass + + +def test_safe_json_serialization(caplog): + unserializable_data: dict[str, Any] = SAMPLE_DATA | dict(unserializable=CustomType()) + with pytest.raises(TypeError, match="Object of type CustomType is not JSON serializable"): + common_dumps(unserializable_data) + + safe_dumped = safe_dumps(unserializable_data) + roundtrip = common_loads(safe_dumped) + unserializable_value = roundtrip.pop("unserializable") + assert unserializable_value is None + assert roundtrip == SAMPLE_DATA + + log_record = get_log_record(caplog, predicate=lambda r: r.funcName == "to_jsonable", single=True) + assert log_record.levelname == "WARNING" + assert log_record.msg == "Value of type CustomType is not JSON serializable, skipping serialization"