Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
rafalp committed Dec 14, 2023
1 parent 46a3b0c commit 03dfd5a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
4 changes: 2 additions & 2 deletions ariadne/enums_default_values.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from graphql import GraphQLSchema
from graphql import GraphQLInputField, GraphQLSchema

from .enums_values_visitor import (
GraphQLASTEnumDefaultValueLocation,
Expand Down Expand Up @@ -130,5 +130,5 @@ def visit_schema_enum_default_value(
location.default_value[location.default_value_path] = valid_default
elif location.arg_def:
location.arg_def.default_value = valid_default
elif location.field_def:
elif isinstance(location.field_def, GraphQLInputField):
location.field_def.default_value = valid_default
14 changes: 8 additions & 6 deletions ariadne/enums_values_visitor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, TypeGuard, Union, cast

from graphql import (
EnumValueNode,
Expand All @@ -15,6 +15,7 @@
GraphQLObjectType,
GraphQLSchema,
GraphQLType,
InputValueDefinitionNode,
ListValueNode,
ObjectValueNode,
)
Expand Down Expand Up @@ -134,7 +135,7 @@ def visit_value(
field_def,
arg_name,
arg_def,
src_def.type,
src_type,
src_def.default_value,
)

Expand Down Expand Up @@ -334,7 +335,8 @@ def visit_value(
elif isinstance(arg_def, GraphQLArgument):
src_def = arg_def

default_value_ast = src_def.ast_node.default_value
ast_node = cast(InputValueDefinitionNode, src_def.ast_node)
default_value_ast = ast_node.default_value
if is_graphql_list(src_def.type) and isinstance(
default_value_ast, ListValueNode
):
Expand Down Expand Up @@ -403,7 +405,7 @@ def visit_list_value(
arg_name,
arg_def,
value_type,
value_item,
cast(ListValueNode, value_item),
)

elif isinstance(value_type, GraphQLEnumType):
Expand Down Expand Up @@ -514,7 +516,7 @@ class GraphQLASTEnumDefaultValueLocation:
enum_name: str
enum_value: Any
object_name: str
object_def: Union[GraphQLInputObjectType, GraphQLObjectType]
object_def: Union[GraphQLInputObjectType, GraphQLInterfaceType, GraphQLObjectType]
field_name: str
field_def: Union[GraphQLField, GraphQLInputField]
arg_name: Optional[str] = None
Expand Down Expand Up @@ -547,7 +549,7 @@ def unwrap_nonnull_type(graphql_type: GraphQLType) -> GraphQLType:
return graphql_type


def is_graphql_list(graphql_type: GraphQLType) -> GraphQLType:
def is_graphql_list(graphql_type: GraphQLType) -> TypeGuard[GraphQLList]:
if isinstance(graphql_type, GraphQLNonNull):
return is_graphql_list(graphql_type.of_type)

Expand Down

0 comments on commit 03dfd5a

Please sign in to comment.