From 9075a075b331addad2781e50ae624cf17a9fe15e Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 23 Jan 2025 21:01:54 -0500 Subject: [PATCH] Check can get json_schema for np arrays. Handle underscore attributes Signed-off-by: Nijat Khanbabayev --- csp/impl/struct.py | 39 ++++++------ csp/tests/impl/test_struct.py | 108 ++++++++++++++++++++++++++++++++-- csp/typing.py | 6 +- 3 files changed, 125 insertions(+), 28 deletions(-) diff --git a/csp/impl/struct.py b/csp/impl/struct.py index d24e4e02..a85c765f 100644 --- a/csp/impl/struct.py +++ b/csp/impl/struct.py @@ -91,8 +91,9 @@ def _get_pydantic_core_schema(cls, _source_type, handler): from pydantic_core import core_schema fields = {} - for field_name, field_type in cls.__full_metadata_typed__.items(): + if field_name.startswith("_"): + continue # we skip fields with underscore, like pydantic does try: field_schema = handler.generate_schema(field_type) except PydanticSchemaGenerationError: # for classes we dont have a schema for @@ -102,43 +103,37 @@ def _get_pydantic_core_schema(cls, _source_type, handler): field_schema = core_schema.with_default_schema( schema=field_schema, default=cls.__defaults__[field_name] ) - fields[field_name] = core_schema.typed_dict_field( schema=field_schema, required=False, # Make all fields optional ) - # Schema for dictionary inputs fields_schema = core_schema.typed_dict_schema( fields=fields, total=False, # Allow missing fields - ) - # Schema for direct class instances - instance_schema = core_schema.is_instance_schema(cls) - # Use union schema to handle both cases - schema = core_schema.union_schema( - [ - instance_schema, - fields_schema, - ] + extra_behavior="allow", # let csp catch extra attributes, allows underscore fields to pass through ) - def create_instance(validated_data): + def create_instance(raw_data, validator): # We choose to not revalidate, this is the default behavior in pydantic - if isinstance(validated_data, cls): - return validated_data - - data_dict = validated_data[0] if isinstance(validated_data, tuple) else validated_data - return cls(**data_dict) + if isinstance(raw_data, cls): + return raw_data + try: + return cls(**validator(raw_data)) + except AttributeError as e: + # Pydantic can't use AttributeError to check other classes, like in Union annotations + raise ValueError(str(e)) from None def serializer(val, handler): - # We don't use 'to_dict' since that works recursively - new_val = {k: getattr(val, k) for k in val.__full_metadata_typed__ if hasattr(val, k)} + # We don't use 'to_dict' since that works recursively, we ignore underscore leading fields + new_val = { + k: getattr(val, k) for k in val.__full_metadata_typed__ if not k.startswith("_") and hasattr(val, k) + } return handler(new_val) - return core_schema.no_info_after_validator_function( + return core_schema.no_info_wrap_validator_function( function=create_instance, - schema=schema, + schema=fields_schema, serialization=core_schema.wrap_serializer_function_ser_schema( function=serializer, schema=fields_schema, when_used="always" ), diff --git a/csp/tests/impl/test_struct.py b/csp/tests/impl/test_struct.py index fbad755e..eba7ad94 100644 --- a/csp/tests/impl/test_struct.py +++ b/csp/tests/impl/test_struct.py @@ -3016,10 +3016,8 @@ class SimpleStruct(csp.Struct): invalid_data = valid_data.copy() invalid_data["missing"] = False - result_extra_attr = TypeAdapter(SimpleStruct).validate_python( - invalid_data - ) # this passes since we drop extra fields - self.assertEqual(result, result_extra_attr) + with self.assertRaises(ValidationError): + TypeAdapter(SimpleStruct).validate_python(invalid_data) # extra fields throw an error # Test that we can validate existing structs existing = SimpleStruct(value=1, scores=[1]) @@ -3732,6 +3730,10 @@ class NestedStruct(csp.Struct): self.assertEqual(python_native, python_pydantic) self.assertEqual(enum_as_enum.name, enum_as_str) + self.assertEqual( + nested, TypeAdapter(NestedStruct).validate_python(TypeAdapter(NestedStruct).dump_python(nested)) + ) + json_native = nested.to_json() json_pydantic = TypeAdapter(NestedStruct).dump_json(nested).decode() self.assertEqual(json.loads(json_native), json.loads(json_pydantic)) @@ -3751,6 +3753,104 @@ class NPStruct(csp.Struct): with self.assertRaises(ValidationError) as exc_info: TypeAdapter(NPStruct).validate_python(dict(arr=[1, 3, "ab"])) self.assertIn("could not convert string to float", str(exc_info.exception)) + # We should be able to generate the json_schema + TypeAdapter(NPStruct).json_schema() + + def test_struct_with_private_fields(self): + """Test CSP Struct with private (_) fields to ensure they're validated but excluded from serialization""" + + class BaseMetric(csp.Struct): + _base_id: str + value: float + + class MetricMetadata(BaseMetric): + _internal_id: str + public_tag: str + _inherited: bool = True + + class MetricStruct(csp.Struct): + value: float + _confidence: float + metadata: MetricMetadata + + class EventStruct(csp.Struct): + name: str + timestamp: datetime + _source: str = "system" + + class DataPoint(csp.Struct): + id: str + data: Union[MetricStruct, EventStruct] + _last_updated: datetime + + # Test validation with private fields + metric_data = { + "id": "metric-1", + "_last_updated": datetime(2023, 1, 1, 12, 0), # not validated + "data": { + "value": 42.5, + "_confidence": 0.95, + "metadata": { + "_base_id": "base123", + "value": 99.9, + "_internal_id": "internal123", + "public_tag": "temperature", + "_inherited": False, + }, + }, + } + + result = TypeAdapter(DataPoint).validate_python(metric_data) + + # Verify private fields are properly set including inherited ones + self.assertEqual(result._last_updated, datetime(2023, 1, 1, 12, 0)) + self.assertEqual(result.data._confidence, 0.95) + self.assertEqual(result.data.metadata._base_id, "base123") + self.assertEqual(result.data.metadata._internal_id, "internal123") + self.assertEqual(result.data.metadata._inherited, False) + self.assertEqual(result.data.metadata.value, 99.9) + + # Test serialization - private fields should be excluded, including inherited ones + serialized = TypeAdapter(DataPoint).dump_python(result) + self.assertNotIn("_last_updated", serialized) + self.assertNotIn("_confidence", serialized["data"]) + self.assertNotIn("_base_id", serialized["data"]["metadata"]) + self.assertNotIn("_internal_id", serialized["data"]["metadata"]) + self.assertNotIn("_inherited", serialized["data"]["metadata"]) + self.assertEqual(serialized["data"]["metadata"]["value"], 99.9) + + # Verify JSON serialization also excludes private fields + json_data = json.loads(TypeAdapter(DataPoint).dump_json(result)) + self.assertNotIn("_last_updated", json_data) + self.assertNotIn("_confidence", json_data["data"]) + self.assertNotIn("_base_id", json_data["data"]["metadata"]) + self.assertNotIn("_internal_id", json_data["data"]["metadata"]) + self.assertNotIn("_inherited", json_data["data"]["metadata"]) + self.assertEqual(json_data["data"]["metadata"]["value"], 99.9) + + # Test that public fields are still included + self.assertEqual(json_data["data"]["metadata"]["public_tag"], "temperature") + + # Test with event data + event_data = { + "id": "event-1", + "_last_updated": datetime(2023, 1, 1, 12, 0), # not validated + "data": { + "name": "system_start", + "timestamp": "2023-01-01T12:00:00", # validated + "_source": "automated_test", + }, + } + + result = TypeAdapter(DataPoint).validate_python(event_data) + + # Verify private fields are set but excluded from serialization + self.assertEqual(result._last_updated, datetime(2023, 1, 1, 12, 0)) + self.assertEqual(result.data._source, "automated_test") + + json_data = json.loads(TypeAdapter(DataPoint).dump_json(result)) + self.assertNotIn("_last_updated", json_data) + self.assertNotIn("_source", json_data["data"]) if __name__ == "__main__": diff --git a/csp/typing.py b/csp/typing.py index e9bb70db..0ae89187 100644 --- a/csp/typing.py +++ b/csp/typing.py @@ -40,8 +40,9 @@ def _validate(v): raise ValueError(f"dtype of array must be a subdtype of {source_args[0]}") return v - return core_schema.no_info_plain_validator_function( + return core_schema.no_info_before_validator_function( _validate, + core_schema.any_schema(), serialization=core_schema.wrap_serializer_function_ser_schema( lambda val, handler: handler(val if val is None else val.tolist()), info_arg=False, @@ -71,8 +72,9 @@ def _validate(v): raise ValueError("array must be one dimensional") return v - return core_schema.no_info_plain_validator_function( + return core_schema.no_info_before_validator_function( _validate, + core_schema.any_schema(), serialization=core_schema.wrap_serializer_function_ser_schema( lambda val, handler: handler(val if val is None else val.tolist()), info_arg=False,