diff --git a/example/tests/integration/test_includes.py b/example/tests/integration/test_includes.py index f8fe8604..d32ff7e3 100644 --- a/example/tests/integration/test_includes.py +++ b/example/tests/integration/test_includes.py @@ -4,71 +4,6 @@ pytestmark = pytest.mark.django_db -def test_included_data_on_list(multiple_entries, client): - response = client.get( - reverse("entry-list"), data={"include": "comments", "page[size]": 5} - ) - included = response.json().get("included") - - assert len(response.json()["data"]) == len( - multiple_entries - ), "Incorrect entry count" - assert [x.get("type") for x in included] == [ - "comments", - "comments", - ], "List included types are incorrect" - - comment_count = len( - [resource for resource in included if resource["type"] == "comments"] - ) - expected_comment_count = sum(entry.comments.count() for entry in multiple_entries) - assert comment_count == expected_comment_count, "List comment count is incorrect" - - -def test_included_data_on_list_with_one_to_one_relations(multiple_entries, client): - response = client.get( - reverse("entry-list"), data={"include": "authors.bio.metadata", "page[size]": 5} - ) - included = response.json().get("included") - - assert len(response.json()["data"]) == len( - multiple_entries - ), "Incorrect entry count" - expected_include_types = [ - "authorBioMetadata", - "authorBioMetadata", - "authorBios", - "authorBios", - "authors", - "authors", - ] - include_types = [x.get("type") for x in included] - assert include_types == expected_include_types, "List included types are incorrect" - - -def test_default_included_data_on_detail(single_entry, client): - return test_included_data_on_detail( - single_entry=single_entry, client=client, query="" - ) - - -def test_included_data_on_detail(single_entry, client, query="?include=comments"): - response = client.get( - reverse("entry-detail", kwargs={"pk": single_entry.pk}) + query - ) - included = response.json().get("included") - - assert [x.get("type") for x in included] == [ - "comments" - ], "Detail included types are incorrect" - - comment_count = len( - [resource for resource in included if resource["type"] == "comments"] - ) - expected_comment_count = single_entry.comments.count() - assert comment_count == expected_comment_count, "Detail comment count is incorrect" - - def test_dynamic_related_data_is_included(single_entry, entry_factory, client): entry_factory() response = client.get( diff --git a/tests/conftest.py b/tests/conftest.py index ebdf5348..682d8342 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,11 @@ from tests.models import ( BasicModel, + ForeignKeySource, ForeignKeyTarget, ManyToManySource, ManyToManyTarget, + NestedRelatedSource, ) @@ -39,6 +41,11 @@ def foreign_key_target(db): return ForeignKeyTarget.objects.create(name="Target") +@pytest.fixture +def foreign_key_source(db, foreign_key_target): + return ForeignKeySource.objects.create(name="Source", target=foreign_key_target) + + @pytest.fixture def many_to_many_source(db, many_to_many_targets): source = ManyToManySource.objects.create(name="Source") @@ -54,6 +61,34 @@ def many_to_many_targets(db): ] +@pytest.fixture +def many_to_many_sources(db, many_to_many_targets): + source1 = ManyToManySource.objects.create(name="Source1") + source2 = ManyToManySource.objects.create(name="Source2") + + source1.targets.add(*many_to_many_targets) + source2.targets.add(*many_to_many_targets) + + return [source1, source2] + + +@pytest.fixture +def nested_related_source( + db, + foreign_key_source, + foreign_key_target, + many_to_many_targets, + many_to_many_sources, +): + source = NestedRelatedSource.objects.create( + fk_source=foreign_key_source, fk_target=foreign_key_target + ) + source.m2m_targets.add(*many_to_many_targets) + source.m2m_sources.add(*many_to_many_sources) + + return source + + @pytest.fixture def client(): return APIClient() diff --git a/tests/models.py b/tests/models.py index a483f2d0..812ee5bf 100644 --- a/tests/models.py +++ b/tests/models.py @@ -42,11 +42,11 @@ class ForeignKeySource(DJAModel): class NestedRelatedSource(DJAModel): - m2m_source = models.ManyToManyField(ManyToManySource, related_name="nested_source") + m2m_sources = models.ManyToManyField(ManyToManySource, related_name="nested_source") fk_source = models.ForeignKey( ForeignKeySource, related_name="nested_source", on_delete=models.CASCADE ) - m2m_target = models.ManyToManyField(ManyToManySource, related_name="nested_target") + m2m_targets = models.ManyToManyField(ManyToManyTarget, related_name="nested_target") fk_target = models.ForeignKey( - ForeignKeySource, related_name="nested_target", on_delete=models.CASCADE + ForeignKeyTarget, related_name="nested_target", on_delete=models.CASCADE ) diff --git a/tests/serializers.py b/tests/serializers.py index e0f7c03b..4159039d 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -1,11 +1,11 @@ from rest_framework_json_api import serializers -from rest_framework_json_api.relations import ResourceRelatedField from tests.models import ( BasicModel, ForeignKeySource, ForeignKeyTarget, ManyToManySource, ManyToManyTarget, + NestedRelatedSource, ) @@ -15,33 +15,51 @@ class Meta: model = BasicModel +class ForeignKeyTargetSerializer(serializers.ModelSerializer): + class Meta: + fields = ("name",) + model = ForeignKeyTarget + + class ForeignKeySourceSerializer(serializers.ModelSerializer): - target = ResourceRelatedField(queryset=ForeignKeyTarget.objects) + included_serializers = {"target": ForeignKeyTargetSerializer} class Meta: model = ForeignKeySource fields = ("target",) +class ManyToManyTargetSerializer(serializers.ModelSerializer): + class Meta: + fields = ("name",) + model = ManyToManyTarget + + class ManyToManySourceSerializer(serializers.ModelSerializer): - targets = ResourceRelatedField(many=True, queryset=ManyToManyTarget.objects) + included_serializers = {"targets": "tests.serializers.ManyToManyTargetSerializer"} class Meta: model = ManyToManySource fields = ("targets",) -class ManyToManyTargetSerializer(serializers.ModelSerializer): +class ManyToManySourceReadOnlySerializer(serializers.ModelSerializer): class Meta: - model = ManyToManyTarget + model = ManyToManySource + fields = ("targets",) -class ManyToManySourceReadOnlySerializer(serializers.ModelSerializer): - targets = ResourceRelatedField(many=True, read_only=True) +class NestedRelatedSourceSerializer(serializers.ModelSerializer): + included_serializers = { + "m2m_sources": ManyToManySourceSerializer, + "fk_source": ForeignKeySourceSerializer, + "m2m_targets": ManyToManyTargetSerializer, + "fk_target": ForeignKeyTargetSerializer, + } class Meta: - model = ManyToManySource - fields = ("targets",) + model = NestedRelatedSource + fields = ("m2m_sources", "fk_source", "m2m_targets", "fk_target") class CallableDefaultSerializer(serializers.Serializer): diff --git a/tests/test_utils.py b/tests/test_utils.py index f2a3d176..4e103ae2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -325,7 +325,7 @@ class Meta: {"many": True, "queryset": ManyToManyTarget.objects.all()}, ), ( - "m2m_target.sources.", + "m2m_target.sources", "ManyToManySource", {"many": True, "queryset": ManyToManySource.objects.all()}, ), diff --git a/tests/test_views.py b/tests/test_views.py index 47fec02a..de5d1b7a 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -12,9 +12,14 @@ from rest_framework_json_api.renderers import JSONRenderer from rest_framework_json_api.utils import format_link_segment from rest_framework_json_api.views import ModelViewSet, ReadOnlyModelViewSet -from tests.models import BasicModel -from tests.serializers import BasicModelSerializer -from tests.views import BasicModelViewSet +from tests.models import BasicModel, ForeignKeySource +from tests.serializers import BasicModelSerializer, ForeignKeyTargetSerializer +from tests.views import ( + BasicModelViewSet, + ForeignKeySourceViewSet, + ManyToManySourceViewSet, + NestedRelatedSourceViewSet, +) class TestModelViewSet: @@ -80,6 +85,90 @@ def test_list(self, client, model): "meta": {"pagination": {"count": 1, "page": 1, "pages": 1}}, } + @pytest.mark.urls(__name__) + def test_list_with_include_foreign_key(self, client, foreign_key_source): + url = reverse("foreign-key-source-list") + response = client.get(url, data={"include": "target"}) + assert response.status_code == status.HTTP_200_OK + result = response.json() + assert "included" in result + assert [ + { + "type": "ForeignKeyTarget", + "id": str(foreign_key_source.target.pk), + "attributes": {"name": foreign_key_source.target.name}, + } + ] == result["included"] + + @pytest.mark.urls(__name__) + def test_list_with_include_many_to_many_field( + self, client, many_to_many_source, many_to_many_targets + ): + url = reverse("many-to-many-source-list") + response = client.get(url, data={"include": "targets"}) + assert response.status_code == status.HTTP_200_OK + result = response.json() + assert "included" in result + assert [ + { + "type": "ManyToManyTarget", + "id": str(target.pk), + "attributes": {"name": target.name}, + } + for target in many_to_many_targets + ] == result["included"] + + @pytest.mark.urls(__name__) + def test_list_with_include_nested_related_field( + self, client, nested_related_source, many_to_many_sources, many_to_many_targets + ): + url = reverse("nested-related-source-list") + response = client.get(url, data={"include": "m2m_sources,m2m_sources.targets"}) + assert response.status_code == status.HTTP_200_OK + result = response.json() + assert "included" in result + + assert [ + { + "type": "ManyToManySource", + "id": str(source.pk), + "relationships": { + "targets": { + "data": [ + {"id": str(target.pk), "type": "ManyToManyTarget"} + for target in source.targets.all() + ], + "meta": {"count": source.targets.count()}, + } + }, + } + for source in many_to_many_sources + ] + [ + { + "type": "ManyToManyTarget", + "id": str(target.pk), + "attributes": {"name": target.name}, + } + for target in many_to_many_targets + ] == result[ + "included" + ] + + @pytest.mark.urls(__name__) + def test_list_with_default_included_resources(self, client, foreign_key_source): + url = reverse("default-included-resources-list") + response = client.get(url) + assert response.status_code == status.HTTP_200_OK + result = response.json() + assert "included" in result + assert [ + { + "type": "ForeignKeyTarget", + "id": str(foreign_key_source.target.pk), + "attributes": {"name": foreign_key_source.target.name}, + } + ] == result["included"] + @pytest.mark.urls(__name__) def test_retrieve(self, client, model): url = reverse("basic-model-detail", kwargs={"pk": model.pk}) @@ -93,6 +182,21 @@ def test_retrieve(self, client, model): } } + @pytest.mark.urls(__name__) + def test_retrieve_with_include_foreign_key(self, client, foreign_key_source): + url = reverse("foreign-key-source-detail", kwargs={"pk": foreign_key_source.pk}) + response = client.get(url, data={"include": "target"}) + assert response.status_code == status.HTTP_200_OK + result = response.json() + assert "included" in result + assert [ + { + "type": "ForeignKeyTarget", + "id": str(foreign_key_source.target.pk), + "attributes": {"name": foreign_key_source.target.name}, + } + ] == result["included"] + @pytest.mark.urls(__name__) def test_patch(self, client, model): data = { @@ -231,6 +335,23 @@ def test_patch_with_custom_id(self, client): # Routing setup +class DefaultIncludedResourcesSerializer(serializers.ModelSerializer): + included_serializers = {"target": ForeignKeyTargetSerializer} + + class Meta: + model = ForeignKeySource + fields = ("target",) + + class JSONAPIMeta: + included_resources = ["target"] + + +class DefaultIncludedResourcesViewSet(ModelViewSet): + serializer_class = DefaultIncludedResourcesSerializer + queryset = ForeignKeySource.objects.all() + ordering = ["id"] + + class CustomModel: def __init__(self, response_dict): for k, v in response_dict.items(): @@ -280,6 +401,22 @@ def patch(self, request, *args, **kwargs): router = SimpleRouter() router.register(r"basic_models", BasicModelViewSet, basename="basic-model") +router.register( + r"foreign_key_sources", ForeignKeySourceViewSet, basename="foreign-key-source" +) +router.register( + r"many_to_many_sources", ManyToManySourceViewSet, basename="many-to-many-source" +) +router.register( + r"nested_related_sources", + NestedRelatedSourceViewSet, + basename="nested-related-source", +) +router.register( + r"default_included_resources", + DefaultIncludedResourcesViewSet, + basename="default-included-resources", +) urlpatterns = [ path("custom", CustomAPIView.as_view(), name="custom"), diff --git a/tests/views.py b/tests/views.py index 42d8a0b0..72a7ea59 100644 --- a/tests/views.py +++ b/tests/views.py @@ -1,8 +1,37 @@ from rest_framework_json_api.views import ModelViewSet -from tests.models import BasicModel -from tests.serializers import BasicModelSerializer +from tests.models import ( + BasicModel, + ForeignKeySource, + ManyToManySource, + NestedRelatedSource, +) +from tests.serializers import ( + BasicModelSerializer, + ForeignKeySourceSerializer, + ManyToManySourceSerializer, + NestedRelatedSourceSerializer, +) class BasicModelViewSet(ModelViewSet): serializer_class = BasicModelSerializer queryset = BasicModel.objects.all() + ordering = ["text"] + + +class ForeignKeySourceViewSet(ModelViewSet): + serializer_class = ForeignKeySourceSerializer + queryset = ForeignKeySource.objects.all() + ordering = ["name"] + + +class ManyToManySourceViewSet(ModelViewSet): + serializer_class = ManyToManySourceSerializer + queryset = ManyToManySource.objects.all() + ordering = ["name"] + + +class NestedRelatedSourceViewSet(ModelViewSet): + serializer_class = NestedRelatedSourceSerializer + queryset = NestedRelatedSource.objects.all() + ordering = ["id"]