Skip to content

Commit

Permalink
feat(filters): support filter aliasing (PR #378)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikwrede committed Oct 6, 2023
1 parent c38ebb3 commit 064adc7
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 39 deletions.
32 changes: 30 additions & 2 deletions graphene_sqlalchemy/filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from typing import Any, Dict, List, Tuple, Type, TypeVar, Union

from graphql import Undefined
from sqlalchemy import and_, not_, or_
from sqlalchemy.orm import Query, aliased # , selectinload

Expand All @@ -15,6 +16,31 @@
"BaseTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer
)

class SQLAlchemyFilterInputField(graphene.InputField):
def __init__(
self,
type_,
model_attr,
name=None,
default_value=Undefined,
deprecation_reason=None,
description=None,
required=False,
_creation_counter=None,
**extra_args,
):
super(SQLAlchemyFilterInputField, self).__init__(
type_,
name,
default_value,
deprecation_reason,
description,
required,
_creation_counter,
**extra_args,
)

self.model_attr = model_attr

def _get_functions_by_regex(
regex: str, subtract_regex: str, class_: Type
Expand Down Expand Up @@ -138,7 +164,8 @@ def execute_filters(
# Check with a profiler is required to determine necessity
input_field = cls._meta.fields[field]
if isinstance(input_field, graphene.Dynamic):
field_filter_type = input_field.get_type().type
input_field = input_field.get_type()
field_filter_type = input_field.type
else:
field_filter_type = cls._meta.fields[field].type
# raise Exception
Expand All @@ -155,7 +182,8 @@ def execute_filters(
)
clauses.extend(_clauses)
else:
model_field = getattr(model, field)
# Get the model attr from the inputfield in case the field is aliased in graphql
model_field = getattr(model, input_field.model_attr or field)
if issubclass(field_filter_type, BaseTypeFilter):
# Get the model to join on the Filter Query
joined_model = field_filter_type._meta.model
Expand Down
65 changes: 45 additions & 20 deletions graphene_sqlalchemy/tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import pytest
from sqlalchemy.sql.operators import is_

import graphene
import pytest
from graphene import Connection, relay
from sqlalchemy.sql.operators import is_

from ..fields import SQLAlchemyConnectionField
from ..filters import FloatFilter
from ..types import ORMField, SQLAlchemyObjectType
from .models import (
Article,
Editor,
Expand All @@ -20,6 +16,10 @@
Tag,
)
from .utils import eventually_await_session, to_std_dicts
from ..fields import SQLAlchemyConnectionField
from ..filters import FloatFilter
from ..types import ORMField, SQLAlchemyObjectType


# TODO test that generated schema is correct for all examples with:
# with open('schema.gql', 'w') as fp:
Expand Down Expand Up @@ -110,26 +110,13 @@ class Meta:

class Query(graphene.ObjectType):
node = relay.Node.Field()
# # TODO how to create filterable singular field?
# article = graphene.Field(ArticleType)
articles = SQLAlchemyConnectionField(ArticleType.connection)
# image = graphene.Field(ImageType)
images = SQLAlchemyConnectionField(ImageType.connection)
readers = SQLAlchemyConnectionField(ReaderType.connection)
# reporter = graphene.Field(ReporterType)
reporters = SQLAlchemyConnectionField(ReporterType.connection)
pets = SQLAlchemyConnectionField(PetType.connection)
tags = SQLAlchemyConnectionField(TagType.connection)

# def resolve_article(self, _info):
# return session.query(Article).first()

# def resolve_image(self, _info):
# return session.query(Image).first()

# def resolve_reporter(self, _info):
# return session.query(Reporter).first()

return Query


Expand Down Expand Up @@ -159,6 +146,44 @@ async def test_filter_simple(session):
assert_and_raise_result(result, expected)


@pytest.mark.asyncio
async def test_filter_alias(session):
"""
Test aliasing of column names in the type
"""
await add_test_data(session)

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
name = "Reporter"
interfaces = (relay.Node,)

lastNameAlias = ORMField(model_attr="last_name")

class Query(graphene.ObjectType):
node = relay.Node.Field()
reporters = SQLAlchemyConnectionField(ReporterType.connection)

query = """
query {
reporters (filter: {lastNameAlias: {eq: "Roe", like: "%oe"}}) {
edges {
node {
firstName
}
}
}
}
"""
expected = {
"reporters": {"edges": [{"node": {"firstName": "Jane"}}]},
}
schema = graphene.Schema(query=Query)
result = await schema.execute_async(query, context_value={"session": session})
assert_and_raise_result(result, expected)


# Test a custom filter type
@pytest.mark.asyncio
async def test_filter_custom_type(session):
Expand Down Expand Up @@ -1084,7 +1109,7 @@ async def test_filter_hybrid_property(session):
result = to_std_dicts(result.data)
assert len(result["carts"]["edges"]) == 1
assert (
len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2
len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2
)


Expand Down
39 changes: 22 additions & 17 deletions graphene_sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
sort_argument_for_object_type,
sort_enum_for_object_type,
)
from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter
from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter, SQLAlchemyFilterInputField
from .registry import Registry, get_global_registry
from .resolvers import get_attr_resolver, get_custom_resolver
from .utils import (
Expand Down Expand Up @@ -151,13 +151,13 @@ def filter_field_from_field(
type_,
registry: Registry,
model_attr: Any,
) -> Optional[Union[graphene.InputField, graphene.Dynamic]]:
model_attr_name: str
) -> Optional[graphene.InputField]:
# Field might be a SQLAlchemyObjectType, due to hybrid properties
if issubclass(type_, SQLAlchemyObjectType):
filter_class = registry.get_filter_for_base_type(type_)
return graphene.InputField(filter_class)
# Enum Special Case
if issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty):
elif issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty):
column = model_attr.columns[0]
model_enum_type: Optional[sqlalchemy.types.Enum] = getattr(column, "type", None)
if not getattr(model_enum_type, "enum_class", None):
Expand All @@ -168,16 +168,16 @@ def filter_field_from_field(
filter_class = registry.get_filter_for_scalar_type(type_)
if not filter_class:
warnings.warn(
f"No compatible filters found for {field.type}. Skipping field."
f"No compatible filters found for {field.type} with db name {model_attr_name}. Skipping field."
)
return None
return graphene.InputField(filter_class)
return SQLAlchemyFilterInputField(filter_class, model_attr_name)


def resolve_dynamic_relationship_filter(
field: graphene.Dynamic,
registry: Registry,
model_attr: Any,
model_attr_name: str
) -> Optional[Union[graphene.InputField, graphene.Dynamic]]:
# Resolve Dynamic Type
type_ = get_nullable_type(field.get_type())
Expand All @@ -200,39 +200,44 @@ def resolve_dynamic_relationship_filter(
reg_res = None

if not reg_res:
warnings.warn(
f"No compatible filters found for {field} with db name {model_attr_name}. Skipping field."
)
return None

return graphene.InputField(reg_res)
return SQLAlchemyFilterInputField(reg_res, model_attr_name)


def filter_field_from_type_field(
field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]],
registry: Registry,
filter_type: Optional[Type],
model_attr: Any,
model_attr_name: str
) -> Optional[Union[graphene.InputField, graphene.Dynamic]]:
# If a custom filter type was set for this field, use it here
if filter_type:
return graphene.InputField(filter_type)
return SQLAlchemyFilterInputField(filter_type, model_attr_name)
elif issubclass(type(field), graphene.Scalar):
filter_class = registry.get_filter_for_scalar_type(type(field))
return graphene.InputField(filter_class)
return SQLAlchemyFilterInputField(filter_class, model_attr_name)
# If the generated field is Dynamic, it is always a relationship
# (due to graphene-sqlalchemy's conversion mechanism).
elif isinstance(field, graphene.Dynamic):
return Dynamic(partial(resolve_dynamic_relationship_filter, field, registry, model_attr))
elif isinstance(field, graphene.Field):
if inspect.isfunction(field._type) or isinstance(field._type, partial):
return Dynamic(lambda: filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr))
else:
return filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr)
return Dynamic(partial(resolve_dynamic_relationship_filter, field, registry, model_attr_name))
# Unsupported but theoretically possible cases, please drop us an issue with reproduction if you need them
elif isinstance(field, graphene.List) or isinstance(field._type, graphene.List):
# Pure lists are not yet supported
pass
elif isinstance(field._type, graphene.Dynamic):
# Fields with nested dynamic Dynamic are not yet supported
pass
# Order matters, this comes last as field._type == list also matches Field
elif isinstance(field, graphene.Field):
if inspect.isfunction(field._type) or isinstance(field._type, partial):
return Dynamic(lambda: filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr, model_attr_name))
else:
return filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr, model_attr_name)


def get_polymorphic_on(model):
Expand Down Expand Up @@ -372,7 +377,7 @@ def construct_fields_and_filters(
fields[orm_field_name] = field
if filtering_enabled_for_field:
filters[orm_field_name] = filter_field_from_type_field(
field, registry, filter_type, attr
field, registry, filter_type, attr, attr_name
)

return fields, filters
Expand Down

0 comments on commit 064adc7

Please sign in to comment.