diff --git a/.gitignore b/.gitignore index b7444c516d..b88f6522e1 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ apollo.config.js schema.graphql .env +.mypy_cache .tox diff --git a/graphene_tornado/__init__.py b/graphene_tornado/__init__.py index 192acd6513..28dd61d767 100644 --- a/graphene_tornado/__init__.py +++ b/graphene_tornado/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.6.1' +__version__ = '3.0.0.b0' __all__ = [ '__version__' diff --git a/graphene_tornado/apollo_tooling/operation_id.py b/graphene_tornado/apollo_tooling/operation_id.py index ee9cd7e2b6..2babefe33d 100644 --- a/graphene_tornado/apollo_tooling/operation_id.py +++ b/graphene_tornado/apollo_tooling/operation_id.py @@ -3,9 +3,11 @@ """ from graphene_tornado.apollo_tooling.transforms import print_with_reduced_whitespace, sort_ast, remove_aliases, hide_literals, \ drop_unused_definitions +from graphql.language.ast import DocumentNode +from typing import Optional -def default_engine_reporting_signature(ast, operation_name): +def default_engine_reporting_signature(ast: DocumentNode, operation_name: str) -> str: """ The engine reporting signature function consists of removing extra whitespace, sorting the AST in a deterministic manner, hiding literals, and removing diff --git a/graphene_tornado/apollo_tooling/query_hash.py b/graphene_tornado/apollo_tooling/query_hash.py index 23b0e8ce6f..b70306ec12 100644 --- a/graphene_tornado/apollo_tooling/query_hash.py +++ b/graphene_tornado/apollo_tooling/query_hash.py @@ -1,7 +1,7 @@ import hashlib -def compute(query): +def compute(query: str) -> str: # type (str) -> str """ Computes the query hash via SHA-256. diff --git a/graphene_tornado/apollo_tooling/seperate_operations.py b/graphene_tornado/apollo_tooling/seperate_operations.py deleted file mode 100644 index e7835fb459..0000000000 --- a/graphene_tornado/apollo_tooling/seperate_operations.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Backported from graphql-core-next - -https://github.com/graphql-python/graphql-core-next/blob/master/src/graphql/utilities/separate_operations.py -""" - -from collections import defaultdict - -from graphql.language.ast import Document, OperationDefinition, FragmentDefinition, FragmentSpread -from graphql.language.visitor import Visitor, visit -from typing import Dict, Set - -__all__ = ["separate_operations"] - - -DepGraph = Dict[str, Set[str]] - - -def separate_operations(document_ast): - """Separate operations in a given AST document. - This function accepts a single AST document which may contain many operations and - fragments and returns a collection of AST documents each of which contains a single - operation as well the fragment definitions it refers to. - """ - - # Populate metadata and build a dependency graph. - visitor = SeparateOperations() - visit(document_ast, visitor) - operations = visitor.operations - fragments = visitor.fragments - positions = visitor.positions - dep_graph = visitor.dep_graph - - # For each operation, produce a new synthesized AST which includes only what is - # necessary for completing that operation. - separated_document_asts = {} - for operation in operations: - operation_name = op_name(operation) - dependencies = set() - collect_transitive_dependencies(dependencies, dep_graph, operation_name) - - # The list of definition nodes to be included for this operation, sorted to - # retain the same order as the original document. - definitions = [operation] - for name in dependencies: - definitions.append(fragments[name]) - definitions.sort(key=lambda n: positions.get(n, 0)) - - separated_document_asts[operation_name] = Document(definitions=definitions) - - return separated_document_asts - - -class SeparateOperations(Visitor): - - def __init__(self): - super(SeparateOperations, self).__init__() - self.operations = [] - self.fragments = {} - self.positions = {} - self.dep_graph = defaultdict(set) - self.from_name = None - self.idx = 0 - - def enter(self, node, key, parent, path, ancestors): - if isinstance(node, OperationDefinition): - self.from_name = op_name(node) - self.operations.append(node) - self.positions[node] = self.idx - self.idx += 1 - elif isinstance(node, FragmentDefinition): - self.from_name = node.name.value - self.fragments[self.from_name] = node - self.positions[node] = self.idx - self.idx += 1 - elif isinstance(node, FragmentSpread): - to_name = node.name.value - self.dep_graph[self.from_name].add(to_name) - return node - - -def op_name(operation): - """Provide the empty string for anonymous operations.""" - return operation.name.value if operation.name else "" - - -def collect_transitive_dependencies(collected, dep_graph, from_name): - """Collect transitive dependencies. - From a dependency graph, collects a list of transitive dependencies by recursing - through a dependency graph. - """ - immediate_deps = dep_graph[from_name] - for to_name in immediate_deps: - if to_name not in collected: - collected.add(to_name) - collect_transitive_dependencies(collected, dep_graph, to_name) diff --git a/graphene_tornado/apollo_tooling/transforms.py b/graphene_tornado/apollo_tooling/transforms.py index 774666b3ee..662eecb75c 100644 --- a/graphene_tornado/apollo_tooling/transforms.py +++ b/graphene_tornado/apollo_tooling/transforms.py @@ -4,16 +4,26 @@ import re import six +from graphql import DirectiveNode +from graphql import DocumentNode +from graphql import FieldNode +from graphql import FloatValueNode +from graphql import FragmentDefinitionNode +from graphql import FragmentSpreadNode +from graphql import InlineFragmentNode +from graphql import IntValueNode +from graphql import ListValueNode +from graphql import ObjectValueNode +from graphql import OperationDefinitionNode from graphql import print_ast -from graphql.language.ast import Document, IntValue, FloatValue, StringValue, ListValue, ObjectValue, Field, \ - Directive, FragmentDefinition, InlineFragment, FragmentSpread, SelectionSet, OperationDefinition +from graphql import SelectionSetNode +from graphql import separate_operations +from graphql import StringValueNode from graphql.language.visitor import Visitor, visit -from graphene_tornado.apollo_tooling.seperate_operations import separate_operations - -def hide_literals(ast: Document) -> Document: +def hide_literals(ast: DocumentNode) -> DocumentNode: """ Replace numeric, string, list, and object literals with "empty" values. Leaves enums alone (since there's no consistent "zero" enum). This @@ -26,7 +36,7 @@ def hide_literals(ast: Document) -> Document: return ast -def hide_string_and_numeric_literals(ast: Document) -> Document: +def hide_string_and_numeric_literals(ast: DocumentNode) -> DocumentNode: """ In the same spirit as the similarly named `hideLiterals` function, only hide string and numeric literals. @@ -35,7 +45,7 @@ def hide_string_and_numeric_literals(ast: Document) -> Document: return ast -def drop_unused_definitions(ast: Document, operation_name: str) -> Document: +def drop_unused_definitions(ast: DocumentNode, operation_name: str) -> DocumentNode: """ A GraphQL query may contain multiple named operations, with the operation to use specified separately by the client. This transformation drops unused @@ -49,7 +59,7 @@ def drop_unused_definitions(ast: Document, operation_name: str) -> Document: return separated -def sort_ast(ast: Document) -> Document: +def sort_ast(ast: DocumentNode) -> DocumentNode: """ sortAST sorts most multi-child nodes alphabetically. Using this as part of your signature calculation function may make it easier to tell the difference @@ -61,7 +71,7 @@ def sort_ast(ast: Document) -> Document: return ast -def remove_aliases(ast: Document) -> Document: +def remove_aliases(ast: DocumentNode) -> DocumentNode: """ removeAliases gets rid of GraphQL aliases, a feature by which you can tell a server to return a field's data under a different name from the field @@ -72,7 +82,7 @@ def remove_aliases(ast: Document) -> Document: return ast -def print_with_reduced_whitespace(ast: Document) -> str: +def print_with_reduced_whitespace(ast: DocumentNode) -> str: """ Like the graphql-js print function, but deleting whitespace wherever feasible. Specifically, all whitespace (outside of string literals) is @@ -115,15 +125,15 @@ def __init__(self, only_string_and_numeric=False): self._only_string_and_numeric = only_string_and_numeric def enter(self, node, key, parent, path, ancestors): - if isinstance(node, IntValue): + if isinstance(node, IntValueNode): node.value = 0 - elif isinstance(node, FloatValue): + elif isinstance(node, FloatValueNode): node.value = 0 - elif isinstance(node, StringValue): + elif isinstance(node, StringValueNode): node.value = "" - elif not self._only_string_and_numeric and isinstance(node, ListValue): + elif not self._only_string_and_numeric and isinstance(node, ListValueNode): node.values = [] - elif not self._only_string_and_numeric and isinstance(node, ObjectValue): + elif not self._only_string_and_numeric and isinstance(node, ObjectValueNode): node.fields = [] return node @@ -135,7 +145,7 @@ def enter(self, node, key, parent, path, ancestors): class _RemoveAliasesVisitor(Visitor): def enter(self, node, key, parent, path, ancestors): - if isinstance(node, Field): + if isinstance(node, FieldNode): node.alias = None return node @@ -146,7 +156,7 @@ def enter(self, node, key, parent, path, ancestors): class _HexConversionVisitor(Visitor): def enter(self, node, key, parent, path, ancestors): - if isinstance(node, StringValue) and node.value is not None: + if isinstance(node, StringValueNode) and node.value is not None: if six.PY3: encoded = node.value.encode('utf-8').hex() else: @@ -161,26 +171,26 @@ def enter(self, node, key, parent, path, ancestors): class _SortingVisitor(Visitor): def enter(self, node, key, parent, path, ancestors): - if isinstance(node, Document): + if isinstance(node, DocumentNode): node.definitions = _sorted(node.definitions, lambda x: (x.__class__.__name__, self._by_name(x))) - elif isinstance(node, OperationDefinition): + elif isinstance(node, OperationDefinitionNode): node.variable_definitions = _sorted(node.variable_definitions, self._by_variable_name) - elif isinstance(node, SelectionSet): + elif isinstance(node, SelectionSetNode): node.selections = _sorted(node.selections, lambda x: (x.__class__.__name__, self._by_name(x))) - elif isinstance(node, Field): + elif isinstance(node, FieldNode): node.arguments = _sorted(node.arguments, self._by_name) - elif isinstance(node, FragmentSpread): + elif isinstance(node, FragmentSpreadNode): node.directives = _sorted(node.directives, self._by_name) - elif isinstance(node, InlineFragment): + elif isinstance(node, InlineFragmentNode): node.directives = _sorted(node.directives, self._by_type_definition) - elif isinstance(node, FragmentDefinition): + elif isinstance(node, FragmentDefinitionNode): node.directives = _sorted(node.directives, self._by_name) - elif isinstance(node, Directive): + elif isinstance(node, DirectiveNode): node.arguments = _sorted(node.arguments, self._by_name) return node def _by_name(self, node): - if isinstance(node, InlineFragment): + if isinstance(node, InlineFragmentNode): return self._by_type_definition(node) elif node.name is not None: return node.name.value diff --git a/graphene_tornado/compiler/execution.py b/graphene_tornado/compiler/execution.py new file mode 100644 index 0000000000..69565ffc60 --- /dev/null +++ b/graphene_tornado/compiler/execution.py @@ -0,0 +1,407 @@ +import json +from dataclasses import dataclass +from typing import Any, Optional, List, Callable, Dict + +from graphql import ( + GraphQLObjectType, + GraphQLError, + GraphQLSchema, + TypeNameMetaFieldDef, +) +from graphql.execution.utils import ( + ExecutionContext as GraphQLContext, + get_operation_root_type, + collect_fields, +) +from graphql.language.ast import ( + OperationDefinition, + FragmentDefinition, + Document, + Field, +) +from graphql.pyutils.default_ordered_dict import DefaultOrderedDict + +from graphene_tornado.compiler.json import query_to_json_schema, fast_json +from graphene_tornado.compiler.path import Path +from graphene_tornado.compiler.variables import compile_variable_parsing + +SAFETY_CHECK_PREFIX = "__validNode" +GLOBAL_DATA_NAME = "__context.data" +GLOBAL_ERRORS_NAME = "__context.errors" +GLOBAL_NULL_ERRORS_NAME = "__context.nullErrors" +GLOBAL_ROOT_NAME = "__context.rootValue" +GLOBAL_VARIABLES_NAME = "__context.variables" +GLOBAL_CONTEXT_NAME = "__context.context" +GLOBAL_EXECUTION_CONTEXT = "__context" +GLOBAL_PROMISE_COUNTER = "__context.promiseCounter" +GLOBAL_INSPECT_NAME = "__context.inspect" +GLOBAL_SAFE_MAP_NAME = "__context.safeMap" +GRAPHQL_ERROR = "__context.GraphQLError" +GLOBAL_RESOLVE = "__context.resolve" +GLOBAL_PARENT_NAME = "__parent" + + +@dataclass +class CompilerOptions(object): + custom_json_serializer: bool = False + disable_leaf_serialization: bool = False + disable_capturing_stack_errors: bool = False + custom_serializers: Dict[str, Any] = None + resolver_info_enricher: Optional[Callable] = None + + +@dataclass +class DeferredField: + name: str + responsePath: Path + originPaths: List[str] + destinationPaths: List[str] + parentType: GraphQLObjectType + fieldName: str + fieldType: Any + fieldNodes: List[Any] + args: Any + + +class CompilationContext(GraphQLContext): + resolvers: Dict[str, Any] + hoistedFunctions: List[str] + hoistedFunctionNames: Dict[str, int] + typeResolvers: Any + isTypeOfs: Any + resolveInfos: Dict[str, Any] + deferred: List[DeferredField] + options: CompilerOptions + depth: int + operation: OperationDefinition + + def __init__( + self, + schema=None, + document_ast=None, + root_value=None, + context_value=None, + variable_values=None, + operation_name=None, + executor=None, + middleware=None, + allow_subscriptions=None, + fragments=None, + resolvers=None, + hoistedFunctions=None, + hostedFunctionNames=None, + typeResolvers=None, + isTypeOfs=None, + resolveInfos=None, + deferred=None, + options=None, + depth=None, + operation=None, + serializers=None, + ): + # super().__init__(schema, document_ast, root_value, context_value, variable_values, operation_name, executor, + # middleware, allow_subscriptions) + self.resolvers = resolvers + self.hoistedFunctions = hoistedFunctions + self.hoistedFunctionNames = hostedFunctionNames + self.typeResolvers = typeResolvers + self.isTypeOfs = isTypeOfs + self.resolveInfos = resolveInfos + self.deferred = deferred + self.options = options + self.depth = depth + self.fragments = fragments + self.operation = operation + self.serializers = serializers + + +class ExecutionContext: + promiseCounter: int + data: Any + errors: Any + nullErrors: Any + resolve: Optional[Callable] + inspect: Any + variables: Any + context: Any + rootValue: Any + safeMap: Any + GraphQLError: Any + resolvers: Any + trimmer: Any + serializers: Any + typeResolvers: Any + isTypeOfs: Any + resolveInfos: Any + + +class CompiledQuery: + operationName: Optional[str] + query: Any + stringify: Any + + +class ObjectPath: + pass +# prev: ObjectPath | undefined; +# key: string; +# type: ResponsePathType; +# } + + +def compile_query( + schema: GraphQLSchema, + document: Document, + operation_name: Optional[str] = None, + options: Optional[str] = None, +) -> CompiledQuery: + if not options: + options = CompilerOptions() + + if not schema: + raise ValueError("schema") + + if not document: + raise ValueError("document") + + context = build_compilation_context(schema, document, options, operation_name) + + if options.custom_json_serializer: + json_schema = query_to_json_schema(context) + stringify = fast_json(json_schema) + else: + stringify = json.dumps + + get_variables = compile_variable_parsing( + schema, context.operation.variable_definitions or [] + ) + + function_body = compile_operation(context) + compiled_query = { + "query": create_bound_query( + context, + document, + None, + get_variables, + context.operation.name.value if context.operation.name else None, + ) + } + + return compiled_query + + +def build_compilation_context( + schema, document: Document, options, operation_name +) -> CompilationContext: + errors = [] + operation = None + has_multiple_assumed_operations = [] + fragments = {} + for definition in document.definitions: + if isinstance(definition, OperationDefinition): + if operation_name and operation: + has_multiple_assumed_operations = True + elif not operation_name or ( + definition.name and definition.name.value == operation_name + ): + operation = definition + elif isinstance(definition, FragmentDefinition): + fragments[definition.name.value] = definition + + if not operation: + if operation_name: + raise GraphQLError(f'Unknown operation named "{operation_name}".') + else: + raise GraphQLError("Must provide dan operation") + + return CompilationContext( + schema=schema, + fragments=fragments, + root_value=None, + context_value=None, + operation=operation, + resolvers={}, + serializers={}, + typeResolvers={}, + isTypeOfs={}, + resolveInfos={}, + ) + + +def compile_operation(context: CompilationContext) -> str: + type = get_operation_root_type(context.schema, context.operation) + serial_execution = context.operation.operation == "mutation" + field_map = collect_fields( + context, type, context.operation.selection_set, DefaultOrderedDict(), set(), + ) + top_level = compile_object_type( + context, + type, + [], + [GLOBAL_ROOT_NAME], + [GLOBAL_ROOT_NAME], + None, + GLOBAL_ERRORS_NAME, + field_map, + True, + ) + body = """ +function query (${GLOBAL_EXECUTION_CONTEXT}) { + "use strict"; +""" + if serial_execution: + body += f"${GLOBAL_EXECUTION_CONTEXT}.queue = [];" + body += generate_unique_declarations(context, True) + body += f"{GLOBAL_DATA_NAME} = {top_level}\n" + if serial_execution: + body += compile_deferred_fields_serially(context) + body += """ + ${GLOBAL_EXECUTION_CONTEXT}.finalResolve = () => {}; + ${GLOBAL_RESOLVE} = (context) => { + if (context.jobCounter >= context.queue.length) { + // All mutations have finished + context.finalResolve(context); + return; + } + context.queue[context.jobCounter++](context); + }; + // There might not be a job to run due to invalid queries + if (${GLOBAL_EXECUTION_CONTEXT}.queue.length > 0) { + ${GLOBAL_EXECUTION_CONTEXT}.jobCounter = 1; // since the first one will be run manually + ${GLOBAL_EXECUTION_CONTEXT}.queue[0](${GLOBAL_EXECUTION_CONTEXT}); + } + // Promises have been scheduled so a new promise is returned + // that will be resolved once every promise is done + if (${GLOBAL_PROMISE_COUNTER} > 0) { + return new Promise(resolve => ${GLOBAL_EXECUTION_CONTEXT}.finalResolve = resolve); + } +""" + else: + body += compile_deferred_fields(context) + body += """ + // Promises have been scheduled so a new promise is returned + // that will be resolved once every promise is done + if (${GLOBAL_PROMISE_COUNTER} > 0) { + return new Promise(resolve => ${GLOBAL_RESOLVE} = resolve); + """ + body += """ + // sync execution, the results are ready + return undefined; + """ + body += context.hoistedFunctions.join("\n") + return body + + +def add_path(response_path, name): + pass + + +def resolve_field_def(context, type, field_nodes): + pass + + +def get_argument_defs(field, param): + pass + + +def compile_object_type( + context: CompilationContext, + type: GraphQLObjectType, + field_nodes: List[Field], + origin_paths: List[str], + destination_paths: List[str], + response_path: Optional[str], + error_destination: str, + field_map: Dict[str, List[Field]], + always_defer: bool, +) -> str: + body = "(" + if isinstance(type.is_type_of, Callable) and not always_defer: + context.isTypeOfs[type.name + "IsTypeOf"] = type.is_type_of + body += """ +!${GLOBAL_EXECUTION_CONTEXT}.isTypeOfs["${ + type.name + }IsTypeOf"](${originPaths.join( + "." + )}) ? (${errorDestination}.push(${createErrorObject( + context, + fieldNodes, + responsePath as any, + ``Expected value of type "${ + type.name + }" but got: ${${GLOBAL_INSPECT_NAME}(${originPaths.join(".")})}.`` + )}), null) : + """ + body += "{" + for name, field_nodes in field_map.items(): + field = resolve_field_def(context, type, field_nodes) + if not field: + continue + body += "${name}, " + + if field == TypeNameMetaFieldDef: + body += "${type.name}, " + continue + + resolver = field.resolve + if not resolver and always_defer: + field_name = field.name + resolver = lambda parent: parent and parent[field_name] + + if resolver: + context.deferred.append( + DeferredField( + name=name, + responsePath=add_path(response_path, name), + originPaths=origin_paths, + destinationPaths=destination_paths, + parentType=type, + fieldName=field.name, + fieldType=field.return_type, + fieldNodes=field_nodes, + args=get_argument_defs(field, field_nodes[0]), + ) + ) + else: + body += compile_type( + context, + field.type, + field_nodes, + origin_paths, + destination_paths, + add_path(response_path, name), + ) + body += "," + body += "})" + return body + + +def compile_deferred_fields(context: CompilationContext): + pass + + +def generate_unique_declarations(context: CompilationContext, param): + pass + + +def create_bound_query( + context: CompilationContext, + document: Document, + param, + get_variables, + operation_name: Optional[str] = None, +) -> Any: + pass + + +def is_compiled_query(prepared) -> bool: + pass + + +def compile_deferred_fields_serially(context: CompilationContext): + pass + + +def compile_type(context: CompilationContext, parent_type: GraphQLContext, field_nodes: List[Field], origin_paths: +List[str], destination_paths: List[str], previous_path: ObjectPath) -> str: + pass diff --git a/graphene_tornado/ext/apollo_engine_reporting/engine_agent.py b/graphene_tornado/ext/apollo_engine_reporting/engine_agent.py index bbeae377b6..45a0c66be4 100644 --- a/graphene_tornado/ext/apollo_engine_reporting/engine_agent.py +++ b/graphene_tornado/ext/apollo_engine_reporting/engine_agent.py @@ -1,21 +1,21 @@ -from __future__ import absolute_import, print_function - import gzip import logging import os import socket import sys -from typing import NamedTuple, Optional, Callable +from typing import Callable +from typing import NamedTuple +from typing import Optional -import six from google.protobuf.json_format import MessageToJson from google.protobuf.message import Message -from six import StringIO, BytesIO +from six import BytesIO from tornado.httpclient import AsyncHTTPClient from tornado_retry_client import RetryClient from graphene_tornado.apollo_tooling.operation_id import default_engine_reporting_signature -from .reports_pb2 import ReportHeader, FullTracesReport +from .reports_pb2 import FullTracesReport +from .reports_pb2 import ReportHeader LOGGER = logging.getLogger(__name__) @@ -51,21 +51,21 @@ ('schema_tag', Optional[str]), ('generate_client_info', Optional[GenerateClientInfo]) ]) -EngineReportingOptions.__new__.__defaults__ = (None,) * len(EngineReportingOptions._fields) +EngineReportingOptions.__new__.__defaults__ = (None,) * len(EngineReportingOptions._fields) # type: ignore def _serialize(message: Message) -> bytes: - out = BytesIO() if six.PY3 else StringIO() + out = BytesIO() with gzip.GzipFile(fileobj=out, mode="w") as f: f.write(message.SerializeToString()) return out.getvalue() -def _get_trace_signature(operation_name, document_ast, query_string): - if not document_ast: +def _get_trace_signature(operation_name, document, query_string): + if not document: return query_string else: - return default_engine_reporting_signature(document_ast, operation_name) + return default_engine_reporting_signature(document, operation_name) class EngineReportingAgent: @@ -101,13 +101,13 @@ def __init__(self, options: EngineReportingOptions, schema_hash: str) -> None: def _options(self) -> EngineReportingOptions: return self.options - async def add_trace(self, operation_name, document_ast, query_string, trace): + async def add_trace(self, operation_name, document, query_string, trace): operation_name = operation_name or '-' if self._stopped: return - signature = _get_trace_signature(operation_name, document_ast, query_string) + signature = _get_trace_signature(operation_name, document, query_string) stats_report_key = "# " + operation_name + '\n' + signature traces_per_query = self.report.traces_per_query.get(stats_report_key, None) if not traces_per_query: diff --git a/graphene_tornado/ext/apollo_engine_reporting/engine_extension.py b/graphene_tornado/ext/apollo_engine_reporting/engine_extension.py index f8b1051805..fef551ea32 100644 --- a/graphene_tornado/ext/apollo_engine_reporting/engine_extension.py +++ b/graphene_tornado/ext/apollo_engine_reporting/engine_extension.py @@ -3,9 +3,15 @@ import json import time from numbers import Number +from typing import Any from typing import Callable, NamedTuple +from typing import cast +from typing import List +from typing import Optional +from typing import Union from google.protobuf.timestamp_pb2 import Timestamp +from graphql.pyutils import Path from tornado.httputil import HTTPServerRequest from graphene_tornado.ext.apollo_engine_reporting.engine_agent import EngineReportingOptions @@ -32,13 +38,13 @@ def generate_client_info(request: HTTPServerRequest) -> ClientInfo: ) -def response_path_as_string(path): - if not path: +def response_path_as_string(path: Optional[List[Union[str, int]]]) -> str: + if not path or len(path) == 0: return '' - return '.'.join((str(x) for x in path)) + return '.'.join([str(p) for p in path]) -def now_ns(): +def now_ns() -> int: return time.time_ns() @@ -59,7 +65,7 @@ def __init__(self, options: EngineReportingOptions, add_trace: Callable) -> None self.trace = Trace(root=root) self.nodes = {response_path_as_string(None): root} self.generate_client_info = options.generate_client_info or generate_client_info - self.resolver_stats = list() + self.resolver_stats: List[Any] = list() async def request_started(self, request, query_string, parsed_query, operation_name, variables, context, request_context): self.trace.start_time.GetCurrentTime() @@ -82,9 +88,7 @@ async def on_request_ended(errors): op_name = self.operation_name or '' self.trace.root.MergeFrom(self.nodes.get('')) - document = request_context.get('document', None) - document_ast = document.document_ast if document else None - await self.add_trace(op_name, document_ast, self.query_string, self.trace) + await self.add_trace(op_name, request_context.get('document', None), self.query_string, self.trace) return on_request_ended @@ -143,27 +147,28 @@ def _get_http_method(self, request): except: return Trace.HTTP.UNKNOWN - def _new_node(self, path): + def _new_node(self, path: Path): node = Trace.Node() - id = path[-1] - if isinstance(id, Number): + path_list = path.as_list() + + id = path_list[-1] + if isinstance(id, int): node.index = id else: - node.response_name = id + node.response_name = cast(str, id) - self.nodes[response_path_as_string(path)] = node + self.nodes[response_path_as_string(path_list)] = node parent_node = self._ensure_parent_node(path) n = parent_node.child.add() n.MergeFrom(node) - self.nodes[response_path_as_string(path)] = n + self.nodes[response_path_as_string(path_list)] = n return n - def _ensure_parent_node(self, path): - prev = [''] if len(path) == 1 else path[:-1] - parent_path = response_path_as_string(prev) + def _ensure_parent_node(self, path: Path): + parent_path = response_path_as_string(path.prev) parent_node = self.nodes.get(parent_path, None) if parent_node: return parent_node - return self._new_node(prev) + return self._new_node(path.prev) diff --git a/graphene_tornado/ext/apollo_engine_reporting/schema_utils.py b/graphene_tornado/ext/apollo_engine_reporting/schema_utils.py index cbccf09b61..5ea446a222 100644 --- a/graphene_tornado/ext/apollo_engine_reporting/schema_utils.py +++ b/graphene_tornado/ext/apollo_engine_reporting/schema_utils.py @@ -1,19 +1,24 @@ -from __future__ import absolute_import, print_function +from __future__ import absolute_import +from __future__ import print_function import hashlib +from typing import cast -from graphene import Schema -from graphql import parse, execute, GraphQLError -from graphql.utils.introspection_query import introspection_query +from graphene.types.schema import introspection_query +from graphql import execute +from graphql import ExecutionResult +from graphql import GraphQLError +from graphql import GraphQLSchema +from graphql import parse from json_stable_stringify_python import stringify -def generate_schema_hash(schema: Schema) -> str: +def generate_schema_hash(schema: GraphQLSchema) -> str: """ Generates a stable hash of the current schema using an introspection query. """ ast = parse(introspection_query) - result = execute(schema, ast) + result = cast(ExecutionResult, execute(schema, ast)) if result and not result.data: raise GraphQLError('Unable to generate server introspection document') diff --git a/graphene_tornado/ext/apollo_engine_reporting/tests/snapshots/__init__.py b/graphene_tornado/ext/apollo_engine_reporting/tests/snapshots/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/graphene_tornado/ext/apollo_engine_reporting/tests/snapshots/snap_test_engine_extension.py b/graphene_tornado/ext/apollo_engine_reporting/tests/snapshots/snap_test_engine_extension.py new file mode 100644 index 0000000000..03e7081673 --- /dev/null +++ b/graphene_tornado/ext/apollo_engine_reporting/tests/snapshots/snap_test_engine_extension.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import Snapshot + + +snapshots = Snapshot() + +snapshots['test_can_send_report_to_engine 1'] = '''{ + "durationNs": "-1", + "endTime": "2019-09-21T19:37:09.908919Z", + "http": { + "method": "GET" + }, + "root": { + "child": [ + { + "endTime": "-1", + "parentType": "Query", + "responseName": "author", + "startTime": "-1", + "type": "User" + }, + { + "endTime": "-1", + "parentType": "Query", + "responseName": "aBoolean", + "startTime": "-1", + "type": "Boolean" + }, + { + "child": [ + { + "endTime": "-1", + "parentType": "User", + "responseName": "name", + "startTime": "-1", + "type": "String" + } + ], + "responseName": "author" + }, + { + "child": [ + { + "endTime": "-1", + "parentType": "User", + "responseName": "posts", + "startTime": "-1", + "type": "[Post]" + } + ], + "responseName": "author" + }, + { + "child": [ + { + "child": [ + { + "child": [ + { + "endTime": "-1", + "parentType": "Post", + "responseName": "id", + "startTime": "-1", + "type": "Int" + } + ], + "index": 0 + } + ], + "responseName": "posts" + } + ], + "responseName": "author" + }, + { + "child": [ + { + "child": [ + { + "child": [ + { + "endTime": "-1", + "parentType": "Post", + "responseName": "id", + "startTime": "-1", + "type": "Int" + } + ], + "index": 1 + } + ], + "responseName": "posts" + } + ], + "responseName": "author" + } + ], + "endTime": "-1", + "startTime": "-1" + }, + "startTime": "2019-09-21T19:37:09.908919Z" +}''' diff --git a/graphene_tornado/ext/apollo_engine_reporting/tests/test_engine_extension.py b/graphene_tornado/ext/apollo_engine_reporting/tests/test_engine_extension.py index 0705b6d779..9f1e157da9 100644 --- a/graphene_tornado/ext/apollo_engine_reporting/tests/test_engine_extension.py +++ b/graphene_tornado/ext/apollo_engine_reporting/tests/test_engine_extension.py @@ -1,7 +1,6 @@ import re import pytest -import six import tornado from google.protobuf.json_format import MessageToJson from graphql import parse @@ -10,80 +9,11 @@ from graphene_tornado.ext.apollo_engine_reporting.engine_extension import EngineReportingExtension from graphene_tornado.ext.apollo_engine_reporting.tests.schema import schema from graphene_tornado.tests.http_helper import HttpHelper -from graphene_tornado.tests.test_graphql import GRAPHQL_HEADER, url_string, response_json +from graphene_tornado.tests.test_graphql import GRAPHQL_HEADER +from graphene_tornado.tests.test_graphql import response_json +from graphene_tornado.tests.test_graphql import url_string from graphene_tornado.tornado_graphql_handler import TornadoGraphQLHandler -expected = """{ - "durationNs": "-1", - "endTime": "2019-09-21T19:37:09.908919Z", - "http": { - "method": "GET" - }, - "root": { - "child": [ - { - "child": [ - { - "endTime": "-1", - "parentType": "User", - "responseName": "name", - "startTime": "-1", - "type": "String" - }, - { - "child": [ - { - "child": [ - { - "endTime": "-1", - "parentType": "Post", - "responseName": "id", - "startTime": "-1", - "type": "Int" - } - ], - "index": 0 - }, - { - "child": [ - { - "endTime": "-1", - "parentType": "Post", - "responseName": "id", - "startTime": "-1", - "type": "Int" - } - ], - "index": 1 - } - ], - "endTime": "-1", - "parentType": "User", - "responseName": "posts", - "startTime": "-1", - "type": "[Post]" - } - ], - "endTime": "-1", - "parentType": "Query", - "responseName": "author", - "startTime": "-1", - "type": "User" - }, - { - "endTime": "-1", - "parentType": "Query", - "responseName": "aBoolean", - "startTime": "-1", - "type": "Boolean" - } - ], - "endTime": "-1", - "startTime": "-1" - }, - "startTime": "2019-09-21T19:37:09.908919Z" -}""" - QUERY = """ query { @@ -129,7 +59,7 @@ def http_helper(http_client, base_url): @pytest.mark.gen_test() -def test_can_send_report_to_engine(http_helper): +def test_can_send_report_to_engine(http_helper, snapshot): response = yield http_helper.get(url_string(query=QUERY), headers=GRAPHQL_HEADER) assert response.code == 200 assert 'data' in response_json(response) @@ -147,7 +77,4 @@ def test_can_send_report_to_engine(http_helper): '"2019-09-21T19:37:09.908919Z"', trace_json) - e = expected - if six.PY3: - e = re.sub(r'\s+\n', '\n', expected) - assert e == trace_json + snapshot.assert_match(trace_json) diff --git a/graphene_tornado/ext/extension_helpers.py b/graphene_tornado/ext/extension_helpers.py index f19bc69a68..e36dbe2181 100644 --- a/graphene_tornado/ext/extension_helpers.py +++ b/graphene_tornado/ext/extension_helpers.py @@ -45,7 +45,7 @@ def get_signature(request_context, operation_name, document, query_string): if signature is None: if document: calculate_signature = default_engine_reporting_signature - signature = calculate_signature(document.document_ast, operation_name) + signature = calculate_signature(document, operation_name) elif query_string: signature = query_string request_context[SIGNATURE] = signature diff --git a/graphene_tornado/ext/opencensus/opencensus_tracing_extension.py b/graphene_tornado/ext/opencensus/opencensus_tracing_extension.py index b8ac9b8fe7..32ca89a823 100644 --- a/graphene_tornado/ext/opencensus/opencensus_tracing_extension.py +++ b/graphene_tornado/ext/opencensus/opencensus_tracing_extension.py @@ -82,7 +82,7 @@ async def will_resolve_field(self, root, info, **args): # If we wanted to be fancy, we could build up a tree like the ApolloEngineExtension does so that the # whole tree appears as nested spans. However, this is a bit tricky to do with the current OpenCensus # API because when you request a span, a bunch of context variables are set. This keeps it simple for now. - tracer.start_span('.'.join(str(x) for x in info.path)) + tracer.start_span('.'.join(info.path.as_list())) async def on_end(errors=None, result=None): tracer.end_span() diff --git a/graphene_tornado/ext/opencensus/tests/test_opencensus_tracing.py b/graphene_tornado/ext/opencensus/tests/test_opencensus_tracing.py index 0b45a3c520..98024aa5a0 100644 --- a/graphene_tornado/ext/opencensus/tests/test_opencensus_tracing.py +++ b/graphene_tornado/ext/opencensus/tests/test_opencensus_tracing.py @@ -2,7 +2,7 @@ import pytest import tornado -from graphql import parse, GraphQLBackend +from graphql import parse from opencensus.trace import tracer as tracer_module, execution_context from opencensus.trace.base_exporter import Exporter from opencensus.trace.propagation.google_cloud_format import GoogleCloudFormatPropagator @@ -20,11 +20,11 @@ class GQLHandler(TornadoGraphQLHandler): - def initialize(self, schema=None, executor=None, middleware: Optional[Any] = None, root_value: Any = None, - graphiql: bool = False, pretty: bool = False, batch: bool = False, backend: GraphQLBackend = None, + def initialize(self, schema=None, middleware: Optional[Any] = None, root_value: Any = None, + graphiql: bool = False, pretty: bool = False, batch: bool = False, extensions: List[Union[Callable[[], GraphQLExtension], GraphQLExtension]] = None, exporter=None): - super().initialize(schema, executor, middleware, root_value, graphiql, pretty, batch, backend, extensions) + super().initialize(schema, middleware, root_value, graphiql, pretty, batch, extensions) execution_context.set_opencensus_tracer(tracer_module.Tracer( sampler=AlwaysOnSampler(), exporter=exporter, diff --git a/graphene_tornado/graphql_extension.py b/graphene_tornado/graphql_extension.py index 16a1ed651f..75e613be14 100644 --- a/graphene_tornado/graphql_extension.py +++ b/graphene_tornado/graphql_extension.py @@ -4,16 +4,22 @@ Extensions are also middleware but have additional hooks. """ -from __future__ import absolute_import, print_function +from __future__ import absolute_import +from __future__ import print_function -from abc import ABCMeta, abstractmethod -from typing import NewType, List, Callable, Optional, Dict, Any +from abc import ABCMeta +from abc import abstractmethod +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from graphql import DocumentNode from graphql import GraphQLSchema -from graphql.language.ast import Document from tornado.httputil import HTTPServerRequest -EndHandler = NewType('EndHandler', Optional[List[Callable[[List[Exception]], None]]]) +EndHandler = Optional[List[Callable[[List[Exception]], None]]] class GraphQLExtension: @@ -24,7 +30,7 @@ class GraphQLExtension: def request_started(self, request: HTTPServerRequest, query_string: Optional[str], - parsed_query: Optional[Document], + parsed_query: Optional[DocumentNode], operation_name: Optional[str], variables: Optional[Dict[str, Any]], context: Any, @@ -43,7 +49,7 @@ def validation_started(self) -> EndHandler: @abstractmethod def execution_started(self, schema: GraphQLSchema, - document: Document, + document: DocumentNode, root: Any, context: Optional[Any], variables: Optional[Any], diff --git a/graphene_tornado/schema.py b/graphene_tornado/schema.py index 296d272058..3582664f7f 100644 --- a/graphene_tornado/schema.py +++ b/graphene_tornado/schema.py @@ -3,6 +3,8 @@ import graphene from graphene import ObjectType, Schema from tornado.escape import to_unicode +from graphql.type.definition import GraphQLResolveInfo +from typing import Optional class QueryRoot(ObjectType): @@ -14,17 +16,17 @@ class QueryRoot(ObjectType): def resolve_thrower(self, info): raise Exception("Throws!") - def resolve_request(self, info): + def resolve_request(self, info: GraphQLResolveInfo) -> str: return to_unicode(info.context.arguments['q'][0]) - def resolve_test(self, info, who=None): + def resolve_test(self, info: GraphQLResolveInfo, who: Optional[str]=None) -> str: return 'Hello %s' % (who or 'World') class MutationRoot(ObjectType): write_test = graphene.Field(QueryRoot) - def resolve_write_test(self, info): + def resolve_write_test(self, info: GraphQLResolveInfo) -> QueryRoot: return QueryRoot() diff --git a/graphene_tornado/tests/test_graphiql.py b/graphene_tornado/tests/test_graphiql.py index 8acdc68f78..76971e40c1 100644 --- a/graphene_tornado/tests/test_graphiql.py +++ b/graphene_tornado/tests/test_graphiql.py @@ -58,19 +58,6 @@ def test_graphiql_renders_pretty(http_helper): assert pretty_response in to_unicode(response.body) -@pytest.mark.gen_test -def test_graphiql_renders_pretty(http_helper): - response = yield http_helper.get('/graphql?query={test}', headers={'Accept': 'text/html'}) - pretty_response = ( - '{\n' - ' "data": {\n' - ' "test": "Hello World"\n' - ' }\n' - '}' - ).replace("\"", "\\\"").replace("\n", "\\n") - assert pretty_response in to_unicode(response.body) - - @pytest.mark.gen_test def test_handles_empty_vars(http_helper): response = yield http_helper.post_json('/graphql', headers={'Accept': 'text/html'}, post_data=dict( diff --git a/graphene_tornado/tests/test_graphql.py b/graphene_tornado/tests/test_graphql.py index 33b71b8eaf..028ce6b337 100644 --- a/graphene_tornado/tests/test_graphql.py +++ b/graphene_tornado/tests/test_graphql.py @@ -82,12 +82,14 @@ def test_reports_validation_errors(http_helper): assert response_json(context.value.response) == { 'errors': [ { - 'message': 'Cannot query field "unknownOne" on type "QueryRoot".', - 'locations': [{'line': 1, 'column': 9}] + 'message': "Cannot query field 'unknownOne' on type 'QueryRoot'.", + 'locations': [{'line': 1, 'column': 9}], + 'path': None, }, { - 'message': 'Cannot query field "unknownTwo" on type "QueryRoot".', - 'locations': [{'line': 1, 'column': 21}] + 'message': "Cannot query field 'unknownTwo' on type 'QueryRoot'.", + 'locations': [{'line': 1, 'column': 21}], + 'path': None, } ] } @@ -103,11 +105,10 @@ def test_errors_when_missing_operation_name(http_helper): ''' )) - assert context.value.code == 400 assert response_json(context.value.response) == { 'errors': [ { - 'message': 'Must provide operation name if query contains multiple operations.' + 'message': 'Must provide operation name if query contains multiple operations.', } ] } @@ -147,9 +148,9 @@ def test_errors_when_selecting_a_mutation_within_a_get(http_helper): assert response_json(context.value.response) == { 'errors': [ { - 'message': 'Can only perform a mutation operation from a POST request.' + 'message': 'Can only perform a mutation operation from a POST request.', } - ] + ], } @@ -444,8 +445,9 @@ def test_handles_syntax_errors_caught_by_graphql(http_helper): assert context.value.code == 400 assert response_json(context.value.response) == { 'errors': [{'locations': [{'column': 1, 'line': 1}], - 'message': 'Syntax Error GraphQL (1:1) ' - 'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n'}] + 'message': "Syntax Error: Unexpected Name 'syntaxerror'.", + 'path': None} + ] } diff --git a/graphene_tornado/tornado_graphql_handler.py b/graphene_tornado/tornado_graphql_handler.py index 990584b756..f4342e72b9 100644 --- a/graphene_tornado/tornado_graphql_handler.py +++ b/graphene_tornado/tornado_graphql_handler.py @@ -2,19 +2,35 @@ import json import sys import traceback -from typing import List, Union, Callable, Any, Optional - -import six -from graphql import get_default_backend, execute, validate, GraphQLBackend -from graphql.error import GraphQLError +from asyncio import iscoroutinefunction +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from graphene.types.schema import Schema +from graphql import DocumentNode +from graphql import execute +from graphql import get_operation_ast +from graphql import OperationType +from graphql import parse +from graphql import validate from graphql.error import format_error as format_graphql_error -from graphql.execution import ExecutionResult -from graphql.execution.executors.asyncio import AsyncioExecutor +from graphql.error.graphql_error import GraphQLError +from graphql.error.syntax_error import GraphQLSyntaxError +from graphql.execution.execute import ExecutionResult +from graphql.pyutils import is_awaitable +from mypy_extensions import TypedDict from tornado import web -from tornado.escape import json_encode, to_unicode -from tornado.locks import Event +from tornado.escape import json_encode +from tornado.escape import to_unicode +from tornado.httputil import HTTPServerRequest from tornado.log import app_log from tornado.web import HTTPError +from typing_extensions import Awaitable from werkzeug.datastructures import MIMEAccept from werkzeug.http import parse_accept_header @@ -35,34 +51,30 @@ def __init__(self, status_code=400, errors=None): class TornadoGraphQLHandler(web.RequestHandler): - executor = None - schema = None - batch = False - middleware = [] - pretty = False - root_value = None - graphiql = False - graphiql_version = None - graphiql_template = None - graphiql_html_title = None - backend = None - document = None - graphql_params = None - parsed_body = None + schema: Schema + batch: bool = False + middleware: List[Any] = [] + pretty: bool = False + root_value: Optional[Any] = None + graphiql: bool = False + graphiql_version: Optional[str] = None + graphiql_template: Optional[str] = None + graphiql_html_title: Optional[str] = None + document: Optional[DocumentNode] + graphql_params: Optional[Tuple[Any, Any, Any, Any]] = None + parsed_body: Optional[Dict[str, Any]] = None extension_stack = GraphQLExtensionStack([]) - request_context = {} + request_context: Dict[str, Any] = {} def initialize(self, - schema=None, - executor=None, + schema: Optional[Schema]=None, middleware: Optional[Any] = None, root_value: Any = None, graphiql: bool = False, pretty: bool = False, batch: bool = False, - backend: GraphQLBackend = None, extensions: List[Union[Callable[[], GraphQLExtension], GraphQLExtension]] = None - ): + ) -> None: super(TornadoGraphQLHandler, self).initialize() self.schema = schema @@ -78,45 +90,40 @@ def initialize(self, if len(middlewares) > 0: self.middleware = middlewares - self.executor = executor self.root_value = root_value self.pretty = pretty self.graphiql = graphiql self.batch = batch - self.backend = backend or get_default_backend() @property - def context(self): + def context(self) -> HTTPServerRequest: return self.request - def get_root(self): + def get_root(self) -> Any: return self.root_value - def get_middleware(self): + def get_middleware(self) -> List[Callable]: return self.middleware - def get_backend(self): - return self.backend - - def get_document(self): + def get_document(self) -> Optional[DocumentNode]: return self.document def get_parsed_body(self): return self.parsed_body - async def get(self): + async def get(self) -> None: try: await self.run('get') except Exception as ex: self.handle_error(ex) - async def post(self): + async def post(self) -> None: try: await self.run('post') except Exception as ex: self.handle_error(ex) - async def run(self, method): + async def run(self, method: str) -> None: show_graphiql = self.graphiql and self.should_display_graphiql() if show_graphiql: @@ -152,7 +159,7 @@ async def run(self, method): self.write(result) await self.finish() - def parse_body(self): + def parse_body(self) -> Any: content_type = self.content_type if content_type == 'application/graphql': @@ -199,7 +206,7 @@ async def get_response(self, data, method, show_graphiql=False): self.context, self.request_context) try: - execution_result = await self.execute_graphql_request( + execution_result, invalid = await self.execute_graphql_request( method, query, variables, @@ -211,11 +218,8 @@ async def get_response(self, data, method, show_graphiql=False): if execution_result: response = {} - if getattr(execution_result, 'is_pending', False): - event = Event() - on_resolve = lambda *_: event.set() # noqa - await execution_result.then(on_resolve).catch(on_resolve) - await event.wait() + if is_awaitable(execution_result) or iscoroutinefunction(execution_result): + execution_result = await execution_result if hasattr(execution_result, 'get'): execution_result = execution_result.get() @@ -223,7 +227,7 @@ async def get_response(self, data, method, show_graphiql=False): if execution_result.errors: response['errors'] = [self.format_error(e) for e in execution_result.errors] - if execution_result.invalid: + if invalid: status_code = 400 else: response['data'] = execution_result.data @@ -242,48 +246,51 @@ async def get_response(self, data, method, show_graphiql=False): finally: await request_end() - async def execute_graphql_request(self, method, query, variables, operation_name, show_graphiql=False): + async def execute_graphql_request(self, method: str, query: Optional[str], variables: Optional[Dict[str, str]], operation_name: Optional[str], show_graphiql: bool = False) -> Tuple[Optional[Union[Awaitable[ExecutionResult], ExecutionResult]], Optional[bool]]: if not query: if show_graphiql: - return None + return None, None raise HTTPError(400, 'Must provide query string.') parsing_ended = await self.extension_stack.parsing_started(query) try: - backend = self.get_backend() - self.document = backend.document_from_string(self.schema, query) + self.document = parse(query) await parsing_ended() - except Exception as e: + except GraphQLError as e: await parsing_ended(e) - return ExecutionResult(errors=[e], invalid=True) + return ExecutionResult(errors=[e], data=None), True validation_ended = await self.extension_stack.validation_started() try: - validation_errors = validate(self.schema, self.document.document_ast) - except Exception as e: + validation_errors = validate(self.schema.graphql_schema, self.document) + except GraphQLError as e: await validation_ended([e]) - return ExecutionResult(errors=[e], invalid=True) + return ExecutionResult(errors=[e], data=None), True if validation_errors: - validation_ended(validation_errors) + await validation_ended(validation_errors) return ExecutionResult( errors=validation_errors, - invalid=True, - ) + data=None, + ), True else: await validation_ended() if method.lower() == 'get': - operation_type = self.document.get_operation_type(operation_name) - if operation_type and operation_type != "query": + operation_node = get_operation_ast(self.document, operation_name) + if not operation_node: if show_graphiql: - return None + return None, None + raise HTTPError(405, 'Must provide operation name if query contains multiple operations.') + if not operation_node.operation == OperationType.QUERY: + if show_graphiql: + return None, None raise HTTPError(405, 'Can only perform a {} operation from a POST request.' - .format(operation_type)) + .format(operation_node.operation.value)) execution_ended = await self.extension_stack.execution_started( - schema=self.schema, + schema=self.schema.graphql_schema, document=self.document, root=self.root_value, context=self.context, @@ -293,32 +300,30 @@ async def execute_graphql_request(self, method, query, variables, operation_name ) try: result = await self.execute( - self.document.document_ast, - root=self.get_root(), - variables=variables, + self.document, + root_value=self.get_root(), + variable_values=variables, operation_name=operation_name, - context=self.context, + context_value=self.context, middleware=self.get_middleware(), - executor=self.executor or AsyncioExecutor(), - return_promise=True ) await execution_ended() - except Exception as e: + except GraphQLError as e: await execution_ended([e]) - return ExecutionResult(errors=[e], invalid=True) + return ExecutionResult(errors=[e], data=None), True - return result + return result, False - async def execute(self, *args, **kwargs): - return execute(self.schema, *args, **kwargs) + async def execute(self, *args, **kwargs) -> Union[Awaitable[ExecutionResult], ExecutionResult]: + return execute(self.schema.graphql_schema, *args, **kwargs) - def json_encode(self, d, pretty=False): - if pretty or self.get_query_argument('pretty', False): + def json_encode(self, d: Dict[str, Any], pretty: bool = False) -> str: + if pretty or self.get_query_argument('pretty', False): # type: ignore return json.dumps(d, sort_keys=True, indent=2, separators=(',', ': ')) return json.dumps(d, separators=(',', ':')) - def render_graphiql(self, query, variables, operation_name, result): + def render_graphiql(self, query: str, variables: str, operation_name: str, result: str) -> str: return render_graphiql( query=query, variables=variables, @@ -329,18 +334,18 @@ def render_graphiql(self, query, variables, operation_name, result): graphiql_html_title=self.graphiql_html_title, ) - def should_display_graphiql(self): + def should_display_graphiql(self) -> bool: raw = 'raw' in self.request.query_arguments.keys() or 'raw' in self.request.arguments return not raw and self.request_wants_html() - def request_wants_html(self): + def request_wants_html(self) -> bool: accept_header = self.request.headers.get('Accept', '') accept_mimetypes = parse_accept_header(accept_header, MIMEAccept) best = accept_mimetypes.best_match(['application/json', 'text/html']) return best == 'text/html' and accept_mimetypes[best] > accept_mimetypes['application/json'] @property - def content_type(self): + def content_type(self) -> str: return self.request.headers.get('Content-Type', 'text/plain').split(';')[0] @staticmethod @@ -351,19 +356,19 @@ def instantiate_middleware(middlewares): continue yield middleware - def get_graphql_params(self, request, data): + def get_graphql_params(self, request: HTTPServerRequest, data: Dict[str, Any]) -> Any: if self.graphql_params: return self.graphql_params single_args = {} for key in request.arguments.keys(): - single_args[key] = self.decode_argument(request.arguments.get(key)[0]) + single_args[key] = self.decode_argument(request.arguments.get(key)[0]) # type: ignore query = single_args.get('query') or data.get('query') variables = single_args.get('variables') or data.get('variables') id = single_args.get('id') or data.get('id') - if variables and isinstance(variables, six.string_types): + if variables and isinstance(variables, str): try: variables = json.loads(variables) except: # noqa @@ -376,7 +381,7 @@ def get_graphql_params(self, request, data): self.graphql_params = query, variables, operation_name, id return self.graphql_params - def handle_error(self, ex): + def handle_error(self, ex: Exception) -> None: if not isinstance(ex, (web.HTTPError, ExecutionError, GraphQLError)): tb = ''.join(traceback.format_exception(*sys.exc_info())) app_log.error('Error: {0} {1}'.format(ex, tb)) @@ -386,7 +391,7 @@ def handle_error(self, ex): self.write(error_json) @staticmethod - def error_status(exception): + def error_status(exception: Exception) -> int: if isinstance(exception, web.HTTPError): return exception.status_code elif isinstance(exception, (ExecutionError, GraphQLError)): @@ -395,7 +400,7 @@ def error_status(exception): return 500 @staticmethod - def error_format(exception): + def error_format(exception: Exception) -> List[Dict[str, Any]]: if isinstance(exception, ExecutionError): return [{'message': e} for e in exception.errors] elif isinstance(exception, GraphQLError): @@ -406,8 +411,8 @@ def error_format(exception): return [{'message': 'Unknown server error'}] @staticmethod - def format_error(error): + def format_error(error: Union[GraphQLError, GraphQLSyntaxError]) -> Dict[str, Any]: if isinstance(error, GraphQLError): return format_graphql_error(error) - return {'message': six.text_type(error)} + return {'message': str(error)} diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000..2603afdf14 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,16 @@ +[mypy] + +[mypy-graphene.*] +ignore_missing_imports = True + +[mypy-opencensus.*] +ignore_missing_imports = True + +[mypy-json_stable_stringify_python] +ignore_missing_imports = True + +[mypy-tornado_retry_client.*] +ignore_missing_imports = True + +[mypy-graphene_tornado.*.tests.*] +ignore_errors = True diff --git a/requirements-test.txt b/requirements-test.txt index 853c319865..1b604ad309 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,5 +1,6 @@ coveralls==1.5.1 mock==2.0.0 +mypy==0.770 pytest==4.4.1 pytest-cov==2.6.1 pytest-tornado==0.8.0 diff --git a/requirements.txt b/requirements.txt index 8e65fe5642..03168d6558 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,10 @@ six>=1.10.0 json-stable-stringify-python==0.2 -graphene>=2.0.1 +graphene>=3.0b1 Jinja2>=2.10.1, <2.11.0 opencensus>=0.7.7 protobuf>=3.7.1 +snapshottest==0.5.1 tornado>=5.1.0 tornado-retry-client==0.6.1 tox diff --git a/setup.py b/setup.py index 897594271c..4a57b8ee90 100644 --- a/setup.py +++ b/setup.py @@ -35,12 +35,13 @@ 'Development Status :: 3 - Alpha', 'Intended Audience :: Developers', 'Topic :: Software Development :: Libraries', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: Implementation :: PyPy', ], @@ -50,13 +51,14 @@ install_requires=[ 'six>=1.10.0', - 'graphene>=2.1,<3', + 'graphene>3.0b1', 'Jinja2>=2.10.1', 'tornado>=5.1.0', 'werkzeug==0.12.2' ], setup_requires=[ 'pytest', + 'snapshottest' ], tests_require=tests_require, extras_require={ diff --git a/tox.ini b/tox.ini index c8609bda6f..cd21b453b3 100644 --- a/tox.ini +++ b/tox.ini @@ -5,13 +5,10 @@ skipsdist = true [testenv] setenv = PYTHONPATH = {toxinidir} +pip_pre = true deps = -rrequirements.txt -rrequirements-test.txt commands = - py{py,27,37}: py.test -vv -p no:warnings --cov=graphene_tornado graphene_tornado {posargs} - -[testenv:mypy] -basepython=python3.7 -deps = mypy -commands = - mypy graphene_tornado --ignore-missing-imports + mypy --config-file=mypy.ini graphene_tornado + py{py,37}: py.test -vv -p no:warnings --cov=graphene_tornado graphene_tornado {posargs} +whitelist_externals=mypy