Skip to content

Commit

Permalink
Check can get json_schema for np arrays. Handle underscore attributes
Browse files Browse the repository at this point in the history
Signed-off-by: Nijat Khanbabayev <[email protected]>
  • Loading branch information
NeejWeej committed Jan 24, 2025
1 parent c941c98 commit 9075a07
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 28 deletions.
39 changes: 17 additions & 22 deletions csp/impl/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ def _get_pydantic_core_schema(cls, _source_type, handler):
from pydantic_core import core_schema

fields = {}

for field_name, field_type in cls.__full_metadata_typed__.items():
if field_name.startswith("_"):
continue # we skip fields with underscore, like pydantic does
try:
field_schema = handler.generate_schema(field_type)
except PydanticSchemaGenerationError: # for classes we dont have a schema for
Expand All @@ -102,43 +103,37 @@ def _get_pydantic_core_schema(cls, _source_type, handler):
field_schema = core_schema.with_default_schema(
schema=field_schema, default=cls.__defaults__[field_name]
)

fields[field_name] = core_schema.typed_dict_field(
schema=field_schema,
required=False, # Make all fields optional
)

# Schema for dictionary inputs
fields_schema = core_schema.typed_dict_schema(
fields=fields,
total=False, # Allow missing fields
)
# Schema for direct class instances
instance_schema = core_schema.is_instance_schema(cls)
# Use union schema to handle both cases
schema = core_schema.union_schema(
[
instance_schema,
fields_schema,
]
extra_behavior="allow", # let csp catch extra attributes, allows underscore fields to pass through
)

def create_instance(validated_data):
def create_instance(raw_data, validator):
# We choose to not revalidate, this is the default behavior in pydantic
if isinstance(validated_data, cls):
return validated_data

data_dict = validated_data[0] if isinstance(validated_data, tuple) else validated_data
return cls(**data_dict)
if isinstance(raw_data, cls):
return raw_data
try:
return cls(**validator(raw_data))
except AttributeError as e:
# Pydantic can't use AttributeError to check other classes, like in Union annotations
raise ValueError(str(e)) from None

def serializer(val, handler):
# We don't use 'to_dict' since that works recursively
new_val = {k: getattr(val, k) for k in val.__full_metadata_typed__ if hasattr(val, k)}
# We don't use 'to_dict' since that works recursively, we ignore underscore leading fields
new_val = {
k: getattr(val, k) for k in val.__full_metadata_typed__ if not k.startswith("_") and hasattr(val, k)
}
return handler(new_val)

return core_schema.no_info_after_validator_function(
return core_schema.no_info_wrap_validator_function(
function=create_instance,
schema=schema,
schema=fields_schema,
serialization=core_schema.wrap_serializer_function_ser_schema(
function=serializer, schema=fields_schema, when_used="always"
),
Expand Down
108 changes: 104 additions & 4 deletions csp/tests/impl/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3016,10 +3016,8 @@ class SimpleStruct(csp.Struct):

invalid_data = valid_data.copy()
invalid_data["missing"] = False
result_extra_attr = TypeAdapter(SimpleStruct).validate_python(
invalid_data
) # this passes since we drop extra fields
self.assertEqual(result, result_extra_attr)
with self.assertRaises(ValidationError):
TypeAdapter(SimpleStruct).validate_python(invalid_data) # extra fields throw an error

# Test that we can validate existing structs
existing = SimpleStruct(value=1, scores=[1])
Expand Down Expand Up @@ -3732,6 +3730,10 @@ class NestedStruct(csp.Struct):
self.assertEqual(python_native, python_pydantic)
self.assertEqual(enum_as_enum.name, enum_as_str)

self.assertEqual(
nested, TypeAdapter(NestedStruct).validate_python(TypeAdapter(NestedStruct).dump_python(nested))
)

json_native = nested.to_json()
json_pydantic = TypeAdapter(NestedStruct).dump_json(nested).decode()
self.assertEqual(json.loads(json_native), json.loads(json_pydantic))
Expand All @@ -3751,6 +3753,104 @@ class NPStruct(csp.Struct):
with self.assertRaises(ValidationError) as exc_info:
TypeAdapter(NPStruct).validate_python(dict(arr=[1, 3, "ab"]))
self.assertIn("could not convert string to float", str(exc_info.exception))
# We should be able to generate the json_schema
TypeAdapter(NPStruct).json_schema()

def test_struct_with_private_fields(self):
"""Test CSP Struct with private (_) fields to ensure they're validated but excluded from serialization"""

class BaseMetric(csp.Struct):
_base_id: str
value: float

class MetricMetadata(BaseMetric):
_internal_id: str
public_tag: str
_inherited: bool = True

class MetricStruct(csp.Struct):
value: float
_confidence: float
metadata: MetricMetadata

class EventStruct(csp.Struct):
name: str
timestamp: datetime
_source: str = "system"

class DataPoint(csp.Struct):
id: str
data: Union[MetricStruct, EventStruct]
_last_updated: datetime

# Test validation with private fields
metric_data = {
"id": "metric-1",
"_last_updated": datetime(2023, 1, 1, 12, 0), # not validated
"data": {
"value": 42.5,
"_confidence": 0.95,
"metadata": {
"_base_id": "base123",
"value": 99.9,
"_internal_id": "internal123",
"public_tag": "temperature",
"_inherited": False,
},
},
}

result = TypeAdapter(DataPoint).validate_python(metric_data)

# Verify private fields are properly set including inherited ones
self.assertEqual(result._last_updated, datetime(2023, 1, 1, 12, 0))
self.assertEqual(result.data._confidence, 0.95)
self.assertEqual(result.data.metadata._base_id, "base123")
self.assertEqual(result.data.metadata._internal_id, "internal123")
self.assertEqual(result.data.metadata._inherited, False)
self.assertEqual(result.data.metadata.value, 99.9)

# Test serialization - private fields should be excluded, including inherited ones
serialized = TypeAdapter(DataPoint).dump_python(result)
self.assertNotIn("_last_updated", serialized)
self.assertNotIn("_confidence", serialized["data"])
self.assertNotIn("_base_id", serialized["data"]["metadata"])
self.assertNotIn("_internal_id", serialized["data"]["metadata"])
self.assertNotIn("_inherited", serialized["data"]["metadata"])
self.assertEqual(serialized["data"]["metadata"]["value"], 99.9)

# Verify JSON serialization also excludes private fields
json_data = json.loads(TypeAdapter(DataPoint).dump_json(result))
self.assertNotIn("_last_updated", json_data)
self.assertNotIn("_confidence", json_data["data"])
self.assertNotIn("_base_id", json_data["data"]["metadata"])
self.assertNotIn("_internal_id", json_data["data"]["metadata"])
self.assertNotIn("_inherited", json_data["data"]["metadata"])
self.assertEqual(json_data["data"]["metadata"]["value"], 99.9)

# Test that public fields are still included
self.assertEqual(json_data["data"]["metadata"]["public_tag"], "temperature")

# Test with event data
event_data = {
"id": "event-1",
"_last_updated": datetime(2023, 1, 1, 12, 0), # not validated
"data": {
"name": "system_start",
"timestamp": "2023-01-01T12:00:00", # validated
"_source": "automated_test",
},
}

result = TypeAdapter(DataPoint).validate_python(event_data)

# Verify private fields are set but excluded from serialization
self.assertEqual(result._last_updated, datetime(2023, 1, 1, 12, 0))
self.assertEqual(result.data._source, "automated_test")

json_data = json.loads(TypeAdapter(DataPoint).dump_json(result))
self.assertNotIn("_last_updated", json_data)
self.assertNotIn("_source", json_data["data"])


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions csp/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def _validate(v):
raise ValueError(f"dtype of array must be a subdtype of {source_args[0]}")
return v

return core_schema.no_info_plain_validator_function(
return core_schema.no_info_before_validator_function(
_validate,
core_schema.any_schema(),
serialization=core_schema.wrap_serializer_function_ser_schema(
lambda val, handler: handler(val if val is None else val.tolist()),
info_arg=False,
Expand Down Expand Up @@ -71,8 +72,9 @@ def _validate(v):
raise ValueError("array must be one dimensional")
return v

return core_schema.no_info_plain_validator_function(
return core_schema.no_info_before_validator_function(
_validate,
core_schema.any_schema(),
serialization=core_schema.wrap_serializer_function_ser_schema(
lambda val, handler: handler(val if val is None else val.tolist()),
info_arg=False,
Expand Down

0 comments on commit 9075a07

Please sign in to comment.