Skip to content

Commit

Permalink
Enable sorting when batching is enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulSchweizer committed Jul 23, 2022
1 parent dfee3e9 commit ac57fd4
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 153 deletions.
48 changes: 36 additions & 12 deletions graphene_sqlalchemy/batching.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
from asyncio import get_event_loop

import aiodataloader
import sqlalchemy
from sqlalchemy.orm import Session, strategies
from sqlalchemy.orm.query import QueryContext

from .utils import is_sqlalchemy_version_less_than

# Cache this across `batch_load_fn` calls
# This is so SQL string generation is cached under-the-hood via `bakery`
# Caching the relationship loader for each relationship prop.
RELATIONSHIP_LOADERS_CACHE = {}

def get_batch_resolver(relationship_prop):

# Cache this across `batch_load_fn` calls
# This is so SQL string generation is cached under-the-hood via `bakery`
selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),))
def get_batch_resolver(relationship_prop):

class RelationshipLoader(aiodataloader.DataLoader):
cache = False

def __init__(self, relationship_prop, selectin_loader):
super().__init__()
self.relationship_prop = relationship_prop
self.selectin_loader = selectin_loader

async def batch_load_fn(self, parents):
"""
Batch loads the relationships of all the parents as one SQL statement.
Expand All @@ -38,8 +46,8 @@ async def batch_load_fn(self, parents):
SQLAlchemy's main maitainer suggestion.
See https://git.io/JewQ7
"""
child_mapper = relationship_prop.mapper
parent_mapper = relationship_prop.parent
child_mapper = self.relationship_prop.mapper
parent_mapper = self.relationship_prop.parent
session = Session.object_session(parents[0])

# These issues are very unlikely to happen in practice...
Expand All @@ -62,26 +70,42 @@ async def batch_load_fn(self, parents):
query_context = parent_mapper_query._compile_context()

if is_sqlalchemy_version_less_than('1.4'):
selectin_loader._load_for_path(
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper
)
else:
selectin_loader._load_for_path(
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
None
)

return [getattr(parent, relationship_prop.key) for parent in parents]

loader = RelationshipLoader()
return [getattr(parent, self.relationship_prop.key) for parent in parents]

def _get_loader(relationship_prop):
"""Retrieve the cached loader of the given relationship."""
loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None)
if loader is None:
selectin_loader = strategies.SelectInLoader(
relationship_prop,
(('lazy', 'selectin'),)
)
loader = RelationshipLoader(
relationship_prop=relationship_prop,
selectin_loader=selectin_loader
)
RELATIONSHIP_LOADERS_CACHE[relationship_prop] = loader
else:
loader.loop = get_event_loop()
return loader

loader = _get_loader(relationship_prop)

async def resolve(root, info, **args):
return await loader.load(root)
Expand Down
27 changes: 19 additions & 8 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,31 @@ def get_query(cls, model, info, sort=None, **args):
return query


class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField):
"""
This is currently experimental.
The API and behavior may change in future versions.
Use at your own risk.
"""

def wrap_resolve(self, parent_resolver):
return partial(
self.connection_resolver,
self.resolver,
get_nullable_type(self.type),
self.model,
)
@classmethod
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
if root is None:
resolved = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
else:
relationship_prop = None
for relationship in root.__class__.__mapper__.relationships:
if relationship.mapper.class_ == model:
relationship_prop = relationship
break
resolved = get_batch_resolver(relationship_prop)(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection_type, root, info, args)

if is_thenable(resolved):
return Promise.resolve(resolved).then(on_resolve)

return on_resolve(resolved)

@classmethod
def from_relationship(cls, relationship, registry, **field_kwargs):
Expand Down
18 changes: 18 additions & 0 deletions graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,24 @@ class Article(Base):
headline = Column(String(100))
pub_date = Column(Date())
reporter_id = Column(Integer(), ForeignKey("reporters.id"))
readers = relationship(
"Reader", secondary="articles_readers", back_populates="articles"
)


class Reader(Base):
__tablename__ = "readers"
id = Column(Integer(), primary_key=True)
name = Column(String(100))
articles = relationship(
"Article", secondary="articles_readers", back_populates="readers"
)


class ArticleReader(Base):
__tablename__ = "articles_readers"
article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True)
reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True)


class ReflectedEditor(type):
Expand Down
Loading

0 comments on commit ac57fd4

Please sign in to comment.