diff --git a/caluma/core/filters.py b/caluma/core/filters.py index bd9ec963f..44693e73a 100644 --- a/caluma/core/filters.py +++ b/caluma/core/filters.py @@ -36,6 +36,7 @@ from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField from .relay import extract_global_id +from .types import DjangoConnectionField class GlobalIDFilter(Filter): @@ -414,7 +415,9 @@ class MetaFilterSet(FilterSet): meta_value = MetaValueFilter(field_name="meta") -class DjangoFilterConnectionField(filter.DjangoFilterConnectionField): +class DjangoFilterConnectionField( + filter.DjangoFilterConnectionField, DjangoConnectionField +): """ Django connection filter field with object type get_queryset support. diff --git a/caluma/core/pagination.py b/caluma/core/pagination.py new file mode 100644 index 000000000..0a6905439 --- /dev/null +++ b/caluma/core/pagination.py @@ -0,0 +1,94 @@ +from graphql_relay.connection.arrayconnection import ( + get_offset_with_default, + offset_to_cursor, +) +from graphql_relay.connection.connectiontypes import Connection, Edge, PageInfo + + +def connection_from_list(data, args=None, **kwargs): + """ + Replace graphql_relay.connection.arrayconnection.connection_from_list. + + This can be removed, when (or better if) + https://github.com/graphql-python/graphql-relay-py/issues/12 + is resolved. + + A simple function that accepts an array and connection arguments, and returns + a connection object for use in GraphQL. It uses array offsets as pagination, + so pagination will only work if the array is static. + """ + _len = len(data) + return connection_from_list_slice( + data, args, slice_start=0, list_length=_len, list_slice_length=_len, **kwargs + ) + + +def connection_from_list_slice( + list_slice, + args=None, + connection_type=None, + edge_type=None, + pageinfo_type=None, + slice_start=0, + list_length=0, + list_slice_length=None, +): + """ + Replace graphql_relay.connection.arrayconnection.connection_from_list_slice. + + This can be removed, when (or better if) + https://github.com/graphql-python/graphql-relay-py/issues/12 + is resolved. + + Given a slice (subset) of an array, returns a connection object for use in + GraphQL. + This function is similar to `connectionFromArray`, but is intended for use + cases where you know the cardinality of the connection, consider it too large + to materialize the entire array, and instead wish pass in a slice of the + total result large enough to cover the range specified in `args`. + """ + connection_type = connection_type or Connection + edge_type = edge_type or Edge + pageinfo_type = pageinfo_type or PageInfo + + args = args or {} + + before = args.get("before") + after = args.get("after") + first = args.get("first") + last = args.get("last") + if list_slice_length is None: # pragma: no cover + list_slice_length = len(list_slice) + slice_end = slice_start + list_slice_length + before_offset = get_offset_with_default(before, list_length) + after_offset = get_offset_with_default(after, -1) + + start_offset = max(slice_start - 1, after_offset, -1) + 1 + end_offset = min(slice_end, before_offset, list_length) + if isinstance(first, int): + end_offset = min(end_offset, start_offset + first) + if isinstance(last, int): + start_offset = max(start_offset, end_offset - last) + + # If supplied slice is too large, trim it down before mapping over it. + _slice = list_slice[ + max(start_offset - slice_start, 0) : list_slice_length + - (slice_end - end_offset) + ] + edges = [ + edge_type(node=node, cursor=offset_to_cursor(start_offset + i)) + for i, node in enumerate(_slice) + ] + + first_edge_cursor = edges[0].cursor if edges else None + last_edge_cursor = edges[-1].cursor if edges else None + + return connection_type( + edges=edges, + page_info=pageinfo_type( + start_cursor=first_edge_cursor, + end_cursor=last_edge_cursor, + has_previous_page=start_offset > 0, + has_next_page=end_offset < list_length, + ), + ) diff --git a/caluma/core/tests/test_pagination.py b/caluma/core/tests/test_pagination.py index 619952fec..ebf9d489f 100644 --- a/caluma/core/tests/test_pagination.py +++ b/caluma/core/tests/test_pagination.py @@ -1,3 +1,6 @@ +import pytest + + def test_offset_pagination(db, schema_executor, document_factory): document_factory(meta={"position": 0}) document_factory(meta={"position": 1}) @@ -26,3 +29,51 @@ def test_offset_pagination(db, schema_executor, document_factory): assert result.data["allDocuments"]["totalCount"] == 2 assert result.data["allDocuments"]["edges"][0]["node"]["meta"]["position"] == 2 assert result.data["allDocuments"]["edges"][1]["node"]["meta"]["position"] == 3 + + +@pytest.mark.parametrize( + "first,last,before,after,has_next,has_previous", + [ + (1, None, None, None, True, False), + (None, 1, None, None, False, True), + (None, None, None, None, False, False), + (None, None, None, "YXJyYXljb25uZWN0aW9uOjI=", False, True), + (None, None, "YXJyYXljb25uZWN0aW9uOjI=", None, True, False), + ], +) +def test_has_next_previous( + db, + first, + last, + before, + after, + has_next, + has_previous, + schema_executor, + document_factory, +): + document_factory.create_batch(5) + + query = """ + query AllDocumentsQuery ($first: Int, $last: Int, $before: String, $after: String) { + allDocuments(first: $first, last: $last, before: $before, after: $after) { + pageInfo { + hasNextPage + hasPreviousPage + } + edges { + node { + id + } + } + } + } + """ + + inp = {"first": first, "last": last, "before": before, "after": after} + + result = schema_executor(query, variables=inp) + + assert not result.errors + assert result.data["allDocuments"]["pageInfo"]["hasNextPage"] == has_next + assert result.data["allDocuments"]["pageInfo"]["hasPreviousPage"] == has_previous diff --git a/caluma/core/types.py b/caluma/core/types.py index 7a1d0bf75..6b17fc5cd 100644 --- a/caluma/core/types.py +++ b/caluma/core/types.py @@ -1,7 +1,15 @@ +from collections import Iterable + import graphene from django.core.exceptions import ImproperlyConfigured from django.db.models.query import QuerySet +from graphene.relay import PageInfo +from graphene.relay.connection import ConnectionField from graphene_django import types +from graphene_django.fields import DjangoConnectionField +from graphene_django.utils import maybe_queryset + +from .pagination import connection_from_list, connection_from_list_slice class Node(object): @@ -54,3 +62,68 @@ def resolve_total_count(self, info, **kwargs): if isinstance(self.iterable, QuerySet): return self.iterable.count() return len(self.iterable) + + +class DjangoConnectionField(DjangoConnectionField): + """ + Custom DjangoConnectionField with fix for hasNextPage/hasPreviousPage. + + This can be removed, when (or better if) + https://github.com/graphql-python/graphql-relay-py/issues/12 + is resolved. + """ + + @classmethod + def resolve_connection(cls, connection, default_manager, args, iterable): + if iterable is None: + iterable = default_manager + iterable = maybe_queryset(iterable) + if isinstance(iterable, QuerySet): + if iterable is not default_manager: + default_queryset = maybe_queryset(default_manager) + iterable = cls.merge_querysets(default_queryset, iterable) + _len = iterable.count() + else: # pragma: no cover + _len = len(iterable) + connection = connection_from_list_slice( + iterable, + args, + slice_start=0, + list_length=_len, + list_slice_length=_len, + connection_type=connection, + edge_type=connection.Edge, + pageinfo_type=PageInfo, + ) + connection.iterable = iterable + connection.length = _len + return connection + + +class ConnectionField(ConnectionField): + """ + Custom ConnectionField with fix for hasNextPage/hasPreviousPage. + + This can be removed, when (or better if) + https://github.com/graphql-python/graphql-relay-py/issues/12 + is resolved. + """ + + @classmethod + def resolve_connection(cls, connection_type, args, resolved): + if isinstance(resolved, connection_type): # pragma: no cover + return resolved + + assert isinstance(resolved, Iterable), ( + "Resolved value from the connection field have to be iterable or instance of {0}. " + 'Received "{1}"' + ).format(connection_type, resolved) + connection = connection_from_list( + resolved, + args, + connection_type=connection_type, + edge_type=connection_type.Edge, + pageinfo_type=PageInfo, + ) + connection.iterable = resolved + return connection diff --git a/caluma/form/schema.py b/caluma/form/schema.py index 96cdda750..f32250a68 100644 --- a/caluma/form/schema.py +++ b/caluma/form/schema.py @@ -1,12 +1,17 @@ import graphene -from graphene import ConnectionField, relay +from graphene import relay from graphene.types import ObjectType, generic from graphene_django.rest_framework import serializer_converter from ..core.filters import DjangoFilterConnectionField, DjangoFilterSetConnectionField from ..core.mutation import Mutation, UserDefinedPrimaryKeyMixin from ..core.relay import extract_global_id -from ..core.types import CountableConnectionBase, DjangoObjectType, Node +from ..core.types import ( + ConnectionField, + CountableConnectionBase, + DjangoObjectType, + Node, +) from ..data_source.data_source_handlers import get_data_source_data from ..data_source.schema import DataSourceDataConnection from . import filters, models, serializers @@ -168,7 +173,7 @@ class Meta: class TextQuestion(QuestionQuerysetMixin, FormDjangoObjectType): max_length = graphene.Int() placeholder = graphene.String() - format_validators = graphene.ConnectionField(FormatValidatorConnection) + format_validators = ConnectionField(FormatValidatorConnection) def resolve_format_validators(self, info): return get_format_validators(include=self.format_validators) @@ -192,7 +197,7 @@ class Meta: class TextareaQuestion(QuestionQuerysetMixin, FormDjangoObjectType): max_length = graphene.Int() placeholder = graphene.String() - format_validators = graphene.ConnectionField(FormatValidatorConnection) + format_validators = ConnectionField(FormatValidatorConnection) def resolve_format_validators(self, info): return get_format_validators(include=self.format_validators)