diff --git a/graphql/__init__.py b/graphql/__init__.py index 2365383f..2461dbbd 100644 --- a/graphql/__init__.py +++ b/graphql/__init__.py @@ -59,6 +59,7 @@ GraphQLSkipDirective, GraphQLIncludeDirective, GraphQLDeprecatedDirective, + GraphQLDeferDirective, # Constant Deprecation Reason DEFAULT_DEPRECATION_REASON, # GraphQL Types for introspection. @@ -198,6 +199,7 @@ "GraphQLSkipDirective", "GraphQLIncludeDirective", "GraphQLDeprecatedDirective", + "GraphQLDeferDirective", "DEFAULT_DEPRECATION_REASON", "TypeKind", "DirectiveLocation", diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index 1b5e884e..72f2b602 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -21,6 +21,7 @@ GraphQLScalarType, GraphQLSchema, GraphQLUnionType, + GraphQLDeferDirective, ) from .base import ( ExecutionContext, @@ -63,7 +64,8 @@ def execute( executor=None, # type: Any return_promise=False, # type: bool middleware=None, # type: Optional[Any] - allow_subscriptions=False, # type: bool + allow_subscriptions=False, # type: bool, + deferred_results = None, #type: Optional[List[Tuple[String, Promise[ExecutionResult]]]] **options # type: Any ): # type: (...) -> Union[ExecutionResult, Promise[ExecutionResult]] @@ -117,11 +119,16 @@ def execute( executor, middleware, allow_subscriptions, - ) + ) + + if deferred_results is not None: + deferred = [] + else: + deferred = None def promise_executor(v): # type: (Optional[Any]) -> Union[Dict, Promise[Dict], Observable] - return execute_operation(exe_context, exe_context.operation, root) + return execute_operation(exe_context, exe_context.operation, root, deferred) def on_rejected(error): # type: (Exception) -> None @@ -142,6 +149,19 @@ def on_resolve(data): Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve) ) + def on_deferred_resolve(data, errors): + if len(errors) == 0: + return ExecutionResult(data=data) + return ExecutionResult(data=data, errors=errors) + + if deferred_results is not None: + for path, deferred_promise in deferred: + errors = [] + deferred_results.append( + (path, deferred_promise + .catch(errors.append) + .then(functools.partial(on_deferred_resolve, errors=errors)))) + if not return_promise: exe_context.executor.wait_until_finished() return promise.get() @@ -157,6 +177,7 @@ def execute_operation( exe_context, # type: ExecutionContext operation, # type: OperationDefinition root_value, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Union[Dict, Promise[Dict]] type = get_operation_root_type(exe_context.schema, operation) @@ -176,7 +197,14 @@ def execute_operation( ) return subscribe_fields(exe_context, type, root_value, fields) - return execute_fields(exe_context, type, root_value, fields, [], None) + if deferred is None: + deferred = [] + + result = execute_fields(exe_context, type, root_value, fields, [], None, deferred) + # if len(deferred) > 0: + # return Promise.all((result, deferred)) + # else: + return result def execute_fields_serially( @@ -197,6 +225,7 @@ def execute_field_callback(results, response_name): field_asts, None, path + [response_name], + [] ) if result is Undefined: return results @@ -231,6 +260,7 @@ def execute_fields( fields, # type: DefaultOrderedDict path, # type: List[Union[int, str]] info, # type: Optional[ResolveInfo] + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Union[Dict, Promise[Dict]] contains_promise = False @@ -245,13 +275,20 @@ def execute_fields( field_asts, info, path + [response_name], + deferred ) if result is Undefined: continue - final_results[response_name] = result - if is_thenable(result): - contains_promise = True + if deferred is not None and any( + d.name.value == GraphQLDeferDirective.name for a in field_asts for d in a.directives): + final_results[response_name] = None + deferred.append((path + [response_name], Promise.resolve(result))) + else: + if is_thenable(result): + contains_promise = True + + final_results[response_name] = result if not contains_promise: return final_results @@ -316,6 +353,7 @@ def resolve_field( field_asts, # type: List[Field] parent_info, # type: Optional[ResolveInfo] field_path, # type: List[Union[int, str]] + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Any field_ast = field_asts[0] @@ -360,7 +398,7 @@ def resolve_field( result = resolve_or_error(resolve_fn_middleware, source, info, args, executor) return complete_value_catching_error( - exe_context, return_type, field_asts, info, field_path, result + exe_context, return_type, field_asts, info, field_path, result, deferred ) @@ -462,18 +500,19 @@ def complete_value_catching_error( info, # type: ResolveInfo path, # type: List[Union[int, str]] result, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Any # If the field type is non-nullable, then it is resolved without any # protection from errors. if isinstance(return_type, GraphQLNonNull): - return complete_value(exe_context, return_type, field_asts, info, path, result) + return complete_value(exe_context, return_type, field_asts, info, path, result, deferred) # Otherwise, error protection is applied, logging the error and # resolving a null value for this field if one is encountered. try: completed = complete_value( - exe_context, return_type, field_asts, info, path, result + exe_context, return_type, field_asts, info, path, result, deferred ) if is_thenable(completed): @@ -499,6 +538,7 @@ def complete_value( info, # type: ResolveInfo path, # type: List[Union[int, str]] result, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Any """ @@ -524,7 +564,7 @@ def complete_value( if is_thenable(result): return Promise.resolve(result).then( lambda resolved: complete_value( - exe_context, return_type, field_asts, info, path, resolved + exe_context, return_type, field_asts, info, path, resolved, deferred ), lambda error: Promise.rejected( GraphQLLocatedError(field_asts, original_error=error, path=path) @@ -537,7 +577,7 @@ def complete_value( if isinstance(return_type, GraphQLNonNull): return complete_nonnull_value( - exe_context, return_type, field_asts, info, path, result + exe_context, return_type, field_asts, info, path, result, deferred ) # If result is null-like, return null. @@ -547,7 +587,7 @@ def complete_value( # If field type is List, complete each item in the list with the inner type if isinstance(return_type, GraphQLList): return complete_list_value( - exe_context, return_type, field_asts, info, path, result + exe_context, return_type, field_asts, info, path, result, deferred ) # If field type is Scalar or Enum, serialize to a valid value, returning @@ -557,12 +597,12 @@ def complete_value( if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): return complete_abstract_value( - exe_context, return_type, field_asts, info, path, result + exe_context, return_type, field_asts, info, path, result, deferred ) if isinstance(return_type, GraphQLObjectType): return complete_object_value( - exe_context, return_type, field_asts, info, path, result + exe_context, return_type, field_asts, info, path, result, deferred ) assert False, u'Cannot complete value of unexpected type "{}".'.format(return_type) @@ -575,13 +615,14 @@ def complete_list_value( info, # type: ResolveInfo path, # type: List[Union[int, str]] result, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> List[Any] """ Complete a list value by completing each item in the list with the inner type """ assert isinstance(result, collections.Iterable), ( - "User Error: expected iterable, but did not find one " + "for field {}.{}." + "User Error: expected iterable, but did not find one " + "for field {}.{}." ).format(info.parent_type, info.field_name) item_type = return_type.of_type @@ -591,7 +632,7 @@ def complete_list_value( index = 0 for item in result: completed_item = complete_value_catching_error( - exe_context, item_type, field_asts, info, path + [index], item + exe_context, item_type, field_asts, info, path + [index], item, deferred ) if not contains_promise and is_thenable(completed_item): contains_promise = True @@ -631,6 +672,7 @@ def complete_abstract_value( info, # type: ResolveInfo path, # type: List[Union[int, str]] result, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Dict[str, Any] """ @@ -669,7 +711,7 @@ def complete_abstract_value( ) return complete_object_value( - exe_context, runtime_type, field_asts, info, path, result + exe_context, runtime_type, field_asts, info, path, result, deferred ) @@ -693,6 +735,7 @@ def complete_object_value( info, # type: ResolveInfo path, # type: List[Union[int, str]] result, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Dict[str, Any] """ @@ -708,7 +751,7 @@ def complete_object_value( # Collect sub-fields to execute to complete this value. subfield_asts = exe_context.get_sub_fields(return_type, field_asts) - return execute_fields(exe_context, return_type, result, subfield_asts, path, info) + return execute_fields(exe_context, return_type, result, subfield_asts, path, info, deferred) def complete_nonnull_value( @@ -718,13 +761,14 @@ def complete_nonnull_value( info, # type: ResolveInfo path, # type: List[Union[int, str]] result, # type: Any + deferred, #type: Optional[List[Promise]] ): # type: (...) -> Any """ Complete a NonNull value by completing the inner type """ completed = complete_value( - exe_context, return_type.of_type, field_asts, info, path, result + exe_context, return_type.of_type, field_asts, info, path, result, deferred ) if completed is None: raise GraphQLError( diff --git a/graphql/type/__init__.py b/graphql/type/__init__.py index 41a11115..82224b20 100644 --- a/graphql/type/__init__.py +++ b/graphql/type/__init__.py @@ -31,6 +31,7 @@ GraphQLSkipDirective, GraphQLIncludeDirective, GraphQLDeprecatedDirective, + GraphQLDeferDirective, # Constant Deprecation Reason DEFAULT_DEPRECATION_REASON, ) diff --git a/graphql/type/directives.py b/graphql/type/directives.py index ef7417c4..85b7885a 100644 --- a/graphql/type/directives.py +++ b/graphql/type/directives.py @@ -117,8 +117,20 @@ def __init__(self, name, description=None, args=None, locations=None): locations=[DirectiveLocation.FIELD_DEFINITION, DirectiveLocation.ENUM_VALUE], ) + +"""Used to defer the result of an element.""" +GraphQLDeferDirective = GraphQLDirective( + name="defer", + description='Defers this field', + args={}, + locations=[ + DirectiveLocation.FIELD + ], +) + specified_directives = [ GraphQLIncludeDirective, GraphQLSkipDirective, GraphQLDeprecatedDirective, + GraphQLDeferDirective, ] diff --git a/graphql/utils/build_ast_schema.py b/graphql/utils/build_ast_schema.py index 8d28f4be..962bd2d0 100644 --- a/graphql/utils/build_ast_schema.py +++ b/graphql/utils/build_ast_schema.py @@ -4,6 +4,7 @@ from ..type import ( GraphQLArgument, GraphQLBoolean, + GraphQLDeferDirective, GraphQLDeprecatedDirective, GraphQLDirective, GraphQLEnumType, @@ -308,6 +309,9 @@ def make_input_object_def(definition): find_deprecated_directive = ( directive.name for directive in directives if directive.name == "deprecated" ) + find_defer_directive = ( + directive.name for directive in directives if directive.name == "defer" + ) if not next(find_skip_directive, None): directives.append(GraphQLSkipDirective) @@ -318,6 +322,9 @@ def make_input_object_def(definition): if not next(find_deprecated_directive, None): directives.append(GraphQLDeprecatedDirective) + if not next(find_defer_directive, None): + directives.append(GraphQLDeferDirective) + schema_kwargs = {"query": get_object_type(ast_map[query_type_name])} if mutation_type_name: diff --git a/graphql/utils/schema_printer.py b/graphql/utils/schema_printer.py index 30a28abd..8d8ba71e 100644 --- a/graphql/utils/schema_printer.py +++ b/graphql/utils/schema_printer.py @@ -39,7 +39,7 @@ def print_introspection_schema(schema): def is_spec_directive(directive_name): # type: (str) -> bool - return directive_name in ("skip", "include", "deprecated") + return directive_name in ("skip", "include", "deprecated", "defer") def _is_defined_type(typename): diff --git a/graphql/utils/tests/test_build_ast_schema.py b/graphql/utils/tests/test_build_ast_schema.py index 6f84aa64..97828767 100644 --- a/graphql/utils/tests/test_build_ast_schema.py +++ b/graphql/utils/tests/test_build_ast_schema.py @@ -8,6 +8,7 @@ GraphQLDeprecatedDirective, GraphQLIncludeDirective, GraphQLSkipDirective, + GraphQLDeferDirective ) @@ -66,10 +67,11 @@ def test_maintains_skip_and_include_directives(): """ schema = build_ast_schema(parse(body)) - assert len(schema.get_directives()) == 3 + assert len(schema.get_directives()) == 4 assert schema.get_directive("skip") == GraphQLSkipDirective assert schema.get_directive("include") == GraphQLIncludeDirective assert schema.get_directive("deprecated") == GraphQLDeprecatedDirective + assert schema.get_directive("defer") == GraphQLDeferDirective def test_overriding_directives_excludes_specified(): @@ -81,20 +83,23 @@ def test_overriding_directives_excludes_specified(): directive @skip on FIELD directive @include on FIELD directive @deprecated on FIELD_DEFINITION - + directive @defer on FIELD + type Hello { str: String } """ schema = build_ast_schema(parse(body)) - assert len(schema.get_directives()) == 3 + assert len(schema.get_directives()) == 4 assert schema.get_directive("skip") != GraphQLSkipDirective assert schema.get_directive("skip") is not None assert schema.get_directive("include") != GraphQLIncludeDirective assert schema.get_directive("include") is not None assert schema.get_directive("deprecated") != GraphQLDeprecatedDirective assert schema.get_directive("deprecated") is not None + assert schema.get_directive("defer") != GraphQLDeferDirective + assert schema.get_directive("defer") is not None def test_overriding_skip_directive_excludes_built_in_one(): @@ -111,11 +116,12 @@ def test_overriding_skip_directive_excludes_built_in_one(): """ schema = build_ast_schema(parse(body)) - assert len(schema.get_directives()) == 3 + assert len(schema.get_directives()) == 4 assert schema.get_directive("skip") != GraphQLSkipDirective assert schema.get_directive("skip") is not None assert schema.get_directive("include") == GraphQLIncludeDirective assert schema.get_directive("deprecated") == GraphQLDeprecatedDirective + assert schema.get_directive("defer") == GraphQLDeferDirective def test_overriding_include_directive_excludes_built_in_one(): @@ -132,11 +138,12 @@ def test_overriding_include_directive_excludes_built_in_one(): """ schema = build_ast_schema(parse(body)) - assert len(schema.get_directives()) == 3 + assert len(schema.get_directives()) == 4 assert schema.get_directive("skip") == GraphQLSkipDirective assert schema.get_directive("deprecated") == GraphQLDeprecatedDirective assert schema.get_directive("include") != GraphQLIncludeDirective assert schema.get_directive("include") is not None + assert schema.get_directive("defer") == GraphQLDeferDirective def test_adding_directives_maintains_skip_and_include_directives(): @@ -153,10 +160,11 @@ def test_adding_directives_maintains_skip_and_include_directives(): """ schema = build_ast_schema(parse(body)) - assert len(schema.get_directives()) == 4 + assert len(schema.get_directives()) == 5 assert schema.get_directive("skip") == GraphQLSkipDirective assert schema.get_directive("include") == GraphQLIncludeDirective assert schema.get_directive("deprecated") == GraphQLDeprecatedDirective + assert schema.get_directive("defer") == GraphQLDeferDirective def test_type_modifiers(): diff --git a/graphql/utils/tests/test_schema_printer.py b/graphql/utils/tests/test_schema_printer.py index 0d61facf..31837389 100644 --- a/graphql/utils/tests/test_schema_printer.py +++ b/graphql/utils/tests/test_schema_printer.py @@ -580,6 +580,8 @@ def test_print_introspection_schema(): directive @deprecated(reason: String = "No longer supported") on FIELD_DEFINITION | ENUM_VALUE +directive @defer on FIELD + type __Directive { name: String! description: String