Skip to content

Commit

Permalink
Refined schema typing
Browse files Browse the repository at this point in the history
  • Loading branch information
christiansandberg committed Sep 19, 2024
1 parent 4ae25d5 commit db29595
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 61 deletions.
3 changes: 1 addition & 2 deletions onedm/sdf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

class CommonQualities(BaseModel):
model_config = ConfigDict(
extra="allow", validate_assignment=True, alias_generator=to_camel
extra="allow", alias_generator=to_camel
)

label: str | None = None
description: str | None = None
ref: str | None = Field(None, alias="sdfRef")
required: list[str | bool] = Field(default_factory=list, alias="sdfRequired")

def get_extra(self) -> dict[str, Any]:
return self.__pydantic_extra__
79 changes: 51 additions & 28 deletions onedm/sdf/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

from abc import ABC
from enum import Enum
import datetime
from typing import Annotated, Any, Literal, Union

from pydantic import Field, NonNegativeInt, model_validator
Expand All @@ -16,20 +16,11 @@
from .common import CommonQualities


class DataType(str, Enum):
BOOLEAN = "boolean"
NUMBER = "number"
INTEGER = "integer"
STRING = "string"
OBJECT = "object"
ARRAY = "array"


class DataQualities(CommonQualities, ABC):
"""Base class for all data qualities."""

type: DataType
sdf_type: str | None = None
type: Literal["boolean", "number", "integer", "string", "object", "array"]
sdf_type: str | None = Field(None, pattern=r"^[a-z][\-a-z0-9]*$")
nullable: bool = True
const: Any | None = None
default: Any | None = None
Expand Down Expand Up @@ -65,7 +56,7 @@ def validate(self, input: Any) -> Any:


class NumberData(DataQualities):
type: Literal[DataType.NUMBER]
type: Literal["number"]
unit: str | None = None
minimum: float | None = None
maximum: float | None = None
Expand All @@ -80,10 +71,33 @@ class NumberData(DataQualities):
@classmethod
def set_default_type(cls, data: Any):
if isinstance(data, dict):
data.setdefault("type", DataType.NUMBER)
data.setdefault("type", "number")
return data

def _get_base_schema(self) -> core_schema.FloatSchema:
def _get_base_schema(self) -> core_schema.FloatSchema | core_schema.DatetimeSchema:
if self.sdf_type == "unix-time":
return core_schema.datetime_schema(
ge=(
datetime.datetime.fromtimestamp(self.minimum)
if self.minimum is not None
else None
),
le=(
datetime.datetime.fromtimestamp(self.maximum)
if self.maximum is not None
else None
),
gt=(
datetime.datetime.fromtimestamp(self.exclusive_minimum)
if self.exclusive_minimum is not None
else None
),
lt=(
datetime.datetime.fromtimestamp(self.exclusive_maximum)
if self.exclusive_maximum is not None
else None
),
)
return core_schema.float_schema(
ge=self.minimum,
le=self.maximum,
Expand All @@ -92,19 +106,18 @@ def _get_base_schema(self) -> core_schema.FloatSchema:
multiple_of=self.multiple_of,
)

def validate(self, input: Any) -> int:
def validate(self, input: Any) -> float:
return super().validate(input)


class IntegerData(DataQualities):
type: Literal[DataType.INTEGER]
type: Literal["integer"]
unit: str | None = None
minimum: int | None = None
maximum: int | None = None
exclusive_minimum: int | None = None
exclusive_maximum: int | None = None
multiple_of: int | None = None
format: str | None = None
choices: dict[str, IntegerData] | None = Field(None, alias="sdfChoice")
const: int | None = None
default: int | None = None
Expand All @@ -113,7 +126,7 @@ class IntegerData(DataQualities):
@classmethod
def set_default_type(cls, data: Any):
if isinstance(data, dict):
data.setdefault("type", DataType.INTEGER)
data.setdefault("type", "integer")
return data

def _get_base_schema(self) -> core_schema.IntSchema:
Expand All @@ -130,15 +143,15 @@ def validate(self, input: Any) -> int:


class BooleanData(DataQualities):
type: Literal[DataType.BOOLEAN]
type: Literal["boolean"]
const: bool | None = None
default: bool | None = None

@model_validator(mode="before")
@classmethod
def set_default_type(cls, data: Any):
if isinstance(data, dict):
data.setdefault("type", DataType.BOOLEAN)
data.setdefault("type", "boolean")
return data

def _get_base_schema(self) -> core_schema.BoolSchema:
Expand All @@ -149,7 +162,7 @@ def validate(self, input: Any) -> bool:


class StringData(DataQualities):
type: Literal[DataType.STRING]
type: Literal["string"]
enum: list[str] | None = None
min_length: NonNegativeInt = 0
max_length: NonNegativeInt | None = None
Expand All @@ -164,7 +177,7 @@ class StringData(DataQualities):
@classmethod
def set_default_type(cls, data: Any):
if isinstance(data, dict):
data.setdefault("type", DataType.STRING)
data.setdefault("type", "string")
return data

def _get_base_schema(
Expand All @@ -176,6 +189,16 @@ def _get_base_schema(
return core_schema.bytes_schema(
min_length=self.min_length, max_length=self.max_length
)
if self.format == "uuid":
return core_schema.uuid_schema()
if self.format == "date-time":
return core_schema.datetime_schema()
if self.format == "date":
return core_schema.date_schema()
if self.format == "time":
return core_schema.time_schema()
if self.format == "uri":
return core_schema.url_schema()
return core_schema.str_schema(
min_length=self.min_length,
max_length=self.max_length,
Expand All @@ -187,7 +210,7 @@ def validate(self, input: Any) -> str | bytes:


class ArrayData(DataQualities):
type: Literal[DataType.ARRAY]
type: Literal["array"]
min_items: NonNegativeInt = 0
max_items: NonNegativeInt | None = None
unique_items: bool = False
Expand All @@ -199,7 +222,7 @@ class ArrayData(DataQualities):
@classmethod
def set_default_type(cls, data: Any):
if isinstance(data, dict):
data.setdefault("type", DataType.ARRAY)
data.setdefault("type", "array")
return data

def _get_base_schema(self) -> core_schema.ListSchema | core_schema.SetSchema:
Expand All @@ -220,8 +243,8 @@ def validate(self, input: Any) -> list | set:


class ObjectData(DataQualities):
type: Literal[DataType.OBJECT]
required: list[str] | None = None
type: Literal["object"]
required: list[str] = Field(default_factory=list)
properties: dict[str, Data] | None = None
const: dict[str, Any] | None = None
default: dict[str, Any] | None = None
Expand All @@ -230,7 +253,7 @@ class ObjectData(DataQualities):
@classmethod
def set_default_type(cls, data: Any):
if isinstance(data, dict):
data.setdefault("type", DataType.OBJECT)
data.setdefault("type", "object")
return data

def _get_base_schema(self) -> core_schema.TypedDictSchema:
Expand Down
63 changes: 40 additions & 23 deletions onedm/sdf/definitions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Annotated, Union
from typing import Annotated, Literal, Union

from pydantic import Field
from pydantic import Field, NonNegativeInt

from .common import CommonQualities
from .data import (
Expand All @@ -16,38 +16,53 @@
)


class PropertyCommon:
class NumberProperty(NumberData):
observable: bool = True
readable: bool = True
writable: bool = True
required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired")


class NumberProperty(NumberData, PropertyCommon):
pass


class IntegerProperty(IntegerData, PropertyCommon):
pass
class IntegerProperty(IntegerData):
observable: bool = True
readable: bool = True
writable: bool = True
required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired")


class BooleanProperty(BooleanData, PropertyCommon):
pass
class BooleanProperty(BooleanData):
observable: bool = True
readable: bool = True
writable: bool = True
required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired")


class StringProperty(StringData, PropertyCommon):
pass
class StringProperty(StringData):
observable: bool = True
readable: bool = True
writable: bool = True
required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired")


class ArrayProperty(ArrayData, PropertyCommon):
pass
class ArrayProperty(ArrayData):
observable: bool = True
readable: bool = True
writable: bool = True
required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired")


class ObjectProperty(ObjectData, PropertyCommon):
pass
class ObjectProperty(ObjectData):
observable: bool = True
readable: bool = True
writable: bool = True
required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired")


class AnyProperty(AnyData, PropertyCommon):
pass
class AnyProperty(AnyData):
observable: bool = True
readable: bool = True
writable: bool = True
required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired")


Property = Union[
Expand Down Expand Up @@ -78,9 +93,10 @@ class Object(CommonQualities):
actions: dict[str, Action] = Field(default_factory=dict, alias="sdfAction")
events: dict[str, Event] = Field(default_factory=dict, alias="sdfEvent")
data: dict[str, Data] = Field(default_factory=dict, alias="sdfData")
required: list[str] = Field(default_factory=list, alias="sdfRequired")
# If array of objects
min_items: int | None = None
max_items: int | None = None
min_items: NonNegativeInt | None = None
max_items: NonNegativeInt | None = None


class Thing(CommonQualities):
Expand All @@ -90,6 +106,7 @@ class Thing(CommonQualities):
actions: dict[str, Action] = Field(default_factory=dict, alias="sdfAction")
events: dict[str, Event] = Field(default_factory=dict, alias="sdfEvent")
data: dict[str, Data] = Field(default_factory=dict, alias="sdfData")
required: list[str] = Field(default_factory=list, alias="sdfRequired")
# If array of things
min_items: int | None = None
max_items: int | None = None
min_items: NonNegativeInt | None = None
max_items: NonNegativeInt | None = None
20 changes: 12 additions & 8 deletions tests/sdf/test_value_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from onedm import sdf


def test_integer_validation(test_model: sdf.SDF):
assert test_model.data["Integer"].validate(2) == 2
def test_integer_validation():
integer = sdf.IntegerData(maximum=2)
assert integer.validate(2) == 2
with pytest.raises(ValueError):
test_model.data["Integer"].validate(True)
integer.validate(1.5)
# Out of range
with pytest.raises(ValueError):
test_model.data["Integer"].validate(3)
integer.validate(3)


def test_number_validation(test_model: sdf.SDF):
Expand Down Expand Up @@ -40,9 +41,12 @@ def test_string_validation(test_model: sdf.SDF):
test_model.data["Number"].validate(["0123456789"])


def test_nullable_validation(test_model: sdf.SDF):
assert test_model.data["NullableInteger"].validate(None) == None
def test_nullable_validation():
nullable_integer = sdf.IntegerData(nullable=True)
assert nullable_integer.validate(None) == None

# Not nullable

def test_non_nullable_validation():
integer = sdf.IntegerData(nullable=False)
with pytest.raises(ValueError):
test_model.data["Integer"].validate(None)
integer.validate(None)

0 comments on commit db29595

Please sign in to comment.