diff --git a/ariadne/contrib/federation/utils.py b/ariadne/contrib/federation/utils.py index bb6bc1d8..a93acc45 100644 --- a/ariadne/contrib/federation/utils.py +++ b/ariadne/contrib/federation/utils.py @@ -1,9 +1,14 @@ # pylint: disable=cell-var-from-loop -import re from inspect import isawaitable -from typing import Any, List +from typing import Any, List, Tuple, cast +from graphql import ( + DirectiveDefinitionNode, + Node, + parse, + print_ast, +) from graphql.language import DirectiveNode from graphql.type import ( GraphQLNamedType, @@ -14,36 +19,6 @@ ) -_i_token_delimiter = r"(?:^|[\s]+|$)" -_i_token_name = "[_A-Za-z][_0-9A-Za-z]*" -_i_token_arguments = r"\([^)]*\)" -_i_token_location = "[_A-Za-z][_0-9A-Za-z]*" -_i_token_description_block_string = r"(?:\"{3}(?:[^\"]{1,}|[\s])\"{3})" -_i_token_description_single_line = r"(?:\"(?:[^\"\n\r])*?\")" - -_r_directive_definition = re.compile( - "(" - f"(?:{_i_token_delimiter}(?:" - f"{_i_token_description_block_string}|{_i_token_description_single_line}" - "))??" - f"{_i_token_delimiter}directive" - f"(?:{_i_token_delimiter})?@({_i_token_name})" - f"(?:(?:{_i_token_delimiter})?{_i_token_arguments})?" - f"{_i_token_delimiter}on" - f"{_i_token_delimiter}(?:[|]{_i_token_delimiter})?{_i_token_location}" - f"(?:{_i_token_delimiter}[|]{_i_token_delimiter}{_i_token_location})*" - ")" - f"(?={_i_token_delimiter})", -) - -_r_directive = re.compile( - "(" - f"(?:{_i_token_delimiter})?@({_i_token_name})" - f"(?:(?:{_i_token_delimiter})?{_i_token_arguments})?" - ")" - f"(?={_i_token_delimiter})", -) - _allowed_directives = [ "skip", # Default directive as per specs. "include", # Default directive as per specs. @@ -66,14 +41,39 @@ ] +def _purge_directive_nodes(nodes: Tuple[Node, ...]) -> Tuple[Node, ...]: + return tuple( + node + for node in nodes + if not isinstance(node, (DirectiveNode, DirectiveDefinitionNode)) + or node.name.value in _allowed_directives + ) + + +def _purge_type_directives(definition: Node): + # Recursively check every field defined on the Node definition + # and remove any directives found. + for key in definition.keys: + value = getattr(definition, key, None) + if isinstance(value, tuple): + # Remove directive nodes from the tuple + # e.g. doc -> definitions [DirectiveDefinitionNode] + next_value = _purge_directive_nodes(cast(Tuple[Node, ...], value)) + for item in next_value: + if isinstance(item, Node): + # Look for directive nodes on sub-nodes + # e.g. doc -> definitions [ObjectTypeDefinitionNode] -> fields -> directives + _purge_type_directives(item) + setattr(definition, key, next_value) + elif isinstance(value, Node): + _purge_type_directives(value) + + def purge_schema_directives(joined_type_defs: str) -> str: """Remove custom schema directives from federation.""" - joined_type_defs = _r_directive_definition.sub("", joined_type_defs) - joined_type_defs = _r_directive.sub( - lambda m: m.group(1) if m.group(2) in _allowed_directives else "", - joined_type_defs, - ) - return joined_type_defs + ast_document = parse(joined_type_defs) + _purge_type_directives(ast_document) + return print_ast(ast_document) def resolve_entities(_: Any, info: GraphQLResolveInfo, **kwargs) -> Any: diff --git a/tests/federation/test_utils.py b/tests/federation/test_utils.py index d6ef51ef..9fccdab7 100644 --- a/tests/federation/test_utils.py +++ b/tests/federation/test_utils.py @@ -63,6 +63,8 @@ def test_purge_directives_remove_custom_directives(): directive @another on FIELD + directive @plural repeatable on FIELD + type Query { field1: String @custom field2: String @other @@ -107,6 +109,10 @@ def test_purge_directives_remove_custom_directives_with_single_line_description( "Any Description" directive @custom on FIELD + type Entity { + field: String @custom + } + type Query { rootField: String @custom } @@ -114,6 +120,10 @@ def test_purge_directives_remove_custom_directives_with_single_line_description( assert sic(purge_schema_directives(type_defs)) == sic( """ + type Entity { + field: String + } + type Query { rootField: String } @@ -127,6 +137,58 @@ def test_purge_directives_without_leading_whitespace(): assert sic(purge_schema_directives(type_defs)) == "" +def test_purge_directives_remove_custom_directives_from_interfaces(): + type_defs = """ + directive @custom on INTERFACE + + interface EntityInterface @custom { + field: String + } + + type Entity implements EntityInterface { + field: String + } + + type Query { + rootField: Entity + } + """ + + assert sic(purge_schema_directives(type_defs)) == sic( + """ + interface EntityInterface { + field: String + } + + type Entity implements EntityInterface { + field: String + } + + type Query { + rootField: Entity + } + """ + ) + + +def test_purge_directives_remove_custom_directive_with_arguments(): + type_defs = """ + directive @custom(arg: String) on FIELD + + type Query { + rootField: String @custom(arg: "value") + } + """ + + assert sic(purge_schema_directives(type_defs)) == sic( + """ + type Query { + rootField: String + } + """ + ) + + def test_get_entity_types_with_key_directive(): type_defs = """ type Query {