From 48f3d5c0272c93da06d7940c943090a33776766e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Sat, 15 Oct 2022 21:52:28 -0300 Subject: [PATCH] Handle additional case for related primary key field --- .../prefetch/serializer_optimization.py | 8 ++- tests/optimization/test_lookup_finder.py | 55 ++++++++++++++----- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/django_virtual_models/prefetch/serializer_optimization.py b/django_virtual_models/prefetch/serializer_optimization.py index 3d0085f..1e0b04e 100644 --- a/django_virtual_models/prefetch/serializer_optimization.py +++ b/django_virtual_models/prefetch/serializer_optimization.py @@ -258,7 +258,7 @@ def _maybe_handle_method_field( block_queries=self.block_queries, ) - def _maybe_handle_relational_field(self, field: Field, **kwargs) -> List[str]: + def _maybe_handle_related_field(self, field: Field, **kwargs) -> List[str]: lookup = None if isinstance(field, serializers.PrimaryKeyRelatedField): lookup = field.source @@ -266,9 +266,11 @@ def _maybe_handle_relational_field(self, field: Field, **kwargs) -> List[str]: field.child_relation, serializers.PrimaryKeyRelatedField ): lookup = field.child_relation.source + if lookup == "": + lookup = field.field_name if lookup is None: raise self.CantHandleException - return [lookup] + return [lookup + "__id"] def _maybe_handle_field_with_nested_source(self, field: Field, **kwargs) -> List[str]: # TODO: Right now, those fields must always be defined in the Virtual Model, @@ -386,7 +388,7 @@ def recursively_find_lookup_list( self._maybe_handle_property_field, self._maybe_handle_model_method_field, self._maybe_handle_method_field, - self._maybe_handle_relational_field, + self._maybe_handle_related_field, self._maybe_handle_field_with_nested_source, self._maybe_handle_url_field, self._maybe_handle_nested_serializer_field, diff --git a/tests/optimization/test_lookup_finder.py b/tests/optimization/test_lookup_finder.py index fd03ec1..0344fab 100644 --- a/tests/optimization/test_lookup_finder.py +++ b/tests/optimization/test_lookup_finder.py @@ -52,6 +52,7 @@ class VirtualCourse(v.VirtualModel): assignments = NestedAssignment( manager=Assignment.objects, ) + assignees = v.VirtualModel(manager=User.objects) lessons = v.VirtualModel(manager=Lesson.objects) settings = v.NoOp() @@ -200,22 +201,23 @@ def test_found_lookup_list(self): assert sorted(lookup_list) == sorted( [ - "assignments", - "assignments__email", - "assignments__lessons_completed_total", - "assignments__lessons_total", - "created_by", - "created_by__email", - "facilitator_emails", - "lessons", "name", + "created_by__id", + "created_by__email", "description", + "facilitator_emails", "user_assignment", - "user_assignment__completed_lessons", - "user_assignment__course", "user_assignment__email", - "user_assignment__lessons_completed_total", "user_assignment__lessons_total", + "user_assignment__lessons_completed_total", + "user_assignment__course", + "user_assignment__completed_lessons", + "assignments", + "assignments__email", + "assignments__lessons_total", + "assignments__lessons_completed_total", + "lessons__id", + "lessons", ] ) @@ -233,6 +235,32 @@ def test_found_lookup_list_has_no_n_plus_one_queries(self): course_list = list(optimized_qs) assert len(course_list) == 3 + def test_found_lookup_list_for_implicit_related_fields(self): + class CourseRelsSerializer(serializers.ModelSerializer): + class Meta: + model = Course + virtual_model = VirtualCourse + fields = ["id", "created_by", "assignees", "lessons"] + + qs = Course.objects.all() + serializer_instance = CourseRelsSerializer(instance=qs, many=True) + virtual_model = VirtualCourse() + + lookup_list = LookupFinder( + serializer_instance=serializer_instance, + virtual_model=virtual_model, + block_queries=False, + ).recursively_find_lookup_list() + + assert sorted(lookup_list) == sorted( + [ + "id", + "created_by__id", + "assignees__id", + "lessons__id", + ] + ) + def test_ignored_nested_serializer_with_noop(self): """ Sometimes one needs a nested serializer generated dynamically. @@ -292,9 +320,9 @@ def get_assignments(self, obj, serializer_cls): assert sorted(lookup_list) == sorted( [ "name", - "description", - "created_by", + "created_by__id", "created_by__email", + "description", "facilitator_emails", "user_assignment", "user_assignment__email", @@ -306,6 +334,7 @@ def get_assignments(self, obj, serializer_cls): "assignments__email", "assignments__lessons_total", "assignments__lessons_completed_total", + "lessons__id", "lessons", ] )