diff --git a/graphene_sqlalchemy/mutations.py b/graphene_sqlalchemy/mutations.py new file mode 100644 index 00000000..f1f50620 --- /dev/null +++ b/graphene_sqlalchemy/mutations.py @@ -0,0 +1,266 @@ +from graphene import Argument, Field, List, Mutation +from graphene.types.objecttype import ObjectTypeOptions +from graphene.types.utils import yank_fields_from_attrs +from sqlalchemy.inspection import inspect as sqlalchemyinspect + +from graphene_sqlalchemy.types import construct_fields +from .registry import get_global_registry +from .utils import get_session, get_snake_or_camel_attr + + +class SQLAlchemyMutationOptions(ObjectTypeOptions): + model = None # type: Model + + +class SQLAlchemyCreate(Mutation): + @classmethod + def __init_subclass_with_meta__(cls, model=None, registry=None, only_fields=(), exclude_fields=None, **options): + meta = SQLAlchemyMutationOptions(cls) + meta.model = model + + model_inspect = sqlalchemyinspect(model) + cls._model_inspect = model_inspect + + if not isinstance(exclude_fields, list): + if exclude_fields: + exclude_fields = list(exclude_fields) + else: + exclude_fields = [] + + for primary_key_column in model_inspect.primary_key: + if primary_key_column.autoincrement: + exclude_fields.append(primary_key_column.name) + + for relationship in model_inspect.relationships: + exclude_fields.append(relationship.key) + + if not registry: + registry = get_global_registry() + + arguments = yank_fields_from_attrs( + construct_fields(model, registry, only_fields, exclude_fields), + _as=Argument, + ) + + super(SQLAlchemyCreate, cls).__init_subclass_with_meta__(_meta=meta, arguments=arguments, **options) + + @classmethod + def mutate(cls, self, info, **kwargs): + session = get_session(info.context) + + meta = cls._meta + + model = meta.model() + session.add(model) + + for key, value in kwargs.items(): + setattr(model, key, value) + + session.commit() + + return model + + @classmethod + def Field(cls, *args, **kwargs): + return Field( + cls._meta.output, args=cls._meta.arguments, resolver=cls._meta.resolver + ) + + +class SQLAlchemyUpdate(Mutation): + @classmethod + def __init_subclass_with_meta__(cls, model=None, registry=None, only_fields=(), exclude_fields=None, **options): + meta = SQLAlchemyMutationOptions(cls) + meta.model = model + + model_inspect = sqlalchemyinspect(model) + cls._model_inspect = model_inspect + + if not isinstance(exclude_fields, list): + if exclude_fields: + exclude_fields = list(exclude_fields) + else: + exclude_fields = [] + + for relationship in model_inspect.relationships: + exclude_fields.append(relationship.key) + + if not registry: + registry = get_global_registry() + + arguments = yank_fields_from_attrs( + construct_fields(model, registry, only_fields, exclude_fields), + _as=Argument + ) + + super(SQLAlchemyUpdate, cls).__init_subclass_with_meta__(_meta=meta, arguments=arguments, **options) + + @classmethod + def mutate(cls, self, info, **kwargs): + session = get_session(info.context) + + meta = cls._meta + + query = session.query(meta.model) + for primary_key_column in cls._model_inspect.primary_key: + query = query.filter(getattr(meta.model, primary_key_column.key) == kwargs[primary_key_column.name]) + model = query.one() + + for key, value in kwargs.items(): + setattr(model, key, value) + + session.commit() + + return model + + @classmethod + def Field(cls, *args, **kwargs): + return Field( + cls._meta.output, args=cls._meta.arguments, resolver=cls._meta.resolver + ) + + +class SQLAlchemyDelete(Mutation): + @classmethod + def __init_subclass_with_meta__(cls, model=None, registry=None, only_fields=(), + exclude_fields=None, **options): + meta = SQLAlchemyMutationOptions(cls) + meta.model = model + + model_inspect = sqlalchemyinspect(model) + cls._model_inspect = model_inspect + + only_fields = [] + exclude_fields = () + for primary_key_column in model_inspect.primary_key: + only_fields.append(primary_key_column.name) + + if not registry: + registry = get_global_registry() + + arguments = yank_fields_from_attrs( + construct_fields(model, registry, only_fields, exclude_fields), + _as=Argument + ) + + super(SQLAlchemyDelete, cls).__init_subclass_with_meta__(_meta=meta, arguments=arguments, **options) + + @classmethod + def mutate(cls, self, info, **kwargs): + session = get_session(info.context) + + meta = cls._meta + + query = session.query(meta.model) + + for primary_key_column in cls._model_inspect.primary_key: + query = query.filter(getattr(meta.model, primary_key_column.key) == kwargs[primary_key_column.name]) + model = query.one() + session.delete(model) + + session.commit() + + return model + + @classmethod + def Field(cls, *args, **kwargs): + return Field( + cls._meta.output, args=cls._meta.arguments, resolver=cls._meta.resolver + ) + + +class SQLAlchemyListDelete(Mutation): + @classmethod + def __init_subclass_with_meta__(cls, model=None, registry=None, only_fields=(), + exclude_fields=None, **options): + meta = SQLAlchemyMutationOptions(cls) + meta.model = model + + model_inspect = sqlalchemyinspect(model) + for column in model_inspect.columns: + column.nullable = True + + cls._model_inspect = model_inspect + + if not isinstance(exclude_fields, list): + if exclude_fields: + exclude_fields = list(exclude_fields) + else: + exclude_fields = [] + + for relationship in model_inspect.relationships: + exclude_fields.append(relationship.key) + + if not registry: + registry = get_global_registry() + + arguments = yank_fields_from_attrs( + construct_fields(model, registry, only_fields, exclude_fields), + _as=Argument + ) + + super(SQLAlchemyListDelete, cls).__init_subclass_with_meta__(_meta=meta, arguments=arguments, **options) + + @classmethod + def mutate(cls, self, info, **kwargs): + session = get_session(info.context) + + meta = cls._meta + + query = session.query(meta.model) + for key, value in kwargs.items(): + query = query.filter(get_snake_or_camel_attr(meta.model, key) == value) + + models = query.all() + for model in models: + session.delete(model) + + session.commit() + + return models + + @classmethod + def Field(cls, *args, **kwargs): + return Field( + cls._meta.output, args=cls._meta.arguments, resolver=cls._meta.resolver + ) + + +def create(of_type): + class CreateModel(SQLAlchemyCreate): + class Meta: + model = of_type._meta.model + + Output = of_type + + return CreateModel.Field() + + +def update(of_type): + class UpdateModel(SQLAlchemyUpdate): + class Meta: + model = of_type._meta.model + + Output = of_type + + return UpdateModel.Field() + + +def delete(of_type): + class DeleteModel(SQLAlchemyDelete): + class Meta: + model = of_type._meta.model + + Output = of_type + + return DeleteModel.Field() + + +def delete_all(of_type): + class DeleteListModel(SQLAlchemyListDelete): + class Meta: + model = of_type._meta.model + + Output = List(of_type) + + return DeleteListModel.Field() diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py new file mode 100644 index 00000000..b5d8e07f --- /dev/null +++ b/graphene_sqlalchemy/tests/conftest.py @@ -0,0 +1,7 @@ +import pytest +from sqlalchemy import create_engine + + +@pytest.fixture(scope='session') +def db(): + return create_engine('sqlite:///test_sqlalchemy.sqlite3') diff --git a/graphene_sqlalchemy/tests/test_mutations.py b/graphene_sqlalchemy/tests/test_mutations.py new file mode 100644 index 00000000..316efeea --- /dev/null +++ b/graphene_sqlalchemy/tests/test_mutations.py @@ -0,0 +1,334 @@ +import graphene +import pytest +from sqlalchemy.orm import scoped_session, sessionmaker + +from graphene_sqlalchemy import SQLAlchemyObjectType +from graphene_sqlalchemy.mutations import create, delete, update, delete_all +from graphene_sqlalchemy.registry import reset_global_registry +from graphene_sqlalchemy.tests.models import Base, Reporter + + +@pytest.yield_fixture(scope='function') +def session(db): + reset_global_registry() + connection = db.engine.connect() + transaction = connection.begin() + Base.metadata.create_all(connection) + + # options = dict(bind=connection, binds={}) + session_factory = sessionmaker(bind=connection) + session = scoped_session(session_factory) + + yield session + + # Finalize test here + transaction.rollback() + connection.close() + session.remove() + + +def setup_fixtures(session): + reporter = Reporter(first_name='ABC', last_name='def') + session.add(reporter) + reporter2 = Reporter(first_name='CBA', last_name='fed') + session.add(reporter2) + session.commit() + + +def test_should_create_with_create_field(session): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = graphene.List(ReporterType) + + def resolve_reporters(self, *args, **kwargs): + return session.query(Reporter) + + class Mutation(graphene.ObjectType): + createReporter = create(ReporterType) + + query = ''' + query ReporterQuery { + reporters { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [] + } + schema = graphene.Schema(query=Query, mutation=Mutation) + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + query = ''' + mutation createReporter{ + createReporter (firstName: "ABC", lastName: "def") { + firstName, + lastName, + email + } + } + ''' + expected = { + 'createReporter': { + 'firstName': 'ABC', + 'lastName': 'def', + 'email': None, + } + } + schema = graphene.Schema(query=Query, mutation=Mutation) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + +def test_should_delete_with_delete_field(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = graphene.List(ReporterType) + + def resolve_reporters(self, *args, **kwargs): + return session.query(Reporter) + + class Mutation(graphene.ObjectType): + deleteReporter = delete(ReporterType) + + query = ''' + query ReporterQuery { + reporters { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [ + { + 'firstName': 'ABC', + 'lastName': 'def', + 'email': None + }, + { + 'firstName': 'CBA', + 'lastName': 'fed', + 'email': None + } + ] + } + schema = graphene.Schema(query=Query, mutation=Mutation) + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + query = ''' + mutation deleteReporter { + deleteReporter (id: 1) { + firstName, + lastName, + email + } + } + ''' + expected = { + 'deleteReporter': { + 'firstName': 'ABC', + 'lastName': 'def', + 'email': None, + } + } + schema = graphene.Schema(query=Query, mutation=Mutation) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + query = ''' + query ReporterQuery { + reporters { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [ + { + 'firstName': 'CBA', + 'lastName': 'fed', + 'email': None + } + ] + } + schema = graphene.Schema(query=Query, mutation=Mutation) + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_should_delete_all_with_delete_all_field(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = graphene.List(ReporterType) + + def resolve_reporters(self, *args, **kwargs): + return session.query(Reporter) + + class Mutation(graphene.ObjectType): + deleteReporters = delete_all(ReporterType) + + query = ''' + query ReporterQuery { + reporters { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [ + { + 'firstName': 'ABC', + 'lastName': 'def', + 'email': None + }, + { + 'firstName': 'CBA', + 'lastName': 'fed', + 'email': None + } + ] + } + schema = graphene.Schema(query=Query, mutation=Mutation) + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + query = ''' + mutation deleteReporters { + deleteReporters (firstName: "ABC") { + firstName, + lastName, + email + } + } + ''' + expected = { + 'deleteReporters': [ + { + 'firstName': 'ABC', + 'lastName': 'def', + 'email': None, + } + ] + } + schema = graphene.Schema(query=Query, mutation=Mutation) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + query = ''' + query ReporterQuery { + reporters { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [ + { + 'firstName': 'CBA', + 'lastName': 'fed', + 'email': None + } + ] + } + schema = graphene.Schema(query=Query, mutation=Mutation) + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_should_update_with_update_field(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = graphene.List(ReporterType) + + def resolve_reporters(self, *args, **kwargs): + return session.query(Reporter) + + class Mutation(graphene.ObjectType): + updateReporter = update(ReporterType) + + query = ''' + mutation updateReporter { + updateReporter (id: 1, lastName: "updated", email: "test@test.io") { + firstName, + lastName, + email + } + } + ''' + expected = { + 'updateReporter': { + 'firstName': 'ABC', + 'lastName': 'updated', + 'email': 'test@test.io', + } + } + schema = graphene.Schema(query=Query, mutation=Mutation) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + query = ''' + query ReporterQuery { + reporters { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [ + { + 'firstName': 'ABC', + 'lastName': 'updated', + 'email': 'test@test.io', + }, + { + 'firstName': 'CBA', + 'lastName': 'fed', + 'email': None + } + ] + } + schema = graphene.Schema(query=Query, mutation=Mutation) + result = schema.execute(query) + assert not result.errors + assert result.data == expected diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index e4c3f835..179fbfaa 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -7,14 +7,12 @@ from ..registry import reset_global_registry from ..fields import SQLAlchemyConnectionField -from ..types import SQLAlchemyObjectType +from ..types import SQLAlchemyObjectType, SQLAlchemyList from .models import Article, Base, Editor, Reporter -db = create_engine('sqlite:///test_sqlalchemy.sqlite3') - @pytest.yield_fixture(scope='function') -def session(): +def session(db): reset_global_registry() connection = db.engine.connect() transaction = connection.begin() @@ -45,11 +43,10 @@ def setup_fixtures(session): session.commit() -def test_should_query_well(session): +def test_should_query_well_with_graphene_types(session): setup_fixtures(session) class ReporterType(SQLAlchemyObjectType): - class Meta: model = Reporter @@ -93,24 +90,257 @@ def resolve_reporters(self, *args, **kwargs): assert result.data == expected +def test_should_filter_with_sqlalchemy_fields(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = SQLAlchemyList(ReporterType) + + query = ''' + query ReporterQuery { + reporters(firstName: "ABA") { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [{ + 'firstName': 'ABA', + 'lastName': 'X', + 'email': None + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + +def test_should_filter_with_custom_argument(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = SQLAlchemyList(ReporterType, contains_o=graphene.Boolean()) + + def query_reporters(self, info, query, **kwargs): + return query.filter(Reporter.first_name.contains('O') == kwargs['contains_o']) + + query = ''' + query ReporterQuery { + reporters(lastName: "Y", containsO: true) { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [{ + 'firstName': 'ABO', + 'lastName': 'Y', + 'email': None + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + query = ''' + query ReporterQuery { + reporters(containsO: false) { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [{ + 'firstName': 'ABA', + 'lastName': 'X', + 'email': None + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + +def test_should_filter_with_custom_operator(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = SQLAlchemyList(ReporterType, operator='like') + + query = ''' + query ReporterQuery { + reporters(firstName: "%BO%") { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [{ + 'firstName': 'ABO', + 'lastName': 'Y', + 'email': None + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + +def test_should_order_by(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = SQLAlchemyList(ReporterType, order_by='firstName') + + query = ''' + query ReporterQuery { + reporters { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [{ + 'firstName': 'ABA', + 'lastName': 'X', + 'email': None + }, + { + 'firstName': 'ABO', + 'lastName': 'Y', + 'email': None + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + +def test_should_order_by_asc(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = SQLAlchemyList(ReporterType, order_by='first_name ASC') + + query = ''' + query ReporterQuery { + reporters { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [ + { + 'firstName': 'ABA', + 'lastName': 'X', + 'email': None + }, + { + 'firstName': 'ABO', + 'lastName': 'Y', + 'email': None + } + ] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + +def test_should_order_by_desc(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = SQLAlchemyList(ReporterType, order_by='firstName desc') + + query = ''' + query ReporterQuery { + reporters { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [ + { + 'firstName': 'ABO', + 'lastName': 'Y', + 'email': None + }, + { + 'firstName': 'ABA', + 'lastName': 'X', + 'email': None + } + ] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + def test_should_node(session): setup_fixtures(session) class ReporterNode(SQLAlchemyObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) @classmethod def get_node(cls, id, info): return Reporter(id=2, first_name='Cookie Monster') class ArticleNode(SQLAlchemyObjectType): - class Meta: model = Article - interfaces = (Node, ) + interfaces = (Node,) # @classmethod # def get_node(cls, id, info): @@ -169,9 +399,9 @@ def resolve_article(self, *args, **kwargs): 'email': None, 'articles': { 'edges': [{ - 'node': { - 'headline': 'Hi!' - } + 'node': { + 'headline': 'Hi!' + } }] }, }, @@ -197,10 +427,9 @@ def test_should_custom_identifier(session): setup_fixtures(session) class EditorNode(SQLAlchemyObjectType): - class Meta: model = Editor - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): node = Node.Field() @@ -247,29 +476,25 @@ def test_should_mutate_well(session): setup_fixtures(session) class EditorNode(SQLAlchemyObjectType): - class Meta: model = Editor - interfaces = (Node, ) + interfaces = (Node,) class ReporterNode(SQLAlchemyObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) @classmethod def get_node(cls, id, info): return Reporter(id=2, first_name='Cookie Monster') class ArticleNode(SQLAlchemyObjectType): - class Meta: model = Article - interfaces = (Node, ) + interfaces = (Node,) class CreateArticle(graphene.Mutation): - class Arguments: headline = graphene.String() reporter_id = graphene.ID() diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index 8af3c61e..91fadc98 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -1,6 +1,6 @@ from graphene import ObjectType, Schema, String -from ..utils import get_session +from ..utils import get_session, get_operator_function def test_get_session(): @@ -22,3 +22,28 @@ def resolve_x(self, info): result = schema.execute(query, context_value={'session': session}) assert not result.errors assert result.data['x'] == session + + +def test_get_operator_function(): + func = get_operator_function('=') + assert func(1, 1) + assert not func(1, 2) + + func = get_operator_function('!=') + assert not func(1, 1) + assert func(1, 2) + + class DummyLike: + def __init__(self, value): + self.value = value + + def like(self, other): + return self.value == other + + func = get_operator_function('like') + assert func(DummyLike(1), 1) + assert not func(DummyLike(1), 2) + + func = get_operator_function(lambda x, y: x > y) + assert func(5, 3) + assert not func(3, 5) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 04d1a8a6..9599963b 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,20 +1,17 @@ +import six from collections import OrderedDict - -from sqlalchemy.inspection import inspect as sqlalchemyinspect -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm.exc import NoResultFound - -from graphene import Field # , annotate, ResolveInfo +from graphene import Field, List, String # , annotate, ResolveInfo from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.utils import yank_fields_from_attrs +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.inspection import inspect as sqlalchemyinspect +from sqlalchemy.orm.exc import NoResultFound -from .converter import (convert_sqlalchemy_column, - convert_sqlalchemy_composite, - convert_sqlalchemy_relationship, - convert_sqlalchemy_hybrid_method) +from .converter import convert_sqlalchemy_column, convert_sqlalchemy_composite, convert_sqlalchemy_hybrid_method, \ + convert_sqlalchemy_relationship from .registry import Registry, get_global_registry -from .utils import get_query, is_mapped_class, is_mapped_instance +from .utils import get_operator_function, get_query, get_snake_or_camel_attr, is_mapped_class, is_mapped_instance def construct_fields(model, registry, only_fields, exclude_fields): @@ -139,8 +136,8 @@ def is_type_of(cls, root, info): return True if not is_mapped_instance(root): raise Exception(( - 'Received incompatible instance "{}".' - ).format(root)) + 'Received incompatible instance "{}".' + ).format(root)) return isinstance(root, cls._meta.model) @classmethod @@ -159,3 +156,76 @@ def resolve_id(self, info): # graphene_type = info.parent_type.graphene_type keys = self.__mapper__.primary_key_from_instance(self) return tuple(keys) if len(keys) > 1 else keys[0] + + +class SQLAlchemyList(List): + def __init__(self, of_type, exclude_fields=(), include_fields=(), operator=None, order_by=(), *args, **kwargs): + columns_dict = self.build_columns_dict(of_type) + + if include_fields: + columns_dict = {k: columns_dict[k] for k in include_fields} + for exclude_field in exclude_fields: + if exclude_field in columns_dict.keys(): + del columns_dict[exclude_field] + + kwargs.update(**columns_dict) + kwargs['operator'] = String(description="Operator to use for filtering") + kwargs['order_by'] = List(String, description="Fields to use for results ordering") + + default_operator = get_operator_function(operator) + if isinstance(order_by, six.string_types): + order_by = (order_by,) + default_order_by = order_by + + def filters_resolver(self, info, **kwargs): + operator = default_operator + if 'operator' in kwargs: + operator = get_operator_function(kwargs['operator']) + + query = of_type.get_query(info) + + for (k, v) in kwargs.items(): + if hasattr(of_type._meta.model, k): + query = query.filter(operator(get_snake_or_camel_attr(of_type._meta.model, k), v)) + + order_by = default_order_by + if 'order_by' in kwargs: + order_by = kwargs['order_by'] + + for order_by_item in order_by: + if order_by_item.lower().endswith(' asc'): + order_by_item = order_by_item[:-len(' asc')] + query = query.order_by(get_snake_or_camel_attr(of_type._meta.model, order_by_item).asc()) + elif order_by_item.lower().endswith(' desc'): + order_by_item = order_by_item[:-len(' desc')] + query = query.order_by(get_snake_or_camel_attr(of_type._meta.model, order_by_item).desc()) + else: + query = query.order_by(get_snake_or_camel_attr(of_type._meta.model, order_by_item)) + + query_transformer = getattr(info.parent_type.graphene_type, 'query_' + info.field_name, False) + if callable(query_transformer): + transformed_query = query_transformer(info.parent_type.graphene_type(), info, query, **kwargs) + if transformed_query: + query = transformed_query + + return query.all() + + kwargs['resolver'] = filters_resolver + super(SQLAlchemyList, self).__init__(of_type, *args, **kwargs) + + @staticmethod + def build_columns_dict(of_type): + columns_dict = {} + inspected_model = sqlalchemyinspect(of_type._meta.model) + for column in inspected_model.columns: + column.nullable = True # Set nullable to false to build an optional graph type + graphene_type = convert_sqlalchemy_column(column) + columns_dict[column.name] = graphene_type + return columns_dict + + def __eq__(self, other): + return isinstance(other, SQLAlchemyList) and ( + self.of_type == other.of_type and + self.args == other.args and + self.kwargs == other.kwargs + ) diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index e78c9802..4a4fd957 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,6 +1,7 @@ from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError +from graphene.utils.str_converters import to_camel_case, to_snake_case def get_session(context): @@ -18,6 +19,46 @@ def get_query(model, context): return query +_operator_aliases = { + '==': lambda x, y: x == y, + '=': lambda x, y: x == y, + '>=': lambda x, y: x >= y, + '>': lambda x, y: x > y, + '<=': lambda x, y: x <= y, + '<': lambda x, y: x < y, + '!=': lambda x, y: x != y +} + + +def get_operator_function(operator=None): + if not operator: + def operator(x, y): + return x == y + + if not callable(operator): + operator_attr = operator + + if operator in _operator_aliases: + operator = _operator_aliases[operator_attr] + else: + def operator(x, y): + return getattr(x, operator_attr)(y) + + return operator + + +def get_snake_or_camel_attr(model, attr): + try: + return getattr(model, to_snake_case(attr)) + except Exception: + pass + try: + return getattr(model, to_camel_case(attr)) + except Exception: + pass + return getattr(model, attr) + + def is_mapped_class(cls): try: class_mapper(cls)