From ac57fd4ba78a786f36aa0c948596a66766c33ec7 Mon Sep 17 00:00:00 2001 From: Paul Schweizer Date: Sun, 24 Jul 2022 01:05:31 +0200 Subject: [PATCH 01/81] Enable sorting when batching is enabled --- graphene_sqlalchemy/batching.py | 48 ++- graphene_sqlalchemy/fields.py | 27 +- graphene_sqlalchemy/tests/models.py | 18 + graphene_sqlalchemy/tests/test_batching.py | 396 ++++++++++++++------- 4 files changed, 336 insertions(+), 153 deletions(-) diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 85cc8855..7daf1c07 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -1,3 +1,5 @@ +from asyncio import get_event_loop + import aiodataloader import sqlalchemy from sqlalchemy.orm import Session, strategies @@ -5,16 +7,22 @@ from .utils import is_sqlalchemy_version_less_than +# Cache this across `batch_load_fn` calls +# This is so SQL string generation is cached under-the-hood via `bakery` +# Caching the relationship loader for each relationship prop. +RELATIONSHIP_LOADERS_CACHE = {} -def get_batch_resolver(relationship_prop): - # Cache this across `batch_load_fn` calls - # This is so SQL string generation is cached under-the-hood via `bakery` - selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) +def get_batch_resolver(relationship_prop): class RelationshipLoader(aiodataloader.DataLoader): cache = False + def __init__(self, relationship_prop, selectin_loader): + super().__init__() + self.relationship_prop = relationship_prop + self.selectin_loader = selectin_loader + async def batch_load_fn(self, parents): """ Batch loads the relationships of all the parents as one SQL statement. @@ -38,8 +46,8 @@ async def batch_load_fn(self, parents): SQLAlchemy's main maitainer suggestion. See https://git.io/JewQ7 """ - child_mapper = relationship_prop.mapper - parent_mapper = relationship_prop.parent + child_mapper = self.relationship_prop.mapper + parent_mapper = self.relationship_prop.parent session = Session.object_session(parents[0]) # These issues are very unlikely to happen in practice... @@ -62,7 +70,7 @@ async def batch_load_fn(self, parents): query_context = parent_mapper_query._compile_context() if is_sqlalchemy_version_less_than('1.4'): - selectin_loader._load_for_path( + self.selectin_loader._load_for_path( query_context, parent_mapper._path_registry, states, @@ -70,7 +78,7 @@ async def batch_load_fn(self, parents): child_mapper ) else: - selectin_loader._load_for_path( + self.selectin_loader._load_for_path( query_context, parent_mapper._path_registry, states, @@ -78,10 +86,26 @@ async def batch_load_fn(self, parents): child_mapper, None ) - - return [getattr(parent, relationship_prop.key) for parent in parents] - - loader = RelationshipLoader() + return [getattr(parent, self.relationship_prop.key) for parent in parents] + + def _get_loader(relationship_prop): + """Retrieve the cached loader of the given relationship.""" + loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None) + if loader is None: + selectin_loader = strategies.SelectInLoader( + relationship_prop, + (('lazy', 'selectin'),) + ) + loader = RelationshipLoader( + relationship_prop=relationship_prop, + selectin_loader=selectin_loader + ) + RELATIONSHIP_LOADERS_CACHE[relationship_prop] = loader + else: + loader.loop = get_event_loop() + return loader + + loader = _get_loader(relationship_prop) async def resolve(root, info, **args): return await loader.load(root) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index d7a83392..b7650684 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -129,20 +129,31 @@ def get_query(cls, model, info, sort=None, **args): return query -class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): +class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField): """ This is currently experimental. The API and behavior may change in future versions. Use at your own risk. """ - def wrap_resolve(self, parent_resolver): - return partial( - self.connection_resolver, - self.resolver, - get_nullable_type(self.type), - self.model, - ) + @classmethod + def connection_resolver(cls, resolver, connection_type, model, root, info, **args): + if root is None: + resolved = resolver(root, info, **args) + on_resolve = partial(cls.resolve_connection, connection_type, model, info, args) + else: + relationship_prop = None + for relationship in root.__class__.__mapper__.relationships: + if relationship.mapper.class_ == model: + relationship_prop = relationship + break + resolved = get_batch_resolver(relationship_prop)(root, info, **args) + on_resolve = partial(cls.resolve_connection, connection_type, root, info, args) + + if is_thenable(resolved): + return Promise.resolve(resolved).then(on_resolve) + + return on_resolve(resolved) @classmethod def from_relationship(cls, relationship, registry, **field_kwargs): diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index dc399ee0..c7a1d664 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -110,6 +110,24 @@ class Article(Base): headline = Column(String(100)) pub_date = Column(Date()) reporter_id = Column(Integer(), ForeignKey("reporters.id")) + readers = relationship( + "Reader", secondary="articles_readers", back_populates="articles" + ) + + +class Reader(Base): + __tablename__ = "readers" + id = Column(Integer(), primary_key=True) + name = Column(String(100)) + articles = relationship( + "Article", secondary="articles_readers", back_populates="readers" + ) + + +class ArticleReader(Base): + __tablename__ = "articles_readers" + article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True) + reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True) class ReflectedEditor(type): diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 1896900b..d7064c92 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -5,13 +5,13 @@ import pytest import graphene -from graphene import relay +from graphene import Connection, relay from ..fields import (BatchSQLAlchemyConnectionField, default_connection_field_factory) from ..types import ORMField, SQLAlchemyObjectType from ..utils import is_sqlalchemy_version_less_than -from .models import Article, HairKind, Pet, Reporter +from .models import Article, HairKind, Pet, Reader, Reporter from .utils import remove_cache_miss_stat, to_std_dicts @@ -73,6 +73,40 @@ def resolve_reporters(self, info): return graphene.Schema(query=Query) +def get_full_relay_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + name = "Article" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + class ReaderType(SQLAlchemyObjectType): + class Meta: + model = Reader + name = "Reader" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + class Query(graphene.ObjectType): + node = relay.Node.Field() + articles = BatchSQLAlchemyConnectionField(ArticleType.connection) + reporters = BatchSQLAlchemyConnectionField(ReporterType.connection) + readers = BatchSQLAlchemyConnectionField(ReaderType.connection) + + return graphene.Schema(query=Query) + + if is_sqlalchemy_version_less_than('1.2'): pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) @@ -82,11 +116,11 @@ async def test_many_to_one(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name='Reporter_1', ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name='Reporter_2', ) session.add(reporter_2) @@ -138,20 +172,20 @@ async def test_many_to_one(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "articles": [ - { - "headline": "Article_1", - "reporter": { - "firstName": "Reporter_1", - }, - }, - { - "headline": "Article_2", - "reporter": { - "firstName": "Reporter_2", - }, - }, - ], + "articles": [ + { + "headline": "Article_1", + "reporter": { + "firstName": "Reporter_1", + }, + }, + { + "headline": "Article_2", + "reporter": { + "firstName": "Reporter_2", + }, + }, + ], } @@ -160,11 +194,11 @@ async def test_one_to_one(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name='Reporter_1', ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name='Reporter_2', ) session.add(reporter_2) @@ -185,14 +219,14 @@ async def test_one_to_one(session_factory): # Starts new session to fully reset the engine / connection logging level session = session_factory() result = await schema.execute_async(""" - query { - reporters { - firstName - favoriteArticle { - headline - } + query { + reporters { + firstName + favoriteArticle { + headline + } + } } - } """, context_value={"session": session}) messages = sqlalchemy_logging_handler.messages @@ -216,20 +250,20 @@ async def test_one_to_one(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "favoriteArticle": { - "headline": "Article_1", - }, - }, - { - "firstName": "Reporter_2", - "favoriteArticle": { - "headline": "Article_2", - }, - }, - ], + "reporters": [ + { + "firstName": "Reporter_1", + "favoriteArticle": { + "headline": "Article_1", + }, + }, + { + "firstName": "Reporter_2", + "favoriteArticle": { + "headline": "Article_2", + }, + }, + ], } @@ -238,11 +272,11 @@ async def test_one_to_many(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name='Reporter_1', ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name='Reporter_2', ) session.add(reporter_2) @@ -271,18 +305,18 @@ async def test_one_to_many(session_factory): # Starts new session to fully reset the engine / connection logging level session = session_factory() result = await schema.execute_async(""" - query { - reporters { - firstName - articles(first: 2) { - edges { - node { - headline - } + query { + reporters { + firstName + articles(first: 2) { + edges { + node { + headline + } + } + } } - } } - } """, context_value={"session": session}) messages = sqlalchemy_logging_handler.messages @@ -306,42 +340,42 @@ async def test_one_to_many(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "articles": { - "edges": [ - { - "node": { - "headline": "Article_1", - }, - }, - { - "node": { - "headline": "Article_2", + "reporters": [ + { + "firstName": "Reporter_1", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_1", + }, + }, + { + "node": { + "headline": "Article_2", + }, + }, + ], }, - }, - ], - }, - }, - { - "firstName": "Reporter_2", - "articles": { - "edges": [ - { - "node": { - "headline": "Article_3", + }, + { + "firstName": "Reporter_2", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_3", + }, + }, + { + "node": { + "headline": "Article_4", + }, + }, + ], }, - }, - { - "node": { - "headline": "Article_4", - }, - }, - ], - }, - }, - ], + }, + ], } @@ -350,11 +384,11 @@ async def test_many_to_many(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name='Reporter_1', ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name='Reporter_2', ) session.add(reporter_2) @@ -385,18 +419,18 @@ async def test_many_to_many(session_factory): # Starts new session to fully reset the engine / connection logging level session = session_factory() result = await schema.execute_async(""" - query { - reporters { - firstName - pets(first: 2) { - edges { - node { - name - } + query { + reporters { + firstName + pets(first: 2) { + edges { + node { + name + } + } + } } - } } - } """, context_value={"session": session}) messages = sqlalchemy_logging_handler.messages @@ -420,42 +454,42 @@ async def test_many_to_many(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "pets": { - "edges": [ - { - "node": { - "name": "Pet_1", + "reporters": [ + { + "firstName": "Reporter_1", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_1", + }, + }, + { + "node": { + "name": "Pet_2", + }, + }, + ], }, - }, - { - "node": { - "name": "Pet_2", + }, + { + "firstName": "Reporter_2", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_3", + }, + }, + { + "node": { + "name": "Pet_4", + }, + }, + ], }, - }, - ], - }, - }, - { - "firstName": "Reporter_2", - "pets": { - "edges": [ - { - "node": { - "name": "Pet_3", - }, - }, - { - "node": { - "name": "Pet_4", - }, - }, - ], - }, - }, - ], + }, + ], } @@ -642,3 +676,99 @@ def resolve_reporters(self, info): select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] assert len(select_statements) == 2 + + +@pytest.mark.asyncio +async def test_batching_across_nested_relay_schema(session_factory): + session = session_factory() + + for first_name in "fgerbhjikzutzxsdfdqqa": + reporter = Reporter( + first_name=first_name, + ) + session.add(reporter) + article = Article(headline='Article') + article.reporter = reporter + session.add(article) + reader = Reader(name='Reader') + reader.articles = [article] + session.add(reader) + + session.commit() + session.close() + + schema = get_full_relay_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = await schema.execute_async(""" + query { + reporters { + edges { + node { + firstName + articles { + edges { + node { + id + readers { + edges { + node { + name + } + } + } + } + } + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + result = to_std_dicts(result.data) + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 2 + + +@pytest.mark.asyncio +async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_factory): + session = session_factory() + + for first_name, email in zip("cadbbb", "aaabac"): + reporter_1 = Reporter( + first_name=first_name, + email=email + ) + session.add(reporter_1) + article_1 = Article(headline="headline") + article_1.reporter = reporter_1 + session.add(article_1) + + session.commit() + session.close() + + schema = get_full_relay_schema() + + session = session_factory() + result = await schema.execute_async(""" + query { + reporters(sort: [FIRST_NAME_ASC, EMAIL_ASC]) { + edges { + node { + firstName + email + } + } + } + } + """, context_value={"session": session}) + + result = to_std_dicts(result.data) + assert [ + r["node"]["firstName"] + r["node"]["email"] + for r in result["reporters"]["edges"] + ] == ['aa', 'ba', 'bb', 'bc', 'ca', 'da'] From 6417061998823fb3531def8fee394bd81c6f4c24 Mon Sep 17 00:00:00 2001 From: Paul Schweizer Date: Sun, 31 Jul 2022 17:41:50 +0200 Subject: [PATCH 02/81] Deprecate UnsortedSQLAlchemyConnectionField and resetting RelationshipLoader between queries --- graphene_sqlalchemy/batching.py | 166 +++++++++++++++++--------------- graphene_sqlalchemy/fields.py | 83 ++++++++-------- 2 files changed, 130 insertions(+), 119 deletions(-) diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 7daf1c07..ee479346 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -1,4 +1,6 @@ +"""The dataloader uses "select in loading" strategy to load related entities.""" from asyncio import get_event_loop +from typing import Dict import aiodataloader import sqlalchemy @@ -7,102 +9,106 @@ from .utils import is_sqlalchemy_version_less_than + +class RelationshipLoader(aiodataloader.DataLoader): + cache = False + + def __init__(self, relationship_prop, selectin_loader): + super().__init__() + self.relationship_prop = relationship_prop + self.selectin_loader = selectin_loader + + async def batch_load_fn(self, parents): + """ + Batch loads the relationships of all the parents as one SQL statement. + + There is no way to do this out-of-the-box with SQLAlchemy but + we can piggyback on some internal APIs of the `selectin` + eager loading strategy. It's a bit hacky but it's preferable + than re-implementing and maintainnig a big chunk of the `selectin` + loader logic ourselves. + + The approach here is to build a regular query that + selects the parent and `selectin` load the relationship. + But instead of having the query emits 2 `SELECT` statements + when callling `all()`, we skip the first `SELECT` statement + and jump right before the `selectin` loader is called. + To accomplish this, we have to construct objects that are + normally built in the first part of the query in order + to call directly `SelectInLoader._load_for_path`. + + TODO Move this logic to a util in the SQLAlchemy repo as per + SQLAlchemy's main maitainer suggestion. + See https://git.io/JewQ7 + """ + child_mapper = self.relationship_prop.mapper + parent_mapper = self.relationship_prop.parent + session = Session.object_session(parents[0]) + + # These issues are very unlikely to happen in practice... + for parent in parents: + # assert parent.__mapper__ is parent_mapper + # All instances must share the same session + assert session is Session.object_session(parent) + # The behavior of `selectin` is undefined if the parent is dirty + assert parent not in session.dirty + + # Should the boolean be set to False? Does it matter for our purposes? + states = [(sqlalchemy.inspect(parent), True) for parent in parents] + + # For our purposes, the query_context will only used to get the session + query_context = None + if is_sqlalchemy_version_less_than('1.4'): + query_context = QueryContext(session.query(parent_mapper.entity)) + else: + parent_mapper_query = session.query(parent_mapper.entity) + query_context = parent_mapper_query._compile_context() + + if is_sqlalchemy_version_less_than('1.4'): + self.selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + ) + else: + self.selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + None, + ) + return [ + getattr(parent, self.relationship_prop.key) for parent in parents + ] + + # Cache this across `batch_load_fn` calls # This is so SQL string generation is cached under-the-hood via `bakery` # Caching the relationship loader for each relationship prop. -RELATIONSHIP_LOADERS_CACHE = {} +RELATIONSHIP_LOADERS_CACHE: Dict[ + sqlalchemy.orm.relationships.RelationshipProperty, RelationshipLoader +] = {} def get_batch_resolver(relationship_prop): - - class RelationshipLoader(aiodataloader.DataLoader): - cache = False - - def __init__(self, relationship_prop, selectin_loader): - super().__init__() - self.relationship_prop = relationship_prop - self.selectin_loader = selectin_loader - - async def batch_load_fn(self, parents): - """ - Batch loads the relationships of all the parents as one SQL statement. - - There is no way to do this out-of-the-box with SQLAlchemy but - we can piggyback on some internal APIs of the `selectin` - eager loading strategy. It's a bit hacky but it's preferable - than re-implementing and maintainnig a big chunk of the `selectin` - loader logic ourselves. - - The approach here is to build a regular query that - selects the parent and `selectin` load the relationship. - But instead of having the query emits 2 `SELECT` statements - when callling `all()`, we skip the first `SELECT` statement - and jump right before the `selectin` loader is called. - To accomplish this, we have to construct objects that are - normally built in the first part of the query in order - to call directly `SelectInLoader._load_for_path`. - - TODO Move this logic to a util in the SQLAlchemy repo as per - SQLAlchemy's main maitainer suggestion. - See https://git.io/JewQ7 - """ - child_mapper = self.relationship_prop.mapper - parent_mapper = self.relationship_prop.parent - session = Session.object_session(parents[0]) - - # These issues are very unlikely to happen in practice... - for parent in parents: - # assert parent.__mapper__ is parent_mapper - # All instances must share the same session - assert session is Session.object_session(parent) - # The behavior of `selectin` is undefined if the parent is dirty - assert parent not in session.dirty - - # Should the boolean be set to False? Does it matter for our purposes? - states = [(sqlalchemy.inspect(parent), True) for parent in parents] - - # For our purposes, the query_context will only used to get the session - query_context = None - if is_sqlalchemy_version_less_than('1.4'): - query_context = QueryContext(session.query(parent_mapper.entity)) - else: - parent_mapper_query = session.query(parent_mapper.entity) - query_context = parent_mapper_query._compile_context() - - if is_sqlalchemy_version_less_than('1.4'): - self.selectin_loader._load_for_path( - query_context, - parent_mapper._path_registry, - states, - None, - child_mapper - ) - else: - self.selectin_loader._load_for_path( - query_context, - parent_mapper._path_registry, - states, - None, - child_mapper, - None - ) - return [getattr(parent, self.relationship_prop.key) for parent in parents] + """get the resolve function for the given relationship.""" def _get_loader(relationship_prop): """Retrieve the cached loader of the given relationship.""" loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None) - if loader is None: + if loader is None or loader.loop != get_event_loop(): selectin_loader = strategies.SelectInLoader( - relationship_prop, - (('lazy', 'selectin'),) + relationship_prop, (('lazy', 'selectin'),) ) loader = RelationshipLoader( relationship_prop=relationship_prop, - selectin_loader=selectin_loader + selectin_loader=selectin_loader, ) RELATIONSHIP_LOADERS_CACHE[relationship_prop] = loader - else: - loader.loop = get_event_loop() return loader loader = _get_loader(relationship_prop) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index b7650684..905ad415 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -14,7 +14,7 @@ from .utils import EnumValue, get_query -class UnsortedSQLAlchemyConnectionField(ConnectionField): +class SQLAlchemyConnectionField(ConnectionField): @property def type(self): from .types import SQLAlchemyObjectType @@ -37,13 +37,45 @@ def type(self): ) return nullable_type.connection + def __init__(self, type_, *args, **kwargs): + nullable_type = get_nullable_type(type_) + if "sort" not in kwargs and nullable_type and issubclass(nullable_type, Connection): + # Let super class raise if type is not a Connection + try: + kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) + except (AttributeError, TypeError): + raise TypeError( + 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' + " to None to disabling the creation of the sort query argument".format( + nullable_type.__name__ + ) + ) + elif "sort" in kwargs and kwargs["sort"] is None: + del kwargs["sort"] + super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) + @property def model(self): return get_nullable_type(self.type)._meta.node._meta.model @classmethod - def get_query(cls, model, info, **args): - return get_query(model, info.context) + def get_query(cls, model, info, sort=None, **args): + query = get_query(model, info.context) + if sort is not None: + if not isinstance(sort, list): + sort = [sort] + sort_args = [] + # ensure consistent handling of graphene Enums, enum values and + # plain strings + for item in sort: + if isinstance(item, enum.Enum): + sort_args.append(item.value.value) + elif isinstance(item, EnumValue): + sort_args.append(item.value) + else: + sort_args.append(item) + query = query.order_by(*sort_args) + return query @classmethod def resolve_connection(cls, connection_type, model, info, args, resolved): @@ -90,43 +122,16 @@ def wrap_resolve(self, parent_resolver): ) -# TODO Rename this to SortableSQLAlchemyConnectionField -class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): +# TODO Remove in next major version +class UnsortedSQLAlchemyConnectionField(SQLAlchemyConnectionField): def __init__(self, type_, *args, **kwargs): - nullable_type = get_nullable_type(type_) - if "sort" not in kwargs and issubclass(nullable_type, Connection): - # Let super class raise if type is not a Connection - try: - kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) - except (AttributeError, TypeError): - raise TypeError( - 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' - " to None to disabling the creation of the sort query argument".format( - nullable_type.__name__ - ) - ) - elif "sort" in kwargs and kwargs["sort"] is None: - del kwargs["sort"] - super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) - - @classmethod - def get_query(cls, model, info, sort=None, **args): - query = get_query(model, info.context) - if sort is not None: - if not isinstance(sort, list): - sort = [sort] - sort_args = [] - # ensure consistent handling of graphene Enums, enum values and - # plain strings - for item in sort: - if isinstance(item, enum.Enum): - sort_args.append(item.value.value) - elif isinstance(item, EnumValue): - sort_args.append(item.value) - else: - sort_args.append(item) - query = query.order_by(*sort_args) - return query + super(UnsortedSQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) + warnings.warn( + "UnsortedSQLAlchemyConnectionField is deprecated and will be removed in the next " + "major version. Use SQLAlchemyConnectionField instead and set `sort = None` " + "if you want to disable sorting.", + DeprecationWarning, + ) class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField): From 535afbe183c109ce2fd5c7d0763f00f60cccfcc1 Mon Sep 17 00:00:00 2001 From: Paul Schweizer Date: Sun, 31 Jul 2022 17:52:49 +0200 Subject: [PATCH 03/81] Use field_name instead of column.key to build sort enum names to ensure the enum will get the actula field_name --- graphene_sqlalchemy/enums.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index a2ed17ad..19f40b7f 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -144,9 +144,9 @@ def sort_enum_for_object_type( column = orm_field.columns[0] if only_indexed and not (column.primary_key or column.index): continue - asc_name = get_name(column.key, True) + asc_name = get_name(field_name, True) asc_value = EnumValue(asc_name, column.asc()) - desc_name = get_name(column.key, False) + desc_name = get_name(field_name, False) desc_value = EnumValue(desc_name, column.desc()) if column.primary_key: default.append(asc_value) From b91fede916ff405079be47efedff67289fea8d73 Mon Sep 17 00:00:00 2001 From: Paul Schweizer Date: Sat, 6 Aug 2022 12:08:56 +0200 Subject: [PATCH 04/81] Adjust batching test to honor different selet in query structure in sqla1.2 --- graphene_sqlalchemy/tests/test_batching.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index d7064c92..7615b114 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -730,8 +730,15 @@ async def test_batching_across_nested_relay_schema(session_factory): messages = sqlalchemy_logging_handler.messages result = to_std_dicts(result.data) - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] - assert len(select_statements) == 2 + select_statements = [message for message in messages if 'SELECT' in message] + assert len(select_statements) == 4 + assert select_statements[-1].startswith("SELECT articles_1.id") + if is_sqlalchemy_version_less_than('1.3'): + assert select_statements[-2].startswith("SELECT reporters_1.id") + assert "WHERE reporters_1.id IN" in select_statements[-2] + else: + assert select_statements[-2].startswith("SELECT articles.reporter_id") + assert "WHERE articles.reporter_id IN" in select_statements[-2] @pytest.mark.asyncio From 5b8d0680e94ac39aa1d57c37e9bcdc14f127b3d8 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 25 Jul 2022 11:36:00 -0400 Subject: [PATCH 05/81] add filter tests for discussion --- graphene_sqlalchemy/tests/test_filters.py | 317 ++++++++++++++++++++++ 1 file changed, 317 insertions(+) create mode 100644 graphene_sqlalchemy/tests/test_filters.py diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py new file mode 100644 index 00000000..f298bdca --- /dev/null +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -0,0 +1,317 @@ +import graphene + +from ..fields import SQLAlchemyConnectionField +from ..filters import FloatFilter +from ..types import SQLAlchemyObjectType +from .models import Article, Editor, HairKind, Image, Pet, Reporter, Tag +from .utils import to_std_dicts + + +def add_test_data(session): + reporter = Reporter( + first_name='John', last_name='Doe', favorite_pet_kind='cat') + session.add(reporter) + pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT) + pet.reporters = reporter + session.add(pet) + pet = Pet(name='Snoopy', pet_kind='dog', hair_kind=HairKind.SHORT) + pet.reporters = reporter + session.add(pet) + reporter = Reporter( + first_name='John', last_name='Woe', favorite_pet_kind='cat') + session.add(reporter) + article = Article(headline='Hi!') + article.reporter = reporter + session.add(article) + reporter = Reporter( + first_name='Jane', last_name='Roe', favorite_pet_kind='dog') + session.add(reporter) + pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG) + pet.reporters.append(reporter) + session.add(pet) + editor = Editor(name="Jack") + session.add(editor) + session.commit() + + +def create_schema(session): + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + + class ImageType(SQLAlchemyObjectType): + class Meta: + model = Image + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + article = graphene.Field(ArticleType) + articles = graphene.List(ArticleType) + image = graphene.Field(ImageType) + images = graphene.List(ImageType) + reporter = graphene.Field(ReporterType) + reporters = graphene.List(ReporterType) + + def resolve_article(self, _info): + return session.query(Article).first() + + def resolve_articles(self, _info): + return session.query(Article) + + def resolve_image(self, _info): + return session.query(Image).first() + + def resolve_images(self, _info): + return session.query(Image) + + def resolve_reporter(self, _info): + return session.query(Reporter).first() + + def resolve_reporters(self, _info): + return session.query(Reporter) + + return Query + + +# Test a simple example of filtering +def test_filter_simple(session): + add_test_data(session) + Query = create_schema(session) + + query = """ + query { + reporters(filters: {firstName: "John"}) { + firstName + } + } + """ + expected = { + "reporters": [{"firstName": "John"}], + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +# Test a custom filter type +def test_filter_custom_type(session): + add_test_data(session) + Query = create_schema(session) + + class MathFilter(FloatFilter): + def divisibleBy(dividend, divisor): + return dividend % divisor == 0 + + class ExtraQuery: + pets = SQLAlchemyConnectionField(Pet, filters=MathFilter()) + + class CustomQuery(Query, ExtraQuery): + pass + + query = """ + query { + pets (filters: { + legs: {divisibleBy: 2} + }) { + name + } + } + """ + expected = { + "pets": [{"name": "Garfield"}, {"name": "Lassie"}], + } + schema = graphene.Schema(query=CustomQuery) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + +# Test a 1:1 relationship +def test_filter_relationship_one_to_one(session): + article = Article(headline='Hi!') + image = Image(external_id=1, description="A beautiful image.") + article.image = image + session.add(article) + session.add(image) + session.commit() + + Query = create_schema(session) + + query = """ + query { + article (filters: { + image: {description: "A beautiful image."} + }) { + firstName + } + } + """ + expected = { + "article": [{"headline": "Hi!"}], + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +# Test a 1:n relationship +def test_filter_relationship_one_to_many(session): + add_test_data(session) + Query = create_schema(session) + + query = """ + query { + reporter (filters: { + pets: { + name: {in: ["Garfield", "Snoopy"]} + } + }) { + firstName + lastName + } + } + """ + expected = { + "reporter": [{"firstName": "John"}, {"lastName": "Doe"}], + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +# Test a n:m relationship +def test_filter_relationship_many_to_many(session): + article1 = Article(headline='Article! Look!') + article2 = Article(headline='Woah! Another!') + tag1 = Tag(name="sensational") + tag2 = Tag(name="eye-grabbing") + article1.tags.append(tag1) + article2.tags.append([tag1, tag2]) + session.add(article1) + session.add(article2) + session.add(tag1) + session.add(tag2) + session.commit() + + Query = create_schema(session) + + query = """ + query { + articles (filters: { + tags: { name: { in: ["sensational", "eye-grabbing"] } } + }) { + headline + } + } + """ + expected = { + "articles": [{"headline": "Woah! Another!"}], + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +# Test connecting filters with "and" +def test_filter_logic_and(session): + add_test_data(session) + + Query = create_schema(session) + + query = """ + query { + reporters (filters: { + and: [ + {firstName: "John"}, + {favoritePetKind: "cat"}, + ] + }) { + lastName + } + } + """ + expected = { + "reporters": [{"lastName": "Doe"}, {"lastName": "Woe"}], + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +# Test connecting filters with "or" +def test_filter_logic_or(session): + add_test_data(session) + Query = create_schema(session) + + query = """ + query { + reporters (filters: { + or: [ + {lastName: "Woe"}, + {favoritePetKind: "dog"}, + ] + }) { + firstName + lastName + } + } + """ + expected = { + "reporters": [ + {"firstName": "John", "lastName": "Woe"}, + {"firstName": "Jane", "lastName": "Roe"}, + ], + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +# Test connecting filters with "and" and "or" together +def test_filter_logic_and_or(session): + add_test_data(session) + Query = create_schema(session) + + query = """ + query { + reporters (filters: { + and: [ + {firstName: "John"}, + or : [ + {lastName: "Doe"}, + {favoritePetKind: "cat"}, + ] + ] + }) { + firstName + } + } + """ + expected = { + "reporters": [{"firstName": "John"}, {"firstName": "Jane"}], + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +# TODO hybrid property +def test_filter_hybrid_property(session): + raise NotImplementedError From 861613e6ce92b8d09d3f9b12e6a93b6b0a63980e Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Tue, 2 Aug 2022 13:30:14 -0400 Subject: [PATCH 06/81] add typing for custom filter --- graphene_sqlalchemy/tests/test_filters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index f298bdca..13b58427 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -104,8 +104,8 @@ def test_filter_custom_type(session): Query = create_schema(session) class MathFilter(FloatFilter): - def divisibleBy(dividend, divisor): - return dividend % divisor == 0 + def divisibleBy(dividend: float, divisor: float) -> float: + return dividend % divisor == 0. class ExtraQuery: pets = SQLAlchemyConnectionField(Pet, filters=MathFilter()) From 1f47f92ef6763197f4e757daf760b8d26f998020 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Tue, 2 Aug 2022 15:03:37 -0400 Subject: [PATCH 07/81] update 1:n and n:m filter tests to use RelationshipFilter syntax --- graphene_sqlalchemy/tests/test_filters.py | 106 +++++++++++++++++++++- 1 file changed, 104 insertions(+), 2 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 13b58427..584bde76 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -166,11 +166,38 @@ def test_filter_relationship_one_to_many(session): add_test_data(session) Query = create_schema(session) + # test contains query = """ query { reporter (filters: { pets: { - name: {in: ["Garfield", "Snoopy"]} + contains: { + name: {in: ["Garfield", "Lassie"]} + } + } + }) { + lastName + } + } + """ + expected = { + "reporter": [{"lastName": "Doe"}, {"lastName": "Roe"}], + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + # test containsAllOf + query = """ + query { + reporter (filters: { + pets: { + containsAllOf: [ + name: {eq: "Garfield"}, + name: {eq: "Snoopy"}, + ] } }) { firstName @@ -187,6 +214,28 @@ def test_filter_relationship_one_to_many(session): result = to_std_dicts(result.data) assert result == expected + # test containsExactly + query = """ + query { + reporter (filters: { + pets: { + containsExactly: [ + name: {eq: "Garfield"} + ] + } + }) { + firstName + } + } + """ + expected = { + "reporter": [], + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected # Test a n:m relationship def test_filter_relationship_many_to_many(session): @@ -204,10 +253,42 @@ def test_filter_relationship_many_to_many(session): Query = create_schema(session) + # test contains + query = """ + query { + articles (filters: { + tags: { + contains: { + name: { in: ["sensational", "eye-grabbing"] } + } + } + }) { + headline + } + } + """ + expected = { + "articles": [ + {"headline": "Woah! Another!"}, + {"headline": "Article! Look!"}, + ], + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + # test containsAllOf query = """ query { articles (filters: { - tags: { name: { in: ["sensational", "eye-grabbing"] } } + tags: { + containsAllOf: [ + { tag: { name: { eq: "eye-grabbing" } } }, + { tag: { name: { eq: "sensational" } } }, + ] + } }) { headline } @@ -222,6 +303,27 @@ def test_filter_relationship_many_to_many(session): result = to_std_dicts(result.data) assert result == expected + # test containsExactly + query = """ + query { + articles (filters: { + containsExactly: [ + { tag: { name: { eq: "sensational" } } } + ] + }) { + headline + } + } + """ + expected = { + "articles": [{"headline": "Article! Look!"}], + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + # Test connecting filters with "and" def test_filter_logic_and(session): From e0cd4656ca82eb31ef1922b137906839e98b2114 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Tue, 2 Aug 2022 15:41:30 -0400 Subject: [PATCH 08/81] add models and mark tests to fail --- graphene_sqlalchemy/tests/models.py | 31 ++++++++++++++++++++++- graphene_sqlalchemy/tests/test_filters.py | 10 ++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index c7a1d664..5654834a 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -9,7 +9,7 @@ String, Table, func, select) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import column_property, composite, mapper, relationship +from sqlalchemy.orm import backref, column_property, composite, mapper, relationship PetKind = Enum("cat", "dog", name="pet_kind") @@ -42,6 +42,7 @@ class Pet(Base): pet_kind = Column(PetKind, nullable=False) hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) reporter_id = Column(Integer(), ForeignKey("reporters.id")) + legs = Column(Integer(), default=4) class CompositeFullName(object): @@ -104,6 +105,14 @@ def hybrid_prop_list(self) -> List[int]: composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite") +articles_tags_table = Table( + "articles_tags", + Base.metadata, + Column("article_id", ForeignKey("article.id")), + Column("imgae_id", ForeignKey("image.id")), +) + + class Article(Base): __tablename__ = "articles" id = Column(Integer(), primary_key=True) @@ -129,6 +138,26 @@ class ArticleReader(Base): article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True) reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True) + # one-to-one relationship with image + image_id = Column(Integer(), ForeignKey('image.id'), unique=True) + image = relationship("Image", backref=backref("articles", uselist=False)) + + # many-to-many relationship with tags + tags = relationship("Tag", secondary=articles_tags_table, backref="articles") + + +class Image(Base): + __tablename__ = "images" + id = Column(Integer(), primary_key=True) + external_id = Column(Integer()) + description = Column(String(30)) + + +class Tag(Base): + __tablename__ = "tags" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + class ReflectedEditor(type): """Same as Editor, but using reflected table.""" diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 584bde76..720f84ed 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -1,4 +1,5 @@ import graphene +import pytest from ..fields import SQLAlchemyConnectionField from ..filters import FloatFilter @@ -77,6 +78,7 @@ def resolve_reporters(self, _info): # Test a simple example of filtering +@pytest.mark.xfail def test_filter_simple(session): add_test_data(session) Query = create_schema(session) @@ -99,6 +101,7 @@ def test_filter_simple(session): # Test a custom filter type +@pytest.mark.xfail def test_filter_custom_type(session): add_test_data(session) Query = create_schema(session) @@ -132,6 +135,7 @@ class CustomQuery(Query, ExtraQuery): assert result == expected # Test a 1:1 relationship +@pytest.mark.xfail def test_filter_relationship_one_to_one(session): article = Article(headline='Hi!') image = Image(external_id=1, description="A beautiful image.") @@ -162,6 +166,7 @@ def test_filter_relationship_one_to_one(session): # Test a 1:n relationship +@pytest.mark.xfail def test_filter_relationship_one_to_many(session): add_test_data(session) Query = create_schema(session) @@ -238,6 +243,7 @@ def test_filter_relationship_one_to_many(session): assert result == expected # Test a n:m relationship +@pytest.mark.xfail def test_filter_relationship_many_to_many(session): article1 = Article(headline='Article! Look!') article2 = Article(headline='Woah! Another!') @@ -326,6 +332,7 @@ def test_filter_relationship_many_to_many(session): # Test connecting filters with "and" +@pytest.mark.xfail def test_filter_logic_and(session): add_test_data(session) @@ -354,6 +361,7 @@ def test_filter_logic_and(session): # Test connecting filters with "or" +@pytest.mark.xfail def test_filter_logic_or(session): add_test_data(session) Query = create_schema(session) @@ -385,6 +393,7 @@ def test_filter_logic_or(session): # Test connecting filters with "and" and "or" together +@pytest.mark.xfail def test_filter_logic_and_or(session): add_test_data(session) Query = create_schema(session) @@ -415,5 +424,6 @@ def test_filter_logic_and_or(session): # TODO hybrid property +@pytest.mark.xfail def test_filter_hybrid_property(session): raise NotImplementedError From 32254b68f58a2ea8131acd58021f275e3a735bc7 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Tue, 2 Aug 2022 16:02:35 -0400 Subject: [PATCH 09/81] make filter tests run --- graphene_sqlalchemy/filters.py | 2 ++ graphene_sqlalchemy/tests/models.py | 32 ++++++++++---------- graphene_sqlalchemy/tests/test_filters.py | 4 +-- graphene_sqlalchemy/tests/test_sort_enums.py | 10 +++++- 4 files changed, 29 insertions(+), 19 deletions(-) create mode 100644 graphene_sqlalchemy/filters.py diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py new file mode 100644 index 00000000..6bfe7c96 --- /dev/null +++ b/graphene_sqlalchemy/filters.py @@ -0,0 +1,2 @@ +class FloatFilter: + pass diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 5654834a..2fd2b01f 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -108,11 +108,24 @@ def hybrid_prop_list(self) -> List[int]: articles_tags_table = Table( "articles_tags", Base.metadata, - Column("article_id", ForeignKey("article.id")), - Column("imgae_id", ForeignKey("image.id")), + Column("article_id", ForeignKey("articles.id")), + Column("tag_id", ForeignKey("tags.id")), ) +class Image(Base): + __tablename__ = "images" + id = Column(Integer(), primary_key=True) + external_id = Column(Integer()) + description = Column(String(30)) + + +class Tag(Base): + __tablename__ = "tags" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + + class Article(Base): __tablename__ = "articles" id = Column(Integer(), primary_key=True) @@ -139,26 +152,13 @@ class ArticleReader(Base): reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True) # one-to-one relationship with image - image_id = Column(Integer(), ForeignKey('image.id'), unique=True) + image_id = Column(Integer(), ForeignKey('images.id'), unique=True) image = relationship("Image", backref=backref("articles", uselist=False)) # many-to-many relationship with tags tags = relationship("Tag", secondary=articles_tags_table, backref="articles") -class Image(Base): - __tablename__ = "images" - id = Column(Integer(), primary_key=True) - external_id = Column(Integer()) - description = Column(String(30)) - - -class Tag(Base): - __tablename__ = "tags" - id = Column(Integer(), primary_key=True) - name = Column(String(30)) - - class ReflectedEditor(type): """Same as Editor, but using reflected table.""" diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 720f84ed..f44b139c 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -51,8 +51,8 @@ class Meta: class Query(graphene.ObjectType): article = graphene.Field(ArticleType) articles = graphene.List(ArticleType) - image = graphene.Field(ImageType) - images = graphene.List(ImageType) + # image = graphene.Field(ImageType) + # images = graphene.List(ImageType) reporter = graphene.Field(ReporterType) reporters = graphene.List(ReporterType) diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index e2510abc..73f0f59a 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -40,6 +40,8 @@ class Meta: "HAIR_KIND_DESC", "REPORTER_ID_ASC", "REPORTER_ID_DESC", + "LEGS_ASC", + "LEGS_DESC", ] assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" @@ -94,6 +96,8 @@ class Meta: "PET_KIND_DESC", "HAIR_KIND_ASC", "HAIR_KIND_DESC", + "LEGS_ASC", + "LEGS_DESC", ] @@ -134,6 +138,8 @@ class Meta: "HAIR_KIND_DESC", "REPORTER_ID_ASC", "REPORTER_ID_DESC", + "LEGS_ASC", + "LEGS_DESC", ] assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" @@ -148,7 +154,7 @@ def test_sort_argument_with_excluded_fields_in_object_type(): class PetType(SQLAlchemyObjectType): class Meta: model = Pet - exclude_fields = ["hair_kind", "reporter_id"] + exclude_fields = ["hair_kind", "reporter_id", "legs"] sort_arg = PetType.sort_argument() sort_enum = sort_arg.type._of_type @@ -237,6 +243,8 @@ def get_symbol_name(column_name, sort_asc=True): "HairKindDown", "ReporterIdUp", "ReporterIdDown", + "LegsUp", + "LegsDown", ] assert sort_arg.default_value == ["IdUp"] From 172985c697b4cb19c7b9cb8bb773a570f895f6ea Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Thu, 11 Aug 2022 18:21:46 +0200 Subject: [PATCH 10/81] Added draft methods & classes for filter registry Signed-off-by: Erik Wrede --- graphene_sqlalchemy/filters.py | 41 ++++++++++++++++++++++++++++++++- graphene_sqlalchemy/registry.py | 26 ++++++++++++++++++++- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 6bfe7c96..dd5a7c1c 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,2 +1,41 @@ -class FloatFilter: +import graphene + + +class ObjectTypeFilter(graphene.InputObjectType): pass + + +class RelationshipFilter(graphene.InputObjectType): + pass + + +class ScalarFilter(graphene.InputObjectType): + """Basic Filter for Scalars in Graphene. + We want this filter to use Dynamic fields so it provides the base + filtering methods ("eq, nEq") for different types of scalars. + The Dynamic fields will resolve to Meta.filtered_type""" + + @classmethod + def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): + print(type) # The type from the Meta Class + super(ScalarFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) + + # TODO: Make this dynamic based on Meta.Type (see FloatFilter) + eq = graphene.Dynamic(None) + + +class StringFilter(ScalarFilter): + class Meta: + type = graphene.String + + +class NumberFilter(ScalarFilter): + """Intermediate Filter class since all Numbers are in an order relationship (support <, > etc)""" + pass + + +class FloatFilter(NumberFilter): + """Cooncrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" + + class Meta: + type = graphene.Float diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 80470d9b..3db19266 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -5,6 +5,8 @@ import graphene from graphene import Enum +from graphene_sqlalchemy.filters import (ObjectTypeFilter, RelationshipFilter, + ScalarFilter) class Registry(object): @@ -103,9 +105,31 @@ def register_union_type(self, union: graphene.Union, obj_types: List[Type[graphe self._registry_unions[frozenset(obj_types)] = union - def get_union_for_object_types(self, obj_types : List[Type[graphene.ObjectType]]): + def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]]): return self._registry_unions.get(frozenset(obj_types)) + # Filter Scalar Fields of Object Types + def register_filter_for_scalar_type(self, scalar_type: graphene.Scalar, filter: ScalarFilter): + pass + + def get_filter_for_scalar_type(self, scalar_type: graphene.Scalar) -> ScalarFilter: + pass + + # Filter Object Types + def register_filter_for_object_type(self, object_type: graphene.ObjectType, filter: ObjectTypeFilter): + pass + + def get_filter_for_object_type(self, object_type: graphene.ObjectType): + pass + + # Filter Relationships between object types + def register_relationship_filter_for_object_type(self, object_type: graphene.ObjectType, + filter: RelationshipFilter): + pass + + def get_relationship_filter_for_object_type(self, object_type: graphene.ObjectType) -> RelationshipFilter: + pass + registry = None From 63192783329dafb44d02f54aa474af828c7f7063 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Thu, 11 Aug 2022 19:55:06 +0200 Subject: [PATCH 11/81] Drafted abstract filters, generation of filter fields from methods. Signed-off-by: Erik Wrede --- graphene_sqlalchemy/filters.py | 66 +++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 4 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index dd5a7c1c..ac53b7d3 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,4 +1,7 @@ +import re + import graphene +from graphene.types.inputobjecttype import InputObjectTypeOptions class ObjectTypeFilter(graphene.InputObjectType): @@ -9,6 +12,11 @@ class RelationshipFilter(graphene.InputObjectType): pass +class AbstractType: + """Dummy class for generic filters""" + pass + + class ScalarFilter(graphene.InputObjectType): """Basic Filter for Scalars in Graphene. We want this filter to use Dynamic fields so it provides the base @@ -18,10 +26,53 @@ class ScalarFilter(graphene.InputObjectType): @classmethod def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): print(type) # The type from the Meta Class + + # get all filter functions + filter_function_regex = re.compile(".+_filter$") + + filter_functions = [] + + # Search the entire class for functions matching the filter regex + for func in dir(cls): + func_attr = getattr(cls, func) + # Check if attribute is a function + if callable(func_attr) and filter_function_regex.match(func): + # add function and attribute name to the list + filter_functions.append((func.removesuffix("_filter"), func_attr.__annotations__)) + + # Init meta options class if it doesn't exist already + if not _meta: + _meta = InputObjectTypeOptions(cls) + + new_filter_fields = {} + + # Generate Graphene Fields from the filter functions based on type hints + for field_name, annotations in filter_functions: + assert "val" in annotations, "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + if annotations["val"] == AbstractType: + # TODO Maybe there is an existing class or a more elegant way to solve this + # One option would be to only annotate non-abstract filters + new_filter_fields.update({field_name: graphene.InputField(type)}) + else: + # TODO this is a place holder, we need to convert the type of val to a valid graphene + # type that we can pass to the InputField. We could re-use converter.convert_hybrid_property_return_type + new_filter_fields.update({field_name: graphene.InputField(graphene.String)}) + + # Add all fields to the meta options. graphene.InputbjectType will take care of the rest + if _meta.fields: + _meta.fields.update(new_filter_fields) + else: + _meta.fields = new_filter_fields + + # Pass modified meta to the super class super(ScalarFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) - # TODO: Make this dynamic based on Meta.Type (see FloatFilter) - eq = graphene.Dynamic(None) + # Abstract methods can be marked using AbstractType. See comment on the init method + @classmethod + def eq_filter(cls, val: AbstractType) -> bool: + # TBD filtering magic + pass class StringFilter(ScalarFilter): @@ -31,11 +82,18 @@ class Meta: class NumberFilter(ScalarFilter): """Intermediate Filter class since all Numbers are in an order relationship (support <, > etc)""" - pass + + class Meta: + abstract = True + + @classmethod + def gt_filter(cls, val: str) -> bool: + # TBD filtering magic + pass class FloatFilter(NumberFilter): - """Cooncrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" + """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" class Meta: type = graphene.Float From 69680a94e6001b3734d0285ae7288c6aab96b5fb Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 12 Aug 2022 15:50:54 +0200 Subject: [PATCH 12/81] Prototype: Filter Schema generation Signed-off-by: Erik Wrede --- graphene_sqlalchemy/converter.py | 4 +- graphene_sqlalchemy/fields.py | 44 ++++++++------ graphene_sqlalchemy/filters.py | 101 +++++++++++++++++++++++++++---- graphene_sqlalchemy/registry.py | 69 ++++++++++++++++----- graphene_sqlalchemy/types.py | 81 +++++++++++++++++++++++-- graphene_sqlalchemy/utils.py | 15 +++++ 6 files changed, 262 insertions(+), 52 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 1e7846eb..0c39df16 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -19,7 +19,7 @@ default_connection_field_factory) from .registry import get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver -from .utils import (DummyImport, registry_sqlalchemy_model_from_str, +from .utils import (DummyImport, is_list, registry_sqlalchemy_model_from_str, safe_isinstance, singledispatchbymatchfunction, value_equals) @@ -420,7 +420,7 @@ def convert_sqlalchemy_hybrid_property_union(arg): get_global_registry()) -@convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) in [list, typing.List]) +@convert_sqlalchemy_hybrid_property_type.register(is_list) def convert_sqlalchemy_hybrid_property_type_list_t(arg): # type is either list[T] or List[T], generic argument at __args__[0] internal_type = arg.__args__[0] diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 905ad415..7a75c1b5 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -5,13 +5,12 @@ from promise import Promise, is_thenable from sqlalchemy.orm.query import Query -from graphene import NonNull from graphene.relay import Connection, ConnectionField from graphene.relay.connection import connection_adapter, page_info_adapter from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver -from .utils import EnumValue, get_query +from .utils import EnumValue, get_nullable_type, get_query class SQLAlchemyConnectionField(ConnectionField): @@ -39,19 +38,30 @@ def type(self): def __init__(self, type_, *args, **kwargs): nullable_type = get_nullable_type(type_) - if "sort" not in kwargs and nullable_type and issubclass(nullable_type, Connection): - # Let super class raise if type is not a Connection - try: - kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) - except (AttributeError, TypeError): - raise TypeError( - 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' - " to None to disabling the creation of the sort query argument".format( - nullable_type.__name__ + # Handle Sorting and Filtering + if nullable_type and issubclass(nullable_type, Connection): + if "sort" not in kwargs: + # Let super class raise if type is not a Connection + try: + kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) + except (AttributeError, TypeError): + raise TypeError( + 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' + " to None to disabling the creation of the sort query argument".format( + nullable_type.__name__ + ) ) - ) - elif "sort" in kwargs and kwargs["sort"] is None: - del kwargs["sort"] + elif "sort" in kwargs and kwargs["sort"] is None: + del kwargs["sort"] + + if "filter" not in kwargs: + # Only add filtering if a filter argument exists on the object type + filter_argument = nullable_type.Edge.node._type.get_filter_argument() + if filter_argument: + kwargs.setdefault("filter", filter_argument) + elif "filter" in kwargs and kwargs["filter"] is None: + del kwargs["filter"] + super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) @property @@ -204,9 +214,3 @@ def unregisterConnectionFieldFactory(): ) global __connectionFactory __connectionFactory = UnsortedSQLAlchemyConnectionField - - -def get_nullable_type(_type): - if isinstance(_type, NonNull): - return _type.of_type - return _type diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index ac53b7d3..62d9b6b9 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,23 +1,86 @@ import re +from typing import List import graphene from graphene.types.inputobjecttype import InputObjectTypeOptions +from graphene_sqlalchemy.utils import is_list -class ObjectTypeFilter(graphene.InputObjectType): +class AbstractType: + """Dummy class for generic filters""" pass +class ObjectTypeFilter(graphene.InputObjectType): + @classmethod + def __init_subclass_with_meta__(cls, filter_fields=None, _meta=None, **options): + + # Init meta options class if it doesn't exist already + if not _meta: + _meta = InputObjectTypeOptions(cls) + + # Add all fields to the meta options. graphene.InputObjectType will take care of the rest + if _meta.fields: + _meta.fields.update(filter_fields) + else: + _meta.fields = filter_fields + + super(ObjectTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) + + @classmethod + def and_logic(cls, val: list["ObjectTypeFilter"]): + # TODO + pass + + class RelationshipFilter(graphene.InputObjectType): - pass + @classmethod + def __init_subclass_with_meta__(cls, object_type_filter=None, _meta=None, **options): + if not object_type_filter: + raise Exception("Relationship Filters must be specific to an object type") + # Init meta options class if it doesn't exist already + if not _meta: + _meta = InputObjectTypeOptions(cls) + # get all filter functions + filter_function_regex = re.compile(".+_filter$") -class AbstractType: - """Dummy class for generic filters""" - pass + filter_functions = [] + + # Search the entire class for functions matching the filter regex + for func in dir(cls): + func_attr = getattr(cls, func) + # Check if attribute is a function + if callable(func_attr) and filter_function_regex.match(func): + # add function and attribute name to the list + filter_functions.append((func.removesuffix("_filter"), func_attr.__annotations__)) + + relationship_filters = {} + + # Generate Graphene Fields from the filter functions based on type hints + for field_name, annotations in filter_functions: + assert "val" in annotations, "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + if is_list(annotations["val"]): + relationship_filters.update({field_name: graphene.InputField(graphene.List(object_type_filter))}) + else: + relationship_filters.update({field_name: graphene.InputField(object_type_filter)}) + + # Add all fields to the meta options. graphene.InputObjectType will take care of the rest + if _meta.fields: + _meta.fields.update(relationship_filters) + else: + _meta.fields = relationship_filters + super(RelationshipFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) -class ScalarFilter(graphene.InputObjectType): + @classmethod + def contains_filter(cls, val: List["RelationshipFilter"]): + # TODO + pass + + +class FieldFilter(graphene.InputObjectType): """Basic Filter for Scalars in Graphene. We want this filter to use Dynamic fields so it provides the base filtering methods ("eq, nEq") for different types of scalars. @@ -25,7 +88,6 @@ class ScalarFilter(graphene.InputObjectType): @classmethod def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): - print(type) # The type from the Meta Class # get all filter functions filter_function_regex = re.compile(".+_filter$") @@ -66,7 +128,7 @@ def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): _meta.fields = new_filter_fields # Pass modified meta to the super class - super(ScalarFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) + super(FieldFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) # Abstract methods can be marked using AbstractType. See comment on the init method @classmethod @@ -75,19 +137,24 @@ def eq_filter(cls, val: AbstractType) -> bool: pass -class StringFilter(ScalarFilter): +class StringFilter(FieldFilter): class Meta: type = graphene.String -class NumberFilter(ScalarFilter): +class BooleanFilter(FieldFilter): + class Meta: + type = graphene.Boolean + + +class NumberFilter(FieldFilter): """Intermediate Filter class since all Numbers are in an order relationship (support <, > etc)""" class Meta: abstract = True @classmethod - def gt_filter(cls, val: str) -> bool: + def gt_filter(cls, val: AbstractType) -> bool: # TBD filtering magic pass @@ -97,3 +164,15 @@ class FloatFilter(NumberFilter): class Meta: type = graphene.Float + + +class IntFilter(NumberFilter): + class Meta: + type = graphene.Int + + +class DateFilter(NumberFilter): + """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" + + class Meta: + type = graphene.Date diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 3db19266..3eca2178 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -5,11 +5,14 @@ import graphene from graphene import Enum -from graphene_sqlalchemy.filters import (ObjectTypeFilter, RelationshipFilter, - ScalarFilter) +from graphene_sqlalchemy.filters import (BooleanFilter, FieldFilter, + FloatFilter, IntFilter, StringFilter) class Registry(object): + from graphene_sqlalchemy.filters import (FieldFilter, ObjectTypeFilter, + RelationshipFilter) + def __init__(self): self._registry = {} self._registry_models = {} @@ -18,6 +21,9 @@ def __init__(self): self._registry_enums = {} self._registry_sort_enums = {} self._registry_unions = {} + self._registry_scalar_filters = {} + self._registry_object_type_filters = {} + self._registry_relationship_filters = {} def register(self, obj_type): @@ -109,26 +115,55 @@ def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]]) return self._registry_unions.get(frozenset(obj_types)) # Filter Scalar Fields of Object Types - def register_filter_for_scalar_type(self, scalar_type: graphene.Scalar, filter: ScalarFilter): - pass + def register_filter_for_scalar_type(self, scalar_type: Type[graphene.Scalar], filter_obj: Type[FieldFilter]): + if not isinstance(scalar_type, type(graphene.Scalar)): + raise TypeError( + "Expected Scalar, but got: {!r}".format(scalar_type) + ) + + if not isinstance(filter_obj, type(FieldFilter)): + raise TypeError( + "Expected ScalarFilter, but got: {!r}".format(filter_obj) + ) + self._registry_scalar_filters[scalar_type] = filter_obj + + def get_filter_for_scalar_type(self, scalar_type: Type[graphene.Scalar]) -> Type[FieldFilter]: - def get_filter_for_scalar_type(self, scalar_type: graphene.Scalar) -> ScalarFilter: - pass + return self._registry_scalar_filters.get(scalar_type) # Filter Object Types - def register_filter_for_object_type(self, object_type: graphene.ObjectType, filter: ObjectTypeFilter): - pass + def register_filter_for_object_type(self, object_type: Type[graphene.ObjectType], + filter_obj: Type[ObjectTypeFilter]): + if not isinstance(object_type, type(graphene.ObjectType)): + raise TypeError( + "Expected Object Type, but got: {!r}".format(object_type) + ) - def get_filter_for_object_type(self, object_type: graphene.ObjectType): - pass + if not isinstance(filter_obj, type(FieldFilter)): + raise TypeError( + "Expected ObjectTypeFilter, but got: {!r}".format(filter_obj) + ) + self._registry_object_type_filters[object_type] = filter_obj + + def get_filter_for_object_type(self, object_type: Type[graphene.ObjectType]): + return self._registry_object_type_filters.get(object_type) # Filter Relationships between object types def register_relationship_filter_for_object_type(self, object_type: graphene.ObjectType, - filter: RelationshipFilter): - pass + filter_obj: RelationshipFilter): + if not isinstance(object_type, type(graphene.ObjectType)): + raise TypeError( + "Expected Object Type, but got: {!r}".format(object_type) + ) - def get_relationship_filter_for_object_type(self, object_type: graphene.ObjectType) -> RelationshipFilter: - pass + if not isinstance(filter_obj, type(FieldFilter)): + raise TypeError( + "Expected RelationshipFilter, but got: {!r}".format(filter_obj) + ) + self._registry_relationship_filters[object_type] = filter_obj + + def get_relationship_filter_for_object_type(self, object_type: Type[graphene.ObjectType]) -> RelationshipFilter: + return self._registry_relationship_filters.get(object_type) registry = None @@ -144,3 +179,9 @@ def get_global_registry(): def reset_global_registry(): global registry registry = None + + +get_global_registry().register_filter_for_scalar_type(graphene.Float, FloatFilter) +get_global_registry().register_filter_for_scalar_type(graphene.Float, IntFilter) +get_global_registry().register_filter_for_scalar_type(graphene.String, StringFilter) +get_global_registry().register_filter_for_scalar_type(graphene.Boolean, BooleanFilter) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index e6c3d14c..2891a3f5 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,4 +1,6 @@ +import warnings from collections import OrderedDict +from typing import Type, Union import sqlalchemy from sqlalchemy.ext.hybrid import hybrid_property @@ -6,6 +8,7 @@ RelationshipProperty) from sqlalchemy.orm.exc import NoResultFound +import graphene from graphene import Field from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions @@ -18,9 +21,11 @@ convert_sqlalchemy_relationship) from .enums import (enum_for_field, sort_argument_for_object_type, sort_enum_for_object_type) +from .filters import ObjectTypeFilter, RelationshipFilter from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver -from .utils import get_query, is_mapped_class, is_mapped_instance +from .utils import (get_nullable_type, get_query, is_mapped_class, + is_mapped_instance) class ORMField(OrderedType): @@ -88,6 +93,52 @@ class Meta: self.kwargs.update(common_kwargs) +def get_or_create_relationship_filter(obj_type: Type[ObjectType], registry: Registry) -> Type[RelationshipFilter]: + relationship_filter = registry.get_relationship_filter_for_object_type(obj_type) + + if not relationship_filter: + object_type_filter = registry.get_filter_for_object_type(obj_type) + relationship_filter = RelationshipFilter.create_type(f"{obj_type.__name__}RelationshipFilter", + object_type_filter=object_type_filter) + + return relationship_filter + + +def filter_field_from_type_field(field: Union[graphene.Field, graphene.Dynamic], + registry: Registry) -> Union[graphene.InputField, graphene.Dynamic]: + if isinstance(field.type, graphene.List): + pass + elif isinstance(field.type, graphene.Dynamic): + pass + # If the field is Dynamic, we don't know its type yet and can't select the right filter + elif isinstance(field, graphene.Dynamic): + def resolve_dynamic(): + # Resolve Dynamic Type + type_ = get_nullable_type(field.get_type()) + from graphene_sqlalchemy import SQLAlchemyConnectionField + if isinstance(type_, SQLAlchemyConnectionField): + inner_type = get_nullable_type(type_.type.Edge.node._type) + return graphene.InputField(get_or_create_relationship_filter(inner_type, registry)) + elif isinstance(type_, Field): + reg_res = registry.get_filter_for_object_type(type_.type) + return graphene.InputField(reg_res) + else: + warnings.warn(f"Unexpected Dynamic Type: {type_}") # Investigate + # raise Exception(f"Unexpected Dynamic Type: {type_}") + + return graphene.Dynamic(resolve_dynamic) + + elif isinstance(field, graphene.Field): + type_ = get_nullable_type(field.type) + filter_class = registry.get_filter_for_scalar_type(type_) + if not filter_class: + warnings.warn(f"No compatible filters found for {field.type}. Skipping field.") + return None + return graphene.InputField(filter_class) + else: + raise Exception(f"Expected a graphene.Field or graphene.Dynamic, but got: {field}") + + def construct_fields( obj_type, model, registry, only_fields, exclude_fields, batching, connection_field_factory ): @@ -138,9 +189,9 @@ def construct_fields( attr_name = orm_field.kwargs.get('model_attr', orm_field_name) if attr_name not in all_model_attrs: raise ValueError(( - "Cannot map ORMField to a model attribute.\n" - "Field: '{}.{}'" - ).format(obj_type.__name__, orm_field_name,)) + "Cannot map ORMField to a model attribute.\n" + "Field: '{}.{}'" + ).format(obj_type.__name__, orm_field_name, )) orm_field.kwargs['model_attr'] = attr_name # Merge automatic fields with custom ORM fields @@ -174,7 +225,6 @@ def construct_fields( field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs) else: raise Exception('Property type is not supported') # Should never happen - registry.register_orm_field(obj_type, orm_field_name, attr) fields[orm_field_name] = field @@ -186,6 +236,7 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): registry = None # type: sqlalchemy.Registry connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] id = None # type: str + filter_class: Type[ObjectTypeFilter] = None class SQLAlchemyObjectType(ObjectType): @@ -199,6 +250,7 @@ def __init_subclass_with_meta__( exclude_fields=(), connection=None, connection_class=None, + filter_base_class=None, use_connection=None, interfaces=(), id=None, @@ -268,6 +320,19 @@ def __init_subclass_with_meta__( else: _meta.fields = sqla_fields + # Save Generated filter class in Meta Class + if not _meta.filter_class: + filters = OrderedDict() + # Map graphene fields to filters + # TODO we might need to pass the ORMFields containing the SQLAlchemy models + # to the scalar filters here (to generate expressions from the model) + for fieldname, field in sqla_fields.items(): + field_filter = filter_field_from_type_field(field, registry) + if field_filter: + filters[fieldname] = field_filter + _meta.filter_class = ObjectTypeFilter.create_type(f"{cls.__name__}Filter", filter_fields=filters) + registry.register_filter_for_object_type(cls, _meta.filter_class) + _meta.connection = connection _meta.id = id or "id" @@ -309,6 +374,12 @@ def resolve_id(self, info): def enum_for_field(cls, field_name): return enum_for_field(cls, field_name) + @classmethod + def get_filter_argument(cls): + if cls._meta.filter_class: + return graphene.Argument(cls._meta.filter_class) + return None + sort_enum = classmethod(sort_enum_for_object_type) sort_argument = classmethod(sort_argument_for_object_type) diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index f6ee9b62..4f5ebdb1 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,4 +1,5 @@ import re +import typing import warnings from collections import OrderedDict from typing import Any, Callable, Dict, Optional @@ -8,6 +9,14 @@ from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError +from graphene import NonNull + + +def get_nullable_type(_type): + if isinstance(_type, NonNull): + return _type.of_type + return _type + def get_session(context): return context.get("session") @@ -197,6 +206,7 @@ def safe_isinstance_checker(arg): return isinstance(arg, cls) except TypeError: pass + return safe_isinstance_checker @@ -208,7 +218,12 @@ def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: pass +def is_list(x): + return getattr(x, '__origin__', None) in [list, typing.List] + + class DummyImport: """The dummy module returns 'object' for a query for any member""" + def __getattr__(self, name): return object From c31bbac03f922397001e70d17b622e07186d9e89 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Fri, 12 Aug 2022 11:44:54 -0400 Subject: [PATCH 13/81] use re sub instead of removesuffix for python<3.9 support --- graphene_sqlalchemy/filters.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 62d9b6b9..692535f4 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -100,7 +100,9 @@ def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): # Check if attribute is a function if callable(func_attr) and filter_function_regex.match(func): # add function and attribute name to the list - filter_functions.append((func.removesuffix("_filter"), func_attr.__annotations__)) + filter_functions.append(( + re.sub("\_filter$", "", func), func_attr.__annotations__) + ) # Init meta options class if it doesn't exist already if not _meta: From f5328572cd82bef31e39a0b63b603545475e60bb Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Fri, 12 Aug 2022 12:00:11 -0400 Subject: [PATCH 14/81] fix syntax for python<3.9 support --- graphene_sqlalchemy/filters.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 692535f4..8848a290 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,3 +1,4 @@ +from __future__ import annotations import re from typing import List @@ -53,7 +54,7 @@ def __init_subclass_with_meta__(cls, object_type_filter=None, _meta=None, **opti # Check if attribute is a function if callable(func_attr) and filter_function_regex.match(func): # add function and attribute name to the list - filter_functions.append((func.removesuffix("_filter"), func_attr.__annotations__)) + filter_functions.append((re.sub("\_filter$", "", func), func_attr.__annotations__)) relationship_filters = {} From 7e2bf6eba84f3a9430785cbf9d770f2694e5c4d1 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Fri, 12 Aug 2022 17:46:22 -0400 Subject: [PATCH 15/81] field-level filtering working for equals and not equals --- graphene_sqlalchemy/fields.py | 14 +++++++++++++- graphene_sqlalchemy/filters.py | 15 ++++++++++----- graphene_sqlalchemy/types.py | 2 +- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 7a75c1b5..f253aabe 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -69,7 +69,7 @@ def model(self): return get_nullable_type(self.type)._meta.node._meta.model @classmethod - def get_query(cls, model, info, sort=None, **args): + def get_query(cls, model, info, sort=None, filter=None, **args): query = get_query(model, info.context) if sort is not None: if not isinstance(sort, list): @@ -85,6 +85,18 @@ def get_query(cls, model, info, sort=None, **args): else: sort_args.append(item) query = query.order_by(*sort_args) + + if filter is not None: + assert isinstance(filter, dict) + filter_type = type(filter) + for field, filt_dict in filter.items(): + model = filter_type._meta.model + field_filter_type = filter_type._meta.fields[field]._type + for filt, val in filt_dict.items(): + query = getattr( + field_filter_type, filt + "_filter" + )(query, getattr(model, field), val) + return query @classmethod diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 8848a290..4ea04d34 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -5,7 +5,7 @@ import graphene from graphene.types.inputobjecttype import InputObjectTypeOptions from graphene_sqlalchemy.utils import is_list - +from sqlalchemy import not_ class AbstractType: """Dummy class for generic filters""" @@ -14,7 +14,7 @@ class AbstractType: class ObjectTypeFilter(graphene.InputObjectType): @classmethod - def __init_subclass_with_meta__(cls, filter_fields=None, _meta=None, **options): + def __init_subclass_with_meta__(cls, filter_fields=None, model=None, _meta=None, **options): # Init meta options class if it doesn't exist already if not _meta: @@ -26,6 +26,8 @@ def __init_subclass_with_meta__(cls, filter_fields=None, _meta=None, **options): else: _meta.fields = filter_fields + _meta.model = model + super(ObjectTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) @classmethod @@ -135,9 +137,12 @@ def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): # Abstract methods can be marked using AbstractType. See comment on the init method @classmethod - def eq_filter(cls, val: AbstractType) -> bool: - # TBD filtering magic - pass + def eq_filter(cls, query, field, val: AbstractType) -> bool: + return query.filter(field == val) + + @classmethod + def n_eq_filter(cls, query, field, val: AbstractType) -> bool: + return query.filter(not_(field == val)) class StringFilter(FieldFilter): diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 2891a3f5..e0c56d12 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -330,7 +330,7 @@ def __init_subclass_with_meta__( field_filter = filter_field_from_type_field(field, registry) if field_filter: filters[fieldname] = field_filter - _meta.filter_class = ObjectTypeFilter.create_type(f"{cls.__name__}Filter", filter_fields=filters) + _meta.filter_class = ObjectTypeFilter.create_type(f"{cls.__name__}Filter", filter_fields=filters, model=model) registry.register_filter_for_object_type(cls, _meta.filter_class) _meta.connection = connection From b21e44990af24a4ed34028a9640a97ad9baac2fa Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sat, 13 Aug 2022 12:11:48 +0200 Subject: [PATCH 16/81] Made filter query generation modular & fix flake8 Signed-off-by: Erik Wrede --- graphene_sqlalchemy/fields.py | 13 ++--- graphene_sqlalchemy/filters.py | 93 +++++++++++++++++++++++++-------- graphene_sqlalchemy/registry.py | 2 +- graphene_sqlalchemy/types.py | 3 +- 4 files changed, 78 insertions(+), 33 deletions(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index f253aabe..1f344208 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -10,6 +10,7 @@ from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver +from .filters import ObjectTypeFilter from .utils import EnumValue, get_nullable_type, get_query @@ -88,14 +89,10 @@ def get_query(cls, model, info, sort=None, filter=None, **args): if filter is not None: assert isinstance(filter, dict) - filter_type = type(filter) - for field, filt_dict in filter.items(): - model = filter_type._meta.model - field_filter_type = filter_type._meta.fields[field]._type - for filt, val in filt_dict.items(): - query = getattr( - field_filter_type, filt + "_filter" - )(query, getattr(model, field), val) + filter_type : ObjectTypeFilter = type(filter) + query, clauses = filter_type.execute_filters(query, filter) + + query = query.filter(*clauses) return query diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 4ea04d34..d30d4824 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,11 +1,15 @@ from __future__ import annotations + import re -from typing import List +from typing import Any, Dict, List, Tuple, Type, TypeVar, Union + +from sqlalchemy import not_ +from sqlalchemy.orm import Query import graphene from graphene.types.inputobjecttype import InputObjectTypeOptions from graphene_sqlalchemy.utils import is_list -from sqlalchemy import not_ + class AbstractType: """Dummy class for generic filters""" @@ -31,10 +35,22 @@ def __init_subclass_with_meta__(cls, filter_fields=None, model=None, _meta=None, super(ObjectTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) @classmethod - def and_logic(cls, val: list["ObjectTypeFilter"]): + def and_logic(cls, query, field, val: list["ObjectTypeFilter"]): # TODO pass + @classmethod + def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict) -> Tuple[Query, List[Any]]: + clauses = [] + for field, filt_dict in filter_dict.items(): + model = cls._meta.model + field_filter_type: FieldFilter = cls._meta.fields[field]._type + model_field = getattr(model, field) + query, clauses = field_filter_type.execute_filters(query, model_field, filt_dict) + clauses.extend(clauses) + + return query, clauses + class RelationshipFilter(graphene.InputObjectType): @classmethod @@ -56,15 +72,15 @@ def __init_subclass_with_meta__(cls, object_type_filter=None, _meta=None, **opti # Check if attribute is a function if callable(func_attr) and filter_function_regex.match(func): # add function and attribute name to the list - filter_functions.append((re.sub("\_filter$", "", func), func_attr.__annotations__)) + filter_functions.append((re.sub("_filter$", "", func), func_attr.__annotations__)) relationship_filters = {} # Generate Graphene Fields from the filter functions based on type hints - for field_name, annotations in filter_functions: - assert "val" in annotations, "Each filter method must have a value field with valid type annotations" + for field_name, _annotations in filter_functions: + assert "val" in _annotations, "Each filter method must have a value field with valid type annotations" # If type is generic, replace with actual type of filter class - if is_list(annotations["val"]): + if is_list(_annotations["val"]): relationship_filters.update({field_name: graphene.InputField(graphene.List(object_type_filter))}) else: relationship_filters.update({field_name: graphene.InputField(object_type_filter)}) @@ -83,6 +99,9 @@ def contains_filter(cls, val: List["RelationshipFilter"]): pass +any_field_filter = TypeVar('any_field_filter', bound="FieldFilter") + + class FieldFilter(graphene.InputObjectType): """Basic Filter for Scalars in Graphene. We want this filter to use Dynamic fields so it provides the base @@ -104,7 +123,7 @@ def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): if callable(func_attr) and filter_function_regex.match(func): # add function and attribute name to the list filter_functions.append(( - re.sub("\_filter$", "", func), func_attr.__annotations__) + re.sub("_filter$", "", func), func_attr.__annotations__) ) # Init meta options class if it doesn't exist already @@ -112,12 +131,13 @@ def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): _meta = InputObjectTypeOptions(cls) new_filter_fields = {} - + print(f"Geenerating Fields for {cls.__name__} with type {type} ") # Generate Graphene Fields from the filter functions based on type hints - for field_name, annotations in filter_functions: - assert "val" in annotations, "Each filter method must have a value field with valid type annotations" + for field_name, _annotations in filter_functions: + assert "val" in _annotations, "Each filter method must have a value field with valid type annotations" # If type is generic, replace with actual type of filter class - if annotations["val"] == AbstractType: + print(f"Field: {field_name} with annotation {_annotations['val']}") + if _annotations["val"] == "AbstractType": # TODO Maybe there is an existing class or a more elegant way to solve this # One option would be to only annotate non-abstract filters new_filter_fields.update({field_name: graphene.InputField(type)}) @@ -137,12 +157,23 @@ def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): # Abstract methods can be marked using AbstractType. See comment on the init method @classmethod - def eq_filter(cls, query, field, val: AbstractType) -> bool: - return query.filter(field == val) + def eq_filter(cls, query, field, val: AbstractType) -> Union[Tuple[Query, Any], Any]: + return field == val @classmethod - def n_eq_filter(cls, query, field, val: AbstractType) -> bool: - return query.filter(not_(field == val)) + def n_eq_filter(cls, query, field, val: AbstractType) -> Union[Tuple[Query, Any], Any]: + return not_(field == val) + + @classmethod + def execute_filters(cls: Type[FieldFilter], query, field, filter_dict: any_field_filter) -> Tuple[Query, List[Any]]: + clauses = [] + for filt, val in filter_dict.items(): + clause = getattr(cls, filt + "_filter")(query, field, val) + if isinstance(clause, tuple): + query, clause = clause + clauses.append(clause) + + return query, clauses class StringFilter(FieldFilter): @@ -155,16 +186,32 @@ class Meta: type = graphene.Boolean -class NumberFilter(FieldFilter): - """Intermediate Filter class since all Numbers are in an order relationship (support <, > etc)""" - +class OrderedFilter(FieldFilter): class Meta: abstract = True @classmethod - def gt_filter(cls, val: AbstractType) -> bool: - # TBD filtering magic - pass + def gt_filter(cls, query, field, val: AbstractType) -> bool: + return field > val + + @classmethod + def gte_filter(cls, query, field, val: AbstractType) -> bool: + return field >= val + + @classmethod + def lt_filter(cls, query, field, val: AbstractType) -> bool: + return field < val + + @classmethod + def lte_filter(cls, query, field, val: AbstractType) -> bool: + return field <= val + + +class NumberFilter(OrderedFilter): + """Intermediate Filter class since all Numbers are in an order relationship (support <, > etc)""" + + class Meta: + abstract = True class FloatFilter(NumberFilter): @@ -179,7 +226,7 @@ class Meta: type = graphene.Int -class DateFilter(NumberFilter): +class DateFilter(OrderedFilter): """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" class Meta: diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 3eca2178..20d7f2d8 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -182,6 +182,6 @@ def reset_global_registry(): get_global_registry().register_filter_for_scalar_type(graphene.Float, FloatFilter) -get_global_registry().register_filter_for_scalar_type(graphene.Float, IntFilter) +get_global_registry().register_filter_for_scalar_type(graphene.Int, IntFilter) get_global_registry().register_filter_for_scalar_type(graphene.String, StringFilter) get_global_registry().register_filter_for_scalar_type(graphene.Boolean, BooleanFilter) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index e0c56d12..95b0f012 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -330,7 +330,8 @@ def __init_subclass_with_meta__( field_filter = filter_field_from_type_field(field, registry) if field_filter: filters[fieldname] = field_filter - _meta.filter_class = ObjectTypeFilter.create_type(f"{cls.__name__}Filter", filter_fields=filters, model=model) + _meta.filter_class = ObjectTypeFilter \ + .create_type(f"{cls.__name__}Filter", filter_fields=filters, model=model) registry.register_filter_for_object_type(cls, _meta.filter_class) _meta.connection = connection From 4fbbe8176540d62e98292866a0962b17c9aa2567 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sat, 13 Aug 2022 12:47:59 +0200 Subject: [PATCH 17/81] Fixed variable name in execute_filters Signed-off-by: Erik Wrede --- graphene_sqlalchemy/filters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index d30d4824..0cfbde00 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -46,8 +46,8 @@ def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict) -> Tuple[Q model = cls._meta.model field_filter_type: FieldFilter = cls._meta.fields[field]._type model_field = getattr(model, field) - query, clauses = field_filter_type.execute_filters(query, model_field, filt_dict) - clauses.extend(clauses) + query, _clauses = field_filter_type.execute_filters(query, model_field, filt_dict) + clauses.extend(_clauses) return query, clauses From 675a264a702ed7740d9df313bbd5b218bf1d70ab Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sat, 13 Aug 2022 13:26:57 +0200 Subject: [PATCH 18/81] Drafted :1 and :many relationship filter construction Signed-off-by: Erik Wrede --- graphene_sqlalchemy/filters.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 0cfbde00..72e6e468 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -44,10 +44,29 @@ def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict) -> Tuple[Q clauses = [] for field, filt_dict in filter_dict.items(): model = cls._meta.model - field_filter_type: FieldFilter = cls._meta.fields[field]._type - model_field = getattr(model, field) - query, _clauses = field_filter_type.execute_filters(query, model_field, filt_dict) - clauses.extend(_clauses) + # Relationships are Dynamic, we need to resolve them fist + # Maybe we can cache these dynamics to improve efficiency + # Check with a profiler is required to determine necessity + input_field = cls._meta.fields[field] + if isinstance(input_field, graphene.Dynamic): + field_filter_type = input_field.get_type().type + else: + field_filter_type = cls._meta.fields[field].type + # TODO we need to save the relationship props in the meta fields array + # to conduct joins and alias the joins (in case there are duplicate joins: A->B A->C B->C) + if issubclass(field_filter_type, ObjectTypeFilter): + # TODO see above; not yet working + query, _clauses = field_filter_type.execute_filters(query, filt_dict) + clauses.extend(_clauses) + if issubclass(field_filter_type, RelationshipFilter): + # TODO see above; not yet working + relationship_prop = None + query, _clauses = field_filter_type.execute_filters(query, filt_dict, relationship_prop) + clauses.extend(_clauses) + elif issubclass(field_filter_type, FieldFilter): + model_field = getattr(model, field) + query, _clauses = field_filter_type.execute_filters(query, model_field, filt_dict) + clauses.extend(_clauses) return query, clauses @@ -98,6 +117,12 @@ def contains_filter(cls, val: List["RelationshipFilter"]): # TODO pass + @classmethod + def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict, relationship_prop) -> Tuple[Query, List[Any]]: + query, clauses = (query, []) + # TODO + return query, clauses + any_field_filter = TypeVar('any_field_filter', bound="FieldFilter") From 60bfd3b97e5a0063a708a05d73d5103b6c6753f5 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sat, 13 Aug 2022 16:38:29 +0200 Subject: [PATCH 19/81] Prototype: :1 relationship filtering is working Signed-off-by: Erik Wrede --- graphene_sqlalchemy/filters.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 72e6e468..ff7a136f 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Tuple, Type, TypeVar, Union from sqlalchemy import not_ -from sqlalchemy.orm import Query +from sqlalchemy.orm import Query, aliased import graphene from graphene.types.inputobjecttype import InputObjectTypeOptions @@ -40,10 +40,13 @@ def and_logic(cls, query, field, val: list["ObjectTypeFilter"]): pass @classmethod - def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict) -> Tuple[Query, List[Any]]: + def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict, model_alias=None) -> Tuple[Query, List[Any]]: + model = cls._meta.model + if model_alias: + model = model_alias + clauses = [] for field, filt_dict in filter_dict.items(): - model = cls._meta.model # Relationships are Dynamic, we need to resolve them fist # Maybe we can cache these dynamics to improve efficiency # Check with a profiler is required to determine necessity @@ -54,9 +57,18 @@ def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict) -> Tuple[Q field_filter_type = cls._meta.fields[field].type # TODO we need to save the relationship props in the meta fields array # to conduct joins and alias the joins (in case there are duplicate joins: A->B A->C B->C) + model_field = getattr(model, field) if issubclass(field_filter_type, ObjectTypeFilter): - # TODO see above; not yet working - query, _clauses = field_filter_type.execute_filters(query, filt_dict) + # Get the model to join on the Filter Query + joined_model = field_filter_type._meta.model + # Always alias the model + joined_model_alias = aliased(joined_model) + + # Join the aliased model onto the query + query = query.join(model_field.of_type(joined_model_alias)) + + # Pass the joined query down to the next object type filter for processing + query, _clauses = field_filter_type.execute_filters(query, filt_dict, model_alias=joined_model_alias) clauses.extend(_clauses) if issubclass(field_filter_type, RelationshipFilter): # TODO see above; not yet working @@ -64,7 +76,6 @@ def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict) -> Tuple[Q query, _clauses = field_filter_type.execute_filters(query, filt_dict, relationship_prop) clauses.extend(_clauses) elif issubclass(field_filter_type, FieldFilter): - model_field = getattr(model, field) query, _clauses = field_filter_type.execute_filters(query, model_field, filt_dict) clauses.extend(_clauses) From a30d77d7334cd487028656159ef1f9ae8858c8ab Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 15 Aug 2022 20:22:22 -0400 Subject: [PATCH 20/81] error on simple filter test --- graphene_sqlalchemy/tests/test_filters.py | 38 ++++++++++++----------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index f44b139c..c082d832 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -1,6 +1,8 @@ import graphene import pytest +from graphene import Connection, relay + from ..fields import SQLAlchemyConnectionField from ..filters import FloatFilter from ..types import SQLAlchemyObjectType @@ -13,10 +15,10 @@ def add_test_data(session): first_name='John', last_name='Doe', favorite_pet_kind='cat') session.add(reporter) pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT) - pet.reporters = reporter + pet.reporter = reporter session.add(pet) pet = Pet(name='Snoopy', pet_kind='dog', hair_kind=HairKind.SHORT) - pet.reporters = reporter + pet.reporter = reporter session.add(pet) reporter = Reporter( first_name='John', last_name='Woe', favorite_pet_kind='cat') @@ -28,7 +30,7 @@ def add_test_data(session): first_name='Jane', last_name='Roe', favorite_pet_kind='dog') session.add(reporter) pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG) - pet.reporters.append(reporter) + pet.reporter = reporter session.add(pet) editor = Editor(name="Jack") session.add(editor) @@ -47,45 +49,44 @@ class Meta: class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + connection_class = Connection class Query(graphene.ObjectType): + node = relay.Node.Field() + # TODO how to create filterable singular field? article = graphene.Field(ArticleType) - articles = graphene.List(ArticleType) - # image = graphene.Field(ImageType) - # images = graphene.List(ImageType) + articles = SQLAlchemyConnectionField(ArticleType.connection) + image = graphene.Field(ImageType) + images = SQLAlchemyConnectionField(ImageType.connection) reporter = graphene.Field(ReporterType) - reporters = graphene.List(ReporterType) + reporters = SQLAlchemyConnectionField(ReporterType.connection) def resolve_article(self, _info): return session.query(Article).first() - def resolve_articles(self, _info): - return session.query(Article) - def resolve_image(self, _info): return session.query(Image).first() - def resolve_images(self, _info): - return session.query(Image) - def resolve_reporter(self, _info): return session.query(Reporter).first() - def resolve_reporters(self, _info): - return session.query(Reporter) - return Query # Test a simple example of filtering -@pytest.mark.xfail def test_filter_simple(session): add_test_data(session) Query = create_schema(session) + # TODO test singular field filter + # reporter(filter: {firstName: "John"}) { + # firstName + # } query = """ query { - reporters(filters: {firstName: "John"}) { + reporters(filter: {firstName: "John"}) { firstName } } @@ -95,6 +96,7 @@ def test_filter_simple(session): } schema = graphene.Schema(query=Query) result = schema.execute(query) + print(result) assert not result.errors result = to_std_dicts(result.data) assert result == expected From fcce1f79f5cfed034ba2728ff3a94395a6a41f58 Mon Sep 17 00:00:00 2001 From: sabard Date: Wed, 5 Oct 2022 00:01:56 -0400 Subject: [PATCH 21/81] get simple and 1:1 filter tests passing --- graphene_sqlalchemy/filters.py | 7 +- graphene_sqlalchemy/registry.py | 28 +++-- graphene_sqlalchemy/tests/models.py | 22 ++-- graphene_sqlalchemy/tests/test_filters.py | 145 ++++++++++++++-------- graphene_sqlalchemy/types.py | 9 +- tox.ini | 3 + 6 files changed, 139 insertions(+), 75 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index ff7a136f..425e6af4 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -167,7 +167,7 @@ def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): _meta = InputObjectTypeOptions(cls) new_filter_fields = {} - print(f"Geenerating Fields for {cls.__name__} with type {type} ") + print(f"Generating Fields for {cls.__name__} with type {type} ") # Generate Graphene Fields from the filter functions based on type hints for field_name, _annotations in filter_functions: assert "val" in _annotations, "Each filter method must have a value field with valid type annotations" @@ -267,3 +267,8 @@ class DateFilter(OrderedFilter): class Meta: type = graphene.Date + + +class IdFilter(FieldFilter): + class Meta: + type = graphene.ID diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 20d7f2d8..365e88ac 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -5,13 +5,11 @@ import graphene from graphene import Enum -from graphene_sqlalchemy.filters import (BooleanFilter, FieldFilter, - FloatFilter, IntFilter, StringFilter) +from graphene_sqlalchemy.filters import (FieldFilter, ObjectTypeFilter, + RelationshipFilter) class Registry(object): - from graphene_sqlalchemy.filters import (FieldFilter, ObjectTypeFilter, - RelationshipFilter) def __init__(self): self._registry = {} @@ -131,6 +129,22 @@ def get_filter_for_scalar_type(self, scalar_type: Type[graphene.Scalar]) -> Type return self._registry_scalar_filters.get(scalar_type) + # TODO register enums automatically + def register_filter_for_enum_type(self, enum_type: Type[graphene.Enum], filter_obj: Type[FieldFilter]): + if not isinstance(enum_type, type(graphene.Enum)): + raise TypeError( + "Expected Enum, but got: {!r}".format(enum_type) + ) + + if not isinstance(filter_obj, type(FieldFilter)): + raise TypeError( + "Expected FieldFilter, but got: {!r}".format(filter_obj) + ) + self._registry_scalar_filters[enum_type] = filter_obj + + def get_filter_for_enum_type(self, enum_type: Type[graphene.Enum]) -> Type[FieldFilter]: + return self._registry_enum_type_filters.get(enum_type) + # Filter Object Types def register_filter_for_object_type(self, object_type: Type[graphene.ObjectType], filter_obj: Type[ObjectTypeFilter]): @@ -179,9 +193,3 @@ def get_global_registry(): def reset_global_registry(): global registry registry = None - - -get_global_registry().register_filter_for_scalar_type(graphene.Float, FloatFilter) -get_global_registry().register_filter_for_scalar_type(graphene.Int, IntFilter) -get_global_registry().register_filter_for_scalar_type(graphene.String, StringFilter) -get_global_registry().register_filter_for_scalar_type(graphene.Boolean, BooleanFilter) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 2fd2b01f..8090eaa6 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -9,7 +9,8 @@ String, Table, func, select) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import backref, column_property, composite, mapper, relationship +from sqlalchemy.orm import (backref, column_property, composite, mapper, + relationship) PetKind = Enum("cat", "dog", name="pet_kind") @@ -99,7 +100,8 @@ def hybrid_prop_list(self) -> List[int]: return [1, 2, 3] column_prop = column_property( - select([func.cast(func.count(id), Integer)]), doc="Column property" + select([func.cast(func.count(id), Integer)]).scalar_subquery(), + doc="Column property" ) composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite") @@ -110,7 +112,7 @@ def hybrid_prop_list(self) -> List[int]: Base.metadata, Column("article_id", ForeignKey("articles.id")), Column("tag_id", ForeignKey("tags.id")), -) +) class Image(Base): @@ -136,6 +138,13 @@ class Article(Base): "Reader", secondary="articles_readers", back_populates="articles" ) + # one-to-one relationship with image + image_id = Column(Integer(), ForeignKey('images.id'), unique=True) + image = relationship("Image", backref=backref("articles", uselist=False)) + + # many-to-many relationship with tags + tags = relationship("Tag", secondary=articles_tags_table, backref="articles") + class Reader(Base): __tablename__ = "readers" @@ -151,13 +160,6 @@ class ArticleReader(Base): article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True) reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True) - # one-to-one relationship with image - image_id = Column(Integer(), ForeignKey('images.id'), unique=True) - image = relationship("Image", backref=backref("articles", uselist=False)) - - # many-to-many relationship with tags - tags = relationship("Tag", secondary=articles_tags_table, backref="articles") - class ReflectedEditor(type): """Same as Editor, but using reflected table.""" diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index c082d832..701faa73 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -1,6 +1,6 @@ -import graphene import pytest +import graphene from graphene import Connection, relay from ..fields import SQLAlchemyConnectionField @@ -41,10 +41,23 @@ def create_schema(session): class ArticleType(SQLAlchemyObjectType): class Meta: model = Article + name = "Article" + interfaces = (relay.Node,) + connection_class = Connection class ImageType(SQLAlchemyObjectType): class Meta: model = Image + name = "Image" + interfaces = (relay.Node,) + connection_class = Connection + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + name = "Pet" + interfaces = (relay.Node,) + connection_class = Connection class ReporterType(SQLAlchemyObjectType): class Meta: @@ -55,22 +68,22 @@ class Meta: class Query(graphene.ObjectType): node = relay.Node.Field() - # TODO how to create filterable singular field? - article = graphene.Field(ArticleType) + # # TODO how to create filterable singular field? + # article = graphene.Field(ArticleType) articles = SQLAlchemyConnectionField(ArticleType.connection) - image = graphene.Field(ImageType) + # image = graphene.Field(ImageType) images = SQLAlchemyConnectionField(ImageType.connection) - reporter = graphene.Field(ReporterType) + # reporter = graphene.Field(ReporterType) reporters = SQLAlchemyConnectionField(ReporterType.connection) - def resolve_article(self, _info): - return session.query(Article).first() + # def resolve_article(self, _info): + # return session.query(Article).first() - def resolve_image(self, _info): - return session.query(Image).first() + # def resolve_image(self, _info): + # return session.query(Image).first() - def resolve_reporter(self, _info): - return session.query(Reporter).first() + # def resolve_reporter(self, _info): + # return session.query(Reporter).first() return Query @@ -78,25 +91,25 @@ def resolve_reporter(self, _info): # Test a simple example of filtering def test_filter_simple(session): add_test_data(session) + Query = create_schema(session) - # TODO test singular field filter - # reporter(filter: {firstName: "John"}) { - # firstName - # } query = """ query { - reporters(filter: {firstName: "John"}) { - firstName + reporters (filter: {lastName: {eq: "Roe"}}) { + edges { + node { + firstName + } + } } } """ expected = { - "reporters": [{"firstName": "John"}], + "reporters": {"edges": [{"node": {"firstName": "Jane"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query) - print(result) + result = schema.execute(query, context_value={'session': session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -109,18 +122,31 @@ def test_filter_custom_type(session): Query = create_schema(session) class MathFilter(FloatFilter): - def divisibleBy(dividend: float, divisor: float) -> float: - return dividend % divisor == 0. + class Meta: + type = graphene.Float + + @classmethod + def divisible_by(cls, query, field, val: graphene.Float) -> bool: + return field % val == 0. + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + name = "Pet" + interfaces = (relay.Node,) + connection_class = Connection class ExtraQuery: - pets = SQLAlchemyConnectionField(Pet, filters=MathFilter()) + pets = SQLAlchemyConnectionField( + PetType.connection, filter=MathFilter() + ) class CustomQuery(Query, ExtraQuery): pass query = """ query { - pets (filters: { + pets (filter: { legs: {divisibleBy: 2} }) { name @@ -131,13 +157,13 @@ class CustomQuery(Query, ExtraQuery): "pets": [{"name": "Garfield"}, {"name": "Lassie"}], } schema = graphene.Schema(query=CustomQuery) - result = schema.execute(query) + result = schema.execute(query, context_value={'session': session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected + # Test a 1:1 relationship -@pytest.mark.xfail def test_filter_relationship_one_to_one(session): article = Article(headline='Hi!') image = Image(external_id=1, description="A beautiful image.") @@ -150,18 +176,22 @@ def test_filter_relationship_one_to_one(session): query = """ query { - article (filters: { - image: {description: "A beautiful image."} + articles (filter: { + image: {description: {eq: "A beautiful image."}} }) { - firstName + edges { + node { + headline + } + } } } """ expected = { - "article": [{"headline": "Hi!"}], + "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = schema.execute(query, context_value={'session': session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -176,14 +206,18 @@ def test_filter_relationship_one_to_many(session): # test contains query = """ query { - reporter (filters: { + reporters (filter: { pets: { contains: { name: {in: ["Garfield", "Lassie"]} } } }) { - lastName + edges { + node { + lastName + } + } } } """ @@ -191,7 +225,7 @@ def test_filter_relationship_one_to_many(session): "reporter": [{"lastName": "Doe"}, {"lastName": "Roe"}], } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = schema.execute(query, context_value={'session': session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -199,7 +233,7 @@ def test_filter_relationship_one_to_many(session): # test containsAllOf query = """ query { - reporter (filters: { + reporters (filter: { pets: { containsAllOf: [ name: {eq: "Garfield"}, @@ -207,8 +241,12 @@ def test_filter_relationship_one_to_many(session): ] } }) { - firstName - lastName + edges { + node { + firstName + lastName + } + } } } """ @@ -216,7 +254,7 @@ def test_filter_relationship_one_to_many(session): "reporter": [{"firstName": "John"}, {"lastName": "Doe"}], } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = schema.execute(query, context_value={'session': session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -224,7 +262,7 @@ def test_filter_relationship_one_to_many(session): # test containsExactly query = """ query { - reporter (filters: { + reporter (filter: { pets: { containsExactly: [ name: {eq: "Garfield"} @@ -244,6 +282,7 @@ def test_filter_relationship_one_to_many(session): result = to_std_dicts(result.data) assert result == expected + # Test a n:m relationship @pytest.mark.xfail def test_filter_relationship_many_to_many(session): @@ -264,10 +303,10 @@ def test_filter_relationship_many_to_many(session): # test contains query = """ query { - articles (filters: { + articles (filter: { tags: { contains: { - name: { in: ["sensational", "eye-grabbing"] } + name: { in: ["sensational", "eye-grabbing"] } } } }) { @@ -277,7 +316,7 @@ def test_filter_relationship_many_to_many(session): """ expected = { "articles": [ - {"headline": "Woah! Another!"}, + {"headline": "Woah! Another!"}, {"headline": "Article! Look!"}, ], } @@ -290,7 +329,7 @@ def test_filter_relationship_many_to_many(session): # test containsAllOf query = """ query { - articles (filters: { + articles (filter: { tags: { containsAllOf: [ { tag: { name: { eq: "eye-grabbing" } } }, @@ -314,7 +353,7 @@ def test_filter_relationship_many_to_many(session): # test containsExactly query = """ query { - articles (filters: { + articles (filter: { containsExactly: [ { tag: { name: { eq: "sensational" } } } ] @@ -342,10 +381,10 @@ def test_filter_logic_and(session): query = """ query { - reporters (filters: { + reporters (filter: { and: [ {firstName: "John"}, - {favoritePetKind: "cat"}, + {favoritePetKind: "cat"}, ] }) { lastName @@ -362,7 +401,7 @@ def test_filter_logic_and(session): assert result == expected -# Test connecting filters with "or" +# Test connecting filters with "or" @pytest.mark.xfail def test_filter_logic_or(session): add_test_data(session) @@ -370,10 +409,10 @@ def test_filter_logic_or(session): query = """ query { - reporters (filters: { + reporters (filter: { or: [ {lastName: "Woe"}, - {favoritePetKind: "dog"}, + {favoritePetKind: "dog"}, ] }) { firstName @@ -383,7 +422,7 @@ def test_filter_logic_or(session): """ expected = { "reporters": [ - {"firstName": "John", "lastName": "Woe"}, + {"firstName": "John", "lastName": "Woe"}, {"firstName": "Jane", "lastName": "Roe"}, ], } @@ -402,12 +441,12 @@ def test_filter_logic_and_or(session): query = """ query { - reporters (filters: { + reporters (filter: { and: [ {firstName: "John"}, - or : [ + or : [ {lastName: "Doe"}, - {favoritePetKind: "cat"}, + {favoritePetKind: "cat"}, ] ] }) { diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 95b0f012..218706e5 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -21,7 +21,8 @@ convert_sqlalchemy_relationship) from .enums import (enum_for_field, sort_argument_for_object_type, sort_enum_for_object_type) -from .filters import ObjectTypeFilter, RelationshipFilter +from .filters import (BooleanFilter, FloatFilter, IdFilter, IntFilter, + ObjectTypeFilter, RelationshipFilter, StringFilter) from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import (get_nullable_type, get_query, is_mapped_class, @@ -267,6 +268,12 @@ def __init_subclass_with_meta__( if not registry: registry = get_global_registry() + # TODO way of doing this automatically? + get_global_registry().register_filter_for_scalar_type(graphene.Float, FloatFilter) + get_global_registry().register_filter_for_scalar_type(graphene.Int, IntFilter) + get_global_registry().register_filter_for_scalar_type(graphene.String, StringFilter) + get_global_registry().register_filter_for_scalar_type(graphene.Boolean, BooleanFilter) + get_global_registry().register_filter_for_scalar_type(graphene.ID, IdFilter) assert isinstance(registry, Registry), ( "The attribute registry in {} needs to be an instance of " diff --git a/tox.ini b/tox.ini index 2802dee0..27be21f2 100644 --- a/tox.ini +++ b/tox.ini @@ -38,3 +38,6 @@ basepython = python3.10 deps = -e.[dev] commands = flake8 --exclude setup.py,docs,examples,tests,.tox --max-line-length 120 + +[pytest] +asyncio_mode = auto From 486273764d1dfb75e4527b99cca6422e2f5416eb Mon Sep 17 00:00:00 2001 From: sabard Date: Mon, 24 Oct 2022 03:29:31 -0400 Subject: [PATCH 22/81] fix a couple errors to get basic, incomplete 1:n test working --- graphene_sqlalchemy/filters.py | 41 ++++++++++++++------ graphene_sqlalchemy/tests/test_filters.py | 47 ++++++----------------- graphene_sqlalchemy/types.py | 8 +++- 3 files changed, 48 insertions(+), 48 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 425e6af4..d998165a 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import Any, Dict, List, Tuple, Type, TypeVar, Union +from typing import Any, Dict, List, Tuple, Type, TypeVar, Union, get_type_hints from sqlalchemy import not_ from sqlalchemy.orm import Query, aliased @@ -55,6 +55,7 @@ def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict, model_alia field_filter_type = input_field.get_type().type else: field_filter_type = cls._meta.fields[field].type + # raise Exception # TODO we need to save the relationship props in the meta fields array # to conduct joins and alias the joins (in case there are duplicate joins: A->B A->C B->C) model_field = getattr(model, field) @@ -72,8 +73,8 @@ def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict, model_alia clauses.extend(_clauses) if issubclass(field_filter_type, RelationshipFilter): # TODO see above; not yet working - relationship_prop = None - query, _clauses = field_filter_type.execute_filters(query, filt_dict, relationship_prop) + relationship_prop = field_filter_type._meta.model + query, _clauses = field_filter_type.execute_filters(query, model_field, filt_dict, relationship_prop) clauses.extend(_clauses) elif issubclass(field_filter_type, FieldFilter): query, _clauses = field_filter_type.execute_filters(query, model_field, filt_dict) @@ -84,7 +85,7 @@ def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict, model_alia class RelationshipFilter(graphene.InputObjectType): @classmethod - def __init_subclass_with_meta__(cls, object_type_filter=None, _meta=None, **options): + def __init_subclass_with_meta__(cls, object_type_filter=None, model=None, _meta=None, **options): if not object_type_filter: raise Exception("Relationship Filters must be specific to an object type") # Init meta options class if it doesn't exist already @@ -102,7 +103,7 @@ def __init_subclass_with_meta__(cls, object_type_filter=None, _meta=None, **opti # Check if attribute is a function if callable(func_attr) and filter_function_regex.match(func): # add function and attribute name to the list - filter_functions.append((re.sub("_filter$", "", func), func_attr.__annotations__)) + filter_functions.append((re.sub("_filter$", "", func), get_type_hints(func_attr))) relationship_filters = {} @@ -121,18 +122,36 @@ def __init_subclass_with_meta__(cls, object_type_filter=None, _meta=None, **opti else: _meta.fields = relationship_filters + _meta.model = model + super(RelationshipFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) @classmethod - def contains_filter(cls, val: List["RelationshipFilter"]): - # TODO - pass + def contains_filter(cls, query, field, val: List[AbstractType]): + clauses = [] + for v in val: + query, clauses = v.execute_filters(query, dict(v)) + clauses += clauses + return clauses @classmethod - def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict, relationship_prop) -> Tuple[Query, List[Any]]: + def contains_exactly_filter(cls, query, field, val: List[AbstractType]): + clauses = [] + for v in val: + query, clauses = v.execute_filters(query, dict(v)) + clauses += clauses + return clauses + + @classmethod + def execute_filters( + cls: Type[FieldFilter], query, field, filter_dict: Dict, relationship_prop + ) -> Tuple[Query, List[Any]]: query, clauses = (query, []) - # TODO - return query, clauses + + for filt, val in filter_dict.items(): + clauses += getattr(cls, filt + "_filter")(query, field, val) + + return query.join(field), clauses any_field_filter = TypeVar('any_field_filter', bound="FieldFilter") diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 701faa73..75786f6d 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -26,6 +26,9 @@ def add_test_data(session): article = Article(headline='Hi!') article.reporter = reporter session.add(article) + article = Article(headline='Hello!') + article.reporter = reporter + session.add(article) reporter = Reporter( first_name='Jane', last_name='Roe', favorite_pet_kind='dog') session.add(reporter) @@ -198,7 +201,6 @@ def test_filter_relationship_one_to_one(session): # Test a 1:n relationship -@pytest.mark.xfail def test_filter_relationship_one_to_many(session): add_test_data(session) Query = create_schema(session) @@ -207,10 +209,8 @@ def test_filter_relationship_one_to_many(session): query = """ query { reporters (filter: { - pets: { - contains: { - name: {in: ["Garfield", "Lassie"]} - } + articles: { + contains: [{headline: {eq: "Hi!"}}], } }) { edges { @@ -222,7 +222,7 @@ def test_filter_relationship_one_to_many(session): } """ expected = { - "reporter": [{"lastName": "Doe"}, {"lastName": "Roe"}], + "reporters": {"edges": [{"node": {"lastName": "Woe"}}]}, } schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={'session': session}) @@ -230,14 +230,14 @@ def test_filter_relationship_one_to_many(session): result = to_std_dicts(result.data) assert result == expected - # test containsAllOf + # test containsExactly query = """ query { reporters (filter: { - pets: { - containsAllOf: [ - name: {eq: "Garfield"}, - name: {eq: "Snoopy"}, + articles: { + containsExactly: [ + {headline: {eq: "Hi!"}} + {headline: {eq: "Hello!"}} ] } }) { @@ -251,7 +251,7 @@ def test_filter_relationship_one_to_many(session): } """ expected = { - "reporter": [{"firstName": "John"}, {"lastName": "Doe"}], + "reporters": {"edges": [{"node": {"firstName": "John", "lastName": "Woe"}}]} } schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={'session': session}) @@ -259,29 +259,6 @@ def test_filter_relationship_one_to_many(session): result = to_std_dicts(result.data) assert result == expected - # test containsExactly - query = """ - query { - reporter (filter: { - pets: { - containsExactly: [ - name: {eq: "Garfield"} - ] - } - }) { - firstName - } - } - """ - expected = { - "reporter": [], - } - schema = graphene.Schema(query=Query) - result = schema.execute(query) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected - # Test a n:m relationship @pytest.mark.xfail diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 218706e5..62324d30 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -100,7 +100,9 @@ def get_or_create_relationship_filter(obj_type: Type[ObjectType], registry: Regi if not relationship_filter: object_type_filter = registry.get_filter_for_object_type(obj_type) relationship_filter = RelationshipFilter.create_type(f"{obj_type.__name__}RelationshipFilter", - object_type_filter=object_type_filter) + object_type_filter=object_type_filter, + model=obj_type._meta.model) + registry.register_relationship_filter_for_object_type(obj_type, relationship_filter) return relationship_filter @@ -117,7 +119,9 @@ def resolve_dynamic(): # Resolve Dynamic Type type_ = get_nullable_type(field.get_type()) from graphene_sqlalchemy import SQLAlchemyConnectionField - if isinstance(type_, SQLAlchemyConnectionField): + + from .fields import UnsortedSQLAlchemyConnectionField + if isinstance(type_, SQLAlchemyConnectionField) or isinstance(type_, UnsortedSQLAlchemyConnectionField): inner_type = get_nullable_type(type_.type.Edge.node._type) return graphene.InputField(get_or_create_relationship_filter(inner_type, registry)) elif isinstance(type_, Field): From 15cfbd754c627f9da6f0f1e885bf65b0311ee858 Mon Sep 17 00:00:00 2001 From: sabard Date: Mon, 24 Oct 2022 12:28:31 -0400 Subject: [PATCH 23/81] revert scalar_subquery to as_scalar for sqlalchemy<1.4 tests --- graphene_sqlalchemy/tests/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 8090eaa6..590f1655 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -100,7 +100,8 @@ def hybrid_prop_list(self) -> List[int]: return [1, 2, 3] column_prop = column_property( - select([func.cast(func.count(id), Integer)]).scalar_subquery(), + # TODO scalar_subquery replaced as_scalar in sqlalchemy 1.4 + select([func.cast(func.count(id), Integer)]).as_scalar(), doc="Column property" ) From 122d1828ca08828a0367bb0916435af8ca759d13 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 21 Nov 2022 13:24:43 -0500 Subject: [PATCH 24/81] initial implementation of and/or logic --- graphene_sqlalchemy/filters.py | 177 ++++++++++++++++------ graphene_sqlalchemy/tests/test_filters.py | 82 ++++++---- 2 files changed, 189 insertions(+), 70 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index d998165a..b7f275f5 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -3,7 +3,7 @@ import re from typing import Any, Dict, List, Tuple, Type, TypeVar, Union, get_type_hints -from sqlalchemy import not_ +from sqlalchemy import and_, not_, or_ from sqlalchemy.orm import Query, aliased import graphene @@ -13,17 +13,24 @@ class AbstractType: """Dummy class for generic filters""" + pass class ObjectTypeFilter(graphene.InputObjectType): @classmethod - def __init_subclass_with_meta__(cls, filter_fields=None, model=None, _meta=None, **options): + def __init_subclass_with_meta__( + cls, filter_fields=None, model=None, _meta=None, **options + ): # Init meta options class if it doesn't exist already if not _meta: _meta = InputObjectTypeOptions(cls) + # TODO do this dynamically based off the field name, but also value type + filter_fields["and"] = graphene.InputField(graphene.List(cls)) + filter_fields["or"] = graphene.InputField(graphene.List(cls)) + # Add all fields to the meta options. graphene.InputObjectType will take care of the rest if _meta.fields: _meta.fields.update(filter_fields) @@ -35,12 +42,57 @@ def __init_subclass_with_meta__(cls, filter_fields=None, model=None, _meta=None, super(ObjectTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) @classmethod - def and_logic(cls, query, field, val: list["ObjectTypeFilter"]): - # TODO - pass + def and_logic( + cls, + query, + filter_type: ObjectTypeFilter, + vals: graphene.List["ObjectTypeFilter"], + ): + # # Get the model to join on the Filter Query + # joined_model = filter_type._meta.model + # # Always alias the model + # joined_model_alias = aliased(joined_model) + + clauses = [] + for val in vals: + # # Join the aliased model onto the query + # query = query.join(model_field.of_type(joined_model_alias)) + + query, _clauses = filter_type.execute_filters( + query, val + ) # , model_alias=joined_model_alias) + clauses += _clauses + + return query, [and_(*clauses)] + + @classmethod + def or_logic( + cls, + query, + filter_type: ObjectTypeFilter, + vals: graphene.List["ObjectTypeFilter"], + ): + # # Get the model to join on the Filter Query + # joined_model = filter_type._meta.model + # # Always alias the model + # joined_model_alias = aliased(joined_model) + + clauses = [] + for val in vals: + # # Join the aliased model onto the query + # query = query.join(model_field.of_type(joined_model_alias)) + + query, _clauses = filter_type.execute_filters( + query, val + ) # , model_alias=joined_model_alias) + clauses += _clauses + + return query, [or_(*clauses)] @classmethod - def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict, model_alias=None) -> Tuple[Query, List[Any]]: + def execute_filters( + cls: Type[FieldFilter], query, filter_dict: Dict, model_alias=None + ) -> Tuple[Query, List[Any]]: model = cls._meta.model if model_alias: model = model_alias @@ -58,34 +110,53 @@ def execute_filters(cls: Type[FieldFilter], query, filter_dict: Dict, model_alia # raise Exception # TODO we need to save the relationship props in the meta fields array # to conduct joins and alias the joins (in case there are duplicate joins: A->B A->C B->C) - model_field = getattr(model, field) - if issubclass(field_filter_type, ObjectTypeFilter): - # Get the model to join on the Filter Query - joined_model = field_filter_type._meta.model - # Always alias the model - joined_model_alias = aliased(joined_model) - - # Join the aliased model onto the query - query = query.join(model_field.of_type(joined_model_alias)) - - # Pass the joined query down to the next object type filter for processing - query, _clauses = field_filter_type.execute_filters(query, filt_dict, model_alias=joined_model_alias) - clauses.extend(_clauses) - if issubclass(field_filter_type, RelationshipFilter): - # TODO see above; not yet working - relationship_prop = field_filter_type._meta.model - query, _clauses = field_filter_type.execute_filters(query, model_field, filt_dict, relationship_prop) + if field == "and": + query, _clauses = cls.and_logic( + query, field_filter_type.of_type, filt_dict + ) clauses.extend(_clauses) - elif issubclass(field_filter_type, FieldFilter): - query, _clauses = field_filter_type.execute_filters(query, model_field, filt_dict) + elif field == "or": + query, _clauses = cls.or_logic( + query, field_filter_type.of_type, filt_dict + ) clauses.extend(_clauses) + else: + model_field = getattr(model, field) + if issubclass(field_filter_type, ObjectTypeFilter): + # Get the model to join on the Filter Query + joined_model = field_filter_type._meta.model + # Always alias the model + joined_model_alias = aliased(joined_model) + + # Join the aliased model onto the query + query = query.join(model_field.of_type(joined_model_alias)) + + # Pass the joined query down to the next object type filter for processing + query, _clauses = field_filter_type.execute_filters( + query, filt_dict, model_alias=joined_model_alias + ) + clauses.extend(_clauses) + if issubclass(field_filter_type, RelationshipFilter): + # TODO see above; not yet working + relationship_prop = field_filter_type._meta.model + query, _clauses = field_filter_type.execute_filters( + query, model_field, filt_dict, relationship_prop + ) + clauses.extend(_clauses) + elif issubclass(field_filter_type, FieldFilter): + query, _clauses = field_filter_type.execute_filters( + query, model_field, filt_dict + ) + clauses.extend(_clauses) return query, clauses class RelationshipFilter(graphene.InputObjectType): @classmethod - def __init_subclass_with_meta__(cls, object_type_filter=None, model=None, _meta=None, **options): + def __init_subclass_with_meta__( + cls, object_type_filter=None, model=None, _meta=None, **options + ): if not object_type_filter: raise Exception("Relationship Filters must be specific to an object type") # Init meta options class if it doesn't exist already @@ -103,18 +174,26 @@ def __init_subclass_with_meta__(cls, object_type_filter=None, model=None, _meta= # Check if attribute is a function if callable(func_attr) and filter_function_regex.match(func): # add function and attribute name to the list - filter_functions.append((re.sub("_filter$", "", func), get_type_hints(func_attr))) + filter_functions.append( + (re.sub("_filter$", "", func), get_type_hints(func_attr)) + ) relationship_filters = {} # Generate Graphene Fields from the filter functions based on type hints for field_name, _annotations in filter_functions: - assert "val" in _annotations, "Each filter method must have a value field with valid type annotations" + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" # If type is generic, replace with actual type of filter class if is_list(_annotations["val"]): - relationship_filters.update({field_name: graphene.InputField(graphene.List(object_type_filter))}) + relationship_filters.update( + {field_name: graphene.InputField(graphene.List(object_type_filter))} + ) else: - relationship_filters.update({field_name: graphene.InputField(object_type_filter)}) + relationship_filters.update( + {field_name: graphene.InputField(object_type_filter)} + ) # Add all fields to the meta options. graphene.InputObjectType will take care of the rest if _meta.fields: @@ -124,22 +203,24 @@ def __init_subclass_with_meta__(cls, object_type_filter=None, model=None, _meta= _meta.model = model - super(RelationshipFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) + super(RelationshipFilter, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) @classmethod def contains_filter(cls, query, field, val: List[AbstractType]): clauses = [] for v in val: - query, clauses = v.execute_filters(query, dict(v)) - clauses += clauses + query, _clauses = v.execute_filters(query, dict(v)) + clauses += _clauses return clauses @classmethod def contains_exactly_filter(cls, query, field, val: List[AbstractType]): clauses = [] for v in val: - query, clauses = v.execute_filters(query, dict(v)) - clauses += clauses + query, _clauses = v.execute_filters(query, dict(v)) + clauses += _clauses return clauses @classmethod @@ -154,7 +235,7 @@ def execute_filters( return query.join(field), clauses -any_field_filter = TypeVar('any_field_filter', bound="FieldFilter") +any_field_filter = TypeVar("any_field_filter", bound="FieldFilter") class FieldFilter(graphene.InputObjectType): @@ -177,8 +258,8 @@ def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): # Check if attribute is a function if callable(func_attr) and filter_function_regex.match(func): # add function and attribute name to the list - filter_functions.append(( - re.sub("_filter$", "", func), func_attr.__annotations__) + filter_functions.append( + (re.sub("_filter$", "", func), func_attr.__annotations__) ) # Init meta options class if it doesn't exist already @@ -189,7 +270,9 @@ def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): print(f"Generating Fields for {cls.__name__} with type {type} ") # Generate Graphene Fields from the filter functions based on type hints for field_name, _annotations in filter_functions: - assert "val" in _annotations, "Each filter method must have a value field with valid type annotations" + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" # If type is generic, replace with actual type of filter class print(f"Field: {field_name} with annotation {_annotations['val']}") if _annotations["val"] == "AbstractType": @@ -199,7 +282,9 @@ def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): else: # TODO this is a place holder, we need to convert the type of val to a valid graphene # type that we can pass to the InputField. We could re-use converter.convert_hybrid_property_return_type - new_filter_fields.update({field_name: graphene.InputField(graphene.String)}) + new_filter_fields.update( + {field_name: graphene.InputField(graphene.String)} + ) # Add all fields to the meta options. graphene.InputbjectType will take care of the rest if _meta.fields: @@ -212,15 +297,21 @@ def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): # Abstract methods can be marked using AbstractType. See comment on the init method @classmethod - def eq_filter(cls, query, field, val: AbstractType) -> Union[Tuple[Query, Any], Any]: + def eq_filter( + cls, query, field, val: AbstractType + ) -> Union[Tuple[Query, Any], Any]: return field == val @classmethod - def n_eq_filter(cls, query, field, val: AbstractType) -> Union[Tuple[Query, Any], Any]: + def n_eq_filter( + cls, query, field, val: AbstractType + ) -> Union[Tuple[Query, Any], Any]: return not_(field == val) @classmethod - def execute_filters(cls: Type[FieldFilter], query, field, filter_dict: any_field_filter) -> Tuple[Query, List[Any]]: + def execute_filters( + cls: Type[FieldFilter], query, field, filter_dict: any_field_filter + ) -> Tuple[Query, List[Any]]: clauses = [] for filt, val in filter_dict.items(): clause = getattr(cls, filt + "_filter")(query, field, val) diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 75786f6d..78a18ec4 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -9,6 +9,10 @@ from .models import Article, Editor, HairKind, Image, Pet, Reporter, Tag from .utils import to_std_dicts +# TODO test that generated schema is correct for all examples with: +# with open('schema.gql', 'w') as fp: +# fp.write(str(schema)) + def add_test_data(session): reporter = Reporter( @@ -298,7 +302,7 @@ def test_filter_relationship_many_to_many(session): ], } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = schema.execute(query, context_value={'session': session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -343,14 +347,13 @@ def test_filter_relationship_many_to_many(session): "articles": [{"headline": "Article! Look!"}], } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = schema.execute(query, context_value={'session': session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected # Test connecting filters with "and" -@pytest.mark.xfail def test_filter_logic_and(session): add_test_data(session) @@ -360,26 +363,32 @@ def test_filter_logic_and(session): query { reporters (filter: { and: [ - {firstName: "John"}, - {favoritePetKind: "cat"}, + { firstName: { eq: "John" } }, + # TODO get enums working for filters + # { favoritePetKind: { eq: "cat" } }, ] }) { - lastName + edges { + node { + lastName + } + } } } """ expected = { - "reporters": [{"lastName": "Doe"}, {"lastName": "Woe"}], + "reporters": {"edges": [{"node": {"lastName": "Doe"}}, {"node": {"lastName": "Woe"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query) + with open('schema.gql', 'w') as fp: + fp.write(str(schema)) + result = schema.execute(query, context_value={'session': session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected # Test connecting filters with "or" -@pytest.mark.xfail def test_filter_logic_or(session): add_test_data(session) Query = create_schema(session) @@ -388,30 +397,37 @@ def test_filter_logic_or(session): query { reporters (filter: { or: [ - {lastName: "Woe"}, - {favoritePetKind: "dog"}, + { lastName: { eq: "Woe" } }, + # TODO get enums working for filters + #{ favoritePetKind: { eq: "dog" } }, ] }) { - firstName - lastName + edges { + node { + firstName + lastName + } + } } } """ expected = { - "reporters": [ - {"firstName": "John", "lastName": "Woe"}, - {"firstName": "Jane", "lastName": "Roe"}, - ], + "reporters": { + "edges": [ + {"node": {"firstName": "John", "lastName": "Woe"}}, + # TODO get enums working for filters + # {"node": {"firstName": "Jane", "lastName": "Roe"}}, + ] + } } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = schema.execute(query, context_value={'session': session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected # Test connecting filters with "and" and "or" together -@pytest.mark.xfail def test_filter_logic_and_or(session): add_test_data(session) Query = create_schema(session) @@ -420,22 +436,34 @@ def test_filter_logic_and_or(session): query { reporters (filter: { and: [ - {firstName: "John"}, - or : [ - {lastName: "Doe"}, - {favoritePetKind: "cat"}, - ] + { firstName: { eq: "John" } }, + { + or: [ + { lastName: { eq: "Doe" } }, + # TODO get enums working for filters + # { favoritePetKind: { eq: "cat" } }, + ] + } ] }) { - firstName + edges { + node { + firstName + } + } } } """ expected = { - "reporters": [{"firstName": "John"}, {"firstName": "Jane"}], + "reporters": { + "edges": [ + {"node": {"firstName": "John"}}, + # {"node": {"firstName": "Jane"}}, + ], + } } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = schema.execute(query, context_value={'session': session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected From 11f7d91e84e183a33e8b03dd00d885e94d8f3a86 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 21 Nov 2022 10:50:13 -0800 Subject: [PATCH 25/81] revert filter logic so 1:n test passes --- graphene_sqlalchemy/filters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index b7f275f5..8bc98ffd 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -211,16 +211,16 @@ def __init_subclass_with_meta__( def contains_filter(cls, query, field, val: List[AbstractType]): clauses = [] for v in val: - query, _clauses = v.execute_filters(query, dict(v)) - clauses += _clauses + query, clauses = v.execute_filters(query, dict(v)) + clauses += clauses return clauses @classmethod def contains_exactly_filter(cls, query, field, val: List[AbstractType]): clauses = [] for v in val: - query, _clauses = v.execute_filters(query, dict(v)) - clauses += _clauses + query, clauses = v.execute_filters(query, dict(v)) + clauses += clauses return clauses @classmethod From 6263540911f8b0b3d367e5672de368ca375abb69 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Thu, 1 Dec 2022 16:47:30 +0100 Subject: [PATCH 26/81] partial: support custom filter fields, custom types of filter inputs --- graphene_sqlalchemy/converter.py | 56 ++-- graphene_sqlalchemy/fields.py | 51 ++-- graphene_sqlalchemy/filters.py | 327 ++++++++++++---------- graphene_sqlalchemy/registry.py | 54 ++-- graphene_sqlalchemy/tests/test_filters.py | 93 +++--- graphene_sqlalchemy/types.py | 87 ++++-- 6 files changed, 393 insertions(+), 275 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index d3ae8123..8a3369bd 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -16,7 +16,6 @@ from .batching import get_batch_resolver from .enums import enum_for_sa_enum -from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from .registry import get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( @@ -141,10 +140,11 @@ def _convert_o2m_or_m2m_relationship( :param dict field_kwargs: :rtype: Field """ + from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory + child_type = obj_type._meta.registry.get_type_for_model( relationship_prop.mapper.entity ) - if not child_type._meta.connection: return graphene.Field(graphene.List(child_type), **field_kwargs) @@ -337,7 +337,9 @@ def convert_variant_to_impl_type(type, column, registry=None): @singledispatchbymatchfunction -def convert_sqlalchemy_hybrid_property_type(arg: Any): +def convert_sqlalchemy_hybrid_property_type( + arg: Any, replace_type_vars: typing.Dict[str, Any] = None +): existing_graphql_type = get_global_registry().get_type_for_model(arg) if existing_graphql_type: return existing_graphql_type @@ -345,9 +347,15 @@ def convert_sqlalchemy_hybrid_property_type(arg: Any): if isinstance(arg, type(graphene.ObjectType)): return arg + if isinstance(arg, type(graphene.InputObjectType)): + return arg + if isinstance(arg, type(graphene.Scalar)): return arg + if replace_type_vars and arg in replace_type_vars: + return replace_type_vars[arg] + # No valid type found, warn and fall back to graphene.String warnings.warn( f'I don\'t know how to generate a GraphQL type out of a "{arg}" type.' @@ -357,22 +365,22 @@ def convert_sqlalchemy_hybrid_property_type(arg: Any): @convert_sqlalchemy_hybrid_property_type.register(value_equals(str)) -def convert_sqlalchemy_hybrid_property_type_str(arg): +def convert_sqlalchemy_hybrid_property_type_str(arg, *args, **kwargs): return graphene.String @convert_sqlalchemy_hybrid_property_type.register(value_equals(int)) -def convert_sqlalchemy_hybrid_property_type_int(arg): +def convert_sqlalchemy_hybrid_property_type_int(arg, *args, **kwargs): return graphene.Int @convert_sqlalchemy_hybrid_property_type.register(value_equals(float)) -def convert_sqlalchemy_hybrid_property_type_float(arg): +def convert_sqlalchemy_hybrid_property_type_float(arg, *args, **kwargs): return graphene.Float @convert_sqlalchemy_hybrid_property_type.register(value_equals(Decimal)) -def convert_sqlalchemy_hybrid_property_type_decimal(arg): +def convert_sqlalchemy_hybrid_property_type_decimal(arg, *args, **kwargs): # The reason Decimal should be serialized as a String is because this is a # base10 type used in things like money, and string allows it to not # lose precision (which would happen if we downcasted to a Float, for example) @@ -380,27 +388,27 @@ def convert_sqlalchemy_hybrid_property_type_decimal(arg): @convert_sqlalchemy_hybrid_property_type.register(value_equals(bool)) -def convert_sqlalchemy_hybrid_property_type_bool(arg): +def convert_sqlalchemy_hybrid_property_type_bool(arg, *args, **kwargs): return graphene.Boolean @convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.datetime)) -def convert_sqlalchemy_hybrid_property_type_datetime(arg): +def convert_sqlalchemy_hybrid_property_type_datetime(arg, *args, **kwargs): return graphene.DateTime @convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.date)) -def convert_sqlalchemy_hybrid_property_type_date(arg): +def convert_sqlalchemy_hybrid_property_type_date(arg, *args, **kwargs): return graphene.Date @convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.time)) -def convert_sqlalchemy_hybrid_property_type_time(arg): +def convert_sqlalchemy_hybrid_property_type_time(arg, *args, **kwargs): return graphene.Time @convert_sqlalchemy_hybrid_property_type.register(value_equals(uuid.UUID)) -def convert_sqlalchemy_hybrid_property_type_uuid(arg): +def convert_sqlalchemy_hybrid_property_type_uuid(arg, *args, **kwargs): return graphene.UUID @@ -428,7 +436,7 @@ def graphene_union_for_py_union( @convert_sqlalchemy_hybrid_property_type.register(is_union) -def convert_sqlalchemy_hybrid_property_union(arg): +def convert_sqlalchemy_hybrid_property_union(arg, *args, **kwargs): """ Converts Unions (Union[X,Y], or X | Y for python > 3.10) to the corresponding graphene schema object. Since Optionals are internally represented as Union[T, ], they are handled here as well. @@ -446,6 +454,7 @@ def convert_sqlalchemy_hybrid_property_union(arg): # Just get the T out of the list of arguments by filtering out the NoneType nested_types = list(filter(lambda x: not type(None) == x, arg.__args__)) + # TODO redo this for , *args, **kwargs # Map the graphene types to the nested types. # We use convert_sqlalchemy_hybrid_property_type instead of the registry to account for ForwardRefs, Lists,... graphene_types = list(map(convert_sqlalchemy_hybrid_property_type, nested_types)) @@ -471,20 +480,24 @@ def convert_sqlalchemy_hybrid_property_union(arg): ) -@convert_sqlalchemy_hybrid_property_type.register( - lambda x: getattr(x, "__origin__", None) in [list, typing.List] -) -def convert_sqlalchemy_hybrid_property_type_list_t(arg): +def is_list(x): + return getattr(x, "__origin__", None) in [list, typing.List] + + +@convert_sqlalchemy_hybrid_property_type.register(is_list) +def convert_sqlalchemy_hybrid_property_type_list_t(arg, *args, **kwargs): # type is either list[T] or List[T], generic argument at __args__[0] internal_type = arg.__args__[0] - graphql_internal_type = convert_sqlalchemy_hybrid_property_type(internal_type) + graphql_internal_type = convert_sqlalchemy_hybrid_property_type( + internal_type, *args, **kwargs + ) return graphene.List(graphql_internal_type) @convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(ForwardRef)) -def convert_sqlalchemy_hybrid_property_forwardref(arg): +def convert_sqlalchemy_hybrid_property_forwardref(arg, *args, **kwargs): """ Generate a lambda that will resolve the type at runtime This takes care of self-references @@ -502,12 +515,11 @@ def forward_reference_solver(): @convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(str)) -def convert_sqlalchemy_hybrid_property_bare_str(arg): +def convert_sqlalchemy_hybrid_property_bare_str(arg, *args, **kwargs): """ Convert Bare String into a ForwardRef """ - - return convert_sqlalchemy_hybrid_property_type(ForwardRef(arg)) + return convert_sqlalchemy_hybrid_property_type(ForwardRef(arg), *args, **kwargs) def convert_hybrid_property_return_type(hybrid_prop): diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index a97e59ed..8c5a275f 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -38,30 +38,35 @@ def type(self): def __init__(self, type_, *args, **kwargs): nullable_type = get_nullable_type(type_) # Handle Sorting and Filtering - if nullable_type and issubclass(nullable_type, Connection): - if "sort" not in kwargs: - # Let super class raise if type is not a Connection - try: - kwargs.setdefault( - "sort", nullable_type.Edge.node._type.sort_argument() + if ( + nullable_type + and issubclass(nullable_type, Connection) + and "sort" not in kwargs + ): + # Let super class raise if type is not a Connection + try: + kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) + except (AttributeError, TypeError): + raise TypeError( + 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' + " to None to disabling the creation of the sort query argument".format( + nullable_type.__name__ ) - except (AttributeError, TypeError): - raise TypeError( - 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' - " to None to disabling the creation of the sort query argument".format( - nullable_type.__name__ - ) - ) - elif "sort" in kwargs and kwargs["sort"] is None: - del kwargs["sort"] - - if "filter" not in kwargs: - # Only add filtering if a filter argument exists on the object type - filter_argument = nullable_type.Edge.node._type.get_filter_argument() - if filter_argument: - kwargs.setdefault("filter", filter_argument) - elif "filter" in kwargs and kwargs["filter"] is None: - del kwargs["filter"] + ) + elif "sort" in kwargs and kwargs["sort"] is None: + del kwargs["sort"] + + if ( + nullable_type + and issubclass(nullable_type, Connection) + and "filter" not in kwargs + ): + # Only add filtering if a filter argument exists on the object type + filter_argument = nullable_type.Edge.node._type.get_filter_argument() + if filter_argument: + kwargs.setdefault("filter", filter_argument) + elif "filter" in kwargs and kwargs["filter"] is None: + del kwargs["filter"] super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 8bc98ffd..2f014f94 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import re from typing import Any, Dict, List, Tuple, Type, TypeVar, Union, get_type_hints @@ -7,14 +5,34 @@ from sqlalchemy.orm import Query, aliased import graphene -from graphene.types.inputobjecttype import InputObjectTypeOptions +from graphene.types.inputobjecttype import ( + InputObjectTypeContainer, + InputObjectTypeOptions, +) from graphene_sqlalchemy.utils import is_list +ObjectTypeFilterSelf = TypeVar( + "ObjectTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer +) + -class AbstractType: - """Dummy class for generic filters""" +def _get_functions_by_regex( + regex: str, subtract_regex: str, class_: Type +) -> list[Tuple[str, dict[str, Any]]]: + function_regex = re.compile(regex) - pass + matching_functions = [] + + # Search the entire class for functions matching the filter regex + for func in dir(class_): + func_attr = getattr(class_, func) + # Check if attribute is a function + if callable(func_attr) and function_regex.match(func): + # add function and attribute name to the list + matching_functions.append( + (re.sub(subtract_regex, "", func), func_attr.__annotations__) + ) + return matching_functions class ObjectTypeFilter(graphene.InputObjectType): @@ -22,20 +40,38 @@ class ObjectTypeFilter(graphene.InputObjectType): def __init_subclass_with_meta__( cls, filter_fields=None, model=None, _meta=None, **options ): + from graphene_sqlalchemy.converter import ( + convert_sqlalchemy_hybrid_property_type, + ) # Init meta options class if it doesn't exist already if not _meta: _meta = InputObjectTypeOptions(cls) - # TODO do this dynamically based off the field name, but also value type - filter_fields["and"] = graphene.InputField(graphene.List(cls)) - filter_fields["or"] = graphene.InputField(graphene.List(cls)) + logic_functions = _get_functions_by_regex(".+_logic$", "_logic$", cls) + new_filter_fields = {} + print(f"Generating Filter for {cls.__name__} with model {model} ") + # Generate Graphene Fields from the filter functions based on type hints + for field_name, _annotations in logic_functions: + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + + replace_type_vars = {ObjectTypeFilterSelf: cls} + + field_type = convert_sqlalchemy_hybrid_property_type( + _annotations.get("val", str), replace_type_vars=replace_type_vars + ) + new_filter_fields.update({field_name: graphene.InputField(field_type)}) # Add all fields to the meta options. graphene.InputObjectType will take care of the rest + if _meta.fields: _meta.fields.update(filter_fields) else: _meta.fields = filter_fields + _meta.fields.update(new_filter_fields) _meta.model = model @@ -45,8 +81,8 @@ def __init_subclass_with_meta__( def and_logic( cls, query, - filter_type: ObjectTypeFilter, - vals: graphene.List["ObjectTypeFilter"], + filter_type: "ObjectTypeFilter", + val: List[ObjectTypeFilterSelf], ): # # Get the model to join on the Filter Query # joined_model = filter_type._meta.model @@ -54,12 +90,12 @@ def and_logic( # joined_model_alias = aliased(joined_model) clauses = [] - for val in vals: + for value in val: # # Join the aliased model onto the query # query = query.join(model_field.of_type(joined_model_alias)) query, _clauses = filter_type.execute_filters( - query, val + query, value ) # , model_alias=joined_model_alias) clauses += _clauses @@ -69,8 +105,8 @@ def and_logic( def or_logic( cls, query, - filter_type: ObjectTypeFilter, - vals: graphene.List["ObjectTypeFilter"], + filter_type: "ObjectTypeFilter", + val: List[ObjectTypeFilterSelf], ): # # Get the model to join on the Filter Query # joined_model = filter_type._meta.model @@ -78,12 +114,12 @@ def or_logic( # joined_model_alias = aliased(joined_model) clauses = [] - for val in vals: + for value in val: # # Join the aliased model onto the query # query = query.join(model_field.of_type(joined_model_alias)) query, _clauses = filter_type.execute_filters( - query, val + query, value ) # , model_alias=joined_model_alias) clauses += _clauses @@ -91,14 +127,15 @@ def or_logic( @classmethod def execute_filters( - cls: Type[FieldFilter], query, filter_dict: Dict, model_alias=None + cls, query, filter_dict: Dict[str, Any], model_alias=None ) -> Tuple[Query, List[Any]]: model = cls._meta.model if model_alias: model = model_alias clauses = [] - for field, filt_dict in filter_dict.items(): + + for field, field_filters in filter_dict.items(): # Relationships are Dynamic, we need to resolve them fist # Maybe we can cache these dynamics to improve efficiency # Check with a profiler is required to determine necessity @@ -112,12 +149,12 @@ def execute_filters( # to conduct joins and alias the joins (in case there are duplicate joins: A->B A->C B->C) if field == "and": query, _clauses = cls.and_logic( - query, field_filter_type.of_type, filt_dict + query, field_filter_type.of_type, field_filters ) clauses.extend(_clauses) elif field == "or": query, _clauses = cls.or_logic( - query, field_filter_type.of_type, filt_dict + query, field_filter_type.of_type, field_filters ) clauses.extend(_clauses) else: @@ -133,109 +170,30 @@ def execute_filters( # Pass the joined query down to the next object type filter for processing query, _clauses = field_filter_type.execute_filters( - query, filt_dict, model_alias=joined_model_alias + query, field_filters, model_alias=joined_model_alias ) clauses.extend(_clauses) if issubclass(field_filter_type, RelationshipFilter): # TODO see above; not yet working relationship_prop = field_filter_type._meta.model query, _clauses = field_filter_type.execute_filters( - query, model_field, filt_dict, relationship_prop + query, model_field, field_filters, relationship_prop ) clauses.extend(_clauses) elif issubclass(field_filter_type, FieldFilter): query, _clauses = field_filter_type.execute_filters( - query, model_field, filt_dict + query, model_field, field_filters ) clauses.extend(_clauses) return query, clauses -class RelationshipFilter(graphene.InputObjectType): - @classmethod - def __init_subclass_with_meta__( - cls, object_type_filter=None, model=None, _meta=None, **options - ): - if not object_type_filter: - raise Exception("Relationship Filters must be specific to an object type") - # Init meta options class if it doesn't exist already - if not _meta: - _meta = InputObjectTypeOptions(cls) +ScalarFilterInputType = TypeVar("ScalarFilterInputType") - # get all filter functions - filter_function_regex = re.compile(".+_filter$") - - filter_functions = [] - # Search the entire class for functions matching the filter regex - for func in dir(cls): - func_attr = getattr(cls, func) - # Check if attribute is a function - if callable(func_attr) and filter_function_regex.match(func): - # add function and attribute name to the list - filter_functions.append( - (re.sub("_filter$", "", func), get_type_hints(func_attr)) - ) - - relationship_filters = {} - - # Generate Graphene Fields from the filter functions based on type hints - for field_name, _annotations in filter_functions: - assert ( - "val" in _annotations - ), "Each filter method must have a value field with valid type annotations" - # If type is generic, replace with actual type of filter class - if is_list(_annotations["val"]): - relationship_filters.update( - {field_name: graphene.InputField(graphene.List(object_type_filter))} - ) - else: - relationship_filters.update( - {field_name: graphene.InputField(object_type_filter)} - ) - - # Add all fields to the meta options. graphene.InputObjectType will take care of the rest - if _meta.fields: - _meta.fields.update(relationship_filters) - else: - _meta.fields = relationship_filters - - _meta.model = model - - super(RelationshipFilter, cls).__init_subclass_with_meta__( - _meta=_meta, **options - ) - - @classmethod - def contains_filter(cls, query, field, val: List[AbstractType]): - clauses = [] - for v in val: - query, clauses = v.execute_filters(query, dict(v)) - clauses += clauses - return clauses - - @classmethod - def contains_exactly_filter(cls, query, field, val: List[AbstractType]): - clauses = [] - for v in val: - query, clauses = v.execute_filters(query, dict(v)) - clauses += clauses - return clauses - - @classmethod - def execute_filters( - cls: Type[FieldFilter], query, field, filter_dict: Dict, relationship_prop - ) -> Tuple[Query, List[Any]]: - query, clauses = (query, []) - - for filt, val in filter_dict.items(): - clauses += getattr(cls, filt + "_filter")(query, field, val) - - return query.join(field), clauses - - -any_field_filter = TypeVar("any_field_filter", bound="FieldFilter") +class FieldFilterOptions(InputObjectTypeOptions): + graphene_type: Type = None class FieldFilter(graphene.InputObjectType): @@ -245,46 +203,32 @@ class FieldFilter(graphene.InputObjectType): The Dynamic fields will resolve to Meta.filtered_type""" @classmethod - def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): + def __init_subclass_with_meta__(cls, graphene_type=None, _meta=None, **options): + from .converter import convert_sqlalchemy_hybrid_property_type # get all filter functions - filter_function_regex = re.compile(".+_filter$") - filter_functions = [] - - # Search the entire class for functions matching the filter regex - for func in dir(cls): - func_attr = getattr(cls, func) - # Check if attribute is a function - if callable(func_attr) and filter_function_regex.match(func): - # add function and attribute name to the list - filter_functions.append( - (re.sub("_filter$", "", func), func_attr.__annotations__) - ) + filter_functions = _get_functions_by_regex(".+_filter$", "_filter$", cls) # Init meta options class if it doesn't exist already if not _meta: - _meta = InputObjectTypeOptions(cls) + _meta = FieldFilterOptions(cls) + + if not _meta.graphene_type: + _meta.graphene_type = graphene_type new_filter_fields = {} - print(f"Generating Fields for {cls.__name__} with type {type} ") # Generate Graphene Fields from the filter functions based on type hints for field_name, _annotations in filter_functions: assert ( "val" in _annotations ), "Each filter method must have a value field with valid type annotations" # If type is generic, replace with actual type of filter class - print(f"Field: {field_name} with annotation {_annotations['val']}") - if _annotations["val"] == "AbstractType": - # TODO Maybe there is an existing class or a more elegant way to solve this - # One option would be to only annotate non-abstract filters - new_filter_fields.update({field_name: graphene.InputField(type)}) - else: - # TODO this is a place holder, we need to convert the type of val to a valid graphene - # type that we can pass to the InputField. We could re-use converter.convert_hybrid_property_return_type - new_filter_fields.update( - {field_name: graphene.InputField(graphene.String)} - ) + replace_type_vars = {ScalarFilterInputType: _meta.graphene_type} + field_type = convert_sqlalchemy_hybrid_property_type( + _annotations.get("val", str), replace_type_vars=replace_type_vars + ) + new_filter_fields.update({field_name: graphene.InputField(field_type)}) # Add all fields to the meta options. graphene.InputbjectType will take care of the rest if _meta.fields: @@ -295,22 +239,30 @@ def __init_subclass_with_meta__(cls, type=None, _meta=None, **options): # Pass modified meta to the super class super(FieldFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) - # Abstract methods can be marked using AbstractType. See comment on the init method + # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method @classmethod def eq_filter( - cls, query, field, val: AbstractType + cls, query, field, val: ScalarFilterInputType ) -> Union[Tuple[Query, Any], Any]: return field == val @classmethod def n_eq_filter( - cls, query, field, val: AbstractType + cls, query, field, val: ScalarFilterInputType ) -> Union[Tuple[Query, Any], Any]: return not_(field == val) + @classmethod + def in_filter(cls, query, field, val: List[ScalarFilterInputType]): + return field.in_(val) + + @classmethod + def not_in_filter(cls, query, field, val: List[ScalarFilterInputType]): + return field.not_int(val) + @classmethod def execute_filters( - cls: Type[FieldFilter], query, field, filter_dict: any_field_filter + cls, query, field, filter_dict: dict[str, any] ) -> Tuple[Query, List[Any]]: clauses = [] for filt, val in filter_dict.items(): @@ -324,12 +276,12 @@ def execute_filters( class StringFilter(FieldFilter): class Meta: - type = graphene.String + graphene_type = graphene.String class BooleanFilter(FieldFilter): class Meta: - type = graphene.Boolean + graphene_type = graphene.Boolean class OrderedFilter(FieldFilter): @@ -337,19 +289,19 @@ class Meta: abstract = True @classmethod - def gt_filter(cls, query, field, val: AbstractType) -> bool: + def gt_filter(cls, query, field, val: ScalarFilterInputType) -> bool: return field > val @classmethod - def gte_filter(cls, query, field, val: AbstractType) -> bool: + def gte_filter(cls, query, field, val: ScalarFilterInputType) -> bool: return field >= val @classmethod - def lt_filter(cls, query, field, val: AbstractType) -> bool: + def lt_filter(cls, query, field, val: ScalarFilterInputType) -> bool: return field < val @classmethod - def lte_filter(cls, query, field, val: AbstractType) -> bool: + def lte_filter(cls, query, field, val: ScalarFilterInputType) -> bool: return field <= val @@ -364,21 +316,104 @@ class FloatFilter(NumberFilter): """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" class Meta: - type = graphene.Float + graphene_type = graphene.Float class IntFilter(NumberFilter): class Meta: - type = graphene.Int + graphene_type = graphene.Int class DateFilter(OrderedFilter): """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" class Meta: - type = graphene.Date + graphene_type = graphene.Date class IdFilter(FieldFilter): class Meta: - type = graphene.ID + graphene_type = graphene.ID + + +class RelationshipFilter(graphene.InputObjectType): + @classmethod + def __init_subclass_with_meta__( + cls, object_type_filter=None, model=None, _meta=None, **options + ): + if not object_type_filter: + raise Exception("Relationship Filters must be specific to an object type") + # Init meta options class if it doesn't exist already + if not _meta: + _meta = InputObjectTypeOptions(cls) + + # get all filter functions + filter_function_regex = re.compile(".+_filter$") + + filter_functions = [] + + # Search the entire class for functions matching the filter regex + for func in dir(cls): + func_attr = getattr(cls, func) + # Check if attribute is a function + if callable(func_attr) and filter_function_regex.match(func): + # add function and attribute name to the list + filter_functions.append( + (re.sub("_filter$", "", func), get_type_hints(func_attr)) + ) + + relationship_filters = {} + + # Generate Graphene Fields from the filter functions based on type hints + for field_name, _annotations in filter_functions: + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + if is_list(_annotations["val"]): + relationship_filters.update( + {field_name: graphene.InputField(graphene.List(object_type_filter))} + ) + else: + relationship_filters.update( + {field_name: graphene.InputField(object_type_filter)} + ) + + # Add all fields to the meta options. graphene.InputObjectType will take care of the rest + if _meta.fields: + _meta.fields.update(relationship_filters) + else: + _meta.fields = relationship_filters + + _meta.model = model + + super(RelationshipFilter, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) + + @classmethod + def contains_filter(cls, query, field, val: List[ScalarFilterInputType]): + clauses = [] + for v in val: + query, clauses = v.execute_filters(query, dict(v)) + clauses += clauses + return clauses + + @classmethod + def contains_exactly_filter(cls, query, field, val: List[ScalarFilterInputType]): + clauses = [] + for v in val: + query, clauses = v.execute_filters(query, dict(v)) + clauses += clauses + return clauses + + @classmethod + def execute_filters( + cls: Type[FieldFilter], query, field, filter_dict: Dict, relationship_prop + ) -> Tuple[Query, List[Any]]: + query, clauses = (query, []) + + for filt, val in filter_dict.items(): + clauses += getattr(cls, filt + "_filter")(query, field, val) + + return query.join(field), clauses diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 30e9e647..b6a4c201 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,15 +1,17 @@ from collections import defaultdict -from typing import List, Type +from typing import TYPE_CHECKING, List, Type from sqlalchemy.types import Enum as SQLAlchemyEnumType import graphene from graphene import Enum -from graphene_sqlalchemy.filters import ( - FieldFilter, - ObjectTypeFilter, - RelationshipFilter, -) + +if TYPE_CHECKING: + from graphene_sqlalchemy.filters import ( + FieldFilter, + ObjectTypeFilter, + RelationshipFilter, + ) class Registry(object): @@ -118,45 +120,59 @@ def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]]) # Filter Scalar Fields of Object Types def register_filter_for_scalar_type( - self, scalar_type: Type[graphene.Scalar], filter_obj: Type[FieldFilter] + self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"] ): + from .filters import FieldFilter + if not isinstance(scalar_type, type(graphene.Scalar)): raise TypeError("Expected Scalar, but got: {!r}".format(scalar_type)) - if not isinstance(filter_obj, type(FieldFilter)): + if not issubclass(filter_obj, FieldFilter): raise TypeError("Expected ScalarFilter, but got: {!r}".format(filter_obj)) self._registry_scalar_filters[scalar_type] = filter_obj def get_filter_for_scalar_type( self, scalar_type: Type[graphene.Scalar] - ) -> Type[FieldFilter]: + ) -> Type["FieldFilter"]: + from .filters import FieldFilter - return self._registry_scalar_filters.get(scalar_type) + filter_type = self._registry_scalar_filters.get(scalar_type) + if not filter_type: + return FieldFilter.create_type( + f"Default{scalar_type.__name__}ScalarFilter", graphene_type=scalar_type + ) + return filter_type # TODO register enums automatically def register_filter_for_enum_type( - self, enum_type: Type[graphene.Enum], filter_obj: Type[FieldFilter] + self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"] ): + from .filters import FieldFilter + if not isinstance(enum_type, type(graphene.Enum)): raise TypeError("Expected Enum, but got: {!r}".format(enum_type)) - if not isinstance(filter_obj, type(FieldFilter)): + if not issubclass(filter_obj, FieldFilter): raise TypeError("Expected FieldFilter, but got: {!r}".format(filter_obj)) self._registry_scalar_filters[enum_type] = filter_obj def get_filter_for_enum_type( self, enum_type: Type[graphene.Enum] - ) -> Type[FieldFilter]: + ) -> Type["FieldFilter"]: return self._registry_enum_type_filters.get(enum_type) # Filter Object Types def register_filter_for_object_type( - self, object_type: Type[graphene.ObjectType], filter_obj: Type[ObjectTypeFilter] + self, + object_type: Type[graphene.ObjectType], + filter_obj: Type["ObjectTypeFilter"], ): + from .filters import ObjectTypeFilter + if not isinstance(object_type, type(graphene.ObjectType)): raise TypeError("Expected Object Type, but got: {!r}".format(object_type)) - if not isinstance(filter_obj, type(FieldFilter)): + if not issubclass(filter_obj, ObjectTypeFilter): raise TypeError( "Expected ObjectTypeFilter, but got: {!r}".format(filter_obj) ) @@ -167,12 +183,14 @@ def get_filter_for_object_type(self, object_type: Type[graphene.ObjectType]): # Filter Relationships between object types def register_relationship_filter_for_object_type( - self, object_type: graphene.ObjectType, filter_obj: RelationshipFilter + self, object_type: graphene.ObjectType, filter_obj: Type["RelationshipFilter"] ): + from .filters import RelationshipFilter + if not isinstance(object_type, type(graphene.ObjectType)): raise TypeError("Expected Object Type, but got: {!r}".format(object_type)) - if not isinstance(filter_obj, type(FieldFilter)): + if not issubclass(filter_obj, RelationshipFilter): raise TypeError( "Expected RelationshipFilter, but got: {!r}".format(filter_obj) ) @@ -180,7 +198,7 @@ def register_relationship_filter_for_object_type( def get_relationship_filter_for_object_type( self, object_type: Type[graphene.ObjectType] - ) -> RelationshipFilter: + ) -> "RelationshipFilter": return self._registry_relationship_filters.get(object_type) diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 78a18ec4..8b647db0 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -1,11 +1,12 @@ import pytest +from sqlalchemy.sql.operators import is_ import graphene from graphene import Connection, relay from ..fields import SQLAlchemyConnectionField from ..filters import FloatFilter -from ..types import SQLAlchemyObjectType +from ..types import ORMField, SQLAlchemyObjectType from .models import Article, Editor, HairKind, Image, Pet, Reporter, Tag from .utils import to_std_dicts @@ -15,32 +16,39 @@ def add_test_data(session): - reporter = Reporter( - first_name='John', last_name='Doe', favorite_pet_kind='cat') + + reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") session.add(reporter) - pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT) + + pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT) pet.reporter = reporter session.add(pet) - pet = Pet(name='Snoopy', pet_kind='dog', hair_kind=HairKind.SHORT) + + pet = Pet(name="Snoopy", pet_kind="dog", hair_kind=HairKind.SHORT, legs=3) pet.reporter = reporter session.add(pet) - reporter = Reporter( - first_name='John', last_name='Woe', favorite_pet_kind='cat') + + reporter = Reporter(first_name="John", last_name="Woe", favorite_pet_kind="cat") session.add(reporter) - article = Article(headline='Hi!') + + article = Article(headline="Hi!") article.reporter = reporter session.add(article) - article = Article(headline='Hello!') + + article = Article(headline="Hello!") article.reporter = reporter session.add(article) - reporter = Reporter( - first_name='Jane', last_name='Roe', favorite_pet_kind='dog') + + reporter = Reporter(first_name="Jane", last_name="Roe", favorite_pet_kind="dog") session.add(reporter) - pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG) + + pet = Pet(name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG) pet.reporter = reporter session.add(pet) + editor = Editor(name="Jack") session.add(editor) + session.commit() @@ -116,25 +124,23 @@ def test_filter_simple(session): "reporters": {"edges": [{"node": {"firstName": "Jane"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected # Test a custom filter type -@pytest.mark.xfail def test_filter_custom_type(session): add_test_data(session) - Query = create_schema(session) class MathFilter(FloatFilter): class Meta: - type = graphene.Float + graphene_type = graphene.Float @classmethod - def divisible_by(cls, query, field, val: graphene.Float) -> bool: - return field % val == 0. + def divisible_by_filter(cls, query, field, val: int) -> bool: + return is_(field % val, 0) class PetType(SQLAlchemyObjectType): class Meta: @@ -143,28 +149,31 @@ class Meta: interfaces = (relay.Node,) connection_class = Connection - class ExtraQuery: - pets = SQLAlchemyConnectionField( - PetType.connection, filter=MathFilter() - ) + legs = ORMField(filter_type=MathFilter) - class CustomQuery(Query, ExtraQuery): - pass + class Query(graphene.ObjectType): + pets = SQLAlchemyConnectionField(PetType.connection) query = """ query { pets (filter: { legs: {divisibleBy: 2} }) { - name + edges { + node { + name + } + } } } """ expected = { - "pets": [{"name": "Garfield"}, {"name": "Lassie"}], + "pets": { + "edges": [{"node": {"name": "Garfield"}}, {"node": {"name": "Lassie"}}] + }, } - schema = graphene.Schema(query=CustomQuery) - result = schema.execute(query, context_value={'session': session}) + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -172,7 +181,7 @@ class CustomQuery(Query, ExtraQuery): # Test a 1:1 relationship def test_filter_relationship_one_to_one(session): - article = Article(headline='Hi!') + article = Article(headline="Hi!") image = Image(external_id=1, description="A beautiful image.") article.image = image session.add(article) @@ -198,7 +207,7 @@ def test_filter_relationship_one_to_one(session): "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -229,7 +238,7 @@ def test_filter_relationship_one_to_many(session): "reporters": {"edges": [{"node": {"lastName": "Woe"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -258,7 +267,7 @@ def test_filter_relationship_one_to_many(session): "reporters": {"edges": [{"node": {"firstName": "John", "lastName": "Woe"}}]} } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -267,8 +276,8 @@ def test_filter_relationship_one_to_many(session): # Test a n:m relationship @pytest.mark.xfail def test_filter_relationship_many_to_many(session): - article1 = Article(headline='Article! Look!') - article2 = Article(headline='Woah! Another!') + article1 = Article(headline="Article! Look!") + article2 = Article(headline="Woah! Another!") tag1 = Tag(name="sensational") tag2 = Tag(name="eye-grabbing") article1.tags.append(tag1) @@ -302,7 +311,7 @@ def test_filter_relationship_many_to_many(session): ], } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -347,7 +356,7 @@ def test_filter_relationship_many_to_many(session): "articles": [{"headline": "Article! Look!"}], } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -377,12 +386,14 @@ def test_filter_logic_and(session): } """ expected = { - "reporters": {"edges": [{"node": {"lastName": "Doe"}}, {"node": {"lastName": "Woe"}}]}, + "reporters": { + "edges": [{"node": {"lastName": "Doe"}}, {"node": {"lastName": "Woe"}}] + }, } schema = graphene.Schema(query=Query) - with open('schema.gql', 'w') as fp: + with open("schema.gql", "w") as fp: fp.write(str(schema)) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -421,7 +432,7 @@ def test_filter_logic_or(session): } } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -463,7 +474,7 @@ def test_filter_logic_and_or(session): } } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 77b82333..01479701 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,6 +1,6 @@ import warnings from collections import OrderedDict -from typing import Type, Union +from typing import Optional, Type, Union import sqlalchemy from sqlalchemy.ext.hybrid import hybrid_property @@ -8,9 +8,10 @@ from sqlalchemy.orm.exc import NoResultFound import graphene -from graphene import Field +from graphene import Field, InputField from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions +from graphene.types.unmountedtype import UnmountedType from graphene.types.utils import yank_fields_from_attrs from graphene.utils.orderedtype import OrderedType @@ -48,6 +49,8 @@ def __init__( description=None, deprecation_reason=None, batching=None, + create_filter=None, + filter_type: Optional[Type] = None, _creation_counter=None, **field_kwargs, ): @@ -86,6 +89,12 @@ class Meta: Same behavior as in graphene.Field. Defaults to None. :param bool batching: Toggle SQL batching. Defaults to None, that is `SQLAlchemyObjectType.meta.batching`. + :param bool create_filter: + Create a filter for this field. Defaults to True. + :param Type filter_type: + Override for the filter of this field with a custom filter type. + Default behavior is to get a matching filter type for this field from the registry. + Create_filter needs to be true :param int _creation_counter: Same behavior as in graphene.Field. """ @@ -97,6 +106,8 @@ class Meta: "required": required, "description": description, "deprecation_reason": deprecation_reason, + "create_filter": create_filter, + "filter_type": filter_type, "batching": batching, } common_kwargs = { @@ -126,9 +137,20 @@ def get_or_create_relationship_filter( def filter_field_from_type_field( - field: Union[graphene.Field, graphene.Dynamic], registry: Registry -) -> Union[graphene.InputField, graphene.Dynamic]: - if isinstance(field.type, graphene.List): + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + registry: Registry, + filter_type: Optional[Type], +) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: + # If a custom filter type was set for this field, use it here + print(field) + if filter_type: + return graphene.InputField(filter_type) + # fixme one test case fails where, find out why + if issubclass(type(field), graphene.Scalar): + filter_class = registry.get_filter_for_scalar_type(type(field)) + return graphene.InputField(filter_class) + + elif isinstance(field.type, graphene.List): pass elif isinstance(field.type, graphene.Dynamic): pass @@ -146,11 +168,14 @@ def resolve_dynamic(): type_, UnsortedSQLAlchemyConnectionField ): inner_type = get_nullable_type(type_.type.Edge.node._type) - return graphene.InputField( - get_or_create_relationship_filter(inner_type, registry) - ) + reg_res = get_or_create_relationship_filter(inner_type, registry) + if not reg_res: + print("filter class was none!!!") + print(type_) + return graphene.InputField(reg_res) elif isinstance(type_, Field): reg_res = registry.get_filter_for_object_type(type_.type) + return graphene.InputField(reg_res) else: warnings.warn(f"Unexpected Dynamic Type: {type_}") # Investigate @@ -173,13 +198,14 @@ def resolve_dynamic(): ) -def construct_fields( +def construct_fields_and_filters( obj_type, model, registry, only_fields, exclude_fields, batching, + create_filters, connection_field_factory, ): """ @@ -195,6 +221,7 @@ def construct_fields( :param tuple[string] only_fields: :param tuple[string] exclude_fields: :param bool batching: + :param bool create_filters: Enable filter generation for this type :param function|None connection_field_factory: :rtype: OrderedDict[str, graphene.Field] """ @@ -250,7 +277,12 @@ def construct_fields( # Build all the field dictionary fields = OrderedDict() + filters = OrderedDict() for orm_field_name, orm_field in orm_fields.items(): + filtering_enabled_for_field = orm_field.kwargs.pop( + "create_filter", create_filters + ) + filter_type = orm_field.kwargs.pop("filter_type", None) attr_name = orm_field.kwargs.pop("model_attr") attr = all_model_attrs[attr_name] resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver( @@ -286,8 +318,12 @@ def construct_fields( registry.register_orm_field(obj_type, orm_field_name, attr) fields[orm_field_name] = field + if filtering_enabled_for_field: + filters[orm_field_name] = filter_field_from_type_field( + field, registry, filter_type + ) - return fields + return fields, filters class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): @@ -351,16 +387,19 @@ def __init_subclass_with_meta__( "The options 'only_fields' and 'exclude_fields' cannot be both set on the same type." ) + fields, filters = construct_fields_and_filters( + obj_type=cls, + model=model, + registry=registry, + only_fields=only_fields, + exclude_fields=exclude_fields, + batching=batching, + create_filters=True, + connection_field_factory=connection_field_factory, + ) + sqla_fields = yank_fields_from_attrs( - construct_fields( - obj_type=cls, - model=model, - registry=registry, - only_fields=only_fields, - exclude_fields=exclude_fields, - batching=batching, - connection_field_factory=connection_field_factory, - ), + fields, _as=Field, sort=False, ) @@ -397,16 +436,14 @@ def __init_subclass_with_meta__( # Save Generated filter class in Meta Class if not _meta.filter_class: - filters = OrderedDict() # Map graphene fields to filters # TODO we might need to pass the ORMFields containing the SQLAlchemy models # to the scalar filters here (to generate expressions from the model) - for fieldname, field in sqla_fields.items(): - field_filter = filter_field_from_type_field(field, registry) - if field_filter: - filters[fieldname] = field_filter + + filter_fields = yank_fields_from_attrs(filters, _as=InputField, sort=False) + _meta.filter_class = ObjectTypeFilter.create_type( - f"{cls.__name__}Filter", filter_fields=filters, model=model + f"{cls.__name__}Filter", filter_fields=filter_fields, model=model ) registry.register_filter_for_object_type(cls, _meta.filter_class) From 4c7efb89f0bb1839aa06cd2469cdd56c25ab57be Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Thu, 1 Dec 2022 16:58:51 +0100 Subject: [PATCH 27/81] chore: add missing workflow setting --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 428eca1d..1013d16b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,12 +1,12 @@ name: Tests -on: +on: push: branches: - 'master' pull_request: branches: - - '*' + - '**' jobs: test: runs-on: ubuntu-latest From c16ccf64ba3d90c98c2c5f301762a1588e6cf670 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Thu, 1 Dec 2022 17:00:45 +0100 Subject: [PATCH 28/81] chore: change workflows back to the original setting --- .github/workflows/tests.yml | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1013d16b..de78190d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,12 +1,7 @@ name: Tests -on: - push: - branches: - - 'master' - pull_request: - branches: - - '**' +on: [push, pull_request] + jobs: test: runs-on: ubuntu-latest From 9b863325320b2b42808e38d357c6a202f14cac69 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Thu, 1 Dec 2022 17:06:24 +0100 Subject: [PATCH 29/81] fix: 3.7 type hints --- graphene_sqlalchemy/filters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 2f014f94..90e35564 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -18,7 +18,7 @@ def _get_functions_by_regex( regex: str, subtract_regex: str, class_: Type -) -> list[Tuple[str, dict[str, Any]]]: +) -> List[Tuple[str, Dict[str, Any]]]: function_regex = re.compile(regex) matching_functions = [] @@ -262,7 +262,7 @@ def not_in_filter(cls, query, field, val: List[ScalarFilterInputType]): @classmethod def execute_filters( - cls, query, field, filter_dict: dict[str, any] + cls, query, field, filter_dict: Dict[str, any] ) -> Tuple[Query, List[Any]]: clauses = [] for filt, val in filter_dict.items(): From ae28fe630798619b110e41e45f22fb652036eabd Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 19 Dec 2022 11:15:40 +0100 Subject: [PATCH 30/81] fix: relationship contains filters for n:m working Signed-off-by: Erik Wrede --- graphene_sqlalchemy/fields.py | 6 ++-- graphene_sqlalchemy/filters.py | 60 ++++++++++++++++++++++++++-------- graphene_sqlalchemy/types.py | 17 +++++++++- 3 files changed, 67 insertions(+), 16 deletions(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 8c5a275f..d9523c45 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -96,9 +96,11 @@ def get_query(cls, model, info, sort=None, filter=None, **args): assert isinstance(filter, dict) filter_type: ObjectTypeFilter = type(filter) query, clauses = filter_type.execute_filters(query, filter) - + print("Query before filter") + print(query) + print([str(cla) for cla in clauses]) query = query.filter(*clauses) - + print(query) return query @classmethod diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 90e35564..923b622c 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -60,7 +60,6 @@ def __init_subclass_with_meta__( # If type is generic, replace with actual type of filter class replace_type_vars = {ObjectTypeFilterSelf: cls} - field_type = convert_sqlalchemy_hybrid_property_type( _annotations.get("val", str), replace_type_vars=replace_type_vars ) @@ -73,6 +72,9 @@ def __init_subclass_with_meta__( _meta.fields = filter_fields _meta.fields.update(new_filter_fields) + for field in _meta.fields: + print(f"Added field {field} of type {_meta.fields[field].type}") + _meta.model = model super(ObjectTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) @@ -164,10 +166,15 @@ def execute_filters( joined_model = field_filter_type._meta.model # Always alias the model joined_model_alias = aliased(joined_model) - # Join the aliased model onto the query query = query.join(model_field.of_type(joined_model_alias)) + if model_alias: + print("=======================") + print( + f"joining model {joined_model} on {model_alias} with alias {joined_model_alias}" + ) + print(str(query)) # Pass the joined query down to the next object type filter for processing query, _clauses = field_filter_type.execute_filters( query, field_filters, model_alias=joined_model_alias @@ -176,6 +183,13 @@ def execute_filters( if issubclass(field_filter_type, RelationshipFilter): # TODO see above; not yet working relationship_prop = field_filter_type._meta.model + # Always alias the model + # joined_model_alias = aliased(relationship_prop) + + # Join the aliased model onto the query + # query = query.join(model_field.of_type(joined_model_alias)) + # todo should we use selectinload here instead of join for large lists? + query, _clauses = field_filter_type.execute_filters( query, model_field, field_filters, relationship_prop ) @@ -184,6 +198,7 @@ def execute_filters( query, _clauses = field_filter_type.execute_filters( query, model_field, field_filters ) + print([str(cla) for cla in _clauses]) clauses.extend(_clauses) return query, clauses @@ -386,26 +401,42 @@ def __init_subclass_with_meta__( _meta.fields = relationship_filters _meta.model = model - + _meta.object_type_filter = object_type_filter super(RelationshipFilter, cls).__init_subclass_with_meta__( _meta=_meta, **options ) @classmethod - def contains_filter(cls, query, field, val: List[ScalarFilterInputType]): + def contains_filter( + cls, query, field, relationship_prop, val: List[ScalarFilterInputType] + ): clauses = [] for v in val: - query, clauses = v.execute_filters(query, dict(v)) - clauses += clauses - return clauses + print("executing contains filter", v) + # Always alias the model + joined_model_alias = aliased(relationship_prop) + + # Join the aliased model onto the query + query = query.join(field.of_type(joined_model_alias)) + print("Joined model", relationship_prop) + print(query) + # pass the alias so group can join group + query, _clauses = cls._meta.object_type_filter.execute_filters( + query, v, model_alias=joined_model_alias + ) + # print(query) + clauses += _clauses + return query, clauses @classmethod - def contains_exactly_filter(cls, query, field, val: List[ScalarFilterInputType]): + def contains_exactly_filter( + cls, query, field, relationship_prop, val: List[ScalarFilterInputType] + ): clauses = [] for v in val: - query, clauses = v.execute_filters(query, dict(v)) - clauses += clauses - return clauses + query, _clauses = v.execute_filters(query, dict(v)) + clauses += _clauses + return query, clauses @classmethod def execute_filters( @@ -414,6 +445,9 @@ def execute_filters( query, clauses = (query, []) for filt, val in filter_dict.items(): - clauses += getattr(cls, filt + "_filter")(query, field, val) + query, _clauses = getattr(cls, filt + "_filter")( + query, field, relationship_prop, val + ) + clauses += _clauses - return query.join(field), clauses + return query, clauses diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 01479701..376b025c 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -142,7 +142,7 @@ def filter_field_from_type_field( filter_type: Optional[Type], ) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: # If a custom filter type was set for this field, use it here - print(field) + # print(field) if filter_type: return graphene.InputField(filter_type) # fixme one test case fails where, find out why @@ -151,6 +151,10 @@ def filter_field_from_type_field( return graphene.InputField(filter_class) elif isinstance(field.type, graphene.List): + print("got field with list type") + pass + elif isinstance(field, graphene.List): + print("Got list") pass elif isinstance(field.type, graphene.Dynamic): pass @@ -174,6 +178,13 @@ def resolve_dynamic(): print(type_) return graphene.InputField(reg_res) elif isinstance(type_, Field): + if isinstance(type_.type, graphene.List): + inner_type = get_nullable_type(type_.type.of_type) + reg_res = get_or_create_relationship_filter(inner_type, registry) + if not reg_res: + print("filter class was none!!!") + print(type_) + return graphene.InputField(reg_res) reg_res = registry.get_filter_for_object_type(type_.type) return graphene.InputField(reg_res) @@ -185,6 +196,10 @@ def resolve_dynamic(): elif isinstance(field, graphene.Field): type_ = get_nullable_type(field.type) + # Field might be a SQLAlchemyObjectType, due to hybrid properties + if issubclass(type_, SQLAlchemyObjectType): + filter_class = registry.get_filter_for_object_type(type_) + return graphene.InputField(filter_class) filter_class = registry.get_filter_for_scalar_type(type_) if not filter_class: warnings.warn( From d3345936d76f8121ff1811b156b7bb8fe671daf9 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 19 Dec 2022 12:59:42 -0500 Subject: [PATCH 31/81] test: add nested filter test and reverse relationship tests --- .gitignore | 3 + graphene_sqlalchemy/filters.py | 32 +- graphene_sqlalchemy/tests/test_filters.py | 405 +++++++++++++++++++--- graphene_sqlalchemy/types.py | 1 - 4 files changed, 382 insertions(+), 59 deletions(-) diff --git a/.gitignore b/.gitignore index c4a735fe..6250579b 100644 --- a/.gitignore +++ b/.gitignore @@ -70,5 +70,8 @@ target/ *.sqlite3 .vscode +# Schema +*.gql + # mypy cache .mypy_cache/ diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 923b622c..193a71b0 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, List, Tuple, Type, TypeVar, Union, get_type_hints +from typing import Any, Dict, List, Tuple, Type, TypeVar, Union from sqlalchemy import and_, not_, or_ from sqlalchemy.orm import Query, aliased @@ -182,6 +182,8 @@ def execute_filters( clauses.extend(_clauses) if issubclass(field_filter_type, RelationshipFilter): # TODO see above; not yet working + print("ObjectType execute_filters: ", query, field_filters) + print(model_field, field_filter_type) relationship_prop = field_filter_type._meta.model # Always alias the model # joined_model_alias = aliased(relationship_prop) @@ -363,19 +365,7 @@ def __init_subclass_with_meta__( _meta = InputObjectTypeOptions(cls) # get all filter functions - filter_function_regex = re.compile(".+_filter$") - - filter_functions = [] - - # Search the entire class for functions matching the filter regex - for func in dir(cls): - func_attr = getattr(cls, func) - # Check if attribute is a function - if callable(func_attr) and filter_function_regex.match(func): - # add function and attribute name to the list - filter_functions.append( - (re.sub("_filter$", "", func), get_type_hints(func_attr)) - ) + filter_functions = _get_functions_by_regex(".+_filter$", "_filter$", cls) relationship_filters = {} @@ -424,7 +414,6 @@ def contains_filter( query, _clauses = cls._meta.object_type_filter.execute_filters( query, v, model_alias=joined_model_alias ) - # print(query) clauses += _clauses return query, clauses @@ -433,9 +422,18 @@ def contains_exactly_filter( cls, query, field, relationship_prop, val: List[ScalarFilterInputType] ): clauses = [] + print("Contains exactly: ", query, val) + # query, clauses = v.execute_filters(query, all_(val(items))) + # vals = [] for v in val: - query, _clauses = v.execute_filters(query, dict(v)) - clauses += _clauses + # vals.append(dict(v)) + # print(dict(v)) + query, clauses = v.execute_filters(query, dict(v)) + clauses += clauses + # query, clauses = v.execute_filters(query, all_(vals)) + # clauses = [or_(*clauses)] + # print(query) + # print(clauses) return query, clauses @classmethod diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 8b647db0..b191e1f6 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -7,7 +7,7 @@ from ..fields import SQLAlchemyConnectionField from ..filters import FloatFilter from ..types import ORMField, SQLAlchemyObjectType -from .models import Article, Editor, HairKind, Image, Pet, Reporter, Tag +from .models import Article, Editor, HairKind, Image, Pet, Reader, Reporter, Tag from .utils import to_std_dicts # TODO test that generated schema is correct for all examples with: @@ -16,7 +16,6 @@ def add_test_data(session): - reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") session.add(reporter) @@ -52,6 +51,37 @@ def add_test_data(session): session.commit() +def add_n2m_test_data(session): + # create objects + reader1 = Reader(name="Ada") + reader2 = Reader(name="Bip") + article1 = Article(headline="Article! Look!") + article2 = Article(headline="Woah! Another!") + tag1 = Tag(name="sensational") + tag2 = Tag(name="eye-grabbing") + image1 = Image(description="article 1") + image2 = Image(description="article 2") + + # set relationships + article1.tags = [tag1] + article2.tags = [tag1, tag2] + article1.image = image1 + article2.image = image2 + reader1.articles = [article1] + reader2.articles = [article1, article2] + + # save + session.add(image1) + session.add(image2) + session.add(tag1) + session.add(tag2) + session.add(article1) + session.add(article2) + session.add(reader1) + session.add(reader2) + session.commit() + + def create_schema(session): class ArticleType(SQLAlchemyObjectType): class Meta: @@ -74,6 +104,13 @@ class Meta: interfaces = (relay.Node,) connection_class = Connection + class ReaderType(SQLAlchemyObjectType): + class Meta: + model = Reader + name = "Reader" + interfaces = (relay.Node,) + connection_class = Connection + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -81,6 +118,13 @@ class Meta: interfaces = (relay.Node,) connection_class = Connection + class TagType(SQLAlchemyObjectType): + class Meta: + model = Tag + name = "Tag" + interfaces = (relay.Node,) + connection_class = Connection + class Query(graphene.ObjectType): node = relay.Node.Field() # # TODO how to create filterable singular field? @@ -88,8 +132,10 @@ class Query(graphene.ObjectType): articles = SQLAlchemyConnectionField(ArticleType.connection) # image = graphene.Field(ImageType) images = SQLAlchemyConnectionField(ImageType.connection) + readers = SQLAlchemyConnectionField(ReaderType.connection) # reporter = graphene.Field(ReporterType) reporters = SQLAlchemyConnectionField(ReporterType.connection) + tags = SQLAlchemyConnectionField(TagType.connection) # def resolve_article(self, _info): # return session.query(Article).first() @@ -214,6 +260,8 @@ def test_filter_relationship_one_to_one(session): # Test a 1:n relationship +# TODO implement containsExactly +@pytest.mark.xfail def test_filter_relationship_one_to_many(session): add_test_data(session) Query = create_schema(session) @@ -273,42 +321,98 @@ def test_filter_relationship_one_to_many(session): assert result == expected -# Test a n:m relationship -@pytest.mark.xfail -def test_filter_relationship_many_to_many(session): - article1 = Article(headline="Article! Look!") - article2 = Article(headline="Woah! Another!") - tag1 = Tag(name="sensational") - tag2 = Tag(name="eye-grabbing") - article1.tags.append(tag1) - article2.tags.append([tag1, tag2]) - session.add(article1) - session.add(article2) - session.add(tag1) - session.add(tag2) - session.commit() - +# Test n:m relationship contains +def test_filter_relationship_many_to_many_contains(session): + add_n2m_test_data(session) Query = create_schema(session) - # test contains + # test contains 1 query = """ - query { - articles (filter: { - tags: { - contains: { - name: { in: ["sensational", "eye-grabbing"] } + query { + articles (filter: { + tags: { + contains: [ + { name: { in: ["sensational", "eye-grabbing"] } }, + ] + } + }) { + edges { + node { + headline + } + } } } - }) { - headline - } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Article! Look!"}}, + {"node": {"headline": "Woah! Another!"}}, + ], + }, } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + # test contains 2 + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + # test reverse + query = """ + query { + tags (filter: { + articles: { + contains: [ + { headline: { eq: "Article! Look!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } """ expected = { - "articles": [ - {"headline": "Woah! Another!"}, - {"headline": "Article! Look!"}, - ], + "tags": { + "edges": [ + {"node": {"name": "sensational"}}, + ], + }, } schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) @@ -316,31 +420,43 @@ def test_filter_relationship_many_to_many(session): result = to_std_dicts(result.data) assert result == expected - # test containsAllOf + +# Test n:m relationship containsExactly +# TODO implement containsExactly +@pytest.mark.xfail +def test_filter_relationship_many_to_many_contains_exactly(session): + add_n2m_test_data(session) + Query = create_schema(session) + + # test containsExactly 1 query = """ query { articles (filter: { tags: { - containsAllOf: [ - { tag: { name: { eq: "eye-grabbing" } } }, - { tag: { name: { eq: "sensational" } } }, + containsExactly: [ + { name: { eq: "eye-grabbing" } }, + { name: { eq: "sensational" } }, ] } }) { - headline + edges { + node { + headline + } + } } } """ expected = { - "articles": [{"headline": "Woah! Another!"}], + "articles": {"edges": [{"node": {"headline": "Woah! Another!"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected - # test containsExactly + # test containsExactly 2 query = """ query { articles (filter: { @@ -353,7 +469,216 @@ def test_filter_relationship_many_to_many(session): } """ expected = { - "articles": [{"headline": "Article! Look!"}], + "articles": { + "edges": [ + {"node": {"headline": "Article! Look!"}}, + {"node": {"headline": "Woah! Another!"}}, + ] + }, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + # test reverse + query = """ + query { + tags (filter: { + articles: { + containsExactly: [ + { headline: { eq: "Article! Look!" } }, + { headline: { eq: "Woah! Another!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": { + "edges": [ + {"node": {"name": "eye-grabbing"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +# Test n:m relationship both contains and containsExactly +# TODO implement containsExactly +@pytest.mark.xfail +def test_filter_relationship_many_to_many_contains_and_contains_exactly(session): + add_n2m_test_data(session) + Query = create_schema(session) + + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + containsExactly: [ + { name: { eq: "eye-grabbing" } }, + { name: { eq: "sensational" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Woah! Another!"}}]}, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +# Test n:m nested relationship +# TODO add containsExactly +def test_filter_relationship_many_to_many_nested(session): + add_n2m_test_data(session) + Query = create_schema(session) + + # test readers->articles relationship + query = """ + query { + readers (filter: { + articles: { + contains: [ + { headline: { eq: "Woah! Another!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "readers": {"edges": [{"node": {"name": "Bip"}}]}, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + # test nested readers->articles->tags + query = """ + query { + readers (filter: { + articles: { + contains: [ + { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + } + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "readers": {"edges": [{"node": {"name": "Bip"}}]}, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + # test nested reverse + query = """ + query { + tags (filter: { + articles: { + contains: [ + { + readers: { + contains: [ + { name: { eq: "Ada" } }, + ] + } + } + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": {"edges": [{"node": {"name": "sensational"}}]}, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + # test filter on both levels of nesting + query = """ + query { + readers (filter: { + articles: { + contains: [ + { headline: { eq: "Woah! Another!" } }, + { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + } + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "readers": {"edges": [{"node": {"name": "Bip"}}]}, } schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) @@ -391,8 +716,6 @@ def test_filter_logic_and(session): }, } schema = graphene.Schema(query=Query) - with open("schema.gql", "w") as fp: - fp.write(str(schema)) result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 376b025c..f58a3302 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -142,7 +142,6 @@ def filter_field_from_type_field( filter_type: Optional[Type], ) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: # If a custom filter type was set for this field, use it here - # print(field) if filter_type: return graphene.InputField(filter_type) # fixme one test case fails where, find out why From a4621e277f3e0d8229142b044b7892a93df042ac Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 19 Dec 2022 15:38:12 -0500 Subject: [PATCH 32/81] fix: chain or/and in contains --- graphene_sqlalchemy/filters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 193a71b0..98fb831c 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -414,8 +414,8 @@ def contains_filter( query, _clauses = cls._meta.object_type_filter.execute_filters( query, v, model_alias=joined_model_alias ) - clauses += _clauses - return query, clauses + clauses.append(and_(*_clauses)) + return query, [or_(*clauses)] @classmethod def contains_exactly_filter( @@ -428,8 +428,8 @@ def contains_exactly_filter( for v in val: # vals.append(dict(v)) # print(dict(v)) - query, clauses = v.execute_filters(query, dict(v)) - clauses += clauses + query, _clauses = v.execute_filters(query, dict(v)) + clauses += _clauses # query, clauses = v.execute_filters(query, all_(vals)) # clauses = [or_(*clauses)] # print(query) From 0f34c054907c139f38cfb05ee52a242c2a8570fa Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Sat, 31 Dec 2022 21:03:29 -0500 Subject: [PATCH 33/81] test: try filtering ids on each containsExactly subquery --- graphene_sqlalchemy/filters.py | 60 ++++++++++++++++------- graphene_sqlalchemy/tests/test_filters.py | 13 +---- 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 98fb831c..56c85958 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,8 +1,8 @@ import re from typing import Any, Dict, List, Tuple, Type, TypeVar, Union -from sqlalchemy import and_, not_, or_ -from sqlalchemy.orm import Query, aliased +from sqlalchemy import and_, func, not_, or_ +from sqlalchemy.orm import Query, aliased, selectinload import graphene from graphene.types.inputobjecttype import ( @@ -182,8 +182,6 @@ def execute_filters( clauses.extend(_clauses) if issubclass(field_filter_type, RelationshipFilter): # TODO see above; not yet working - print("ObjectType execute_filters: ", query, field_filters) - print(model_field, field_filter_type) relationship_prop = field_filter_type._meta.model # Always alias the model # joined_model_alias = aliased(relationship_prop) @@ -421,20 +419,48 @@ def contains_filter( def contains_exactly_filter( cls, query, field, relationship_prop, val: List[ScalarFilterInputType] ): - clauses = [] - print("Contains exactly: ", query, val) - # query, clauses = v.execute_filters(query, all_(val(items))) - # vals = [] + print("Contains exactly called: ", query, val) + session = query.session + child_model_ids = [] for v in val: - # vals.append(dict(v)) - # print(dict(v)) - query, _clauses = v.execute_filters(query, dict(v)) - clauses += _clauses - # query, clauses = v.execute_filters(query, all_(vals)) - # clauses = [or_(*clauses)] - # print(query) - # print(clauses) - return query, clauses + print("Contains exactly loop: ", v) + + # Always alias the model + joined_model_alias = aliased(relationship_prop) + + # Store list of child IDs to filter per attribute + # attr = field.of_type(joined_model_alias) + # if not child_model_ids.get(str(attr), None): + # child_model_ids[str(attr)] = [] + + subquery = session.query(joined_model_alias.id) + subquery, _clauses = cls._meta.object_type_filter.execute_filters( + subquery, v, model_alias=joined_model_alias + ) + subquery_ids = [s_id[0] for s_id in subquery.filter(and_(*_clauses)).all()] + + child_model_ids.extend(subquery_ids) + + # Join the aliased model onto the query + query = query.join(field.of_type(joined_model_alias)) + query = ( + query.filter(relationship_prop.id.in_(subquery_ids)) + .group_by(joined_model_alias) + .having(func.count(joined_model_alias.id) == len(subquery_ids)) + ) + + # Define new query? + # query = session.query(cls) + + # Construct clauses from child_model_ids + # from .tests.models import Reporter + import pdb + + pdb.set_trace() + # query = query.filter(relationship_prop.id.in_(child_model_ids)).group_by(Reporter).having(func.count(relationship_prop.id)==len(child_model_ids)) + # query = query.filter(relationship_prop.id.in_(child_model_ids)).group_by(relationship_prop).having(func.count(field)==len(child_model_ids)) + + return query, [] @classmethod def execute_filters( diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index b191e1f6..590fbe19 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -260,8 +260,6 @@ def test_filter_relationship_one_to_one(session): # Test a 1:n relationship -# TODO implement containsExactly -@pytest.mark.xfail def test_filter_relationship_one_to_many(session): add_test_data(session) Query = create_schema(session) @@ -422,8 +420,6 @@ def test_filter_relationship_many_to_many_contains(session): # Test n:m relationship containsExactly -# TODO implement containsExactly -@pytest.mark.xfail def test_filter_relationship_many_to_many_contains_exactly(session): add_n2m_test_data(session) Query = create_schema(session) @@ -469,12 +465,7 @@ def test_filter_relationship_many_to_many_contains_exactly(session): } """ expected = { - "articles": { - "edges": [ - {"node": {"headline": "Article! Look!"}}, - {"node": {"headline": "Woah! Another!"}}, - ] - }, + "articles": {"edges": [{"node": {"headline": "Article! Look!"}}]}, } schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) @@ -516,8 +507,6 @@ def test_filter_relationship_many_to_many_contains_exactly(session): # Test n:m relationship both contains and containsExactly -# TODO implement containsExactly -@pytest.mark.xfail def test_filter_relationship_many_to_many_contains_and_contains_exactly(session): add_n2m_test_data(session) Query = create_schema(session) From 7f7e98a32ed2bb87cf301d235636b5662370af05 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Sat, 31 Dec 2022 21:06:02 -0500 Subject: [PATCH 34/81] test: working group_by/having clause for specific case --- graphene_sqlalchemy/filters.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 56c85958..8ceec566 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -441,23 +441,26 @@ def contains_exactly_filter( child_model_ids.extend(subquery_ids) - # Join the aliased model onto the query - query = query.join(field.of_type(joined_model_alias)) - query = ( - query.filter(relationship_prop.id.in_(subquery_ids)) - .group_by(joined_model_alias) - .having(func.count(joined_model_alias.id) == len(subquery_ids)) - ) + # query = ( + # query.filter(relationship_prop.id.in_(subquery_ids)) + # .group_by(joined_model_alias) + # .having(func.count(joined_model_alias.id) == len(subquery_ids)) + # ) + + # Join the aliased model onto the query + query = query.join(relationship_prop) # Define new query? # query = session.query(cls) # Construct clauses from child_model_ids # from .tests.models import Reporter - import pdb - - pdb.set_trace() - # query = query.filter(relationship_prop.id.in_(child_model_ids)).group_by(Reporter).having(func.count(relationship_prop.id)==len(child_model_ids)) + import pdb; pdb.set_trace() + query = ( + query.filter(relationship_prop.id.in_(child_model_ids)) + .group_by(Reporter) + .having(func.count(relationship_prop.id)==len(child_model_ids)) + ) # query = query.filter(relationship_prop.id.in_(child_model_ids)).group_by(relationship_prop).having(func.count(field)==len(child_model_ids)) return query, [] From 98db652439af85050b145d9593e861dd90238e82 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Sun, 1 Jan 2023 06:10:54 -0500 Subject: [PATCH 35/81] fix: group_by/having working for all tests but containsExactly 2 --- graphene_sqlalchemy/filters.py | 55 +++++++++++++++-------- graphene_sqlalchemy/tests/test_filters.py | 16 ++++--- 2 files changed, 46 insertions(+), 25 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 8ceec566..310f6cb7 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Tuple, Type, TypeVar, Union from sqlalchemy import and_, func, not_, or_ -from sqlalchemy.orm import Query, aliased, selectinload +from sqlalchemy.orm import Query, aliased # , selectinload import graphene from graphene.types.inputobjecttype import ( @@ -24,13 +24,13 @@ def _get_functions_by_regex( matching_functions = [] # Search the entire class for functions matching the filter regex - for func in dir(class_): - func_attr = getattr(class_, func) + for fn in dir(class_): + func_attr = getattr(class_, fn) # Check if attribute is a function - if callable(func_attr) and function_regex.match(func): + if callable(func_attr) and function_regex.match(fn): # add function and attribute name to the list matching_functions.append( - (re.sub(subtract_regex, "", func), func_attr.__annotations__) + (re.sub(subtract_regex, "", fn), func_attr.__annotations__) ) return matching_functions @@ -191,7 +191,7 @@ def execute_filters( # todo should we use selectinload here instead of join for large lists? query, _clauses = field_filter_type.execute_filters( - query, model_field, field_filters, relationship_prop + query, model, model_field, field_filters, relationship_prop ) clauses.extend(_clauses) elif issubclass(field_filter_type, FieldFilter): @@ -396,7 +396,12 @@ def __init_subclass_with_meta__( @classmethod def contains_filter( - cls, query, field, relationship_prop, val: List[ScalarFilterInputType] + cls, + query, + parent_model, + field, + relationship_prop, + val: List[ScalarFilterInputType], ): clauses = [] for v in val: @@ -417,7 +422,12 @@ def contains_filter( @classmethod def contains_exactly_filter( - cls, query, field, relationship_prop, val: List[ScalarFilterInputType] + cls, + query, + parent_model, + field, + relationship_prop, + val: List[ScalarFilterInputType], ): print("Contains exactly called: ", query, val) session = query.session @@ -447,33 +457,42 @@ def contains_exactly_filter( # .having(func.count(joined_model_alias.id) == len(subquery_ids)) # ) - # Join the aliased model onto the query - query = query.join(relationship_prop) + # Join the relationship onto the query + # import pdb; pdb.set_trace() + joined_model_alias = aliased(relationship_prop) + query = query.join(field.of_type(joined_model_alias)) # Define new query? # query = session.query(cls) # Construct clauses from child_model_ids - # from .tests.models import Reporter - import pdb; pdb.set_trace() query = ( - query.filter(relationship_prop.id.in_(child_model_ids)) - .group_by(Reporter) - .having(func.count(relationship_prop.id)==len(child_model_ids)) + query.filter(joined_model_alias.id.in_(child_model_ids)) + .group_by(parent_model) + .having(func.count(joined_model_alias.id) == len(child_model_ids)) ) - # query = query.filter(relationship_prop.id.in_(child_model_ids)).group_by(relationship_prop).having(func.count(field)==len(child_model_ids)) + # query = ( + # query.filter(relationship_prop.id.in_(child_model_ids)) + # .group_by(relationship_prop) + # .having(func.count(field)==len(child_model_ids) + # ) return query, [] @classmethod def execute_filters( - cls: Type[FieldFilter], query, field, filter_dict: Dict, relationship_prop + cls: Type[FieldFilter], + query, + parent_model, + field, + filter_dict: Dict, + relationship_prop, ) -> Tuple[Query, List[Any]]: query, clauses = (query, []) for filt, val in filter_dict.items(): query, _clauses = getattr(cls, filt + "_filter")( - query, field, relationship_prop, val + query, parent_model, field, relationship_prop, val ) clauses += _clauses diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 590fbe19..b2355785 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -456,11 +456,17 @@ def test_filter_relationship_many_to_many_contains_exactly(session): query = """ query { articles (filter: { + tags: { containsExactly: [ - { tag: { name: { eq: "sensational" } } } + { name: { eq: "sensational" } } ] + } }) { - headline + edges { + node { + headline + } + } } } """ @@ -493,11 +499,7 @@ def test_filter_relationship_many_to_many_contains_exactly(session): } """ expected = { - "tags": { - "edges": [ - {"node": {"name": "eye-grabbing"}}, - ], - }, + "tags": {"edges": [{"node": {"name": "eye-grabbing"}}]}, } schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) From ca1b4988eef8846b430eed84df7618076b5c6969 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Sun, 1 Jan 2023 22:03:30 -0500 Subject: [PATCH 36/81] chore: remove some debugging statements --- graphene_sqlalchemy/fields.py | 4 ---- graphene_sqlalchemy/filters.py | 28 ++++------------------- graphene_sqlalchemy/tests/test_filters.py | 1 + 3 files changed, 6 insertions(+), 27 deletions(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index d9523c45..85182c5b 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -96,11 +96,7 @@ def get_query(cls, model, info, sort=None, filter=None, **args): assert isinstance(filter, dict) filter_type: ObjectTypeFilter = type(filter) query, clauses = filter_type.execute_filters(query, filter) - print("Query before filter") - print(query) - print([str(cla) for cla in clauses]) query = query.filter(*clauses) - print(query) return query @classmethod diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 310f6cb7..05db1ce4 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -438,44 +438,26 @@ def contains_exactly_filter( # Always alias the model joined_model_alias = aliased(relationship_prop) - # Store list of child IDs to filter per attribute - # attr = field.of_type(joined_model_alias) - # if not child_model_ids.get(str(attr), None): - # child_model_ids[str(attr)] = [] - subquery = session.query(joined_model_alias.id) subquery, _clauses = cls._meta.object_type_filter.execute_filters( subquery, v, model_alias=joined_model_alias ) subquery_ids = [s_id[0] for s_id in subquery.filter(and_(*_clauses)).all()] - child_model_ids.extend(subquery_ids) - # query = ( - # query.filter(relationship_prop.id.in_(subquery_ids)) - # .group_by(joined_model_alias) - # .having(func.count(joined_model_alias.id) == len(subquery_ids)) - # ) - # Join the relationship onto the query - # import pdb; pdb.set_trace() joined_model_alias = aliased(relationship_prop) - query = query.join(field.of_type(joined_model_alias)) - - # Define new query? - # query = session.query(cls) + joined_field = field.of_type(joined_model_alias) + query = query.join(joined_field) # Construct clauses from child_model_ids query = ( query.filter(joined_model_alias.id.in_(child_model_ids)) .group_by(parent_model) - .having(func.count(joined_model_alias.id) == len(child_model_ids)) + .having(func.count(str(field)) == len(child_model_ids)) + # TODO should filter on aliased field + # .having(func.count(joined_field) == len(child_model_ids)) ) - # query = ( - # query.filter(relationship_prop.id.in_(child_model_ids)) - # .group_by(relationship_prop) - # .having(func.count(field)==len(child_model_ids) - # ) return query, [] diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index b2355785..a5e34223 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -420,6 +420,7 @@ def test_filter_relationship_many_to_many_contains(session): # Test n:m relationship containsExactly +@pytest.mark.xfail def test_filter_relationship_many_to_many_contains_exactly(session): add_n2m_test_data(session) Query = create_schema(session) From 8f012f491356578a5bd84bd54b88d23492af34a0 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 2 Jan 2023 00:46:41 -0500 Subject: [PATCH 37/81] test: cover additional filter types --- graphene_sqlalchemy/filters.py | 2 +- graphene_sqlalchemy/registry.py | 2 +- graphene_sqlalchemy/tests/test_filters.py | 71 ++++++++++++++++++++++- 3 files changed, 72 insertions(+), 3 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 05db1ce4..23a8fbbc 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -273,7 +273,7 @@ def in_filter(cls, query, field, val: List[ScalarFilterInputType]): @classmethod def not_in_filter(cls, query, field, val: List[ScalarFilterInputType]): - return field.not_int(val) + return field.not_in(val) @classmethod def execute_filters( diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index b6a4c201..337e5628 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -6,7 +6,7 @@ import graphene from graphene import Enum -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no_cover from graphene_sqlalchemy.filters import ( FieldFilter, ObjectTypeFilter, diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index a5e34223..32b137f3 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -19,7 +19,7 @@ def add_test_data(session): reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") session.add(reporter) - pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT) + pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT, legs=4) pet.reporter = reporter session.add(pet) @@ -135,6 +135,7 @@ class Query(graphene.ObjectType): readers = SQLAlchemyConnectionField(ReaderType.connection) # reporter = graphene.Field(ReporterType) reporters = SQLAlchemyConnectionField(ReporterType.connection) + pets = SQLAlchemyConnectionField(PetType.connection) tags = SQLAlchemyConnectionField(TagType.connection) # def resolve_article(self, _info): @@ -799,3 +800,71 @@ def test_filter_logic_and_or(session): @pytest.mark.xfail def test_filter_hybrid_property(session): raise NotImplementedError + + +# Test edge cases to improve test coverage +def test_filter_edge_cases(session): + add_test_data(session) + + # test disabling filtering + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + name = "Article" + interfaces = (relay.Node,) + connection_class = Connection + + class Query(graphene.ObjectType): + node = relay.Node.Field() + articles = SQLAlchemyConnectionField(ArticleType.connection, filter=None) + + schema = graphene.Schema(query=Query) + assert not hasattr(schema, "ArticleTypeFilter") + + +# Test additional filter types to improve test coverage +def test_additional_filters(session): + add_test_data(session) + Query = create_schema(session) + + # test n_eq and not_in filters + query = """ + query { + reporters (filter: {firstName: {nEq: "Jane"}, lastName: {notIn: "Doe"}}) { + edges { + node { + lastName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"lastName": "Woe"}}]}, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + # test gt, lt, gte, and lte filters + query = """ + query { + pets (filter: {legs: {gt: 2, lt: 4, gte: 3, lte: 3}}) { + edges { + node { + name + } + } + } + } + """ + expected = { + "pets": {"edges": [{"node": {"name": "Snoopy"}}]}, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected From 14c657cb4cea8f8f052a2b7fefe1b05cc4180665 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 2 Jan 2023 00:48:47 -0500 Subject: [PATCH 38/81] fix: use notin_ --- graphene_sqlalchemy/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 23a8fbbc..b1d4fd70 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -273,7 +273,7 @@ def in_filter(cls, query, field, val: List[ScalarFilterInputType]): @classmethod def not_in_filter(cls, query, field, val: List[ScalarFilterInputType]): - return field.not_in(val) + return field.notin_(val) @classmethod def execute_filters( From 2e321e92f14613ce97c1293168da49d8c77a463e Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sun, 4 Dec 2022 21:13:29 +0100 Subject: [PATCH 39/81] chore: prepare for sqlalchemy2.0 adjustments --- .github/workflows/tests.yml | 13 ++++--------- graphene_sqlalchemy/tests/conftest.py | 7 +++++++ tox.ini | 6 +++++- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7632fd38..d8239457 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,20 +1,15 @@ name: Tests -on: - push: - branches: - - 'master' - pull_request: - branches: - - '*' +on: [ push, pull_request ] + jobs: test: runs-on: ubuntu-latest strategy: max-parallel: 10 matrix: - sql-alchemy: ["1.2", "1.3", "1.4"] - python-version: ["3.7", "3.8", "3.9", "3.10"] + sql-alchemy: [ "1.2", "1.3", "1.4", "2.0" ] + python-version: [ "3.7", "3.8", "3.9", "3.10" ] steps: - uses: actions/checkout@v3 diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 89b357a4..d3fcedc9 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -3,6 +3,13 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +# fmt: off +# Fixme remove when https://github.com/kvesteri/sqlalchemy-utils/pull/644 is released #noqa +import sqlalchemy # noqa # isort:skip +if sqlalchemy.__version__ == "2.0.0b3": # noqa # isort:skip + sqlalchemy.__version__ = "2.0.0" # noqa # isort:skip +# fmt: on + import graphene from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 diff --git a/tox.ini b/tox.ini index 2802dee0..f7b5f973 100644 --- a/tox.ini +++ b/tox.ini @@ -15,6 +15,7 @@ SQLALCHEMY = 1.2: sql12 1.3: sql13 1.4: sql14 + 2.0: sql20 [testenv] passenv = GITHUB_* @@ -23,8 +24,11 @@ deps = sql12: sqlalchemy>=1.2,<1.3 sql13: sqlalchemy>=1.3,<1.4 sql14: sqlalchemy>=1.4,<1.5 + sql20: sqlalchemy>=2.0.0b3,<2.1 +setenv = + SQLALCHEMY_WARN_20 = 1 commands = - pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs} + python -W always -m pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs} [testenv:pre-commit] basepython=python3.10 From e4b1a7f08bfd23cba56e2f456a15e0e651ba8488 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sun, 4 Dec 2022 21:18:59 +0100 Subject: [PATCH 40/81] update envlist for tox,reduce number of python versions --- .github/workflows/tests.yml | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d8239457..402de29f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,7 +9,7 @@ jobs: max-parallel: 10 matrix: sql-alchemy: [ "1.2", "1.3", "1.4", "2.0" ] - python-version: [ "3.7", "3.8", "3.9", "3.10" ] + python-version: ["3.9", "3.10" ] steps: - uses: actions/checkout@v3 diff --git a/tox.ini b/tox.ini index f7b5f973..1841cb1a 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = pre-commit,py{37,38,39,310}-sql{12,13,14} +envlist = pre-commit,py{37,38,39,310}-sql{12,13,14,20} skipsdist = true minversion = 3.7.0 From 812b28bfdbb10a787f3528216ea4262fb704e543 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sun, 4 Dec 2022 21:22:33 +0100 Subject: [PATCH 41/81] fix: set sqlalchemy max version to 2.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9122baf2..9650e6d2 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ # To keep things simple, we only support newer versions of Graphene "graphene>=3.0.0b7", "promise>=2.3", - "SQLAlchemy>=1.1,<2", + "SQLAlchemy>=1.1,<2.1", "aiodataloader>=0.2.0,<1.0", ] From 728811776ce413beece177f8006796a5755026ca Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 2 Jan 2023 15:14:37 +0100 Subject: [PATCH 42/81] fix: all unit tests running Signed-off-by: Erik Wrede --- graphene_sqlalchemy/batching.py | 12 ++++++++++ graphene_sqlalchemy/tests/models.py | 20 ++++++++++++---- graphene_sqlalchemy/tests/models_batching.py | 13 +++++++--- graphene_sqlalchemy/tests/test_converter.py | 25 ++++++++++++++++++++ graphene_sqlalchemy/utils.py | 6 +++++ 5 files changed, 69 insertions(+), 7 deletions(-) diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 23b6712e..d7feb4fb 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -5,6 +5,7 @@ import sqlalchemy from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext +from sqlalchemy.util import immutabledict from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_graphene_version_less_than @@ -77,6 +78,17 @@ async def batch_load_fn(self, parents): else: query_context = QueryContext(session.query(parent_mapper.entity)) if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + self.selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + None, + None, # recursion depth can be none + immutabledict(), # default value for selectinload->lazyload + ) + elif SQL_VERSION_HIGHER_EQUAL_THAN_1_4: self.selectin_loader._load_for_path( query_context, parent_mapper._path_registry, diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index ee286585..a2ccd82f 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -22,6 +22,8 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import backref, column_property, composite, mapper, relationship +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 + PetKind = Enum("cat", "dog", name="pet_kind") @@ -116,9 +118,15 @@ def hybrid_prop_bool(self) -> bool: def hybrid_prop_list(self) -> List[int]: return [1, 2, 3] - column_prop = column_property( - select([func.cast(func.count(id), Integer)]), doc="Column property" - ) + # TODO Remove when switching min sqlalchemy version to SQLAlchemy 1.4 + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + column_prop = column_property( + select(func.cast(func.count(id), Integer)), doc="Column property" + ) + else: + column_prop = column_property( + select([func.cast(func.count(id), Integer)]), doc="Column property" + ) composite_prop = composite( CompositeFullName, first_name, last_name, doc="Composite" @@ -161,7 +169,11 @@ def __subclasses__(cls): editor_table = Table("editors", Base.metadata, autoload=True) -mapper(ReflectedEditor, editor_table) +# TODO Remove when switching min sqlalchemy version to SQLAlchemy 1.4 +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + Base.registry.map_imperatively(ReflectedEditor, editor_table) +else: + mapper(ReflectedEditor, editor_table) ############################################ diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py index 6f1c42ff..dde6d45c 100644 --- a/graphene_sqlalchemy/tests/models_batching.py +++ b/graphene_sqlalchemy/tests/models_batching.py @@ -16,6 +16,8 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import column_property, relationship +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 + PetKind = Enum("cat", "dog", name="pet_kind") @@ -60,9 +62,14 @@ class Reporter(Base): articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) - column_prop = column_property( - select([func.cast(func.count(id), Integer)]), doc="Column property" - ) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + column_prop = column_property( + select(func.cast(func.count(id), Integer)), doc="Column property" + ) + else: + column_prop = column_property( + select([func.cast(func.count(id), Integer)]), doc="Column property" + ) class Article(Base): diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index b9a1c152..bfd3ee66 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -24,6 +24,7 @@ from ..fields import UnsortedSQLAlchemyConnectionField, default_connection_field_factory from ..registry import Registry, get_global_registry from ..types import ORMField, SQLAlchemyObjectType +from ..utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_sqlalchemy_version_less_than from .models import ( Article, CompositeFullName, @@ -336,6 +337,22 @@ class TestEnum(enum.IntEnum): assert graphene_type._meta.enum.__members__["two"].value == 2 +@pytest.mark.skipif( + not SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + reason="SQLAlchemy <1.4 does not support this", +) +def test_should_columproperty_convert_sqa_20(): + field = get_field_from_column( + column_property(select(func.sum(func.cast(id, types.Integer))).where(id == 1)) + ) + + assert field.type == graphene.Int + + +@pytest.mark.skipif( + not is_sqlalchemy_version_less_than("2.0.0b1"), + reason="SQLAlchemy >=2.0 does not support this syntax, see convert_sqa_20", +) def test_should_columproperty_convert(): field = get_field_from_column( column_property(select([func.sum(func.cast(id, types.Integer))]).where(id == 1)) @@ -355,10 +372,18 @@ def test_should_jsontype_convert_jsonstring(): assert get_field(types.JSON).type == graphene.JSONString +@pytest.mark.skipif( + (not is_sqlalchemy_version_less_than("2.0.0b1")), + reason="SQLAlchemy >=2.0 does not support this: Variant is no longer used in SQLAlchemy", +) def test_should_variant_int_convert_int(): assert get_field(types.Variant(types.Integer(), {})).type == graphene.Int +@pytest.mark.skipif( + (not is_sqlalchemy_version_less_than("2.0.0b1")), + reason="SQLAlchemy >=2.0 does not support this: Variant is no longer used in SQLAlchemy", +) def test_should_variant_string_convert_string(): assert get_field(types.Variant(types.String(), {})).type == graphene.String diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 62c71d8d..86ebcd79 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -32,6 +32,12 @@ def is_graphene_version_less_than(version_string): # pragma: no cover SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = True +SQL_VERSION_HIGHER_EQUAL_THAN_2 = False + +if not is_sqlalchemy_version_less_than("2.0.0b1"): + SQL_VERSION_HIGHER_EQUAL_THAN_2 = True + + def get_session(context): return context.get("session") From e562cc2aecd84121203ef999c294322c700579ec Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 2 Jan 2023 15:19:01 +0100 Subject: [PATCH 43/81] fix: corrected sql version check for batching Signed-off-by: Erik Wrede --- graphene_sqlalchemy/batching.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index d7feb4fb..87974ce8 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -7,7 +7,11 @@ from sqlalchemy.orm.query import QueryContext from sqlalchemy.util import immutabledict -from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_graphene_version_less_than +from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + SQL_VERSION_HIGHER_EQUAL_THAN_2, + is_graphene_version_less_than, +) def get_data_loader_impl() -> Any: # pragma: no cover @@ -77,7 +81,7 @@ async def batch_load_fn(self, parents): query_context = parent_mapper_query._compile_context() else: query_context = QueryContext(session.query(parent_mapper.entity)) - if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + if SQL_VERSION_HIGHER_EQUAL_THAN_2: self.selectin_loader._load_for_path( query_context, parent_mapper._path_registry, From e525a8db462bf229e0d5f1577520412f6fbd556e Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 2 Jan 2023 15:25:59 +0100 Subject: [PATCH 44/81] fix: added pragma no cover to version checks Signed-off-by: Erik Wrede --- graphene_sqlalchemy/batching.py | 2 +- graphene_sqlalchemy/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 87974ce8..a5804516 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -81,7 +81,7 @@ async def batch_load_fn(self, parents): query_context = parent_mapper_query._compile_context() else: query_context = QueryContext(session.query(parent_mapper.entity)) - if SQL_VERSION_HIGHER_EQUAL_THAN_2: + if SQL_VERSION_HIGHER_EQUAL_THAN_2: # pragma: no cover self.selectin_loader._load_for_path( query_context, parent_mapper._path_registry, diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 86ebcd79..381164a7 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -26,7 +26,7 @@ def is_graphene_version_less_than(version_string): # pragma: no cover SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = False -if not is_sqlalchemy_version_less_than("1.4"): +if not is_sqlalchemy_version_less_than("1.4"): # pragma: no cover from sqlalchemy.ext.asyncio import AsyncSession SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = True @@ -34,7 +34,7 @@ def is_graphene_version_less_than(version_string): # pragma: no cover SQL_VERSION_HIGHER_EQUAL_THAN_2 = False -if not is_sqlalchemy_version_less_than("2.0.0b1"): +if not is_sqlalchemy_version_less_than("2.0.0b1"): # pragma: no cover SQL_VERSION_HIGHER_EQUAL_THAN_2 = True From b84aa9fa31238e7c3bc8f74db157175ee576e5d2 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 2 Jan 2023 15:27:17 +0100 Subject: [PATCH 45/81] chore: test with all python versions Signed-off-by: Erik Wrede --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 402de29f..ad45c81b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,8 +8,8 @@ jobs: strategy: max-parallel: 10 matrix: - sql-alchemy: [ "1.2", "1.3", "1.4", "2.0" ] - python-version: ["3.9", "3.10" ] + sql-alchemy: [ "1.2", "1.3", "1.4","2.0" ] + python-version: [ "3.7", "3.8", "3.9", "3.10" ] steps: - uses: actions/checkout@v3 From b832fff247a3b69acf77810e71dd8b65e4d8c60f Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 2 Jan 2023 09:55:00 -0500 Subject: [PATCH 46/81] wip: test: add hybrid_prop tests --- graphene_sqlalchemy/tests/models.py | 11 +- graphene_sqlalchemy/tests/test_filters.py | 245 +++++++++++++++++++--- 2 files changed, 222 insertions(+), 34 deletions(-) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 3ac3a777..86092fb1 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -268,11 +268,20 @@ def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]: ], ] - # Other SQLAlchemy Instances + # Other SQLAlchemy Instance @hybrid_property def hybrid_prop_first_shopping_cart_item(self) -> ShoppingCartItem: return ShoppingCartItem(id=1) + # Other SQLAlchemy Instance with expression + @hybrid_property + def hybrid_prop_first_shopping_cart_item_expression(self) -> ShoppingCartItem: + return ShoppingCartItem(id=1) + + @hybrid_prop_first_shopping_cart_item_expression.expression + def hybrid_prop_first_shopping_cart_item_expression(cls): + return ShoppingCartItem + # Other SQLAlchemy Instances @hybrid_property def hybrid_prop_shopping_cart_item_list(self) -> List[ShoppingCartItem]: diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 32b137f3..f560a17e 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -7,7 +7,18 @@ from ..fields import SQLAlchemyConnectionField from ..filters import FloatFilter from ..types import ORMField, SQLAlchemyObjectType -from .models import Article, Editor, HairKind, Image, Pet, Reader, Reporter, Tag +from .models import ( + Article, + Editor, + HairKind, + Image, + Pet, + Reader, + Reporter, + ShoppingCart, + ShoppingCartItem, + Tag, +) from .utils import to_std_dicts # TODO test that generated schema is correct for all examples with: @@ -51,37 +62,6 @@ def add_test_data(session): session.commit() -def add_n2m_test_data(session): - # create objects - reader1 = Reader(name="Ada") - reader2 = Reader(name="Bip") - article1 = Article(headline="Article! Look!") - article2 = Article(headline="Woah! Another!") - tag1 = Tag(name="sensational") - tag2 = Tag(name="eye-grabbing") - image1 = Image(description="article 1") - image2 = Image(description="article 2") - - # set relationships - article1.tags = [tag1] - article2.tags = [tag1, tag2] - article1.image = image1 - article2.image = image2 - reader1.articles = [article1] - reader2.articles = [article1, article2] - - # save - session.add(image1) - session.add(image2) - session.add(tag1) - session.add(tag2) - session.add(article1) - session.add(article2) - session.add(reader1) - session.add(reader2) - session.commit() - - def create_schema(session): class ArticleType(SQLAlchemyObjectType): class Meta: @@ -320,6 +300,37 @@ def test_filter_relationship_one_to_many(session): assert result == expected +def add_n2m_test_data(session): + # create objects + reader1 = Reader(name="Ada") + reader2 = Reader(name="Bip") + article1 = Article(headline="Article! Look!") + article2 = Article(headline="Woah! Another!") + tag1 = Tag(name="sensational") + tag2 = Tag(name="eye-grabbing") + image1 = Image(description="article 1") + image2 = Image(description="article 2") + + # set relationships + article1.tags = [tag1] + article2.tags = [tag1, tag2] + article1.image = image1 + article2.image = image2 + reader1.articles = [article1] + reader2.articles = [article1, article2] + + # save + session.add(image1) + session.add(image2) + session.add(tag1) + session.add(tag2) + session.add(article1) + session.add(article2) + session.add(reader1) + session.add(reader2) + session.commit() + + # Test n:m relationship contains def test_filter_relationship_many_to_many_contains(session): add_n2m_test_data(session) @@ -796,10 +807,178 @@ def test_filter_logic_and_or(session): assert result == expected +def add_hybrid_prop_test_data(session): + # create objects + cart = ShoppingCart() + + # set relationships + + # save + session.add(cart) + session.commit() + + +def create_hybrid_prop_schema(session): + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCart + name = "ShoppingCart" + interfaces = (relay.Node,) + connection_class = Connection + + class ShoppingCartItemType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + name = "ShoppingCartItem" + interfaces = (relay.Node,) + connection_class = Connection + + class Query(graphene.ObjectType): + node = relay.Node.Field() + carts = SQLAlchemyConnectionField(ShoppingCartType.connection) + + return Query + + # TODO hybrid property @pytest.mark.xfail def test_filter_hybrid_property(session): - raise NotImplementedError + add_hybrid_prop_test_data(session) + Query = create_hybrid_prop_schema(session) + + # test hybrid_prop_int + query = """ + query { + carts (filter: {hybridPropInt: {eq: 42}}) { + edges { + node { + hybridPropInt + } + } + } + } + """ + expected = { + "carts": { + "edges": [ + {"node": {"hybridPropInt": 42}}, + ] + }, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + # test hybrid_prop_float + query = """ + query { + carts (filter: {hybridPropFloat: {gt: 42}}) { + edges { + node { + hybridPropFloat + } + } + } + } + """ + expected = { + "carts": { + "edges": [ + {"node": {"hybridPropFloat": 42.3}}, + ] + }, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + # test hybrid_prop different model without expression + query = """ + query { + carts { + edges { + node { + hybridPropFirstShoppingCartItem + } + } + } + } + """ + expected = { + "carts": { + "edges": [ + {"node": {"hybridPropFirstShoppingCartItem": {"id": 1}}}, + ] + }, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + # cart = result["carts"]["edges"][0]["node"]["hybridPropFirstShoppingCartItem"] + # print(cart) + # print(type(cart)) + # TODO why is this str? + assert result == expected + + # test hybrid_prop different model with expression + query = """ + query { + carts { + edges { + node { + hybridPropFirstShoppingCartItemExpression + } + } + } + } + """ + expected = { + "carts": { + "edges": [ + {"node": {"hybridPropFirstShoppingCartItemExpression": {"id": 1}}}, + ] + }, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + # cart = result["carts"]["edges"][0]["node"]["hybridPropFirstShoppingCartItemExpression"] + # print(cart) + # print(type(cart)) + # TODO why is this str? + assert result == expected + + # test hybrid_prop list of models + query = """ + query { + carts { + edges { + node { + hybridPropShoppingCartItemList + } + } + } + } + """ + expected = { + "carts": { + "edges": [ + {"node": {"hybridPropShoppingCartItemList": {"id": 1}}}, + ] + }, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + # TODO why is this str? + assert result == expected # Test edge cases to improve test coverage From 5fc1be30c64e8de7b92aefad505362d7fd603747 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 2 Jan 2023 22:49:28 -0500 Subject: [PATCH 47/81] test: fix hybrid_prop converter test --- graphene_sqlalchemy/tests/test_converter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index b9a1c152..579fbfe7 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -666,6 +666,7 @@ class Meta: graphene.List(graphene.List(graphene.Int)) ), "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, + "hybrid_prop_first_shopping_cart_item_expression": ShoppingCartItemType, "hybrid_prop_shopping_cart_item_list": graphene.List(ShoppingCartItemType), "hybrid_prop_unsupported_type_tuple": graphene.String, # Self Referential List From 156fb68276e10b98131fe9b27295e39a1fd30ff5 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Tue, 3 Jan 2023 06:22:22 -0500 Subject: [PATCH 48/81] test: fix hybrid_prop test by using string typevars --- graphene_sqlalchemy/tests/models.py | 6 +-- graphene_sqlalchemy/tests/test_filters.py | 59 ++++++++--------------- 2 files changed, 22 insertions(+), 43 deletions(-) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 86092fb1..7af6676e 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -270,12 +270,12 @@ def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]: # Other SQLAlchemy Instance @hybrid_property - def hybrid_prop_first_shopping_cart_item(self) -> ShoppingCartItem: + def hybrid_prop_first_shopping_cart_item(self) -> "ShoppingCartItem": return ShoppingCartItem(id=1) # Other SQLAlchemy Instance with expression @hybrid_property - def hybrid_prop_first_shopping_cart_item_expression(self) -> ShoppingCartItem: + def hybrid_prop_first_shopping_cart_item_expression(self) -> "ShoppingCartItem": return ShoppingCartItem(id=1) @hybrid_prop_first_shopping_cart_item_expression.expression @@ -284,7 +284,7 @@ def hybrid_prop_first_shopping_cart_item_expression(cls): # Other SQLAlchemy Instances @hybrid_property - def hybrid_prop_shopping_cart_item_list(self) -> List[ShoppingCartItem]: + def hybrid_prop_shopping_cart_item_list(self) -> List["ShoppingCartItem"]: return [ShoppingCartItem(id=1), ShoppingCartItem(id=2)] # Unsupported Type diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index f560a17e..774dd0fe 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -841,7 +841,6 @@ class Query(graphene.ObjectType): # TODO hybrid property -@pytest.mark.xfail def test_filter_hybrid_property(session): add_hybrid_prop_test_data(session) Query = create_hybrid_prop_schema(session) @@ -901,84 +900,64 @@ def test_filter_hybrid_property(session): query { carts { edges { - node { - hybridPropFirstShoppingCartItem + node { + hybridPropFirstShoppingCartItem { + id } + } } } } """ - expected = { - "carts": { - "edges": [ - {"node": {"hybridPropFirstShoppingCartItem": {"id": 1}}}, - ] - }, - } schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) - # cart = result["carts"]["edges"][0]["node"]["hybridPropFirstShoppingCartItem"] - # print(cart) - # print(type(cart)) - # TODO why is this str? - assert result == expected + assert len(result["carts"]["edges"]) == 1 # test hybrid_prop different model with expression query = """ query { carts { edges { - node { - hybridPropFirstShoppingCartItemExpression + node { + hybridPropFirstShoppingCartItemExpression { + id } + } } } } """ - expected = { - "carts": { - "edges": [ - {"node": {"hybridPropFirstShoppingCartItemExpression": {"id": 1}}}, - ] - }, - } + schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) - # cart = result["carts"]["edges"][0]["node"]["hybridPropFirstShoppingCartItemExpression"] - # print(cart) - # print(type(cart)) - # TODO why is this str? - assert result == expected + assert len(result["carts"]["edges"]) == 1 # test hybrid_prop list of models query = """ query { carts { edges { - node { - hybridPropShoppingCartItemList + node { + hybridPropShoppingCartItemList { + id } + } } } } """ - expected = { - "carts": { - "edges": [ - {"node": {"hybridPropShoppingCartItemList": {"id": 1}}}, - ] - }, - } schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) - # TODO why is this str? - assert result == expected + assert len(result["carts"]["edges"]) == 1 + assert ( + len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2 + ) # Test edge cases to improve test coverage From 5d819919b343698816fb8fc911ba4559168a47c7 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Wed, 4 Jan 2023 22:56:10 -0500 Subject: [PATCH 49/81] test: revert test models --- graphene_sqlalchemy/tests/models.py | 6 +++--- graphene_sqlalchemy/tests/test_filters.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 7af6676e..86092fb1 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -270,12 +270,12 @@ def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]: # Other SQLAlchemy Instance @hybrid_property - def hybrid_prop_first_shopping_cart_item(self) -> "ShoppingCartItem": + def hybrid_prop_first_shopping_cart_item(self) -> ShoppingCartItem: return ShoppingCartItem(id=1) # Other SQLAlchemy Instance with expression @hybrid_property - def hybrid_prop_first_shopping_cart_item_expression(self) -> "ShoppingCartItem": + def hybrid_prop_first_shopping_cart_item_expression(self) -> ShoppingCartItem: return ShoppingCartItem(id=1) @hybrid_prop_first_shopping_cart_item_expression.expression @@ -284,7 +284,7 @@ def hybrid_prop_first_shopping_cart_item_expression(cls): # Other SQLAlchemy Instances @hybrid_property - def hybrid_prop_shopping_cart_item_list(self) -> List["ShoppingCartItem"]: + def hybrid_prop_shopping_cart_item_list(self) -> List[ShoppingCartItem]: return [ShoppingCartItem(id=1), ShoppingCartItem(id=2)] # Unsupported Type diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 774dd0fe..2a233358 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -819,28 +819,29 @@ def add_hybrid_prop_test_data(session): def create_hybrid_prop_schema(session): - class ShoppingCartType(SQLAlchemyObjectType): + class ShoppingCartItemType(SQLAlchemyObjectType): class Meta: - model = ShoppingCart - name = "ShoppingCart" + model = ShoppingCartItem + name = "ShoppingCartItem" interfaces = (relay.Node,) connection_class = Connection - class ShoppingCartItemType(SQLAlchemyObjectType): + class ShoppingCartType(SQLAlchemyObjectType): class Meta: - model = ShoppingCartItem - name = "ShoppingCartItem" + model = ShoppingCart + name = "ShoppingCart" interfaces = (relay.Node,) connection_class = Connection class Query(graphene.ObjectType): node = relay.Node.Field() + items = SQLAlchemyConnectionField(ShoppingCartItemType.connection) carts = SQLAlchemyConnectionField(ShoppingCartType.connection) return Query -# TODO hybrid property +# Test filtering over and returning hybrid_property def test_filter_hybrid_property(session): add_hybrid_prop_test_data(session) Query = create_hybrid_prop_schema(session) From ec62082c9e36d47a6311a62d0a6ddccc1f3157ec Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 16 Jan 2023 08:27:03 -0500 Subject: [PATCH 50/81] fix: filters support interfaces with BaseType --- graphene_sqlalchemy/converter.py | 6 +++- graphene_sqlalchemy/fields.py | 4 +-- graphene_sqlalchemy/filters.py | 45 ++++++++++++++--------------- graphene_sqlalchemy/registry.py | 49 ++++++++++++++++---------------- graphene_sqlalchemy/types.py | 30 ++++++++++--------- 5 files changed, 70 insertions(+), 64 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 6f37e767..40ec3127 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -15,7 +15,6 @@ from .batching import get_batch_resolver from .enums import enum_for_sa_enum -from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( @@ -243,8 +242,13 @@ def convert_sqlalchemy_type( # noqa type_arg: Any, column: Optional[Union[MapperProperty, hybrid_property]] = None, registry: Registry = None, + replace_type_vars: typing.Dict[str, Any] = None, **kwargs, ): + + if replace_type_vars and type_arg in replace_type_vars: + return replace_type_vars[type_arg] + # No valid type found, raise an error raise TypeError( diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 949505ab..78ef5555 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -10,7 +10,7 @@ from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver -from .filters import ObjectTypeFilter +from .filters import BaseTypeFilter from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, EnumValue, get_nullable_type, get_query, get_session if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: @@ -97,7 +97,7 @@ def get_query(cls, model, info, sort=None, filter=None, **args): if filter is not None: assert isinstance(filter, dict) - filter_type: ObjectTypeFilter = type(filter) + filter_type: BaseTypeFilter = type(filter) query, clauses = filter_type.execute_filters(query, filter) query = query.filter(*clauses) return query diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index b1d4fd70..c833658a 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -11,8 +11,9 @@ ) from graphene_sqlalchemy.utils import is_list -ObjectTypeFilterSelf = TypeVar( - "ObjectTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer + +BaseTypeFilterSelf = TypeVar( + "BaseTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer ) @@ -35,13 +36,13 @@ def _get_functions_by_regex( return matching_functions -class ObjectTypeFilter(graphene.InputObjectType): +class BaseTypeFilter(graphene.InputObjectType): @classmethod def __init_subclass_with_meta__( cls, filter_fields=None, model=None, _meta=None, **options ): from graphene_sqlalchemy.converter import ( - convert_sqlalchemy_hybrid_property_type, + convert_sqlalchemy_type, ) # Init meta options class if it doesn't exist already @@ -59,8 +60,8 @@ def __init_subclass_with_meta__( ), "Each filter method must have a value field with valid type annotations" # If type is generic, replace with actual type of filter class - replace_type_vars = {ObjectTypeFilterSelf: cls} - field_type = convert_sqlalchemy_hybrid_property_type( + replace_type_vars = {BaseTypeFilterSelf: cls} + field_type = convert_sqlalchemy_type( _annotations.get("val", str), replace_type_vars=replace_type_vars ) new_filter_fields.update({field_name: graphene.InputField(field_type)}) @@ -77,14 +78,14 @@ def __init_subclass_with_meta__( _meta.model = model - super(ObjectTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) + super(BaseTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) @classmethod def and_logic( cls, query, - filter_type: "ObjectTypeFilter", - val: List[ObjectTypeFilterSelf], + filter_type: "BaseTypeFilter", + val: List[BaseTypeFilterSelf], ): # # Get the model to join on the Filter Query # joined_model = filter_type._meta.model @@ -107,8 +108,8 @@ def and_logic( def or_logic( cls, query, - filter_type: "ObjectTypeFilter", - val: List[ObjectTypeFilterSelf], + filter_type: "BaseTypeFilter", + val: List[BaseTypeFilterSelf], ): # # Get the model to join on the Filter Query # joined_model = filter_type._meta.model @@ -161,7 +162,7 @@ def execute_filters( clauses.extend(_clauses) else: model_field = getattr(model, field) - if issubclass(field_filter_type, ObjectTypeFilter): + if issubclass(field_filter_type, BaseTypeFilter): # Get the model to join on the Filter Query joined_model = field_filter_type._meta.model # Always alias the model @@ -219,7 +220,7 @@ class FieldFilter(graphene.InputObjectType): @classmethod def __init_subclass_with_meta__(cls, graphene_type=None, _meta=None, **options): - from .converter import convert_sqlalchemy_hybrid_property_type + from .converter import convert_sqlalchemy_type # get all filter functions @@ -240,9 +241,7 @@ def __init_subclass_with_meta__(cls, graphene_type=None, _meta=None, **options): ), "Each filter method must have a value field with valid type annotations" # If type is generic, replace with actual type of filter class replace_type_vars = {ScalarFilterInputType: _meta.graphene_type} - field_type = convert_sqlalchemy_hybrid_property_type( - _annotations.get("val", str), replace_type_vars=replace_type_vars - ) + field_type = convert_sqlalchemy_type(_annotations.get("val", str), replace_type_vars=replace_type_vars) new_filter_fields.update({field_name: graphene.InputField(field_type)}) # Add all fields to the meta options. graphene.InputbjectType will take care of the rest @@ -354,9 +353,9 @@ class Meta: class RelationshipFilter(graphene.InputObjectType): @classmethod def __init_subclass_with_meta__( - cls, object_type_filter=None, model=None, _meta=None, **options + cls, base_type_filter=None, model=None, _meta=None, **options ): - if not object_type_filter: + if not base_type_filter: raise Exception("Relationship Filters must be specific to an object type") # Init meta options class if it doesn't exist already if not _meta: @@ -375,11 +374,11 @@ def __init_subclass_with_meta__( # If type is generic, replace with actual type of filter class if is_list(_annotations["val"]): relationship_filters.update( - {field_name: graphene.InputField(graphene.List(object_type_filter))} + {field_name: graphene.InputField(graphene.List(base_type_filter))} ) else: relationship_filters.update( - {field_name: graphene.InputField(object_type_filter)} + {field_name: graphene.InputField(base_type_filter)} ) # Add all fields to the meta options. graphene.InputObjectType will take care of the rest @@ -389,7 +388,7 @@ def __init_subclass_with_meta__( _meta.fields = relationship_filters _meta.model = model - _meta.object_type_filter = object_type_filter + _meta.base_type_filter = base_type_filter super(RelationshipFilter, cls).__init_subclass_with_meta__( _meta=_meta, **options ) @@ -414,7 +413,7 @@ def contains_filter( print("Joined model", relationship_prop) print(query) # pass the alias so group can join group - query, _clauses = cls._meta.object_type_filter.execute_filters( + query, _clauses = cls._meta.base_type_filter.execute_filters( query, v, model_alias=joined_model_alias ) clauses.append(and_(*_clauses)) @@ -439,7 +438,7 @@ def contains_exactly_filter( joined_model_alias = aliased(relationship_prop) subquery = session.query(joined_model_alias.id) - subquery, _clauses = cls._meta.object_type_filter.execute_filters( + subquery, _clauses = cls._meta.base_type_filter.execute_filters( subquery, v, model_alias=joined_model_alias ) subquery_ids = [s_id[0] for s_id in subquery.filter(and_(*_clauses)).all()] diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 07f76e22..d29c4f6f 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -5,11 +5,12 @@ import graphene from graphene import Enum +from graphene.types.base import BaseType if TYPE_CHECKING: # pragma: no_cover from graphene_sqlalchemy.filters import ( FieldFilter, - ObjectTypeFilter, + BaseTypeFilter, RelationshipFilter, ) @@ -24,7 +25,7 @@ def __init__(self): self._registry_sort_enums = {} self._registry_unions = {} self._registry_scalar_filters = {} - self._registry_object_type_filters = {} + self._registry_base_type_filters = {} self._registry_relationship_filters = {} def register(self, obj_type): @@ -140,7 +141,7 @@ def register_filter_for_enum_type( ): from .filters import FieldFilter - if not isinstance(enum_type, type(graphene.Enum)): + if not issubclass(enum_type, graphene.Enum): raise TypeError("Expected Enum, but got: {!r}".format(enum_type)) if not issubclass(filter_obj, FieldFilter): @@ -152,45 +153,45 @@ def get_filter_for_enum_type( ) -> Type["FieldFilter"]: return self._registry_enum_type_filters.get(enum_type) - # Filter Object Types - def register_filter_for_object_type( + # Filter Base Types + def register_filter_for_base_type( self, - object_type: Type[graphene.ObjectType], - filter_obj: Type["ObjectTypeFilter"], + base_type: Type[BaseType], + filter_obj: Type["BaseTypeFilter"], ): - from .filters import ObjectTypeFilter + from .filters import BaseTypeFilter - if not isinstance(object_type, type(graphene.ObjectType)): - raise TypeError("Expected Object Type, but got: {!r}".format(object_type)) + if not issubclass(base_type, BaseType): + raise TypeError("Expected BaseType, but got: {!r}".format(base_type)) - if not issubclass(filter_obj, ObjectTypeFilter): + if not issubclass(filter_obj, BaseTypeFilter): raise TypeError( - "Expected ObjectTypeFilter, but got: {!r}".format(filter_obj) + "Expected BaseTypeFilter, but got: {!r}".format(filter_obj) ) - self._registry_object_type_filters[object_type] = filter_obj + self._registry_base_type_filters[base_type] = filter_obj - def get_filter_for_object_type(self, object_type: Type[graphene.ObjectType]): - return self._registry_object_type_filters.get(object_type) + def get_filter_for_base_type(self, base_type: Type[BaseType]): + return self._registry_base_type_filters.get(base_type) - # Filter Relationships between object types - def register_relationship_filter_for_object_type( - self, object_type: graphene.ObjectType, filter_obj: Type["RelationshipFilter"] + # Filter Relationships between base types + def register_relationship_filter_for_base_type( + self, base_type: BaseType, filter_obj: Type["RelationshipFilter"] ): from .filters import RelationshipFilter - if not isinstance(object_type, type(graphene.ObjectType)): - raise TypeError("Expected Object Type, but got: {!r}".format(object_type)) + if not isinstance(base_type, type(BaseType)): + raise TypeError("Expected BaseType, but got: {!r}".format(base_type)) if not issubclass(filter_obj, RelationshipFilter): raise TypeError( "Expected RelationshipFilter, but got: {!r}".format(filter_obj) ) - self._registry_relationship_filters[object_type] = filter_obj + self._registry_relationship_filters[base_type] = filter_obj - def get_relationship_filter_for_object_type( - self, object_type: Type[graphene.ObjectType] + def get_relationship_filter_for_base_type( + self, base_type: Type[BaseType] ) -> "RelationshipFilter": - return self._registry_relationship_filters.get(object_type) + return self._registry_relationship_filters.get(base_type) registry = None diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 5cf2cf4a..580b1b45 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -34,7 +34,7 @@ FloatFilter, IdFilter, IntFilter, - ObjectTypeFilter, + BaseTypeFilter, RelationshipFilter, StringFilter, ) @@ -131,19 +131,19 @@ class Meta: def get_or_create_relationship_filter( - obj_type: Type[ObjectType], registry: Registry + base_type: Type[BaseType], registry: Registry ) -> Type[RelationshipFilter]: - relationship_filter = registry.get_relationship_filter_for_object_type(obj_type) + relationship_filter = registry.get_relationship_filter_for_base_type(base_type) if not relationship_filter: - object_type_filter = registry.get_filter_for_object_type(obj_type) + base_type_filter = registry.get_filter_for_base_type(base_type) relationship_filter = RelationshipFilter.create_type( - f"{obj_type.__name__}RelationshipFilter", - object_type_filter=object_type_filter, - model=obj_type._meta.model, + f"{base_type.__name__}RelationshipFilter", + base_type_filter=base_type_filter, + model=base_type._meta.model, ) - registry.register_relationship_filter_for_object_type( - obj_type, relationship_filter + registry.register_relationship_filter_for_base_type( + base_type, relationship_filter ) return relationship_filter @@ -197,7 +197,7 @@ def resolve_dynamic(): print("filter class was none!!!") print(type_) return graphene.InputField(reg_res) - reg_res = registry.get_filter_for_object_type(type_.type) + reg_res = registry.get_filter_for_base_type(type_.type) return graphene.InputField(reg_res) else: @@ -210,7 +210,7 @@ def resolve_dynamic(): type_ = get_nullable_type(field.type) # Field might be a SQLAlchemyObjectType, due to hybrid properties if issubclass(type_, SQLAlchemyObjectType): - filter_class = registry.get_filter_for_object_type(type_) + filter_class = registry.get_filter_for_base_type(type_) return graphene.InputField(filter_class) filter_class = registry.get_filter_for_scalar_type(type_) if not filter_class: @@ -373,7 +373,7 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): registry = None # type: sqlalchemy.Registry connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] id = None # type: str - filter_class: Type[ObjectTypeFilter] = None + filter_class: Type[BaseTypeFilter] = None class SQLAlchemyBase(BaseType): @@ -491,10 +491,10 @@ def __init_subclass_with_meta__( filter_fields = yank_fields_from_attrs(filters, _as=InputField, sort=False) - _meta.filter_class = ObjectTypeFilter.create_type( + _meta.filter_class = BaseTypeFilter.create_type( f"{cls.__name__}Filter", filter_fields=filter_fields, model=model ) - registry.register_filter_for_object_type(cls, _meta.filter_class) + registry.register_filter_for_base_type(cls, _meta.filter_class) _meta.connection = connection _meta.id = id or "id" @@ -571,6 +571,7 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): registry = None # type: sqlalchemy.Registry connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] id = None # type: str + filter_class: Type[BaseTypeFilter] = None class SQLAlchemyObjectType(SQLAlchemyBase, ObjectType): @@ -605,6 +606,7 @@ class SQLAlchemyInterfaceOptions(InterfaceOptions): registry = None # type: sqlalchemy.Registry connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] id = None # type: str + filter_class: Type[BaseTypeFilter] = None class SQLAlchemyInterface(SQLAlchemyBase, Interface): From b88280d91c94c20fa004c119aa3e441e8ec7e00a Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 16 Jan 2023 08:45:05 -0500 Subject: [PATCH 51/81] test: fix basic tests for async --- graphene_sqlalchemy/fields.py | 3 + graphene_sqlalchemy/tests/test_filters.py | 143 ++++++++++++---------- 2 files changed, 80 insertions(+), 66 deletions(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 78ef5555..d978c736 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -99,7 +99,9 @@ def get_query(cls, model, info, sort=None, filter=None, **args): assert isinstance(filter, dict) filter_type: BaseTypeFilter = type(filter) query, clauses = filter_type.execute_filters(query, filter) + print("1: ", query, clauses) query = query.filter(*clauses) + print("2: ", query) return query @classmethod @@ -146,6 +148,7 @@ async def resolve_connection_async( session = get_session(info.context) if resolved is None: query = cls.get_query(model, info, **args) + print("HERE: ", query) resolved = (await session.scalars(query)).all() if isinstance(resolved, Query): _len = resolved.count() diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 2a233358..f706a19c 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -19,14 +19,14 @@ ShoppingCartItem, Tag, ) -from .utils import to_std_dicts +from .utils import eventually_await_session, to_std_dicts # TODO test that generated schema is correct for all examples with: # with open('schema.gql', 'w') as fp: # fp.write(str(schema)) -def add_test_data(session): +async def add_test_data(session): reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") session.add(reporter) @@ -59,7 +59,7 @@ def add_test_data(session): editor = Editor(name="Jack") session.add(editor) - session.commit() + await eventually_await_session(session, "commit") def create_schema(session): @@ -131,8 +131,9 @@ class Query(graphene.ObjectType): # Test a simple example of filtering -def test_filter_simple(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_filter_simple(session): + await add_test_data(session) Query = create_schema(session) @@ -151,15 +152,18 @@ def test_filter_simple(session): "reporters": {"edges": [{"node": {"firstName": "Jane"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) + print(result) + print(result.errors) assert not result.errors result = to_std_dicts(result.data) assert result == expected # Test a custom filter type -def test_filter_custom_type(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_filter_custom_type(session): + await add_test_data(session) class MathFilter(FloatFilter): class Meta: @@ -200,20 +204,21 @@ class Query(graphene.ObjectType): }, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected # Test a 1:1 relationship -def test_filter_relationship_one_to_one(session): +@pytest.mark.asyncio +async def test_filter_relationship_one_to_one(session): article = Article(headline="Hi!") image = Image(external_id=1, description="A beautiful image.") article.image = image session.add(article) session.add(image) - session.commit() + await eventually_await_session(session, "commit") Query = create_schema(session) @@ -234,15 +239,16 @@ def test_filter_relationship_one_to_one(session): "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected # Test a 1:n relationship -def test_filter_relationship_one_to_many(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_filter_relationship_one_to_many(session): + await add_test_data(session) Query = create_schema(session) # test contains @@ -265,7 +271,7 @@ def test_filter_relationship_one_to_many(session): "reporters": {"edges": [{"node": {"lastName": "Woe"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -294,13 +300,13 @@ def test_filter_relationship_one_to_many(session): "reporters": {"edges": [{"node": {"firstName": "John", "lastName": "Woe"}}]} } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def add_n2m_test_data(session): +async def add_n2m_test_data(session): # create objects reader1 = Reader(name="Ada") reader2 = Reader(name="Bip") @@ -328,12 +334,13 @@ def add_n2m_test_data(session): session.add(article2) session.add(reader1) session.add(reader2) - session.commit() + await eventually_await_session(session, "commit") # Test n:m relationship contains -def test_filter_relationship_many_to_many_contains(session): - add_n2m_test_data(session) +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains(session): + await add_n2m_test_data(session) Query = create_schema(session) # test contains 1 @@ -363,7 +370,7 @@ def test_filter_relationship_many_to_many_contains(session): }, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -394,7 +401,7 @@ def test_filter_relationship_many_to_many_contains(session): }, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -425,7 +432,7 @@ def test_filter_relationship_many_to_many_contains(session): }, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -433,8 +440,9 @@ def test_filter_relationship_many_to_many_contains(session): # Test n:m relationship containsExactly @pytest.mark.xfail -def test_filter_relationship_many_to_many_contains_exactly(session): - add_n2m_test_data(session) +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains_exactly(session): + await add_n2m_test_data(session) Query = create_schema(session) # test containsExactly 1 @@ -460,7 +468,7 @@ def test_filter_relationship_many_to_many_contains_exactly(session): "articles": {"edges": [{"node": {"headline": "Woah! Another!"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -487,7 +495,7 @@ def test_filter_relationship_many_to_many_contains_exactly(session): "articles": {"edges": [{"node": {"headline": "Article! Look!"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -515,15 +523,16 @@ def test_filter_relationship_many_to_many_contains_exactly(session): "tags": {"edges": [{"node": {"name": "eye-grabbing"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected # Test n:m relationship both contains and containsExactly -def test_filter_relationship_many_to_many_contains_and_contains_exactly(session): - add_n2m_test_data(session) +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains_and_contains_exactly(session): + await add_n2m_test_data(session) Query = create_schema(session) query = """ @@ -551,7 +560,7 @@ def test_filter_relationship_many_to_many_contains_and_contains_exactly(session) "articles": {"edges": [{"node": {"headline": "Woah! Another!"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -559,8 +568,9 @@ def test_filter_relationship_many_to_many_contains_and_contains_exactly(session) # Test n:m nested relationship # TODO add containsExactly -def test_filter_relationship_many_to_many_nested(session): - add_n2m_test_data(session) +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_nested(session): + await add_n2m_test_data(session) Query = create_schema(session) # test readers->articles relationship @@ -585,7 +595,7 @@ def test_filter_relationship_many_to_many_nested(session): "readers": {"edges": [{"node": {"name": "Bip"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -618,7 +628,7 @@ def test_filter_relationship_many_to_many_nested(session): "readers": {"edges": [{"node": {"name": "Bip"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -651,7 +661,7 @@ def test_filter_relationship_many_to_many_nested(session): "tags": {"edges": [{"node": {"name": "sensational"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -685,15 +695,16 @@ def test_filter_relationship_many_to_many_nested(session): "readers": {"edges": [{"node": {"name": "Bip"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected # Test connecting filters with "and" -def test_filter_logic_and(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_filter_logic_and(session): + await add_test_data(session) Query = create_schema(session) @@ -720,15 +731,16 @@ def test_filter_logic_and(session): }, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected # Test connecting filters with "or" -def test_filter_logic_or(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_filter_logic_or(session): + await add_test_data(session) Query = create_schema(session) query = """ @@ -759,15 +771,16 @@ def test_filter_logic_or(session): } } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected # Test connecting filters with "and" and "or" together -def test_filter_logic_and_or(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_filter_logic_and_or(session): + await add_test_data(session) Query = create_schema(session) query = """ @@ -801,21 +814,16 @@ def test_filter_logic_and_or(session): } } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def add_hybrid_prop_test_data(session): - # create objects +async def add_hybrid_prop_test_data(session): cart = ShoppingCart() - - # set relationships - - # save session.add(cart) - session.commit() + await eventually_await_session(session, "commit") def create_hybrid_prop_schema(session): @@ -842,8 +850,9 @@ class Query(graphene.ObjectType): # Test filtering over and returning hybrid_property -def test_filter_hybrid_property(session): - add_hybrid_prop_test_data(session) +@pytest.mark.asyncio +async def test_filter_hybrid_property(session): + await add_hybrid_prop_test_data(session) Query = create_hybrid_prop_schema(session) # test hybrid_prop_int @@ -866,7 +875,7 @@ def test_filter_hybrid_property(session): }, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -891,7 +900,7 @@ def test_filter_hybrid_property(session): }, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -911,7 +920,7 @@ def test_filter_hybrid_property(session): } """ schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert len(result["carts"]["edges"]) == 1 @@ -932,7 +941,7 @@ def test_filter_hybrid_property(session): """ schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert len(result["carts"]["edges"]) == 1 @@ -952,7 +961,7 @@ def test_filter_hybrid_property(session): } """ schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert len(result["carts"]["edges"]) == 1 @@ -962,8 +971,9 @@ def test_filter_hybrid_property(session): # Test edge cases to improve test coverage -def test_filter_edge_cases(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_filter_edge_cases(session): + await add_test_data(session) # test disabling filtering class ArticleType(SQLAlchemyObjectType): @@ -982,8 +992,9 @@ class Query(graphene.ObjectType): # Test additional filter types to improve test coverage -def test_additional_filters(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_additional_filters(session): + await add_test_data(session) Query = create_schema(session) # test n_eq and not_in filters @@ -1002,7 +1013,7 @@ def test_additional_filters(session): "reporters": {"edges": [{"node": {"lastName": "Woe"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -1023,7 +1034,7 @@ def test_additional_filters(session): "pets": {"edges": [{"node": {"name": "Snoopy"}}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected From 8e45a3a42677fa2b1cba75a4cc06f5584713685b Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 16 Jan 2023 08:56:25 -0500 Subject: [PATCH 52/81] fix: remove print statements --- graphene_sqlalchemy/fields.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index d978c736..80d093f6 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -11,7 +11,13 @@ from .batching import get_batch_resolver from .filters import BaseTypeFilter -from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, EnumValue, get_nullable_type, get_query, get_session +from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + EnumValue, + get_nullable_type, + get_query, + get_session, +) if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession @@ -99,9 +105,7 @@ def get_query(cls, model, info, sort=None, filter=None, **args): assert isinstance(filter, dict) filter_type: BaseTypeFilter = type(filter) query, clauses = filter_type.execute_filters(query, filter) - print("1: ", query, clauses) query = query.filter(*clauses) - print("2: ", query) return query @classmethod @@ -148,7 +152,6 @@ async def resolve_connection_async( session = get_session(info.context) if resolved is None: query = cls.get_query(model, info, **args) - print("HERE: ", query) resolved = (await session.scalars(query)).all() if isinstance(resolved, Query): _len = resolved.count() From 2495448ecfd2c483cb9be21b210904867863cc13 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 16 Jan 2023 13:36:16 -0500 Subject: [PATCH 53/81] chore: cleanup --- graphene_sqlalchemy/converter.py | 2 ++ graphene_sqlalchemy/fields.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 177a5c3c..834c067b 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -192,6 +192,7 @@ def _convert_o2m_or_m2m_relationship( child_type = obj_type._meta.registry.get_type_for_model( relationship_prop.mapper.entity ) + if not child_type._meta.connection: # check if we need to use non-null fields list_type = ( @@ -641,6 +642,7 @@ def convert_sqlalchemy_hybrid_property_bare_str(type_arg: str, **kwargs): return convert_sqlalchemy_type(ForwardRef(type_arg), **kwargs) + def convert_hybrid_property_return_type(hybrid_prop): # Grab the original method's return type annotations from inside the hybrid property return_type_annotation = hybrid_prop.fget.__annotations__.get("return", None) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 80d093f6..ef798852 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -48,9 +48,9 @@ def __init__(self, type_, *args, **kwargs): nullable_type = get_nullable_type(type_) # Handle Sorting and Filtering if ( - nullable_type + "sort" not in kwargs + and nullable_type and issubclass(nullable_type, Connection) - and "sort" not in kwargs ): # Let super class raise if type is not a Connection try: @@ -66,9 +66,9 @@ def __init__(self, type_, *args, **kwargs): del kwargs["sort"] if ( - nullable_type + "filter" not in kwargs + and nullable_type and issubclass(nullable_type, Connection) - and "filter" not in kwargs ): # Only add filtering if a filter argument exists on the object type filter_argument = nullable_type.Edge.node._type.get_filter_argument() From f37dd8b8c266847b5874f6c629280652a21f37d3 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 16 Jan 2023 14:47:32 -0500 Subject: [PATCH 54/81] fix: make converter tests work --- graphene_sqlalchemy/types.py | 42 +++++++++++++++--------------------- tox.ini | 3 --- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 580b1b45..0fb59ddc 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,3 +1,4 @@ +import types import warnings from collections import OrderedDict from inspect import isawaitable @@ -30,11 +31,11 @@ sort_enum_for_object_type, ) from .filters import ( + BaseTypeFilter, BooleanFilter, FloatFilter, IdFilter, IntFilter, - BaseTypeFilter, RelationshipFilter, StringFilter, ) @@ -157,21 +158,11 @@ def filter_field_from_type_field( # If a custom filter type was set for this field, use it here if filter_type: return graphene.InputField(filter_type) - # fixme one test case fails where, find out why if issubclass(type(field), graphene.Scalar): filter_class = registry.get_filter_for_scalar_type(type(field)) return graphene.InputField(filter_class) - - elif isinstance(field.type, graphene.List): - print("got field with list type") - pass - elif isinstance(field, graphene.List): - print("Got list") - pass - elif isinstance(field.type, graphene.Dynamic): - pass # If the field is Dynamic, we don't know its type yet and can't select the right filter - elif isinstance(field, graphene.Dynamic): + if isinstance(field, graphene.Dynamic): def resolve_dynamic(): # Resolve Dynamic Type @@ -206,7 +197,18 @@ def resolve_dynamic(): return graphene.Dynamic(resolve_dynamic) - elif isinstance(field, graphene.Field): + if isinstance(field, graphene.List): + print("Got list") + return + if isinstance(field._type, types.FunctionType): + print("got field with function type") + return + if isinstance(field._type, graphene.Dynamic): + return + if isinstance(field._type, graphene.List): + print("got field with list type") + return + if isinstance(field, graphene.Field): type_ = get_nullable_type(field.type) # Field might be a SQLAlchemyObjectType, due to hybrid properties if issubclass(type_, SQLAlchemyObjectType): @@ -219,10 +221,8 @@ def resolve_dynamic(): ) return None return graphene.InputField(filter_class) - else: - raise Exception( - f"Expected a graphene.Field or graphene.Dynamic, but got: {field}" - ) + + raise Exception(f"Expected a graphene.Field or graphene.Dynamic, but got: {field}") def get_polymorphic_on(model): @@ -368,14 +368,6 @@ def construct_fields_and_filters( return fields, filters -class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): - model = None # type: sqlalchemy.Model - registry = None # type: sqlalchemy.Registry - connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] - id = None # type: str - filter_class: Type[BaseTypeFilter] = None - - class SQLAlchemyBase(BaseType): """ This class contains initialization code that is common to both ObjectTypes diff --git a/tox.ini b/tox.ini index 27be21f2..2802dee0 100644 --- a/tox.ini +++ b/tox.ini @@ -38,6 +38,3 @@ basepython = python3.10 deps = -e.[dev] commands = flake8 --exclude setup.py,docs,examples,tests,.tox --max-line-length 120 - -[pytest] -asyncio_mode = auto From 594229157a4ed1c89b5dcb97e523f1e4c65c1f92 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 16 Jan 2023 14:49:56 -0500 Subject: [PATCH 55/81] fix: revert ci --- .github/workflows/tests.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index de78190d..8b3cadfc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,6 +1,12 @@ name: Tests -on: [push, pull_request] +on: + push: + branches: + - 'master' + pull_request: + branches: + - '*' jobs: test: From ace0abac7ed711e1af4d35f551aba900a01014e4 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 13 Feb 2023 13:35:32 -0500 Subject: [PATCH 56/81] add initial filter docs --- docs/filters.rst | 217 +++++++++++++++++++++++++++++++++++++++++++++++ docs/index.rst | 1 + 2 files changed, 218 insertions(+) create mode 100644 docs/filters.rst diff --git a/docs/filters.rst b/docs/filters.rst new file mode 100644 index 00000000..dd797e65 --- /dev/null +++ b/docs/filters.rst @@ -0,0 +1,217 @@ +======= +Filters +======= + +Starting in graphene-sqlalchemy version 3, the SQLAlchemyConnectionField class implements filtering by default. The query utilizes a ``filter`` keyword to specify a filter class that inherits from ``graphene.InputObjectType``. + +Migrating from graphene-sqlalchemy-filter +--------------------------------------------- + +If like many of us, you have been using |graphene-sqlalchemy-filter|_ to implement filters and would like to use the in-built mechanism here, there are a couple key differences to note. Mainly, in an effort to simplify the generated schema, filter keywords are nested under their respective fields instead of concatenated. For example, the filter partial ``{usernameIn: ["moderator", "cool guy"]}`` would be represented as ``{username: {in: ["moderator", "cool guy"]}}``. + +.. |graphene-sqlalchemy-filter| replace:: ``graphene-sqlalchemy-filter`` +.. _graphene-sqlalchemy-filter: https://github.com/art1415926535/graphene-sqlalchemy-filter + +Further, some of the constructs found in libraries like `DGraph's DQL `_ have been implemented, so if you have created custom implementations for these features, you may want to take a look at the examples below. + + +Example model +------------- + +Take as example a Pet model similar to that in the sorting example. We will use variations on this arrangement for the following examples. + +.. code:: + + class Pet(Base): + __tablename__ = 'pets' + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + age = Column(Integer()) + + + class PetNode(SQLAlchemyObjectType): + class Meta: + model = Pet + + + class Query(ObjectType): + allPets = SQLAlchemyConnectionField(PetNode.connection) + + +Simple filter example +--------------------- + +Filters are defined at the object level through the ``BaseTypeFilter`` class. The ``BaseType`` encompasses both Graphene ``ObjectType``\ s and ``Interface``\ s. Each ``BaseTypeFilter`` instance may define fields via ``FieldFilter`` and relationships via ``RelationshipFilter``. Here's a basic example querying a single field on the Pet model: + +.. code:: + + allPets(filter: {name: {eq: "Fido"}}){ + edges { + node { + name + } + } + } + +This will return all pets with the name "Fido". + + +Custom filter types +------------------- + +If you'd like to implement custom behavior for filtering a field, you can do so by extending one of the base filter classes in ``graphene_sqlalchemy.filters``. For example, if you'd like to add a ``divisible_by`` keyword to filter the age attribute on the ``Pet`` model, you can do so as follows: + +.. code:: python + + class MathFilter(FloatFilter): + class Meta: + graphene_type = graphene.Float + + @classmethod + def divisible_by_filter(cls, query, field, val: int) -> bool: + return is_(field % val, 0) + + class PetType(SQLAlchemyObjectType): + ... + + age = ORMField(filter_type=MathFilter) + + class Query(graphene.ObjectType): + pets = SQLAlchemyConnectionField(PetType.connection) + + +Filtering over relationships with RelationshipFilter +---------------------------------------------------- + +When a filter class field refers to another object in a relationship, you may nest filters on relationship object attributes. This happens directly for 1:1 and m:1 relationships and through the ``contains`` and ``containsExactly`` keywords for 1:n and m:n relationships. + + +:1 relationships +^^^^^^^^^^^^^^^^ + +When an object or interface defines a singular relationship, relationship object attributes may be filtered directly like so: + +Take the following SQLAlchemy model definition as an example: + +.. code:: python + + class Pet + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + person_id = Column(Integer(), ForeignKey("persons.id")) + + class Person + id = Column(Integer(), primary_key=True) + pets = relationship("Pet", backref="person") + + +Then, this query will return all pets whose person is named "Ada": + +.. code:: + + allPets(filter: { + person: {name: {eq: "Ada"}} + }) { + ... + } + + +:n relationships +^^^^^^^^^^^^^^^^ + +However, for plural relationships, relationship object attributes must be filtered through either ``contains`` or ``containsExactly``: + +Now, using a many-to-many model definition: + +.. code:: python + + people_pets_table = sqlalchemy.Table( + "people_pets", + Base.metadata, + Column("person_id", ForeignKey("people.id")), + Column("pet_id", ForeignKey("pets.id")), + ) + + class Pet + __tablename__ = "pets" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + + class Person + --tablename__ = "people" + id = Column(Integer(), primary_key=True) + pets = relationship("Pet", backref="people") + + +this query will return all pets which have a person named "Ben" in their ``persons`` list. + +.. code:: + + allPets(filter: { + people: { + contains: [{name: {eq: "Ben"}}], + } + }) { + ... + } + + +and this one will return all pets which hvae a person list that contains exactly the people "Ada" and "Ben" and no fewer or people with other names. + +.. code:: + + allPets(filter: { + articles: { + containsExactly: [ + {name: {eq: "Ada"}}, + {name: {eq: "Ben"}}, + ], + } + }) { + ... + } + +And/Or Logic +------------ + +Filters can also be chained together logically using `and` and `or` keywords nested under `filter`. Clauses are passed directly to `sqlalchemy.and_` and `slqlalchemy.or_`, respectively. To return all pets named "Fido" or "Spot", use: + + +.. code:: + + allPets(filter: { + or: [ + {name: {eq: "Fido"}}, + {name: {eq: "Spot"}}, + ] + }) { + ... + } + +And to return all pets that are named "Fido" or are 5 years old and named "Spot", use: + +.. code:: + + allPets(filter: { + or: [ + {name: {eq: "Fido"}}, + { and: [ + {name: {eq: "Spot"}}, + {age: {eq: 5}} + } + ] + }) { + ... + } + + +Hybrid Property support +----------------------- + +Filtering over SQLAlchemy `hybrid properties `_ is fully supported. + + +Reporting feedback and bugs +--------------------------- + +Filtering is a new feature to graphene-sqlalchemy, so please `post an issue on Github `_ if you run into any problems or have ideas on how to improve the implementation. diff --git a/docs/index.rst b/docs/index.rst index 81b2f316..99b91422 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -8,4 +8,5 @@ Contents: tutorial tips + filters examples From fa0eecc44d841a938868e7794fa65983597c3eea Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 13 Feb 2023 15:33:13 -0500 Subject: [PATCH 57/81] fix: typo --- docs/filters.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/filters.rst b/docs/filters.rst index dd797e65..e50bab98 100644 --- a/docs/filters.rst +++ b/docs/filters.rst @@ -138,7 +138,7 @@ Now, using a many-to-many model definition: name = Column(String(30)) class Person - --tablename__ = "people" + __tablename__ = "people" id = Column(Integer(), primary_key=True) pets = relationship("Pet", backref="people") From 3db74110bb6da7bff1face2ada495afd191f6545 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Tue, 14 Feb 2023 22:00:23 -0500 Subject: [PATCH 58/81] fix: test nits --- graphene_sqlalchemy/filters.py | 11 ++++++----- graphene_sqlalchemy/tests/test_filters.py | 8 +------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index c833658a..3967b0e4 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -11,7 +11,6 @@ ) from graphene_sqlalchemy.utils import is_list - BaseTypeFilterSelf = TypeVar( "BaseTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer ) @@ -41,9 +40,7 @@ class BaseTypeFilter(graphene.InputObjectType): def __init_subclass_with_meta__( cls, filter_fields=None, model=None, _meta=None, **options ): - from graphene_sqlalchemy.converter import ( - convert_sqlalchemy_type, - ) + from graphene_sqlalchemy.converter import convert_sqlalchemy_type # Init meta options class if it doesn't exist already if not _meta: @@ -241,7 +238,9 @@ def __init_subclass_with_meta__(cls, graphene_type=None, _meta=None, **options): ), "Each filter method must have a value field with valid type annotations" # If type is generic, replace with actual type of filter class replace_type_vars = {ScalarFilterInputType: _meta.graphene_type} - field_type = convert_sqlalchemy_type(_annotations.get("val", str), replace_type_vars=replace_type_vars) + field_type = convert_sqlalchemy_type( + _annotations.get("val", str), replace_type_vars=replace_type_vars + ) new_filter_fields.update({field_name: graphene.InputField(field_type)}) # Add all fields to the meta options. graphene.InputbjectType will take care of the rest @@ -274,6 +273,8 @@ def in_filter(cls, query, field, val: List[ScalarFilterInputType]): def not_in_filter(cls, query, field, val: List[ScalarFilterInputType]): return field.notin_(val) + # TODO add like/ilike + @classmethod def execute_filters( cls, query, field, filter_dict: Dict[str, any] diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index f706a19c..5d6130e4 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -68,42 +68,36 @@ class Meta: model = Article name = "Article" interfaces = (relay.Node,) - connection_class = Connection class ImageType(SQLAlchemyObjectType): class Meta: model = Image name = "Image" interfaces = (relay.Node,) - connection_class = Connection class PetType(SQLAlchemyObjectType): class Meta: model = Pet name = "Pet" interfaces = (relay.Node,) - connection_class = Connection class ReaderType(SQLAlchemyObjectType): class Meta: model = Reader name = "Reader" interfaces = (relay.Node,) - connection_class = Connection class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter name = "Reporter" interfaces = (relay.Node,) - connection_class = Connection class TagType(SQLAlchemyObjectType): class Meta: model = Tag name = "Tag" interfaces = (relay.Node,) - connection_class = Connection class Query(graphene.ObjectType): node = relay.Node.Field() @@ -139,7 +133,7 @@ async def test_filter_simple(session): query = """ query { - reporters (filter: {lastName: {eq: "Roe"}}) { + reporters (filter: {lastName: {eq: "Roe", like: "oe"}}) { edges { node { firstName From 56620150364bb331ba285fa2b4ae758c32832713 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 27 Feb 2023 04:48:49 -0500 Subject: [PATCH 59/81] add basic filter example app --- docs/filters.rst | 18 +++++------- docs/requirements.txt | 1 + examples/filters/README.md | 47 +++++++++++++++++++++++++++++ examples/filters/__init__.py | 0 examples/filters/app.py | 18 ++++++++++++ examples/filters/database.py | 49 +++++++++++++++++++++++++++++++ examples/filters/models.py | 34 +++++++++++++++++++++ examples/filters/requirements.txt | 3 ++ examples/filters/run.sh | 1 + examples/filters/schema.py | 42 ++++++++++++++++++++++++++ 10 files changed, 202 insertions(+), 11 deletions(-) create mode 100644 examples/filters/README.md create mode 100644 examples/filters/__init__.py create mode 100644 examples/filters/app.py create mode 100644 examples/filters/database.py create mode 100644 examples/filters/models.py create mode 100644 examples/filters/requirements.txt create mode 100755 examples/filters/run.sh create mode 100644 examples/filters/schema.py diff --git a/docs/filters.rst b/docs/filters.rst index e50bab98..ac36803d 100644 --- a/docs/filters.rst +++ b/docs/filters.rst @@ -34,7 +34,7 @@ Take as example a Pet model similar to that in the sorting example. We will use model = Pet - class Query(ObjectType): + class Query(graphene.ObjectType): allPets = SQLAlchemyConnectionField(PetNode.connection) @@ -96,12 +96,11 @@ Take the following SQLAlchemy model definition as an example: .. code:: python class Pet - id = Column(Integer(), primary_key=True) - name = Column(String(30)) - person_id = Column(Integer(), ForeignKey("persons.id")) + ... + person_id = Column(Integer(), ForeignKey("people.id")) class Person - id = Column(Integer(), primary_key=True) + ... pets = relationship("Pet", backref="person") @@ -133,17 +132,14 @@ Now, using a many-to-many model definition: ) class Pet - __tablename__ = "pets" - id = Column(Integer(), primary_key=True) - name = Column(String(30)) + ... class Person - __tablename__ = "people" - id = Column(Integer(), primary_key=True) + ... pets = relationship("Pet", backref="people") -this query will return all pets which have a person named "Ben" in their ``persons`` list. +this query will return all pets which have a person named "Ben" in their ``people`` list. .. code:: diff --git a/docs/requirements.txt b/docs/requirements.txt index 666a8c9d..220b7cfb 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,3 @@ +sphinx # Docs template http://graphene-python.org/sphinx_graphene_theme.zip diff --git a/examples/filters/README.md b/examples/filters/README.md new file mode 100644 index 00000000..a72e75de --- /dev/null +++ b/examples/filters/README.md @@ -0,0 +1,47 @@ +Example Filters Project +================================ + +This example highlights the ability to filter queries in graphene-sqlalchemy. + +The project contains two models, one named `Department` and another +named `Employee`. + +Getting started +--------------- + +First you'll need to get the source of the project. Do this by cloning the +whole Graphene-SQLAlchemy repository: + +```bash +# Get the example project code +git clone https://github.com/graphql-python/graphene-sqlalchemy.git +cd graphene-sqlalchemy/examples/filters +``` + +It is recommended to create a virtual environment +for this project. We'll do this using +[virtualenv](http://docs.python-guide.org/en/latest/dev/virtualenvs/) +to keep things simple, +but you may also find something like +[virtualenvwrapper](https://virtualenvwrapper.readthedocs.org/en/latest/) +to be useful: + +```bash +# Create a virtualenv in which we can install the dependencies +virtualenv env +source env/bin/activate +``` + +Install our dependencies: + +```bash +pip install -r requirements.txt +``` + +The following command will setup the database, and start the server: + +```bash +python app.py +``` + +Now head over to your favorite GraphQL client, POST to [http://127.0.0.1:5000/graphql](http://127.0.0.1:5000/graphql) and run some queries! diff --git a/examples/filters/__init__.py b/examples/filters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/filters/app.py b/examples/filters/app.py new file mode 100644 index 00000000..6cb15633 --- /dev/null +++ b/examples/filters/app.py @@ -0,0 +1,18 @@ +from database import init_db +from fastapi import FastAPI +from schema import schema +from starlette_graphene3 import GraphQLApp, make_playground_handler + + +def create_app() -> FastAPI: + print("HERE") + init_db() + print("HERE?") + app = FastAPI() + + app.mount("/graphql", GraphQLApp(schema, on_get=make_playground_handler())) + + return app + + +app = create_app() diff --git a/examples/filters/database.py b/examples/filters/database.py new file mode 100644 index 00000000..8f6522f7 --- /dev/null +++ b/examples/filters/database.py @@ -0,0 +1,49 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() +engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, echo=True +) +session_factory = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +from sqlalchemy.orm import scoped_session as scoped_session_factory + +scoped_session = scoped_session_factory(session_factory) + +Base.query = scoped_session.query_property() +Base.metadata.bind = engine + + +def init_db(): + from models import Person, Pet, Toy + + Base.metadata.create_all() + scoped_session.execute("PRAGMA foreign_keys=on") + db = scoped_session() + + person1 = Person(name="A") + person2 = Person(name="B") + + pet1 = Pet(name="Spot") + pet2 = Pet(name="Milo") + + toy1 = Toy(name="disc") + toy2 = Toy(name="ball") + + person1.pet = pet1 + person2.pet = pet2 + + pet1.toys.append(toy1) + pet2.toys.append(toy1) + pet2.toys.append(toy2) + + db.add(person1) + db.add(person2) + db.add(pet1) + db.add(pet2) + db.add(toy1) + db.add(toy2) + + db.commit() diff --git a/examples/filters/models.py b/examples/filters/models.py new file mode 100644 index 00000000..1b22956b --- /dev/null +++ b/examples/filters/models.py @@ -0,0 +1,34 @@ +import sqlalchemy +from database import Base +from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy.orm import relationship + + +class Pet(Base): + __tablename__ = "pets" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + age = Column(Integer()) + person_id = Column(Integer(), ForeignKey("people.id")) + + +class Person(Base): + __tablename__ = "people" + id = Column(Integer(), primary_key=True) + name = Column(String(100)) + pets = relationship("Pet", backref="person") + + +pets_toys_table = sqlalchemy.Table( + "pets_toys", + Base.metadata, + Column("pet_id", ForeignKey("pets.id")), + Column("toy_id", ForeignKey("toys.id")), +) + + +class Toy(Base): + __tablename__ = "toys" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + pets = relationship("Pet", secondary=pets_toys_table, backref="toys") diff --git a/examples/filters/requirements.txt b/examples/filters/requirements.txt new file mode 100644 index 00000000..b433ec59 --- /dev/null +++ b/examples/filters/requirements.txt @@ -0,0 +1,3 @@ +-e ../../ +fastapi +uvicorn diff --git a/examples/filters/run.sh b/examples/filters/run.sh new file mode 100755 index 00000000..ec365444 --- /dev/null +++ b/examples/filters/run.sh @@ -0,0 +1 @@ +uvicorn app:app --port 5000 diff --git a/examples/filters/schema.py b/examples/filters/schema.py new file mode 100644 index 00000000..2728cab7 --- /dev/null +++ b/examples/filters/schema.py @@ -0,0 +1,42 @@ +from models import Person as PersonModel +from models import Pet as PetModel +from models import Toy as ToyModel + +import graphene +from graphene import relay +from graphene_sqlalchemy import SQLAlchemyObjectType +from graphene_sqlalchemy.fields import SQLAlchemyConnectionField + + +class Pet(SQLAlchemyObjectType): + class Meta: + model = PetModel + name = "Pet" + interfaces = (relay.Node,) + batching = True + + +class Person(SQLAlchemyObjectType): + class Meta: + model = PersonModel + name = "Person" + interfaces = (relay.Node,) + batching = True + + +class Toy(SQLAlchemyObjectType): + class Meta: + model = ToyModel + name = "Toy" + interfaces = (relay.Node,) + batching = True + + +class Query(graphene.ObjectType): + node = relay.Node.Field() + pets = SQLAlchemyConnectionField(Pet.connection) + people = SQLAlchemyConnectionField(Person.connection) + toys = SQLAlchemyConnectionField(Toy.connection) + + +schema = graphene.Schema(query=Query) From 9e8f1a5f07ab068f03a9e27d80b5c49ad6e9fd5b Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 27 Feb 2023 14:58:53 -0500 Subject: [PATCH 60/81] add like methods to StringFilter --- graphene_sqlalchemy/filters.py | 12 ++++++++++++ graphene_sqlalchemy/tests/test_filters.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 3967b0e4..f7790fa5 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -293,6 +293,18 @@ class StringFilter(FieldFilter): class Meta: graphene_type = graphene.String + @classmethod + def like_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field.like(val) + + @classmethod + def ilike_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field.ilike(val) + + @classmethod + def notlike_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field.notlike(val) + class BooleanFilter(FieldFilter): class Meta: diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 5d6130e4..ab357502 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -133,7 +133,7 @@ async def test_filter_simple(session): query = """ query { - reporters (filter: {lastName: {eq: "Roe", like: "oe"}}) { + reporters (filter: {lastName: {eq: "Roe", like: "%oe"}}) { edges { node { firstName From d2360a1982d368ae93d408e409fedf95c19947f0 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Tue, 28 Feb 2023 15:47:34 -0500 Subject: [PATCH 61/81] raise errors in filters tests --- graphene_sqlalchemy/filters.py | 5 ++ graphene_sqlalchemy/tests/test_filters.py | 103 +++++++--------------- 2 files changed, 37 insertions(+), 71 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index f7790fa5..2fbc12e6 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -443,6 +443,11 @@ def contains_exactly_filter( ): print("Contains exactly called: ", query, val) session = query.session + # TODO change logic as follows: + # do select() + # write down query without session + # main_query.subqueryload() + # use query.where() instead of query.filter() child_model_ids = [] for v in val: print("Contains exactly loop: ", v) diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index ab357502..48a4620a 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -26,6 +26,15 @@ # fp.write(str(schema)) +def assert_and_raise_result(result, expected): + if result.errors: + for error in result.errors: + raise error + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + async def add_test_data(session): reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") session.add(reporter) @@ -147,11 +156,7 @@ async def test_filter_simple(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - print(result) - print(result.errors) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # Test a custom filter type @@ -199,9 +204,7 @@ class Query(graphene.ObjectType): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # Test a 1:1 relationship @@ -234,9 +237,7 @@ async def test_filter_relationship_one_to_one(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # Test a 1:n relationship @@ -266,9 +267,7 @@ async def test_filter_relationship_one_to_many(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # test containsExactly query = """ @@ -295,9 +294,7 @@ async def test_filter_relationship_one_to_many(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) async def add_n2m_test_data(session): @@ -365,9 +362,7 @@ async def test_filter_relationship_many_to_many_contains(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # test contains 2 query = """ @@ -396,9 +391,7 @@ async def test_filter_relationship_many_to_many_contains(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # test reverse query = """ @@ -427,9 +420,7 @@ async def test_filter_relationship_many_to_many_contains(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # Test n:m relationship containsExactly @@ -463,9 +454,7 @@ async def test_filter_relationship_many_to_many_contains_exactly(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # test containsExactly 2 query = """ @@ -490,9 +479,7 @@ async def test_filter_relationship_many_to_many_contains_exactly(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # test reverse query = """ @@ -518,9 +505,7 @@ async def test_filter_relationship_many_to_many_contains_exactly(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # Test n:m relationship both contains and containsExactly @@ -555,9 +540,7 @@ async def test_filter_relationship_many_to_many_contains_and_contains_exactly(se } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # Test n:m nested relationship @@ -590,9 +573,7 @@ async def test_filter_relationship_many_to_many_nested(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # test nested readers->articles->tags query = """ @@ -623,9 +604,7 @@ async def test_filter_relationship_many_to_many_nested(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # test nested reverse query = """ @@ -656,9 +635,7 @@ async def test_filter_relationship_many_to_many_nested(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # test filter on both levels of nesting query = """ @@ -690,9 +667,7 @@ async def test_filter_relationship_many_to_many_nested(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # Test connecting filters with "and" @@ -726,9 +701,7 @@ async def test_filter_logic_and(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # Test connecting filters with "or" @@ -766,9 +739,7 @@ async def test_filter_logic_or(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # Test connecting filters with "and" and "or" together @@ -809,9 +780,7 @@ async def test_filter_logic_and_or(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) async def add_hybrid_prop_test_data(session): @@ -870,9 +839,7 @@ async def test_filter_hybrid_property(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # test hybrid_prop_float query = """ @@ -895,9 +862,7 @@ async def test_filter_hybrid_property(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # test hybrid_prop different model without expression query = """ @@ -1008,9 +973,7 @@ async def test_additional_filters(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) # test gt, lt, gte, and lte filters query = """ @@ -1029,6 +992,4 @@ async def test_additional_filters(session): } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) - assert not result.errors - result = to_std_dicts(result.data) - assert result == expected + assert_and_raise_result(result, expected) From 589c7d7551db7b277829d66d76234802396171e2 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 27 Mar 2023 15:14:57 -0400 Subject: [PATCH 62/81] remove contains_exactly logic --- graphene_sqlalchemy/filters.py | 37 +-------------- graphene_sqlalchemy/tests/test_filters.py | 56 ++++++++++++----------- 2 files changed, 31 insertions(+), 62 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 2fbc12e6..b04e27d7 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -441,42 +441,7 @@ def contains_exactly_filter( relationship_prop, val: List[ScalarFilterInputType], ): - print("Contains exactly called: ", query, val) - session = query.session - # TODO change logic as follows: - # do select() - # write down query without session - # main_query.subqueryload() - # use query.where() instead of query.filter() - child_model_ids = [] - for v in val: - print("Contains exactly loop: ", v) - - # Always alias the model - joined_model_alias = aliased(relationship_prop) - - subquery = session.query(joined_model_alias.id) - subquery, _clauses = cls._meta.base_type_filter.execute_filters( - subquery, v, model_alias=joined_model_alias - ) - subquery_ids = [s_id[0] for s_id in subquery.filter(and_(*_clauses)).all()] - child_model_ids.extend(subquery_ids) - - # Join the relationship onto the query - joined_model_alias = aliased(relationship_prop) - joined_field = field.of_type(joined_model_alias) - query = query.join(joined_field) - - # Construct clauses from child_model_ids - query = ( - query.filter(joined_model_alias.id.in_(child_model_ids)) - .group_by(parent_model) - .having(func.count(str(field)) == len(child_model_ids)) - # TODO should filter on aliased field - # .having(func.count(joined_field) == len(child_model_ids)) - ) - - return query, [] + raise NotImplementedError @classmethod def execute_filters( diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 48a4620a..be99222a 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -269,32 +269,33 @@ async def test_filter_relationship_one_to_many(session): result = await schema.execute_async(query, context_value={"session": session}) assert_and_raise_result(result, expected) - # test containsExactly - query = """ - query { - reporters (filter: { - articles: { - containsExactly: [ - {headline: {eq: "Hi!"}} - {headline: {eq: "Hello!"}} - ] - } - }) { - edges { - node { - firstName - lastName - } - } - } - } - """ - expected = { - "reporters": {"edges": [{"node": {"firstName": "John", "lastName": "Woe"}}]} - } - schema = graphene.Schema(query=Query) - result = await schema.execute_async(query, context_value={"session": session}) - assert_and_raise_result(result, expected) + # TODO test containsExactly + # # test containsExactly + # query = """ + # query { + # reporters (filter: { + # articles: { + # containsExactly: [ + # {headline: {eq: "Hi!"}} + # {headline: {eq: "Hello!"}} + # ] + # } + # }) { + # edges { + # node { + # firstName + # lastName + # } + # } + # } + # } + # """ + # expected = { + # "reporters": {"edges": [{"node": {"firstName": "John", "lastName": "Woe"}}]} + # } + # schema = graphene.Schema(query=Query) + # result = await schema.execute_async(query, context_value={"session": session}) + # assert_and_raise_result(result, expected) async def add_n2m_test_data(session): @@ -427,6 +428,7 @@ async def test_filter_relationship_many_to_many_contains(session): @pytest.mark.xfail @pytest.mark.asyncio async def test_filter_relationship_many_to_many_contains_exactly(session): + raise NotImplementedError await add_n2m_test_data(session) Query = create_schema(session) @@ -509,8 +511,10 @@ async def test_filter_relationship_many_to_many_contains_exactly(session): # Test n:m relationship both contains and containsExactly +@pytest.mark.xfail @pytest.mark.asyncio async def test_filter_relationship_many_to_many_contains_and_contains_exactly(session): + raise NotImplementedError await add_n2m_test_data(session) Query = create_schema(session) From 48b1c6c56151212d20af6968eb647e721950cf7c Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 10 Apr 2023 13:28:25 -0400 Subject: [PATCH 63/81] fix lint --- graphene_sqlalchemy/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index b04e27d7..b0ce31be 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,7 +1,7 @@ import re from typing import Any, Dict, List, Tuple, Type, TypeVar, Union -from sqlalchemy import and_, func, not_, or_ +from sqlalchemy import and_, not_, or_ from sqlalchemy.orm import Query, aliased # , selectinload import graphene From cab376a7614490d27be9275d541677afe45b18ee Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 24 Apr 2023 12:24:30 -0400 Subject: [PATCH 64/81] fix: sqla 1.4 async filter tests passing with distinct --- graphene_sqlalchemy/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index b0ce31be..becdf373 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -422,7 +422,7 @@ def contains_filter( joined_model_alias = aliased(relationship_prop) # Join the aliased model onto the query - query = query.join(field.of_type(joined_model_alias)) + query = query.join(field.of_type(joined_model_alias)).distinct() print("Joined model", relationship_prop) print(query) # pass the alias so group can join group From db9f794339d4448b7ea1e0c7b9668a8a84399a39 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 24 Apr 2023 14:38:45 -0400 Subject: [PATCH 65/81] cleanup: automatically register field filters --- examples/filters/app.py | 2 -- graphene_sqlalchemy/types.py | 40 +++++++++++++++--------------------- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/examples/filters/app.py b/examples/filters/app.py index 6cb15633..ab918da7 100644 --- a/examples/filters/app.py +++ b/examples/filters/app.py @@ -5,9 +5,7 @@ def create_app() -> FastAPI: - print("HERE") init_db() - print("HERE?") app = FastAPI() app.mount("/graphql", GraphQLApp(schema, on_get=make_playground_handler())) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 25511fff..fee9d386 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,3 +1,4 @@ +import inspect import types import warnings from collections import OrderedDict @@ -18,6 +19,7 @@ from graphene.types.unmountedtype import UnmountedType from graphene.types.utils import yank_fields_from_attrs from graphene.utils.orderedtype import OrderedType +import graphene_sqlalchemy.filters as gsa_filters from .converter import ( convert_sqlalchemy_column, @@ -30,15 +32,7 @@ sort_argument_for_object_type, sort_enum_for_object_type, ) -from .filters import ( - BaseTypeFilter, - BooleanFilter, - FloatFilter, - IdFilter, - IntFilter, - RelationshipFilter, - StringFilter, -) +from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( @@ -405,21 +399,21 @@ def __init_subclass_with_meta__( ) if not registry: + # TODO add documentation for users to register their own filters registry = get_global_registry() - # TODO way of doing this automatically? - get_global_registry().register_filter_for_scalar_type( - graphene.Float, FloatFilter - ) - get_global_registry().register_filter_for_scalar_type( - graphene.Int, IntFilter - ) - get_global_registry().register_filter_for_scalar_type( - graphene.String, StringFilter - ) - get_global_registry().register_filter_for_scalar_type( - graphene.Boolean, BooleanFilter - ) - get_global_registry().register_filter_for_scalar_type(graphene.ID, IdFilter) + field_filter_classes = [ + filter_cls[1] + for filter_cls in inspect.getmembers(gsa_filters, inspect.isclass) + if ( + filter_cls[1] is not FieldFilter + and FieldFilter in filter_cls[1].__mro__ + and getattr(filter_cls[1]._meta, "graphene_type", False) + ) + ] + for field_filter_class in field_filter_classes: + get_global_registry().register_filter_for_scalar_type( + field_filter_class._meta.graphene_type, field_filter_class + ) assert isinstance(registry, Registry), ( "The attribute registry in {} needs to be an instance of " From 21eba2e0ede2f8361d0e719e1b2082216f723fea Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 24 Apr 2023 22:08:05 +0200 Subject: [PATCH 66/81] fix: convert type vars in converter.py --- graphene_sqlalchemy/converter.py | 116 ++++++++++---------- graphene_sqlalchemy/tests/test_converter.py | 14 ++- 2 files changed, 73 insertions(+), 57 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 3d08c1b7..32062df3 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -3,7 +3,7 @@ import typing import uuid from decimal import Decimal -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union, cast, Dict, TypeVar from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql @@ -101,12 +101,12 @@ def is_column_nullable(column): def convert_sqlalchemy_relationship( - relationship_prop, - obj_type, - connection_field_factory, - batching, - orm_field_name, - **field_kwargs, + relationship_prop, + obj_type, + connection_field_factory, + batching, + orm_field_name, + **field_kwargs, ): """ :param sqlalchemy.RelationshipProperty relationship_prop: @@ -147,7 +147,7 @@ def dynamic_type(): def _convert_o2o_or_m2o_relationship( - relationship_prop, obj_type, batching, orm_field_name, **field_kwargs + relationship_prop, obj_type, batching, orm_field_name, **field_kwargs ): """ Convert one-to-one or many-to-one relationshsip. Return an object field. @@ -175,7 +175,7 @@ def _convert_o2o_or_m2o_relationship( def _convert_o2m_or_m2m_relationship( - relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs + relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs ): """ Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field. @@ -281,13 +281,12 @@ def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): @singledispatchbymatchfunction def convert_sqlalchemy_type( # noqa - type_arg: Any, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - replace_type_vars: typing.Dict[str, Any] = None, - **kwargs, + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + replace_type_vars: typing.Dict[str, Any] = None, + **kwargs, ): - if replace_type_vars and type_arg in replace_type_vars: return replace_type_vars[type_arg] @@ -302,7 +301,7 @@ def convert_sqlalchemy_type( # noqa @convert_sqlalchemy_type.register(safe_isinstance(DeclarativeMeta)) def convert_sqlalchemy_model_using_registry( - type_arg: Any, registry: Registry = None, **kwargs + type_arg: Any, registry: Registry = None, **kwargs ): registry_ = registry or get_global_registry() @@ -330,6 +329,11 @@ def convert_scalar_type(type_arg: Any, **kwargs): return type_arg +@convert_sqlalchemy_type.register(safe_isinstance(TypeVar)) +def convert_scalar_type(type_arg: Any, replace_type_vars: Dict[TypeVar, Any], **kwargs): + return replace_type_vars[type_arg] + + @convert_sqlalchemy_type.register(column_type_eq(str)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.String)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Text)) @@ -349,8 +353,8 @@ def convert_column_to_string(type_arg: Any, **kwargs): @convert_sqlalchemy_type.register(column_type_eq(sqa_utils.UUIDType)) @convert_sqlalchemy_type.register(column_type_eq(uuid.UUID)) def convert_column_to_uuid( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.UUID @@ -358,8 +362,8 @@ def convert_column_to_uuid( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.DateTime)) @convert_sqlalchemy_type.register(column_type_eq(datetime.datetime)) def convert_column_to_datetime( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.DateTime @@ -367,8 +371,8 @@ def convert_column_to_datetime( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Time)) @convert_sqlalchemy_type.register(column_type_eq(datetime.time)) def convert_column_to_time( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.Time @@ -376,8 +380,8 @@ def convert_column_to_time( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Date)) @convert_sqlalchemy_type.register(column_type_eq(datetime.date)) def convert_column_to_date( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.Date @@ -386,10 +390,10 @@ def convert_column_to_date( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Integer)) @convert_sqlalchemy_type.register(column_type_eq(int)) def convert_column_to_int_or_id( - type_arg: Any, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - **kwargs, + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, ): # fixme drop the primary key processing from here in another pr if column is not None: @@ -401,8 +405,8 @@ def convert_column_to_int_or_id( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Boolean)) @convert_sqlalchemy_type.register(column_type_eq(bool)) def convert_column_to_boolean( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.Boolean @@ -412,8 +416,8 @@ def convert_column_to_boolean( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Numeric)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.BigInteger)) def convert_column_to_float( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.Float @@ -421,10 +425,10 @@ def convert_column_to_float( @convert_sqlalchemy_type.register(column_type_eq(postgresql.ENUM)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Enum)) def convert_enum_to_enum( - type_arg: Any, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - **kwargs, + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, ): if column is None or isinstance(column, hybrid_property): raise Exception("SQL-Enum conversion requires a column") @@ -435,9 +439,9 @@ def convert_enum_to_enum( # TODO Make ChoiceType conversion consistent with other enums @convert_sqlalchemy_type.register(column_type_eq(sqa_utils.ChoiceType)) def convert_choice_to_enum( - type_arg: sqa_utils.ChoiceType, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - **kwargs, + type_arg: sqa_utils.ChoiceType, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + **kwargs, ): if column is None or isinstance(column, hybrid_property): raise Exception("ChoiceType conversion requires a column") @@ -453,8 +457,8 @@ def convert_choice_to_enum( @convert_sqlalchemy_type.register(column_type_eq(sqa_utils.ScalarListType)) def convert_scalar_list_to_list( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.List(graphene.String) @@ -470,10 +474,10 @@ def init_array_list_recursive(inner_type, n): @convert_sqlalchemy_type.register(column_type_eq(sqa_types.ARRAY)) @convert_sqlalchemy_type.register(column_type_eq(postgresql.ARRAY)) def convert_array_to_list( - type_arg: Any, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - **kwargs, + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, ): if column is None or isinstance(column, hybrid_property): raise Exception("SQL-Array conversion requires a column") @@ -492,8 +496,8 @@ def convert_array_to_list( @convert_sqlalchemy_type.register(column_type_eq(postgresql.JSON)) @convert_sqlalchemy_type.register(column_type_eq(postgresql.JSONB)) def convert_json_to_string( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return JSONString @@ -501,18 +505,18 @@ def convert_json_to_string( @convert_sqlalchemy_type.register(column_type_eq(sqa_utils.JSONType)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.JSON)) def convert_json_type_to_string( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return JSONString @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Variant)) def convert_variant_to_impl_type( - type_arg: sqa_types.Variant, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - **kwargs, + type_arg: sqa_types.Variant, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, ): if column is None or isinstance(column, hybrid_property): raise Exception("Vaiant conversion requires a column") @@ -543,7 +547,7 @@ def is_union(type_arg: Any, **kwargs) -> bool: def graphene_union_for_py_union( - obj_types: typing.List[graphene.ObjectType], registry + obj_types: typing.List[graphene.ObjectType], registry ) -> graphene.Union: union_type = registry.get_union_for_object_types(obj_types) @@ -586,8 +590,8 @@ def convert_sqlalchemy_hybrid_property_union(type_arg: Any, **kwargs): # Now check if every type is instance of an ObjectType if not all( - isinstance(graphene_type, type(graphene.ObjectType)) - for graphene_type in graphene_types + isinstance(graphene_type, type(graphene.ObjectType)) + for graphene_type in graphene_types ): raise ValueError( "Cannot convert hybrid_property Union to graphene.Union: the Union contains scalars. " diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 929b4855..7ca42a6c 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,6 +1,6 @@ import enum import sys -from typing import Dict, Tuple, Union +from typing import Dict, Tuple, Union, TypeVar import pytest import sqlalchemy @@ -195,6 +195,18 @@ def hybrid_prop(self) -> "ShoppingCartItem": get_hybrid_property_type(hybrid_prop).type == ShoppingCartType +def test_converter_replace_type_var(): + + T = TypeVar("T") + + replace_type_vars = {T: graphene.String} + + field_type = convert_sqlalchemy_type( + T, replace_type_vars=replace_type_vars + ) + + assert field_type == graphene.String + @pytest.mark.skipif( sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" ) From 39834eda72bfbbd42a2dc06260c08c18c5e9a339 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sun, 14 May 2023 15:21:23 +0200 Subject: [PATCH 67/81] chore: fix review comments --- .gitignore | 1 + graphene_sqlalchemy/tests/models.py | 14 ++--- graphene_sqlalchemy/tests/models_batching.py | 14 ++--- graphene_sqlalchemy/tests/test_converter.py | 63 +++++++------------- graphene_sqlalchemy/tests/utils.py | 13 +++- 5 files changed, 44 insertions(+), 61 deletions(-) diff --git a/.gitignore b/.gitignore index c4a735fe..47a82df0 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ __pycache__/ .Python env/ .venv/ +venv/ build/ develop-eggs/ dist/ diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index a2ccd82f..8349a394 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -16,12 +16,12 @@ String, Table, func, - select, ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import backref, column_property, composite, mapper, relationship +from graphene_sqlalchemy.tests.utils import wrap_select_func from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 PetKind = Enum("cat", "dog", name="pet_kind") @@ -118,15 +118,9 @@ def hybrid_prop_bool(self) -> bool: def hybrid_prop_list(self) -> List[int]: return [1, 2, 3] - # TODO Remove when switching min sqlalchemy version to SQLAlchemy 1.4 - if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: - column_prop = column_property( - select(func.cast(func.count(id), Integer)), doc="Column property" - ) - else: - column_prop = column_property( - select([func.cast(func.count(id), Integer)]), doc="Column property" - ) + column_prop = column_property( + wrap_select_func(func.cast(func.count(id), Integer)), doc="Column property" + ) composite_prop = composite( CompositeFullName, first_name, last_name, doc="Composite" diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py index dde6d45c..5dde366f 100644 --- a/graphene_sqlalchemy/tests/models_batching.py +++ b/graphene_sqlalchemy/tests/models_batching.py @@ -11,12 +11,11 @@ String, Table, func, - select, ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import column_property, relationship -from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 +from graphene_sqlalchemy.tests.utils import wrap_select_func PetKind = Enum("cat", "dog", name="pet_kind") @@ -62,14 +61,9 @@ class Reporter(Base): articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) - if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: - column_prop = column_property( - select(func.cast(func.count(id), Integer)), doc="Column property" - ) - else: - column_prop = column_property( - select([func.cast(func.count(id), Integer)]), doc="Column property" - ) + column_prop = column_property( + wrap_select_func(func.cast(func.count(id), Integer)), doc="Column property" + ) class Article(Base): diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index bfd3ee66..1a5e0093 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -2,19 +2,26 @@ import sys from typing import Dict, Union +import graphene import pytest import sqlalchemy_utils as sqa_utils -from sqlalchemy import Column, func, select, types +from graphene.relay import Node +from graphene.types.structures import Structure +from sqlalchemy import Column, func, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite -import graphene -from graphene.relay import Node -from graphene.types.structures import Structure - +from .models import ( + Article, + CompositeFullName, + Pet, + Reporter, + ShoppingCart, + ShoppingCartItem, +) from ..converter import ( convert_sqlalchemy_column, convert_sqlalchemy_composite, @@ -24,15 +31,7 @@ from ..fields import UnsortedSQLAlchemyConnectionField, default_connection_field_factory from ..registry import Registry, get_global_registry from ..types import ORMField, SQLAlchemyObjectType -from ..utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_sqlalchemy_version_less_than -from .models import ( - Article, - CompositeFullName, - Pet, - Reporter, - ShoppingCart, - ShoppingCartItem, -) +from ..utils import is_sqlalchemy_version_less_than def mock_resolver(): @@ -88,9 +87,9 @@ def prop_method() -> int | str: return "not allowed in gql schema" with pytest.raises( - ValueError, - match=r"Cannot convert hybrid_property Union to " - r"graphene.Union: the Union contains scalars. \.*", + ValueError, + match=r"Cannot convert hybrid_property Union to " + r"graphene.Union: the Union contains scalars. \.*", ): get_hybrid_property_type(prop_method) @@ -337,25 +336,9 @@ class TestEnum(enum.IntEnum): assert graphene_type._meta.enum.__members__["two"].value == 2 -@pytest.mark.skipif( - not SQL_VERSION_HIGHER_EQUAL_THAN_1_4, - reason="SQLAlchemy <1.4 does not support this", -) -def test_should_columproperty_convert_sqa_20(): - field = get_field_from_column( - column_property(select(func.sum(func.cast(id, types.Integer))).where(id == 1)) - ) - - assert field.type == graphene.Int - - -@pytest.mark.skipif( - not is_sqlalchemy_version_less_than("2.0.0b1"), - reason="SQLAlchemy >=2.0 does not support this syntax, see convert_sqa_20", -) def test_should_columproperty_convert(): field = get_field_from_column( - column_property(select([func.sum(func.cast(id, types.Integer))]).where(id == 1)) + column_property(wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1)) ) assert field.type == graphene.Int @@ -654,8 +637,8 @@ class Meta: ) for ( - hybrid_prop_name, - hybrid_prop_expected_return_type, + hybrid_prop_name, + hybrid_prop_expected_return_type, ) in shopping_cart_item_expected_types.items(): hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name] @@ -666,7 +649,7 @@ class Meta: str(hybrid_prop_expected_return_type), ) assert ( - hybrid_prop_field.description is None + hybrid_prop_field.description is None ) # "doc" is ignored by hybrid property ################################################### @@ -714,8 +697,8 @@ class Meta: ) for ( - hybrid_prop_name, - hybrid_prop_expected_return_type, + hybrid_prop_name, + hybrid_prop_expected_return_type, ) in shopping_cart_expected_types.items(): hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name] @@ -726,5 +709,5 @@ class Meta: str(hybrid_prop_expected_return_type), ) assert ( - hybrid_prop_field.description is None + hybrid_prop_field.description is None ) # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py index 4a118243..6e843316 100644 --- a/graphene_sqlalchemy/tests/utils.py +++ b/graphene_sqlalchemy/tests/utils.py @@ -1,6 +1,10 @@ import inspect import re +from sqlalchemy import select + +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 + def to_std_dicts(value): """Convert nested ordered dicts to normal dicts for better comparison.""" @@ -18,8 +22,15 @@ def remove_cache_miss_stat(message): return re.sub(r"\[generated in \d+.?\d*s\]\s", "", message) -async def eventually_await_session(session, func, *args): +def wrap_select_func(query): + # TODO remove this when we drop support for sqa < 2.0 + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + return select(query) + else: + return select([query]) + +async def eventually_await_session(session, func, *args): if inspect.iscoroutinefunction(getattr(session, func)): await getattr(session, func)(*args) else: From a9915d6af26ab8e6f0e74444db8c0cec7a8c6f05 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sun, 14 May 2023 19:36:44 +0200 Subject: [PATCH 68/81] chore: update dependencies and fix test --- graphene_sqlalchemy/tests/test_converter.py | 1 + setup.py | 2 +- tox.ini | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 1a5e0093..4666d9a2 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -22,6 +22,7 @@ ShoppingCart, ShoppingCartItem, ) +from .utils import wrap_select_func from ..converter import ( convert_sqlalchemy_column, convert_sqlalchemy_composite, diff --git a/setup.py b/setup.py index 9650e6d2..0e828caa 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ # To keep things simple, we only support newer versions of Graphene "graphene>=3.0.0b7", "promise>=2.3", - "SQLAlchemy>=1.1,<2.1", + "SQLAlchemy>=1.1", "aiodataloader>=0.2.0,<1.0", ] diff --git a/tox.ini b/tox.ini index 1841cb1a..9ce901e4 100644 --- a/tox.ini +++ b/tox.ini @@ -24,7 +24,7 @@ deps = sql12: sqlalchemy>=1.2,<1.3 sql13: sqlalchemy>=1.3,<1.4 sql14: sqlalchemy>=1.4,<1.5 - sql20: sqlalchemy>=2.0.0b3,<2.1 + sql20: sqlalchemy>=2.0.0b3 setenv = SQLALCHEMY_WARN_20 = 1 commands = From 4712e10e5f16f04d1c780e6e14dfccc7a26798a7 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sun, 14 May 2023 19:55:21 +0200 Subject: [PATCH 69/81] chore: update sqa-utils fix --- graphene_sqlalchemy/tests/conftest.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index d3fcedc9..89b357a4 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -3,13 +3,6 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -# fmt: off -# Fixme remove when https://github.com/kvesteri/sqlalchemy-utils/pull/644 is released #noqa -import sqlalchemy # noqa # isort:skip -if sqlalchemy.__version__ == "2.0.0b3": # noqa # isort:skip - sqlalchemy.__version__ = "2.0.0" # noqa # isort:skip -# fmt: on - import graphene from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 From a0abf8a00f09e7cdf3538c8e0bef7a3d281bc036 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sun, 14 May 2023 21:41:45 +0200 Subject: [PATCH 70/81] fix: adjust test after sqlalchemy 2.0 update --- graphene_sqlalchemy/tests/models.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 68c96bf6..b638b5d4 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -20,11 +20,18 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import backref, column_property, composite, mapper, relationship -from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter from sqlalchemy.sql.type_api import TypeEngine from graphene_sqlalchemy.tests.utils import wrap_select_func -from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, SQL_VERSION_HIGHER_EQUAL_THAN_2 + +# fmt: off +import sqlalchemy +if SQL_VERSION_HIGHER_EQUAL_THAN_2: + from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip +else: + from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip +# fmt: on PetKind = Enum("cat", "dog", name="pet_kind") @@ -343,7 +350,7 @@ class Employee(Person): ############################################ -class CustomIntegerColumn(_LookupExpressionAdapter, TypeEngine): +class CustomIntegerColumn(HasExpressionLookup, TypeEngine): """ Custom Column Type that our converters don't recognize Adapted from sqlalchemy.Integer From 1052b5269660ad4454fdcd0da97a54abbd836094 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sun, 4 Jun 2023 22:38:13 +0200 Subject: [PATCH 71/81] fix: keep aliases during and + or filtering --- graphene_sqlalchemy/filters.py | 11 +-- graphene_sqlalchemy/tests/test_filters.py | 102 ++++++++++++++++++++++ 2 files changed, 108 insertions(+), 5 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index becdf373..1f227e17 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -83,19 +83,19 @@ def and_logic( query, filter_type: "BaseTypeFilter", val: List[BaseTypeFilterSelf], + model_alias=None, ): # # Get the model to join on the Filter Query # joined_model = filter_type._meta.model # # Always alias the model # joined_model_alias = aliased(joined_model) - clauses = [] for value in val: # # Join the aliased model onto the query # query = query.join(model_field.of_type(joined_model_alias)) query, _clauses = filter_type.execute_filters( - query, value + query, value, model_alias=model_alias ) # , model_alias=joined_model_alias) clauses += _clauses @@ -107,6 +107,7 @@ def or_logic( query, filter_type: "BaseTypeFilter", val: List[BaseTypeFilterSelf], + model_alias=None, ): # # Get the model to join on the Filter Query # joined_model = filter_type._meta.model @@ -119,7 +120,7 @@ def or_logic( # query = query.join(model_field.of_type(joined_model_alias)) query, _clauses = filter_type.execute_filters( - query, value + query, value, model_alias=model_alias ) # , model_alias=joined_model_alias) clauses += _clauses @@ -149,12 +150,12 @@ def execute_filters( # to conduct joins and alias the joins (in case there are duplicate joins: A->B A->C B->C) if field == "and": query, _clauses = cls.and_logic( - query, field_filter_type.of_type, field_filters + query, field_filter_type.of_type, field_filters, model_alias=model ) clauses.extend(_clauses) elif field == "or": query, _clauses = cls.or_logic( - query, field_filter_type.of_type, field_filters + query, field_filter_type.of_type, field_filters, model_alias=model ) clauses.extend(_clauses) else: diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index be99222a..20672c2f 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -424,6 +424,108 @@ async def test_filter_relationship_many_to_many_contains(session): assert_and_raise_result(result, expected) +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains_with_and(session): + """ + This test is necessary to ensure we don't accidentally turn and-contains filter + into or-contains filters due to incorrect aliasing of the joined table. + """ + await add_n2m_test_data(session) + Query = create_schema(session) + + # test contains 1 + query = """ + query { + articles (filter: { + tags: { + contains: [{ + and: [ + { name: { in: ["sensational", "eye-grabbing"] } }, + { name: { eq: "eye-grabbing" } }, + ] + + } + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test contains 2 + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test reverse + query = """ + query { + tags (filter: { + articles: { + contains: [ + { headline: { eq: "Article! Look!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": { + "edges": [ + {"node": {"name": "sensational"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # Test n:m relationship containsExactly @pytest.mark.xfail @pytest.mark.asyncio From 0fb1db3e61c09a4513930f36e307f9a09d8c99fb Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sun, 4 Jun 2023 22:50:59 +0200 Subject: [PATCH 72/81] chore: make flake8 happy --- graphene_sqlalchemy/converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 32062df3..3e37ce5f 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -330,7 +330,7 @@ def convert_scalar_type(type_arg: Any, **kwargs): @convert_sqlalchemy_type.register(safe_isinstance(TypeVar)) -def convert_scalar_type(type_arg: Any, replace_type_vars: Dict[TypeVar, Any], **kwargs): +def convert_type_var(type_arg: Any, replace_type_vars: Dict[TypeVar, Any], **kwargs): return replace_type_vars[type_arg] From 87bbd6fc1363c2662c13e07dab26f573bb7c963d Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 5 Jun 2023 15:40:18 -0400 Subject: [PATCH 73/81] test: breaking tests on enums --- graphene_sqlalchemy/tests/test_filters.py | 63 ++++++++++++++++++++--- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 20672c2f..cf833f81 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -207,6 +207,60 @@ class Query(graphene.ObjectType): assert_and_raise_result(result, expected) +# Test filtering on enums +@pytest.mark.asyncio +async def test_filter_enum(session): + await add_test_data(session) + + Query = create_schema(session) + + # test sqlalchemy enum + query = """ + query { + reporters (filter: { + favoritePetKind: {eq: "dog"} + } + ) { + edges { + node { + firstName + lastName + } + } + } + } + """ + expected = { + "pets": {"edges": [{"node": {"firstName": "Jane", "lastName": "Roe"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test Python enum and sqlalchemy enum + query = """ + query { + pets (filter: { + and: [ + { hairKind: {eq: LONG} }, + { petKind: {eq: "dog"} } + ]}) { + edges { + node { + name + } + } + } + } + """ + expected = { + "pets": {"edges": [{"node": {"name": "Lassie"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # Test a 1:1 relationship @pytest.mark.asyncio async def test_filter_relationship_one_to_one(session): @@ -443,7 +497,7 @@ async def test_filter_relationship_many_to_many_contains_with_and(session): { name: { in: ["sensational", "eye-grabbing"] } }, { name: { eq: "eye-grabbing" } }, ] - + } ] } @@ -788,8 +842,7 @@ async def test_filter_logic_and(session): reporters (filter: { and: [ { firstName: { eq: "John" } }, - # TODO get enums working for filters - # { favoritePetKind: { eq: "cat" } }, + { favoritePetKind: { eq: "cat" } }, ] }) { edges { @@ -821,8 +874,7 @@ async def test_filter_logic_or(session): reporters (filter: { or: [ { lastName: { eq: "Woe" } }, - # TODO get enums working for filters - #{ favoritePetKind: { eq: "dog" } }, + { favoritePetKind: { eq: "dog" } }, ] }) { edges { @@ -838,7 +890,6 @@ async def test_filter_logic_or(session): "reporters": { "edges": [ {"node": {"firstName": "John", "lastName": "Woe"}}, - # TODO get enums working for filters # {"node": {"firstName": "Jane", "lastName": "Roe"}}, ] } From 06c90cba2b50c555b9b740f2dc516c416bfcd4c6 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 28 Jul 2023 18:52:12 +0200 Subject: [PATCH 74/81] fix: create special enum filters. Code pending refactor. --- graphene_sqlalchemy/filters.py | 46 +++++- graphene_sqlalchemy/registry.py | 61 +++++--- graphene_sqlalchemy/tests/conftest.py | 28 ++-- graphene_sqlalchemy/tests/test_filters.py | 16 ++- graphene_sqlalchemy/types.py | 163 ++++++++++++---------- 5 files changed, 202 insertions(+), 112 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 1f227e17..0e0de900 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -49,7 +49,6 @@ def __init_subclass_with_meta__( logic_functions = _get_functions_by_regex(".+_logic$", "_logic$", cls) new_filter_fields = {} - print(f"Generating Filter for {cls.__name__} with model {model} ") # Generate Graphene Fields from the filter functions based on type hints for field_name, _annotations in logic_functions: assert ( @@ -70,9 +69,6 @@ def __init_subclass_with_meta__( _meta.fields = filter_fields _meta.fields.update(new_filter_fields) - for field in _meta.fields: - print(f"Added field {field} of type {_meta.fields[field].type}") - _meta.model = model super(BaseTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) @@ -289,6 +285,48 @@ def execute_filters( return query, clauses +class SQLEnumFilter(FieldFilter): + """Basic Filter for Scalars in Graphene. + We want this filter to use Dynamic fields so it provides the base + filtering methods ("eq, nEq") for different types of scalars. + The Dynamic fields will resolve to Meta.filtered_type""" + class Meta: + graphene_type = graphene.Enum + + # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method + @classmethod + def eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return field == val.value + + @classmethod + def n_eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return not_(field == val.value) + +class PyEnumFilter(FieldFilter): + """Basic Filter for Scalars in Graphene. + We want this filter to use Dynamic fields so it provides the base + filtering methods ("eq, nEq") for different types of scalars. + The Dynamic fields will resolve to Meta.filtered_type""" + class Meta: + graphene_type = graphene.Enum + + # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method + @classmethod + def eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return field == val + + @classmethod + def n_eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return not_(field == val) + class StringFilter(FieldFilter): class Meta: diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index d29c4f6f..75693871 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,18 +1,16 @@ from collections import defaultdict from typing import TYPE_CHECKING, List, Type -from sqlalchemy.types import Enum as SQLAlchemyEnumType - import graphene from graphene import Enum from graphene.types.base import BaseType +from sqlalchemy.types import Enum as SQLAlchemyEnumType if TYPE_CHECKING: # pragma: no_cover from graphene_sqlalchemy.filters import ( FieldFilter, BaseTypeFilter, - RelationshipFilter, - ) + RelationshipFilter, ) class Registry(object): @@ -81,7 +79,7 @@ def register_sort_enum(self, obj_type, sort_enum: Enum): from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -94,7 +92,7 @@ def get_sort_enum_for_object_type(self, obj_type: graphene.ObjectType): return self._registry_sort_enums.get(obj_type) def register_union_type( - self, union: Type[graphene.Union], obj_types: List[Type[graphene.ObjectType]] + self, union: Type[graphene.Union], obj_types: List[Type[graphene.ObjectType]] ): if not issubclass(union, graphene.Union): raise TypeError("Expected graphene.Union, but got: {!r}".format(union)) @@ -112,7 +110,7 @@ def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]]) # Filter Scalar Fields of Object Types def register_filter_for_scalar_type( - self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"] + self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"] ): from .filters import FieldFilter @@ -123,21 +121,49 @@ def register_filter_for_scalar_type( raise TypeError("Expected ScalarFilter, but got: {!r}".format(filter_obj)) self._registry_scalar_filters[scalar_type] = filter_obj + def get_filter_for_sql_enum_type( + self, enum_type: Type[graphene.Enum] + ) -> Type["FieldFilter"]: + from .filters import SQLEnumFilter + + filter_type = self._registry_scalar_filters.get(enum_type) + if not filter_type: + filter_type = SQLEnumFilter.create_type( + f"Default{enum_type.__name__}EnumFilter", graphene_type=enum_type + ) + self._registry_scalar_filters[enum_type] = filter_type + return filter_type + + def get_filter_for_py_enum_type( + self, enum_type: Type[graphene.Enum] + ) -> Type["FieldFilter"]: + from .filters import PyEnumFilter + + filter_type = self._registry_scalar_filters.get(enum_type) + if not filter_type: + filter_type = PyEnumFilter.create_type( + f"Default{enum_type.__name__}EnumFilter", graphene_type=enum_type + ) + self._registry_scalar_filters[enum_type] = filter_type + return filter_type + def get_filter_for_scalar_type( - self, scalar_type: Type[graphene.Scalar] + self, scalar_type: Type[graphene.Scalar] ) -> Type["FieldFilter"]: from .filters import FieldFilter filter_type = self._registry_scalar_filters.get(scalar_type) if not filter_type: - return FieldFilter.create_type( + filter_type = FieldFilter.create_type( f"Default{scalar_type.__name__}ScalarFilter", graphene_type=scalar_type ) + self._registry_scalar_filters[scalar_type] = filter_type + return filter_type # TODO register enums automatically def register_filter_for_enum_type( - self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"] + self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"] ): from .filters import FieldFilter @@ -148,16 +174,11 @@ def register_filter_for_enum_type( raise TypeError("Expected FieldFilter, but got: {!r}".format(filter_obj)) self._registry_scalar_filters[enum_type] = filter_obj - def get_filter_for_enum_type( - self, enum_type: Type[graphene.Enum] - ) -> Type["FieldFilter"]: - return self._registry_enum_type_filters.get(enum_type) - # Filter Base Types def register_filter_for_base_type( - self, - base_type: Type[BaseType], - filter_obj: Type["BaseTypeFilter"], + self, + base_type: Type[BaseType], + filter_obj: Type["BaseTypeFilter"], ): from .filters import BaseTypeFilter @@ -175,7 +196,7 @@ def get_filter_for_base_type(self, base_type: Type[BaseType]): # Filter Relationships between base types def register_relationship_filter_for_base_type( - self, base_type: BaseType, filter_obj: Type["RelationshipFilter"] + self, base_type: BaseType, filter_obj: Type["RelationshipFilter"] ): from .filters import RelationshipFilter @@ -189,7 +210,7 @@ def register_relationship_filter_for_base_type( self._registry_relationship_filters[base_type] = filter_obj def get_relationship_filter_for_base_type( - self, base_type: Type[BaseType] + self, base_type: Type[BaseType] ) -> "RelationshipFilter": return self._registry_relationship_filters.get(base_type) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 89b357a4..80047357 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,14 +1,15 @@ +from typing import Literal + +import graphene import pytest import pytest_asyncio from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -import graphene from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 - +from .models import Base, CompositeFullName from ..converter import convert_sqlalchemy_composite from ..registry import reset_global_registry -from .models import Base, CompositeFullName if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine @@ -25,14 +26,23 @@ def convert_composite_class(composite, registry): return graphene.Field(graphene.Int) -@pytest.fixture(params=[False, True]) -def async_session(request): +# make a typed literal for session one is sync and one is async +SESSION_TYPE = Literal["sync", "session_factory"] + + +@pytest.fixture(params=["sync", "async"]) +def session_type(request) -> SESSION_TYPE: return request.param @pytest.fixture -def test_db_url(async_session: bool): - if async_session: +def async_session(session_type): + return session_type == "async" + + +@pytest.fixture +def test_db_url(session_type: SESSION_TYPE): + if session_type == "async": return "sqlite+aiosqlite://" else: return "sqlite://" @@ -40,8 +50,8 @@ def test_db_url(async_session: bool): @pytest.mark.asyncio @pytest_asyncio.fixture(scope="function") -async def session_factory(async_session: bool, test_db_url: str): - if async_session: +async def session_factory(session_type: SESSION_TYPE, test_db_url: str): + if session_type == "async": if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4: pytest.skip("Async Sessions only work in sql alchemy 1.4 and above") engine = create_async_engine(test_db_url) diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index cf833f81..6a936413 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -218,20 +218,21 @@ async def test_filter_enum(session): query = """ query { reporters (filter: { - favoritePetKind: {eq: "dog"} + favoritePetKind: {eq: DOG} } ) { edges { node { firstName lastName + favoritePetKind } } } } """ expected = { - "pets": {"edges": [{"node": {"firstName": "Jane", "lastName": "Roe"}}]}, + "reporters": {"edges": [{"node": {"firstName": "Jane", "lastName": "Roe", "favoritePetKind": "DOG"}}]}, } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) @@ -243,7 +244,7 @@ async def test_filter_enum(session): pets (filter: { and: [ { hairKind: {eq: LONG} }, - { petKind: {eq: "dog"} } + { petKind: {eq: DOG} } ]}) { edges { node { @@ -842,7 +843,7 @@ async def test_filter_logic_and(session): reporters (filter: { and: [ { firstName: { eq: "John" } }, - { favoritePetKind: { eq: "cat" } }, + { favoritePetKind: { eq: CAT } }, ] }) { edges { @@ -874,13 +875,14 @@ async def test_filter_logic_or(session): reporters (filter: { or: [ { lastName: { eq: "Woe" } }, - { favoritePetKind: { eq: "dog" } }, + { favoritePetKind: { eq: DOG } }, ] }) { edges { node { firstName lastName + favoritePetKind } } } @@ -889,8 +891,8 @@ async def test_filter_logic_or(session): expected = { "reporters": { "edges": [ - {"node": {"firstName": "John", "lastName": "Woe"}}, - # {"node": {"firstName": "Jane", "lastName": "Roe"}}, + {"node": {"firstName": "John", "lastName": "Woe", "favoritePetKind": "CAT"}}, + {"node": {"firstName": "Jane", "lastName": "Roe", "favoritePetKind": "DOG"}}, ] } } diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index fee9d386..70715f2b 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,17 +1,13 @@ import inspect -import types import warnings from collections import OrderedDict +from functools import partial from inspect import isawaitable from typing import Any, Optional, Type, Union -import sqlalchemy -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty -from sqlalchemy.orm.exc import NoResultFound - import graphene -from graphene import Field, InputField +import sqlalchemy +from graphene import Field, InputField, Dynamic from graphene.relay import Connection, Node from graphene.types.base import BaseType from graphene.types.interface import Interface, InterfaceOptions @@ -19,8 +15,11 @@ from graphene.types.unmountedtype import UnmountedType from graphene.types.utils import yank_fields_from_attrs from graphene.utils.orderedtype import OrderedType -import graphene_sqlalchemy.filters as gsa_filters +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty +from sqlalchemy.orm.exc import NoResultFound +import graphene_sqlalchemy.filters as gsa_filters from .converter import ( convert_sqlalchemy_column, convert_sqlalchemy_composite, @@ -50,17 +49,17 @@ class ORMField(OrderedType): def __init__( - self, - model_attr=None, - type_=None, - required=None, - description=None, - deprecation_reason=None, - batching=None, - create_filter=None, - filter_type: Optional[Type] = None, - _creation_counter=None, - **field_kwargs, + self, + model_attr=None, + type_=None, + required=None, + description=None, + deprecation_reason=None, + batching=None, + create_filter=None, + filter_type: Optional[Type] = None, + _creation_counter=None, + **field_kwargs, ): """ Use this to override fields automatically generated by SQLAlchemyObjectType. @@ -126,7 +125,7 @@ class Meta: def get_or_create_relationship_filter( - base_type: Type[BaseType], registry: Registry + base_type: Type[BaseType], registry: Registry ) -> Type[RelationshipFilter]: relationship_filter = registry.get_relationship_filter_for_base_type(base_type) @@ -144,10 +143,39 @@ def get_or_create_relationship_filter( return relationship_filter +def filter_field_from_field( + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + type_, + registry: Registry, + model_attr: Any, +) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: + # Field might be a SQLAlchemyObjectType, due to hybrid properties + if issubclass(type_, SQLAlchemyObjectType): + filter_class = registry.get_filter_for_base_type(type_) + return graphene.InputField(filter_class) + # Enum Special Case + if issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty): + column = model_attr.columns[0] + model_enum_type: Optional[sqlalchemy.types.Enum] = getattr(column, "type", None) + if not getattr(model_enum_type, "enum_class", None): + filter_class = registry.get_filter_for_sql_enum_type(type_) + else: + filter_class = registry.get_filter_for_py_enum_type(type_) + else: + filter_class = registry.get_filter_for_scalar_type(type_) + if not filter_class: + warnings.warn( + f"No compatible filters found for {field.type}. Skipping field." + ) + return None + return graphene.InputField(filter_class) + + def filter_field_from_type_field( - field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], - registry: Registry, - filter_type: Optional[Type], + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + registry: Registry, + filter_type: Optional[Type], + model_attr: Any, ) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: # If a custom filter type was set for this field, use it here if filter_type: @@ -166,7 +194,7 @@ def resolve_dynamic(): from .fields import UnsortedSQLAlchemyConnectionField if isinstance(type_, SQLAlchemyConnectionField) or isinstance( - type_, UnsortedSQLAlchemyConnectionField + type_, UnsortedSQLAlchemyConnectionField ): inner_type = get_nullable_type(type_.type.Edge.node._type) reg_res = get_or_create_relationship_filter(inner_type, registry) @@ -189,32 +217,24 @@ def resolve_dynamic(): warnings.warn(f"Unexpected Dynamic Type: {type_}") # Investigate # raise Exception(f"Unexpected Dynamic Type: {type_}") - return graphene.Dynamic(resolve_dynamic) + return Dynamic(resolve_dynamic) if isinstance(field, graphene.List): print("Got list") return - if isinstance(field._type, types.FunctionType): - print("got field with function type") - return + # if isinstance(field._type, types.FunctionType): + # print("got field with function type") + # return if isinstance(field._type, graphene.Dynamic): return if isinstance(field._type, graphene.List): print("got field with list type") return if isinstance(field, graphene.Field): - type_ = get_nullable_type(field.type) - # Field might be a SQLAlchemyObjectType, due to hybrid properties - if issubclass(type_, SQLAlchemyObjectType): - filter_class = registry.get_filter_for_base_type(type_) - return graphene.InputField(filter_class) - filter_class = registry.get_filter_for_scalar_type(type_) - if not filter_class: - warnings.warn( - f"No compatible filters found for {field.type}. Skipping field." - ) - return None - return graphene.InputField(filter_class) + if inspect.isfunction(field._type) or isinstance(field._type, partial): + return Dynamic(lambda: filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr)) + else: + return filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr) raise Exception(f"Expected a graphene.Field or graphene.Dynamic, but got: {field}") @@ -232,14 +252,14 @@ def get_polymorphic_on(model): def construct_fields_and_filters( - obj_type, - model, - registry, - only_fields, - exclude_fields, - batching, - create_filters, - connection_field_factory, + obj_type, + model, + registry, + only_fields, + exclude_fields, + batching, + create_filters, + connection_field_factory, ): """ Construct all the fields for a SQLAlchemyObjectType. @@ -276,9 +296,9 @@ def construct_fields_and_filters( auto_orm_field_names = [] for attr_name, attr in all_model_attrs.items(): if ( - (only_fields and attr_name not in only_fields) - or (attr_name in exclude_fields) - or attr_name == polymorphic_on + (only_fields and attr_name not in only_fields) + or (attr_name in exclude_fields) + or attr_name == polymorphic_on ): continue auto_orm_field_names.append(attr_name) @@ -356,7 +376,7 @@ def construct_fields_and_filters( fields[orm_field_name] = field if filtering_enabled_for_field: filters[orm_field_name] = filter_field_from_type_field( - field, registry, filter_type + field, registry, filter_type, attr ) return fields, filters @@ -370,21 +390,21 @@ class SQLAlchemyBase(BaseType): @classmethod def __init_subclass_with_meta__( - cls, - model=None, - registry=None, - skip_registry=False, - only_fields=(), - exclude_fields=(), - connection=None, - connection_class=None, - use_connection=None, - interfaces=(), - id=None, - batching=False, - connection_field_factory=None, - _meta=None, - **options, + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + exclude_fields=(), + connection=None, + connection_class=None, + use_connection=None, + interfaces=(), + id=None, + batching=False, + connection_field_factory=None, + _meta=None, + **options, ): # We always want to bypass this hook unless we're defining a concrete # `SQLAlchemyObjectType` or `SQLAlchemyInterface`. @@ -405,9 +425,9 @@ def __init_subclass_with_meta__( filter_cls[1] for filter_cls in inspect.getmembers(gsa_filters, inspect.isclass) if ( - filter_cls[1] is not FieldFilter - and FieldFilter in filter_cls[1].__mro__ - and getattr(filter_cls[1]._meta, "graphene_type", False) + filter_cls[1] is not FieldFilter + and FieldFilter in filter_cls[1].__mro__ + and getattr(filter_cls[1]._meta, "graphene_type", False) ) ] for field_filter_class in field_filter_classes: @@ -522,7 +542,6 @@ def get_node(cls, info, id): session = get_session(info.context) if isinstance(session, AsyncSession): - async def get_result() -> Any: return await session.get(cls._meta.model, id) @@ -654,7 +673,7 @@ def __init_subclass_with_meta__(cls, _meta=None, **options): if hasattr(_meta.model, "__mapper__"): polymorphic_identity = _meta.model.__mapper__.polymorphic_identity assert ( - polymorphic_identity is None + polymorphic_identity is None ), '{}: An interface cannot map to a concrete type (polymorphic_identity is "{}")'.format( cls.__name__, polymorphic_identity ) From 4e34a79b16807bc8f7f6fe7adb224389f7109069 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 28 Jul 2023 18:57:34 +0200 Subject: [PATCH 75/81] fix: use typing extensions --- graphene_sqlalchemy/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 80047357..9489011b 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing_extensions import Literal import graphene import pytest From c38ebb333fdc73d4d07dcbd0aa7044778c5799c1 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 6 Oct 2023 21:09:07 +0200 Subject: [PATCH 76/81] refactor: cleanup filter type generation --- graphene_sqlalchemy/types.py | 98 +++++++++++++++++------------------- 1 file changed, 47 insertions(+), 51 deletions(-) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 70715f2b..269f451f 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,4 +1,5 @@ import inspect +import logging import warnings from collections import OrderedDict from functools import partial @@ -46,6 +47,8 @@ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession +logger = logging.getLogger(__name__) + class ORMField(OrderedType): def __init__( @@ -171,6 +174,37 @@ def filter_field_from_field( return graphene.InputField(filter_class) +def resolve_dynamic_relationship_filter( + field: graphene.Dynamic, + registry: Registry, + model_attr: Any, +) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: + # Resolve Dynamic Type + type_ = get_nullable_type(field.get_type()) + from graphene_sqlalchemy import SQLAlchemyConnectionField + + # Connections always result in list filters + if isinstance(type_, SQLAlchemyConnectionField): + inner_type = get_nullable_type(type_.type.Edge.node._type) + reg_res = get_or_create_relationship_filter(inner_type, registry) + # Field relationships can either be a list or a single object + elif isinstance(type_, Field): + if isinstance(type_.type, graphene.List): + inner_type = get_nullable_type(type_.type.of_type) + reg_res = get_or_create_relationship_filter(inner_type, registry) + else: + reg_res = registry.get_filter_for_base_type(type_.type) + else: + # Other dynamic type constellation are not yet supported, + # please open an issue with reproduction if you need them + reg_res = None + + if not reg_res: + return None + + return graphene.InputField(reg_res) + + def filter_field_from_type_field( field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], registry: Registry, @@ -180,63 +214,25 @@ def filter_field_from_type_field( # If a custom filter type was set for this field, use it here if filter_type: return graphene.InputField(filter_type) - if issubclass(type(field), graphene.Scalar): + elif issubclass(type(field), graphene.Scalar): filter_class = registry.get_filter_for_scalar_type(type(field)) return graphene.InputField(filter_class) - # If the field is Dynamic, we don't know its type yet and can't select the right filter - if isinstance(field, graphene.Dynamic): - - def resolve_dynamic(): - # Resolve Dynamic Type - type_ = get_nullable_type(field.get_type()) - from graphene_sqlalchemy import SQLAlchemyConnectionField - - from .fields import UnsortedSQLAlchemyConnectionField - - if isinstance(type_, SQLAlchemyConnectionField) or isinstance( - type_, UnsortedSQLAlchemyConnectionField - ): - inner_type = get_nullable_type(type_.type.Edge.node._type) - reg_res = get_or_create_relationship_filter(inner_type, registry) - if not reg_res: - print("filter class was none!!!") - print(type_) - return graphene.InputField(reg_res) - elif isinstance(type_, Field): - if isinstance(type_.type, graphene.List): - inner_type = get_nullable_type(type_.type.of_type) - reg_res = get_or_create_relationship_filter(inner_type, registry) - if not reg_res: - print("filter class was none!!!") - print(type_) - return graphene.InputField(reg_res) - reg_res = registry.get_filter_for_base_type(type_.type) - - return graphene.InputField(reg_res) - else: - warnings.warn(f"Unexpected Dynamic Type: {type_}") # Investigate - # raise Exception(f"Unexpected Dynamic Type: {type_}") - - return Dynamic(resolve_dynamic) - - if isinstance(field, graphene.List): - print("Got list") - return - # if isinstance(field._type, types.FunctionType): - # print("got field with function type") - # return - if isinstance(field._type, graphene.Dynamic): - return - if isinstance(field._type, graphene.List): - print("got field with list type") - return - if isinstance(field, graphene.Field): + # If the generated field is Dynamic, it is always a relationship + # (due to graphene-sqlalchemy's conversion mechanism). + elif isinstance(field, graphene.Dynamic): + return Dynamic(partial(resolve_dynamic_relationship_filter, field, registry, model_attr)) + elif isinstance(field, graphene.Field): if inspect.isfunction(field._type) or isinstance(field._type, partial): return Dynamic(lambda: filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr)) else: return filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr) - - raise Exception(f"Expected a graphene.Field or graphene.Dynamic, but got: {field}") + # Unsupported but theoretically possible cases, please drop us an issue with reproduction if you need them + elif isinstance(field, graphene.List) or isinstance(field._type, graphene.List): + # Pure lists are not yet supported + pass + elif isinstance(field._type, graphene.Dynamic): + # Fields with nested dynamic Dynamic are not yet supported + pass def get_polymorphic_on(model): From 064adc7c623061bd9c38c900981051651e4b3a99 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 6 Oct 2023 21:47:26 +0200 Subject: [PATCH 77/81] feat(filters): support filter aliasing (PR #378) --- graphene_sqlalchemy/filters.py | 32 ++++++++++- graphene_sqlalchemy/tests/test_filters.py | 65 ++++++++++++++++------- graphene_sqlalchemy/types.py | 39 ++++++++------ 3 files changed, 97 insertions(+), 39 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 0e0de900..0af0a698 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -1,6 +1,7 @@ import re from typing import Any, Dict, List, Tuple, Type, TypeVar, Union +from graphql import Undefined from sqlalchemy import and_, not_, or_ from sqlalchemy.orm import Query, aliased # , selectinload @@ -15,6 +16,31 @@ "BaseTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer ) +class SQLAlchemyFilterInputField(graphene.InputField): + def __init__( + self, + type_, + model_attr, + name=None, + default_value=Undefined, + deprecation_reason=None, + description=None, + required=False, + _creation_counter=None, + **extra_args, + ): + super(SQLAlchemyFilterInputField, self).__init__( + type_, + name, + default_value, + deprecation_reason, + description, + required, + _creation_counter, + **extra_args, + ) + + self.model_attr = model_attr def _get_functions_by_regex( regex: str, subtract_regex: str, class_: Type @@ -138,7 +164,8 @@ def execute_filters( # Check with a profiler is required to determine necessity input_field = cls._meta.fields[field] if isinstance(input_field, graphene.Dynamic): - field_filter_type = input_field.get_type().type + input_field = input_field.get_type() + field_filter_type = input_field.type else: field_filter_type = cls._meta.fields[field].type # raise Exception @@ -155,7 +182,8 @@ def execute_filters( ) clauses.extend(_clauses) else: - model_field = getattr(model, field) + # Get the model attr from the inputfield in case the field is aliased in graphql + model_field = getattr(model, input_field.model_attr or field) if issubclass(field_filter_type, BaseTypeFilter): # Get the model to join on the Filter Query joined_model = field_filter_type._meta.model diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 6a936413..026247ca 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -1,12 +1,8 @@ -import pytest -from sqlalchemy.sql.operators import is_ - import graphene +import pytest from graphene import Connection, relay +from sqlalchemy.sql.operators import is_ -from ..fields import SQLAlchemyConnectionField -from ..filters import FloatFilter -from ..types import ORMField, SQLAlchemyObjectType from .models import ( Article, Editor, @@ -20,6 +16,10 @@ Tag, ) from .utils import eventually_await_session, to_std_dicts +from ..fields import SQLAlchemyConnectionField +from ..filters import FloatFilter +from ..types import ORMField, SQLAlchemyObjectType + # TODO test that generated schema is correct for all examples with: # with open('schema.gql', 'w') as fp: @@ -110,26 +110,13 @@ class Meta: class Query(graphene.ObjectType): node = relay.Node.Field() - # # TODO how to create filterable singular field? - # article = graphene.Field(ArticleType) articles = SQLAlchemyConnectionField(ArticleType.connection) - # image = graphene.Field(ImageType) images = SQLAlchemyConnectionField(ImageType.connection) readers = SQLAlchemyConnectionField(ReaderType.connection) - # reporter = graphene.Field(ReporterType) reporters = SQLAlchemyConnectionField(ReporterType.connection) pets = SQLAlchemyConnectionField(PetType.connection) tags = SQLAlchemyConnectionField(TagType.connection) - # def resolve_article(self, _info): - # return session.query(Article).first() - - # def resolve_image(self, _info): - # return session.query(Image).first() - - # def resolve_reporter(self, _info): - # return session.query(Reporter).first() - return Query @@ -159,6 +146,44 @@ async def test_filter_simple(session): assert_and_raise_result(result, expected) +@pytest.mark.asyncio +async def test_filter_alias(session): + """ + Test aliasing of column names in the type + """ + await add_test_data(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + + lastNameAlias = ORMField(model_attr="last_name") + + class Query(graphene.ObjectType): + node = relay.Node.Field() + reporters = SQLAlchemyConnectionField(ReporterType.connection) + + query = """ + query { + reporters (filter: {lastNameAlias: {eq: "Roe", like: "%oe"}}) { + edges { + node { + firstName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"firstName": "Jane"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # Test a custom filter type @pytest.mark.asyncio async def test_filter_custom_type(session): @@ -1084,7 +1109,7 @@ async def test_filter_hybrid_property(session): result = to_std_dicts(result.data) assert len(result["carts"]["edges"]) == 1 assert ( - len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2 + len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2 ) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 269f451f..7a693a4d 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -32,7 +32,7 @@ sort_argument_for_object_type, sort_enum_for_object_type, ) -from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter +from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter, SQLAlchemyFilterInputField from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( @@ -151,13 +151,13 @@ def filter_field_from_field( type_, registry: Registry, model_attr: Any, -) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: + model_attr_name: str +) -> Optional[graphene.InputField]: # Field might be a SQLAlchemyObjectType, due to hybrid properties if issubclass(type_, SQLAlchemyObjectType): filter_class = registry.get_filter_for_base_type(type_) - return graphene.InputField(filter_class) # Enum Special Case - if issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty): + elif issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty): column = model_attr.columns[0] model_enum_type: Optional[sqlalchemy.types.Enum] = getattr(column, "type", None) if not getattr(model_enum_type, "enum_class", None): @@ -168,16 +168,16 @@ def filter_field_from_field( filter_class = registry.get_filter_for_scalar_type(type_) if not filter_class: warnings.warn( - f"No compatible filters found for {field.type}. Skipping field." + f"No compatible filters found for {field.type} with db name {model_attr_name}. Skipping field." ) return None - return graphene.InputField(filter_class) + return SQLAlchemyFilterInputField(filter_class, model_attr_name) def resolve_dynamic_relationship_filter( field: graphene.Dynamic, registry: Registry, - model_attr: Any, + model_attr_name: str ) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: # Resolve Dynamic Type type_ = get_nullable_type(field.get_type()) @@ -200,9 +200,12 @@ def resolve_dynamic_relationship_filter( reg_res = None if not reg_res: + warnings.warn( + f"No compatible filters found for {field} with db name {model_attr_name}. Skipping field." + ) return None - return graphene.InputField(reg_res) + return SQLAlchemyFilterInputField(reg_res, model_attr_name) def filter_field_from_type_field( @@ -210,22 +213,18 @@ def filter_field_from_type_field( registry: Registry, filter_type: Optional[Type], model_attr: Any, + model_attr_name: str ) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: # If a custom filter type was set for this field, use it here if filter_type: - return graphene.InputField(filter_type) + return SQLAlchemyFilterInputField(filter_type, model_attr_name) elif issubclass(type(field), graphene.Scalar): filter_class = registry.get_filter_for_scalar_type(type(field)) - return graphene.InputField(filter_class) + return SQLAlchemyFilterInputField(filter_class, model_attr_name) # If the generated field is Dynamic, it is always a relationship # (due to graphene-sqlalchemy's conversion mechanism). elif isinstance(field, graphene.Dynamic): - return Dynamic(partial(resolve_dynamic_relationship_filter, field, registry, model_attr)) - elif isinstance(field, graphene.Field): - if inspect.isfunction(field._type) or isinstance(field._type, partial): - return Dynamic(lambda: filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr)) - else: - return filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr) + return Dynamic(partial(resolve_dynamic_relationship_filter, field, registry, model_attr_name)) # Unsupported but theoretically possible cases, please drop us an issue with reproduction if you need them elif isinstance(field, graphene.List) or isinstance(field._type, graphene.List): # Pure lists are not yet supported @@ -233,6 +232,12 @@ def filter_field_from_type_field( elif isinstance(field._type, graphene.Dynamic): # Fields with nested dynamic Dynamic are not yet supported pass + # Order matters, this comes last as field._type == list also matches Field + elif isinstance(field, graphene.Field): + if inspect.isfunction(field._type) or isinstance(field._type, partial): + return Dynamic(lambda: filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr, model_attr_name)) + else: + return filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr, model_attr_name) def get_polymorphic_on(model): @@ -372,7 +377,7 @@ def construct_fields_and_filters( fields[orm_field_name] = field if filtering_enabled_for_field: filters[orm_field_name] = filter_field_from_type_field( - field, registry, filter_type, attr + field, registry, filter_type, attr, attr_name ) return fields, filters From 1aef7483a0bc4d9c056027f621e11f90cff88441 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 6 Oct 2023 21:48:31 +0200 Subject: [PATCH 78/81] chore: remove print statements --- graphene_sqlalchemy/filters.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 0af0a698..5a4e684e 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -191,13 +191,6 @@ def execute_filters( joined_model_alias = aliased(joined_model) # Join the aliased model onto the query query = query.join(model_field.of_type(joined_model_alias)) - - if model_alias: - print("=======================") - print( - f"joining model {joined_model} on {model_alias} with alias {joined_model_alias}" - ) - print(str(query)) # Pass the joined query down to the next object type filter for processing query, _clauses = field_filter_type.execute_filters( query, field_filters, model_alias=joined_model_alias @@ -221,7 +214,6 @@ def execute_filters( query, _clauses = field_filter_type.execute_filters( query, model_field, field_filters ) - print([str(cla) for cla in _clauses]) clauses.extend(_clauses) return query, clauses @@ -484,14 +476,11 @@ def contains_filter( ): clauses = [] for v in val: - print("executing contains filter", v) # Always alias the model joined_model_alias = aliased(relationship_prop) # Join the aliased model onto the query query = query.join(field.of_type(joined_model_alias)).distinct() - print("Joined model", relationship_prop) - print(query) # pass the alias so group can join group query, _clauses = cls._meta.base_type_filter.execute_filters( query, v, model_alias=joined_model_alias From 18a7c5488daf2bda36248f3f3df98cfd7e2f3592 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 6 Oct 2023 21:55:41 +0200 Subject: [PATCH 79/81] chore: move base filter creation --- graphene_sqlalchemy/registry.py | 23 ++++++++++++++++++++++- graphene_sqlalchemy/types.py | 15 --------------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 75693871..2a45e787 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,3 +1,4 @@ +import inspect from collections import defaultdict from typing import TYPE_CHECKING, List, Type @@ -7,7 +8,7 @@ from sqlalchemy.types import Enum as SQLAlchemyEnumType if TYPE_CHECKING: # pragma: no_cover - from graphene_sqlalchemy.filters import ( + from .filters import ( FieldFilter, BaseTypeFilter, RelationshipFilter, ) @@ -26,6 +27,26 @@ def __init__(self): self._registry_base_type_filters = {} self._registry_relationship_filters = {} + self._init_base_filters() + + def _init_base_filters(self): + import graphene_sqlalchemy.filters as gsqa_filters + + from .filters import (FieldFilter) + field_filter_classes = [ + filter_cls[1] + for filter_cls in inspect.getmembers(gsqa_filters, inspect.isclass) + if ( + filter_cls[1] is not FieldFilter + and FieldFilter in filter_cls[1].__mro__ + and getattr(filter_cls[1]._meta, "graphene_type", False) + ) + ] + for field_filter_class in field_filter_classes: + self.register_filter_for_scalar_type( + field_filter_class._meta.graphene_type, field_filter_class + ) + def register(self, obj_type): from .types import SQLAlchemyBase diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 7a693a4d..b99f0236 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -20,7 +20,6 @@ from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty from sqlalchemy.orm.exc import NoResultFound -import graphene_sqlalchemy.filters as gsa_filters from .converter import ( convert_sqlalchemy_column, convert_sqlalchemy_composite, @@ -420,21 +419,7 @@ def __init_subclass_with_meta__( ) if not registry: - # TODO add documentation for users to register their own filters registry = get_global_registry() - field_filter_classes = [ - filter_cls[1] - for filter_cls in inspect.getmembers(gsa_filters, inspect.isclass) - if ( - filter_cls[1] is not FieldFilter - and FieldFilter in filter_cls[1].__mro__ - and getattr(filter_cls[1]._meta, "graphene_type", False) - ) - ] - for field_filter_class in field_filter_classes: - get_global_registry().register_filter_for_scalar_type( - field_filter_class._meta.graphene_type, field_filter_class - ) assert isinstance(registry, Registry), ( "The attribute registry in {} needs to be an instance of " From a2b8a9bb5b2d4892b34967845281dd2a9622c01b Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 6 Oct 2023 22:29:00 +0200 Subject: [PATCH 80/81] chore: run pre commit on PR (hook update due to incompatible hooks with local machine) --- .pre-commit-config.yaml | 4 +- graphene_sqlalchemy/converter.py | 110 +++++++-------- graphene_sqlalchemy/filters.py | 6 + graphene_sqlalchemy/registry.py | 45 +++--- graphene_sqlalchemy/tests/conftest.py | 8 +- graphene_sqlalchemy/tests/models.py | 12 +- graphene_sqlalchemy/tests/models_batching.py | 11 +- graphene_sqlalchemy/tests/test_converter.py | 46 +++--- graphene_sqlalchemy/tests/test_filters.py | 42 ++++-- graphene_sqlalchemy/types.py | 140 +++++++++++-------- 10 files changed, 227 insertions(+), 197 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 470a29eb..262e7608 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.7 + python: python3.8 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.2.0 @@ -12,7 +12,7 @@ repos: - id: trailing-whitespace exclude: README.md - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort name: isort (python) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 3e37ce5f..161848f6 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -3,7 +3,7 @@ import typing import uuid from decimal import Decimal -from typing import Any, Optional, Union, cast, Dict, TypeVar +from typing import Any, Dict, Optional, TypeVar, Union, cast from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql @@ -101,12 +101,12 @@ def is_column_nullable(column): def convert_sqlalchemy_relationship( - relationship_prop, - obj_type, - connection_field_factory, - batching, - orm_field_name, - **field_kwargs, + relationship_prop, + obj_type, + connection_field_factory, + batching, + orm_field_name, + **field_kwargs, ): """ :param sqlalchemy.RelationshipProperty relationship_prop: @@ -147,7 +147,7 @@ def dynamic_type(): def _convert_o2o_or_m2o_relationship( - relationship_prop, obj_type, batching, orm_field_name, **field_kwargs + relationship_prop, obj_type, batching, orm_field_name, **field_kwargs ): """ Convert one-to-one or many-to-one relationshsip. Return an object field. @@ -175,7 +175,7 @@ def _convert_o2o_or_m2o_relationship( def _convert_o2m_or_m2m_relationship( - relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs + relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs ): """ Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field. @@ -281,11 +281,11 @@ def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): @singledispatchbymatchfunction def convert_sqlalchemy_type( # noqa - type_arg: Any, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - replace_type_vars: typing.Dict[str, Any] = None, - **kwargs, + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + replace_type_vars: typing.Dict[str, Any] = None, + **kwargs, ): if replace_type_vars and type_arg in replace_type_vars: return replace_type_vars[type_arg] @@ -301,7 +301,7 @@ def convert_sqlalchemy_type( # noqa @convert_sqlalchemy_type.register(safe_isinstance(DeclarativeMeta)) def convert_sqlalchemy_model_using_registry( - type_arg: Any, registry: Registry = None, **kwargs + type_arg: Any, registry: Registry = None, **kwargs ): registry_ = registry or get_global_registry() @@ -353,8 +353,8 @@ def convert_column_to_string(type_arg: Any, **kwargs): @convert_sqlalchemy_type.register(column_type_eq(sqa_utils.UUIDType)) @convert_sqlalchemy_type.register(column_type_eq(uuid.UUID)) def convert_column_to_uuid( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.UUID @@ -362,8 +362,8 @@ def convert_column_to_uuid( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.DateTime)) @convert_sqlalchemy_type.register(column_type_eq(datetime.datetime)) def convert_column_to_datetime( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.DateTime @@ -371,8 +371,8 @@ def convert_column_to_datetime( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Time)) @convert_sqlalchemy_type.register(column_type_eq(datetime.time)) def convert_column_to_time( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.Time @@ -380,8 +380,8 @@ def convert_column_to_time( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Date)) @convert_sqlalchemy_type.register(column_type_eq(datetime.date)) def convert_column_to_date( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.Date @@ -390,10 +390,10 @@ def convert_column_to_date( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Integer)) @convert_sqlalchemy_type.register(column_type_eq(int)) def convert_column_to_int_or_id( - type_arg: Any, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - **kwargs, + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, ): # fixme drop the primary key processing from here in another pr if column is not None: @@ -405,8 +405,8 @@ def convert_column_to_int_or_id( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Boolean)) @convert_sqlalchemy_type.register(column_type_eq(bool)) def convert_column_to_boolean( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.Boolean @@ -416,8 +416,8 @@ def convert_column_to_boolean( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Numeric)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.BigInteger)) def convert_column_to_float( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.Float @@ -425,10 +425,10 @@ def convert_column_to_float( @convert_sqlalchemy_type.register(column_type_eq(postgresql.ENUM)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Enum)) def convert_enum_to_enum( - type_arg: Any, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - **kwargs, + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, ): if column is None or isinstance(column, hybrid_property): raise Exception("SQL-Enum conversion requires a column") @@ -439,9 +439,9 @@ def convert_enum_to_enum( # TODO Make ChoiceType conversion consistent with other enums @convert_sqlalchemy_type.register(column_type_eq(sqa_utils.ChoiceType)) def convert_choice_to_enum( - type_arg: sqa_utils.ChoiceType, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - **kwargs, + type_arg: sqa_utils.ChoiceType, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + **kwargs, ): if column is None or isinstance(column, hybrid_property): raise Exception("ChoiceType conversion requires a column") @@ -457,8 +457,8 @@ def convert_choice_to_enum( @convert_sqlalchemy_type.register(column_type_eq(sqa_utils.ScalarListType)) def convert_scalar_list_to_list( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.List(graphene.String) @@ -474,10 +474,10 @@ def init_array_list_recursive(inner_type, n): @convert_sqlalchemy_type.register(column_type_eq(sqa_types.ARRAY)) @convert_sqlalchemy_type.register(column_type_eq(postgresql.ARRAY)) def convert_array_to_list( - type_arg: Any, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - **kwargs, + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, ): if column is None or isinstance(column, hybrid_property): raise Exception("SQL-Array conversion requires a column") @@ -496,8 +496,8 @@ def convert_array_to_list( @convert_sqlalchemy_type.register(column_type_eq(postgresql.JSON)) @convert_sqlalchemy_type.register(column_type_eq(postgresql.JSONB)) def convert_json_to_string( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return JSONString @@ -505,18 +505,18 @@ def convert_json_to_string( @convert_sqlalchemy_type.register(column_type_eq(sqa_utils.JSONType)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.JSON)) def convert_json_type_to_string( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return JSONString @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Variant)) def convert_variant_to_impl_type( - type_arg: sqa_types.Variant, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - **kwargs, + type_arg: sqa_types.Variant, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, ): if column is None or isinstance(column, hybrid_property): raise Exception("Vaiant conversion requires a column") @@ -547,7 +547,7 @@ def is_union(type_arg: Any, **kwargs) -> bool: def graphene_union_for_py_union( - obj_types: typing.List[graphene.ObjectType], registry + obj_types: typing.List[graphene.ObjectType], registry ) -> graphene.Union: union_type = registry.get_union_for_object_types(obj_types) @@ -590,8 +590,8 @@ def convert_sqlalchemy_hybrid_property_union(type_arg: Any, **kwargs): # Now check if every type is instance of an ObjectType if not all( - isinstance(graphene_type, type(graphene.ObjectType)) - for graphene_type in graphene_types + isinstance(graphene_type, type(graphene.ObjectType)) + for graphene_type in graphene_types ): raise ValueError( "Cannot convert hybrid_property Union to graphene.Union: the Union contains scalars. " diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 5a4e684e..bb422724 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -16,6 +16,7 @@ "BaseTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer ) + class SQLAlchemyFilterInputField(graphene.InputField): def __init__( self, @@ -42,6 +43,7 @@ def __init__( self.model_attr = model_attr + def _get_functions_by_regex( regex: str, subtract_regex: str, class_: Type ) -> List[Tuple[str, Dict[str, Any]]]: @@ -305,11 +307,13 @@ def execute_filters( return query, clauses + class SQLEnumFilter(FieldFilter): """Basic Filter for Scalars in Graphene. We want this filter to use Dynamic fields so it provides the base filtering methods ("eq, nEq") for different types of scalars. The Dynamic fields will resolve to Meta.filtered_type""" + class Meta: graphene_type = graphene.Enum @@ -326,11 +330,13 @@ def n_eq_filter( ) -> Union[Tuple[Query, Any], Any]: return not_(field == val.value) + class PyEnumFilter(FieldFilter): """Basic Filter for Scalars in Graphene. We want this filter to use Dynamic fields so it provides the base filtering methods ("eq, nEq") for different types of scalars. The Dynamic fields will resolve to Meta.filtered_type""" + class Meta: graphene_type = graphene.Enum diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 2a45e787..b959d221 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -2,16 +2,14 @@ from collections import defaultdict from typing import TYPE_CHECKING, List, Type +from sqlalchemy.types import Enum as SQLAlchemyEnumType + import graphene from graphene import Enum from graphene.types.base import BaseType -from sqlalchemy.types import Enum as SQLAlchemyEnumType if TYPE_CHECKING: # pragma: no_cover - from .filters import ( - FieldFilter, - BaseTypeFilter, - RelationshipFilter, ) + from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter class Registry(object): @@ -32,14 +30,15 @@ def __init__(self): def _init_base_filters(self): import graphene_sqlalchemy.filters as gsqa_filters - from .filters import (FieldFilter) + from .filters import FieldFilter + field_filter_classes = [ filter_cls[1] for filter_cls in inspect.getmembers(gsqa_filters, inspect.isclass) if ( - filter_cls[1] is not FieldFilter - and FieldFilter in filter_cls[1].__mro__ - and getattr(filter_cls[1]._meta, "graphene_type", False) + filter_cls[1] is not FieldFilter + and FieldFilter in filter_cls[1].__mro__ + and getattr(filter_cls[1]._meta, "graphene_type", False) ) ] for field_filter_class in field_filter_classes: @@ -100,7 +99,7 @@ def register_sort_enum(self, obj_type, sort_enum: Enum): from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -113,7 +112,7 @@ def get_sort_enum_for_object_type(self, obj_type: graphene.ObjectType): return self._registry_sort_enums.get(obj_type) def register_union_type( - self, union: Type[graphene.Union], obj_types: List[Type[graphene.ObjectType]] + self, union: Type[graphene.Union], obj_types: List[Type[graphene.ObjectType]] ): if not issubclass(union, graphene.Union): raise TypeError("Expected graphene.Union, but got: {!r}".format(union)) @@ -131,7 +130,7 @@ def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]]) # Filter Scalar Fields of Object Types def register_filter_for_scalar_type( - self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"] + self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"] ): from .filters import FieldFilter @@ -143,7 +142,7 @@ def register_filter_for_scalar_type( self._registry_scalar_filters[scalar_type] = filter_obj def get_filter_for_sql_enum_type( - self, enum_type: Type[graphene.Enum] + self, enum_type: Type[graphene.Enum] ) -> Type["FieldFilter"]: from .filters import SQLEnumFilter @@ -156,7 +155,7 @@ def get_filter_for_sql_enum_type( return filter_type def get_filter_for_py_enum_type( - self, enum_type: Type[graphene.Enum] + self, enum_type: Type[graphene.Enum] ) -> Type["FieldFilter"]: from .filters import PyEnumFilter @@ -169,7 +168,7 @@ def get_filter_for_py_enum_type( return filter_type def get_filter_for_scalar_type( - self, scalar_type: Type[graphene.Scalar] + self, scalar_type: Type[graphene.Scalar] ) -> Type["FieldFilter"]: from .filters import FieldFilter @@ -184,7 +183,7 @@ def get_filter_for_scalar_type( # TODO register enums automatically def register_filter_for_enum_type( - self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"] + self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"] ): from .filters import FieldFilter @@ -197,9 +196,9 @@ def register_filter_for_enum_type( # Filter Base Types def register_filter_for_base_type( - self, - base_type: Type[BaseType], - filter_obj: Type["BaseTypeFilter"], + self, + base_type: Type[BaseType], + filter_obj: Type["BaseTypeFilter"], ): from .filters import BaseTypeFilter @@ -207,9 +206,7 @@ def register_filter_for_base_type( raise TypeError("Expected BaseType, but got: {!r}".format(base_type)) if not issubclass(filter_obj, BaseTypeFilter): - raise TypeError( - "Expected BaseTypeFilter, but got: {!r}".format(filter_obj) - ) + raise TypeError("Expected BaseTypeFilter, but got: {!r}".format(filter_obj)) self._registry_base_type_filters[base_type] = filter_obj def get_filter_for_base_type(self, base_type: Type[BaseType]): @@ -217,7 +214,7 @@ def get_filter_for_base_type(self, base_type: Type[BaseType]): # Filter Relationships between base types def register_relationship_filter_for_base_type( - self, base_type: BaseType, filter_obj: Type["RelationshipFilter"] + self, base_type: BaseType, filter_obj: Type["RelationshipFilter"] ): from .filters import RelationshipFilter @@ -231,7 +228,7 @@ def register_relationship_filter_for_base_type( self._registry_relationship_filters[base_type] = filter_obj def get_relationship_filter_for_base_type( - self, base_type: Type[BaseType] + self, base_type: Type[BaseType] ) -> "RelationshipFilter": return self._registry_relationship_filters.get(base_type) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 9489011b..2c749da7 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,15 +1,15 @@ -from typing_extensions import Literal - -import graphene import pytest import pytest_asyncio from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from typing_extensions import Literal +import graphene from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 -from .models import Base, CompositeFullName + from ..converter import convert_sqlalchemy_composite from ..registry import reset_global_registry +from .models import Base, CompositeFullName if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 7ec6de32..12554dc2 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -6,6 +6,7 @@ from decimal import Decimal from typing import List, Optional +# fmt: off from sqlalchemy import ( Column, Date, @@ -23,14 +24,15 @@ from sqlalchemy.sql.type_api import TypeEngine from graphene_sqlalchemy.tests.utils import wrap_select_func -from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, SQL_VERSION_HIGHER_EQUAL_THAN_2 +from graphene_sqlalchemy.utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + SQL_VERSION_HIGHER_EQUAL_THAN_2, +) -# fmt: off -import sqlalchemy if SQL_VERSION_HIGHER_EQUAL_THAN_2: - from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip + from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip else: - from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip + from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip # fmt: on PetKind = Enum("cat", "dog", name="pet_kind") diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py index 5dde366f..e0f5d4bd 100644 --- a/graphene_sqlalchemy/tests/models_batching.py +++ b/graphene_sqlalchemy/tests/models_batching.py @@ -2,16 +2,7 @@ import enum -from sqlalchemy import ( - Column, - Date, - Enum, - ForeignKey, - Integer, - String, - Table, - func, -) +from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table, func from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import column_property, relationship diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 60bf5058..1b2e0ec5 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,13 +1,10 @@ import enum import sys -from typing import Dict, Tuple, Union, TypeVar +from typing import Dict, Tuple, TypeVar, Union -import graphene import pytest import sqlalchemy import sqlalchemy_utils as sqa_utils -from graphene.relay import Node -from graphene.types.structures import Structure from sqlalchemy import Column, func, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base @@ -15,15 +12,10 @@ from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite -from .models import ( - Article, - CompositeFullName, - Pet, - Reporter, - ShoppingCart, - ShoppingCartItem, -) -from .utils import wrap_select_func +import graphene +from graphene.relay import Node +from graphene.types.structures import Structure + from ..converter import ( convert_sqlalchemy_column, convert_sqlalchemy_composite, @@ -45,6 +37,7 @@ ShoppingCart, ShoppingCartItem, ) +from .utils import wrap_select_func def mock_resolver(): @@ -210,12 +203,11 @@ def test_converter_replace_type_var(): replace_type_vars = {T: graphene.String} - field_type = convert_sqlalchemy_type( - T, replace_type_vars=replace_type_vars - ) + field_type = convert_sqlalchemy_type(T, replace_type_vars=replace_type_vars) assert field_type == graphene.String + @pytest.mark.skipif( sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" ) @@ -225,9 +217,9 @@ def prop_method() -> int | str: return "not allowed in gql schema" with pytest.raises( - ValueError, - match=r"Cannot convert hybrid_property Union to " - r"graphene.Union: the Union contains scalars. \.*", + ValueError, + match=r"Cannot convert hybrid_property Union to " + r"graphene.Union: the Union contains scalars. \.*", ): get_hybrid_property_type(prop_method) @@ -481,7 +473,9 @@ class TestEnum(enum.IntEnum): def test_should_columproperty_convert(): field = get_field_from_column( - column_property(wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1)) + column_property( + wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1) + ) ) assert field.type == graphene.Int @@ -840,8 +834,8 @@ class Meta: ) for ( - hybrid_prop_name, - hybrid_prop_expected_return_type, + hybrid_prop_name, + hybrid_prop_expected_return_type, ) in shopping_cart_item_expected_types.items(): hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name] @@ -852,7 +846,7 @@ class Meta: str(hybrid_prop_expected_return_type), ) assert ( - hybrid_prop_field.description is None + hybrid_prop_field.description is None ) # "doc" is ignored by hybrid property ################################################### @@ -900,8 +894,8 @@ class Meta: ) for ( - hybrid_prop_name, - hybrid_prop_expected_return_type, + hybrid_prop_name, + hybrid_prop_expected_return_type, ) in shopping_cart_expected_types.items(): hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name] @@ -912,5 +906,5 @@ class Meta: str(hybrid_prop_expected_return_type), ) assert ( - hybrid_prop_field.description is None + hybrid_prop_field.description is None ) # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 026247ca..4acf89a8 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -1,8 +1,12 @@ -import graphene import pytest -from graphene import Connection, relay from sqlalchemy.sql.operators import is_ +import graphene +from graphene import Connection, relay + +from ..fields import SQLAlchemyConnectionField +from ..filters import FloatFilter +from ..types import ORMField, SQLAlchemyObjectType from .models import ( Article, Editor, @@ -16,10 +20,6 @@ Tag, ) from .utils import eventually_await_session, to_std_dicts -from ..fields import SQLAlchemyConnectionField -from ..filters import FloatFilter -from ..types import ORMField, SQLAlchemyObjectType - # TODO test that generated schema is correct for all examples with: # with open('schema.gql', 'w') as fp: @@ -257,7 +257,17 @@ async def test_filter_enum(session): } """ expected = { - "reporters": {"edges": [{"node": {"firstName": "Jane", "lastName": "Roe", "favoritePetKind": "DOG"}}]}, + "reporters": { + "edges": [ + { + "node": { + "firstName": "Jane", + "lastName": "Roe", + "favoritePetKind": "DOG", + } + } + ] + }, } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) @@ -916,8 +926,20 @@ async def test_filter_logic_or(session): expected = { "reporters": { "edges": [ - {"node": {"firstName": "John", "lastName": "Woe", "favoritePetKind": "CAT"}}, - {"node": {"firstName": "Jane", "lastName": "Roe", "favoritePetKind": "DOG"}}, + { + "node": { + "firstName": "John", + "lastName": "Woe", + "favoritePetKind": "CAT", + } + }, + { + "node": { + "firstName": "Jane", + "lastName": "Roe", + "favoritePetKind": "DOG", + } + }, ] } } @@ -1109,7 +1131,7 @@ async def test_filter_hybrid_property(session): result = to_std_dicts(result.data) assert len(result["carts"]["edges"]) == 1 assert ( - len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2 + len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2 ) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index b99f0236..6b5ab4dd 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -6,9 +6,13 @@ from inspect import isawaitable from typing import Any, Optional, Type, Union -import graphene import sqlalchemy -from graphene import Field, InputField, Dynamic +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty +from sqlalchemy.orm.exc import NoResultFound + +import graphene +from graphene import Dynamic, Field, InputField from graphene.relay import Connection, Node from graphene.types.base import BaseType from graphene.types.interface import Interface, InterfaceOptions @@ -16,9 +20,6 @@ from graphene.types.unmountedtype import UnmountedType from graphene.types.utils import yank_fields_from_attrs from graphene.utils.orderedtype import OrderedType -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty -from sqlalchemy.orm.exc import NoResultFound from .converter import ( convert_sqlalchemy_column, @@ -31,7 +32,7 @@ sort_argument_for_object_type, sort_enum_for_object_type, ) -from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter, SQLAlchemyFilterInputField +from .filters import BaseTypeFilter, RelationshipFilter, SQLAlchemyFilterInputField from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( @@ -51,17 +52,17 @@ class ORMField(OrderedType): def __init__( - self, - model_attr=None, - type_=None, - required=None, - description=None, - deprecation_reason=None, - batching=None, - create_filter=None, - filter_type: Optional[Type] = None, - _creation_counter=None, - **field_kwargs, + self, + model_attr=None, + type_=None, + required=None, + description=None, + deprecation_reason=None, + batching=None, + create_filter=None, + filter_type: Optional[Type] = None, + _creation_counter=None, + **field_kwargs, ): """ Use this to override fields automatically generated by SQLAlchemyObjectType. @@ -127,7 +128,7 @@ class Meta: def get_or_create_relationship_filter( - base_type: Type[BaseType], registry: Registry + base_type: Type[BaseType], registry: Registry ) -> Type[RelationshipFilter]: relationship_filter = registry.get_relationship_filter_for_base_type(base_type) @@ -146,11 +147,11 @@ def get_or_create_relationship_filter( def filter_field_from_field( - field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], - type_, - registry: Registry, - model_attr: Any, - model_attr_name: str + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + type_, + registry: Registry, + model_attr: Any, + model_attr_name: str, ) -> Optional[graphene.InputField]: # Field might be a SQLAlchemyObjectType, due to hybrid properties if issubclass(type_, SQLAlchemyObjectType): @@ -174,9 +175,7 @@ def filter_field_from_field( def resolve_dynamic_relationship_filter( - field: graphene.Dynamic, - registry: Registry, - model_attr_name: str + field: graphene.Dynamic, registry: Registry, model_attr_name: str ) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: # Resolve Dynamic Type type_ = get_nullable_type(field.get_type()) @@ -208,11 +207,11 @@ def resolve_dynamic_relationship_filter( def filter_field_from_type_field( - field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], - registry: Registry, - filter_type: Optional[Type], - model_attr: Any, - model_attr_name: str + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + registry: Registry, + filter_type: Optional[Type], + model_attr: Any, + model_attr_name: str, ) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: # If a custom filter type was set for this field, use it here if filter_type: @@ -223,7 +222,11 @@ def filter_field_from_type_field( # If the generated field is Dynamic, it is always a relationship # (due to graphene-sqlalchemy's conversion mechanism). elif isinstance(field, graphene.Dynamic): - return Dynamic(partial(resolve_dynamic_relationship_filter, field, registry, model_attr_name)) + return Dynamic( + partial( + resolve_dynamic_relationship_filter, field, registry, model_attr_name + ) + ) # Unsupported but theoretically possible cases, please drop us an issue with reproduction if you need them elif isinstance(field, graphene.List) or isinstance(field._type, graphene.List): # Pure lists are not yet supported @@ -234,9 +237,23 @@ def filter_field_from_type_field( # Order matters, this comes last as field._type == list also matches Field elif isinstance(field, graphene.Field): if inspect.isfunction(field._type) or isinstance(field._type, partial): - return Dynamic(lambda: filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr, model_attr_name)) + return Dynamic( + lambda: filter_field_from_field( + field, + get_nullable_type(field.type), + registry, + model_attr, + model_attr_name, + ) + ) else: - return filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr, model_attr_name) + return filter_field_from_field( + field, + get_nullable_type(field.type), + registry, + model_attr, + model_attr_name, + ) def get_polymorphic_on(model): @@ -252,14 +269,14 @@ def get_polymorphic_on(model): def construct_fields_and_filters( - obj_type, - model, - registry, - only_fields, - exclude_fields, - batching, - create_filters, - connection_field_factory, + obj_type, + model, + registry, + only_fields, + exclude_fields, + batching, + create_filters, + connection_field_factory, ): """ Construct all the fields for a SQLAlchemyObjectType. @@ -296,9 +313,9 @@ def construct_fields_and_filters( auto_orm_field_names = [] for attr_name, attr in all_model_attrs.items(): if ( - (only_fields and attr_name not in only_fields) - or (attr_name in exclude_fields) - or attr_name == polymorphic_on + (only_fields and attr_name not in only_fields) + or (attr_name in exclude_fields) + or attr_name == polymorphic_on ): continue auto_orm_field_names.append(attr_name) @@ -390,21 +407,21 @@ class SQLAlchemyBase(BaseType): @classmethod def __init_subclass_with_meta__( - cls, - model=None, - registry=None, - skip_registry=False, - only_fields=(), - exclude_fields=(), - connection=None, - connection_class=None, - use_connection=None, - interfaces=(), - id=None, - batching=False, - connection_field_factory=None, - _meta=None, - **options, + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + exclude_fields=(), + connection=None, + connection_class=None, + use_connection=None, + interfaces=(), + id=None, + batching=False, + connection_field_factory=None, + _meta=None, + **options, ): # We always want to bypass this hook unless we're defining a concrete # `SQLAlchemyObjectType` or `SQLAlchemyInterface`. @@ -528,6 +545,7 @@ def get_node(cls, info, id): session = get_session(info.context) if isinstance(session, AsyncSession): + async def get_result() -> Any: return await session.get(cls._meta.model, id) @@ -659,7 +677,7 @@ def __init_subclass_with_meta__(cls, _meta=None, **options): if hasattr(_meta.model, "__mapper__"): polymorphic_identity = _meta.model.__mapper__.polymorphic_identity assert ( - polymorphic_identity is None + polymorphic_identity is None ), '{}: An interface cannot map to a concrete type (polymorphic_identity is "{}")'.format( cls.__name__, polymorphic_identity ) From e698d7c4369b034f4b69369f4af8affab5b38e71 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 4 Dec 2023 21:23:58 +0100 Subject: [PATCH 81/81] chore: fix newly added merge conflict due to association proxies --- graphene_sqlalchemy/types.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index b60b78e1..18d06eef 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -135,15 +135,19 @@ def get_or_create_relationship_filter( relationship_filter = registry.get_relationship_filter_for_base_type(base_type) if not relationship_filter: - base_type_filter = registry.get_filter_for_base_type(base_type) - relationship_filter = RelationshipFilter.create_type( - f"{base_type.__name__}RelationshipFilter", - base_type_filter=base_type_filter, - model=base_type._meta.model, - ) - registry.register_relationship_filter_for_base_type( - base_type, relationship_filter - ) + try: + base_type_filter = registry.get_filter_for_base_type(base_type) + relationship_filter = RelationshipFilter.create_type( + f"{base_type.__name__}RelationshipFilter", + base_type_filter=base_type_filter, + model=base_type._meta.model, + ) + registry.register_relationship_filter_for_base_type( + base_type, relationship_filter + ) + except Exception as e: + print("e") + raise e return relationship_filter @@ -397,14 +401,16 @@ def construct_fields_and_filters( connection_field_factory, batching, resolver, - **orm_field.kwargs + **orm_field.kwargs, ) else: raise Exception("Property type is not supported") # Should never happen registry.register_orm_field(obj_type, orm_field_name, attr) fields[orm_field_name] = field - if filtering_enabled_for_field: + if filtering_enabled_for_field and not isinstance(attr, AssociationProxy): + # we don't support filtering on association proxies yet. + # Support will be patched in a future release of graphene-sqlalchemy filters[orm_field_name] = filter_field_from_type_field( field, registry, filter_type, attr, attr_name )