Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve auto creation of Graphene Enums. #98

Closed
wants to merge 11 commits into from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ __pycache__/

# Distribution / packaging
.Python
.venv/
env/
build/
develop-eggs/
Expand Down
5 changes: 2 additions & 3 deletions examples/flask_sqlalchemy/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#!/usr/bin/env python

from database import db_session, init_db
from flask import Flask
from schema import schema

from flask_graphql import GraphQLView

from .database import db_session, init_db
from .schema import schema

app = Flask(__name__)
app.debug = True

Expand Down
2 changes: 1 addition & 1 deletion examples/flask_sqlalchemy/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def init_db():
# import all modules here that might define models so that
# they will be registered properly on the metadata. Otherwise
# you will have to import them first before calling init_db()
from .models import Department, Employee, Role
from models import Department, Employee, Role
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)

Expand Down
3 changes: 1 addition & 2 deletions examples/flask_sqlalchemy/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from database import Base
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func
from sqlalchemy.orm import backref, relationship

from .database import Base


class Department(Base):
__tablename__ = 'department'
Expand Down
11 changes: 5 additions & 6 deletions examples/flask_sqlalchemy/schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from models import Department as DepartmentModel
from models import Employee as EmployeeModel
from models import Role as RoleModel

import graphene
from graphene import relay
from graphene_sqlalchemy import (SQLAlchemyConnectionField,
SQLAlchemyObjectType, utils)

from .models import Department as DepartmentModel
from .models import Employee as EmployeeModel
from .models import Role as RoleModel


class Department(SQLAlchemyObjectType):
class Meta:
Expand All @@ -26,8 +26,7 @@ class Meta:
interfaces = (relay.Node, )


SortEnumEmployee = utils.sort_enum_for_model(EmployeeModel, 'SortEnumEmployee',
lambda c, d: c.upper() + ('_ASC' if d else '_DESC'))
SortEnumEmployee = utils.get_sort_enum_for_model(EmployeeModel)


class Query(graphene.ObjectType):
Expand Down
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .types import SQLAlchemyObjectType
from .fields import SQLAlchemyConnectionField
from .types import SQLAlchemyObjectType
from .utils import get_query, get_session

__version__ = "2.1.1"
Expand Down
20 changes: 8 additions & 12 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,21 +147,17 @@ def convert_column_to_float(type, column, registry=None):

@convert_sqlalchemy_type.register(types.Enum)
def convert_enum_to_enum(type, column, registry=None):
enum_class = getattr(type, 'enum_class', None)
if enum_class: # Check if an enum.Enum type is used
graphene_type = Enum.from_enum(enum_class)
else: # Nope, just a list of string options
items = zip(type.enums, type.enums)
graphene_type = Enum(type.name, items)
return Field(
graphene_type,
description=get_column_doc(column),
required=not (is_column_nullable(column)),
)
if registry is None:
from .registry import get_global_registry
jnak marked this conversation as resolved.
Show resolved Hide resolved
registry = get_global_registry()
graphene_type = registry.get_type_for_enum(type)
return Field(graphene_type,
description=get_column_doc(column),
required=not(is_column_nullable(column)))


@convert_sqlalchemy_type.register(ChoiceType)
def convert_column_to_enum(type, column, registry=None):
def convert_choice_to_enum(type, column, registry=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this rename? You are still converting a column here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functions must have specific names, because columns with ChoiceType and Enum both convert to Enum. The function name would then be convert_column_to_enum for both. Actually the other functions should also have more specific names, i.e. contain not only the target type, but also the source type.

name = "{}_{}".format(column.table.name, column.name).upper()
return Enum(name, type.choices, description=get_column_doc(column))

Expand Down
4 changes: 2 additions & 2 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from graphene.relay.connection import PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice

from .utils import get_query, sort_argument_for_model
from .utils import get_query, get_sort_argument_for_model


class UnsortedSQLAlchemyConnectionField(ConnectionField):
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(self, type, *args, **kwargs):
# Let super class raise if type is not a Connection
try:
model = type.Edge.node._type._meta.model
kwargs.setdefault("sort", sort_argument_for_model(model))
kwargs.setdefault("sort", get_sort_argument_for_model(model))
except Exception:
raise Exception(
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
Expand Down
111 changes: 100 additions & 11 deletions graphene_sqlalchemy/registry.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,83 @@
from collections import OrderedDict

from sqlalchemy.types import Enum as SQLAlchemyEnumType

from graphene import Enum

from .utils import to_enum_value_name, to_type_name


class Registry(object):
def __init__(self):
def __init__(self, check_duplicate_registration=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the legitimate cases where one would want to disable duplicate registration?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_type_for_model() method (used by convert_sqlalchemy_relationship()) assumes that there is only one graphene type per model. You may want to make sure that this assumption is met and you haven't accidentally associated two different types with one model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case makes sense to me. I was asking about the legitimate cases where you one does not want this function to raise when there are duplicates.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok - good question. Not checking was the behavior until now, and test_types would break because CustomCharacter uses the same model as Character. It only matters if you have relationships. Suggestion: Maybe we should always allow duplicate registration, but memorize all of them instead of only the last. And only when get_type_for_model is called, we throw an error which lists the duplicates. I.e. you only get an error, when this is really a problem.

self.check_duplicate_registration = check_duplicate_registration
self._registry = {}
self._registry_models = {}
self._registry_composites = {}
self._registry_enums = {}
self._registry_sort_params = {}

def register(self, cls):
from .types import SQLAlchemyObjectType

assert issubclass(cls, SQLAlchemyObjectType), (
"Only classes of type SQLAlchemyObjectType can be registered, "
'received "{}"'
).format(cls.__name__)
assert cls._meta.registry == self, "Registry for a Model have to match."
# assert self.get_type_for_model(cls._meta.model) in [None, cls], (
# 'SQLAlchemy model "{}" already associated with '
# 'another type "{}".'
# ).format(cls._meta.model, self._registry[cls._meta.model])
self._registry[cls._meta.model] = cls
if not issubclass(cls, SQLAlchemyObjectType):
raise TypeError(
"Only classes of type SQLAlchemyObjectType can be registered, "
'received "{}"'.format(cls.__name__)
)
if cls._meta.registry != self:
raise TypeError("Registry for a Model have to match.")

registered_cls = (
self._registry.get(cls._meta.model)
if self.check_duplicate_registration
else None
)
if registered_cls:
if cls != registered_cls:
raise TypeError(
"Different object types registered for the same model {}:"
" tried to register {}, but {} existed already.".format(
cls._meta.model, cls, registered_cls
)
)
else:
self._registry[cls._meta.model] = cls

def register_enum(self, name, members):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does register_enum takes a name instead of SQLAlchemyEnumType? It seems inconsistent with register and register_sort_params

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because not all enums are derived from SQLAlchemyEnumTypes. Some are created for sorting purposes, they don't exist as SQLAlchemyEnumTypes. All of the enums should use the same registry because Enum names in the schema must be unique - you want to get an error when a sort enum accidentally has the same name as an existing enum for a column, and you want to have a way for requesting registered enums by name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. How about we instead have this function take an enum or a column that maps to SQLAlchemyEnumType to keep all that logic encapsulated here? In both cases, we would be able to extract a name.

Regarding name collisions, the names need to be unique across all types (vs just within Enums). graphene already raises an error to prevent this. Why do we need to handle this in graphene-sqlalchemy as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I will think about the whole mechanism again and try to create a better-thought-out solution. But maybe I'll do this as a new PR.

graphene_enum = self._registry_enums.get(name)
if graphene_enum:
registered_members = {
key: value.value
for key, value in graphene_enum._meta.enum.__members__.items()
}
if members != registered_members:
raise TypeError(
'Different enums with the same name "{}":'
" tried to register {}, but {} existed already.".format(
name, members, registered_members
)
)
else:
graphene_enum = Enum(name, members)
self._registry_enums[name] = graphene_enum
return graphene_enum

def register_sort_params(self, cls, sort_params):
registered_sort_params = (
self._registry_sort_params.get(cls)
if self.check_duplicate_registration
else None
)
if registered_sort_params:
if registered_sort_params != sort_params:
raise TypeError(
"Different sort args for the same model {}:"
" tried to register {}, but {} existed already.".format(
cls, sort_params, registered_sort_params
)
)
else:
self._registry_sort_params[cls] = sort_params

def get_type_for_model(self, model):
return self._registry.get(model)
Expand All @@ -27,6 +88,34 @@ def register_composite_converter(self, composite, converter):
def get_converter_for_composite(self, composite):
return self._registry_composites.get(composite)

def get_type_for_enum(self, sql_type):
if not isinstance(sql_type, SQLAlchemyEnumType):
raise TypeError(
"Only sqlalchemy.Enum objects can be registered as enum, "
'received "{}"'.format(sql_type)
)
enum_class = sql_type.enum_class
if enum_class:
name = enum_class.__name__
members = OrderedDict(
(to_enum_value_name(key), value.value)
for key, value in enum_class.__members__.items()
)
else:
name = sql_type.name
name = (
to_type_name(name)
if name
else "Enum{}".format(len(self._registry_enums) + 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm in theory all columns should have a name. I wonder if this can only happen in our tests because SQLAlchemy is not fully initialized.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the name of the SQLAlchemy Enum type, not the name of the column. That name may not exist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. What are the benefits of having this function take a SQLAlchemyEnumType instead of a column that maps to a SQLAlchemyEnumType? I can't think of a case where we would have a SQLAlchemyEnumType without its associated column.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You would need to get the SQLAlchemy enum type anyway with column.type. The convert_sqlalchemy_column function already does this. We need the name of the enum type, not the name of the column, because the same enum could be used in different columns.

)
members = OrderedDict(
(to_enum_value_name(key), key) for key in sql_type.enums
)
return self.register_enum(name, members)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does get_type_for_enum need to register the enum? It seems inconsistent with get_type_for_model and get_sort_params_for_model?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming of the methods and division between registry and utils modules could certainly be improved. Need to sleep over it, or we refactor this in a new PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to address this in this PR because this may lead to unexpected behaviors. For example, as is we are getting different enums for the same un-named sql enum:

reg = Registry()
sa_enum_1 = SQLAlchemyEnum('red', 'blue')
assert reg.get_type_for_enum(sa_enum_1) is reg.get_type_for_enum(sa_enum_1)  # fails

Ideally a getter should not have any side effects. Why can't we just return the enum from the registry?

self. _registry_enums[enum]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below, I'll try to come up with a better solution.


def get_sort_params_for_model(self, model):
return self._registry_sort_params.get(model)


registry = None

Expand Down
9 changes: 6 additions & 3 deletions graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import mapper, relationship

PetKind = Enum("cat", "dog", name="pet_kind")

class Hairkind(enum.Enum):

class HairKind(enum.Enum):
LONG = 'long'
SHORT = 'short'

Expand All @@ -32,8 +34,8 @@ class Pet(Base):
__tablename__ = "pets"
id = Column(Integer(), primary_key=True)
name = Column(String(30))
pet_kind = Column(Enum("cat", "dog", name="pet_kind"), nullable=False)
hair_kind = Column(Enum(Hairkind, name="hair_kind"), nullable=False)
pet_kind = Column(PetKind, nullable=False)
hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False)
reporter_id = Column(Integer(), ForeignKey("reporters.id"))


Expand All @@ -43,6 +45,7 @@ class Reporter(Base):
first_name = Column(String(30))
last_name = Column(String(30))
email = Column(String())
favorite_pet_kind = Column(PetKind)
pets = relationship("Pet", secondary=association_table, backref="reporters")
articles = relationship("Article", backref="reporter")
favorite_article = relationship("Article", uselist=False)
Expand Down
43 changes: 31 additions & 12 deletions graphene_sqlalchemy/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,37 @@ def test_should_unicodetext_convert_string():


def test_should_enum_convert_enum():
field = assert_column_conversion(
types.Enum(enum.Enum("one", "two")), graphene.Field
)
field = assert_column_conversion(types.Enum("one", "two"), graphene.Field)
field_type = field.type()
assert field_type.__class__.__name__.startswith("Enum")
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, "two")
assert hasattr(field_type, "ONE")
assert not hasattr(field_type, "one")
assert hasattr(field_type, "TWO")

field = assert_column_conversion(
types.Enum("one", "two", name="two_numbers"), graphene.Field
)
field_type = field.type()
assert field_type.__class__.__name__ == "two_numbers"
assert field_type.__class__.__name__ == "TwoNumbers"
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, "ONE")
assert not hasattr(field_type, "one")
assert hasattr(field_type, "TWO")


def test_conflicting_enum_should_raise_error():
some_type = types.Enum(enum.Enum("ConflictingEnum", "cat cow"))
field = assert_column_conversion(some_type, graphene.Field)
field_type = field.type()
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, "two")
assert hasattr(field_type, "COW")
same_type = types.Enum(enum.Enum("ConflictingEnum", "cat cow"))
field = assert_column_conversion(same_type, graphene.Field)
assert field_type == field.type()
conflicting_type = types.Enum(enum.Enum("ConflictingEnum", "cat horse"))
with raises(TypeError):
assert_column_conversion(conflicting_type, graphene.Field)


def test_should_small_integer_convert_int():
Expand Down Expand Up @@ -272,19 +290,20 @@ def test_should_postgresql_enum_convert():
postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field
)
field_type = field.type()
assert field_type.__class__.__name__ == "two_numbers"
assert field_type.__class__.__name__ == "TwoNumbers"
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, "two")
assert hasattr(field_type, "TWO")


def test_should_postgresql_py_enum_convert():
field = assert_column_conversion(
postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), graphene.Field
postgresql.ENUM(enum.Enum("TwoNumbersEnum", "one two"), name="two_numbers"),
graphene.Field,
)
field_type = field.type()
assert field_type.__class__.__name__ == "TwoNumbers"
assert field_type.__class__.__name__ == "TwoNumbersEnum"
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, "two")
assert hasattr(field_type, "TWO")


def test_should_postgresql_array_convert():
Expand All @@ -304,7 +323,7 @@ def test_should_postgresql_hstore_convert():


def test_should_composite_convert():
class CompositeClass(object):
class CompositeClass:
def __init__(self, col1, col2):
self.col1 = col1
self.col2 = col2
Expand Down
8 changes: 4 additions & 4 deletions graphene_sqlalchemy/tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ..fields import SQLAlchemyConnectionField
from ..types import SQLAlchemyObjectType
from ..utils import sort_argument_for_model
from ..utils import get_sort_argument_for_model
from .models import Editor
from .models import Pet as PetModel

Expand All @@ -22,7 +22,7 @@ class Meta:
def test_sort_added_by_default():
arg = SQLAlchemyConnectionField(PetConn)
assert "sort" in arg.args
assert arg.args["sort"] == sort_argument_for_model(PetModel)
assert arg.args["sort"] == get_sort_argument_for_model(PetModel)


def test_sort_can_be_removed():
Expand All @@ -31,8 +31,8 @@ def test_sort_can_be_removed():


def test_custom_sort():
arg = SQLAlchemyConnectionField(PetConn, sort=sort_argument_for_model(Editor))
assert arg.args["sort"] == sort_argument_for_model(Editor)
arg = SQLAlchemyConnectionField(PetConn, sort=get_sort_argument_for_model(Editor))
assert arg.args["sort"] == get_sort_argument_for_model(Editor)


def test_init_raises():
Expand Down
Loading