Skip to content

Commit

Permalink
Improve auto creation of Graphene Enums.
Browse files Browse the repository at this point in the history
The created Graphene Enums are now registered and reused,
because their names must be unique in a GraphQL schema.
Also the naming conventions for Enum type names (CamelCase)
and options (UPPER_CASE) are applied when creating them.
  • Loading branch information
Cito committed Apr 11, 2019
1 parent 8819829 commit c2c4a77
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 76 deletions.
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .types import SQLAlchemyObjectType
from .fields import SQLAlchemyConnectionField
from .types import SQLAlchemyObjectType
from .utils import get_query, get_session

__version__ = "2.1.1"
Expand Down
18 changes: 7 additions & 11 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,13 @@ def convert_column_to_float(type, column, registry=None):

@convert_sqlalchemy_type.register(types.Enum)
def convert_enum_to_enum(type, column, registry=None):
enum_class = getattr(type, 'enum_class', None)
if enum_class: # Check if an enum.Enum type is used
graphene_type = Enum.from_enum(enum_class)
else: # Nope, just a list of string options
items = zip(type.enums, type.enums)
graphene_type = Enum(type.name, items)
return Field(
graphene_type,
description=get_column_doc(column),
required=not (is_column_nullable(column)),
)
if registry is None:
from .registry import get_global_registry
registry = get_global_registry()
graphene_type = registry.get_type_for_enum(type)
return Field(graphene_type,
description=get_column_doc(column),
required=not(is_column_nullable(column)))


@convert_sqlalchemy_type.register(ChoiceType)
Expand Down
43 changes: 43 additions & 0 deletions graphene_sqlalchemy/registry.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@

from collections import OrderedDict

from sqlalchemy.types import Enum as SQLAlchemyEnumType

from graphene import Enum

from .utils import to_enum_value_name, to_type_name


class Registry(object):
def __init__(self):
self._registry = {}
self._registry_models = {}
self._registry_composites = {}
self._registry_enums = {}

def register(self, cls):
from .types import SQLAlchemyObjectType
Expand All @@ -27,6 +38,38 @@ def register_composite_converter(self, composite, converter):
def get_converter_for_composite(self, composite):
return self._registry_composites.get(composite)

def get_type_for_enum(self, sql_type):
if not isinstance(sql_type, SQLAlchemyEnumType):
raise TypeError(
'Only sqlalchemy.Enum objects can be registered as enum, '
'received "{}"'.format(sql_type))
enum_class = sql_type.enum_class
if enum_class:
name = enum_class.__name__
members = OrderedDict(
(to_enum_value_name(key), value.value)
for key, value in enum_class.__members__.items())
else:
name = sql_type.name
name = to_type_name(name) if name else 'Enum{}'.format(
len(self._registry_enums) + 1)
members = OrderedDict(
(to_enum_value_name(key), key) for key in sql_type.enums)
graphene_type = self._registry_enums.get(name)
if graphene_type:
existing_members = {
key: value.value for key, value
in graphene_type._meta.enum.__members__.items()}
if members != existing_members:
raise TypeError(
'Different enums with the same name "{}":'
' tried to register {}, but {} existed already.'.format(
name, members, existing_members))
else:
graphene_type = Enum(name, members)
self._registry_enums[name] = graphene_type
return graphene_type


registry = None

Expand Down
9 changes: 6 additions & 3 deletions graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import mapper, relationship

PetKind = Enum("cat", "dog", name="pet_kind")

class Hairkind(enum.Enum):

class HairKind(enum.Enum):
LONG = 'long'
SHORT = 'short'

Expand All @@ -32,8 +34,8 @@ class Pet(Base):
__tablename__ = "pets"
id = Column(Integer(), primary_key=True)
name = Column(String(30))
pet_kind = Column(Enum("cat", "dog", name="pet_kind"), nullable=False)
hair_kind = Column(Enum(Hairkind, name="hair_kind"), nullable=False)
pet_kind = Column(PetKind, nullable=False)
hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False)
reporter_id = Column(Integer(), ForeignKey("reporters.id"))


Expand All @@ -43,6 +45,7 @@ class Reporter(Base):
first_name = Column(String(30))
last_name = Column(String(30))
email = Column(String())
favorite_pet_kind = Column(PetKind)
pets = relationship("Pet", secondary=association_table, backref="reporters")
articles = relationship("Article", backref="reporter")
favorite_article = relationship("Article", uselist=False)
Expand Down
41 changes: 30 additions & 11 deletions graphene_sqlalchemy/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,37 @@ def test_should_unicodetext_convert_string():


def test_should_enum_convert_enum():
field = assert_column_conversion(
types.Enum(enum.Enum("one", "two")), graphene.Field
)
field = assert_column_conversion(types.Enum("one", "two"), graphene.Field)
field_type = field.type()
assert field_type.__class__.__name__.startswith("Enum")
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, "two")
assert hasattr(field_type, "ONE")
assert not hasattr(field_type, "one")
assert hasattr(field_type, "TWO")

field = assert_column_conversion(
types.Enum("one", "two", name="two_numbers"), graphene.Field
)
field_type = field.type()
assert field_type.__class__.__name__ == "two_numbers"
assert field_type.__class__.__name__ == "TwoNumbers"
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, "ONE")
assert not hasattr(field_type, "one")
assert hasattr(field_type, "TWO")


def test_conflicting_enum_should_raise_error():
some_type = types.Enum(enum.Enum("ConflictingEnum", "cat cow"))
field = assert_column_conversion(some_type, graphene.Field)
field_type = field.type()
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, "two")
assert hasattr(field_type, "COW")
same_type = types.Enum(enum.Enum("ConflictingEnum", "cat cow"))
field = assert_column_conversion(same_type, graphene.Field)
assert field_type == field.type()
conflicting_type = types.Enum(enum.Enum("ConflictingEnum", "cat horse"))
with raises(TypeError):
assert_column_conversion(conflicting_type, graphene.Field)


def test_should_small_integer_convert_int():
Expand Down Expand Up @@ -272,19 +290,20 @@ def test_should_postgresql_enum_convert():
postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field
)
field_type = field.type()
assert field_type.__class__.__name__ == "two_numbers"
assert field_type.__class__.__name__ == "TwoNumbers"
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, "two")
assert hasattr(field_type, "TWO")


def test_should_postgresql_py_enum_convert():
field = assert_column_conversion(
postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), graphene.Field
postgresql.ENUM(enum.Enum("TwoNumbersEnum", "one two"), name="two_numbers"),
graphene.Field,
)
field_type = field.type()
assert field_type.__class__.__name__ == "TwoNumbers"
assert field_type.__class__.__name__ == "TwoNumbersEnum"
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, "two")
assert hasattr(field_type, "TWO")


def test_should_postgresql_array_convert():
Expand Down
Loading

0 comments on commit c2c4a77

Please sign in to comment.