diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index becdf37..ffa10da 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,6 +1,7 @@ import re from typing import Any, Dict, List, Tuple, Type, TypeVar, Union +from graphql import Undefined from sqlalchemy import and_, not_, or_ from sqlalchemy.orm import Query, aliased # , selectinload @@ -35,6 +36,33 @@ def _get_functions_by_regex( return matching_functions +class SQLAlchemyFilterInputField(graphene.InputField): + def __init__( + self, + type_, + model_attr, + name=None, + default_value=Undefined, + deprecation_reason=None, + description=None, + required=False, + _creation_counter=None, + **extra_args, + ): + super(SQLAlchemyFilterInputField, self).__init__( + type_, + name, + default_value, + deprecation_reason, + description, + required, + _creation_counter, + **extra_args, + ) + + self.model_attr = model_attr + + class BaseTypeFilter(graphene.InputObjectType): @classmethod def __init_subclass_with_meta__( @@ -141,7 +169,8 @@ def execute_filters( # Check with a profiler is required to determine necessity input_field = cls._meta.fields[field] if isinstance(input_field, graphene.Dynamic): - field_filter_type = input_field.get_type().type + input_field = input_field.get_type() + field_filter_type = input_field.type else: field_filter_type = cls._meta.fields[field].type # raise Exception @@ -158,7 +187,9 @@ def execute_filters( ) clauses.extend(_clauses) else: - model_field = getattr(model, field) + model_field = getattr( + model, input_field.model_attr + ) # getattr(model, field) if issubclass(field_filter_type, BaseTypeFilter): # Get the model to join on the Filter Query joined_model = field_filter_type._meta.model @@ -193,6 +224,10 @@ def execute_filters( ) clauses.extend(_clauses) elif issubclass(field_filter_type, FieldFilter): + print("got", model_field) + print(repr(model_field)) + print(model_field == 1) + print("with input", field_filters) query, _clauses = field_filter_type.execute_filters( query, model_field, field_filters ) diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index d29c4f6..2b3a660 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -9,8 +9,8 @@ if TYPE_CHECKING: # pragma: no_cover from graphene_sqlalchemy.filters import ( - FieldFilter, BaseTypeFilter, + FieldFilter, RelationshipFilter, ) @@ -165,9 +165,7 @@ def register_filter_for_base_type( raise TypeError("Expected BaseType, but got: {!r}".format(base_type)) if not issubclass(filter_obj, BaseTypeFilter): - raise TypeError( - "Expected BaseTypeFilter, but got: {!r}".format(filter_obj) - ) + raise TypeError("Expected BaseTypeFilter, but got: {!r}".format(filter_obj)) self._registry_base_type_filters[base_type] = filter_obj def get_filter_for_base_type(self, base_type: Type[BaseType]): diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 25511ff..c11db7e 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -37,6 +37,7 @@ IdFilter, IntFilter, RelationshipFilter, + SQLAlchemyFilterInputField, StringFilter, ) from .registry import Registry, get_global_registry @@ -154,13 +155,14 @@ def filter_field_from_type_field( field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], registry: Registry, filter_type: Optional[Type], + model_attr_name: Any, ) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: # If a custom filter type was set for this field, use it here if filter_type: - return graphene.InputField(filter_type) + return SQLAlchemyFilterInputField(filter_type, model_attr_name) if issubclass(type(field), graphene.Scalar): filter_class = registry.get_filter_for_scalar_type(type(field)) - return graphene.InputField(filter_class) + return SQLAlchemyFilterInputField(filter_class, model_attr_name) # If the field is Dynamic, we don't know its type yet and can't select the right filter if isinstance(field, graphene.Dynamic): @@ -179,7 +181,7 @@ def resolve_dynamic(): if not reg_res: print("filter class was none!!!") print(type_) - return graphene.InputField(reg_res) + return SQLAlchemyFilterInputField(reg_res, model_attr_name) elif isinstance(type_, Field): if isinstance(type_.type, graphene.List): inner_type = get_nullable_type(type_.type.of_type) @@ -187,10 +189,10 @@ def resolve_dynamic(): if not reg_res: print("filter class was none!!!") print(type_) - return graphene.InputField(reg_res) + return SQLAlchemyFilterInputField(reg_res, model_attr_name) reg_res = registry.get_filter_for_base_type(type_.type) - return graphene.InputField(reg_res) + return SQLAlchemyFilterInputField(reg_res, model_attr_name) else: warnings.warn(f"Unexpected Dynamic Type: {type_}") # Investigate # raise Exception(f"Unexpected Dynamic Type: {type_}") @@ -213,14 +215,14 @@ def resolve_dynamic(): # Field might be a SQLAlchemyObjectType, due to hybrid properties if issubclass(type_, SQLAlchemyObjectType): filter_class = registry.get_filter_for_base_type(type_) - return graphene.InputField(filter_class) + return SQLAlchemyFilterInputField(filter_class, model_attr_name) filter_class = registry.get_filter_for_scalar_type(type_) if not filter_class: warnings.warn( f"No compatible filters found for {field.type}. Skipping field." ) return None - return graphene.InputField(filter_class) + return SQLAlchemyFilterInputField(filter_class, model_attr_name) raise Exception(f"Expected a graphene.Field or graphene.Dynamic, but got: {field}") @@ -362,7 +364,7 @@ def construct_fields_and_filters( fields[orm_field_name] = field if filtering_enabled_for_field: filters[orm_field_name] = filter_field_from_type_field( - field, registry, filter_type + field, registry, filter_type, attr_name ) return fields, filters