Skip to content

Commit

Permalink
Add serialization and type coercion to csp numpy array types
Browse files Browse the repository at this point in the history
Signed-off-by: Nijat Khanbabayev <[email protected]>
  • Loading branch information
NeejWeej committed Jan 22, 2025
1 parent c1acca8 commit c941c98
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 7 deletions.
24 changes: 24 additions & 0 deletions csp/tests/impl/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import csp
from csp.impl.struct import define_nested_struct, define_struct, defineNestedStruct, defineStruct
from csp.impl.types.typing_utils import FastList
from csp.typing import Numpy1DArray


class MyEnum(csp.Enum):
Expand Down Expand Up @@ -3013,6 +3014,13 @@ class SimpleStruct(csp.Struct):
self.assertEqual(result.name, "ya")
self.assertEqual(result.scores, [1.1, 2.2, 3.3])

invalid_data = valid_data.copy()
invalid_data["missing"] = False
result_extra_attr = TypeAdapter(SimpleStruct).validate_python(
invalid_data
) # this passes since we drop extra fields
self.assertEqual(result, result_extra_attr)

# Test that we can validate existing structs
existing = SimpleStruct(value=1, scores=[1])
new = TypeAdapter(SimpleStruct).validate_python(existing)
Expand Down Expand Up @@ -3728,6 +3736,22 @@ class NestedStruct(csp.Struct):
json_pydantic = TypeAdapter(NestedStruct).dump_json(nested).decode()
self.assertEqual(json.loads(json_native), json.loads(json_pydantic))

def test_pydantic_np_arr(self):
class NPStruct(csp.Struct):
arr: Numpy1DArray[float] = np.array([])

val = NPStruct(arr=np.array([1, 2]))
json_val = TypeAdapter(NPStruct).dump_json(val)
# We serialize as a list
self.assertEqual(json.loads(json_val), dict(arr=[1, 2]))
revived_val = TypeAdapter(NPStruct).validate_json(json_val)
np.all(val.arr == revived_val)

NPStruct(arr=np.array([1, 3, "ab"])) # No error, even though the types are wrong
with self.assertRaises(ValidationError) as exc_info:
TypeAdapter(NPStruct).validate_python(dict(arr=[1, 3, "ab"]))
self.assertIn("could not convert string to float", str(exc_info.exception))


if __name__ == "__main__":
unittest.main()
8 changes: 4 additions & 4 deletions csp/tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def test_Numpy1DArray(self):
ta.validate_python(np.array([1.0, 2.0], dtype=np.float64))
self.assertRaises(Exception, ta.validate_python, np.array([[1.0]]))
self.assertRaises(Exception, ta.validate_python, np.array(["foo"]))
self.assertRaises(Exception, ta.validate_python, np.array([1, 2]))
self.assertRaises(Exception, ta.validate_python, np.array([1.0, 2.0], dtype=np.float32))
ta.validate_python(np.array([1, 2])) # gets coerced to correct type
ta.validate_python(np.array([1.0, 2.0], dtype=np.float32)) # gets coerced to correct type

def test_NumpyNDArray(self):
ta = TypeAdapter(NumpyNDArray[float])
Expand All @@ -22,5 +22,5 @@ def test_NumpyNDArray(self):
ta.validate_python(np.array([[1.0, 2.0]]))
ta.validate_python(np.array([[1.0, 2.0]], dtype=np.float64))
self.assertRaises(Exception, ta.validate_python, np.array(["foo"]))
self.assertRaises(Exception, ta.validate_python, np.array([1, 2]))
self.assertRaises(Exception, ta.validate_python, np.array([1.0, 2.0], dtype=np.float32))
ta.validate_python(np.array([1, 2])) # gets coerced to correct type
ta.validate_python(np.array([1.0, 2.0], dtype=np.float32)) # gets coerced to correct type
41 changes: 38 additions & 3 deletions csp/typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
import numpy
from typing import Generic, TypeVar, get_args
from typing import Any, Generic, TypeVar, get_args

T = TypeVar("T")


def _get_validator_np(source_type):
# Given a source type, gets the numpy array validator
def _validate(v):
subtypes = get_args(source_type)
dtype = subtypes[0] if subtypes and subtypes[0] != Any else None
try:
if dtype:
return numpy.asarray(v, dtype=dtype)
return numpy.asarray(v)

except TypeError:
raise ValueError(f"Unable to convert {v} to an array.")

return _validate


class NumpyNDArray(Generic[T], numpy.ndarray):
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
Expand All @@ -14,14 +30,24 @@ def __get_pydantic_core_schema__(cls, source_type, handler):
if not source_args:
raise TypeError(f"Must provide a single generic argument to {cls}")

validate_func = _get_validator_np(source_type=source_type)

def _validate(v):
v = validate_func(v)
if not isinstance(v, numpy.ndarray):
raise ValueError("value must be an instance of numpy.ndarray")
if not numpy.issubdtype(v.dtype, source_args[0]):
raise ValueError(f"dtype of array must be a subdtype of {source_args[0]}")
return v

return core_schema.no_info_plain_validator_function(_validate)
return core_schema.no_info_plain_validator_function(
_validate,
serialization=core_schema.wrap_serializer_function_ser_schema(
lambda val, handler: handler(val if val is None else val.tolist()),
info_arg=False,
return_schema=core_schema.list_schema(),
),
)


class Numpy1DArray(NumpyNDArray[T]):
Expand All @@ -33,8 +59,10 @@ def __get_pydantic_core_schema__(cls, source_type, handler):
source_args = get_args(source_type)
if not source_args:
raise TypeError(f"Must provide a single generic argument to {cls}")
validate_func = _get_validator_np(source_type=source_type)

def _validate(v):
v = validate_func(v)
if not isinstance(v, numpy.ndarray):
raise ValueError("value must be an instance of numpy.ndarray")
if not numpy.issubdtype(v.dtype, source_args[0]):
Expand All @@ -43,4 +71,11 @@ def _validate(v):
raise ValueError("array must be one dimensional")
return v

return core_schema.no_info_plain_validator_function(_validate)
return core_schema.no_info_plain_validator_function(
_validate,
serialization=core_schema.wrap_serializer_function_ser_schema(
lambda val, handler: handler(val if val is None else val.tolist()),
info_arg=False,
return_schema=core_schema.list_schema(),
),
)

0 comments on commit c941c98

Please sign in to comment.