From 5e0e08b021e76376590dbed63971a18c694125fb Mon Sep 17 00:00:00 2001 From: Daan van der Kallen Date: Wed, 22 May 2024 11:57:51 +0200 Subject: [PATCH] Make id filter work for combined endpoint --- binder/plugins/views/combined.py | 40 +++++++++++++++++++++++++++----- tests/test_combined.py | 15 ++++++++++++ 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/binder/plugins/views/combined.py b/binder/plugins/views/combined.py index a76850ad..7801186e 100644 --- a/binder/plugins/views/combined.py +++ b/binder/plugins/views/combined.py @@ -75,7 +75,7 @@ def combined_view(request, router, names): # Get filtered & annotated querysets per name querysets = {} - for name in names: + for i, name in enumerate(names): view = views[name] queryset = view.get_queryset(request) @@ -92,11 +92,39 @@ def combined_view(request, router, names): queryset = annotate(queryset, request, sub_include_annotations.get('')) # filters - filters = { - 'id' if k == f'.{name}' else k[len(name) + 2:]: v - for k, v in request.GET.lists() - if k == f'.{name}' or k.startswith(f'.{name}.') - } + filters = {} + for k, v in request.GET.lists(): + if k == '.id' or k.startswith('.id:'): + values = [] + for value in v: + ids = [] + for id in value.split(','): + try: + id = int(id) + except ValueError: + # leave invalid values for the detailed view + pass + else: + if id % len(names) == i: + # this is a combined id that matches this view, + # so we can convert it to an id for the model + # itself + id = str(id // len(names)) + else: + # this id does not match this view, so we + # convert it to an id that never matches any + # model + id = '-1' + ids.append(id) + values.append(','.join(ids)) + filters[k[1:]] = values + + elif k == f'.{name}' or k.startswith(f'.{name}:'): + filters['id' + k[len(name) + 1:]] = v + + elif k.startswith(f'.{name}.'): + filters[k[len(name) + 2:]] = v + for field, values in filters.items(): for v in values: q, distinct = view._parse_filter(field, v, request, sub_include_annotations) diff --git a/tests/test_combined.py b/tests/test_combined.py index 873bf85e..6e18271e 100644 --- a/tests/test_combined.py +++ b/tests/test_combined.py @@ -152,3 +152,18 @@ def test_combined_order_by_multi_field(self): animal2.id * 2 + 1, # Harambe zoo1.id * 2, # Apenheul ]) + + def test_combined_id_filter(self): + zoo1 = Zoo.objects.create(name='Apenheul', founding_date='1980-01-01') + zoo2 = Zoo.objects.create(name='Emmen', founding_date='1990-01-01') + animal1 = Animal.objects.create(zoo=zoo1, name='Bokito', birth_date='1995-01-01') + animal2 = Animal.objects.create(zoo=zoo2, name='Harambe', birth_date='1985-01-01') + + res = self.client.get(f'/combined/zoo/animal/?.id:in={zoo1.id * 2},{animal2.id * 2 + 1}') + self.assertEqual(res.status_code, 200) + data = json.loads(res.content) + ids = {obj['id'] for obj in data['data']} + self.assertEqual(ids, { + zoo1.id * 2, # Apenheul + animal2.id * 2 + 1, # Harambe + })