Skip to content

Commit

Permalink
Fix Enum collision with same choices & varying labels #790 #1104
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Nov 25, 2023
1 parent 82c00f8 commit 7cc36f0
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 23 deletions.
43 changes: 39 additions & 4 deletions drf_spectacular/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ def create_enum_component(name, schema):
generator.registry.register_on_missing(component)
return component

def extract_hash(schema):
if 'x-spec-enum-id' in schema:
# try to use the injected enum hash first as it generated from (name, value) tuples,
# which prevents collisions on choice sets only differing in labels not values.
return schema['x-spec-enum-id']
else:
# fall back to actual list hashing when we encounter enums not generated by us.
# remove blank/null entry for hashing. will be reconstructed in the last step
return list_hash([(i, i) for i in schema['enum'] if i not in ('', None)])

schemas = result.get('components', {}).get('schemas', {})

overrides = load_enum_name_overrides()
Expand All @@ -58,8 +68,8 @@ def create_enum_component(name, schema):
prop_schema = prop_schema.get('items', {})
if 'enum' not in prop_schema:
continue
# remove blank/null entry for hashing. will be reconstructed in the last step
prop_enum_cleaned_hash = list_hash([i for i in prop_schema['enum'] if i not in ['', None]])

prop_enum_cleaned_hash = extract_hash(prop_schema)
prop_hash_mapping[prop_name].add(prop_enum_cleaned_hash)
hash_name_mapping[prop_enum_cleaned_hash].add((component_name, prop_name))

Expand Down Expand Up @@ -110,14 +120,14 @@ def create_enum_component(name, schema):

prop_enum_original_list = prop_schema['enum']
prop_schema['enum'] = [i for i in prop_schema['enum'] if i not in ['', None]]
prop_hash = list_hash(prop_schema['enum'])
prop_hash = extract_hash(prop_schema)
# when choice sets are reused under multiple names, the generated name cannot be
# resolved from the hash alone. fall back to prop_name and hash for resolution.
enum_name = enum_name_mapping.get(prop_hash) or enum_name_mapping[prop_hash, prop_name]

# split property into remaining property and enum component parts
enum_schema = {k: v for k, v in prop_schema.items() if k in ['type', 'enum']}
prop_schema = {k: v for k, v in prop_schema.items() if k not in ['type', 'enum']}
prop_schema = {k: v for k, v in prop_schema.items() if k not in ['type', 'enum', 'x-spec-enum-id']}

# separate actual description from name-value tuples
if spectacular_settings.ENUM_GENERATE_CHOICE_DESCRIPTION:
Expand Down Expand Up @@ -148,6 +158,31 @@ def create_enum_component(name, schema):

# sort again with additional components
result['components'] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS)

# remove remaining ids that were not part of this hook (operation parameters mainly)
postprocess_schema_enum_id_removal(result, generator)

return result


def postprocess_schema_enum_id_removal(result, generator, **kwargs):
"""
Iterative modifying approach to scanning the whole schema and removing the
temporary helper ids that allowed us to distinguish similar enums.
"""
def clean(sub_result):
if isinstance(sub_result, dict):
for key in list(sub_result):
if key == 'x-spec-enum-id':
del sub_result['x-spec-enum-id']
else:
clean(sub_result[key])
elif isinstance(sub_result, (list, tuple)):
for item in sub_result:
clean(item)

clean(result)

return result


Expand Down
15 changes: 11 additions & 4 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ def build_choice_field(field):
if spectacular_settings.ENUM_GENERATE_CHOICE_DESCRIPTION:
schema['description'] = build_choice_description_list(field.choices.items())

schema['x-spec-enum-id'] = list_hash([(k, v) for k, v in field.choices.items() if k not in ('', None)])

return schema


Expand Down Expand Up @@ -499,10 +501,12 @@ def build_root_object(paths, components, version):
def safe_ref(schema):
"""
ensure that $ref has its own context and does not remove potential sibling
entries when $ref is substituted.
entries when $ref is substituted. also remove useless singular "allOf" .
"""
if '$ref' in schema and len(schema) > 1:
return {'allOf': [{'$ref': schema.pop('$ref')}], **schema}
if 'allOf' in schema and len(schema) == 1 and len(schema['allOf']) == 1:
return schema['allOf'][0]
return schema


Expand Down Expand Up @@ -815,11 +819,12 @@ def load_enum_name_overrides():
if inspect.isclass(choices) and issubclass(choices, Choices):
choices = choices.choices
if inspect.isclass(choices) and issubclass(choices, Enum):
choices = [c.value for c in choices]
choices = [(c.value, c.name) for c in choices]
normalized_choices = []
for choice in choices:
# Allow None values in the simple values list case
if isinstance(choice, str) or choice is None:
# TODO warning
normalized_choices.append((choice, choice)) # simple choice list
elif isinstance(choice[1], (list, tuple)):
normalized_choices.extend(choice[1]) # categorized nested choices
Expand All @@ -828,7 +833,9 @@ def load_enum_name_overrides():

# Get all of choice values that should be used in the hash, blank and None values get excluded
# in the post-processing hook for enum overrides, so we do the same here to ensure the hashes match
hashable_values = [value for value, _ in normalized_choices if value not in ['', None]]
hashable_values = [
(value, label) for value, label in normalized_choices if value not in ['', None]
]
overrides[list_hash(hashable_values)] = name

if len(spectacular_settings.ENUM_NAME_OVERRIDES) != len(overrides):
Expand All @@ -840,7 +847,7 @@ def load_enum_name_overrides():


def list_hash(lst):
return hashlib.sha256(json.dumps(list(lst), sort_keys=True, cls=JSONEncoder).encode()).hexdigest()
return hashlib.sha256(json.dumps(list(lst), sort_keys=True, cls=JSONEncoder).encode()).hexdigest()[:16]


def anchor_pattern(pattern: str) -> str:
Expand Down
18 changes: 17 additions & 1 deletion tests/test_plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from drf_spectacular.plumbing import (
analyze_named_regex_pattern, build_basic_type, build_choice_field, detype_pattern,
follow_field_source, force_instance, get_list_serializer, is_field, is_serializer,
resolve_type_hint,
resolve_type_hint, safe_ref,
)
from drf_spectacular.validation import validate_schema
from tests import generate_schema
Expand Down Expand Up @@ -377,3 +377,19 @@ def test_choicefield_choices_enum():
))
assert schema['enum'] == ['bluepill', 'redpill', '', None]
assert 'type' not in schema


def test_safe_ref():
schema = build_basic_type(str)
schema['$ref'] = '#/components/schemas/Foo'

schema = safe_ref(schema)
assert schema == {
'allOf': [{'$ref': '#/components/schemas/Foo'}],
'type': 'string'
}

del schema['type']
schema = safe_ref(schema)
assert schema == {'$ref': '#/components/schemas/Foo'}
assert safe_ref(schema) == safe_ref(schema)
65 changes: 51 additions & 14 deletions tests/test_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from rest_framework.views import APIView

try:
from django.db.models.enums import TextChoices
from django.db.models.enums import IntegerChoices, TextChoices
except ImportError:
TextChoices = object # type: ignore # django < 3.0 handling
IntegerChoices = object # type: ignore # django < 3.0 handling

from drf_spectacular.plumbing import list_hash, load_enum_name_overrides
from drf_spectacular.utils import OpenApiParameter, extend_schema
Expand Down Expand Up @@ -244,35 +245,39 @@ def partial_update(self, request):


def test_enum_override_variations(no_warnings):
enum_override_variations = ['language_list', 'LanguageEnum', 'LanguageStrEnum']
enum_override_variations = [
('language_list', [('en', 'en')]),
('LanguageEnum', [('en', 'EN')]),
('LanguageStrEnum', [('en', 'EN')]),
]
if DJANGO_VERSION > '3':
enum_override_variations += ['LanguageChoices', 'LanguageChoices.choices']
enum_override_variations += [
('LanguageChoices', [('en', 'En')]),
('LanguageChoices.choices', [('en', 'En')])
]

for variation in enum_override_variations:
for variation, expected_hashed_keys in enum_override_variations:
with mock.patch(
'drf_spectacular.settings.spectacular_settings.ENUM_NAME_OVERRIDES',
{'LanguageEnum': f'tests.test_postprocessing.{variation}'}
):
load_enum_name_overrides.cache_clear()
assert list_hash(['en']) in load_enum_name_overrides()
assert list_hash(expected_hashed_keys) in load_enum_name_overrides()


def test_enum_override_variations_with_blank_and_null(no_warnings):
enum_override_variations = [
'blank_null_language_list',
'BlankNullLanguageEnum',
('BlankNullLanguageStrEnum', ['en', 'None'])
('blank_null_language_list', [('en', 'en')]),
('BlankNullLanguageEnum', [('en', 'EN')]),
('BlankNullLanguageStrEnum', [('en', 'EN'), ('None', 'NULL')])
]
if DJANGO_VERSION > '3':
enum_override_variations += [
('BlankNullLanguageChoices', ['en', 'None']),
('BlankNullLanguageChoices.choices', ['en', 'None'])
('BlankNullLanguageChoices', [('en', 'En'), ('None', 'Null')]),
('BlankNullLanguageChoices.choices', [('en', 'En'), ('None', 'Null')])
]

for variation in enum_override_variations:
expected_hashed_keys = ['en']
if isinstance(variation, (list, tuple, )):
variation, expected_hashed_keys = variation
for variation, expected_hashed_keys in enum_override_variations:
with mock.patch(
'drf_spectacular.settings.spectacular_settings.ENUM_NAME_OVERRIDES',
{'LanguageEnum': f'tests.test_postprocessing.{variation}'}
Expand Down Expand Up @@ -340,3 +345,35 @@ def get(self, request):
uuid.UUID('93d7527f-de3c-4a76-9cc2-5578675630d4'),
uuid.UUID('47a4b873-409e-4e43-81d5-fafc3faeb849')
]


@pytest.mark.skipif(DJANGO_VERSION < '3', reason='Not available before Django 3.0')
def test_equal_choices_different_semantics(no_warnings):

class Health(IntegerChoices):
OK = 0
FAIL = 1

class Status(IntegerChoices):
GREEN = 0
RED = 1

class XSerializer(serializers.Serializer):
some_health = serializers.ChoiceField(choices=Health.choices)
some_status = serializers.ChoiceField(choices=Status.choices)

class XAPIView(APIView):
@extend_schema(responses=XSerializer)
def get(self, request):
pass # pragma: no cover

# This should not generate a warning even though the enum list is identical
# in both Enums. We now also differentiate the Enums by their labels.
schema = generate_schema('x', view=XAPIView)

assert schema['components']['schemas']['SomeHealthEnum'] == {
'enum': [0, 1], 'type': 'integer', 'description': '* `0` - Ok\n* `1` - Fail'
}
assert schema['components']['schemas']['SomeStatusEnum'] == {
'enum': [0, 1], 'type': 'integer', 'description': '* `0` - Green\n* `1` - Red'
}

0 comments on commit 7cc36f0

Please sign in to comment.