diff --git a/ariadne/enums_default_values.py b/ariadne/enums_default_values.py index c1b01ca73..8b8808e9a 100644 --- a/ariadne/enums_default_values.py +++ b/ariadne/enums_default_values.py @@ -1,4 +1,4 @@ -from graphql import GraphQLSchema +from graphql import GraphQLInputField, GraphQLSchema from .enums_values_visitor import ( GraphQLASTEnumDefaultValueLocation, @@ -130,5 +130,5 @@ def visit_schema_enum_default_value( location.default_value[location.default_value_path] = valid_default elif location.arg_def: location.arg_def.default_value = valid_default - elif location.field_def: + elif isinstance(location.field_def, GraphQLInputField): location.field_def.default_value = valid_default diff --git a/ariadne/enums_values_visitor.py b/ariadne/enums_values_visitor.py index f6d179c21..501c7a396 100644 --- a/ariadne/enums_values_visitor.py +++ b/ariadne/enums_values_visitor.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, TypeGuard, Union, cast from graphql import ( EnumValueNode, @@ -15,6 +15,7 @@ GraphQLObjectType, GraphQLSchema, GraphQLType, + InputValueDefinitionNode, ListValueNode, ObjectValueNode, ) @@ -134,7 +135,7 @@ def visit_value( field_def, arg_name, arg_def, - src_def.type, + src_type, src_def.default_value, ) @@ -334,7 +335,8 @@ def visit_value( elif isinstance(arg_def, GraphQLArgument): src_def = arg_def - default_value_ast = src_def.ast_node.default_value + ast_node = cast(InputValueDefinitionNode, src_def.ast_node) + default_value_ast = ast_node.default_value if is_graphql_list(src_def.type) and isinstance( default_value_ast, ListValueNode ): @@ -403,7 +405,7 @@ def visit_list_value( arg_name, arg_def, value_type, - value_item, + cast(ListValueNode, value_item), ) elif isinstance(value_type, GraphQLEnumType): @@ -514,7 +516,7 @@ class GraphQLASTEnumDefaultValueLocation: enum_name: str enum_value: Any object_name: str - object_def: Union[GraphQLInputObjectType, GraphQLObjectType] + object_def: Union[GraphQLInputObjectType, GraphQLInterfaceType, GraphQLObjectType] field_name: str field_def: Union[GraphQLField, GraphQLInputField] arg_name: Optional[str] = None @@ -547,7 +549,7 @@ def unwrap_nonnull_type(graphql_type: GraphQLType) -> GraphQLType: return graphql_type -def is_graphql_list(graphql_type: GraphQLType) -> GraphQLType: +def is_graphql_list(graphql_type: GraphQLType) -> TypeGuard[GraphQLList]: if isinstance(graphql_type, GraphQLNonNull): return is_graphql_list(graphql_type.of_type)