Skip to content

Commit

Permalink
fix: BI-5673 don't fail on invalid types in RQE JSON serialization (#599
Browse files Browse the repository at this point in the history
)
  • Loading branch information
MCPN authored Sep 5, 2024
1 parent 274a5b8 commit d467d40
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 50 deletions.
26 changes: 10 additions & 16 deletions lib/dl_api_lib/dl_api_lib/api_common/update_dataset_mutation_key.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from __future__ import annotations

import hashlib
import json
from typing import (
Any,
List,
)

import attr

from dl_api_lib.request_model.data import FieldAction
from dl_constants.types import TJSONExt
from dl_core.us_manager.mutation_cache.mutation_key_base import MutationKey
from dl_model_tools.serialization import RedisDatalensDataJSONEncoder
from dl_model_tools.serialization import hashable_dumps


class MutationKeySerializationError(ValueError):
Expand All @@ -23,25 +19,23 @@ class UpdateDatasetMutationKey(MutationKey):
_dumped: str
_hash: str

def get_collision_tier_breaker(self) -> Any:
def get_collision_tier_breaker(self) -> str:
return self._dumped

def get_hash(self) -> str:
return self._hash

@staticmethod
def _dumps(value: TJSONExt) -> str:
return hashable_dumps(value, sort_keys=True, ensure_ascii=True, check_circular=True)

@classmethod
def create(cls, dataset_revision_id: str, updates: List[FieldAction]) -> UpdateDatasetMutationKey:
def create(cls, dataset_revision_id: str, updates: list[FieldAction]) -> UpdateDatasetMutationKey:
try:
serialized = [upd.serialized for upd in updates]
except Exception as e:
raise MutationKeySerializationError() from e
serialized.sort(key=lambda x: json.dumps(x, indent=None, sort_keys=True, cls=RedisDatalensDataJSONEncoder))
dumped = json.dumps(
dict(ds_rev=dataset_revision_id, mutation=serialized),
sort_keys=True,
indent=None,
separators=(",", ":"),
cls=RedisDatalensDataJSONEncoder,
)
serialized.sort(key=cls._dumps)
dumped = cls._dumps(dict(ds_rev=dataset_revision_id, mutation=serialized))
hashed = hashlib.sha256(dumped.encode()).hexdigest()
return UpdateDatasetMutationKey(dumped=dumped, hash=hashed)
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
)
from dl_core.enums import QueryExecutorMode
from dl_model_tools.serialization import (
RedisDatalensDataJSONDecoder,
RedisDatalensDataJSONEncoder,
DataLensJSONDecoder,
DataLensJSONEncoder,
)
from dl_utils.utils import get_type_full_name

Expand Down Expand Up @@ -82,13 +82,13 @@ def dump_conn_params(self, dba_query: DBAdapterQuery) -> Optional[dict]:
)
if conn_params is not None:
for k, v in conn_params.items():
conn_params[k] = json.dumps(v, cls=RedisDatalensDataJSONEncoder)
conn_params[k] = json.dumps(v, cls=DataLensJSONEncoder)
return conn_params

def load_conn_params(self, conn_params: Optional[dict]) -> Optional[Dict[str, Any]]:
if conn_params is not None:
for k, v in conn_params.items():
conn_params[k] = json.loads(v, cls=RedisDatalensDataJSONDecoder)
conn_params[k] = json.loads(v, cls=DataLensJSONDecoder)
return conn_params


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import argparse
from functools import partial
import logging
import pickle
import sys
Expand Down Expand Up @@ -59,7 +58,7 @@
from dl_core.logging_config import configure_logging
from dl_dashsql.typed_query.query_serialization import get_typed_query_serializer
from dl_dashsql.typed_query.result_serialization import get_typed_query_result_serializer
from dl_model_tools.serialization import hashable_dumps
from dl_model_tools.serialization import safe_dumps
from dl_utils.aio import ContextVarExecutor


Expand Down Expand Up @@ -173,8 +172,7 @@ async def handle_non_stream_query_action(

with GenericProfiler("async_qe_serialization"):
if self.request.headers.get(HEADER_USE_JSON_SERIALIZER) == "1":
dumps = partial(hashable_dumps, sort_keys=False, check_circular=True)
response = web.json_response(events, dumps=dumps)
response = web.json_response(events, dumps=safe_dumps)
else:
response = web.Response(body=pickle.dumps(events))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from dl_dashsql.typed_query.result_serialization import get_typed_query_result_serializer
from dl_model_tools.serialization import (
common_loads,
hashable_dumps,
safe_dumps,
)


Expand Down Expand Up @@ -298,7 +298,7 @@ def loads(self, s: str | bytes, **kwargs: Any) -> Any:
return common_loads(s, **kwargs)

def dumps(self, obj: Any, **kwargs: Any) -> str:
return hashable_dumps(obj, sort_keys=False, check_circular=True, **kwargs)
return safe_dumps(obj, **kwargs)


def create_sync_app() -> flask.Flask:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def get_hash(self) -> str:
raise NotImplementedError()

@abc.abstractmethod
def get_collision_tier_breaker(self) -> Any:
def get_collision_tier_breaker(self) -> str:
"""Returns less collision-affected but serializable representation of key."""
raise NotImplementedError()

Expand Down
76 changes: 61 additions & 15 deletions lib/dl_model_tools/dl_model_tools/serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
RDL JSON - Redis DataLens JSON
DataLens JSON serialization tools
Serialization with support for custom objects like ``date`` & ``datetime``.
"""
Expand All @@ -12,6 +12,7 @@
import decimal
import ipaddress
import json
import logging
from typing import (
Any,
Callable,
Expand All @@ -33,6 +34,9 @@
from dl_type_transformer.native_type_schema import OneOfNativeTypeSchema


LOGGER = logging.getLogger(__name__)


_TS_TV = TypeVar("_TS_TV")


Expand Down Expand Up @@ -264,6 +268,25 @@ def from_jsonable(value: TJSONLike) -> GenericNativeType:
return NativeTypeSerializer.schema.load(value)


class UnsupportedSerializer(TypeSerializer[object]):
"""
Special serializer that logs warning and dumps null
instead of an unserializable value
"""

typename = "unsupported"

@staticmethod
def to_jsonable(value: object) -> TJSONLike:
LOGGER.warning(f"Value of type {value.__class__.__name__} is not JSON serializable, skipping serialization")
return None

@staticmethod
def from_jsonable(value: TJSONLike) -> object:
assert value is None
return None


COMMON_SERIALIZERS: list[Type[TypeSerializer]] = [
DateSerializer,
DatetimeSerializer,
Expand All @@ -279,27 +302,36 @@ def from_jsonable(value: TJSONLike) -> GenericNativeType:
IPv4InterfaceSerializer,
IPv6InterfaceSerializer,
NativeTypeSerializer,
UnsupportedSerializer,
]
assert len(set(cls.typename for cls in COMMON_SERIALIZERS)) == len(COMMON_SERIALIZERS), "uniqueness check"


class RedisDatalensDataJSONEncoder(json.JSONEncoder):
class DataLensJSONEncoder(json.JSONEncoder):
JSONABLERS_MAP = {cls.typeobj(): cls for cls in COMMON_SERIALIZERS}

def _get_preprocessor(self, typeobj: type) -> Optional[Type[TypeSerializer]]:
if issubclass(typeobj, GenericNativeType):
return NativeTypeSerializer
return self.JSONABLERS_MAP.get(typeobj)

def default(self, obj: Any) -> Any:
typeobj = type(obj)
preprocessor: Optional[Type[TypeSerializer]]
if issubclass(typeobj, GenericNativeType):
preprocessor = NativeTypeSerializer
else:
preprocessor = self.JSONABLERS_MAP.get(typeobj)
preprocessor = self._get_preprocessor(typeobj)
if preprocessor is not None:
return dict(__dl_type__=preprocessor.typename, value=preprocessor.to_jsonable(obj))

return super().default(obj) # effectively, raises `TypeError`


class RedisDatalensDataJSONDecoder(json.JSONDecoder):
class SafeDataLensJSONEncoder(DataLensJSONEncoder):
def _get_preprocessor(self, typeobj: type) -> Optional[Type[TypeSerializer]]:
if (preprocessor := super()._get_preprocessor(typeobj)) is not None:
return preprocessor
return UnsupportedSerializer # don't raise a TypeError and log warning


class DataLensJSONDecoder(json.JSONDecoder):
DEJSONABLERS_MAP = {cls.typename: cls for cls in COMMON_SERIALIZERS}

def __init__(
Expand Down Expand Up @@ -336,34 +368,48 @@ def object_hook(self, obj: dict[str, Any]) -> Any:
def common_dumps(value: TJSONExt, **kwargs: Any) -> bytes:
return json.dumps(
value,
cls=RedisDatalensDataJSONEncoder,
cls=DataLensJSONEncoder,
separators=(",", ":"),
ensure_ascii=False,
check_circular=False, # dangerous but faster
**kwargs,
).encode("utf-8")


def hashable_dumps(value: TJSONExt, sort_keys: bool = True, check_circular: bool = False, **kwargs: Any) -> str:
def hashable_dumps(
value: TJSONExt,
sort_keys: bool = True,
check_circular: bool = False,
ensure_ascii: bool = False,
**kwargs: Any,
) -> str:
return json.dumps(
value,
cls=RedisDatalensDataJSONEncoder,
cls=DataLensJSONEncoder,
separators=(",", ":"),
ensure_ascii=False,
ensure_ascii=ensure_ascii,
check_circular=check_circular,
sort_keys=sort_keys,
**kwargs,
)


def common_loads(value: Union[bytes, str], **kwargs: Any) -> TJSONExt:
return json.loads(
def safe_dumps(value: TJSONExt, **kwargs: Any) -> str:
return json.dumps(
value,
cls=RedisDatalensDataJSONDecoder,
cls=SafeDataLensJSONEncoder,
separators=(",", ":"),
ensure_ascii=False,
check_circular=True,
sort_keys=False,
**kwargs,
)


def common_loads(value: Union[bytes, str], **kwargs: Any) -> TJSONExt:
return json.loads(value, cls=DataLensJSONDecoder, **kwargs)


class CacheMetadataSerialization:
r"""
Serialization/deserialization of metadata+data pair.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
import decimal
import ipaddress
import json
from typing import Any
import uuid

import pytest

from dl_model_tools.serialization import (
RedisDatalensDataJSONDecoder,
RedisDatalensDataJSONEncoder,
common_dumps,
common_loads,
safe_dumps,
)
from dl_testing.utils import get_log_record
from dl_type_transformer.native_type import (
ClickHouseDateTime64NativeType,
ClickHouseDateTime64WithTZNativeType,
Expand All @@ -22,7 +27,7 @@


TZINFO = datetime.timezone(datetime.timedelta(seconds=-1320))
SAMPLE_DATA = dict(
SAMPLE_DATA: dict[str, Any] = dict(
# Scalars
some_int=42,
some_float=19.89,
Expand Down Expand Up @@ -71,7 +76,7 @@
)


EXPECTED_DUMP = dict(
EXPECTED_DUMP: dict[str, Any] = dict(
# Scalars are unchanged
some_int=42,
some_float=19.89,
Expand Down Expand Up @@ -161,17 +166,37 @@

def test_json_serialization():
data = SAMPLE_DATA
dumped = json.dumps(data, cls=RedisDatalensDataJSONEncoder)
dumped = common_dumps(data)
dumped_dict = json.loads(dumped)
assert dumped_dict == EXPECTED_DUMP
roundtrip = json.loads(dumped, cls=RedisDatalensDataJSONDecoder)
roundtrip = common_loads(dumped)
assert roundtrip == data


def test_json_tricky_serialization():
tricky_data = dict(normal=SAMPLE_DATA, abnormal=EXPECTED_DUMP)
tricky_data_dumped = json.dumps(tricky_data, cls=RedisDatalensDataJSONEncoder)
tricky_roundtrip = json.loads(tricky_data_dumped, cls=RedisDatalensDataJSONDecoder)
tricky_data_dumped = common_dumps(tricky_data)
tricky_roundtrip = common_loads(tricky_data_dumped)
assert tricky_roundtrip["normal"] == tricky_data["normal"], tricky_roundtrip
# abnormal data contains __dl_type__ fields, so decoder considers them to be dumps of BI types and decodes them
assert tricky_roundtrip["abnormal"] == tricky_data["normal"], tricky_roundtrip


class CustomType:
pass


def test_safe_json_serialization(caplog):
unserializable_data: dict[str, Any] = SAMPLE_DATA | dict(unserializable=CustomType())
with pytest.raises(TypeError, match="Object of type CustomType is not JSON serializable"):
common_dumps(unserializable_data)

safe_dumped = safe_dumps(unserializable_data)
roundtrip = common_loads(safe_dumped)
unserializable_value = roundtrip.pop("unserializable")
assert unserializable_value is None
assert roundtrip == SAMPLE_DATA

log_record = get_log_record(caplog, predicate=lambda r: r.funcName == "to_jsonable", single=True)
assert log_record.levelname == "WARNING"
assert log_record.msg == "Value of type CustomType is not JSON serializable, skipping serialization"

0 comments on commit d467d40

Please sign in to comment.