diff --git a/graphene_sqlalchemy/mutations.py b/graphene_sqlalchemy/mutations.py new file mode 100644 index 00000000..1193ae09 --- /dev/null +++ b/graphene_sqlalchemy/mutations.py @@ -0,0 +1,27 @@ +from graphene_sqlalchemy.types import SQLAlchemyMutation + + +def create(of_type): + class CreateModel(SQLAlchemyMutation): + class Meta: + model = of_type._meta.model + create = True + Output = of_type + return CreateModel.Field() + + +def update(of_type): + class UpdateModel(SQLAlchemyMutation): + class Meta: + model = of_type._meta.model + Output = of_type + return UpdateModel.Field() + + +def delete(of_type): + class DeleteModel(SQLAlchemyMutation): + class Meta: + model = of_type._meta.model + delete = True + Output = of_type + return DeleteModel.Field() diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index e4c3f835..5dd65f8a 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -7,7 +7,7 @@ from ..registry import reset_global_registry from ..fields import SQLAlchemyConnectionField -from ..types import SQLAlchemyObjectType +from ..types import SQLAlchemyObjectType, SQLAlchemyList from .models import Article, Base, Editor, Reporter db = create_engine('sqlite:///test_sqlalchemy.sqlite3') @@ -45,11 +45,10 @@ def setup_fixtures(session): session.commit() -def test_should_query_well(session): +def test_should_query_well_with_graphene_types(session): setup_fixtures(session) class ReporterType(SQLAlchemyObjectType): - class Meta: model = Reporter @@ -93,24 +92,257 @@ def resolve_reporters(self, *args, **kwargs): assert result.data == expected +def test_should_filter_with_sqlalchemy_fields(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = SQLAlchemyList(ReporterType) + + query = ''' + query ReporterQuery { + reporters(firstName: "ABA") { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [{ + 'firstName': 'ABA', + 'lastName': 'X', + 'email': None + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + +def test_should_filter_with_custom_argument(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = SQLAlchemyList(ReporterType, contains_o=graphene.Boolean()) + + def query_reporters(self, info, query, **kwargs): + return query.filter(Reporter.first_name.contains('O') == kwargs['contains_o']) + + query = ''' + query ReporterQuery { + reporters(lastName: "Y", containsO: true) { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [{ + 'firstName': 'ABO', + 'lastName': 'Y', + 'email': None + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + query = ''' + query ReporterQuery { + reporters(containsO: false) { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [{ + 'firstName': 'ABA', + 'lastName': 'X', + 'email': None + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + +def test_should_filter_with_custom_operator(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = SQLAlchemyList(ReporterType, operator='like') + + query = ''' + query ReporterQuery { + reporters(firstName: "%BO%") { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [{ + 'firstName': 'ABO', + 'lastName': 'Y', + 'email': None + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + +def test_should_order_by(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = SQLAlchemyList(ReporterType, order_by='firstName') + + query = ''' + query ReporterQuery { + reporters { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [{ + 'firstName': 'ABA', + 'lastName': 'X', + 'email': None + }, + { + 'firstName': 'ABO', + 'lastName': 'Y', + 'email': None + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + +def test_should_order_by_asc(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = SQLAlchemyList(ReporterType, order_by='first_name ASC') + + query = ''' + query ReporterQuery { + reporters { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [ + { + 'firstName': 'ABA', + 'lastName': 'X', + 'email': None + }, + { + 'firstName': 'ABO', + 'lastName': 'Y', + 'email': None + } + ] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + +def test_should_order_by_desc(session): + setup_fixtures(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporters = SQLAlchemyList(ReporterType, order_by='firstName desc') + + query = ''' + query ReporterQuery { + reporters { + firstName, + lastName, + email + } + } + ''' + expected = { + 'reporters': [ + { + 'firstName': 'ABO', + 'lastName': 'Y', + 'email': None + }, + { + 'firstName': 'ABA', + 'lastName': 'X', + 'email': None + } + ] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected + + def test_should_node(session): setup_fixtures(session) class ReporterNode(SQLAlchemyObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) @classmethod def get_node(cls, id, info): return Reporter(id=2, first_name='Cookie Monster') class ArticleNode(SQLAlchemyObjectType): - class Meta: model = Article - interfaces = (Node, ) + interfaces = (Node,) # @classmethod # def get_node(cls, id, info): @@ -169,9 +401,9 @@ def resolve_article(self, *args, **kwargs): 'email': None, 'articles': { 'edges': [{ - 'node': { - 'headline': 'Hi!' - } + 'node': { + 'headline': 'Hi!' + } }] }, }, @@ -197,10 +429,9 @@ def test_should_custom_identifier(session): setup_fixtures(session) class EditorNode(SQLAlchemyObjectType): - class Meta: model = Editor - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): node = Node.Field() @@ -247,29 +478,25 @@ def test_should_mutate_well(session): setup_fixtures(session) class EditorNode(SQLAlchemyObjectType): - class Meta: model = Editor - interfaces = (Node, ) + interfaces = (Node,) class ReporterNode(SQLAlchemyObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) @classmethod def get_node(cls, id, info): return Reporter(id=2, first_name='Cookie Monster') class ArticleNode(SQLAlchemyObjectType): - class Meta: model = Article - interfaces = (Node, ) + interfaces = (Node,) class CreateArticle(graphene.Mutation): - class Arguments: headline = graphene.String() reporter_id = graphene.ID() diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index 8af3c61e..91fadc98 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -1,6 +1,6 @@ from graphene import ObjectType, Schema, String -from ..utils import get_session +from ..utils import get_session, get_operator_function def test_get_session(): @@ -22,3 +22,28 @@ def resolve_x(self, info): result = schema.execute(query, context_value={'session': session}) assert not result.errors assert result.data['x'] == session + + +def test_get_operator_function(): + func = get_operator_function('=') + assert func(1, 1) + assert not func(1, 2) + + func = get_operator_function('!=') + assert not func(1, 1) + assert func(1, 2) + + class DummyLike: + def __init__(self, value): + self.value = value + + def like(self, other): + return self.value == other + + func = get_operator_function('like') + assert func(DummyLike(1), 1) + assert not func(DummyLike(1), 2) + + func = get_operator_function(lambda x, y: x > y) + assert func(5, 3) + assert not func(3, 5) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 04d1a8a6..4fd09639 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,10 +1,12 @@ +import six from collections import OrderedDict +from functools import partial from sqlalchemy.inspection import inspect as sqlalchemyinspect from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm.exc import NoResultFound -from graphene import Field # , annotate, ResolveInfo +from graphene import Field, Mutation, Argument, List, String # , annotate, ResolveInfo from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.utils import yank_fields_from_attrs @@ -14,7 +16,8 @@ convert_sqlalchemy_relationship, convert_sqlalchemy_hybrid_method) from .registry import Registry, get_global_registry -from .utils import get_query, is_mapped_class, is_mapped_instance +from .utils import get_query, is_mapped_class, is_mapped_instance, get_session, get_operator_function, \ + get_snake_or_camel_attr def construct_fields(model, registry, only_fields, exclude_fields): @@ -139,8 +142,8 @@ def is_type_of(cls, root, info): return True if not is_mapped_instance(root): raise Exception(( - 'Received incompatible instance "{}".' - ).format(root)) + 'Received incompatible instance "{}".' + ).format(root)) return isinstance(root, cls._meta.model) @classmethod @@ -159,3 +162,157 @@ def resolve_id(self, info): # graphene_type = info.parent_type.graphene_type keys = self.__mapper__.primary_key_from_instance(self) return tuple(keys) if len(keys) > 1 else keys[0] + + +class SQLAlchemyList(List): + def __init__(self, of_type, exclude_fields=(), include_fields=(), operator=None, order_by=(), *args, **kwargs): + columns_dict = self.build_columns_dict(of_type) + + if include_fields: + columns_dict = {k: columns_dict[k] for k in include_fields} + for exclude_field in exclude_fields: + if exclude_field in columns_dict.keys(): + del columns_dict[exclude_field] + + kwargs.update(**columns_dict) + kwargs['operator'] = String(description="Operator to use for filtering") + kwargs['order_by'] = List(String, description="Fields to use for results ordering") + + default_operator = get_operator_function(operator) + if isinstance(order_by, six.string_types): + order_by = (order_by, ) + default_order_by = order_by + + def filters_resolver(self, info, **kwargs): + operator = default_operator + if 'operator' in kwargs: + operator = get_operator_function(kwargs['operator']) + + query = of_type.get_query(info) + + for (k, v) in kwargs.items(): + if hasattr(of_type._meta.model, k): + query = query.filter(operator(get_snake_or_camel_attr(of_type._meta.model, k), v)) + + order_by = default_order_by + if 'order_by' in kwargs: + order_by = kwargs['order_by'] + + for order_by_item in order_by: + if order_by_item.lower().endswith(' asc'): + order_by_item = order_by_item[:-len(' asc')] + query = query.order_by(get_snake_or_camel_attr(of_type._meta.model, order_by_item).asc()) + elif order_by_item.lower().endswith(' desc'): + order_by_item = order_by_item[:-len(' desc')] + query = query.order_by(get_snake_or_camel_attr(of_type._meta.model, order_by_item).desc()) + else: + query = query.order_by(get_snake_or_camel_attr(of_type._meta.model, order_by_item)) + + query_transformer = getattr(info.parent_type.graphene_type, 'query_' + info.field_name, False) + if callable(query_transformer): + transformed_query = query_transformer(info.parent_type.graphene_type, info, query, **kwargs) + if transformed_query: + query = transformed_query + + return query.all() + + kwargs['resolver'] = filters_resolver + super(SQLAlchemyList, self).__init__(of_type, *args, **kwargs) + + @staticmethod + def build_columns_dict(of_type): + columns_dict = {} + inspected_model = sqlalchemyinspect(of_type._meta.model) + for column in inspected_model.columns: + column.nullable = True # Set nullable to false to build an optional graph type + graphene_type = convert_sqlalchemy_column(column) + columns_dict[column.name] = graphene_type + return columns_dict + + def __eq__(self, other): + return isinstance(other, SQLAlchemyList) and ( + self.of_type == other.of_type and + self.args == other.args and + self.kwargs == other.kwargs + ) + + +class SQLAlchemyMutationOptions(ObjectTypeOptions): + model = None # type: Model + create = False # type: Boolean + delete = False # type: Boolean + + +class SQLAlchemyMutation(Mutation): + @classmethod + def __init_subclass_with_meta__(cls, model=None, create=False, delete=False, registry=None, only_fields=(), + exclude_fields=None, + **options): + meta = SQLAlchemyMutationOptions(cls) + meta.create = create + meta.delete = delete + meta.model = model + + model_inspect = sqlalchemyinspect(model) + cls._model_inspect = model_inspect + + if not registry: + registry = get_global_registry() + + if meta.delete: + only_fields = [] + exclude_fields = () + for primary_key_column in model_inspect.primary_key: + only_fields.append(primary_key_column.name) + else: + if not isinstance(exclude_fields, list): + if exclude_fields: + exclude_fields = list(exclude_fields) + else: + exclude_fields = [] + + if meta.create: + for primary_key_column in model_inspect.primary_key: + exclude_fields.append(primary_key_column.name) + + for relationship in model_inspect.relationships: + exclude_fields.append(relationship.key) + + arguments = yank_fields_from_attrs( + construct_fields(model, registry, only_fields, exclude_fields), + _as=Argument, + ) + + super(SQLAlchemyMutation, cls).__init_subclass_with_meta__(_meta=meta, arguments=arguments, **options) + + @classmethod + def mutate(cls, self, info, **kwargs): + session = get_session(info.context) + + meta = cls._meta + model = None + + if meta.create: + model = meta.model() + session.add(model) + else: + query = session.query(meta.model) + for primary_key_column in cls._model_inspect.primary_key: + query = query.filter(getattr(meta.model, primary_key_column.key) == kwargs[primary_key_column.name]) + model = query.one() + if meta.delete: + session.delete(model) + + if not meta.delete: + for key, value in kwargs.items(): + setattr(model, key, value) + + session.commit() + + return model + + @classmethod + def Field(cls, *args, **kwargs): + return Field( + cls._meta.output, args=cls._meta.arguments, resolver=cls._meta.resolver + ) diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index e78c9802..4a4fd957 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,6 +1,7 @@ from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError +from graphene.utils.str_converters import to_camel_case, to_snake_case def get_session(context): @@ -18,6 +19,46 @@ def get_query(model, context): return query +_operator_aliases = { + '==': lambda x, y: x == y, + '=': lambda x, y: x == y, + '>=': lambda x, y: x >= y, + '>': lambda x, y: x > y, + '<=': lambda x, y: x <= y, + '<': lambda x, y: x < y, + '!=': lambda x, y: x != y +} + + +def get_operator_function(operator=None): + if not operator: + def operator(x, y): + return x == y + + if not callable(operator): + operator_attr = operator + + if operator in _operator_aliases: + operator = _operator_aliases[operator_attr] + else: + def operator(x, y): + return getattr(x, operator_attr)(y) + + return operator + + +def get_snake_or_camel_attr(model, attr): + try: + return getattr(model, to_snake_case(attr)) + except Exception: + pass + try: + return getattr(model, to_camel_case(attr)) + except Exception: + pass + return getattr(model, attr) + + def is_mapped_class(cls): try: class_mapper(cls)