diff --git a/django_filters/filterset.py b/django_filters/filterset.py index 23a88b45..415c8a7a 100644 --- a/django_filters/filterset.py +++ b/django_filters/filterset.py @@ -29,6 +29,13 @@ ) from .utils import get_all_model_fields, get_model_field, resolve_field, try_dbfield +try: + from django.db.models import GeneratedField +except ImportError: + DJANGO_50 = False +else: + DJANGO_50 = True + def remote_queryset(field): """ @@ -390,6 +397,13 @@ def handle_unrecognized_field(cls, field_name, message): def filter_for_field(cls, field, field_name, lookup_expr=None): if lookup_expr is None: lookup_expr = settings.DEFAULT_LOOKUP_EXPR + + # Handle GeneratedFields + if DJANGO_50 and isinstance(field, GeneratedField): + new_field = field.output_field + new_field.model = field.model + field = new_field + field, lookup_type = resolve_field(field, lookup_expr) default = { diff --git a/tests/models.py b/tests/models.py index e1a391b2..319ee8f8 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,7 +1,15 @@ from django import forms from django.db import models +from django.db.models import F from django.utils.translation import gettext_lazy as _ +try: + from django.db.models import GeneratedField +except ImportError: + DJANGO_50 = False +else: + DJANGO_50 = True + REGULAR = 0 MANAGER = 1 ADMIN = 2 @@ -210,3 +218,14 @@ class SpacewalkRecord(models.Model): astronaut = models.CharField(max_length=100) duration = models.DurationField() + + +if DJANGO_50: + class Rectangle(models.Model): + base = models.FloatField() + height = models.FloatField() + area = GeneratedField( + expression=F("base") * F("height"), + output_field=models.FloatField(), + db_persist=True, + ) diff --git a/tests/test_filterset.py b/tests/test_filterset.py index 77fafee5..c3fd6344 100644 --- a/tests/test_filterset.py +++ b/tests/test_filterset.py @@ -49,6 +49,13 @@ ) from .utils import MockQuerySet +try: + from .models import Rectangle +except ImportError: + DJANGO_50 = False +else: + DJANGO_50 = True + class HelperMethodsTests(TestCase): @unittest.skip("todo") @@ -119,6 +126,14 @@ def test_filter_found_for_uuidfield(self): self.assertIsInstance(result, UUIDFilter) self.assertEqual(result.field_name, "uuid") + def test_filter_found_for_generatedfield(self): + if not DJANGO_50: + return + f = Rectangle._meta.get_field("area") + result = FilterSet.filter_for_field(f, "area") + self.assertIsInstance(result, NumberFilter) + self.assertEqual(result.field_name, "area") + def test_filter_found_for_autofield(self): f = User._meta.get_field("id") result = FilterSet.filter_for_field(f, "id") @@ -256,7 +271,6 @@ def test_unknown_field_ignore_behavior(self): def test_unknown_field_invalid_initial_behavior(self): # Creation of new custom FilterSet to set initial field behavior with self.assertRaises(ValueError) as excinfo: - class InvalidBehaviorFilterSet(FilterSet): class Meta: model = NetworkSetting @@ -421,7 +435,6 @@ class Meta: def test_model_no_fields_or_exclude(self): with self.assertRaises(AssertionError) as excinfo: - class F(FilterSet): class Meta: model = Book @@ -548,7 +561,6 @@ def test_meta_fields_list_containing_unknown_fields(self): msg = "'Meta.fields' must not contain non-model field names: " "other, another" with self.assertRaisesMessage(TypeError, msg): - class F(FilterSet): username = CharFilter() @@ -560,7 +572,6 @@ def test_meta_fields_dict_containing_unknown_fields(self): msg = "'Meta.fields' must not contain non-model field names: other" with self.assertRaisesMessage(TypeError, msg): - class F(FilterSet): class Meta: model = Book @@ -575,7 +586,6 @@ def test_meta_fields_dict_containing_declarative_alias(self): msg = "'Meta.fields' must not contain non-model field names: other" with self.assertRaisesMessage(TypeError, msg): - class F(FilterSet): other = CharFilter() @@ -593,7 +603,6 @@ def test_meta_fields_invalid_lookup(self): msg = "Unsupported lookup 'flub' for field 'tests.User.username'." with self.assertRaisesMessage(FieldLookupError, msg): - class F(FilterSet): class Meta: model = User @@ -822,6 +831,7 @@ def test_filterset_factory_base_filter_meta_fields(self): class FilterSetBase(FilterSet): class Meta: fields = ["name"] + f1 = CharFilter() f2 = CharFilter() @@ -832,6 +842,7 @@ def test_filterset_factory_base_filter_fields_and_meta_fields(self): class FilterSetBase(FilterSet): class Meta: fields = ["name"] + f1 = CharFilter() f2 = CharFilter()