Skip to content

Commit

Permalink
Remove custom _get_serializer_class. Use another strategy to handle R…
Browse files Browse the repository at this point in the history
…ecursionError between get_serializer_class, get_read_serializer, and get_write_serializer.
  • Loading branch information
pamella committed Jun 5, 2024
1 parent e74ef84 commit ca742bc
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 32 deletions.
57 changes: 31 additions & 26 deletions drf_rw_serializers/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,19 @@


class GenericAPIView(generics.GenericAPIView):
def _get_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
assert (
self.serializer_class is not None
or getattr(self, "read_serializer_class", None) is not None
), (
"'%s' should either include one of `serializer_class` and `read_serializer_class` "
"attribute, or override one of the `get_serializer_class()`, "
"`get_read_serializer_class()` method." % self.__class__.__name__
)

return self.serializer_class

def get_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
If the request method is GET, it tries to use `self.read_serializer_class`.
If the request method is not GET, it tries to use `self.write_serializer_class`.
If the specific serializer class for the request method is not set, it falls back to
`self.serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
if hasattr(self, "request"):
Expand All @@ -52,7 +36,8 @@ def get_serializer_class(self):
"attribute, or override the `get_read_serializer_class()` or "
"`get_serializer_class()` method." % self.__class__.__name__
)
return self.get_read_serializer_class()
# `use_serializer_class` is used to prevent a `RecursionError`
return self.get_read_serializer_class(use_serializer_class=True)

if self.request.method in ["POST", "PUT", "PATCH", "DELETE"]:
assert (
Expand All @@ -63,9 +48,19 @@ def get_serializer_class(self):
"attribute, or override the `get_write_serializer_class()` or "
"`get_serializer_class()` method." % self.__class__.__name__
)
return self.get_write_serializer_class()
# `use_serializer_class` is used to prevent a `RecursionError`
return self.get_write_serializer_class(use_serializer_class=True)

assert (
self.serializer_class is not None
or getattr(self, "read_serializer_class", None) is not None
), (
"'%s' should either include one of `serializer_class` and `read_serializer_class` "
"attribute, or override one of the `get_serializer_class()`, "
"`get_read_serializer_class()` method." % self.__class__.__name__
)

return self._get_serializer_class()
return self.serializer_class

def get_read_serializer(self, *args, **kwargs):
"""
Expand All @@ -75,16 +70,21 @@ def get_read_serializer(self, *args, **kwargs):
kwargs["context"] = self.get_serializer_context()
return serializer_class(*args, **kwargs)

def get_read_serializer_class(self):
def get_read_serializer_class(self, use_serializer_class: bool = False):
"""
Return the class to use for the serializer.
Defaults to using `self.read_serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
if getattr(self, "read_serializer_class", None) is None:
return self._get_serializer_class()
if use_serializer_class:
return self.serializer_class

return self.get_serializer_class()

return self.read_serializer_class

Expand All @@ -97,16 +97,21 @@ def get_write_serializer(self, *args, **kwargs):
kwargs["context"] = self.get_serializer_context()
return serializer_class(*args, **kwargs)

def get_write_serializer_class(self):
def get_write_serializer_class(self, use_serializer_class: bool = False):
"""
Return the class to use for the serializer.
Defaults to using `self.write_serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins can send extra fields, others cannot)
"""
if getattr(self, "write_serializer_class", None) is None:
return self._get_serializer_class()
if use_serializer_class:
return self.serializer_class

return self.get_serializer_class()

return self.write_serializer_class

Expand Down
96 changes: 90 additions & 6 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ def test_serializer_class_not_provided(self):
),
)

def test_get_serializer_class_override_provided(self):
class GetSerializerClassView(generics.GenericAPIView):
def get_serializer_class(self):
return OrderedMealDetailsSerializer

self.assertEqual(
GetSerializerClassView().get_serializer_class(), OrderedMealDetailsSerializer
)
self.assertEqual(
GetSerializerClassView().get_read_serializer_class(), OrderedMealDetailsSerializer
)
self.assertEqual(
GetSerializerClassView().get_write_serializer_class(), OrderedMealDetailsSerializer
)

def test_no_request_provided(self):
# Return serializer_class over read_serializer_class and write_serializer_class
self.assertEqual(
Expand Down Expand Up @@ -108,18 +123,37 @@ def test_non_read_write_request_method_provided(self):
self.FullSerializerView().get_serializer_class(), OrderedMealDetailsSerializer
)

def test_all_get_serializer_class_override_provided(self):
class GetSerializerClassView(generics.GenericAPIView):
def get_serializer_class(self):
return OrderedMealDetailsSerializer

def get_read_serializer_class(self, use_serializer_class: bool = False):
return OrderListSerializer

def get_write_serializer_class(self, use_serializer_class: bool = False):
return OrderCreateSerializer

self.assertEqual(
GetSerializerClassView().get_serializer_class(), OrderedMealDetailsSerializer
)
self.assertEqual(GetSerializerClassView().get_read_serializer_class(), OrderListSerializer)
self.assertEqual(
GetSerializerClassView().get_write_serializer_class(), OrderCreateSerializer
)


class GenericAPIViewGetReadSerializerClassTests(BaseTestCase):
def test_read_serializer_class_not_provided(self):
class NoReadSerializerView(generics.GenericAPIView):
pass

with mock.patch.object(
NoReadSerializerView, "_get_serializer_class"
) as mock__get_serializer_class:
NoReadSerializerView, "get_serializer_class"
) as mock_get_serializer_class:
NoReadSerializerView().get_read_serializer_class()

mock__get_serializer_class.assert_called_once()
mock_get_serializer_class.assert_called_once()

def test_read_serializer_class_provided(self):
class ReadSerializerClassProvided(generics.GenericAPIView):
Expand All @@ -130,18 +164,43 @@ class ReadSerializerClassProvided(generics.GenericAPIView):
OrderListSerializer,
)

def test_use_serializer_class_fallback(self):
class SerializerClassView(generics.GenericAPIView):
serializer_class = OrderedMealDetailsSerializer

self.assertEqual(
SerializerClassView().get_read_serializer_class(use_serializer_class=True),
OrderedMealDetailsSerializer,
)

with mock.patch.object(
SerializerClassView, "get_serializer_class"
) as mock_get_serializer_class:
SerializerClassView().get_read_serializer_class(use_serializer_class=False)

mock_get_serializer_class.assert_called_once()

def test_get_read_serializer_class_override_provided(self):
class GetReadSerializerClassView(generics.GenericAPIView):
def get_read_serializer_class(self, use_serializer_class: bool = False):
return OrderListSerializer

self.assertEqual(
GetReadSerializerClassView().get_read_serializer_class(), OrderListSerializer
)


class GenericAPIViewGetWriteSerializerClassTests(BaseTestCase):
def test_write_serializer_class_not_provided(self):
class NoWriteSerializerView(generics.GenericAPIView):
pass

with mock.patch.object(
NoWriteSerializerView, "_get_serializer_class"
) as mock__get_serializer_class:
NoWriteSerializerView, "get_serializer_class"
) as mock_get_serializer_class:
NoWriteSerializerView().get_write_serializer_class()

mock__get_serializer_class.assert_called_once()
mock_get_serializer_class.assert_called_once()

def test_write_serializer_class_provided(self):
class WriteSerializerClassProvided(generics.GenericAPIView):
Expand All @@ -152,6 +211,31 @@ class WriteSerializerClassProvided(generics.GenericAPIView):
OrderCreateSerializer,
)

def test_use_serializer_class_fallback(self):
class SerializerClassView(generics.GenericAPIView):
serializer_class = OrderedMealDetailsSerializer

self.assertEqual(
SerializerClassView().get_write_serializer_class(use_serializer_class=True),
OrderedMealDetailsSerializer,
)

with mock.patch.object(
SerializerClassView, "get_serializer_class"
) as mock_get_serializer_class:
SerializerClassView().get_write_serializer_class(use_serializer_class=False)

mock_get_serializer_class.assert_called_once()

def test_get_write_serializer_class_override_provided(self):
class GetWriteSerializerClassView(generics.GenericAPIView):
def get_write_serializer_class(self, use_serializer_class: bool = False):
return OrderCreateSerializer

self.assertEqual(
GetWriteSerializerClassView().get_write_serializer_class(), OrderCreateSerializer
)


class OrderListCreateEndpointTests(BaseTestCase, TestListRequestSuccess, TestCreateRequestSuccess):
def setUp(self):
Expand Down

0 comments on commit ca742bc

Please sign in to comment.