From aa99a84d8c13eaec0f1af5739200b2404f0aac54 Mon Sep 17 00:00:00 2001 From: LazeringDeath <94755334+LazeringDeath@users.noreply.github.com> Date: Wed, 24 Jul 2024 14:59:46 -0500 Subject: [PATCH] Replaced internal encoders with message encoder calling public API (#767) * Serialized all scalar data types for message_serializer with all tests passing. * Added array data types to message_serializer * Added enums and sub functions to message_serializer * Added messages and refactored message_serializer * Modified message_serializer to pass all the tests * Switched current serializer with message_serializer and removed sub functions calling it * Replaced test_message_serialzier with test_serializer * Fixed 'Mypy statis analysis' in message_serializer * Changed file names corresponding to it's functionality * Fixed naming issue * Changed 'test_serializer' to 'test_decoder' * Implemented encoder to reuse message types, renamed and reorder files * Fixed docstrings and reordered encoder. * [DRAFT] Message decoder (#780) Implemented decoder and moved helper functions to serialization_strategy. * Creates 2 messages per service, renamed message/fields, and reordered helper functions. * Deleted serialization_strategy with default_value * Deleted test_serializer * Deleted _message.py * Fixed type errors in test encoder/decoder and docstring * Renamed and cleaned helper functions, added initalize() in metadata. * Fixed sytleguide and type assignment. * Pass enum type in initialize() and moved it in ParameterMetadata * Changed EnumType to Enum * Changed isinstance() to type() in _create_enum_type_class * Add correct type hint to enum_type passing in initialize() and added helper functions for enums --- .../_internal/grpc_servicer.py | 56 ++-- .../_internal/parameter/_get_type.py | 36 +++ .../_internal/parameter/_message.py | 180 ------------ .../_internal/parameter/_serializer_types.py | 26 -- .../_internal/parameter/decoder.py | 66 +++++ .../_internal/parameter/encoder.py | 84 ++++++ .../_internal/parameter/metadata.py | 59 +++- .../parameter/serialization_descriptors.py | 122 ++++++++ .../parameter/serialization_strategy.py | 251 ----------------- .../_internal/parameter/serializer.py | 220 --------------- .../_internal/service_manager.py | 21 +- .../measurement/service.py | 6 +- .../{test_serializer.py => test_decoder.py} | 242 +++++----------- .../service/tests/unit/test_default_value.py | 30 ++ packages/service/tests/unit/test_encoder.py | 261 ++++++++++++++++++ .../tests/unit/test_serialization_strategy.py | 89 ------ 16 files changed, 778 insertions(+), 971 deletions(-) create mode 100644 packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_get_type.py delete mode 100644 packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_message.py delete mode 100644 packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_serializer_types.py create mode 100644 packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py create mode 100644 packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py create mode 100644 packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py delete mode 100644 packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py delete mode 100644 packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serializer.py rename packages/service/tests/unit/{test_serializer.py => test_decoder.py} (68%) create mode 100644 packages/service/tests/unit/test_default_value.py create mode 100644 packages/service/tests/unit/test_encoder.py delete mode 100644 packages/service/tests/unit/test_serialization_strategy.py diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py index ae29db78e..ac78cbe3c 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/grpc_servicer.py @@ -13,8 +13,10 @@ import grpc from google.protobuf import any_pb2 -from ni_measurement_plugin_sdk_service._internal.parameter import serializer -from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ParameterMetadata +from ni_measurement_plugin_sdk_service._internal.parameter import decoder, encoder +from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( + ParameterMetadata, +) from ni_measurement_plugin_sdk_service._internal.stubs.ni.measurementlink.measurement.v1 import ( measurement_service_pb2 as v1_measurement_service_pb2, measurement_service_pb2_grpc as v1_measurement_service_pb2_grpc, @@ -23,7 +25,10 @@ measurement_service_pb2 as v2_measurement_service_pb2, measurement_service_pb2_grpc as v2_measurement_service_pb2_grpc, ) -from ni_measurement_plugin_sdk_service.measurement.info import MeasurementInfo +from ni_measurement_plugin_sdk_service.measurement.info import ( + MeasurementInfo, + ServiceInfo, +) from ni_measurement_plugin_sdk_service.session_management import PinMapContext @@ -131,9 +136,13 @@ def _get_mapping_by_parameter_name( return mapping_by_variable_name -def _serialize_outputs(output_metadata: Dict[int, ParameterMetadata], outputs: Any) -> any_pb2.Any: +def _serialize_outputs( + output_metadata: Dict[int, ParameterMetadata], outputs: Any, service_name: str +) -> any_pb2.Any: if isinstance(outputs, collections.abc.Sequence): - return any_pb2.Any(value=serializer.serialize_parameters(output_metadata, outputs)) + return any_pb2.Any( + value=encoder.serialize_parameters(output_metadata, outputs, service_name) + ) elif outputs is None: raise ValueError(f"Measurement function returned None") else: @@ -161,6 +170,7 @@ def __init__( output_parameter_list: List[ParameterMetadata], measure_function: Callable, owner: object, + service_info: ServiceInfo, ) -> None: """Initialize the measurement v1 servicer.""" super().__init__() @@ -169,6 +179,7 @@ def __init__( self._measurement_info = measurement_info self._measure_function = measure_function self._owner = weakref.ref(owner) if owner is not None else None # avoid reference cycle + self._service_info = service_info def GetMetadata( # noqa: N802 - function name should be lowercase self, request: v1_measurement_service_pb2.GetMetadataRequest, context: grpc.ServicerContext @@ -193,8 +204,8 @@ def GetMetadata( # noqa: N802 - function name should be lowercase ) measurement_signature.configuration_parameters.append(configuration_parameter) - measurement_signature.configuration_defaults.value = serializer.serialize_default_values( - self._configuration_metadata + measurement_signature.configuration_defaults.value = encoder.serialize_default_values( + self._configuration_metadata, self._service_info.service_class + ".Configurations" ) for field_number, output_metadata in self._output_metadata.items(): @@ -224,8 +235,10 @@ def Measure( # noqa: N802 - function name should be lowercase self, request: v1_measurement_service_pb2.MeasureRequest, context: grpc.ServicerContext ) -> v1_measurement_service_pb2.MeasureResponse: """RPC API that executes the registered measurement method.""" - mapping_by_id = serializer.deserialize_parameters( - self._configuration_metadata, request.configuration_parameters.value + mapping_by_id = decoder.deserialize_parameters( + self._configuration_metadata, + request.configuration_parameters.value, + self._service_info.service_class + ".Configurations", ) mapping_by_variable_name = _get_mapping_by_parameter_name( mapping_by_id, self._measure_function @@ -252,9 +265,14 @@ def Measure( # noqa: N802 - function name should be lowercase measurement_service_context.get().mark_complete() measurement_service_context.reset(token) - def _serialize_response(self, outputs: Any) -> v1_measurement_service_pb2.MeasureResponse: + def _serialize_response( + self, + outputs: Any, + ) -> v1_measurement_service_pb2.MeasureResponse: return v1_measurement_service_pb2.MeasureResponse( - outputs=_serialize_outputs(self._output_metadata, outputs) + outputs=_serialize_outputs( + self._output_metadata, outputs, self._service_info.service_class + ".Outputs" + ) ) @@ -268,6 +286,7 @@ def __init__( output_parameter_list: List[ParameterMetadata], measure_function: Callable, owner: object, + service_info: ServiceInfo, ) -> None: """Initialize the measurement v2 servicer.""" super().__init__() @@ -276,6 +295,7 @@ def __init__( self._measurement_info = measurement_info self._measure_function = measure_function self._owner = weakref.ref(owner) if owner is not None else None # avoid reference cycle + self._service_info = service_info def GetMetadata( # noqa: N802 - function name should be lowercase self, request: v2_measurement_service_pb2.GetMetadataRequest, context: grpc.ServicerContext @@ -301,8 +321,8 @@ def GetMetadata( # noqa: N802 - function name should be lowercase ) measurement_signature.configuration_parameters.append(configuration_parameter) - measurement_signature.configuration_defaults.value = serializer.serialize_default_values( - self._configuration_metadata + measurement_signature.configuration_defaults.value = encoder.serialize_default_values( + self._configuration_metadata, self._service_info.service_class + ".Configurations" ) for field_number, output_metadata in self._output_metadata.items(): @@ -334,8 +354,10 @@ def Measure( # noqa: N802 - function name should be lowercase self, request: v2_measurement_service_pb2.MeasureRequest, context: grpc.ServicerContext ) -> Generator[v2_measurement_service_pb2.MeasureResponse, None, None]: """RPC API that executes the registered measurement method.""" - mapping_by_id = serializer.deserialize_parameters( - self._configuration_metadata, request.configuration_parameters.value + mapping_by_id = decoder.deserialize_parameters( + self._configuration_metadata, + request.configuration_parameters.value, + self._service_info.service_class + ".Configurations", ) mapping_by_variable_name = _get_mapping_by_parameter_name( mapping_by_id, self._measure_function @@ -363,5 +385,7 @@ def Measure( # noqa: N802 - function name should be lowercase def _serialize_response(self, outputs: Any) -> v2_measurement_service_pb2.MeasureResponse: return v2_measurement_service_pb2.MeasureResponse( - outputs=_serialize_outputs(self._output_metadata, outputs) + outputs=_serialize_outputs( + self._output_metadata, outputs, self._service_info.service_class + ".Outputs" + ) ) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_get_type.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_get_type.py new file mode 100644 index 000000000..916680889 --- /dev/null +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_get_type.py @@ -0,0 +1,36 @@ +from typing import Any + +from google.protobuf.descriptor_pb2 import FieldDescriptorProto +from google.protobuf.type_pb2 import Field + +_TYPE_DEFAULT_MAPPING = { + Field.TYPE_FLOAT: float(), + Field.TYPE_DOUBLE: float(), + Field.TYPE_INT32: int(), + Field.TYPE_INT64: int(), + Field.TYPE_UINT32: int(), + Field.TYPE_UINT64: int(), + Field.TYPE_BOOL: bool(), + Field.TYPE_STRING: str(), + Field.TYPE_ENUM: int(), +} + +TYPE_FIELD_MAPPING = { + Field.TYPE_FLOAT: FieldDescriptorProto.TYPE_FLOAT, + Field.TYPE_DOUBLE: FieldDescriptorProto.TYPE_DOUBLE, + Field.TYPE_INT32: FieldDescriptorProto.TYPE_INT32, + Field.TYPE_INT64: FieldDescriptorProto.TYPE_INT64, + Field.TYPE_UINT32: FieldDescriptorProto.TYPE_UINT32, + Field.TYPE_UINT64: FieldDescriptorProto.TYPE_UINT64, + Field.TYPE_BOOL: FieldDescriptorProto.TYPE_BOOL, + Field.TYPE_STRING: FieldDescriptorProto.TYPE_STRING, + Field.TYPE_ENUM: FieldDescriptorProto.TYPE_ENUM, + Field.TYPE_MESSAGE: FieldDescriptorProto.TYPE_MESSAGE, +} + + +def get_type_default(type: Field.Kind.ValueType, repeated: bool) -> Any: + """Get the default value for the give type.""" + if repeated: + return list() + return _TYPE_DEFAULT_MAPPING.get(type) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_message.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_message.py deleted file mode 100644 index dcc04b146..000000000 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_message.py +++ /dev/null @@ -1,180 +0,0 @@ -import struct -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -from google.protobuf.internal import encoder, wire_format -from google.protobuf.message import Message - -from ni_measurement_plugin_sdk_service._internal.parameter._serializer_types import ( - Decoder, - Key, - NewDefault, - WriteFunction, -) - - -def _message_encoder_constructor( - field_index: int, is_repeated: bool, is_packed: bool -) -> Callable[[WriteFunction, Union[Message, List[Message]], bool], int]: - """Mimics google.protobuf.internal.MessageEncoder. - - This function was forked in order to call SerializeToString instead of _InternalSerialize. - - _InternalSerialize is only defined for the pure-Python protobuf implementation. Our child - messages (like DoubleXYData) are defined in .proto files, so they use whichever protobuf - implementation that google.protobuf.internal.api_implementation chooses (usually upb). - """ - tag = encoder.TagBytes(field_index, wire_format.WIRETYPE_LENGTH_DELIMITED) - encode_varint = _varint_encoder() - - if is_repeated: - - def _encode_repeated_message( - write: WriteFunction, value: Union[Message, List[Message]], deterministic: bool - ) -> int: - bytes_written = 0 - for element in cast(List[Message], value): - write(tag) - bytes = element.SerializeToString() - encode_varint(write, len(bytes), deterministic) - bytes_written += write(bytes) - return bytes_written - - return _encode_repeated_message - else: - - def _encode_message( - write: WriteFunction, value: Union[Message, List[Message]], deterministic: bool - ) -> int: - write(tag) - bytes = cast(Message, value).SerializeToString() - encode_varint(write, len(bytes), deterministic) - return write(bytes) - - return _encode_message - - -def _varint_encoder() -> Callable[[WriteFunction, int, Optional[bool]], int]: - """Return an encoder for a basic varint value (does not include tag). - - From google.protobuf.internal.encoder.py _VarintEncoder - """ - local_int2byte = struct.Struct(">B").pack - - def encode_varint( - write: WriteFunction, value: int, unused_deterministic: Optional[bool] = None - ) -> int: - bits = value & 0x7F - value >>= 7 - while value: - write(local_int2byte(0x80 | bits)) - bits = value & 0x7F - value >>= 7 - return write(local_int2byte(bits)) - - return encode_varint - - -def _message_decoder_constructor( - field_index: int, is_repeated: bool, is_packed: bool, key: Key, new_default: NewDefault -) -> Decoder: - """Mimics google.protobuf.internal.MessageDecoder. - - This function was forked in order to call ParseFromString instead of _InternalParse. - - _InternalParse is only defined for the pure-Python protobuf implementation. Our child messages - (like DoubleXYData) are defined in .proto files, so they use whichever protobuf implementation - that google.protobuf.internal.api_implementation chooses (usually upb). - """ - if is_repeated: - tag_bytes = encoder.TagBytes(field_index, wire_format.WIRETYPE_LENGTH_DELIMITED) - tag_len = len(tag_bytes) - - def _decode_repeated_message( - buffer: memoryview, pos: int, end: int, message: Message, field_dict: Dict[Key, Any] - ) -> int: - decode_varint = _varint_decoder(mask=(1 << 64) - 1, result_type=int) - value = field_dict.get(key) - if value is None: - value = field_dict.setdefault(key, []) - while 1: - parsed_value = new_default(message) - # Read length. - (size, pos) = decode_varint(buffer, pos) - new_pos = pos + size - if new_pos > end: - raise ValueError("Error decoding a message. Message is truncated.") - parsed_bytes = parsed_value.ParseFromString(buffer[pos:new_pos]) - if parsed_bytes != size: - raise ValueError("Parsed incorrect number of bytes.") - value.append(parsed_value) - # Predict that the next tag is another copy of the same repeated field. - pos = new_pos + tag_len - if buffer[new_pos:pos] != tag_bytes or new_pos == end: - # Prediction failed. Return. - return new_pos - - return _decode_repeated_message - else: - - def _decode_message( - buffer: memoryview, pos: int, end: int, message: Message, field_dict: Dict[Key, Any] - ) -> int: - decode_varint = _varint_decoder(mask=(1 << 64) - 1, result_type=int) - value = field_dict.get(key) - if value is None: - value = field_dict.setdefault(key, new_default(message)) - # Read length. - (size, pos) = decode_varint(buffer, pos) - new_pos = pos + size - if new_pos > end: - raise ValueError("Error decoding a message. Message is truncated.") - parsed_bytes = value.ParseFromString(buffer[pos:new_pos]) - if parsed_bytes != size: - raise ValueError("Parsed incorrect number of bytes.") - return new_pos - - return _decode_message - - -T = TypeVar("T", bound="int") - - -def _varint_decoder(mask: int, result_type: Type[T]) -> Callable[[memoryview, int], Tuple[T, int]]: - """Return an encoder for a basic varint value (does not include tag). - - Decoded values will be bitwise-anded with the given mask before being - returned, e.g. to limit them to 32 bits. The returned decoder does not - take the usual "end" parameter -- the caller is expected to do bounds checking - after the fact (often the caller can defer such checking until later). The - decoder returns a (value, new_pos) pair. - - From google.protobuf.internal.decoder.py _VarintDecoder - """ - - def decode_varint(buffer: memoryview, pos: int) -> Tuple[T, int]: - result = 0 - shift = 0 - while 1: - b = buffer[pos] - result |= (b & 0x7F) << shift - pos += 1 - if not (b & 0x80): - result &= mask - result = result_type(result) - return (result, pos) - shift += 7 - if shift >= 64: - raise ValueError("Too many bytes when decoding varint: {shift}") - - return decode_varint diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_serializer_types.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_serializer_types.py deleted file mode 100644 index a34d66352..000000000 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/_serializer_types.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -import sys -import typing -from typing import Any, Callable, Dict - -from google.protobuf.descriptor import FieldDescriptor -from google.protobuf.message import Message - -if typing.TYPE_CHECKING: - if sys.version_info >= (3, 10): - from typing import TypeAlias - else: - from typing_extensions import TypeAlias - - -Key: TypeAlias = FieldDescriptor -WriteFunction: TypeAlias = Callable[[bytes], int] -Encoder: TypeAlias = Callable[[WriteFunction, bytes, bool], int] -PartialEncoderConstructor: TypeAlias = Callable[[int], Encoder] -EncoderConstructor: TypeAlias = Callable[[int, bool, bool], Encoder] - -Decoder: TypeAlias = Callable[[memoryview, int, int, Message, Dict[Key, Any]], int] -PartialDecoderConstructor: TypeAlias = Callable[[int, Key], Decoder] -NewDefault: TypeAlias = Callable[[Message], Message] -DecoderConstructor: TypeAlias = Callable[[int, bool, bool, Key, NewDefault], Decoder] diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py new file mode 100644 index 000000000..b11c7cb24 --- /dev/null +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/decoder.py @@ -0,0 +1,66 @@ +"""Parameter Serializer.""" + +from typing import Any, Dict + +from google.protobuf import descriptor_pool, message_factory +from google.protobuf.descriptor_pb2 import FieldDescriptorProto + +from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( + ParameterMetadata, +) +from ni_measurement_plugin_sdk_service._internal.parameter.serialization_descriptors import ( + is_protobuf, +) + + +def deserialize_parameters( + parameter_metadata_dict: Dict[int, ParameterMetadata], + parameter_bytes: bytes, + service_name: str, +) -> Dict[int, Any]: + """Deserialize the bytes of the parameter based on the metadata. + + Args: + parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. + + parameter_bytes (bytes): Byte string to deserialize. + + service_name (str): Unique service name. + + Returns: + Dict[int, Any]: Deserialized parameters by ID. + """ + pool = descriptor_pool.Default() + message_proto = pool.FindMessageTypeByName(service_name) + message_instance = message_factory.GetMessageClass(message_proto)() + parameter_values = {} + + message_instance.ParseFromString(parameter_bytes) + for i in message_proto.fields_by_number.keys(): + parameter_metadata = parameter_metadata_dict[i] + field_name = parameter_metadata.field_name + value = getattr(message_instance, field_name) + + if parameter_metadata.type == FieldDescriptorProto.TYPE_ENUM: + parameter_values[i] = _deserialize_enum_parameter(value, parameter_metadata) + elif ( + parameter_metadata.type == FieldDescriptorProto.TYPE_MESSAGE + and not parameter_metadata.repeated + and value.ByteSize() == 0 + ): + parameter_values[i] = None + else: + parameter_values[i] = value + return parameter_values + + +def _deserialize_enum_parameter(field_value: Any, metadata: ParameterMetadata) -> Any: + """Convert enum into their user defined enum type.""" + enum_type = metadata.enum_type + if is_protobuf(enum_type): + return field_value + + assert enum_type is not None + if metadata.repeated: + return [enum_type(value) for value in field_value] + return enum_type(field_value) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py new file mode 100644 index 000000000..0fd12ead9 --- /dev/null +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/encoder.py @@ -0,0 +1,84 @@ +"""Parameter Serializer.""" + +from enum import Enum +from typing import Any, Dict, Sequence + +from google.protobuf import descriptor_pool, message_factory +from google.protobuf.descriptor_pb2 import FieldDescriptorProto + +from ni_measurement_plugin_sdk_service._internal.parameter._get_type import ( + get_type_default, +) +from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( + ParameterMetadata, +) + + +def serialize_parameters( + parameter_metadata_dict: Dict[int, ParameterMetadata], + parameter_values: Sequence[Any], + service_name: str, +) -> bytes: + """Serialize the parameter values in same order based on the metadata_dict. + + Args: + parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. + + parameter_values (Sequence[Any]): Parameter values to serialize. + + service_name (str): Unique service name. + + Returns: + bytes: Serialized byte string containing parameter values. + """ + pool = descriptor_pool.Default() + message_proto = pool.FindMessageTypeByName(service_name) + message_instance = message_factory.GetMessageClass(message_proto)() + + for i, parameter in enumerate(parameter_values, start=1): + parameter_metadata = parameter_metadata_dict[i] + field_name = parameter_metadata.field_name + parameter = _get_enum_values(param=parameter) + type_default_value = get_type_default(parameter_metadata.type, parameter_metadata.repeated) + + # Doesn't assign default values or None values to fields + if parameter != type_default_value and parameter is not None: + if parameter_metadata.repeated: + getattr(message_instance, field_name).extend(parameter) + elif parameter_metadata.type == FieldDescriptorProto.TYPE_MESSAGE: + getattr(message_instance, field_name).CopyFrom(parameter) + else: + setattr(message_instance, field_name, parameter) + return message_instance.SerializeToString() + + +def serialize_default_values( + parameter_metadata_dict: Dict[int, ParameterMetadata], service_name: str +) -> bytes: + """Serialize the Default values in the Metadata. + + Args: + parameter_metadata_dict (Dict[int, ParameterMetadata]): Configuration metadata. + + service_name (str): Unique service name. + + Returns: + bytes: Serialized byte string containing default values. + """ + default_value_parameter_array = [ + parameter.default_value for parameter in parameter_metadata_dict.values() + ] + return serialize_parameters( + parameter_metadata_dict, default_value_parameter_array, service_name + ) + + +def _get_enum_values(param: Any) -> Any: + """Get's value of an enum.""" + if param == []: + return param + if isinstance(param, list) and isinstance(param[0], Enum): + return [x.value for x in param] + elif isinstance(param, Enum): + return param.value + return param diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py index f6aff41c8..1abdadeaf 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/metadata.py @@ -1,8 +1,10 @@ """Contains classes that represents metadata.""" +from __future__ import annotations + import json from enum import Enum -from typing import Any, Dict, Iterable, NamedTuple +from typing import Any, Dict, Iterable, NamedTuple, Union, Type, Optional, TYPE_CHECKING from google.protobuf import type_pb2 @@ -10,10 +12,18 @@ ENUM_VALUES_KEY, TYPE_SPECIALIZATION_KEY, ) -from ni_measurement_plugin_sdk_service._internal.parameter import serialization_strategy +from ni_measurement_plugin_sdk_service._internal.parameter._get_type import ( + get_type_default, +) from ni_measurement_plugin_sdk_service.measurement.info import TypeSpecialization +if TYPE_CHECKING: + from google.protobuf.internal.enum_type_wrapper import _EnumTypeWrapper + + SupportedEnumType = Union[Type[Enum], _EnumTypeWrapper] + + class ParameterMetadata(NamedTuple): """Class that represents the metadata of parameters.""" @@ -40,6 +50,41 @@ class ParameterMetadata(NamedTuple): Required when 'type' is Kind.TypeMessage. Ignored for any other 'type'. """ + field_name: str = "" + """display_name in snake_case format.""" + + enum_type: Optional[SupportedEnumType] = None + """Enum type of parameter""" + + @staticmethod + def initialize( + display_name: str, + type: type_pb2.Field.Kind.ValueType, + repeated: bool, + default_value: Any, + annotations: Dict[str, str], + message_type: str = "", + enum_type: Optional[SupportedEnumType] = None, + ) -> "ParameterMetadata": + """Initialize ParameterMetadata with field_name.""" + underscore_display_name = display_name.replace(" ", "_") + if all(char.isalnum() or char == "_" for char in underscore_display_name): + field_name = underscore_display_name + else: + field_name = "".join( + char for char in underscore_display_name if char.isalnum() or char == "_" + ) + return ParameterMetadata( + display_name, + type, + repeated, + default_value, + annotations, + message_type, + field_name, + enum_type, + ) + def validate_default_value_type(parameter_metadata: ParameterMetadata) -> None: """Validate and raise exception if the default value does not match the type info. @@ -54,18 +99,12 @@ def validate_default_value_type(parameter_metadata: ParameterMetadata) -> None: if default_value is None: return None - expected_type = type( - serialization_strategy.get_type_default( - parameter_metadata.type, parameter_metadata.repeated - ) - ) + expected_type = type(get_type_default(parameter_metadata.type, parameter_metadata.repeated)) display_name = parameter_metadata.display_name enum_values_annotation = get_enum_values_annotation(parameter_metadata) if parameter_metadata.repeated: - expected_element_type = type( - serialization_strategy.get_type_default(parameter_metadata.type, False) - ) + expected_element_type = type(get_type_default(parameter_metadata.type, False)) _validate_default_value_type_for_repeated_type( default_value, expected_type, diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py new file mode 100644 index 000000000..9800f1199 --- /dev/null +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_descriptors.py @@ -0,0 +1,122 @@ +"""Serialization Descriptors.""" + +from __future__ import annotations + +from enum import Enum, EnumMeta +from json import loads +from typing import TYPE_CHECKING, List, Type, Union, Optional + +from google.protobuf.descriptor_pb2 import ( + DescriptorProto, + FieldDescriptorProto, + FileDescriptorProto, +) +from google.protobuf.descriptor_pool import DescriptorPool + +from ni_measurement_plugin_sdk_service._annotations import ENUM_VALUES_KEY +from ni_measurement_plugin_sdk_service._internal.parameter._get_type import ( + TYPE_FIELD_MAPPING, +) +from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( + ParameterMetadata, +) + +if TYPE_CHECKING: + from google.protobuf.internal.enum_type_wrapper import _EnumTypeWrapper + + SupportedEnumType = Union[Type[Enum], _EnumTypeWrapper] + + +def is_protobuf(enum_type: Optional[SupportedEnumType]) -> bool: + """Finds if 'enum_type' is a protobuf or a python enum.""" + return hasattr(enum_type, "ValueType") + + +def _get_enum_type_name(metadata: ParameterMetadata) -> str: + """Get's enum type name from a 'parameter_metadata'.""" + enum_type = metadata.enum_type + if enum_type is None: + raise ValueError("Enum type cannot be None in ParameterMetadata.") + + if is_protobuf(enum_type) and not isinstance(enum_type, EnumMeta): + return enum_type.DESCRIPTOR.name + return enum_type.__name__ + + +def _create_enum_type_class( + file_descriptor: FileDescriptorProto, + metadata: ParameterMetadata, + field_descriptor: FieldDescriptorProto, +) -> None: + """Implement a enum class in 'file_descriptor'.""" + enum_dict = loads(metadata.annotations[ENUM_VALUES_KEY]) + enum_type_name = _get_enum_type_name(metadata) + + if enum_type_name not in [file_enum.name for file_enum in file_descriptor.enum_type]: + enum_descriptor = file_descriptor.enum_type.add() + enum_descriptor.name = enum_type_name + for name, number in enum_dict.items(): + enum_value_descriptor = enum_descriptor.value.add() + enum_value_descriptor.name = f"{enum_type_name}_{name}" + enum_value_descriptor.number = number + field_descriptor.type_name = enum_type_name + + +def _create_field( + message_proto: DescriptorProto, metadata: ParameterMetadata, index: int +) -> FieldDescriptorProto: + """Implement a field in 'message_proto'.""" + field_descriptor = message_proto.field.add() + field_descriptor.number = index + field_descriptor.name = metadata.field_name + field_descriptor.type = TYPE_FIELD_MAPPING[metadata.type] + + if metadata.repeated: + field_descriptor.label = FieldDescriptorProto.LABEL_REPEATED + field_descriptor.options.packed = True + else: + field_descriptor.label = FieldDescriptorProto.LABEL_OPTIONAL + + if metadata.type == FieldDescriptorProto.TYPE_MESSAGE: + field_descriptor.type_name = metadata.message_type + return field_descriptor + + +def _create_message_type( + parameter_metadata: List[ParameterMetadata], + message_name: str, + file_descriptor: FileDescriptorProto, +) -> None: + """Creates a message type with fields intialized in 'file_descriptor'.""" + message_proto = file_descriptor.message_type.add() + message_proto.name = message_name + + # Initialize the message with fields defined + for i, metadata in enumerate(parameter_metadata): + field_descriptor = _create_field( + message_proto=message_proto, metadata=metadata, index=i + 1 + ) + if metadata.type == FieldDescriptorProto.TYPE_ENUM: + _create_enum_type_class( + file_descriptor=file_descriptor, + metadata=metadata, + field_descriptor=field_descriptor, + ) + + +def create_file_descriptor( + service_name: str, + output_metadata: List[ParameterMetadata], + input_metadata: List[ParameterMetadata], + pool: DescriptorPool, +) -> None: + """Creates two message types in one file descriptor proto.""" + try: + pool.FindFileByName(service_name) + except KeyError: + file_descriptor = FileDescriptorProto() + file_descriptor.name = service_name + file_descriptor.package = service_name + _create_message_type(input_metadata, "Configurations", file_descriptor) + _create_message_type(output_metadata, "Outputs", file_descriptor) + pool.Add(file_descriptor) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py deleted file mode 100644 index 88fc5cd84..000000000 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serialization_strategy.py +++ /dev/null @@ -1,251 +0,0 @@ -"""Serialization Strategy.""" - -from __future__ import annotations - -from typing import Any, Optional, cast - -from google.protobuf import type_pb2 -from google.protobuf.internal import decoder, encoder -from google.protobuf.message import Message - -from ni_measurement_plugin_sdk_service._internal.parameter import _message -from ni_measurement_plugin_sdk_service._internal.parameter._serializer_types import ( - Decoder, - DecoderConstructor, - Encoder, - EncoderConstructor, - Key, - PartialDecoderConstructor, - PartialEncoderConstructor, -) -from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import xydata_pb2 - - -def _scalar_encoder(encoder: EncoderConstructor) -> PartialEncoderConstructor: - """Constructs a scalar encoder constructor. - - Takes a field index and returns an Encoder. - - This class returns the Encoder with is_repeated set to False - and is_packed set to False. - """ - - def scalar_encoder(field_index: int) -> Encoder: - is_repeated = False - is_packed = False - return encoder(field_index, is_repeated, is_packed) - - return scalar_encoder - - -def _vector_encoder( - encoder: EncoderConstructor, is_packed: bool = True -) -> PartialEncoderConstructor: - """Constructs a vector (array) encoder constructor. - - Takes a field index and returns an Encoder. - - This class returns the Encoder with is_repeated set to True - and is_packed defaulting to True. - """ - - def vector_encoder(field_index: int) -> Encoder: - is_repeated = True - return encoder(field_index, is_repeated, is_packed) - - return vector_encoder - - -def _scalar_decoder(decoder: DecoderConstructor) -> PartialDecoderConstructor: - """Constructs a scalar decoder constructor. - - Takes a field index and a key and returns a Decoder. - - This class returns the Decoder with is_repeated set to False - and is_packed set to False. - """ - - def _unsupported_new_default(message: Optional[Message]) -> Any: - raise NotImplementedError( - "This function should not be called. Verify that you are using up-to-date and compatible versions of the ni-measurement-plugin-sdk-service and protobuf packages." - ) - - def scalar_decoder(field_index: int, key: Key) -> Decoder: - is_repeated = False - is_packed = False - return decoder(field_index, is_repeated, is_packed, key, _unsupported_new_default) - - return scalar_decoder - - -def _vector_decoder( - decoder: DecoderConstructor, is_packed: bool = True -) -> PartialDecoderConstructor: - """Constructs a vector (array) decoder constructor. - - Takes a field index and a key and returns a Decoder. - - This class returns the Decoder with is_repeated set to True - and is_packed defaulting to True. - """ - - def _new_default(unused_message: Optional[Message] = None) -> Any: - return [] - - def vector_decoder(field_index: int, key: Key) -> Decoder: - is_repeated = True - return decoder(field_index, is_repeated, is_packed, key, _new_default) - - return vector_decoder - - -def _double_xy_data_decoder( - decoder: DecoderConstructor, is_repeated: bool -) -> PartialDecoderConstructor: - """Constructs a DoubleXYData decoder constructor. - - Takes a field index and a key and returns a Decoder for DoubleXYData. - """ - - def _new_default(unused_message: Optional[Message] = None) -> Any: - return xydata_pb2.DoubleXYData() - - def message_decoder(field_index: int, key: Key) -> Decoder: - is_packed = True - return decoder(field_index, is_repeated, is_packed, key, _new_default) - - return message_decoder - - -# Cast works around this issue in typeshed -# https://github.com/python/typeshed/issues/10695 -FloatEncoder = _scalar_encoder(cast(EncoderConstructor, encoder.FloatEncoder)) -DoubleEncoder = _scalar_encoder(cast(EncoderConstructor, encoder.DoubleEncoder)) -IntEncoder = _scalar_encoder(cast(EncoderConstructor, encoder.Int32Encoder)) -UIntEncoder = _scalar_encoder(cast(EncoderConstructor, encoder.UInt32Encoder)) -BoolEncoder = _scalar_encoder(encoder.BoolEncoder) -StringEncoder = _scalar_encoder(encoder.StringEncoder) -MessageEncoder = _scalar_encoder(cast(EncoderConstructor, _message._message_encoder_constructor)) - -FloatArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.FloatEncoder)) -DoubleArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.DoubleEncoder)) -IntArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.Int32Encoder)) -UIntArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.UInt32Encoder)) -BoolArrayEncoder = _vector_encoder(encoder.BoolEncoder) -StringArrayEncoder = _vector_encoder(encoder.StringEncoder, is_packed=False) -MessageArrayEncoder = _vector_encoder( - cast(EncoderConstructor, _message._message_encoder_constructor) -) - -# Cast works around this issue in typeshed -# https://github.com/python/typeshed/issues/10697 -FloatDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.FloatDecoder)) -DoubleDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.DoubleDecoder)) -Int32Decoder = _scalar_decoder(cast(DecoderConstructor, decoder.Int32Decoder)) -UInt32Decoder = _scalar_decoder(cast(DecoderConstructor, decoder.UInt32Decoder)) -Int64Decoder = _scalar_decoder(cast(DecoderConstructor, decoder.Int64Decoder)) -UInt64Decoder = _scalar_decoder(cast(DecoderConstructor, decoder.UInt64Decoder)) -BoolDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.BoolDecoder)) -StringDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.StringDecoder)) -XYDataDecoder = _double_xy_data_decoder(_message._message_decoder_constructor, is_repeated=False) - -FloatArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.FloatDecoder)) -DoubleArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.DoubleDecoder)) -Int32ArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.Int32Decoder)) -UInt32ArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.UInt32Decoder)) -Int64ArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.Int64Decoder)) -UInt64ArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.UInt64Decoder)) -BoolArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.BoolDecoder)) -StringArrayDecoder = _vector_decoder( - cast(DecoderConstructor, decoder.StringDecoder), is_packed=False -) -XYDataArrayDecoder = _double_xy_data_decoder( - _message._message_decoder_constructor, is_repeated=True -) - - -_FIELD_TYPE_TO_ENCODER_MAPPING = { - type_pb2.Field.TYPE_FLOAT: (FloatEncoder, FloatArrayEncoder), - type_pb2.Field.TYPE_DOUBLE: (DoubleEncoder, DoubleArrayEncoder), - type_pb2.Field.TYPE_INT32: (IntEncoder, IntArrayEncoder), - type_pb2.Field.TYPE_INT64: (IntEncoder, IntArrayEncoder), - type_pb2.Field.TYPE_UINT32: (UIntEncoder, UIntArrayEncoder), - type_pb2.Field.TYPE_UINT64: (UIntEncoder, UIntArrayEncoder), - type_pb2.Field.TYPE_BOOL: (BoolEncoder, BoolArrayEncoder), - type_pb2.Field.TYPE_STRING: (StringEncoder, StringArrayEncoder), - type_pb2.Field.TYPE_ENUM: (IntEncoder, IntArrayEncoder), - type_pb2.Field.TYPE_MESSAGE: (MessageEncoder, MessageArrayEncoder), -} - -_FIELD_TYPE_TO_DECODER_MAPPING = { - type_pb2.Field.TYPE_FLOAT: (FloatDecoder, FloatArrayDecoder), - type_pb2.Field.TYPE_DOUBLE: (DoubleDecoder, DoubleArrayDecoder), - type_pb2.Field.TYPE_INT32: (Int32Decoder, Int32ArrayDecoder), - type_pb2.Field.TYPE_INT64: (Int64Decoder, Int64ArrayDecoder), - type_pb2.Field.TYPE_UINT32: (UInt32Decoder, UInt32ArrayDecoder), - type_pb2.Field.TYPE_UINT64: (UInt64Decoder, UInt64ArrayDecoder), - type_pb2.Field.TYPE_BOOL: (BoolDecoder, BoolArrayDecoder), - type_pb2.Field.TYPE_STRING: (StringDecoder, StringArrayDecoder), - type_pb2.Field.TYPE_ENUM: (Int32Decoder, Int32ArrayDecoder), -} - -_TYPE_DEFAULT_MAPPING = { - type_pb2.Field.TYPE_FLOAT: float(), - type_pb2.Field.TYPE_DOUBLE: float(), - type_pb2.Field.TYPE_INT32: int(), - type_pb2.Field.TYPE_INT64: int(), - type_pb2.Field.TYPE_UINT32: int(), - type_pb2.Field.TYPE_UINT64: int(), - type_pb2.Field.TYPE_BOOL: bool(), - type_pb2.Field.TYPE_STRING: str(), - type_pb2.Field.TYPE_ENUM: int(), -} - -_MESSAGE_TYPE_TO_DECODER = { - xydata_pb2.DoubleXYData.DESCRIPTOR.full_name: XYDataDecoder, -} - -_ARRAY_MESSAGE_TYPE_TO_DECODER = { - xydata_pb2.DoubleXYData.DESCRIPTOR.full_name: XYDataArrayDecoder, -} - - -def get_encoder(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> PartialEncoderConstructor: - """Get the appropriate partial encoder constructor for the specified type. - - A scalar or vector constructor is returned based on the 'repeated' parameter. - """ - if type not in _FIELD_TYPE_TO_ENCODER_MAPPING: - raise ValueError(f"Error can not encode type '{type}'") - scalar, array = _FIELD_TYPE_TO_ENCODER_MAPPING[type] - if repeated: - return array - return scalar - - -def get_decoder( - type: type_pb2.Field.Kind.ValueType, repeated: bool, message_type: str = "" -) -> PartialDecoderConstructor: - """Get the appropriate partial decoder constructor for the specified type.""" - decoder_mapping = _FIELD_TYPE_TO_DECODER_MAPPING.get(type) - if decoder_mapping is not None: - scalar_decoder, array_decoder = decoder_mapping - return array_decoder if repeated else scalar_decoder - elif type == type_pb2.Field.Kind.TYPE_MESSAGE: - if repeated: - decoder = _ARRAY_MESSAGE_TYPE_TO_DECODER.get(message_type) - else: - decoder = _MESSAGE_TYPE_TO_DECODER.get(message_type) - if decoder is None: - raise ValueError(f"Unknown message type '{message_type}'") - return decoder - else: - raise ValueError(f"Error can not decode type '{type}'") - - -def get_type_default(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> Any: - """Get the default value for the give type.""" - if repeated: - return list() - type_default_value = _TYPE_DEFAULT_MAPPING.get(type) - return type_default_value diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serializer.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serializer.py deleted file mode 100644 index 71ef39568..000000000 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/parameter/serializer.py +++ /dev/null @@ -1,220 +0,0 @@ -"""Parameter Serializer.""" - -from enum import Enum -from io import BytesIO -from typing import Any, Dict, Sequence, cast - -from google.protobuf.descriptor import FieldDescriptor -from google.protobuf.internal.decoder import ( # type: ignore[attr-defined] - _DecodeSignedVarint32, -) -from google.protobuf.message import Message - -from ni_measurement_plugin_sdk_service._annotations import ( - TYPE_SPECIALIZATION_KEY, -) -from ni_measurement_plugin_sdk_service._internal.parameter import serialization_strategy -from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( - ParameterMetadata, - get_enum_values_annotation, -) -from ni_measurement_plugin_sdk_service.measurement.info import TypeSpecialization - -_GRPC_WIRE_TYPE_BIT_WIDTH = 3 - - -def deserialize_parameters( - parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_bytes: bytes -) -> Dict[int, Any]: - """Deserialize the bytes of the parameter based on the metadata. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. - - parameter_bytes (bytes): Byte string to deserialize. - - Returns: - Dict[int, Any]: Deserialized parameters by ID - """ - # Getting overlapping parameters - overlapping_parameter_by_id = _get_overlapping_parameters( - parameter_metadata_dict, parameter_bytes - ) - - # Deserialization enum parameters to their user-defined type - _deserialize_enum_parameters(parameter_metadata_dict, overlapping_parameter_by_id) - - # Adding missing parameters with type defaults - missing_parameters = _get_missing_parameters( - parameter_metadata_dict, overlapping_parameter_by_id - ) - overlapping_parameter_by_id.update(missing_parameters) - return overlapping_parameter_by_id - - -def serialize_parameters( - parameter_metadata_dict: Dict[int, ParameterMetadata], - parameter_values: Sequence[Any], -) -> bytes: - """Serialize the parameter values in same order based on the metadata_dict. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. - - parameter_value (Sequence[Any]): Parameter values to serialize. - - Returns: - bytes: Serialized byte string containing parameter values. - """ - serialize_buffer = BytesIO() # inner_encoder updates the serialize_buffer - for i, parameter in enumerate(parameter_values): - parameter_metadata = parameter_metadata_dict[i + 1] - encoder = serialization_strategy.get_encoder( - parameter_metadata.type, - parameter_metadata.repeated, - ) - type_default_value = serialization_strategy.get_type_default( - parameter_metadata.type, - parameter_metadata.repeated, - ) - # Convert enum parameters to their underlying value if necessary. - if ( - parameter_metadata.annotations.get(TYPE_SPECIALIZATION_KEY) - == TypeSpecialization.Enum.value - ): - parameter = _get_enum_value(parameter, parameter_metadata.repeated) - # Skipping serialization if the value is None or if its a type default value. - if parameter is not None and parameter != type_default_value: - inner_encoder = encoder(i + 1) - inner_encoder(serialize_buffer.write, parameter, False) - return serialize_buffer.getvalue() - - -def serialize_default_values(parameter_metadata_dict: Dict[int, ParameterMetadata]) -> bytes: - """Serialize the Default values in the Metadata. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Configuration metadata. - - Returns: - bytes: Serialized byte string containing default values. - """ - default_value_parameter_array = list() - default_value_parameter_array = [ - parameter.default_value for parameter in parameter_metadata_dict.values() - ] - return serialize_parameters(parameter_metadata_dict, default_value_parameter_array) - - -def _get_overlapping_parameters( - parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_bytes: bytes -) -> Dict[int, Any]: - """Get the parameters present in both `parameter_metadata_dict` and `parameter_bytes`. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by ID. - - parameter_bytes (bytes): bytes of Parameter that need to be deserialized. - - Raises: - Exception: If the protobuf filed index is invalid. - - Returns: - Dict[int, Any]: Overlapping Parameters by ID. - """ - # inner_decoder update the overlapping_parameters - overlapping_parameters_by_id: Dict[int, Any] = {} - position = 0 - parameter_bytes_memory_view = memoryview(parameter_bytes) - while position < len(parameter_bytes): - (tag, position) = _DecodeSignedVarint32(parameter_bytes_memory_view, position) - field_index = tag >> _GRPC_WIRE_TYPE_BIT_WIDTH - if field_index not in parameter_metadata_dict: - raise Exception( - f"Error occurred while reading the parameter - given protobuf index '{field_index}' is invalid." - ) - field_metadata = parameter_metadata_dict[field_index] - decoder = serialization_strategy.get_decoder( - field_metadata.type, field_metadata.repeated, field_metadata.message_type - ) - inner_decoder = decoder(field_index, cast(FieldDescriptor, field_index)) - position = inner_decoder( - parameter_bytes_memory_view, - position, - len(parameter_bytes), - cast(Message, None), # unused - See serialization_strategy._vector_decoder._new_default - cast(Dict[FieldDescriptor, Any], overlapping_parameters_by_id), - ) - return overlapping_parameters_by_id - - -def _get_missing_parameters( - parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_by_id: Dict[int, Any] -) -> Dict[int, Any]: - """Get the Parameters defined in `parameter_metadata_dict` but not in `parameter_by_id`. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by id. - - parameter_by_id (Dict[int, Any]): Parameters by ID to compare the metadata with. - - Returns: - Dict[int, Any]: Missing parameter(as type defaults) by ID. - """ - missing_parameters = {} - for key, value in parameter_metadata_dict.items(): - if key not in parameter_by_id: - enum_annotations = get_enum_values_annotation(value) - if enum_annotations and not value.repeated: - enum_type = _get_enum_type(value) - missing_parameters[key] = enum_type(0) - else: - missing_parameters[key] = serialization_strategy.get_type_default( - value.type, value.repeated - ) - return missing_parameters - - -def _deserialize_enum_parameters( - parameter_metadata_dict: Dict[int, ParameterMetadata], parameter_by_id: Dict[int, Any] -) -> None: - """Converts all enums in `parameter_by_id` to the user defined enum type. - - Args: - parameter_metadata_dict (Dict[int, ParameterMetadata]): Parameter metadata by id. - - parameter_by_id (Dict[int, Any]): Parameters by ID to compare the metadata with. - """ - for id, value in parameter_by_id.items(): - parameter_metadata = parameter_metadata_dict[id] - if get_enum_values_annotation(parameter_metadata): - enum_type = _get_enum_type(parameter_metadata) - is_protobuf_enum = enum_type is int - if parameter_metadata.repeated: - for index, member_value in enumerate(value): - if is_protobuf_enum: - parameter_by_id[id][index] = member_value - else: - parameter_by_id[id][index] = enum_type(member_value) - else: - if is_protobuf_enum: - parameter_by_id[id] = value - else: - parameter_by_id[id] = enum_type(value) - - -def _get_enum_value(parameter: Any, repeated: bool) -> Any: - if repeated: - if len(parameter) > 0 and isinstance(parameter[0], Enum): - return [x.value for x in parameter] - else: - if isinstance(parameter, Enum): - return parameter.value - return parameter - - -def _get_enum_type(parameter_metadata: ParameterMetadata) -> type: - if parameter_metadata.repeated and len(parameter_metadata.default_value) > 0: - return type(parameter_metadata.default_value[0]) - else: - return type(parameter_metadata.default_value) diff --git a/packages/service/ni_measurement_plugin_sdk_service/_internal/service_manager.py b/packages/service/ni_measurement_plugin_sdk_service/_internal/service_manager.py index e6e5deebe..775c37542 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/_internal/service_manager.py +++ b/packages/service/ni_measurement_plugin_sdk_service/_internal/service_manager.py @@ -3,13 +3,19 @@ import grpc from deprecation import deprecated +from google.protobuf import descriptor_pool from grpc.framework.foundation import logging_pool from ni_measurement_plugin_sdk_service._internal.grpc_servicer import ( MeasurementServiceServicerV1, MeasurementServiceServicerV2, ) -from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ParameterMetadata +from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( + ParameterMetadata, +) +from ni_measurement_plugin_sdk_service._internal.parameter.serialization_descriptors import ( + create_file_descriptor, +) from ni_measurement_plugin_sdk_service._internal.stubs.ni.measurementlink.measurement.v1 import ( measurement_service_pb2_grpc as v1_measurement_service_pb2_grpc, ) @@ -18,7 +24,10 @@ ) from ni_measurement_plugin_sdk_service.discovery import DiscoveryClient, ServiceLocation from ni_measurement_plugin_sdk_service.grpc.loggers import ServerLogger -from ni_measurement_plugin_sdk_service.measurement.info import MeasurementInfo, ServiceInfo +from ni_measurement_plugin_sdk_service.measurement.info import ( + MeasurementInfo, + ServiceInfo, +) _logger = logging.getLogger(__name__) _V1_INTERFACE = "ni.measurementlink.measurement.v1.MeasurementService" @@ -94,6 +103,12 @@ def start( ("grpc.max_send_message_length", -1), ], ) + create_file_descriptor( + service_name=service_info.service_class, + output_metadata=output_parameter_list, + input_metadata=configuration_parameter_list, + pool=descriptor_pool.Default(), + ) for interface in service_info.provided_interfaces: if interface == _V1_INTERFACE: servicer_v1 = MeasurementServiceServicerV1( @@ -102,6 +117,7 @@ def start( output_parameter_list, measure_function, owner, + service_info, ) v1_measurement_service_pb2_grpc.add_MeasurementServiceServicer_to_server( servicer_v1, self._server @@ -113,6 +129,7 @@ def start( output_parameter_list, measure_function, owner, + service_info, ) v2_measurement_service_pb2_grpc.add_MeasurementServiceServicer_to_server( servicer_v2, self._server diff --git a/packages/service/ni_measurement_plugin_sdk_service/measurement/service.py b/packages/service/ni_measurement_plugin_sdk_service/measurement/service.py index ab8718563..ff78e6b1f 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/measurement/service.py +++ b/packages/service/ni_measurement_plugin_sdk_service/measurement/service.py @@ -415,13 +415,14 @@ def configuration( annotations = self._make_annotations_dict( data_type_info.type_specialization, instrument_type=instrument_type, enum_type=enum_type ) - parameter = parameter_metadata.ParameterMetadata( + parameter = parameter_metadata.ParameterMetadata.initialize( display_name, data_type_info.grpc_field_type, data_type_info.repeated, default_value, annotations, data_type_info.message_type, + enum_type, ) parameter_metadata.validate_default_value_type(parameter) self._configuration_parameter_list.append(parameter) @@ -475,13 +476,14 @@ def output( annotations = self._make_annotations_dict( data_type_info.type_specialization, enum_type=enum_type ) - parameter = parameter_metadata.ParameterMetadata( + parameter = parameter_metadata.ParameterMetadata.initialize( display_name, data_type_info.grpc_field_type, data_type_info.repeated, None, annotations, data_type_info.message_type, + enum_type, ) self._output_parameter_list.append(parameter) diff --git a/packages/service/tests/unit/test_serializer.py b/packages/service/tests/unit/test_decoder.py similarity index 68% rename from packages/service/tests/unit/test_serializer.py rename to packages/service/tests/unit/test_decoder.py index bb3655939..9a69ae8e9 100644 --- a/packages/service/tests/unit/test_serializer.py +++ b/packages/service/tests/unit/test_decoder.py @@ -1,21 +1,26 @@ """Contains tests to validate serializer.py.""" from enum import Enum, IntEnum -from typing import Dict, Sequence +from typing import Dict, List, Sequence import pytest -from google.protobuf import any_pb2, type_pb2 +from google.protobuf import any_pb2, descriptor_pb2, descriptor_pool, type_pb2 from ni_measurement_plugin_sdk_service._annotations import ( ENUM_VALUES_KEY, TYPE_SPECIALIZATION_KEY, ) -from ni_measurement_plugin_sdk_service._internal.parameter import serializer +from ni_measurement_plugin_sdk_service._internal.parameter import ( + decoder, + serialization_descriptors, +) from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( ParameterMetadata, TypeSpecialization, ) -from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import xydata_pb2 +from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( + xydata_pb2, +) from tests.utilities.stubs.serialization import test_pb2 from tests.utilities.stubs.serialization.bigmessage_pb2 import BigMessage @@ -52,131 +57,6 @@ class Countries(IntEnum): BIG_MESSAGE_SIZE = 100 -@pytest.mark.parametrize( - "test_values", - [ - [ - 2.0, - 19.2, - 3, - 1, - 2, - 2, - True, - "TestString", - [5.5, 3.3, 1], - [5.5, 3.3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1, String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array, - ], - [ - -0.9999, - -0.9999, - -13, - 1, - 1000, - 2, - True, - "////", - [5.5, -13.3, 1, 0.0, -99.9999], - [5.5, 3.3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1, String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array, - ], - ], -) -def test___serializer___serialize_parameter___successful_serialization(test_values): - default_values = test_values - parameter = _get_test_parameter_by_id(default_values) - - # Custom Serialization - custom_serialized_bytes = serializer.serialize_parameters(parameter, test_values) - - _validate_serialized_bytes(custom_serialized_bytes, test_values) - - -@pytest.mark.parametrize( - "default_values", - [ - [ - 2.0, - 19.2, - 3, - 1, - 2, - 2, - True, - "TestString", - [5.5, 3.3, 1], - [5.5, 3.3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1, String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array, - ], - [ - -0.9999, - -0.9999, - -13, - 1, - 1000, - 2, - True, - "////", - [5.5, -13.3, 1, 0.0, -99.9999], - [5.5, 3.3, 1], - [1, 2, 3, 4], - [0, 1, 399], - [1, 2, 3, 4], - [0, 1, 399], - [True, False, True], - ["String1, String2"], - DifferentColor.ORANGE, - [DifferentColor.TEAL, DifferentColor.BROWN], - Countries.AUSTRALIA, - [Countries.AUSTRALIA, Countries.CANADA], - double_xy_data, - double_xy_data_array, - ], - ], -) -def test___serializer___serialize_default_parameter___successful_serialization(default_values): - parameter = _get_test_parameter_by_id(default_values) - - # Custom Serialization - custom_serialized_bytes = serializer.serialize_default_values(parameter) - - _validate_serialized_bytes(custom_serialized_bytes, default_values) - - @pytest.mark.parametrize( "values", [ @@ -209,9 +89,13 @@ def test___serializer___serialize_default_parameter___successful_serialization(d def test___serializer___deserialize_parameter___successful_deserialization(values): parameter = _get_test_parameter_by_id(values) grpc_serialized_data = _get_grpc_serialized_data(values) + service_name = _test_create_file_descriptor(list(parameter.values()), "deserialize_parameter") - parameter_value_by_id = serializer.deserialize_parameters(parameter, grpc_serialized_data) - + parameter_value_by_id = decoder.deserialize_parameters( + parameter, + grpc_serialized_data, + service_name=service_name, + ) assert list(parameter_value_by_id.values()) == values @@ -243,7 +127,10 @@ def test___empty_buffer___deserialize_parameters___returns_zero_or_empty(): double_xy_data_array, ] parameter = _get_test_parameter_by_id(nonzero_defaults) - parameter_value_by_id = serializer.deserialize_parameters(parameter, bytes()) + service_name = _test_create_file_descriptor(list(parameter.values()), "empty_buffer") + parameter_value_by_id = decoder.deserialize_parameters( + parameter, bytes(), service_name=service_name + ) for key, value in parameter_value_by_id.items(): parameter_metadata = parameter[key] @@ -265,31 +152,19 @@ def test___big_message___deserialize_parameters___returns_parameter_value_by_id( message = _get_big_message(values) serialized_data = message.SerializeToString() expected_parameter_value_by_id = {i + 1: value for (i, value) in enumerate(values)} + service_name = _test_create_file_descriptor( + list(parameter_metadata_by_id.values()), "big_message" + ) - parameter_value_by_id = serializer.deserialize_parameters( - parameter_metadata_by_id, serialized_data + parameter_value_by_id = decoder.deserialize_parameters( + parameter_metadata_by_id, + serialized_data, + service_name=service_name, ) assert parameter_value_by_id == pytest.approx(expected_parameter_value_by_id) -def test___big_message___serialize_parameters___returns_serialized_data() -> None: - parameter_metadata_by_id = _get_big_message_metadata_by_id() - values = [123.456 + i for i in range(BIG_MESSAGE_SIZE)] - expected_message = _get_big_message(values) - - serialized_data = serializer.serialize_parameters(parameter_metadata_by_id, values) - - message = BigMessage.FromString(serialized_data) - assert message.ListFields() == pytest.approx(expected_message.ListFields()) - - -def _validate_serialized_bytes(custom_serialized_bytes, values): - # Serialization using gRPC Any - grpc_serialized_data = _get_grpc_serialized_data(values) - assert grpc_serialized_data == custom_serialized_bytes - - def _get_grpc_serialized_data(values): grpc_parameter = _get_test_grpc_message(values) parameter_any = any_pb2.Any() @@ -300,119 +175,119 @@ def _get_grpc_serialized_data(values): def _get_test_parameter_by_id(default_values): parameter_by_id = { - 1: ParameterMetadata( - display_name="float_data", + 1: ParameterMetadata.initialize( + display_name="float_data!", type=type_pb2.Field.TYPE_FLOAT, repeated=False, default_value=default_values[0], annotations={}, ), - 2: ParameterMetadata( + 2: ParameterMetadata.initialize( display_name="double_data", type=type_pb2.Field.TYPE_DOUBLE, repeated=False, default_value=default_values[1], annotations={}, ), - 3: ParameterMetadata( + 3: ParameterMetadata.initialize( display_name="int32_data", type=type_pb2.Field.TYPE_INT32, repeated=False, default_value=default_values[2], annotations={}, ), - 4: ParameterMetadata( + 4: ParameterMetadata.initialize( display_name="uint32_data", type=type_pb2.Field.TYPE_INT64, repeated=False, default_value=default_values[3], annotations={}, ), - 5: ParameterMetadata( + 5: ParameterMetadata.initialize( display_name="int64_data", type=type_pb2.Field.TYPE_UINT32, repeated=False, default_value=default_values[4], annotations={}, ), - 6: ParameterMetadata( + 6: ParameterMetadata.initialize( display_name="uint64_data", type=type_pb2.Field.TYPE_UINT64, repeated=False, default_value=default_values[5], annotations={}, ), - 7: ParameterMetadata( + 7: ParameterMetadata.initialize( display_name="bool_data", type=type_pb2.Field.TYPE_BOOL, repeated=False, default_value=default_values[6], annotations={}, ), - 8: ParameterMetadata( + 8: ParameterMetadata.initialize( display_name="string_data", type=type_pb2.Field.TYPE_STRING, repeated=False, default_value=default_values[7], annotations={}, ), - 9: ParameterMetadata( + 9: ParameterMetadata.initialize( display_name="double_array_data", type=type_pb2.Field.TYPE_DOUBLE, repeated=True, default_value=default_values[8], annotations={}, ), - 10: ParameterMetadata( + 10: ParameterMetadata.initialize( display_name="float_array_data", type=type_pb2.Field.TYPE_FLOAT, repeated=True, default_value=default_values[9], annotations={}, ), - 11: ParameterMetadata( + 11: ParameterMetadata.initialize( display_name="int32_array_data", type=type_pb2.Field.TYPE_INT32, repeated=True, default_value=default_values[10], annotations={}, ), - 12: ParameterMetadata( + 12: ParameterMetadata.initialize( display_name="uint32_array_data", type=type_pb2.Field.TYPE_UINT32, repeated=True, default_value=default_values[11], annotations={}, ), - 13: ParameterMetadata( + 13: ParameterMetadata.initialize( display_name="int64_array_data", type=type_pb2.Field.TYPE_INT64, repeated=True, default_value=default_values[12], annotations={}, ), - 14: ParameterMetadata( + 14: ParameterMetadata.initialize( display_name="uint64_array_data", type=type_pb2.Field.TYPE_UINT64, repeated=True, default_value=default_values[13], annotations={}, ), - 15: ParameterMetadata( + 15: ParameterMetadata.initialize( display_name="bool_array_data", type=type_pb2.Field.TYPE_BOOL, repeated=True, default_value=default_values[14], annotations={}, ), - 16: ParameterMetadata( + 16: ParameterMetadata.initialize( display_name="string_array_data", type=type_pb2.Field.TYPE_STRING, repeated=True, default_value=default_values[15], annotations={}, ), - 17: ParameterMetadata( + 17: ParameterMetadata.initialize( display_name="enum_data", type=type_pb2.Field.TYPE_ENUM, repeated=False, @@ -421,8 +296,9 @@ def _get_test_parameter_by_id(default_values): TYPE_SPECIALIZATION_KEY: TypeSpecialization.Enum.value, ENUM_VALUES_KEY: '{"PURPLE": 0, "ORANGE": 1, "TEAL": 2, "BROWN": 3}', }, + enum_type=DifferentColor, ), - 18: ParameterMetadata( + 18: ParameterMetadata.initialize( display_name="enum_array_data", type=type_pb2.Field.TYPE_ENUM, repeated=True, @@ -431,8 +307,9 @@ def _get_test_parameter_by_id(default_values): TYPE_SPECIALIZATION_KEY: TypeSpecialization.Enum.value, ENUM_VALUES_KEY: '{"PURPLE": 0, "ORANGE": 1, "TEAL": 2, "BROWN": 3}', }, + enum_type=DifferentColor, ), - 19: ParameterMetadata( + 19: ParameterMetadata.initialize( display_name="int_enum_data", type=type_pb2.Field.TYPE_ENUM, repeated=False, @@ -441,8 +318,9 @@ def _get_test_parameter_by_id(default_values): TYPE_SPECIALIZATION_KEY: TypeSpecialization.Enum.value, ENUM_VALUES_KEY: '{"AMERICA": 0, "TAIWAN": 1, "AUSTRALIA": 2, "CANADA": 3}', }, + enum_type=Countries, ), - 20: ParameterMetadata( + 20: ParameterMetadata.initialize( display_name="int_enum_array_data", type=type_pb2.Field.TYPE_ENUM, repeated=True, @@ -451,8 +329,9 @@ def _get_test_parameter_by_id(default_values): TYPE_SPECIALIZATION_KEY: TypeSpecialization.Enum.value, ENUM_VALUES_KEY: '{"AMERICA": 0, "TAIWAN": 1, "AUSTRALIA": 2, "CANADA": 3}', }, + enum_type=Countries, ), - 21: ParameterMetadata( + 21: ParameterMetadata.initialize( display_name="xy_data", type=type_pb2.Field.TYPE_MESSAGE, repeated=False, @@ -460,7 +339,7 @@ def _get_test_parameter_by_id(default_values): annotations={}, message_type=xydata_pb2.DoubleXYData.DESCRIPTOR.full_name, ), - 22: ParameterMetadata( + 22: ParameterMetadata.initialize( display_name="xy_data_array", type=type_pb2.Field.TYPE_MESSAGE, repeated=True, @@ -503,7 +382,7 @@ def _get_test_grpc_message(test_values): def _get_big_message_metadata_by_id() -> Dict[int, ParameterMetadata]: return { i - + 1: ParameterMetadata( + + 1: ParameterMetadata.initialize( display_name=f"field{i + 1}", type=type_pb2.Field.TYPE_DOUBLE, repeated=False, @@ -517,3 +396,16 @@ def _get_big_message_metadata_by_id() -> Dict[int, ParameterMetadata]: def _get_big_message(values: Sequence[float]) -> BigMessage: assert len(values) == BIG_MESSAGE_SIZE return BigMessage(**{f"field{i + 1}": value for (i, value) in enumerate(values)}) + + +def _test_create_file_descriptor(metadata: List[ParameterMetadata], file_name: str) -> str: + pool = descriptor_pool.Default() + try: + pool.FindMessageTypeByName(f"{file_name}.test") + except KeyError: + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.name = file_name + file_descriptor.package = file_name + serialization_descriptors._create_message_type(metadata, "test", file_descriptor) + pool.Add(file_descriptor) + return file_name + ".test" diff --git a/packages/service/tests/unit/test_default_value.py b/packages/service/tests/unit/test_default_value.py new file mode 100644 index 000000000..1be45f278 --- /dev/null +++ b/packages/service/tests/unit/test_default_value.py @@ -0,0 +1,30 @@ +"""Contains tests to validate the serializationstrategy.py. """ + +import pytest +from google.protobuf import type_pb2 + +from ni_measurement_plugin_sdk_service._internal.parameter import ( + _get_type, +) + + +@pytest.mark.parametrize( + "type,is_repeated,expected_default_value", + [ + (type_pb2.Field.TYPE_FLOAT, False, 0.0), + (type_pb2.Field.TYPE_DOUBLE, False, 0.0), + (type_pb2.Field.TYPE_INT32, False, 0), + (type_pb2.Field.TYPE_INT64, False, 0), + (type_pb2.Field.TYPE_UINT32, False, 0), + (type_pb2.Field.TYPE_UINT64, False, 0), + (type_pb2.Field.TYPE_BOOL, False, False), + (type_pb2.Field.TYPE_STRING, False, ""), + (type_pb2.Field.TYPE_ENUM, False, 0), + (type_pb2.Field.TYPE_MESSAGE, False, None), + (type_pb2.Field.TYPE_MESSAGE, True, []), + ], +) +def test___get_default_value___returns_type_defaults(type, is_repeated, expected_default_value): + test_default_value = _get_type.get_type_default(type, is_repeated) + + assert test_default_value == expected_default_value diff --git a/packages/service/tests/unit/test_encoder.py b/packages/service/tests/unit/test_encoder.py new file mode 100644 index 000000000..5ab159d25 --- /dev/null +++ b/packages/service/tests/unit/test_encoder.py @@ -0,0 +1,261 @@ +"""Contains tests to validate serializer.py.""" + +from enum import Enum, IntEnum +from typing import List + +import pytest +from google.protobuf import descriptor_pb2, descriptor_pool + +from ni_measurement_plugin_sdk_service._internal.parameter import ( + encoder, + metadata, + serialization_descriptors, +) +from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import ( + xydata_pb2, +) +from tests.unit.test_decoder import ( + _get_big_message, + _get_big_message_metadata_by_id, + _get_grpc_serialized_data, + _get_test_parameter_by_id, +) +from tests.utilities.stubs.serialization.bigmessage_pb2 import BigMessage + + +class DifferentColor(Enum): + """Non-primary colors used for testing enum-typed config and output.""" + + PURPLE = 0 + ORANGE = 1 + TEAL = 2 + BROWN = 3 + + +class Countries(IntEnum): + """Countries enum used for testing enum-typed config and output.""" + + AMERICA = 0 + TAIWAN = 1 + AUSTRALIA = 2 + CANADA = 3 + + +double_xy_data = xydata_pb2.DoubleXYData() +double_xy_data.x_data.append(4) +double_xy_data.y_data.append(6) + +double_xy_data2 = xydata_pb2.DoubleXYData() +double_xy_data2.x_data.append(8) +double_xy_data2.y_data.append(10) + +double_xy_data_array = [double_xy_data, double_xy_data2] + +# This should match the number of fields in bigmessage.proto. +BIG_MESSAGE_SIZE = 100 + + +@pytest.mark.parametrize( + "test_values", + [ + [ + 2.0, + 19.2, + 3, + 1, + 2, + 2, + True, + "TestString", + [5.5, 3.3, 1], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array, + ], + [ + -0.9999, + -0.9999, + -13, + 1, + 1000, + 2, + True, + "", + [5.5, -13.3, 1, 0.0, -99.9999], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array, + ], + ], +) +def test___serializer___serialize_parameter___successful_serialization(test_values): + default_values = test_values + parameter = _get_test_parameter_by_id(default_values) + service_name = _test_create_file_descriptor(list(parameter.values()), "serialize_parameter") + + # Custom Serialization + custom_serialized_bytes = encoder.serialize_parameters( + parameter, + test_values, + service_name=service_name, + ) + + _validate_serialized_bytes(custom_serialized_bytes, test_values) + + +@pytest.mark.parametrize( + "default_values", + [ + [ + 2.0, + 19.2, + 3, + 1, + 2, + 2, + True, + "TestString", + [5.5, 3.3, 1], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array, + ], + [ + -0.9999, + -0.9999, + -13, + 1, + 1000, + 2, + False, + "////", + [5.5, -13.3, 1, 0.0, -99.9999], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array, + ], + ], +) +def test___serializer___serialize_default_parameter___successful_serialization(default_values): + parameter = _get_test_parameter_by_id(default_values) + service_name = _test_create_file_descriptor(list(parameter.values()), "default_serialize") + + # Custom Serialization + custom_serialized_bytes = encoder.serialize_default_values(parameter, service_name=service_name) + + _validate_serialized_bytes(custom_serialized_bytes, default_values) + + +def test___big_message___serialize_parameters___returns_serialized_data() -> None: + parameter_metadata_by_id = _get_big_message_metadata_by_id() + values = [123.456 + i for i in range(BIG_MESSAGE_SIZE)] + expected_message = _get_big_message(values) + service_name = _test_create_file_descriptor( + list(parameter_metadata_by_id.values()), "big_message" + ) + + serialized_data = encoder.serialize_parameters( + parameter_metadata_by_id, + values, + service_name=service_name, + ) + + message = BigMessage.FromString(serialized_data) + assert message.ListFields() == pytest.approx(expected_message.ListFields()) + + +@pytest.mark.parametrize( + "test_values", + [ + [ + 2.0, + 19.2, + 3, + 1, + 2, + 2, + True, + "TestString", + [5.5, 3.3, 1], + [5.5, 3.3, 1], + [1, 2, 3, 4], + [0, 1, 399], + [1, 2, 3, 4], + [0, 1, 399], + [True, False, True], + ["String1, String2"], + DifferentColor.ORANGE, + [DifferentColor.TEAL, DifferentColor.BROWN], + Countries.AUSTRALIA, + [Countries.AUSTRALIA, Countries.CANADA], + double_xy_data, + double_xy_data_array, + ], + ], +) +def test___serialize_parameter_multiple_times___returns_one_message_type(test_values): + for i in range(100): + test___serializer___serialize_parameter___successful_serialization(test_values) + pool = descriptor_pool.Default() + file_descriptor = pool.FindFileByName("serialize_parameter") + message_dict = file_descriptor.message_types_by_name + assert len(message_dict) == 1 + + +def _validate_serialized_bytes(custom_serialized_bytes, values): + # Serialization using gRPC Any + grpc_serialized_data = _get_grpc_serialized_data(values) + assert grpc_serialized_data == custom_serialized_bytes + + +def _test_create_file_descriptor(metadata: List[metadata.ParameterMetadata], file_name: str) -> str: + pool = descriptor_pool.Default() + try: + pool.FindMessageTypeByName(f"{file_name}.test") + except KeyError: + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.name = file_name + file_descriptor.package = file_name + serialization_descriptors._create_message_type(metadata, "test", file_descriptor) + pool.Add(file_descriptor) + return file_name + ".test" diff --git a/packages/service/tests/unit/test_serialization_strategy.py b/packages/service/tests/unit/test_serialization_strategy.py deleted file mode 100644 index 41b2c4a81..000000000 --- a/packages/service/tests/unit/test_serialization_strategy.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Contains tests to validate the serializationstrategy.py. """ - -import pytest -from google.protobuf import type_pb2 - -from ni_measurement_plugin_sdk_service._internal.parameter import serialization_strategy -from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types import xydata_pb2 - - -@pytest.mark.parametrize( - "type,is_repeated,expected_encoder", - [ - (type_pb2.Field.TYPE_FLOAT, False, serialization_strategy.FloatEncoder), - (type_pb2.Field.TYPE_DOUBLE, False, serialization_strategy.DoubleEncoder), - (type_pb2.Field.TYPE_INT32, False, serialization_strategy.IntEncoder), - (type_pb2.Field.TYPE_INT64, False, serialization_strategy.IntEncoder), - (type_pb2.Field.TYPE_UINT32, False, serialization_strategy.UIntEncoder), - (type_pb2.Field.TYPE_UINT64, False, serialization_strategy.UIntEncoder), - (type_pb2.Field.TYPE_BOOL, False, serialization_strategy.BoolEncoder), - (type_pb2.Field.TYPE_STRING, False, serialization_strategy.StringEncoder), - (type_pb2.Field.TYPE_ENUM, False, serialization_strategy.IntEncoder), - (type_pb2.Field.TYPE_MESSAGE, False, serialization_strategy.MessageEncoder), - (type_pb2.Field.TYPE_MESSAGE, True, serialization_strategy.MessageArrayEncoder), - ], -) -def test___serialization_strategy___get_encoder___returns_expected_encoder( - type, is_repeated, expected_encoder -): - encoder = serialization_strategy.get_encoder(type, is_repeated) - - assert encoder == expected_encoder - - -@pytest.mark.parametrize( - "type,is_repeated,message_type,expected_decoder", - [ - (type_pb2.Field.TYPE_FLOAT, False, "", serialization_strategy.FloatDecoder), - (type_pb2.Field.TYPE_DOUBLE, False, "", serialization_strategy.DoubleDecoder), - (type_pb2.Field.TYPE_INT32, False, "", serialization_strategy.Int32Decoder), - (type_pb2.Field.TYPE_INT64, False, "", serialization_strategy.Int64Decoder), - (type_pb2.Field.TYPE_UINT32, False, "", serialization_strategy.UInt32Decoder), - (type_pb2.Field.TYPE_UINT64, False, "", serialization_strategy.UInt64Decoder), - (type_pb2.Field.TYPE_BOOL, False, "", serialization_strategy.BoolDecoder), - (type_pb2.Field.TYPE_STRING, False, "", serialization_strategy.StringDecoder), - (type_pb2.Field.TYPE_ENUM, False, "", serialization_strategy.Int32Decoder), - ( - type_pb2.Field.TYPE_MESSAGE, - False, - xydata_pb2.DoubleXYData.DESCRIPTOR.full_name, - serialization_strategy.XYDataDecoder, - ), - ( - type_pb2.Field.TYPE_MESSAGE, - True, - xydata_pb2.DoubleXYData.DESCRIPTOR.full_name, - serialization_strategy.XYDataArrayDecoder, - ), - ], -) -def test___serialization_strategy___get_decoder___returns_expected_decoder( - type, is_repeated, message_type, expected_decoder -): - decoder = serialization_strategy.get_decoder(type, is_repeated, message_type) - - assert decoder == expected_decoder - - -@pytest.mark.parametrize( - "type,is_repeated,expected_default_value", - [ - (type_pb2.Field.TYPE_FLOAT, False, 0.0), - (type_pb2.Field.TYPE_DOUBLE, False, 0.0), - (type_pb2.Field.TYPE_INT32, False, 0), - (type_pb2.Field.TYPE_INT64, False, 0), - (type_pb2.Field.TYPE_UINT32, False, 0), - (type_pb2.Field.TYPE_UINT64, False, 0), - (type_pb2.Field.TYPE_BOOL, False, False), - (type_pb2.Field.TYPE_STRING, False, ""), - (type_pb2.Field.TYPE_ENUM, False, 0), - (type_pb2.Field.TYPE_MESSAGE, False, None), - (type_pb2.Field.TYPE_MESSAGE, True, []), - ], -) -def test___serialization_strategy___get_default_value___returns_type_defaults( - type, is_repeated, expected_default_value -): - default_value = serialization_strategy.get_type_default(type, is_repeated) - - assert default_value == expected_default_value