Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable sorting when batching is enabled #355

Merged
merged 7 commits into from
Sep 9, 2022
178 changes: 103 additions & 75 deletions graphene_sqlalchemy/batching.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The dataloader uses "select in loading" strategy to load related entities."""
from typing import Any
from asyncio import get_event_loop
from typing import Any, Dict

import aiodataloader
import sqlalchemy
Expand All @@ -10,6 +11,90 @@
is_sqlalchemy_version_less_than)


class RelationshipLoader(aiodataloader.DataLoader):
cache = False

def __init__(self, relationship_prop, selectin_loader):
super().__init__()
self.relationship_prop = relationship_prop
self.selectin_loader = selectin_loader

async def batch_load_fn(self, parents):
"""
Batch loads the relationships of all the parents as one SQL statement.

There is no way to do this out-of-the-box with SQLAlchemy but
we can piggyback on some internal APIs of the `selectin`
eager loading strategy. It's a bit hacky but it's preferable
than re-implementing and maintainnig a big chunk of the `selectin`
loader logic ourselves.

The approach here is to build a regular query that
selects the parent and `selectin` load the relationship.
But instead of having the query emits 2 `SELECT` statements
when callling `all()`, we skip the first `SELECT` statement
and jump right before the `selectin` loader is called.
To accomplish this, we have to construct objects that are
normally built in the first part of the query in order
to call directly `SelectInLoader._load_for_path`.

TODO Move this logic to a util in the SQLAlchemy repo as per
SQLAlchemy's main maitainer suggestion.
See https://git.io/JewQ7
"""
child_mapper = self.relationship_prop.mapper
parent_mapper = self.relationship_prop.parent
session = Session.object_session(parents[0])

# These issues are very unlikely to happen in practice...
for parent in parents:
# assert parent.__mapper__ is parent_mapper
# All instances must share the same session
assert session is Session.object_session(parent)
# The behavior of `selectin` is undefined if the parent is dirty
assert parent not in session.dirty

# Should the boolean be set to False? Does it matter for our purposes?
states = [(sqlalchemy.inspect(parent), True) for parent in parents]

# For our purposes, the query_context will only used to get the session
query_context = None
if is_sqlalchemy_version_less_than('1.4'):
query_context = QueryContext(session.query(parent_mapper.entity))
else:
parent_mapper_query = session.query(parent_mapper.entity)
query_context = parent_mapper_query._compile_context()

if is_sqlalchemy_version_less_than('1.4'):
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
)
else:
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
None,
)
return [
getattr(parent, self.relationship_prop.key) for parent in parents
]


# Cache this across `batch_load_fn` calls
# This is so SQL string generation is cached under-the-hood via `bakery`
# Caching the relationship loader for each relationship prop.
RELATIONSHIP_LOADERS_CACHE: Dict[
erikwrede marked this conversation as resolved.
Show resolved Hide resolved
sqlalchemy.orm.relationships.RelationshipProperty, RelationshipLoader
] = {}


def get_data_loader_impl() -> Any: # pragma: no cover
"""Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility,
aiodataloader is used in conjunction with older versions of graphene"""
Expand All @@ -25,80 +110,23 @@ def get_data_loader_impl() -> Any: # pragma: no cover


def get_batch_resolver(relationship_prop):
# Cache this across `batch_load_fn` calls
# This is so SQL string generation is cached under-the-hood via `bakery`
selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),))

class RelationshipLoader(aiodataloader.DataLoader):
cache = False

async def batch_load_fn(self, parents):
"""
Batch loads the relationships of all the parents as one SQL statement.

There is no way to do this out-of-the-box with SQLAlchemy but
we can piggyback on some internal APIs of the `selectin`
eager loading strategy. It's a bit hacky but it's preferable
than re-implementing and maintainnig a big chunk of the `selectin`
loader logic ourselves.

The approach here is to build a regular query that
selects the parent and `selectin` load the relationship.
But instead of having the query emits 2 `SELECT` statements
when callling `all()`, we skip the first `SELECT` statement
and jump right before the `selectin` loader is called.
To accomplish this, we have to construct objects that are
normally built in the first part of the query in order
to call directly `SelectInLoader._load_for_path`.

TODO Move this logic to a util in the SQLAlchemy repo as per
SQLAlchemy's main maitainer suggestion.
See https://git.io/JewQ7
"""
child_mapper = relationship_prop.mapper
parent_mapper = relationship_prop.parent
session = Session.object_session(parents[0])

# These issues are very unlikely to happen in practice...
for parent in parents:
# assert parent.__mapper__ is parent_mapper
# All instances must share the same session
assert session is Session.object_session(parent)
# The behavior of `selectin` is undefined if the parent is dirty
assert parent not in session.dirty

# Should the boolean be set to False? Does it matter for our purposes?
states = [(sqlalchemy.inspect(parent), True) for parent in parents]

# For our purposes, the query_context will only used to get the session
query_context = None
if is_sqlalchemy_version_less_than('1.4'):
query_context = QueryContext(session.query(parent_mapper.entity))
else:
parent_mapper_query = session.query(parent_mapper.entity)
query_context = parent_mapper_query._compile_context()

if is_sqlalchemy_version_less_than('1.4'):
selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper
)
else:
selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
None
)

return [getattr(parent, relationship_prop.key) for parent in parents]

loader = RelationshipLoader()
"""Get the resolve function for the given relationship."""

def _get_loader(relationship_prop):
"""Retrieve the cached loader of the given relationship."""
loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None)
if loader is None or loader.loop != get_event_loop():
selectin_loader = strategies.SelectInLoader(
relationship_prop, (('lazy', 'selectin'),)
)
loader = RelationshipLoader(
relationship_prop=relationship_prop,
selectin_loader=selectin_loader,
)
RELATIONSHIP_LOADERS_CACHE[relationship_prop] = loader
return loader

loader = _get_loader(relationship_prop)

async def resolve(root, info, **args):
return await loader.load(root)
Expand Down
4 changes: 2 additions & 2 deletions graphene_sqlalchemy/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def sort_enum_for_object_type(
column = orm_field.columns[0]
if only_indexed and not (column.primary_key or column.index):
continue
asc_name = get_name(column.key, True)
sabard marked this conversation as resolved.
Show resolved Hide resolved
asc_name = get_name(field_name, True)
asc_value = EnumValue(asc_name, column.asc())
desc_name = get_name(column.key, False)
desc_name = get_name(field_name, False)
desc_value = EnumValue(desc_name, column.desc())
if column.primary_key:
default.append(asc_value)
Expand Down
116 changes: 69 additions & 47 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .utils import EnumValue, get_query


class UnsortedSQLAlchemyConnectionField(ConnectionField):
class SQLAlchemyConnectionField(ConnectionField):
@property
def type(self):
from .types import SQLAlchemyObjectType
Expand All @@ -37,13 +37,45 @@ def type(self):
)
return nullable_type.connection

def __init__(self, type_, *args, **kwargs):
nullable_type = get_nullable_type(type_)
if "sort" not in kwargs and nullable_type and issubclass(nullable_type, Connection):
# Let super class raise if type is not a Connection
try:
kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument())
except (AttributeError, TypeError):
raise TypeError(
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
" to None to disabling the creation of the sort query argument".format(
nullable_type.__name__
)
)
elif "sort" in kwargs and kwargs["sort"] is None:
del kwargs["sort"]
super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)

@property
def model(self):
return get_nullable_type(self.type)._meta.node._meta.model

@classmethod
def get_query(cls, model, info, **args):
return get_query(model, info.context)
def get_query(cls, model, info, sort=None, **args):
query = get_query(model, info.context)
if sort is not None:
if not isinstance(sort, list):
sort = [sort]
sort_args = []
# ensure consistent handling of graphene Enums, enum values and
# plain strings
for item in sort:
if isinstance(item, enum.Enum):
sort_args.append(item.value.value)
elif isinstance(item, EnumValue):
sort_args.append(item.value)
else:
sort_args.append(item)
query = query.order_by(*sort_args)
return query

@classmethod
def resolve_connection(cls, connection_type, model, info, args, resolved):
Expand Down Expand Up @@ -90,59 +122,49 @@ def wrap_resolve(self, parent_resolver):
)


# TODO Rename this to SortableSQLAlchemyConnectionField
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
# TODO Remove in next major version
class UnsortedSQLAlchemyConnectionField(SQLAlchemyConnectionField):
erikwrede marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, type_, *args, **kwargs):
nullable_type = get_nullable_type(type_)
if "sort" not in kwargs and issubclass(nullable_type, Connection):
# Let super class raise if type is not a Connection
try:
kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument())
except (AttributeError, TypeError):
raise TypeError(
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
" to None to disabling the creation of the sort query argument".format(
nullable_type.__name__
)
)
elif "sort" in kwargs and kwargs["sort"] is None:
del kwargs["sort"]
super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)

@classmethod
def get_query(cls, model, info, sort=None, **args):
query = get_query(model, info.context)
if sort is not None:
if not isinstance(sort, list):
sort = [sort]
sort_args = []
# ensure consistent handling of graphene Enums, enum values and
# plain strings
for item in sort:
if isinstance(item, enum.Enum):
sort_args.append(item.value.value)
elif isinstance(item, EnumValue):
sort_args.append(item.value)
else:
sort_args.append(item)
query = query.order_by(*sort_args)
return query
if "sort" in kwargs and kwargs["sort"] is not None:
warnings.warn(
"UnsortedSQLAlchemyConnectionField does not support sorting. "
"All sorting arguments will be ignored."
)
kwargs["sort"] = None
warnings.warn(
"UnsortedSQLAlchemyConnectionField is deprecated and will be removed in the next "
"major version. Use SQLAlchemyConnectionField instead and either don't "
"provide the `sort` argument or set it to None if you do not want sorting.",
DeprecationWarning,
)
super(UnsortedSQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)


class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField):
erikwrede marked this conversation as resolved.
Show resolved Hide resolved
"""
This is currently experimental.
The API and behavior may change in future versions.
Use at your own risk.
"""

def wrap_resolve(self, parent_resolver):
return partial(
self.connection_resolver,
self.resolver,
get_nullable_type(self.type),
self.model,
)
@classmethod
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
if root is None:
resolved = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
else:
relationship_prop = None
for relationship in root.__class__.__mapper__.relationships:
if relationship.mapper.class_ == model:
relationship_prop = relationship
break
resolved = get_batch_resolver(relationship_prop)(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection_type, root, info, args)

if is_thenable(resolved):
return Promise.resolve(resolved).then(on_resolve)

return on_resolve(resolved)

@classmethod
def from_relationship(cls, relationship, registry, **field_kwargs):
Expand Down
18 changes: 18 additions & 0 deletions graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,24 @@ class Article(Base):
headline = Column(String(100))
pub_date = Column(Date())
reporter_id = Column(Integer(), ForeignKey("reporters.id"))
readers = relationship(
"Reader", secondary="articles_readers", back_populates="articles"
)


class Reader(Base):
__tablename__ = "readers"
id = Column(Integer(), primary_key=True)
name = Column(String(100))
articles = relationship(
"Article", secondary="articles_readers", back_populates="readers"
)


class ArticleReader(Base):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to add some deeper nesting in the test data for more robust tests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a Tags class in #357. Will see if there's a way to consolidate over there after this is merged.

__tablename__ = "articles_readers"
article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True)
reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True)


class ReflectedEditor(type):
Expand Down
Loading