diff --git a/django_filters/rest_framework/backends.py b/django_filters/rest_framework/backends.py index 835a67f9a..1d228073c 100644 --- a/django_filters/rest_framework/backends.py +++ b/django_filters/rest_framework/backends.py @@ -113,6 +113,29 @@ def get_coreschema_field(self, field): description=str(field.extra.get('help_text', '')) ) + def build_coreapi_field(self, name, field): + return compat.coreapi.Field( + name=name, + required=field.extra['required'], + location='query', + schema=self.get_coreschema_field(field), + ) + + def get_schema_field_names(self, field_name, field): + """ + Get the corresponding schema field names required to generate the openAPI schema + by referencing the widget suffixes if available. + """ + try: + suffixes = field.field_class.widget.suffixes + except AttributeError: + return [field_name] + else: + return [field_name] if not suffixes else [ + '{}_{}'.format(field_name, suffix) + for suffix in suffixes if suffix + ] + def get_schema_fields(self, view): # This is not compatible with widgets where the query param differs from the # filter's attribute name. Notably, this includes `MultiWidget`, where query @@ -130,14 +153,32 @@ def get_schema_fields(self, view): filterset_class = self.get_filterset_class(view, queryset) - return [] if not filterset_class else [ - compat.coreapi.Field( - name=field_name, - required=field.extra['required'], - location='query', - schema=self.get_coreschema_field(field) - ) for field_name, field in filterset_class.base_filters.items() - ] + return self.build_fields(filterset_class, self.build_coreapi_field) + + def build_fields(self, filterset_class, build_field_method): + if not filterset_class: + return [] + + return [build_field_method(schema_field_name, field) + for field_name, field in filterset_class.base_filters.items() + for schema_field_name in self.get_schema_field_names(field_name, field) + ] + + def build_openapi_field(self, field_name, field): + openapi_field = { + 'name': field_name, + 'required': field.extra['required'], + 'in': 'query', + 'description': field.label if field.label is not None else field_name, + 'schema': { + 'type': 'string', + }, + } + + if field.extra and 'choices' in field.extra: + openapi_field['schema']['enum'] = [c[0] for c in field.extra['choices']] + + return openapi_field def get_schema_operation_parameters(self, view): try: @@ -150,21 +191,4 @@ def get_schema_operation_parameters(self, view): filterset_class = self.get_filterset_class(view, queryset) - if not filterset_class: - return [] - - parameters = [] - for field_name, field in filterset_class.base_filters.items(): - parameter = { - 'name': field_name, - 'required': field.extra['required'], - 'in': 'query', - 'description': field.label if field.label is not None else field_name, - 'schema': { - 'type': 'string', - }, - } - if field.extra and 'choices' in field.extra: - parameter['schema']['enum'] = [c[0] for c in field.extra['choices']] - parameters.append(parameter) - return parameters + return self.build_fields(filterset_class, self.build_openapi_field) diff --git a/tests/rest_framework/test_backends.py b/tests/rest_framework/test_backends.py index 66fecf490..e638c6f03 100644 --- a/tests/rest_framework/test_backends.py +++ b/tests/rest_framework/test_backends.py @@ -65,6 +65,14 @@ class CategoryItemView(generics.ListCreateAPIView): filterset_fields = ["category"] +class FilterClassWithDateRangeFilter(SeveralFieldsFilter): + date = filters.DateFromToRangeFilter() + + +class ViewWithDateRangeFilter(FilterClassRootView): + filterset_class = FilterClassWithDateRangeFilter + + class GetFilterClassTests(TestCase): def test_filterset_class(self): @@ -241,6 +249,40 @@ class View(FilterClassRootView): self.assertEqual(fields, ['text', 'decimal', 'date', 'f']) + def test_fields_with_range_type_filter_default_suffixes(self): + + view = ViewWithDateRangeFilter() + backend = DjangoFilterBackend() + fields = backend.get_schema_fields(view) + field_names = [f.name for f in fields] + + self.assertIn("date_after", field_names) + self.assertIn("date_before", field_names) + + def test_fields_with_range_type_filter_with_custom_suffixes(self): + custom_suffixes = ['previous', 'later'] + + mock_widget = mock.Mock(suffixes=custom_suffixes) + mock_field_class = mock.Mock(widget=mock_widget) + + class CustomDateFromtoRangeFilter(filters.DateFromToRangeFilter): + field_class = mock_field_class + + class F(SeveralFieldsFilter): + date = CustomDateFromtoRangeFilter() + + class View(FilterClassRootView): + filterset_class = F + + view = View() + + backend = DjangoFilterBackend() + fields = backend.get_schema_fields(view) + field_names = [f.name for f in fields] + + for custom_suffix in custom_suffixes: + self.assertIn('date_{}'.format(custom_suffix), field_names) + class GetSchemaOperationParametersTests(TestCase): def test_get_operation_parameters_with_filterset_fields_list(self): @@ -269,6 +311,39 @@ def test_get_operation_parameters_with_filterset_fields_list_with_choices(self): }] ) + def test_get_operation_parameters_with_range_type_filter_with_default_suffixes(self): + view = ViewWithDateRangeFilter() + backend = DjangoFilterBackend() + + fields = backend.get_schema_operation_parameters(view) + fields = [f['name'] for f in fields] + + self.assertEqual(fields, ['text', 'decimal', 'date_after', 'date_before']) + + def test_get_operation_parameters_with_range_type_filter_with_custom_suffixes(self): + custom_suffixes = ['previous', 'later'] + + mock_widget = mock.Mock(suffixes=custom_suffixes) + mock_field_class = mock.Mock(widget=mock_widget) + + class CustomDateFromtoRangeFilter(filters.DateFromToRangeFilter): + field_class = mock_field_class + + class F(SeveralFieldsFilter): + date = CustomDateFromtoRangeFilter() + + class View(FilterClassRootView): + filterset_class = F + + view = View() + backend = DjangoFilterBackend() + + fields = backend.get_schema_operation_parameters(view) + field_names = [f['name'] for f in fields] + + for custom_suffix in custom_suffixes: + self.assertIn('date_{}'.format(custom_suffix), field_names) + class TemplateTests(TestCase): def test_backend_output(self):