diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index d3ae812..7c5330b 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -2,13 +2,12 @@ import sys import typing import uuid -import warnings from decimal import Decimal -from functools import singledispatch -from typing import Any, cast +from typing import Any, Optional, Union, cast from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import interfaces, strategies import graphene @@ -17,16 +16,31 @@ from .batching import get_batch_resolver from .enums import enum_for_sa_enum from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory -from .registry import get_global_registry +from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, DummyImport, + column_type_eq, registry_sqlalchemy_model_from_str, safe_isinstance, + safe_issubclass, singledispatchbymatchfunction, - value_equals, ) +# Import path changed in 1.4 +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.orm import DeclarativeMeta +else: + from sqlalchemy.ext.declarative import DeclarativeMeta + +# We just use MapperProperties for type hints, they don't exist in sqlalchemy < 1.4 +try: + from sqlalchemy import MapperProperty +except ImportError: + # sqlalchemy < 1.4 + MapperProperty = Any + try: from typing import ForwardRef except ImportError: @@ -207,10 +221,15 @@ def inner(fn): def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): column = column_prop.columns[0] - + column_type = getattr(column, "type", None) + # The converter expects a type to find the right conversion function. + # If we get an instance instead, we need to convert it to a type. + # The conversion function will still be able to access the instance via the column argument. + if not isinstance(column_type, type): + column_type = type(column_type) field_kwargs.setdefault( "type_", - convert_sqlalchemy_type(getattr(column, "type", None), column, registry), + convert_sqlalchemy_type(column_type, column=column, registry=registry), ) field_kwargs.setdefault("required", not is_column_nullable(column)) field_kwargs.setdefault("description", get_column_doc(column)) @@ -218,86 +237,178 @@ def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): return graphene.Field(resolver=resolver, **field_kwargs) -@singledispatch -def convert_sqlalchemy_type(type, column, registry=None): - raise Exception( - "Don't know how to convert the SQLAlchemy field %s (%s)" - % (column, column.__class__) +@singledispatchbymatchfunction +def convert_sqlalchemy_type( # noqa + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + # No valid type found, raise an error + + raise TypeError( + "Don't know how to convert the SQLAlchemy field %s (%s, %s). " + "Please add a type converter or set the type manually using ORMField(type_=your_type)" + % (column, column.__class__ or "no column provided", type_arg) ) -@convert_sqlalchemy_type.register(sqa_types.String) -@convert_sqlalchemy_type.register(sqa_types.Text) -@convert_sqlalchemy_type.register(sqa_types.Unicode) -@convert_sqlalchemy_type.register(sqa_types.UnicodeText) -@convert_sqlalchemy_type.register(postgresql.INET) -@convert_sqlalchemy_type.register(postgresql.CIDR) -@convert_sqlalchemy_type.register(sqa_utils.TSVectorType) -@convert_sqlalchemy_type.register(sqa_utils.EmailType) -@convert_sqlalchemy_type.register(sqa_utils.URLType) -@convert_sqlalchemy_type.register(sqa_utils.IPAddressType) -def convert_column_to_string(type, column, registry=None): +@convert_sqlalchemy_type.register(safe_isinstance(DeclarativeMeta)) +def convert_sqlalchemy_model_using_registry( + type_arg: Any, registry: Registry = None, **kwargs +): + registry_ = registry or get_global_registry() + + def get_type_from_registry(): + existing_graphql_type = registry_.get_type_for_model(type_arg) + if existing_graphql_type: + return existing_graphql_type + + raise TypeError( + "No model found in Registry for type %s. " + "Only references to SQLAlchemy Models mapped to " + "SQLAlchemyObjectTypes are allowed." % type_arg + ) + + return get_type_from_registry() + + +@convert_sqlalchemy_type.register(safe_issubclass(graphene.ObjectType)) +def convert_object_type(type_arg: Any, **kwargs): + return type_arg + + +@convert_sqlalchemy_type.register(safe_issubclass(graphene.Scalar)) +def convert_scalar_type(type_arg: Any, **kwargs): + return type_arg + + +@convert_sqlalchemy_type.register(column_type_eq(str)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.String)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Text)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Unicode)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.UnicodeText)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.INET)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.CIDR)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.TSVectorType)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.EmailType)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.URLType)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.IPAddressType)) +def convert_column_to_string(type_arg: Any, **kwargs): return graphene.String -@convert_sqlalchemy_type.register(postgresql.UUID) -@convert_sqlalchemy_type.register(sqa_utils.UUIDType) -def convert_column_to_uuid(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(postgresql.UUID)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.UUIDType)) +@convert_sqlalchemy_type.register(column_type_eq(uuid.UUID)) +def convert_column_to_uuid( + type_arg: Any, + **kwargs, +): return graphene.UUID -@convert_sqlalchemy_type.register(sqa_types.DateTime) -def convert_column_to_datetime(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.DateTime)) +@convert_sqlalchemy_type.register(column_type_eq(datetime.datetime)) +def convert_column_to_datetime( + type_arg: Any, + **kwargs, +): return graphene.DateTime -@convert_sqlalchemy_type.register(sqa_types.Time) -def convert_column_to_time(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Time)) +@convert_sqlalchemy_type.register(column_type_eq(datetime.time)) +def convert_column_to_time( + type_arg: Any, + **kwargs, +): return graphene.Time -@convert_sqlalchemy_type.register(sqa_types.Date) -def convert_column_to_date(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Date)) +@convert_sqlalchemy_type.register(column_type_eq(datetime.date)) +def convert_column_to_date( + type_arg: Any, + **kwargs, +): return graphene.Date -@convert_sqlalchemy_type.register(sqa_types.SmallInteger) -@convert_sqlalchemy_type.register(sqa_types.Integer) -def convert_column_to_int_or_id(type, column, registry=None): - return graphene.ID if column.primary_key else graphene.Int +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.SmallInteger)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Integer)) +@convert_sqlalchemy_type.register(column_type_eq(int)) +def convert_column_to_int_or_id( + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + # fixme drop the primary key processing from here in another pr + if column is not None: + if getattr(column, "primary_key", False) is True: + return graphene.ID + return graphene.Int -@convert_sqlalchemy_type.register(sqa_types.Boolean) -def convert_column_to_boolean(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Boolean)) +@convert_sqlalchemy_type.register(column_type_eq(bool)) +def convert_column_to_boolean( + type_arg: Any, + **kwargs, +): return graphene.Boolean -@convert_sqlalchemy_type.register(sqa_types.Float) -@convert_sqlalchemy_type.register(sqa_types.Numeric) -@convert_sqlalchemy_type.register(sqa_types.BigInteger) -def convert_column_to_float(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(float)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Float)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Numeric)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.BigInteger)) +def convert_column_to_float( + type_arg: Any, + **kwargs, +): return graphene.Float -@convert_sqlalchemy_type.register(sqa_types.Enum) -def convert_enum_to_enum(type, column, registry=None): - return lambda: enum_for_sa_enum(type, registry or get_global_registry()) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.ENUM)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Enum)) +def convert_enum_to_enum( + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + if column is None or isinstance(column, hybrid_property): + raise Exception("SQL-Enum conversion requires a column") + + return lambda: enum_for_sa_enum(column.type, registry or get_global_registry()) # TODO Make ChoiceType conversion consistent with other enums -@convert_sqlalchemy_type.register(sqa_utils.ChoiceType) -def convert_choice_to_enum(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.ChoiceType)) +def convert_choice_to_enum( + type_arg: sqa_utils.ChoiceType, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + **kwargs, +): + if column is None or isinstance(column, hybrid_property): + raise Exception("ChoiceType conversion requires a column") + name = "{}_{}".format(column.table.name, column.key).upper() - if isinstance(type.type_impl, EnumTypeImpl): + if isinstance(column.type.type_impl, EnumTypeImpl): # type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta # do not use from_enum here because we can have more than one enum column in table - return graphene.Enum(name, list((v.name, v.value) for v in type.choices)) + return graphene.Enum(name, list((v.name, v.value) for v in column.type.choices)) else: - return graphene.Enum(name, type.choices) + return graphene.Enum(name, column.type.choices) -@convert_sqlalchemy_type.register(sqa_utils.ScalarListType) -def convert_scalar_list_to_list(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.ScalarListType)) +def convert_scalar_list_to_list( + type_arg: Any, + **kwargs, +): return graphene.List(graphene.String) @@ -309,108 +420,79 @@ def init_array_list_recursive(inner_type, n): ) -@convert_sqlalchemy_type.register(sqa_types.ARRAY) -@convert_sqlalchemy_type.register(postgresql.ARRAY) -def convert_array_to_list(_type, column, registry=None): - inner_type = convert_sqlalchemy_type(column.type.item_type, column) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.ARRAY)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.ARRAY)) +def convert_array_to_list( + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + if column is None or isinstance(column, hybrid_property): + raise Exception("SQL-Array conversion requires a column") + item_type = column.type.item_type + if not isinstance(item_type, type): + item_type = type(item_type) + inner_type = convert_sqlalchemy_type( + item_type, column=column, registry=registry, **kwargs + ) return graphene.List( init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1) ) -@convert_sqlalchemy_type.register(postgresql.HSTORE) -@convert_sqlalchemy_type.register(postgresql.JSON) -@convert_sqlalchemy_type.register(postgresql.JSONB) -def convert_json_to_string(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(postgresql.HSTORE)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.JSON)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.JSONB)) +def convert_json_to_string( + type_arg: Any, + **kwargs, +): return JSONString -@convert_sqlalchemy_type.register(sqa_utils.JSONType) -@convert_sqlalchemy_type.register(sqa_types.JSON) -def convert_json_type_to_string(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.JSONType)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.JSON)) +def convert_json_type_to_string( + type_arg: Any, + **kwargs, +): return JSONString -@convert_sqlalchemy_type.register(sqa_types.Variant) -def convert_variant_to_impl_type(type, column, registry=None): - return convert_sqlalchemy_type(type.impl, column, registry=registry) - - -@singledispatchbymatchfunction -def convert_sqlalchemy_hybrid_property_type(arg: Any): - existing_graphql_type = get_global_registry().get_type_for_model(arg) - if existing_graphql_type: - return existing_graphql_type - - if isinstance(arg, type(graphene.ObjectType)): - return arg - - if isinstance(arg, type(graphene.Scalar)): - return arg - - # No valid type found, warn and fall back to graphene.String - warnings.warn( - f'I don\'t know how to generate a GraphQL type out of a "{arg}" type.' - 'Falling back to "graphene.String"' +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Variant)) +def convert_variant_to_impl_type( + type_arg: sqa_types.Variant, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + if column is None or isinstance(column, hybrid_property): + raise Exception("Vaiant conversion requires a column") + + type_impl = column.type.impl + if not isinstance(type_impl, type): + type_impl = type(type_impl) + return convert_sqlalchemy_type( + type_impl, column=column, registry=registry, **kwargs ) - return graphene.String - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(str)) -def convert_sqlalchemy_hybrid_property_type_str(arg): - return graphene.String - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(int)) -def convert_sqlalchemy_hybrid_property_type_int(arg): - return graphene.Int - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(float)) -def convert_sqlalchemy_hybrid_property_type_float(arg): - return graphene.Float -@convert_sqlalchemy_hybrid_property_type.register(value_equals(Decimal)) -def convert_sqlalchemy_hybrid_property_type_decimal(arg): +@convert_sqlalchemy_type.register(column_type_eq(Decimal)) +def convert_sqlalchemy_hybrid_property_type_decimal(type_arg: Any, **kwargs): # The reason Decimal should be serialized as a String is because this is a # base10 type used in things like money, and string allows it to not # lose precision (which would happen if we downcasted to a Float, for example) return graphene.String -@convert_sqlalchemy_hybrid_property_type.register(value_equals(bool)) -def convert_sqlalchemy_hybrid_property_type_bool(arg): - return graphene.Boolean - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.datetime)) -def convert_sqlalchemy_hybrid_property_type_datetime(arg): - return graphene.DateTime - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.date)) -def convert_sqlalchemy_hybrid_property_type_date(arg): - return graphene.Date - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.time)) -def convert_sqlalchemy_hybrid_property_type_time(arg): - return graphene.Time - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(uuid.UUID)) -def convert_sqlalchemy_hybrid_property_type_uuid(arg): - return graphene.UUID - - -def is_union(arg) -> bool: +def is_union(type_arg: Any, **kwargs) -> bool: if sys.version_info >= (3, 10): from types import UnionType - if isinstance(arg, UnionType): + if isinstance(type_arg, UnionType): return True - return getattr(arg, "__origin__", None) == typing.Union + return getattr(type_arg, "__origin__", None) == typing.Union def graphene_union_for_py_union( @@ -421,14 +503,14 @@ def graphene_union_for_py_union( if union_type is None: # Union Name is name of the three union_name = "".join(sorted(obj_type._meta.name for obj_type in obj_types)) - union_type = graphene.Union(union_name, obj_types) + union_type = graphene.Union.create_type(union_name, types=obj_types) registry.register_union_type(union_type, obj_types) return union_type -@convert_sqlalchemy_hybrid_property_type.register(is_union) -def convert_sqlalchemy_hybrid_property_union(arg): +@convert_sqlalchemy_type.register(is_union) +def convert_sqlalchemy_hybrid_property_union(type_arg: Any, **kwargs): """ Converts Unions (Union[X,Y], or X | Y for python > 3.10) to the corresponding graphene schema object. Since Optionals are internally represented as Union[T, ], they are handled here as well. @@ -444,11 +526,11 @@ def convert_sqlalchemy_hybrid_property_union(arg): # Option is actually Union[T, ] # Just get the T out of the list of arguments by filtering out the NoneType - nested_types = list(filter(lambda x: not type(None) == x, arg.__args__)) + nested_types = list(filter(lambda x: not type(None) == x, type_arg.__args__)) # Map the graphene types to the nested types. # We use convert_sqlalchemy_hybrid_property_type instead of the registry to account for ForwardRefs, Lists,... - graphene_types = list(map(convert_sqlalchemy_hybrid_property_type, nested_types)) + graphene_types = list(map(convert_sqlalchemy_type, nested_types)) # If only one type is left after filtering out NoneType, the Union was an Optional if len(graphene_types) == 1: @@ -471,20 +553,20 @@ def convert_sqlalchemy_hybrid_property_union(arg): ) -@convert_sqlalchemy_hybrid_property_type.register( +@convert_sqlalchemy_type.register( lambda x: getattr(x, "__origin__", None) in [list, typing.List] ) -def convert_sqlalchemy_hybrid_property_type_list_t(arg): +def convert_sqlalchemy_hybrid_property_type_list_t(type_arg: Any, **kwargs): # type is either list[T] or List[T], generic argument at __args__[0] - internal_type = arg.__args__[0] + internal_type = type_arg.__args__[0] - graphql_internal_type = convert_sqlalchemy_hybrid_property_type(internal_type) + graphql_internal_type = convert_sqlalchemy_type(internal_type, **kwargs) return graphene.List(graphql_internal_type) -@convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(ForwardRef)) -def convert_sqlalchemy_hybrid_property_forwardref(arg): +@convert_sqlalchemy_type.register(safe_isinstance(ForwardRef)) +def convert_sqlalchemy_hybrid_property_forwardref(type_arg: Any, **kwargs): """ Generate a lambda that will resolve the type at runtime This takes care of self-references @@ -492,26 +574,36 @@ def convert_sqlalchemy_hybrid_property_forwardref(arg): from .registry import get_global_registry def forward_reference_solver(): - model = registry_sqlalchemy_model_from_str(arg.__forward_arg__) + model = registry_sqlalchemy_model_from_str(type_arg.__forward_arg__) if not model: - return graphene.String + raise TypeError( + "No model found in Registry for forward reference for type %s. " + "Only forward references to other SQLAlchemy Models mapped to " + "SQLAlchemyObjectTypes are allowed." % type_arg + ) # Always fall back to string if no ForwardRef type found. return get_global_registry().get_type_for_model(model) return forward_reference_solver -@convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(str)) -def convert_sqlalchemy_hybrid_property_bare_str(arg): +@convert_sqlalchemy_type.register(safe_isinstance(str)) +def convert_sqlalchemy_hybrid_property_bare_str(type_arg: str, **kwargs): """ Convert Bare String into a ForwardRef """ - return convert_sqlalchemy_hybrid_property_type(ForwardRef(arg)) + return convert_sqlalchemy_type(ForwardRef(type_arg), **kwargs) def convert_hybrid_property_return_type(hybrid_prop): # Grab the original method's return type annotations from inside the hybrid property - return_type_annotation = hybrid_prop.fget.__annotations__.get("return", str) + return_type_annotation = hybrid_prop.fget.__annotations__.get("return", None) + if not return_type_annotation: + raise TypeError( + "Cannot convert hybrid property type {} to a valid graphene type. " + "Please make sure to annotate the return type of the hybrid property or use the " + "type_ attribute of ORMField to set the type.".format(hybrid_prop) + ) - return convert_sqlalchemy_hybrid_property_type(return_type_annotation) + return convert_sqlalchemy_type(return_type_annotation, column=hybrid_prop) diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index cc4b02b..3c46301 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -83,13 +83,13 @@ def get_sort_enum_for_object_type(self, obj_type: graphene.ObjectType): return self._registry_sort_enums.get(obj_type) def register_union_type( - self, union: graphene.Union, obj_types: List[Type[graphene.ObjectType]] + self, union: Type[graphene.Union], obj_types: List[Type[graphene.ObjectType]] ): - if not isinstance(union, graphene.Union): + if not issubclass(union, graphene.Union): raise TypeError("Expected graphene.Union, but got: {!r}".format(union)) for obj_type in obj_types: - if not isinstance(obj_type, type(graphene.ObjectType)): + if not issubclass(obj_type, graphene.ObjectType): raise TypeError( "Expected Graphene ObjectType, but got: {!r}".format(obj_type) ) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index ee28658..9531aaa 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -4,7 +4,7 @@ import enum import uuid from decimal import Decimal -from typing import List, Optional, Tuple +from typing import List, Optional from sqlalchemy import ( Column, @@ -88,12 +88,12 @@ class Reporter(Base): favorite_article = relationship("Article", uselist=False, lazy="selectin") @hybrid_property - def hybrid_prop_with_doc(self): + def hybrid_prop_with_doc(self) -> str: """Docstring test""" return self.first_name @hybrid_property - def hybrid_prop(self): + def hybrid_prop(self) -> str: return self.first_name @hybrid_property @@ -253,11 +253,6 @@ def hybrid_prop_first_shopping_cart_item(self) -> ShoppingCartItem: def hybrid_prop_shopping_cart_item_list(self) -> List[ShoppingCartItem]: return [ShoppingCartItem(id=1), ShoppingCartItem(id=2)] - # Unsupported Type - @hybrid_property - def hybrid_prop_unsupported_type_tuple(self) -> Tuple[str, str]: - return "this will actually", "be a string" - # Self-references @hybrid_property diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index b9a1c15..e903396 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,6 +1,6 @@ import enum import sys -from typing import Dict, Union +from typing import Dict, Tuple, Union import pytest import sqlalchemy_utils as sqa_utils @@ -20,6 +20,7 @@ convert_sqlalchemy_composite, convert_sqlalchemy_hybrid_method, convert_sqlalchemy_relationship, + convert_sqlalchemy_type, ) from ..fields import UnsortedSQLAlchemyConnectionField, default_connection_field_factory from ..registry import Registry, get_global_registry @@ -78,6 +79,110 @@ def prop_method() -> int: assert get_hybrid_property_type(prop_method).type == graphene.Int +def test_hybrid_unknown_annotation(): + @hybrid_property + def hybrid_prop(self): + return "This should fail" + + with pytest.raises( + TypeError, + match=r"(.*)Please make sure to annotate the return type of the hybrid property or use the " + "type_ attribute of ORMField to set the type.(.*)", + ): + get_hybrid_property_type(hybrid_prop) + + +def test_hybrid_prop_no_type_annotation(): + @hybrid_property + def hybrid_prop(self) -> Tuple[str, str]: + return "This should Fail because", "we don't support Tuples in GQL" + + with pytest.raises( + TypeError, match=r"(.*)Don't know how to convert the SQLAlchemy field(.*)" + ): + get_hybrid_property_type(hybrid_prop) + + +def test_hybrid_invalid_forward_reference(): + class MyTypeNotInRegistry: + pass + + @hybrid_property + def hybrid_prop(self) -> "MyTypeNotInRegistry": + return MyTypeNotInRegistry() + + with pytest.raises( + TypeError, + match=r"(.*)Only forward references to other SQLAlchemy Models mapped to " + "SQLAlchemyObjectTypes are allowed.(.*)", + ): + get_hybrid_property_type(hybrid_prop).type + + +def test_hybrid_prop_object_type(): + class MyObjectType(graphene.ObjectType): + string = graphene.String() + + @hybrid_property + def hybrid_prop(self) -> MyObjectType: + return MyObjectType() + + assert get_hybrid_property_type(hybrid_prop).type == MyObjectType + + +def test_hybrid_prop_scalar_type(): + @hybrid_property + def hybrid_prop(self) -> graphene.String: + return "This should work" + + assert get_hybrid_property_type(hybrid_prop).type == graphene.String + + +def test_hybrid_prop_not_mapped_to_graphene_type(): + @hybrid_property + def hybrid_prop(self) -> ShoppingCartItem: + return "This shouldn't work" + + with pytest.raises(TypeError, match=r"(.*)No model found in Registry for type(.*)"): + get_hybrid_property_type(hybrid_prop).type + + +def test_hybrid_prop_mapped_to_graphene_type(): + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + + @hybrid_property + def hybrid_prop(self) -> ShoppingCartItem: + return "Dummy return value" + + get_hybrid_property_type(hybrid_prop).type == ShoppingCartType + + +def test_hybrid_prop_forward_ref_not_mapped_to_graphene_type(): + @hybrid_property + def hybrid_prop(self) -> "ShoppingCartItem": + return "This shouldn't work" + + with pytest.raises( + TypeError, + match=r"(.*)No model found in Registry for forward reference for type(.*)", + ): + get_hybrid_property_type(hybrid_prop).type + + +def test_hybrid_prop_forward_ref_mapped_to_graphene_type(): + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + + @hybrid_property + def hybrid_prop(self) -> "ShoppingCartItem": + return "Dummy return value" + + get_hybrid_property_type(hybrid_prop).type == ShoppingCartType + + @pytest.mark.skipif( sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" ) @@ -131,11 +236,10 @@ def prop_method_2() -> Union[ShoppingCartType, PetType]: field_type_1 = get_hybrid_property_type(prop_method).type field_type_2 = get_hybrid_property_type(prop_method_2).type - assert isinstance(field_type_1, graphene.Union) + assert issubclass(field_type_1, graphene.Union) + assert field_type_1._meta.types == [PetType, ShoppingCartType] assert field_type_1 is field_type_2 - # TODO verify types of the union - @pytest.mark.skipif( sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" @@ -164,10 +268,16 @@ def prop_method_2() -> ShoppingCartType | PetType: field_type_1 = get_hybrid_property_type(prop_method).type field_type_2 = get_hybrid_property_type(prop_method_2).type - assert isinstance(field_type_1, graphene.Union) + assert issubclass(field_type_1, graphene.Union) + assert field_type_1._meta.types == [PetType, ShoppingCartType] assert field_type_1 is field_type_2 +def test_should_unknown_type_raise_error(): + with pytest.raises(Exception): + converted_type = convert_sqlalchemy_type(ZeroDivisionError) # noqa + + def test_should_datetime_convert_datetime(): assert get_field(types.DateTime()).type == graphene.DateTime @@ -667,7 +777,6 @@ class Meta: ), "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, "hybrid_prop_shopping_cart_item_list": graphene.List(ShoppingCartItemType), - "hybrid_prop_unsupported_type_tuple": graphene.String, # Self Referential List "hybrid_prop_self_referential": ShoppingCartType, "hybrid_prop_self_referential_list": graphene.List(ShoppingCartType), diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index 68b5404..e54f08b 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -142,7 +142,7 @@ class Meta: model = Reporter union_types = [PetType, ReporterType] - union = graphene.Union("ReporterPet", tuple(union_types)) + union = graphene.Union.create_type("ReporterPet", types=tuple(union_types)) reg.register_union_type(union, union_types) @@ -155,7 +155,7 @@ def test_register_union_scalar(): reg = Registry() union_types = [graphene.String, graphene.Int] - union = graphene.Union("StringInt", tuple(union_types)) + union = graphene.Union.create_type("StringInt", types=union_types) re_err = r"Expected Graphene ObjectType, but got: .*String.*" with pytest.raises(TypeError, match=re_err): diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 62c71d8..1bf361f 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -196,18 +196,17 @@ def __call__(self, *args, **kwargs): # No match, using default. return self.default(*args, **kwargs) - def register(self, matcher_function: Callable[[Any], bool]): - def grab_function_from_outside(f): - self.registry[matcher_function] = f - return self + def register(self, matcher_function: Callable[[Any], bool], func=None): + if func is None: + return lambda f: self.register(matcher_function, f) + self.registry[matcher_function] = func + return func - return grab_function_from_outside - -def value_equals(value): +def column_type_eq(value: Any) -> Callable[[Any], bool]: """A simple function that makes the equality based matcher functions for SingleDispatchByMatchFunction prettier""" - return lambda x: x == value + return lambda x: (x == value) def safe_isinstance(cls): @@ -220,6 +219,16 @@ def safe_isinstance_checker(arg): return safe_isinstance_checker +def safe_issubclass(cls): + def safe_issubclass_checker(arg): + try: + return issubclass(arg, cls) + except TypeError: + pass + + return safe_issubclass_checker + + def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: from graphene_sqlalchemy.registry import get_global_registry