From b7e603eea5eb19031171eaff79cf0c264533f2f7 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 14:22:49 +0000 Subject: [PATCH] model serializer type arg --- codeforlife/views/model.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py index 93ec7f5..3f6dd19 100644 --- a/codeforlife/views/model.py +++ b/codeforlife/views/model.py @@ -31,6 +31,9 @@ from ..user.models import User RequestUser = t.TypeVar("RequestUser", bound=User) + AnyBaseModelSerializer = t.TypeVar( + "AnyBaseModelSerializer", bound=BaseModelSerializer + ) # NOTE: This raises an error during runtime. # pylint: disable-next=too-few-public-methods @@ -39,6 +42,7 @@ class _ModelViewSet(DrfModelViewSet[AnyModel], t.Generic[AnyModel]): else: RequestUser = t.TypeVar("RequestUser") + AnyBaseModelSerializer = t.TypeVar("AnyBaseModelSerializer") # pylint: disable-next=too-many-ancestors class _ModelViewSet(DrfModelViewSet, t.Generic[AnyModel]): @@ -54,13 +58,11 @@ class _ModelViewSet(DrfModelViewSet, t.Generic[AnyModel]): class BaseModelViewSet( BaseAPIView[AnyBaseRequest], _ModelViewSet[AnyModel], - t.Generic[AnyBaseRequest, AnyModel], + t.Generic[AnyBaseRequest, AnyBaseModelSerializer, AnyModel], ): """Base model view set for all model view sets.""" - serializer_class: t.Optional[ - t.Type["BaseModelSerializer[AnyBaseRequest, AnyModel]"] - ] + serializer_class: t.Optional[t.Type[AnyBaseModelSerializer]] @classmethod def get_model_class(cls) -> t.Type[AnyModel]: @@ -160,16 +162,16 @@ def partial_update( # type: ignore[override] # pragma: no cover # pylint: disable-next=too-many-ancestors class ModelViewSet( - BaseModelViewSet[Request[RequestUser], AnyModel], + BaseModelViewSet[ + Request[RequestUser], + "ModelSerializer[RequestUser, AnyModel]", + AnyModel, + ], APIView[RequestUser], t.Generic[RequestUser, AnyModel], ): """Base model view set for all model view sets.""" - serializer_class: t.Optional[ - t.Type["ModelSerializer[RequestUser, AnyModel]"] - ] - def get_bulk_queryset(self, lookup_values: t.Collection): """Get the queryset for a bulk action.