Skip to content

Commit

Permalink
Ensured that sparse fields only applies when rendering not when parsi…
Browse files Browse the repository at this point in the history
…ng (#1221)

* Added missing name field to ForeignKeySourceSerializer

* Only extract attributes provided by serialized data

* Added changelog
  • Loading branch information
sliverc authored May 1, 2024
1 parent 0eabc39 commit 6c609f8
Show file tree
Hide file tree
Showing 12 changed files with 156 additions and 86 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ any parts of the framework not mentioned in the documentation should generally b
* Avoided that an empty attributes dict is rendered in case serializer does not
provide any attribute fields.
* Avoided shadowing of exception when rendering errors (regression since 4.3.0).
* Ensured that sparse fields only applies when rendering, not when parsing.
* Adjusted that sparse fields properly removes meta fields when not defined.

### Removed

Expand Down
4 changes: 2 additions & 2 deletions example/tests/integration/test_non_paginated_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_multiple_entries_no_pagination(multiple_entries, client):
expected = {
"data": [
{
"type": "posts",
"type": "entries",
"id": "1",
"attributes": {
"headline": multiple_entries[0].headline,
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_multiple_entries_no_pagination(multiple_entries, client):
},
},
{
"type": "posts",
"type": "entries",
"id": "2",
"attributes": {
"headline": multiple_entries[1].headline,
Expand Down
2 changes: 1 addition & 1 deletion example/tests/integration/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_pagination_with_single_entry(single_entry, client):
expected = {
"data": [
{
"type": "posts",
"type": "entries",
"id": "1",
"attributes": {
"headline": single_entry.headline,
Expand Down
1 change: 1 addition & 0 deletions example/tests/integration/test_sparse_fieldsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_sparse_fieldset_valid_fields(client, entry):
entry = data[0]
assert entry["attributes"].keys() == {"headline"}
assert entry["relationships"].keys() == {"blog"}
assert "meta" not in entry


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion example/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def test_search_keywords(self):
expected_result = {
"data": [
{
"type": "posts",
"type": "entries",
"id": "7",
"attributes": {
"headline": "ANTH3868X",
Expand Down
4 changes: 2 additions & 2 deletions example/tests/unit/test_renderer_class_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def test_extract_attributes():
assert sorted(JSONRenderer.extract_attributes(fields, resource)) == sorted(
expected
), "Regular fields should be extracted"
assert sorted(JSONRenderer.extract_attributes(fields, {})) == sorted(
{"username": ""}
assert (
JSONRenderer.extract_attributes(fields, {}) == {}
), "Should not extract read_only fields on empty serializer"


Expand Down
5 changes: 4 additions & 1 deletion example/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ class BlogCustomViewSet(JsonApiViewSet):

class EntryViewSet(ModelViewSet):
queryset = Entry.objects.all()
resource_name = "posts"
# TODO it should not be supported to overwrite resource name
# of endpoints with serializers as includes and sparse fields
# cannot be aware of it
# resource_name = "posts"

def get_serializer_class(self):
return EntrySerializer
Expand Down
122 changes: 71 additions & 51 deletions rest_framework_json_api/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,26 @@
from rest_framework.settings import api_settings

import rest_framework_json_api
from rest_framework_json_api import utils
from rest_framework_json_api.relations import (
HyperlinkedMixin,
ManySerializerMethodResourceRelatedField,
ResourceRelatedField,
SkipDataMixin,
)
from rest_framework_json_api.utils import (
format_errors,
format_field_name,
format_field_names,
get_included_resources,
get_related_resource_type,
get_relation_instance,
get_resource_id,
get_resource_name,
get_resource_type_from_instance,
get_resource_type_from_serializer,
get_serializer_fields,
is_relationship_field,
)


class JSONRenderer(renderers.JSONRenderer):
Expand Down Expand Up @@ -57,31 +70,20 @@ class JSONRenderer(renderers.JSONRenderer):
def extract_attributes(cls, fields, resource):
"""
Builds the `attributes` object of the JSON:API resource object.
"""
data = {}
for field_name, field in iter(fields.items()):
# ID is always provided in the root of JSON:API so remove it from attributes
if field_name == "id":
continue
# don't output a key for write only fields
if fields[field_name].write_only:
continue
# Skip fields with relations
if utils.is_relationship_field(field):
continue
# Skip read_only attribute fields when `resource` is an empty
# serializer. Prevents the "Raw Data" form of the browsable API
# from rendering `"foo": null` for read only fields
try:
resource[field_name]
except KeyError:
if fields[field_name].read_only:
continue
Ensures that ID which is always provided in a JSON:API resource object
and relationships are not returned.
"""

data.update({field_name: resource.get(field_name)})
invalid_fields = {"id", api_settings.URL_FIELD_NAME}

return utils.format_field_names(data)
return {
format_field_name(field_name): value
for field_name, value in resource.items()
if field_name in fields
and field_name not in invalid_fields
and not is_relationship_field(fields[field_name])
}

@classmethod
def extract_relationships(cls, fields, resource, resource_instance):
Expand All @@ -107,14 +109,14 @@ def extract_relationships(cls, fields, resource, resource_instance):
continue

# Skip fields without relations
if not utils.is_relationship_field(field):
if not is_relationship_field(field):
continue

source = field.source
relation_type = utils.get_related_resource_type(field)
relation_type = get_related_resource_type(field)

if isinstance(field, relations.HyperlinkedIdentityField):
resolved, relation_instance = utils.get_relation_instance(
resolved, relation_instance = get_relation_instance(
resource_instance, source, field.parent
)
if not resolved:
Expand Down Expand Up @@ -166,7 +168,7 @@ def extract_relationships(cls, fields, resource, resource_instance):
field,
(relations.PrimaryKeyRelatedField, relations.HyperlinkedRelatedField),
):
resolved, relation = utils.get_relation_instance(
resolved, relation = get_relation_instance(
resource_instance, f"{source}_id", field.parent
)
if not resolved:
Expand All @@ -189,7 +191,7 @@ def extract_relationships(cls, fields, resource, resource_instance):
continue

if isinstance(field, relations.ManyRelatedField):
resolved, relation_instance = utils.get_relation_instance(
resolved, relation_instance = get_relation_instance(
resource_instance, source, field.parent
)
if not resolved:
Expand Down Expand Up @@ -222,9 +224,7 @@ def extract_relationships(cls, fields, resource, resource_instance):
for nested_resource_instance in relation_instance:
nested_resource_instance_type = (
relation_type
or utils.get_resource_type_from_instance(
nested_resource_instance
)
or get_resource_type_from_instance(nested_resource_instance)
)

relation_data.append(
Expand All @@ -243,7 +243,7 @@ def extract_relationships(cls, fields, resource, resource_instance):
)
continue

return utils.format_field_names(data)
return format_field_names(data)

@classmethod
def extract_relation_instance(cls, field, resource_instance):
Expand Down Expand Up @@ -289,7 +289,7 @@ def extract_included(
continue

# Skip fields without relations
if not utils.is_relationship_field(field):
if not is_relationship_field(field):
continue

try:
Expand Down Expand Up @@ -341,7 +341,7 @@ def extract_included(

if isinstance(field, ListSerializer):
serializer = field.child
relation_type = utils.get_resource_type_from_serializer(serializer)
relation_type = get_resource_type_from_serializer(serializer)
relation_queryset = list(relation_instance)

if serializer_data:
Expand All @@ -350,11 +350,9 @@ def extract_included(
nested_resource_instance = relation_queryset[position]
resource_type = (
relation_type
or utils.get_resource_type_from_instance(
nested_resource_instance
)
or get_resource_type_from_instance(nested_resource_instance)
)
serializer_fields = utils.get_serializer_fields(
serializer_fields = get_serializer_fields(
serializer.__class__(
nested_resource_instance, context=serializer.context
)
Expand All @@ -378,10 +376,10 @@ def extract_included(
)

if isinstance(field, Serializer):
relation_type = utils.get_resource_type_from_serializer(field)
relation_type = get_resource_type_from_serializer(field)

# Get the serializer fields
serializer_fields = utils.get_serializer_fields(field)
serializer_fields = get_serializer_fields(field)
if serializer_data:
new_item = cls.build_json_resource_obj(
serializer_fields,
Expand Down Expand Up @@ -414,7 +412,8 @@ def extract_meta(cls, serializer, resource):
meta_fields = getattr(meta, "meta_fields", [])
data = {}
for field_name in meta_fields:
data.update({field_name: resource.get(field_name)})
if field_name in resource:
data.update({field_name: resource[field_name]})
return data

@classmethod
Expand All @@ -434,6 +433,24 @@ def extract_root_meta(cls, serializer, resource):
data.update(json_api_meta)
return data

@classmethod
def _filter_sparse_fields(cls, serializer, fields, resource_name):
request = serializer.context.get("request")
if request:
sparse_fieldset_query_param = f"fields[{resource_name}]"
sparse_fieldset_value = request.query_params.get(
sparse_fieldset_query_param
)
if sparse_fieldset_value:
sparse_fields = sparse_fieldset_value.split(",")
return {
field_name: field
for field_name, field, in fields.items()
if field_name in sparse_fields
}

return fields

@classmethod
def build_json_resource_obj(
cls,
Expand All @@ -449,11 +466,15 @@ def build_json_resource_obj(
"""
# Determine type from the instance if the underlying model is polymorphic
if force_type_resolution:
resource_name = utils.get_resource_type_from_instance(resource_instance)
resource_name = get_resource_type_from_instance(resource_instance)
resource_data = {
"type": resource_name,
"id": utils.get_resource_id(resource_instance, resource),
"id": get_resource_id(resource_instance, resource),
}

# TODO remove this filter by rewriting extract_relationships
# so it uses the serialized data as a basis
fields = cls._filter_sparse_fields(serializer, fields, resource_name)
attributes = cls.extract_attributes(fields, resource)
if attributes:
resource_data["attributes"] = attributes
Expand All @@ -468,7 +489,7 @@ def build_json_resource_obj(

meta = cls.extract_meta(serializer, resource)
if meta:
resource_data["meta"] = utils.format_field_names(meta)
resource_data["meta"] = format_field_names(meta)

return resource_data

Expand All @@ -485,7 +506,7 @@ def render_relationship_view(

def render_errors(self, data, accepted_media_type=None, renderer_context=None):
return super().render(
utils.format_errors(data), accepted_media_type, renderer_context
format_errors(data), accepted_media_type, renderer_context
)

def render(self, data, accepted_media_type=None, renderer_context=None):
Expand All @@ -495,7 +516,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
request = renderer_context.get("request", None)

# Get the resource name.
resource_name = utils.get_resource_name(renderer_context)
resource_name = get_resource_name(renderer_context)

# If this is an error response, skip the rest.
if resource_name == "errors":
Expand Down Expand Up @@ -531,7 +552,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None):

serializer = getattr(serializer_data, "serializer", None)

included_resources = utils.get_included_resources(request, serializer)
included_resources = get_included_resources(request, serializer)

if serializer is not None:
# Extract root meta for any type of serializer
Expand All @@ -558,7 +579,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
else:
resource_serializer_class = serializer.child

fields = utils.get_serializer_fields(resource_serializer_class)
fields = get_serializer_fields(resource_serializer_class)
force_type_resolution = getattr(
resource_serializer_class, "_poly_force_type_resolution", False
)
Expand All @@ -581,7 +602,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
included_cache,
)
else:
fields = utils.get_serializer_fields(serializer)
fields = get_serializer_fields(serializer)
force_type_resolution = getattr(
serializer, "_poly_force_type_resolution", False
)
Expand Down Expand Up @@ -640,7 +661,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
)

if json_api_meta:
render_data["meta"] = utils.format_field_names(json_api_meta)
render_data["meta"] = format_field_names(json_api_meta)

return super().render(render_data, accepted_media_type, renderer_context)

Expand Down Expand Up @@ -690,7 +711,6 @@ def get_includes_form(self, view):
serializer_class = view.get_serializer_class()
except AttributeError:
return

if not hasattr(serializer_class, "included_serializers"):
return

Expand Down
Loading

0 comments on commit 6c609f8

Please sign in to comment.