Skip to content

Commit

Permalink
Handle additional case for related primary key field
Browse files Browse the repository at this point in the history
  • Loading branch information
fjsj committed Oct 16, 2022
1 parent 0ba1589 commit 48f3d5c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 16 deletions.
8 changes: 5 additions & 3 deletions django_virtual_models/prefetch/serializer_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,17 +258,19 @@ def _maybe_handle_method_field(
block_queries=self.block_queries,
)

def _maybe_handle_relational_field(self, field: Field, **kwargs) -> List[str]:
def _maybe_handle_related_field(self, field: Field, **kwargs) -> List[str]:
lookup = None
if isinstance(field, serializers.PrimaryKeyRelatedField):
lookup = field.source
elif isinstance(field, serializers.ManyRelatedField) and isinstance(
field.child_relation, serializers.PrimaryKeyRelatedField
):
lookup = field.child_relation.source
if lookup == "":
lookup = field.field_name
if lookup is None:
raise self.CantHandleException
return [lookup]
return [lookup + "__id"]

def _maybe_handle_field_with_nested_source(self, field: Field, **kwargs) -> List[str]:
# TODO: Right now, those fields must always be defined in the Virtual Model,
Expand Down Expand Up @@ -386,7 +388,7 @@ def recursively_find_lookup_list(
self._maybe_handle_property_field,
self._maybe_handle_model_method_field,
self._maybe_handle_method_field,
self._maybe_handle_relational_field,
self._maybe_handle_related_field,
self._maybe_handle_field_with_nested_source,
self._maybe_handle_url_field,
self._maybe_handle_nested_serializer_field,
Expand Down
55 changes: 42 additions & 13 deletions tests/optimization/test_lookup_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class VirtualCourse(v.VirtualModel):
assignments = NestedAssignment(
manager=Assignment.objects,
)
assignees = v.VirtualModel(manager=User.objects)
lessons = v.VirtualModel(manager=Lesson.objects)
settings = v.NoOp()

Expand Down Expand Up @@ -200,22 +201,23 @@ def test_found_lookup_list(self):

assert sorted(lookup_list) == sorted(
[
"assignments",
"assignments__email",
"assignments__lessons_completed_total",
"assignments__lessons_total",
"created_by",
"created_by__email",
"facilitator_emails",
"lessons",
"name",
"created_by__id",
"created_by__email",
"description",
"facilitator_emails",
"user_assignment",
"user_assignment__completed_lessons",
"user_assignment__course",
"user_assignment__email",
"user_assignment__lessons_completed_total",
"user_assignment__lessons_total",
"user_assignment__lessons_completed_total",
"user_assignment__course",
"user_assignment__completed_lessons",
"assignments",
"assignments__email",
"assignments__lessons_total",
"assignments__lessons_completed_total",
"lessons__id",
"lessons",
]
)

Expand All @@ -233,6 +235,32 @@ def test_found_lookup_list_has_no_n_plus_one_queries(self):
course_list = list(optimized_qs)
assert len(course_list) == 3

def test_found_lookup_list_for_implicit_related_fields(self):
class CourseRelsSerializer(serializers.ModelSerializer):
class Meta:
model = Course
virtual_model = VirtualCourse
fields = ["id", "created_by", "assignees", "lessons"]

qs = Course.objects.all()
serializer_instance = CourseRelsSerializer(instance=qs, many=True)
virtual_model = VirtualCourse()

lookup_list = LookupFinder(
serializer_instance=serializer_instance,
virtual_model=virtual_model,
block_queries=False,
).recursively_find_lookup_list()

assert sorted(lookup_list) == sorted(
[
"id",
"created_by__id",
"assignees__id",
"lessons__id",
]
)

def test_ignored_nested_serializer_with_noop(self):
"""
Sometimes one needs a nested serializer generated dynamically.
Expand Down Expand Up @@ -292,9 +320,9 @@ def get_assignments(self, obj, serializer_cls):
assert sorted(lookup_list) == sorted(
[
"name",
"description",
"created_by",
"created_by__id",
"created_by__email",
"description",
"facilitator_emails",
"user_assignment",
"user_assignment__email",
Expand All @@ -306,6 +334,7 @@ def get_assignments(self, obj, serializer_cls):
"assignments__email",
"assignments__lessons_total",
"assignments__lessons_completed_total",
"lessons__id",
"lessons",
]
)
Expand Down

0 comments on commit 48f3d5c

Please sign in to comment.