From 064adc7c623061bd9c38c900981051651e4b3a99 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 6 Oct 2023 21:47:26 +0200 Subject: [PATCH] feat(filters): support filter aliasing (PR #378) --- graphene_sqlalchemy/filters.py | 32 ++++++++++- graphene_sqlalchemy/tests/test_filters.py | 65 ++++++++++++++++------- graphene_sqlalchemy/types.py | 39 ++++++++------ 3 files changed, 97 insertions(+), 39 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 0e0de90..0af0a69 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,6 +1,7 @@ import re from typing import Any, Dict, List, Tuple, Type, TypeVar, Union +from graphql import Undefined from sqlalchemy import and_, not_, or_ from sqlalchemy.orm import Query, aliased # , selectinload @@ -15,6 +16,31 @@ "BaseTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer ) +class SQLAlchemyFilterInputField(graphene.InputField): + def __init__( + self, + type_, + model_attr, + name=None, + default_value=Undefined, + deprecation_reason=None, + description=None, + required=False, + _creation_counter=None, + **extra_args, + ): + super(SQLAlchemyFilterInputField, self).__init__( + type_, + name, + default_value, + deprecation_reason, + description, + required, + _creation_counter, + **extra_args, + ) + + self.model_attr = model_attr def _get_functions_by_regex( regex: str, subtract_regex: str, class_: Type @@ -138,7 +164,8 @@ def execute_filters( # Check with a profiler is required to determine necessity input_field = cls._meta.fields[field] if isinstance(input_field, graphene.Dynamic): - field_filter_type = input_field.get_type().type + input_field = input_field.get_type() + field_filter_type = input_field.type else: field_filter_type = cls._meta.fields[field].type # raise Exception @@ -155,7 +182,8 @@ def execute_filters( ) clauses.extend(_clauses) else: - model_field = getattr(model, field) + # Get the model attr from the inputfield in case the field is aliased in graphql + model_field = getattr(model, input_field.model_attr or field) if issubclass(field_filter_type, BaseTypeFilter): # Get the model to join on the Filter Query joined_model = field_filter_type._meta.model diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 6a93641..026247c 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -1,12 +1,8 @@ -import pytest -from sqlalchemy.sql.operators import is_ - import graphene +import pytest from graphene import Connection, relay +from sqlalchemy.sql.operators import is_ -from ..fields import SQLAlchemyConnectionField -from ..filters import FloatFilter -from ..types import ORMField, SQLAlchemyObjectType from .models import ( Article, Editor, @@ -20,6 +16,10 @@ Tag, ) from .utils import eventually_await_session, to_std_dicts +from ..fields import SQLAlchemyConnectionField +from ..filters import FloatFilter +from ..types import ORMField, SQLAlchemyObjectType + # TODO test that generated schema is correct for all examples with: # with open('schema.gql', 'w') as fp: @@ -110,26 +110,13 @@ class Meta: class Query(graphene.ObjectType): node = relay.Node.Field() - # # TODO how to create filterable singular field? - # article = graphene.Field(ArticleType) articles = SQLAlchemyConnectionField(ArticleType.connection) - # image = graphene.Field(ImageType) images = SQLAlchemyConnectionField(ImageType.connection) readers = SQLAlchemyConnectionField(ReaderType.connection) - # reporter = graphene.Field(ReporterType) reporters = SQLAlchemyConnectionField(ReporterType.connection) pets = SQLAlchemyConnectionField(PetType.connection) tags = SQLAlchemyConnectionField(TagType.connection) - # def resolve_article(self, _info): - # return session.query(Article).first() - - # def resolve_image(self, _info): - # return session.query(Image).first() - - # def resolve_reporter(self, _info): - # return session.query(Reporter).first() - return Query @@ -159,6 +146,44 @@ async def test_filter_simple(session): assert_and_raise_result(result, expected) +@pytest.mark.asyncio +async def test_filter_alias(session): + """ + Test aliasing of column names in the type + """ + await add_test_data(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + + lastNameAlias = ORMField(model_attr="last_name") + + class Query(graphene.ObjectType): + node = relay.Node.Field() + reporters = SQLAlchemyConnectionField(ReporterType.connection) + + query = """ + query { + reporters (filter: {lastNameAlias: {eq: "Roe", like: "%oe"}}) { + edges { + node { + firstName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"firstName": "Jane"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # Test a custom filter type @pytest.mark.asyncio async def test_filter_custom_type(session): @@ -1084,7 +1109,7 @@ async def test_filter_hybrid_property(session): result = to_std_dicts(result.data) assert len(result["carts"]["edges"]) == 1 assert ( - len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2 + len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2 ) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 269f451..7a693a4 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -32,7 +32,7 @@ sort_argument_for_object_type, sort_enum_for_object_type, ) -from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter +from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter, SQLAlchemyFilterInputField from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( @@ -151,13 +151,13 @@ def filter_field_from_field( type_, registry: Registry, model_attr: Any, -) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: + model_attr_name: str +) -> Optional[graphene.InputField]: # Field might be a SQLAlchemyObjectType, due to hybrid properties if issubclass(type_, SQLAlchemyObjectType): filter_class = registry.get_filter_for_base_type(type_) - return graphene.InputField(filter_class) # Enum Special Case - if issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty): + elif issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty): column = model_attr.columns[0] model_enum_type: Optional[sqlalchemy.types.Enum] = getattr(column, "type", None) if not getattr(model_enum_type, "enum_class", None): @@ -168,16 +168,16 @@ def filter_field_from_field( filter_class = registry.get_filter_for_scalar_type(type_) if not filter_class: warnings.warn( - f"No compatible filters found for {field.type}. Skipping field." + f"No compatible filters found for {field.type} with db name {model_attr_name}. Skipping field." ) return None - return graphene.InputField(filter_class) + return SQLAlchemyFilterInputField(filter_class, model_attr_name) def resolve_dynamic_relationship_filter( field: graphene.Dynamic, registry: Registry, - model_attr: Any, + model_attr_name: str ) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: # Resolve Dynamic Type type_ = get_nullable_type(field.get_type()) @@ -200,9 +200,12 @@ def resolve_dynamic_relationship_filter( reg_res = None if not reg_res: + warnings.warn( + f"No compatible filters found for {field} with db name {model_attr_name}. Skipping field." + ) return None - return graphene.InputField(reg_res) + return SQLAlchemyFilterInputField(reg_res, model_attr_name) def filter_field_from_type_field( @@ -210,22 +213,18 @@ def filter_field_from_type_field( registry: Registry, filter_type: Optional[Type], model_attr: Any, + model_attr_name: str ) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: # If a custom filter type was set for this field, use it here if filter_type: - return graphene.InputField(filter_type) + return SQLAlchemyFilterInputField(filter_type, model_attr_name) elif issubclass(type(field), graphene.Scalar): filter_class = registry.get_filter_for_scalar_type(type(field)) - return graphene.InputField(filter_class) + return SQLAlchemyFilterInputField(filter_class, model_attr_name) # If the generated field is Dynamic, it is always a relationship # (due to graphene-sqlalchemy's conversion mechanism). elif isinstance(field, graphene.Dynamic): - return Dynamic(partial(resolve_dynamic_relationship_filter, field, registry, model_attr)) - elif isinstance(field, graphene.Field): - if inspect.isfunction(field._type) or isinstance(field._type, partial): - return Dynamic(lambda: filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr)) - else: - return filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr) + return Dynamic(partial(resolve_dynamic_relationship_filter, field, registry, model_attr_name)) # Unsupported but theoretically possible cases, please drop us an issue with reproduction if you need them elif isinstance(field, graphene.List) or isinstance(field._type, graphene.List): # Pure lists are not yet supported @@ -233,6 +232,12 @@ def filter_field_from_type_field( elif isinstance(field._type, graphene.Dynamic): # Fields with nested dynamic Dynamic are not yet supported pass + # Order matters, this comes last as field._type == list also matches Field + elif isinstance(field, graphene.Field): + if inspect.isfunction(field._type) or isinstance(field._type, partial): + return Dynamic(lambda: filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr, model_attr_name)) + else: + return filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr, model_attr_name) def get_polymorphic_on(model): @@ -372,7 +377,7 @@ def construct_fields_and_filters( fields[orm_field_name] = field if filtering_enabled_for_field: filters[orm_field_name] = filter_field_from_type_field( - field, registry, filter_type, attr + field, registry, filter_type, attr, attr_name ) return fields, filters