Skip to content

Commit

Permalink
Fix after filter with null
Browse files Browse the repository at this point in the history
  • Loading branch information
daanvdk committed Nov 15, 2024
1 parent fc40658 commit 90d07e0
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 41 deletions.
68 changes: 33 additions & 35 deletions binder/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from django.http.request import RawPostDataException
from django.http.multipartparser import MultiPartParser
from django.db import models, connections
from django.db.models import Q, F, Count
from django.db.models import Q, F, Count, Case, When
from django.db.models.lookups import Transform
from django.utils import timezone
from django.db import transaction
from django.db.models.expressions import BaseExpression, Value, CombinedExpression, OrderBy, ExpressionWrapper, Func
from django.db.models.expressions import BaseExpression, Value, CombinedExpression, OrderBy, ExpressionWrapper
from django.db.models.fields.reverse_related import ForeignObjectRel


Expand Down Expand Up @@ -53,17 +53,6 @@
}


class Tuple(Func):
template = '(%(expressions)s)'


class GreaterThan(Func):
template = '(%(expressions)s)'
arg_joiner = ' > '
arity = 2
output_field = models.BooleanField()


def get_joins_from_queryset(queryset):
"""
Given a queryset returns a set of lines that are used to determine which
Expand Down Expand Up @@ -1569,39 +1558,48 @@ def _after_expr(self, request, after_id, include_annotations):
raise BinderRequestError(f'invalid value for after_id: {after_id!r}')

# Now we will build up a comparison expr based on the order by
left_exprs = []
right_exprs = []
whens = []

for field in ordering:
# First we have to split of a leading '-' as indicating reverse
reverse = field.startswith('-')
if reverse:
field = field[1:]

# Then we build 2 exprs for the left hand side (objs in the query)
# and the right hand side (the object with the provided after id)
left_expr = F(field)
# Then we determine if nulls come last
if field.endswith('__nulls_last'):
field = field[:-12]
nulls_last = True
elif field.endswith('__nulls_first'):
field = field[:-13]
nulls_last = False
elif connections[self.model.objects.db].vendor == 'mysql':
# In MySQL null is considered to be the lowest possible value for ordering
nulls_last = reverse
else:
# In other databases null is considered to be the highest possible value for ordering
nulls_last = not reverse

right_expr = obj
# Then we determine what the value is for the obj we need to be after
value = obj
for attr in field.split('__'):
right_expr = getattr(right_expr, attr)
if isinstance(right_expr, models.Model):
right_expr = right_expr.pk
right_expr = Value(right_expr)

# To handle reverse we flip the expressions
if reverse:
left_exprs.append(right_expr)
right_exprs.append(left_expr)
value = getattr(value, attr)
if isinstance(value, models.Model):
value = value.pk

# Now we add some conditions for the comparison
if value is None:
# If the value is None, that means we have to add a condition for when the field is not None because only then it is different
# What the result should be in that case is determined by nulls last
whens.append(When(Q(**{field + '__isnull': False}), then=Value(not nulls_last)))
else:
left_exprs.append(left_expr)
right_exprs.append(right_expr)
# If the field is None we give a result based on nulls last
whens.append(When(Q(**{field: None}), then=Value(nulls_last)))
# Otherwise we check with comparisons, note that equality is intentionally left open with these two options so in that case we go on to the next field
whens.append(When(Q(**{field + '__lt': value}), then=Value(reverse)))
whens.append(When(Q(**{field + '__gt': value}), then=Value(not reverse)))

# Now we turn this into one big comparison
if len(ordering) == 1:
expr = GreaterThan(left_exprs[0], right_exprs[0])
else:
expr = GreaterThan(Tuple(*left_exprs), Tuple(*right_exprs))
expr = Case(*whens, default=Value(False))

return expr, required_annotations

Expand Down
38 changes: 32 additions & 6 deletions tests/test_after.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from django.contrib.auth.models import User
from django.test import TestCase

Expand All @@ -12,12 +14,12 @@ def setUp(self):
self.mapping = {}

zoo1 = Zoo.objects.create(name='Zoo 2')
self.mapping[Animal.objects.create(name='Animal F', zoo=zoo1).id] = 'f'
self.mapping[Animal.objects.create(name='Animal F', zoo=zoo1, birth_date='1997-03-19').id] = 'f'
self.mapping[Animal.objects.create(name='Animal E', zoo=zoo1).id] = 'e'
self.mapping[Animal.objects.create(name='Animal D', zoo=zoo1).id] = 'd'

zoo2 = Zoo.objects.create(name='Zoo 1')
self.mapping[Animal.objects.create(name='Animal C', zoo=zoo2).id] = 'c'
self.mapping[Animal.objects.create(name='Animal C', zoo=zoo2, birth_date='2000-08-05').id] = 'c'
self.mapping[Animal.objects.create(name='Animal B', zoo=zoo2).id] = 'b'
self.mapping[Animal.objects.create(name='Animal A', zoo=zoo2).id] = 'a'

Expand Down Expand Up @@ -48,17 +50,41 @@ def test_default(self):
self.assertEqual(self.get(after='d'), 'cba')

def test_ordered(self):
self.assertEqual(self.get('name', ), 'abcdef')
self.assertEqual(self.get('name'), 'abcdef')
self.assertEqual(self.get('name', after='c'), 'def')

def test_ordered_relation(self):
self.assertEqual(self.get('zoo,name', ), 'defabc')
self.assertEqual(self.get('zoo,name'), 'defabc')
self.assertEqual(self.get('zoo,name', after='f'), 'abc')

def test_ordered_reverse(self):
self.assertEqual(self.get('-name', ), 'fedcba')
self.assertEqual(self.get('-name'), 'fedcba')
self.assertEqual(self.get('-name', after='d'), 'cba')

def test_ordered_relation_field(self):
self.assertEqual(self.get('zoo.name', ), 'cbafed')
self.assertEqual(self.get('zoo.name'), 'cbafed')
self.assertEqual(self.get('zoo.name', after='a'), 'fed')

def test_ordered_with_null(self):
if os.environ.get('BINDER_TEST_MYSQL', '0') != '0':
# In MySQL null is considered to be the lowest possible value for ordering
self.assertEqual(self.get('birth_date'), 'edbafc')
self.assertEqual(self.get('birth_date', after='f'), 'c')
self.assertEqual(self.get('birth_date', after='e'), 'dbafc')
else:
# In other databases null is considered to be the highest possible value for ordering
self.assertEqual(self.get('birth_date'), 'fcedba')
self.assertEqual(self.get('birth_date', after='f'), 'cedba')
self.assertEqual(self.get('birth_date', after='e'), 'dba')

def test_ordered_with_null_reversed(self):
if os.environ.get('BINDER_TEST_MYSQL', '0') != '0':
# In MySQL null is considered to be the lowest possible value for ordering
self.assertEqual(self.get('-birth_date'), 'cfedba')
self.assertEqual(self.get('-birth_date', after='c'), 'fedba')
self.assertEqual(self.get('-birth_date', after='b'), 'a')
else:
# In other databases null is considered to be the highest possible value for ordering
self.assertEqual(self.get('-birth_date'), 'edbacf')
self.assertEqual(self.get('-birth_date', after='c'), 'f')
self.assertEqual(self.get('-birth_date', after='b'), 'acf')

0 comments on commit 90d07e0

Please sign in to comment.