From 7cc36f000364734d8d5abf086e735eb92e94dae2 Mon Sep 17 00:00:00 2001 From: "T. Franzel" Date: Fri, 24 Nov 2023 21:43:24 +0100 Subject: [PATCH] Fix Enum collision with same choices & varying labels #790 #1104 --- drf_spectacular/hooks.py | 43 +++++++++++++++++++++--- drf_spectacular/plumbing.py | 15 ++++++--- tests/test_plumbing.py | 18 +++++++++- tests/test_postprocessing.py | 65 ++++++++++++++++++++++++++++-------- 4 files changed, 118 insertions(+), 23 deletions(-) diff --git a/drf_spectacular/hooks.py b/drf_spectacular/hooks.py index 1ab6e742..e10e48ba 100644 --- a/drf_spectacular/hooks.py +++ b/drf_spectacular/hooks.py @@ -45,6 +45,16 @@ def create_enum_component(name, schema): generator.registry.register_on_missing(component) return component + def extract_hash(schema): + if 'x-spec-enum-id' in schema: + # try to use the injected enum hash first as it generated from (name, value) tuples, + # which prevents collisions on choice sets only differing in labels not values. + return schema['x-spec-enum-id'] + else: + # fall back to actual list hashing when we encounter enums not generated by us. + # remove blank/null entry for hashing. will be reconstructed in the last step + return list_hash([(i, i) for i in schema['enum'] if i not in ('', None)]) + schemas = result.get('components', {}).get('schemas', {}) overrides = load_enum_name_overrides() @@ -58,8 +68,8 @@ def create_enum_component(name, schema): prop_schema = prop_schema.get('items', {}) if 'enum' not in prop_schema: continue - # remove blank/null entry for hashing. will be reconstructed in the last step - prop_enum_cleaned_hash = list_hash([i for i in prop_schema['enum'] if i not in ['', None]]) + + prop_enum_cleaned_hash = extract_hash(prop_schema) prop_hash_mapping[prop_name].add(prop_enum_cleaned_hash) hash_name_mapping[prop_enum_cleaned_hash].add((component_name, prop_name)) @@ -110,14 +120,14 @@ def create_enum_component(name, schema): prop_enum_original_list = prop_schema['enum'] prop_schema['enum'] = [i for i in prop_schema['enum'] if i not in ['', None]] - prop_hash = list_hash(prop_schema['enum']) + prop_hash = extract_hash(prop_schema) # when choice sets are reused under multiple names, the generated name cannot be # resolved from the hash alone. fall back to prop_name and hash for resolution. enum_name = enum_name_mapping.get(prop_hash) or enum_name_mapping[prop_hash, prop_name] # split property into remaining property and enum component parts enum_schema = {k: v for k, v in prop_schema.items() if k in ['type', 'enum']} - prop_schema = {k: v for k, v in prop_schema.items() if k not in ['type', 'enum']} + prop_schema = {k: v for k, v in prop_schema.items() if k not in ['type', 'enum', 'x-spec-enum-id']} # separate actual description from name-value tuples if spectacular_settings.ENUM_GENERATE_CHOICE_DESCRIPTION: @@ -148,6 +158,31 @@ def create_enum_component(name, schema): # sort again with additional components result['components'] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS) + + # remove remaining ids that were not part of this hook (operation parameters mainly) + postprocess_schema_enum_id_removal(result, generator) + + return result + + +def postprocess_schema_enum_id_removal(result, generator, **kwargs): + """ + Iterative modifying approach to scanning the whole schema and removing the + temporary helper ids that allowed us to distinguish similar enums. + """ + def clean(sub_result): + if isinstance(sub_result, dict): + for key in list(sub_result): + if key == 'x-spec-enum-id': + del sub_result['x-spec-enum-id'] + else: + clean(sub_result[key]) + elif isinstance(sub_result, (list, tuple)): + for item in sub_result: + clean(item) + + clean(result) + return result diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index 95154759..ce40eaba 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -431,6 +431,8 @@ def build_choice_field(field): if spectacular_settings.ENUM_GENERATE_CHOICE_DESCRIPTION: schema['description'] = build_choice_description_list(field.choices.items()) + schema['x-spec-enum-id'] = list_hash([(k, v) for k, v in field.choices.items() if k not in ('', None)]) + return schema @@ -499,10 +501,12 @@ def build_root_object(paths, components, version): def safe_ref(schema): """ ensure that $ref has its own context and does not remove potential sibling - entries when $ref is substituted. + entries when $ref is substituted. also remove useless singular "allOf" . """ if '$ref' in schema and len(schema) > 1: return {'allOf': [{'$ref': schema.pop('$ref')}], **schema} + if 'allOf' in schema and len(schema) == 1 and len(schema['allOf']) == 1: + return schema['allOf'][0] return schema @@ -815,11 +819,12 @@ def load_enum_name_overrides(): if inspect.isclass(choices) and issubclass(choices, Choices): choices = choices.choices if inspect.isclass(choices) and issubclass(choices, Enum): - choices = [c.value for c in choices] + choices = [(c.value, c.name) for c in choices] normalized_choices = [] for choice in choices: # Allow None values in the simple values list case if isinstance(choice, str) or choice is None: + # TODO warning normalized_choices.append((choice, choice)) # simple choice list elif isinstance(choice[1], (list, tuple)): normalized_choices.extend(choice[1]) # categorized nested choices @@ -828,7 +833,9 @@ def load_enum_name_overrides(): # Get all of choice values that should be used in the hash, blank and None values get excluded # in the post-processing hook for enum overrides, so we do the same here to ensure the hashes match - hashable_values = [value for value, _ in normalized_choices if value not in ['', None]] + hashable_values = [ + (value, label) for value, label in normalized_choices if value not in ['', None] + ] overrides[list_hash(hashable_values)] = name if len(spectacular_settings.ENUM_NAME_OVERRIDES) != len(overrides): @@ -840,7 +847,7 @@ def load_enum_name_overrides(): def list_hash(lst): - return hashlib.sha256(json.dumps(list(lst), sort_keys=True, cls=JSONEncoder).encode()).hexdigest() + return hashlib.sha256(json.dumps(list(lst), sort_keys=True, cls=JSONEncoder).encode()).hexdigest()[:16] def anchor_pattern(pattern: str) -> str: diff --git a/tests/test_plumbing.py b/tests/test_plumbing.py index 7a7020f5..bd748c67 100644 --- a/tests/test_plumbing.py +++ b/tests/test_plumbing.py @@ -22,7 +22,7 @@ from drf_spectacular.plumbing import ( analyze_named_regex_pattern, build_basic_type, build_choice_field, detype_pattern, follow_field_source, force_instance, get_list_serializer, is_field, is_serializer, - resolve_type_hint, + resolve_type_hint, safe_ref, ) from drf_spectacular.validation import validate_schema from tests import generate_schema @@ -377,3 +377,19 @@ def test_choicefield_choices_enum(): )) assert schema['enum'] == ['bluepill', 'redpill', '', None] assert 'type' not in schema + + +def test_safe_ref(): + schema = build_basic_type(str) + schema['$ref'] = '#/components/schemas/Foo' + + schema = safe_ref(schema) + assert schema == { + 'allOf': [{'$ref': '#/components/schemas/Foo'}], + 'type': 'string' + } + + del schema['type'] + schema = safe_ref(schema) + assert schema == {'$ref': '#/components/schemas/Foo'} + assert safe_ref(schema) == safe_ref(schema) diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py index 3fcf61e4..0504a2de 100644 --- a/tests/test_postprocessing.py +++ b/tests/test_postprocessing.py @@ -9,9 +9,10 @@ from rest_framework.views import APIView try: - from django.db.models.enums import TextChoices + from django.db.models.enums import IntegerChoices, TextChoices except ImportError: TextChoices = object # type: ignore # django < 3.0 handling + IntegerChoices = object # type: ignore # django < 3.0 handling from drf_spectacular.plumbing import list_hash, load_enum_name_overrides from drf_spectacular.utils import OpenApiParameter, extend_schema @@ -244,35 +245,39 @@ def partial_update(self, request): def test_enum_override_variations(no_warnings): - enum_override_variations = ['language_list', 'LanguageEnum', 'LanguageStrEnum'] + enum_override_variations = [ + ('language_list', [('en', 'en')]), + ('LanguageEnum', [('en', 'EN')]), + ('LanguageStrEnum', [('en', 'EN')]), + ] if DJANGO_VERSION > '3': - enum_override_variations += ['LanguageChoices', 'LanguageChoices.choices'] + enum_override_variations += [ + ('LanguageChoices', [('en', 'En')]), + ('LanguageChoices.choices', [('en', 'En')]) + ] - for variation in enum_override_variations: + for variation, expected_hashed_keys in enum_override_variations: with mock.patch( 'drf_spectacular.settings.spectacular_settings.ENUM_NAME_OVERRIDES', {'LanguageEnum': f'tests.test_postprocessing.{variation}'} ): load_enum_name_overrides.cache_clear() - assert list_hash(['en']) in load_enum_name_overrides() + assert list_hash(expected_hashed_keys) in load_enum_name_overrides() def test_enum_override_variations_with_blank_and_null(no_warnings): enum_override_variations = [ - 'blank_null_language_list', - 'BlankNullLanguageEnum', - ('BlankNullLanguageStrEnum', ['en', 'None']) + ('blank_null_language_list', [('en', 'en')]), + ('BlankNullLanguageEnum', [('en', 'EN')]), + ('BlankNullLanguageStrEnum', [('en', 'EN'), ('None', 'NULL')]) ] if DJANGO_VERSION > '3': enum_override_variations += [ - ('BlankNullLanguageChoices', ['en', 'None']), - ('BlankNullLanguageChoices.choices', ['en', 'None']) + ('BlankNullLanguageChoices', [('en', 'En'), ('None', 'Null')]), + ('BlankNullLanguageChoices.choices', [('en', 'En'), ('None', 'Null')]) ] - for variation in enum_override_variations: - expected_hashed_keys = ['en'] - if isinstance(variation, (list, tuple, )): - variation, expected_hashed_keys = variation + for variation, expected_hashed_keys in enum_override_variations: with mock.patch( 'drf_spectacular.settings.spectacular_settings.ENUM_NAME_OVERRIDES', {'LanguageEnum': f'tests.test_postprocessing.{variation}'} @@ -340,3 +345,35 @@ def get(self, request): uuid.UUID('93d7527f-de3c-4a76-9cc2-5578675630d4'), uuid.UUID('47a4b873-409e-4e43-81d5-fafc3faeb849') ] + + +@pytest.mark.skipif(DJANGO_VERSION < '3', reason='Not available before Django 3.0') +def test_equal_choices_different_semantics(no_warnings): + + class Health(IntegerChoices): + OK = 0 + FAIL = 1 + + class Status(IntegerChoices): + GREEN = 0 + RED = 1 + + class XSerializer(serializers.Serializer): + some_health = serializers.ChoiceField(choices=Health.choices) + some_status = serializers.ChoiceField(choices=Status.choices) + + class XAPIView(APIView): + @extend_schema(responses=XSerializer) + def get(self, request): + pass # pragma: no cover + + # This should not generate a warning even though the enum list is identical + # in both Enums. We now also differentiate the Enums by their labels. + schema = generate_schema('x', view=XAPIView) + + assert schema['components']['schemas']['SomeHealthEnum'] == { + 'enum': [0, 1], 'type': 'integer', 'description': '* `0` - Ok\n* `1` - Fail' + } + assert schema['components']['schemas']['SomeStatusEnum'] == { + 'enum': [0, 1], 'type': 'integer', 'description': '* `0` - Green\n* `1` - Red' + }