diff --git a/src/execution/collectFields.ts b/src/execution/collectFields.ts index 1d0341b4cc..073d9fc905 100644 --- a/src/execution/collectFields.ts +++ b/src/execution/collectFields.ts @@ -5,6 +5,7 @@ import { isSameSet } from '../jsutils/isSameSet.js'; import type { ObjMap } from '../jsutils/ObjMap.js'; import type { + DirectiveNode, FieldNode, FragmentDefinitionNode, FragmentSpreadNode, @@ -26,7 +27,7 @@ import type { GraphQLSchema } from '../type/schema.js'; import { typeFromAST } from '../utilities/typeFromAST.js'; -import { getDirectiveValues } from './values.js'; +import { getArgumentValues, getDirectiveValues } from './values.js'; export interface DeferUsage { label: string | undefined; @@ -60,6 +61,7 @@ export interface CollectFieldsResult { groupedFieldSet: GroupedFieldSet; newGroupedFieldSetDetails: Map; newDeferUsages: ReadonlyArray; + forbiddenDirectiveInstances: ReadonlyArray; } interface CollectFieldsContext { @@ -72,6 +74,7 @@ interface CollectFieldsContext { fieldsByTarget: Map>; newDeferUsages: Array; visitedFragmentNames: Set; + forbiddenDirectiveInstances: Array; } /** @@ -100,6 +103,7 @@ export function collectFields( targetsByKey: new Map(), newDeferUsages: [], visitedFragmentNames: new Set(), + forbiddenDirectiveInstances: [], }; collectFieldsImpl(context, operation.selectionSet); @@ -107,9 +111,20 @@ export function collectFields( return { ...buildGroupedFieldSets(context.targetsByKey, context.fieldsByTarget), newDeferUsages: context.newDeferUsages, + forbiddenDirectiveInstances: context.forbiddenDirectiveInstances, }; } +/** + * This variable is the empty variables used during the validation phase (where + * no variables exist) for field collection; if a `@skip` or `@include` + * directive is ever seen when `variableValues` is set to this, it should + * throw. + */ +export const VALIDATION_PHASE_EMPTY_VARIABLES: { + [variable: string]: any; +} = Object.freeze(Object.create(null)); + /** * Given an array of field nodes, collects all of the subfields of the passed * in fields, and returns them at the end. @@ -139,6 +154,7 @@ export function collectSubfields( targetsByKey: new Map(), newDeferUsages: [], visitedFragmentNames: new Set(), + forbiddenDirectiveInstances: [], }; for (const fieldDetails of fieldGroup.fields) { @@ -155,6 +171,7 @@ export function collectSubfields( fieldGroup.targets, ), newDeferUsages: context.newDeferUsages, + forbiddenDirectiveInstances: context.forbiddenDirectiveInstances, }; } @@ -179,7 +196,7 @@ function collectFieldsImpl( for (const selection of selectionSet.selections) { switch (selection.kind) { case Kind.FIELD: { - if (!shouldIncludeNode(variableValues, selection)) { + if (!shouldIncludeNode(context, variableValues, selection)) { continue; } const key = getFieldEntryKey(selection); @@ -200,7 +217,7 @@ function collectFieldsImpl( } case Kind.INLINE_FRAGMENT: { if ( - !shouldIncludeNode(variableValues, selection) || + !shouldIncludeNode(context, variableValues, selection) || !doesFragmentConditionMatch(schema, selection, runtimeType) ) { continue; @@ -232,7 +249,7 @@ function collectFieldsImpl( case Kind.FRAGMENT_SPREAD: { const fragName = selection.name.value; - if (!shouldIncludeNode(variableValues, selection)) { + if (!shouldIncludeNode(context, variableValues, selection)) { continue; } @@ -304,19 +321,44 @@ function getDeferValues( * directives, where `@skip` has higher precedence than `@include`. */ function shouldIncludeNode( + context: CollectFieldsContext, variableValues: { [variable: string]: unknown }, node: FragmentSpreadNode | FieldNode | InlineFragmentNode, ): boolean { - const skip = getDirectiveValues(GraphQLSkipDirective, node, variableValues); + const skipDirectiveNode = node.directives?.find( + (directive) => directive.name.value === GraphQLSkipDirective.name, + ); + if ( + skipDirectiveNode && + variableValues === VALIDATION_PHASE_EMPTY_VARIABLES + ) { + context.forbiddenDirectiveInstances.push(skipDirectiveNode); + return false; + } + const skip = skipDirectiveNode + ? getArgumentValues(GraphQLSkipDirective, skipDirectiveNode, variableValues) + : undefined; if (skip?.if === true) { return false; } - const include = getDirectiveValues( - GraphQLIncludeDirective, - node, - variableValues, + const includeDirectiveNode = node.directives?.find( + (directive) => directive.name.value === GraphQLIncludeDirective.name, ); + if ( + includeDirectiveNode && + variableValues === VALIDATION_PHASE_EMPTY_VARIABLES + ) { + context.forbiddenDirectiveInstances.push(includeDirectiveNode); + return false; + } + const include = includeDirectiveNode + ? getArgumentValues( + GraphQLIncludeDirective, + includeDirectiveNode, + variableValues, + ) + : undefined; if (include?.if === false) { return false; } diff --git a/src/validation/__tests__/SingleFieldSubscriptionsRule-test.ts b/src/validation/__tests__/SingleFieldSubscriptionsRule-test.ts index 7e0a227d07..7e301b3720 100644 --- a/src/validation/__tests__/SingleFieldSubscriptionsRule-test.ts +++ b/src/validation/__tests__/SingleFieldSubscriptionsRule-test.ts @@ -286,6 +286,48 @@ describe('Validate: Subscriptions with single field', () => { ]); }); + it('fails with @skip or @include directive', () => { + expectErrors(` + subscription RequiredRuntimeValidation($bool: Boolean!) { + newMessage @include(if: $bool) { + body + sender + } + disallowedSecondRootField @skip(if: $bool) + } + `).toDeepEqual([ + { + message: + 'Subscription "RequiredRuntimeValidation" must not use `@skip` or `@include` directives in the top level selection.', + locations: [ + { line: 3, column: 20 }, + { line: 7, column: 35 }, + ], + }, + ]); + }); + + it('fails with @skip or @include directive in anonymous subscription', () => { + expectErrors(` + subscription ($bool: Boolean!) { + newMessage @include(if: $bool) { + body + sender + } + disallowedSecondRootField @skip(if: $bool) + } + `).toDeepEqual([ + { + message: + 'Anonymous Subscription must not use `@skip` or `@include` directives in the top level selection.', + locations: [ + { line: 3, column: 20 }, + { line: 7, column: 35 }, + ], + }, + ]); + }); + it('skips if not subscription type', () => { const emptySchema = buildSchema(` type Query { diff --git a/src/validation/rules/SingleFieldSubscriptionsRule.ts b/src/validation/rules/SingleFieldSubscriptionsRule.ts index c0d1031103..aeb76af51a 100644 --- a/src/validation/rules/SingleFieldSubscriptionsRule.ts +++ b/src/validation/rules/SingleFieldSubscriptionsRule.ts @@ -11,7 +11,10 @@ import { Kind } from '../../language/kinds.js'; import type { ASTVisitor } from '../../language/visitor.js'; import type { FieldGroup } from '../../execution/collectFields.js'; -import { collectFields } from '../../execution/collectFields.js'; +import { + collectFields, + VALIDATION_PHASE_EMPTY_VARIABLES, +} from '../../execution/collectFields.js'; import type { ValidationContext } from '../ValidationContext.js'; @@ -23,7 +26,8 @@ function toNodes(fieldGroup: FieldGroup): ReadonlyArray { * Subscriptions must only include a non-introspection field. * * A GraphQL subscription is valid only if it contains a single root field and - * that root field is not an introspection field. + * that root field is not an introspection field. `@skip` and `@include` + * directives are forbidden. * * See https://spec.graphql.org/draft/#sec-Single-root-field */ @@ -37,9 +41,7 @@ export function SingleFieldSubscriptionsRule( const subscriptionType = schema.getSubscriptionType(); if (subscriptionType) { const operationName = node.name ? node.name.value : null; - const variableValues: { - [variable: string]: any; - } = Object.create(null); + const variableValues = VALIDATION_PHASE_EMPTY_VARIABLES; const document = context.getDocument(); const fragments: ObjMap = Object.create(null); for (const definition of document.definitions) { @@ -47,13 +49,25 @@ export function SingleFieldSubscriptionsRule( fragments[definition.name.value] = definition; } } - const { groupedFieldSet } = collectFields( - schema, - fragments, - variableValues, - subscriptionType, - node, - ); + const { groupedFieldSet, forbiddenDirectiveInstances } = + collectFields( + schema, + fragments, + variableValues, + subscriptionType, + node, + ); + if (forbiddenDirectiveInstances.length > 0) { + context.reportError( + new GraphQLError( + operationName != null + ? `Subscription "${operationName}" must not use \`@skip\` or \`@include\` directives in the top level selection.` + : 'Anonymous Subscription must not use `@skip` or `@include` directives in the top level selection.', + { nodes: forbiddenDirectiveInstances }, + ), + ); + return; + } if (groupedFieldSet.size > 1) { const fieldGroups = [...groupedFieldSet.values()]; const extraFieldGroups = fieldGroups.slice(1);