diff --git a/dynamic_rest/filters.py b/dynamic_rest/filters.py index 495c6d62..14797899 100644 --- a/dynamic_rest/filters.py +++ b/dynamic_rest/filters.py @@ -3,12 +3,12 @@ from django.core.exceptions import ValidationError as InternalValidationError from django.core.exceptions import ImproperlyConfigured from django.db.models import Q, Prefetch, Manager +from django.db.models.expressions import RawSQL, OrderBy import six from rest_framework import serializers from rest_framework.exceptions import ValidationError -from rest_framework.fields import BooleanField, NullBooleanField +from rest_framework.fields import BooleanField, NullBooleanField, JSONField from rest_framework.filters import BaseFilterBackend, OrderingFilter - from dynamic_rest.utils import is_truthy from dynamic_rest.conf import settings from dynamic_rest.datastructures import TreeMap @@ -127,6 +127,15 @@ def generate_query_key(self, serializer): # Recurse into nested field s = getattr(field, 'serializer', None) + if isinstance(field, JSONField): + # If a json field is found, append any terms following + j = i+1 + while j < len(self.field): + rewritten.append(self.field[j]) + j += 1 + if self.operator: + rewritten.append(self.operator) + return ('__'.join(rewritten), field) if isinstance(s, serializers.ListSerializer): s = s.child if not s: @@ -294,33 +303,41 @@ def _filters_to_query(self, includes, excludes, serializer, q=None): q: Q() object (optional) Returns: - Q() instance or None if no inclusion or exclusion filters - were specified. + Tuple of: + * Q() instance or None if no inclusion or exclusion filters + were specified. + * dictionary of {(field,): (operator, value)} for any json fields """ def rewrite_filters(filters, serializer): out = {} + json_out = {} for k, node in six.iteritems(filters): filter_key, field = node.generate_query_key(serializer) if isinstance(field, (BooleanField, NullBooleanField)): node.value = is_truthy(node.value) - out[filter_key] = node.value - return out + if isinstance(field, JSONField): + json_out[tuple(node.field)] = (node.operator, node.value) + else: + out[filter_key] = node.value + return out, json_out q = q or Q() + json_extras = None + if not includes and not excludes: - return None + return None, None if includes: - includes = rewrite_filters(includes, serializer) + includes, json_extras = rewrite_filters(includes, serializer) q &= Q(**includes) if excludes: - excludes = rewrite_filters(excludes, serializer) + excludes, json_extras = rewrite_filters(excludes, serializer) for k, v in six.iteritems(excludes): q &= ~Q(**{k: v}) - return q + return q, json_extras def _create_prefetch(self, source, queryset): return Prefetch(source, queryset=queryset) @@ -569,7 +586,7 @@ def _build_queryset( queryset = queryset.only(*only) # add request filters - query = self._filters_to_query( + query, json_extras = self._filters_to_query( includes=filters.get('_include'), excludes=filters.get('_exclude'), serializer=serializer @@ -579,12 +596,16 @@ def _build_queryset( if extra_filters: query = extra_filters if not query else extra_filters & query - if query: + if query or json_extras: # Convert internal django ValidationError to # APIException-based one in order to resolve validation error # from 500 status code to 400. try: queryset = queryset.filter(query) + + if json_extras: + extra_queries = self._get_json_queries(json_extras) + queryset = queryset.extra(where=extra_queries) except InternalValidationError as e: raise ValidationError( dict(e) if hasattr(e, 'error_dict') else list(e) @@ -620,6 +641,52 @@ def _build_queryset( queryset._using_prefetches = prefetches return queryset + def _get_json_queries(self, json_extras): + extra_queries = [] + + for json_field_names, (operator, value) in six.iteritems(json_extras): + if not operator: + query_operator = '=' + value = "'{}'".format(value) + elif operator in ('startswith', 'istartswith'): + query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE' + value = "'{}%%'".format(value) + elif operator in ('endswith', 'iendswith'): + query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE' + value = "'%%{}'".format(value) + elif operator in ('contains', 'icontains'): + query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE' + value = "'%%{}%%'".format(value) + + else: + raise InternalValidationError( + f"""Unsupported filter operation for nested JSON fields: + {operator}""" + ) + + extra_query = [] + + for idx, k in enumerate(json_field_names): + if idx == 0: + extra_query.append(k) + else: + extra_query.append("'{}'".format(k)) + + if idx == len(json_field_names) - 1: + continue + # the ->> operator returns a raw value + elif idx == len(json_field_names) - 2: + extra_query.append('->>') + # the -> operator returns JSON + else: + extra_query.append('->') + + extra_query.append(query_operator) + extra_query.append(value) + extra_queries.append(' '.join(extra_query)) + + return extra_queries + class FastDynamicFilterBackend(DynamicFilterBackend): def _create_prefetch(self, source, queryset): @@ -665,7 +732,16 @@ def filter_queryset(self, request, queryset, view): """ self.ordering_param = view.SORT - ordering = self.get_ordering(request, queryset, view) + ordering, nested = self.get_ordering(request, queryset, view) + if ordering and nested: + ordering_str = ''.join(ordering) + if ordering_str.startswith('-'): + return queryset.order_by( + OrderBy(RawSQL('LOWER( %s )' % (ordering_str[1:]), nested), + descending=True)) + return queryset.order_by( + OrderBy(RawSQL('LOWER(%s)' % (ordering_str), nested), + descending=False)) if ordering: queryset = queryset.order_by(*ordering) if any(['__' in o for o in ordering]): @@ -681,11 +757,13 @@ def get_ordering(self, request, queryset, view): This method overwrites the DRF default so it can parse the array. """ params = view.get_request_feature(view.SORT) + nested = [] if params: fields = [param.strip() for param in params] - valid_ordering, invalid_ordering = self.remove_invalid_fields( - queryset, fields, view - ) + valid_ordering, invalid_ordering, nested = \ + self.remove_invalid_fields( + queryset, fields, view + ) # if any of the sort fields are invalid, throw an error. # else return the ordering @@ -694,10 +772,10 @@ def get_ordering(self, request, queryset, view): "Invalid filter field: %s" % invalid_ordering ) else: - return valid_ordering + return valid_ordering, nested # No sorting was included - return self.get_default_ordering(view) + return self.get_default_ordering(view), nested def remove_invalid_fields(self, queryset, fields, view): """Remove invalid fields from an ordering. @@ -715,14 +793,14 @@ def remove_invalid_fields(self, queryset, fields, view): stripped_term = term.lstrip('-') # add back the '-' add the end if necessary reverse_sort_term = '' if len(stripped_term) is len(term) else '-' - ordering = self.ordering_for(stripped_term, view) + ordering, nested = self.ordering_for(stripped_term, view) if ordering: valid_orderings.append(reverse_sort_term + ordering) else: invalid_orderings.append(term) - return valid_orderings, invalid_orderings + return valid_orderings, invalid_orderings, nested def ordering_for(self, term, view): """ @@ -732,7 +810,7 @@ def ordering_for(self, term, view): Raise ImproperlyConfigured if serializer_class not set on view """ if not self._is_allowed_term(term, view): - return None + return None, None serializer = self._get_serializer_class(view)() serializer_chain = term.split('.') @@ -742,9 +820,27 @@ def ordering_for(self, term, view): for segment in serializer_chain[:-1]: field = serializer.get_all_fields().get(segment) + # If its a JSONField, construct a RawSQL command in the form + # of 'jsonField->{}'.format('nestedField')' or + # 'jsonField->>{}->{}'.format('nested','doubleNested') + if field and isinstance(field, JSONField): + json_chain_start = str(segment) + json_chain = '' + nested = [] + first = True + for nterm in serializer_chain[1:]: + if first: + json_chain += '->>%s' + first = False + else: + json_chain = '->%s' + json_chain + nested.append(nterm) + json_chain = json_chain_start + json_chain + return json_chain, nested + if not (field and field.source != '*' and isinstance(field, DynamicRelationField)): - return None + return None, None model_chain.append(field.source or segment) @@ -754,11 +850,11 @@ def ordering_for(self, term, view): last_field = serializer.get_all_fields().get(last_segment) if not last_field or last_field.source == '*': - return None + return None, None model_chain.append(last_field.source or last_segment) - return '__'.join(model_chain) + return '__'.join(model_chain), None def _is_allowed_term(self, term, view): valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) diff --git a/tests/migrations/0007_recipe_model.py b/tests/migrations/0007_recipe_model.py new file mode 100644 index 00000000..8aba8f90 --- /dev/null +++ b/tests/migrations/0007_recipe_model.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import migrations, models +from django.contrib.postgres.fields import JSONField + + +class Migration(migrations.Migration): + + dependencies = [ + ('tests', '0006_auto_20210921_1026'), + ] + + operations = [ + migrations.CreateModel( + name='recipe', + fields=[ + ('name', models.CharField(max_length=60)), + ('ingredients', JSONField(null=True)) + ] + ), + ] diff --git a/tests/models.py b/tests/models.py index b3d4d076..2b7dc46f 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,5 +1,6 @@ from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType +from django.contrib.postgres.fields import JSONField from django.db import models @@ -137,3 +138,8 @@ class Part(models.Model): car = models.ForeignKey(Car, on_delete=models.CASCADE) name = models.CharField(max_length=60) country = models.ForeignKey(Country, on_delete=models.CASCADE) + + +class Recipe(models.Model): + name = models.CharField(max_length=60) + ingredients = JSONField(null=True) diff --git a/tests/serializers.py b/tests/serializers.py index ca41de24..8ac63e18 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -22,6 +22,7 @@ Part, Permission, Profile, + Recipe, User, Zebra, ) @@ -323,3 +324,9 @@ class Meta: model = Car fields = ('id', 'name', 'country', 'parts') deferred_fields = ('name', 'country', 'parts') + + +class RecipeSerializer(DynamicModelSerializer): + class Meta: + model = Recipe + fields = ('name', 'ingredients') diff --git a/tests/setup.py b/tests/setup.py index c7e69abb..7deeaa04 100644 --- a/tests/setup.py +++ b/tests/setup.py @@ -11,6 +11,7 @@ Location, Part, Permission, + Recipe, User, Zebra ) @@ -29,14 +30,14 @@ def create_fixture(): types = [ 'users', 'groups', 'locations', 'permissions', 'events', 'cats', 'dogs', 'horses', 'zebras', - 'cars', 'countries', 'parts', + 'cars', 'countries', 'parts', 'recipes', ] Fixture = namedtuple('Fixture', types) fixture = Fixture( users=[], groups=[], locations=[], permissions=[], events=[], cats=[], dogs=[], horses=[], zebras=[], - cars=[], countries=[], parts=[] + cars=[], countries=[], parts=[], recipes=[] ) for i in range(0, 4): @@ -109,19 +110,19 @@ def create_fixture(): }, { 'name': 'Event 2', 'status': 'current', - 'location': 1 + 'location': 1, }, { 'name': 'Event 3', 'status': 'current', - 'location': 1 + 'location': 1, }, { 'name': 'Event 4', 'status': 'archived', - 'location': 2 + 'location': 2, }, { 'name': 'Event 5', 'status': 'current', - 'location': 2 + 'location': 2, }] for dog in dogs: @@ -235,4 +236,48 @@ def create_fixture(): country_id=part.get('country') )) + recipes = [{ + 'name': 'muffin', + 'ingredients': { + 'dough': { + 'water': '100_g', + 'flour': '100_g' + }, + 'chocolate_chips': '0_g' + } + }, { + 'name': 'chocolate chip muffin', + 'ingredients': { + 'dough': { + 'water': '100_g', + 'flour': '100_g' + }, + 'chocolate_chips': '20_g' + } + }, { + 'name': 'scone', + 'ingredients': { + 'dough': { + 'water': '50_g', + 'flour': '100_g' + }, + 'chocolate_chips': '0_g' + } + }, { + 'name': 'chocolate chip scone', + 'ingredients': { + 'dough': { + 'water': '50_g', + 'flour': '100_g' + }, + 'chocolate_chips': '20_g' + } + }] + + for recipe in recipes: + fixture.recipes.append(Recipe.objects.create( + name=recipe.get('name'), + ingredients=recipe.get('ingredients') + )) + return fixture diff --git a/tests/test_json.py b/tests/test_json.py new file mode 100644 index 00000000..8b308b0e --- /dev/null +++ b/tests/test_json.py @@ -0,0 +1,62 @@ +import json +from django.test import TestCase +from tests.setup import create_fixture + + +class TestJSONFieldFiltering(TestCase): + + def setUp(self): + self.fixture = create_fixture() + + def test_filter_by_first_level(self): + url = ( + '/recipes/?filter{ingredients.chocolate_chips}=20_g' + ) + response = self.client.get(url) + self.assertEqual(200, response.status_code) + content = json.loads(response.content.decode('utf-8')) + + self.assertTrue('recipes' in content) + self.assertEqual(2, len(content['recipes'])) + + self.assertTrue('name' in content['recipes'][0]) + self.assertEqual( + 'chocolate chip muffin', + content['recipes'][0]['name'] + ) + + self.assertTrue('name' in content['recipes'][1]) + self.assertEqual('chocolate chip scone', content['recipes'][1]['name']) + + def test_filter_by_second_level(self): + url = ( + '/recipes/?filter{ingredients.dough.water}=50_g' + ) + response = self.client.get(url) + self.assertEqual(200, response.status_code) + content = json.loads(response.content.decode('utf-8')) + + self.assertTrue('recipes' in content) + self.assertEqual(2, len(content['recipes'])) + + self.assertTrue('name' in content['recipes'][0]) + self.assertEqual('scone', content['recipes'][0]['name']) + + self.assertTrue('name' in content['recipes'][1]) + self.assertEqual('chocolate chip scone', content['recipes'][1]['name']) + + def test_filter_by_multiple_criteria(self): + url = ( + '/recipes/?' + 'filter{ingredients.dough.water}=50_g' + '&filter{ingredients.chocolate_chips}=20_g' + ) + response = self.client.get(url) + self.assertEqual(200, response.status_code) + content = json.loads(response.content.decode('utf-8')) + + self.assertTrue('recipes' in content) + self.assertEqual(1, len(content['recipes'])) + + self.assertTrue('name' in content['recipes'][0]) + self.assertEqual('chocolate chip scone', content['recipes'][0]['name']) diff --git a/tests/urls.py b/tests/urls.py index 2c711d1a..fb4a6d3e 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -14,6 +14,7 @@ router.register_resource(viewsets.DogViewSet) router.register_resource(viewsets.HorseViewSet) router.register_resource(viewsets.PermissionViewSet) +router.register_resource(viewsets.RecipeViewSet) router.register(r'zebras', viewsets.ZebraViewSet) # not canonical router.register(r'user_locations', viewsets.UserLocationViewSet) router.register(r'alternate_locations', viewsets.AlternateLocationViewSet) diff --git a/tests/viewsets.py b/tests/viewsets.py index cd04c0dc..25258150 100644 --- a/tests/viewsets.py +++ b/tests/viewsets.py @@ -11,6 +11,7 @@ Location, Permission, Profile, + Recipe, User, Zebra ) @@ -23,6 +24,7 @@ LocationSerializer, PermissionSerializer, ProfileSerializer, + RecipeSerializer, UserLocationSerializer, UserSerializer, ZebraSerializer @@ -179,3 +181,8 @@ class PermissionViewSet(DynamicModelViewSet): class CarViewSet(DynamicModelViewSet): serializer_class = CarSerializer queryset = Car.objects.all() + + +class RecipeViewSet(DynamicModelViewSet): + serializer_class = RecipeSerializer + queryset = Recipe.objects.all()