diff --git a/packages/syft/src/syft/types/syft_metaclass.py b/packages/syft/src/syft/types/syft_metaclass.py index 8ab017eb074..ae09b1c03c8 100644 --- a/packages/syft/src/syft/types/syft_metaclass.py +++ b/packages/syft/src/syft/types/syft_metaclass.py @@ -1,6 +1,5 @@ # stdlib from typing import Any -from typing import TypeVar from typing import final # third party @@ -10,8 +9,6 @@ # relative from ..serde.serializable import serializable -_T = TypeVar("_T", bound=BaseModel) - class EmptyType(type): def __repr__(self) -> str: @@ -38,10 +35,12 @@ def __new__( ) -> type: cls = super().__new__(mcs, cls_name, bases, namespace, *args, **kwargs) - for field_info in cls.model_fields.values(): - if field_info.annotation is not None and field_info.is_required(): - field_info.annotation = field_info.annotation | EmptyType - field_info.default = Empty + if issubclass(cls, BaseModel): + for field_info in cls.model_fields.values(): + if field_info.annotation is not None and field_info.is_required(): + field_info.annotation = field_info.annotation | EmptyType + field_info.default = Empty + + cls.model_rebuild(force=True) - cls.model_rebuild(force=True) return cls