From 9390ad3cec36695297a4c8c2da4c2012952fbdea Mon Sep 17 00:00:00 2001 From: Abhijit Bhole Date: Mon, 5 Nov 2018 07:24:55 +0530 Subject: [PATCH 1/2] One test not working --- graphql/execution/executor.py | 238 +++++++++++++++++++--------------- graphql/type/__init__.py | 1 + graphql/type/directives.py | 14 ++ 3 files changed, 145 insertions(+), 108 deletions(-) diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index 1b5e884e..4fe47151 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -21,6 +21,7 @@ GraphQLScalarType, GraphQLSchema, GraphQLUnionType, + GraphQLDeferDirective, ) from .base import ( ExecutionContext, @@ -54,17 +55,17 @@ def subscribe(*args, **kwargs): def execute( - schema, # type: GraphQLSchema - document_ast, # type: Document - root=None, # type: Any - context=None, # type: Optional[Any] - variables=None, # type: Optional[Any] - operation_name=None, # type: Optional[str] - executor=None, # type: Any - return_promise=False, # type: bool - middleware=None, # type: Optional[Any] - allow_subscriptions=False, # type: bool - **options # type: Any + schema, # type: GraphQLSchema + document_ast, # type: Document + root=None, # type: Any + context=None, # type: Optional[Any] + variables=None, # type: Optional[Any] + operation_name=None, # type: Optional[str] + executor=None, # type: Any + return_promise=False, # type: bool + middleware=None, # type: Optional[Any] + allow_subscriptions=False, # type: bool + **options # type: Any ): # type: (...) -> Union[ExecutionResult, Promise[ExecutionResult]] @@ -91,8 +92,8 @@ def execute( variables = options["variable_values"] assert schema, "Must provide schema" assert isinstance(schema, GraphQLSchema), ( - "Schema must be an instance of GraphQLSchema. Also ensure that there are " - + "not multiple versions of GraphQL installed in your node_modules directory." + "Schema must be an instance of GraphQLSchema. Also ensure that there are " + + "not multiple versions of GraphQL installed in your node_modules directory." ) if middleware: @@ -117,7 +118,7 @@ def execute( executor, middleware, allow_subscriptions, - ) + ) def promise_executor(v): # type: (Optional[Any]) -> Union[Dict, Promise[Dict], Observable] @@ -154,9 +155,9 @@ def on_resolve(data): def execute_operation( - exe_context, # type: ExecutionContext - operation, # type: OperationDefinition - root_value, # type: Any + exe_context, # type: ExecutionContext + operation, # type: OperationDefinition + root_value, # type: Any ): # type: (...) -> Union[Dict, Promise[Dict]] type = get_operation_root_type(exe_context.schema, operation) @@ -176,15 +177,20 @@ def execute_operation( ) return subscribe_fields(exe_context, type, root_value, fields) - return execute_fields(exe_context, type, root_value, fields, [], None) + deferred = [] + result = execute_fields(exe_context, type, root_value, fields, [], None, deferred) + if len(deferred) > 0: + return Promise.all((result, deferred)) + else: + return result def execute_fields_serially( - exe_context, # type: ExecutionContext - parent_type, # type: GraphQLObjectType - source_value, # type: Any - path, # type: List - fields, # type: DefaultOrderedDict + exe_context, # type: ExecutionContext + parent_type, # type: GraphQLObjectType + source_value, # type: Any + path, # type: List + fields, # type: DefaultOrderedDict ): # type: (...) -> Promise def execute_field_callback(results, response_name): @@ -197,6 +203,7 @@ def execute_field_callback(results, response_name): field_asts, None, path + [response_name], + [] ) if result is Undefined: return results @@ -225,12 +232,13 @@ def execute_field(prev_promise, response_name): def execute_fields( - exe_context, # type: ExecutionContext - parent_type, # type: GraphQLObjectType - source_value, # type: Any - fields, # type: DefaultOrderedDict - path, # type: List[Union[int, str]] - info, # type: Optional[ResolveInfo] + exe_context, # type: ExecutionContext + parent_type, # type: GraphQLObjectType + source_value, # type: Any + fields, # type: DefaultOrderedDict + path, # type: List[Union[int, str]] + info, # type: Optional[ResolveInfo] + deferred, #type: List[Promise] ): # type: (...) -> Union[Dict, Promise[Dict]] contains_promise = False @@ -245,13 +253,20 @@ def execute_fields( field_asts, info, path + [response_name], + deferred ) if result is Undefined: continue - final_results[response_name] = result if is_thenable(result): - contains_promise = True + if any(d.name.value == GraphQLDeferDirective.name for a in field_asts for d in a.directives): + final_results[response_name] = None + deferred.append((path + [response_name], result)) + else: + final_results[response_name] = result + contains_promise = True + else: + final_results[response_name] = result if not contains_promise: return final_results @@ -260,10 +275,10 @@ def execute_fields( def subscribe_fields( - exe_context, # type: ExecutionContext - parent_type, # type: GraphQLObjectType - source_value, # type: Any - fields, # type: DefaultOrderedDict + exe_context, # type: ExecutionContext + parent_type, # type: GraphQLObjectType + source_value, # type: Any + fields, # type: DefaultOrderedDict ): # type: (...) -> Observable subscriber_exe_context = SubscriberExecutionContext(exe_context) @@ -310,12 +325,13 @@ def catch_error(error): def resolve_field( - exe_context, # type: ExecutionContext - parent_type, # type: GraphQLObjectType - source, # type: Any - field_asts, # type: List[Field] - parent_info, # type: Optional[ResolveInfo] - field_path, # type: List[Union[int, str]] + exe_context, # type: ExecutionContext + parent_type, # type: GraphQLObjectType + source, # type: Any + field_asts, # type: List[Field] + parent_info, # type: Optional[ResolveInfo] + field_path, # type: List[Union[int, str]] + deferred, #type: List[Promise] ): # type: (...) -> Any field_ast = field_asts[0] @@ -360,16 +376,16 @@ def resolve_field( result = resolve_or_error(resolve_fn_middleware, source, info, args, executor) return complete_value_catching_error( - exe_context, return_type, field_asts, info, field_path, result + exe_context, return_type, field_asts, info, field_path, result, deferred ) def subscribe_field( - exe_context, # type: SubscriberExecutionContext - parent_type, # type: GraphQLObjectType - source, # type: Any - field_asts, # type: List[Field] - path, # type: List[str] + exe_context, # type: SubscriberExecutionContext + parent_type, # type: GraphQLObjectType + source, # type: Any + field_asts, # type: List[Field] + path, # type: List[str] ): # type: (...) -> Observable field_ast = field_asts[0] @@ -436,11 +452,11 @@ def subscribe_field( def resolve_or_error( - resolve_fn, # type: Callable - source, # type: Any - info, # type: ResolveInfo - args, # type: Dict - executor, # type: Any + resolve_fn, # type: Callable + source, # type: Any + info, # type: ResolveInfo + args, # type: Dict + executor, # type: Any ): # type: (...) -> Any try: @@ -456,24 +472,25 @@ def resolve_or_error( def complete_value_catching_error( - exe_context, # type: ExecutionContext - return_type, # type: Any - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any + exe_context, # type: ExecutionContext + return_type, # type: Any + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any + deferred, #type: List[Promise] ): # type: (...) -> Any # If the field type is non-nullable, then it is resolved without any # protection from errors. if isinstance(return_type, GraphQLNonNull): - return complete_value(exe_context, return_type, field_asts, info, path, result) + return complete_value(exe_context, return_type, field_asts, info, path, result, deferred) # Otherwise, error protection is applied, logging the error and # resolving a null value for this field if one is encountered. try: completed = complete_value( - exe_context, return_type, field_asts, info, path, result + exe_context, return_type, field_asts, info, path, result, deferred ) if is_thenable(completed): @@ -493,12 +510,13 @@ def handle_error(error): def complete_value( - exe_context, # type: ExecutionContext - return_type, # type: Any - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any + exe_context, # type: ExecutionContext + return_type, # type: Any + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any + deferred, #type: List[Promise] ): # type: (...) -> Any """ @@ -524,7 +542,7 @@ def complete_value( if is_thenable(result): return Promise.resolve(result).then( lambda resolved: complete_value( - exe_context, return_type, field_asts, info, path, resolved + exe_context, return_type, field_asts, info, path, resolved, deferred ), lambda error: Promise.rejected( GraphQLLocatedError(field_asts, original_error=error, path=path) @@ -537,7 +555,7 @@ def complete_value( if isinstance(return_type, GraphQLNonNull): return complete_nonnull_value( - exe_context, return_type, field_asts, info, path, result + exe_context, return_type, field_asts, info, path, result, deferred ) # If result is null-like, return null. @@ -547,7 +565,7 @@ def complete_value( # If field type is List, complete each item in the list with the inner type if isinstance(return_type, GraphQLList): return complete_list_value( - exe_context, return_type, field_asts, info, path, result + exe_context, return_type, field_asts, info, path, result, deferred ) # If field type is Scalar or Enum, serialize to a valid value, returning @@ -557,31 +575,32 @@ def complete_value( if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): return complete_abstract_value( - exe_context, return_type, field_asts, info, path, result + exe_context, return_type, field_asts, info, path, result, deferred ) if isinstance(return_type, GraphQLObjectType): return complete_object_value( - exe_context, return_type, field_asts, info, path, result + exe_context, return_type, field_asts, info, path, result, deferred ) assert False, u'Cannot complete value of unexpected type "{}".'.format(return_type) def complete_list_value( - exe_context, # type: ExecutionContext - return_type, # type: GraphQLList - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any + exe_context, # type: ExecutionContext + return_type, # type: GraphQLList + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any + deferred, #type: List[Promise] ): # type: (...) -> List[Any] """ Complete a list value by completing each item in the list with the inner type """ assert isinstance(result, collections.Iterable), ( - "User Error: expected iterable, but did not find one " + "for field {}.{}." + "User Error: expected iterable, but did not find one " + "for field {}.{}." ).format(info.parent_type, info.field_name) item_type = return_type.of_type @@ -591,7 +610,7 @@ def complete_list_value( index = 0 for item in result: completed_item = complete_value_catching_error( - exe_context, item_type, field_asts, info, path + [index], item + exe_context, item_type, field_asts, info, path + [index], item, deferred ) if not contains_promise and is_thenable(completed_item): contains_promise = True @@ -603,9 +622,9 @@ def complete_list_value( def complete_leaf_value( - return_type, # type: Union[GraphQLEnumType, GraphQLScalarType] - path, # type: List[Union[int, str]] - result, # type: Any + return_type, # type: Union[GraphQLEnumType, GraphQLScalarType] + path, # type: List[Union[int, str]] + result, # type: Any ): # type: (...) -> Union[int, str, float, bool] """ @@ -625,12 +644,13 @@ def complete_leaf_value( def complete_abstract_value( - exe_context, # type: ExecutionContext - return_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any + exe_context, # type: ExecutionContext + return_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any + deferred, #type: List[Promise] ): # type: (...) -> Dict[str, Any] """ @@ -652,8 +672,8 @@ def complete_abstract_value( if not isinstance(runtime_type, GraphQLObjectType): raise GraphQLError( ( - "Abstract type {} must resolve to an Object type at runtime " - + 'for field {}.{} with value "{}", received "{}".' + "Abstract type {} must resolve to an Object type at runtime " + + 'for field {}.{} with value "{}", received "{}".' ).format( return_type, info.parent_type, info.field_name, result, runtime_type ), @@ -669,14 +689,14 @@ def complete_abstract_value( ) return complete_object_value( - exe_context, runtime_type, field_asts, info, path, result + exe_context, runtime_type, field_asts, info, path, result, deferred ) def get_default_resolve_type_fn( - value, # type: Any - info, # type: ResolveInfo - abstract_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] + value, # type: Any + info, # type: ResolveInfo + abstract_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] ): # type: (...) -> Optional[GraphQLObjectType] possible_types = info.schema.get_possible_types(abstract_type) @@ -687,12 +707,13 @@ def get_default_resolve_type_fn( def complete_object_value( - exe_context, # type: ExecutionContext - return_type, # type: GraphQLObjectType - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any + exe_context, # type: ExecutionContext + return_type, # type: GraphQLObjectType + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any + deferred, #type: List[Promise] ): # type: (...) -> Dict[str, Any] """ @@ -708,23 +729,24 @@ def complete_object_value( # Collect sub-fields to execute to complete this value. subfield_asts = exe_context.get_sub_fields(return_type, field_asts) - return execute_fields(exe_context, return_type, result, subfield_asts, path, info) + return execute_fields(exe_context, return_type, result, subfield_asts, path, info, deferred) def complete_nonnull_value( - exe_context, # type: ExecutionContext - return_type, # type: GraphQLNonNull - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any + exe_context, # type: ExecutionContext + return_type, # type: GraphQLNonNull + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any + deferred, #type: List[Promise] ): # type: (...) -> Any """ Complete a NonNull value by completing the inner type """ completed = complete_value( - exe_context, return_type.of_type, field_asts, info, path, result + exe_context, return_type.of_type, field_asts, info, path, result, deferred ) if completed is None: raise GraphQLError( diff --git a/graphql/type/__init__.py b/graphql/type/__init__.py index 41a11115..82224b20 100644 --- a/graphql/type/__init__.py +++ b/graphql/type/__init__.py @@ -31,6 +31,7 @@ GraphQLSkipDirective, GraphQLIncludeDirective, GraphQLDeprecatedDirective, + GraphQLDeferDirective, # Constant Deprecation Reason DEFAULT_DEPRECATION_REASON, ) diff --git a/graphql/type/directives.py b/graphql/type/directives.py index ef7417c4..e6bad3d0 100644 --- a/graphql/type/directives.py +++ b/graphql/type/directives.py @@ -117,8 +117,22 @@ def __init__(self, name, description=None, args=None, locations=None): locations=[DirectiveLocation.FIELD_DEFINITION, DirectiveLocation.ENUM_VALUE], ) + +"""Used to defer the result of an element.""" +GraphQLDeferDirective = GraphQLDirective( + name="defer", + description="Marks an element of a GraphQL schema as deferred.", + args={}, + locations=[ + DirectiveLocation.FIELD, + DirectiveLocation.FRAGMENT_SPREAD, + DirectiveLocation.INLINE_FRAGMENT, + ], +) + specified_directives = [ GraphQLIncludeDirective, GraphQLSkipDirective, GraphQLDeprecatedDirective, + # GraphQLDeferDirective, ] From f7e9758bb0e59afd410994f071b2df1094ff6914 Mon Sep 17 00:00:00 2001 From: Abhijit Bhole Date: Mon, 5 Nov 2018 14:22:07 +0530 Subject: [PATCH 2/2] Tests working --- graphql/__init__.py | 2 + graphql/execution/executor.py | 246 ++++++++++--------- graphql/type/directives.py | 8 +- graphql/utils/build_ast_schema.py | 7 + graphql/utils/schema_printer.py | 2 +- graphql/utils/tests/test_build_ast_schema.py | 20 +- graphql/utils/tests/test_schema_printer.py | 2 + 7 files changed, 163 insertions(+), 124 deletions(-) diff --git a/graphql/__init__.py b/graphql/__init__.py index 2365383f..2461dbbd 100644 --- a/graphql/__init__.py +++ b/graphql/__init__.py @@ -59,6 +59,7 @@ GraphQLSkipDirective, GraphQLIncludeDirective, GraphQLDeprecatedDirective, + GraphQLDeferDirective, # Constant Deprecation Reason DEFAULT_DEPRECATION_REASON, # GraphQL Types for introspection. @@ -198,6 +199,7 @@ "GraphQLSkipDirective", "GraphQLIncludeDirective", "GraphQLDeprecatedDirective", + "GraphQLDeferDirective", "DEFAULT_DEPRECATION_REASON", "TypeKind", "DirectiveLocation", diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index 4fe47151..72f2b602 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -55,17 +55,18 @@ def subscribe(*args, **kwargs): def execute( - schema, # type: GraphQLSchema - document_ast, # type: Document - root=None, # type: Any - context=None, # type: Optional[Any] - variables=None, # type: Optional[Any] - operation_name=None, # type: Optional[str] - executor=None, # type: Any - return_promise=False, # type: bool - middleware=None, # type: Optional[Any] - allow_subscriptions=False, # type: bool - **options # type: Any + schema, # type: GraphQLSchema + document_ast, # type: Document + root=None, # type: Any + context=None, # type: Optional[Any] + variables=None, # type: Optional[Any] + operation_name=None, # type: Optional[str] + executor=None, # type: Any + return_promise=False, # type: bool + middleware=None, # type: Optional[Any] + allow_subscriptions=False, # type: bool, + deferred_results = None, #type: Optional[List[Tuple[String, Promise[ExecutionResult]]]] + **options # type: Any ): # type: (...) -> Union[ExecutionResult, Promise[ExecutionResult]] @@ -92,8 +93,8 @@ def execute( variables = options["variable_values"] assert schema, "Must provide schema" assert isinstance(schema, GraphQLSchema), ( - "Schema must be an instance of GraphQLSchema. Also ensure that there are " - + "not multiple versions of GraphQL installed in your node_modules directory." + "Schema must be an instance of GraphQLSchema. Also ensure that there are " + + "not multiple versions of GraphQL installed in your node_modules directory." ) if middleware: @@ -120,9 +121,14 @@ def execute( allow_subscriptions, ) + if deferred_results is not None: + deferred = [] + else: + deferred = None + def promise_executor(v): # type: (Optional[Any]) -> Union[Dict, Promise[Dict], Observable] - return execute_operation(exe_context, exe_context.operation, root) + return execute_operation(exe_context, exe_context.operation, root, deferred) def on_rejected(error): # type: (Exception) -> None @@ -143,6 +149,19 @@ def on_resolve(data): Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve) ) + def on_deferred_resolve(data, errors): + if len(errors) == 0: + return ExecutionResult(data=data) + return ExecutionResult(data=data, errors=errors) + + if deferred_results is not None: + for path, deferred_promise in deferred: + errors = [] + deferred_results.append( + (path, deferred_promise + .catch(errors.append) + .then(functools.partial(on_deferred_resolve, errors=errors)))) + if not return_promise: exe_context.executor.wait_until_finished() return promise.get() @@ -155,9 +174,10 @@ def on_resolve(data): def execute_operation( - exe_context, # type: ExecutionContext - operation, # type: OperationDefinition - root_value, # type: Any + exe_context, # type: ExecutionContext + operation, # type: OperationDefinition + root_value, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Union[Dict, Promise[Dict]] type = get_operation_root_type(exe_context.schema, operation) @@ -177,20 +197,22 @@ def execute_operation( ) return subscribe_fields(exe_context, type, root_value, fields) - deferred = [] + if deferred is None: + deferred = [] + result = execute_fields(exe_context, type, root_value, fields, [], None, deferred) - if len(deferred) > 0: - return Promise.all((result, deferred)) - else: - return result + # if len(deferred) > 0: + # return Promise.all((result, deferred)) + # else: + return result def execute_fields_serially( - exe_context, # type: ExecutionContext - parent_type, # type: GraphQLObjectType - source_value, # type: Any - path, # type: List - fields, # type: DefaultOrderedDict + exe_context, # type: ExecutionContext + parent_type, # type: GraphQLObjectType + source_value, # type: Any + path, # type: List + fields, # type: DefaultOrderedDict ): # type: (...) -> Promise def execute_field_callback(results, response_name): @@ -232,13 +254,13 @@ def execute_field(prev_promise, response_name): def execute_fields( - exe_context, # type: ExecutionContext - parent_type, # type: GraphQLObjectType - source_value, # type: Any - fields, # type: DefaultOrderedDict - path, # type: List[Union[int, str]] - info, # type: Optional[ResolveInfo] - deferred, #type: List[Promise] + exe_context, # type: ExecutionContext + parent_type, # type: GraphQLObjectType + source_value, # type: Any + fields, # type: DefaultOrderedDict + path, # type: List[Union[int, str]] + info, # type: Optional[ResolveInfo] + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Union[Dict, Promise[Dict]] contains_promise = False @@ -258,14 +280,14 @@ def execute_fields( if result is Undefined: continue - if is_thenable(result): - if any(d.name.value == GraphQLDeferDirective.name for a in field_asts for d in a.directives): - final_results[response_name] = None - deferred.append((path + [response_name], result)) - else: - final_results[response_name] = result - contains_promise = True + if deferred is not None and any( + d.name.value == GraphQLDeferDirective.name for a in field_asts for d in a.directives): + final_results[response_name] = None + deferred.append((path + [response_name], Promise.resolve(result))) else: + if is_thenable(result): + contains_promise = True + final_results[response_name] = result if not contains_promise: @@ -275,10 +297,10 @@ def execute_fields( def subscribe_fields( - exe_context, # type: ExecutionContext - parent_type, # type: GraphQLObjectType - source_value, # type: Any - fields, # type: DefaultOrderedDict + exe_context, # type: ExecutionContext + parent_type, # type: GraphQLObjectType + source_value, # type: Any + fields, # type: DefaultOrderedDict ): # type: (...) -> Observable subscriber_exe_context = SubscriberExecutionContext(exe_context) @@ -325,13 +347,13 @@ def catch_error(error): def resolve_field( - exe_context, # type: ExecutionContext - parent_type, # type: GraphQLObjectType - source, # type: Any - field_asts, # type: List[Field] - parent_info, # type: Optional[ResolveInfo] - field_path, # type: List[Union[int, str]] - deferred, #type: List[Promise] + exe_context, # type: ExecutionContext + parent_type, # type: GraphQLObjectType + source, # type: Any + field_asts, # type: List[Field] + parent_info, # type: Optional[ResolveInfo] + field_path, # type: List[Union[int, str]] + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Any field_ast = field_asts[0] @@ -381,11 +403,11 @@ def resolve_field( def subscribe_field( - exe_context, # type: SubscriberExecutionContext - parent_type, # type: GraphQLObjectType - source, # type: Any - field_asts, # type: List[Field] - path, # type: List[str] + exe_context, # type: SubscriberExecutionContext + parent_type, # type: GraphQLObjectType + source, # type: Any + field_asts, # type: List[Field] + path, # type: List[str] ): # type: (...) -> Observable field_ast = field_asts[0] @@ -452,11 +474,11 @@ def subscribe_field( def resolve_or_error( - resolve_fn, # type: Callable - source, # type: Any - info, # type: ResolveInfo - args, # type: Dict - executor, # type: Any + resolve_fn, # type: Callable + source, # type: Any + info, # type: ResolveInfo + args, # type: Dict + executor, # type: Any ): # type: (...) -> Any try: @@ -472,13 +494,13 @@ def resolve_or_error( def complete_value_catching_error( - exe_context, # type: ExecutionContext - return_type, # type: Any - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any - deferred, #type: List[Promise] + exe_context, # type: ExecutionContext + return_type, # type: Any + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Any # If the field type is non-nullable, then it is resolved without any @@ -510,13 +532,13 @@ def handle_error(error): def complete_value( - exe_context, # type: ExecutionContext - return_type, # type: Any - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any - deferred, #type: List[Promise] + exe_context, # type: ExecutionContext + return_type, # type: Any + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Any """ @@ -587,13 +609,13 @@ def complete_value( def complete_list_value( - exe_context, # type: ExecutionContext - return_type, # type: GraphQLList - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any - deferred, #type: List[Promise] + exe_context, # type: ExecutionContext + return_type, # type: GraphQLList + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> List[Any] """ @@ -622,9 +644,9 @@ def complete_list_value( def complete_leaf_value( - return_type, # type: Union[GraphQLEnumType, GraphQLScalarType] - path, # type: List[Union[int, str]] - result, # type: Any + return_type, # type: Union[GraphQLEnumType, GraphQLScalarType] + path, # type: List[Union[int, str]] + result, # type: Any ): # type: (...) -> Union[int, str, float, bool] """ @@ -644,13 +666,13 @@ def complete_leaf_value( def complete_abstract_value( - exe_context, # type: ExecutionContext - return_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any - deferred, #type: List[Promise] + exe_context, # type: ExecutionContext + return_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Dict[str, Any] """ @@ -672,8 +694,8 @@ def complete_abstract_value( if not isinstance(runtime_type, GraphQLObjectType): raise GraphQLError( ( - "Abstract type {} must resolve to an Object type at runtime " - + 'for field {}.{} with value "{}", received "{}".' + "Abstract type {} must resolve to an Object type at runtime " + + 'for field {}.{} with value "{}", received "{}".' ).format( return_type, info.parent_type, info.field_name, result, runtime_type ), @@ -694,9 +716,9 @@ def complete_abstract_value( def get_default_resolve_type_fn( - value, # type: Any - info, # type: ResolveInfo - abstract_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] + value, # type: Any + info, # type: ResolveInfo + abstract_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] ): # type: (...) -> Optional[GraphQLObjectType] possible_types = info.schema.get_possible_types(abstract_type) @@ -707,13 +729,13 @@ def get_default_resolve_type_fn( def complete_object_value( - exe_context, # type: ExecutionContext - return_type, # type: GraphQLObjectType - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any - deferred, #type: List[Promise] + exe_context, # type: ExecutionContext + return_type, # type: GraphQLObjectType + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Dict[str, Any] """ @@ -733,13 +755,13 @@ def complete_object_value( def complete_nonnull_value( - exe_context, # type: ExecutionContext - return_type, # type: GraphQLNonNull - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any - deferred, #type: List[Promise] + exe_context, # type: ExecutionContext + return_type, # type: GraphQLNonNull + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Any """ diff --git a/graphql/type/directives.py b/graphql/type/directives.py index e6bad3d0..85b7885a 100644 --- a/graphql/type/directives.py +++ b/graphql/type/directives.py @@ -121,12 +121,10 @@ def __init__(self, name, description=None, args=None, locations=None): """Used to defer the result of an element.""" GraphQLDeferDirective = GraphQLDirective( name="defer", - description="Marks an element of a GraphQL schema as deferred.", + description='Defers this field', args={}, locations=[ - DirectiveLocation.FIELD, - DirectiveLocation.FRAGMENT_SPREAD, - DirectiveLocation.INLINE_FRAGMENT, + DirectiveLocation.FIELD ], ) @@ -134,5 +132,5 @@ def __init__(self, name, description=None, args=None, locations=None): GraphQLIncludeDirective, GraphQLSkipDirective, GraphQLDeprecatedDirective, - # GraphQLDeferDirective, + GraphQLDeferDirective, ] diff --git a/graphql/utils/build_ast_schema.py b/graphql/utils/build_ast_schema.py index 8d28f4be..962bd2d0 100644 --- a/graphql/utils/build_ast_schema.py +++ b/graphql/utils/build_ast_schema.py @@ -4,6 +4,7 @@ from ..type import ( GraphQLArgument, GraphQLBoolean, + GraphQLDeferDirective, GraphQLDeprecatedDirective, GraphQLDirective, GraphQLEnumType, @@ -308,6 +309,9 @@ def make_input_object_def(definition): find_deprecated_directive = ( directive.name for directive in directives if directive.name == "deprecated" ) + find_defer_directive = ( + directive.name for directive in directives if directive.name == "defer" + ) if not next(find_skip_directive, None): directives.append(GraphQLSkipDirective) @@ -318,6 +322,9 @@ def make_input_object_def(definition): if not next(find_deprecated_directive, None): directives.append(GraphQLDeprecatedDirective) + if not next(find_defer_directive, None): + directives.append(GraphQLDeferDirective) + schema_kwargs = {"query": get_object_type(ast_map[query_type_name])} if mutation_type_name: diff --git a/graphql/utils/schema_printer.py b/graphql/utils/schema_printer.py index 30a28abd..8d8ba71e 100644 --- a/graphql/utils/schema_printer.py +++ b/graphql/utils/schema_printer.py @@ -39,7 +39,7 @@ def print_introspection_schema(schema): def is_spec_directive(directive_name): # type: (str) -> bool - return directive_name in ("skip", "include", "deprecated") + return directive_name in ("skip", "include", "deprecated", "defer") def _is_defined_type(typename): diff --git a/graphql/utils/tests/test_build_ast_schema.py b/graphql/utils/tests/test_build_ast_schema.py index 6f84aa64..97828767 100644 --- a/graphql/utils/tests/test_build_ast_schema.py +++ b/graphql/utils/tests/test_build_ast_schema.py @@ -8,6 +8,7 @@ GraphQLDeprecatedDirective, GraphQLIncludeDirective, GraphQLSkipDirective, + GraphQLDeferDirective ) @@ -66,10 +67,11 @@ def test_maintains_skip_and_include_directives(): """ schema = build_ast_schema(parse(body)) - assert len(schema.get_directives()) == 3 + assert len(schema.get_directives()) == 4 assert schema.get_directive("skip") == GraphQLSkipDirective assert schema.get_directive("include") == GraphQLIncludeDirective assert schema.get_directive("deprecated") == GraphQLDeprecatedDirective + assert schema.get_directive("defer") == GraphQLDeferDirective def test_overriding_directives_excludes_specified(): @@ -81,20 +83,23 @@ def test_overriding_directives_excludes_specified(): directive @skip on FIELD directive @include on FIELD directive @deprecated on FIELD_DEFINITION - + directive @defer on FIELD + type Hello { str: String } """ schema = build_ast_schema(parse(body)) - assert len(schema.get_directives()) == 3 + assert len(schema.get_directives()) == 4 assert schema.get_directive("skip") != GraphQLSkipDirective assert schema.get_directive("skip") is not None assert schema.get_directive("include") != GraphQLIncludeDirective assert schema.get_directive("include") is not None assert schema.get_directive("deprecated") != GraphQLDeprecatedDirective assert schema.get_directive("deprecated") is not None + assert schema.get_directive("defer") != GraphQLDeferDirective + assert schema.get_directive("defer") is not None def test_overriding_skip_directive_excludes_built_in_one(): @@ -111,11 +116,12 @@ def test_overriding_skip_directive_excludes_built_in_one(): """ schema = build_ast_schema(parse(body)) - assert len(schema.get_directives()) == 3 + assert len(schema.get_directives()) == 4 assert schema.get_directive("skip") != GraphQLSkipDirective assert schema.get_directive("skip") is not None assert schema.get_directive("include") == GraphQLIncludeDirective assert schema.get_directive("deprecated") == GraphQLDeprecatedDirective + assert schema.get_directive("defer") == GraphQLDeferDirective def test_overriding_include_directive_excludes_built_in_one(): @@ -132,11 +138,12 @@ def test_overriding_include_directive_excludes_built_in_one(): """ schema = build_ast_schema(parse(body)) - assert len(schema.get_directives()) == 3 + assert len(schema.get_directives()) == 4 assert schema.get_directive("skip") == GraphQLSkipDirective assert schema.get_directive("deprecated") == GraphQLDeprecatedDirective assert schema.get_directive("include") != GraphQLIncludeDirective assert schema.get_directive("include") is not None + assert schema.get_directive("defer") == GraphQLDeferDirective def test_adding_directives_maintains_skip_and_include_directives(): @@ -153,10 +160,11 @@ def test_adding_directives_maintains_skip_and_include_directives(): """ schema = build_ast_schema(parse(body)) - assert len(schema.get_directives()) == 4 + assert len(schema.get_directives()) == 5 assert schema.get_directive("skip") == GraphQLSkipDirective assert schema.get_directive("include") == GraphQLIncludeDirective assert schema.get_directive("deprecated") == GraphQLDeprecatedDirective + assert schema.get_directive("defer") == GraphQLDeferDirective def test_type_modifiers(): diff --git a/graphql/utils/tests/test_schema_printer.py b/graphql/utils/tests/test_schema_printer.py index 0d61facf..31837389 100644 --- a/graphql/utils/tests/test_schema_printer.py +++ b/graphql/utils/tests/test_schema_printer.py @@ -580,6 +580,8 @@ def test_print_introspection_schema(): directive @deprecated(reason: String = "No longer supported") on FIELD_DEFINITION | ENUM_VALUE +directive @defer on FIELD + type __Directive { name: String! description: String