diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b05600b..186c58c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ any parts of the framework not mentioned in the documentation should generally b * Avoided that an empty attributes dict is rendered in case serializer does not provide any attribute fields. * Avoided shadowing of exception when rendering errors (regression since 4.3.0). +* Ensured that sparse fields only applies when rendering, not when parsing. +* Adjusted that sparse fields properly removes meta fields when not defined. ### Removed diff --git a/example/tests/integration/test_non_paginated_responses.py b/example/tests/integration/test_non_paginated_responses.py index 92d26de3..6434b7a7 100644 --- a/example/tests/integration/test_non_paginated_responses.py +++ b/example/tests/integration/test_non_paginated_responses.py @@ -14,7 +14,7 @@ def test_multiple_entries_no_pagination(multiple_entries, client): expected = { "data": [ { - "type": "posts", + "type": "entries", "id": "1", "attributes": { "headline": multiple_entries[0].headline, @@ -70,7 +70,7 @@ def test_multiple_entries_no_pagination(multiple_entries, client): }, }, { - "type": "posts", + "type": "entries", "id": "2", "attributes": { "headline": multiple_entries[1].headline, diff --git a/example/tests/integration/test_pagination.py b/example/tests/integration/test_pagination.py index 1a4bd056..0f5ac17e 100644 --- a/example/tests/integration/test_pagination.py +++ b/example/tests/integration/test_pagination.py @@ -14,7 +14,7 @@ def test_pagination_with_single_entry(single_entry, client): expected = { "data": [ { - "type": "posts", + "type": "entries", "id": "1", "attributes": { "headline": single_entry.headline, diff --git a/example/tests/integration/test_sparse_fieldsets.py b/example/tests/integration/test_sparse_fieldsets.py index 605d218d..cf9cee20 100644 --- a/example/tests/integration/test_sparse_fieldsets.py +++ b/example/tests/integration/test_sparse_fieldsets.py @@ -15,6 +15,7 @@ def test_sparse_fieldset_valid_fields(client, entry): entry = data[0] assert entry["attributes"].keys() == {"headline"} assert entry["relationships"].keys() == {"blog"} + assert "meta" not in entry @pytest.mark.parametrize( diff --git a/example/tests/test_filters.py b/example/tests/test_filters.py index 87f9d059..8e45ded1 100644 --- a/example/tests/test_filters.py +++ b/example/tests/test_filters.py @@ -470,7 +470,7 @@ def test_search_keywords(self): expected_result = { "data": [ { - "type": "posts", + "type": "entries", "id": "7", "attributes": { "headline": "ANTH3868X", diff --git a/example/tests/unit/test_renderer_class_methods.py b/example/tests/unit/test_renderer_class_methods.py index 6e7c9ea1..838f6819 100644 --- a/example/tests/unit/test_renderer_class_methods.py +++ b/example/tests/unit/test_renderer_class_methods.py @@ -102,8 +102,8 @@ def test_extract_attributes(): assert sorted(JSONRenderer.extract_attributes(fields, resource)) == sorted( expected ), "Regular fields should be extracted" - assert sorted(JSONRenderer.extract_attributes(fields, {})) == sorted( - {"username": ""} + assert ( + JSONRenderer.extract_attributes(fields, {}) == {} ), "Should not extract read_only fields on empty serializer" diff --git a/example/views.py b/example/views.py index 9c949684..da171698 100644 --- a/example/views.py +++ b/example/views.py @@ -112,7 +112,10 @@ class BlogCustomViewSet(JsonApiViewSet): class EntryViewSet(ModelViewSet): queryset = Entry.objects.all() - resource_name = "posts" + # TODO it should not be supported to overwrite resource name + # of endpoints with serializers as includes and sparse fields + # cannot be aware of it + # resource_name = "posts" def get_serializer_class(self): return EntrySerializer diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 639f0b11..8c632f6a 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -17,13 +17,26 @@ from rest_framework.settings import api_settings import rest_framework_json_api -from rest_framework_json_api import utils from rest_framework_json_api.relations import ( HyperlinkedMixin, ManySerializerMethodResourceRelatedField, ResourceRelatedField, SkipDataMixin, ) +from rest_framework_json_api.utils import ( + format_errors, + format_field_name, + format_field_names, + get_included_resources, + get_related_resource_type, + get_relation_instance, + get_resource_id, + get_resource_name, + get_resource_type_from_instance, + get_resource_type_from_serializer, + get_serializer_fields, + is_relationship_field, +) class JSONRenderer(renderers.JSONRenderer): @@ -57,31 +70,20 @@ class JSONRenderer(renderers.JSONRenderer): def extract_attributes(cls, fields, resource): """ Builds the `attributes` object of the JSON:API resource object. - """ - data = {} - for field_name, field in iter(fields.items()): - # ID is always provided in the root of JSON:API so remove it from attributes - if field_name == "id": - continue - # don't output a key for write only fields - if fields[field_name].write_only: - continue - # Skip fields with relations - if utils.is_relationship_field(field): - continue - # Skip read_only attribute fields when `resource` is an empty - # serializer. Prevents the "Raw Data" form of the browsable API - # from rendering `"foo": null` for read only fields - try: - resource[field_name] - except KeyError: - if fields[field_name].read_only: - continue + Ensures that ID which is always provided in a JSON:API resource object + and relationships are not returned. + """ - data.update({field_name: resource.get(field_name)}) + invalid_fields = {"id", api_settings.URL_FIELD_NAME} - return utils.format_field_names(data) + return { + format_field_name(field_name): value + for field_name, value in resource.items() + if field_name in fields + and field_name not in invalid_fields + and not is_relationship_field(fields[field_name]) + } @classmethod def extract_relationships(cls, fields, resource, resource_instance): @@ -107,14 +109,14 @@ def extract_relationships(cls, fields, resource, resource_instance): continue # Skip fields without relations - if not utils.is_relationship_field(field): + if not is_relationship_field(field): continue source = field.source - relation_type = utils.get_related_resource_type(field) + relation_type = get_related_resource_type(field) if isinstance(field, relations.HyperlinkedIdentityField): - resolved, relation_instance = utils.get_relation_instance( + resolved, relation_instance = get_relation_instance( resource_instance, source, field.parent ) if not resolved: @@ -166,7 +168,7 @@ def extract_relationships(cls, fields, resource, resource_instance): field, (relations.PrimaryKeyRelatedField, relations.HyperlinkedRelatedField), ): - resolved, relation = utils.get_relation_instance( + resolved, relation = get_relation_instance( resource_instance, f"{source}_id", field.parent ) if not resolved: @@ -189,7 +191,7 @@ def extract_relationships(cls, fields, resource, resource_instance): continue if isinstance(field, relations.ManyRelatedField): - resolved, relation_instance = utils.get_relation_instance( + resolved, relation_instance = get_relation_instance( resource_instance, source, field.parent ) if not resolved: @@ -222,9 +224,7 @@ def extract_relationships(cls, fields, resource, resource_instance): for nested_resource_instance in relation_instance: nested_resource_instance_type = ( relation_type - or utils.get_resource_type_from_instance( - nested_resource_instance - ) + or get_resource_type_from_instance(nested_resource_instance) ) relation_data.append( @@ -243,7 +243,7 @@ def extract_relationships(cls, fields, resource, resource_instance): ) continue - return utils.format_field_names(data) + return format_field_names(data) @classmethod def extract_relation_instance(cls, field, resource_instance): @@ -289,7 +289,7 @@ def extract_included( continue # Skip fields without relations - if not utils.is_relationship_field(field): + if not is_relationship_field(field): continue try: @@ -341,7 +341,7 @@ def extract_included( if isinstance(field, ListSerializer): serializer = field.child - relation_type = utils.get_resource_type_from_serializer(serializer) + relation_type = get_resource_type_from_serializer(serializer) relation_queryset = list(relation_instance) if serializer_data: @@ -350,11 +350,9 @@ def extract_included( nested_resource_instance = relation_queryset[position] resource_type = ( relation_type - or utils.get_resource_type_from_instance( - nested_resource_instance - ) + or get_resource_type_from_instance(nested_resource_instance) ) - serializer_fields = utils.get_serializer_fields( + serializer_fields = get_serializer_fields( serializer.__class__( nested_resource_instance, context=serializer.context ) @@ -378,10 +376,10 @@ def extract_included( ) if isinstance(field, Serializer): - relation_type = utils.get_resource_type_from_serializer(field) + relation_type = get_resource_type_from_serializer(field) # Get the serializer fields - serializer_fields = utils.get_serializer_fields(field) + serializer_fields = get_serializer_fields(field) if serializer_data: new_item = cls.build_json_resource_obj( serializer_fields, @@ -414,7 +412,8 @@ def extract_meta(cls, serializer, resource): meta_fields = getattr(meta, "meta_fields", []) data = {} for field_name in meta_fields: - data.update({field_name: resource.get(field_name)}) + if field_name in resource: + data.update({field_name: resource[field_name]}) return data @classmethod @@ -434,6 +433,24 @@ def extract_root_meta(cls, serializer, resource): data.update(json_api_meta) return data + @classmethod + def _filter_sparse_fields(cls, serializer, fields, resource_name): + request = serializer.context.get("request") + if request: + sparse_fieldset_query_param = f"fields[{resource_name}]" + sparse_fieldset_value = request.query_params.get( + sparse_fieldset_query_param + ) + if sparse_fieldset_value: + sparse_fields = sparse_fieldset_value.split(",") + return { + field_name: field + for field_name, field, in fields.items() + if field_name in sparse_fields + } + + return fields + @classmethod def build_json_resource_obj( cls, @@ -449,11 +466,15 @@ def build_json_resource_obj( """ # Determine type from the instance if the underlying model is polymorphic if force_type_resolution: - resource_name = utils.get_resource_type_from_instance(resource_instance) + resource_name = get_resource_type_from_instance(resource_instance) resource_data = { "type": resource_name, - "id": utils.get_resource_id(resource_instance, resource), + "id": get_resource_id(resource_instance, resource), } + + # TODO remove this filter by rewriting extract_relationships + # so it uses the serialized data as a basis + fields = cls._filter_sparse_fields(serializer, fields, resource_name) attributes = cls.extract_attributes(fields, resource) if attributes: resource_data["attributes"] = attributes @@ -468,7 +489,7 @@ def build_json_resource_obj( meta = cls.extract_meta(serializer, resource) if meta: - resource_data["meta"] = utils.format_field_names(meta) + resource_data["meta"] = format_field_names(meta) return resource_data @@ -485,7 +506,7 @@ def render_relationship_view( def render_errors(self, data, accepted_media_type=None, renderer_context=None): return super().render( - utils.format_errors(data), accepted_media_type, renderer_context + format_errors(data), accepted_media_type, renderer_context ) def render(self, data, accepted_media_type=None, renderer_context=None): @@ -495,7 +516,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None): request = renderer_context.get("request", None) # Get the resource name. - resource_name = utils.get_resource_name(renderer_context) + resource_name = get_resource_name(renderer_context) # If this is an error response, skip the rest. if resource_name == "errors": @@ -531,7 +552,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None): serializer = getattr(serializer_data, "serializer", None) - included_resources = utils.get_included_resources(request, serializer) + included_resources = get_included_resources(request, serializer) if serializer is not None: # Extract root meta for any type of serializer @@ -558,7 +579,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None): else: resource_serializer_class = serializer.child - fields = utils.get_serializer_fields(resource_serializer_class) + fields = get_serializer_fields(resource_serializer_class) force_type_resolution = getattr( resource_serializer_class, "_poly_force_type_resolution", False ) @@ -581,7 +602,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None): included_cache, ) else: - fields = utils.get_serializer_fields(serializer) + fields = get_serializer_fields(serializer) force_type_resolution = getattr( serializer, "_poly_force_type_resolution", False ) @@ -640,7 +661,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None): ) if json_api_meta: - render_data["meta"] = utils.format_field_names(json_api_meta) + render_data["meta"] = format_field_names(json_api_meta) return super().render(render_data, accepted_media_type, renderer_context) @@ -690,7 +711,6 @@ def get_includes_form(self, view): serializer_class = view.get_serializer_class() except AttributeError: return - if not hasattr(serializer_class, "included_serializers"): return diff --git a/rest_framework_json_api/serializers.py b/rest_framework_json_api/serializers.py index c680f60a..66650caf 100644 --- a/rest_framework_json_api/serializers.py +++ b/rest_framework_json_api/serializers.py @@ -75,35 +75,32 @@ class SparseFieldsetsMixin: Specification: https://jsonapi.org/format/#fetching-sparse-fieldsets """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - context = kwargs.get("context") - request = context.get("request") if context else None + @property + def _readable_fields(self): + request = self.context.get("request") if self.context else None + readable_fields = super()._readable_fields if request: - sparse_fieldset_query_param = "fields[{}]".format( - get_resource_type_from_serializer(self) - ) try: - param_name = next( - key - for key in request.query_params - if sparse_fieldset_query_param == key + resource_type = get_resource_type_from_serializer(self) + sparse_fieldset_query_param = f"fields[{resource_type}]" + + sparse_fieldset_value = request.query_params.get( + sparse_fieldset_query_param ) - except StopIteration: + if sparse_fieldset_value: + sparse_fields = sparse_fieldset_value.split(",") + return ( + field + for field in readable_fields + if field.field_name in sparse_fields + or field.field_name == api_settings.URL_FIELD_NAME + ) + except AttributeError: + # no type on serializer, must be used only as only nested pass - else: - fieldset = request.query_params.get(param_name).split(",") - # iterate over a *copy* of self.fields' underlying dict, because we may - # modify the original during the iteration. - # self.fields is a `rest_framework.utils.serializer_helpers.BindingDict` - for field_name, _field in self.fields.fields.copy().items(): - if ( - field_name == api_settings.URL_FIELD_NAME - ): # leave self link there - continue - if field_name not in fieldset: - self.fields.pop(field_name) + + return readable_fields class IncludedResourcesValidationMixin: diff --git a/tests/serializers.py b/tests/serializers.py index 4159039d..ddf28f98 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -26,7 +26,10 @@ class ForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = ForeignKeySource - fields = ("target",) + fields = ( + "name", + "target", + ) class ManyToManyTargetSerializer(serializers.ModelSerializer): diff --git a/tests/test_relations.py b/tests/test_relations.py index ad4bebcb..5f8883be1 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -39,7 +39,9 @@ def test_serialize( settings.JSON_API_FORMAT_TYPES = format_type settings.JSON_API_PLURALIZE_TYPES = pluralize_type - serializer = ForeignKeySourceSerializer(instance={"target": foreign_key_target}) + serializer = ForeignKeySourceSerializer( + instance={"target": foreign_key_target, "name": "Test"} + ) expected = { "type": resource_type, "id": str(foreign_key_target.pk), @@ -85,7 +87,10 @@ def test_deserialize( settings.JSON_API_PLURALIZE_TYPES = pluralize_type serializer = ForeignKeySourceSerializer( - data={"target": {"type": resource_type, "id": str(foreign_key_target.pk)}} + data={ + "target": {"type": resource_type, "id": str(foreign_key_target.pk)}, + "name": "Test", + } ) assert serializer.is_valid() @@ -191,7 +196,9 @@ def test_deserialize_many_to_many_relation( ], ) def test_invalid_resource_id_object(self, resource_identifier, error): - serializer = ForeignKeySourceSerializer(data={"target": resource_identifier}) + serializer = ForeignKeySourceSerializer( + data={"target": resource_identifier, "name": "Test"} + ) assert not serializer.is_valid() assert serializer.errors == {"target": [error]} diff --git a/tests/test_views.py b/tests/test_views.py index acba7e66..468c2cbd 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -237,6 +237,43 @@ def test_delete(self, client, model): assert BasicModel.objects.count() == 0 assert len(response.rendered_content) == 0 + @pytest.mark.urls(__name__) + def test_create_with_sparse_fields(self, client, foreign_key_target): + url = reverse("foreign-key-source-list") + data = { + "data": { + "id": None, + "type": "ForeignKeySource", + "attributes": {"name": "Test"}, + "relationships": { + "target": { + "data": { + "id": str(foreign_key_target.pk), + "type": "ForeignKeyTarget", + } + } + }, + } + } + response = client.post(f"{url}?fields[ForeignKeySource]=target", data=data) + assert response.status_code == status.HTTP_201_CREATED + foreign_key_source = ForeignKeySource.objects.first() + assert foreign_key_source.name == "Test" + assert response.json() == { + "data": { + "id": str(foreign_key_source.pk), + "type": "ForeignKeySource", + "relationships": { + "target": { + "data": { + "id": str(foreign_key_target.pk), + "type": "ForeignKeyTarget", + } + } + }, + } + } + class TestReadonlyModelViewSet: @pytest.mark.parametrize(