diff --git a/src/execution/Canceller.ts b/src/execution/Canceller.ts new file mode 100644 index 0000000000..19f797c35b --- /dev/null +++ b/src/execution/Canceller.ts @@ -0,0 +1,52 @@ +import { promiseWithResolvers } from '../jsutils/promiseWithResolvers.js'; + +/** + * A Canceller object that can be used to cancel multiple promises + * using a single AbortSignal. + * + * @internal + */ +export class Canceller { + abortSignal: AbortSignal; + abort: () => void; + + private _aborts: Set<() => void>; + + constructor(abortSignal: AbortSignal) { + this.abortSignal = abortSignal; + this._aborts = new Set<() => void>(); + this.abort = () => { + for (const abort of this._aborts) { + abort(); + } + }; + + abortSignal.addEventListener('abort', this.abort); + } + + unsubscribe(): void { + this.abortSignal.removeEventListener('abort', this.abort); + } + + withCancellation(originalPromise: Promise): Promise { + if (this.abortSignal === undefined) { + return originalPromise; + } + + const { promise, resolve, reject } = promiseWithResolvers(); + const abort = () => reject(this.abortSignal.reason); + this._aborts.add(abort); + originalPromise.then( + (resolved) => { + this._aborts.delete(abort); + resolve(resolved); + }, + (error: unknown) => { + this._aborts.delete(abort); + reject(error); + }, + ); + + return promise; + } +} diff --git a/src/execution/IncrementalPublisher.ts b/src/execution/IncrementalPublisher.ts index d060ad2463..24ef905a7a 100644 --- a/src/execution/IncrementalPublisher.ts +++ b/src/execution/IncrementalPublisher.ts @@ -4,6 +4,7 @@ import { pathToArray } from '../jsutils/Path.js'; import type { GraphQLError } from '../error/GraphQLError.js'; +import type { Canceller } from './Canceller.js'; import { IncrementalGraph } from './IncrementalGraph.js'; import type { CancellableStreamRecord, @@ -43,6 +44,7 @@ export function buildIncrementalResponse( } interface IncrementalPublisherContext { + canceller: Canceller | undefined; cancellableStreams: Set | undefined; } @@ -171,6 +173,7 @@ class IncrementalPublisher { batch = await this._incrementalGraph.nextCompletedBatch(); } while (batch !== undefined); + this._context.canceller?.unsubscribe(); await this._returnAsyncIteratorsIgnoringErrors(); return { value: undefined, done: true }; }; diff --git a/src/execution/__tests__/abort-signal-test.ts b/src/execution/__tests__/abort-signal-test.ts index ad9ba6c332..e25cbb3b38 100644 --- a/src/execution/__tests__/abort-signal-test.ts +++ b/src/execution/__tests__/abort-signal-test.ts @@ -9,7 +9,11 @@ import { parse } from '../../language/parser.js'; import { buildSchema } from '../../utilities/buildASTSchema.js'; -import { execute, experimentalExecuteIncrementally } from '../execute.js'; +import { + execute, + experimentalExecuteIncrementally, + subscribe, +} from '../execute.js'; import type { InitialIncrementalExecutionResult, SubsequentIncrementalExecutionResult, @@ -52,12 +56,17 @@ const schema = buildSchema(` type Query { todo: Todo + nonNullableTodo: Todo! } type Mutation { foo: String bar: String } + + type Subscription { + foo: String + } `); describe('Execute: Cancellation', () => { @@ -300,6 +309,97 @@ describe('Execute: Cancellation', () => { }); }); + it('should stop the execution when aborted despite a hanging resolver', async () => { + const abortController = new AbortController(); + const document = parse(` + query { + todo { + id + author { + id + } + } + } + `); + + const resultPromise = execute({ + document, + schema, + abortSignal: abortController.signal, + rootValue: { + todo: () => + new Promise(() => { + /* will never resolve */ + }), + }, + }); + + abortController.abort(); + + const result = await resultPromise; + + expect(result.errors?.[0].originalError?.name).to.equal('AbortError'); + + expectJSON(result).toDeepEqual({ + data: { + todo: null, + }, + errors: [ + { + message: 'This operation was aborted', + path: ['todo'], + locations: [{ line: 3, column: 9 }], + }, + ], + }); + }); + + it('should stop the execution when aborted with proper null bubbling', async () => { + const abortController = new AbortController(); + const document = parse(` + query { + nonNullableTodo { + id + author { + id + } + } + } + `); + + const resultPromise = execute({ + document, + schema, + abortSignal: abortController.signal, + rootValue: { + nonNullableTodo: async () => + Promise.resolve({ + id: '1', + text: 'Hello, World!', + /* c8 ignore next */ + author: () => expect.fail('Should not be called'), + }), + }, + }); + + abortController.abort(); + + const result = await resultPromise; + + expect(result.errors?.[0].originalError?.name).to.equal('AbortError'); + + expectJSON(result).toDeepEqual({ + data: null, + errors: [ + { + message: 'This operation was aborted', + path: ['nonNullableTodo'], + locations: [{ line: 3, column: 9 }], + }, + ], + }); + }); + it('should stop deferred execution when aborted', async () => { const abortController = new AbortController(); const document = parse(` @@ -353,14 +453,12 @@ describe('Execute: Cancellation', () => { const abortController = new AbortController(); const document = parse(` query { - todo { - id - ... on Todo @defer { + ... on Query @defer { + todo { + id text author { - ... on Author @defer { - id - } + id } } } @@ -382,41 +480,27 @@ describe('Execute: Cancellation', () => { abortController.signal, ); - await resolveOnNextTick(); - await resolveOnNextTick(); - await resolveOnNextTick(); - abortController.abort(); const result = await resultPromise; expectJSON(result).toDeepEqual([ { - data: { - todo: { - id: '1', - }, - }, - pending: [{ id: '0', path: ['todo'] }], + data: {}, + pending: [{ id: '0', path: [] }], hasNext: true, }, { incremental: [ { data: { - text: 'hello world', - author: null, + todo: null, }, errors: [ { - locations: [ - { - column: 13, - line: 7, - }, - ], message: 'This operation was aborted', - path: ['todo', 'author'], + path: ['todo'], + locations: [{ line: 4, column: 11 }], }, ], id: '0', @@ -448,6 +532,10 @@ describe('Execute: Cancellation', () => { }, }); + await resolveOnNextTick(); + await resolveOnNextTick(); + await resolveOnNextTick(); + abortController.abort(); const result = await resultPromise; @@ -498,4 +586,39 @@ describe('Execute: Cancellation', () => { ], }); }); + + it('should stop the execution when aborted during subscription', async () => { + const abortController = new AbortController(); + const document = parse(` + subscription { + foo + } + `); + + const resultPromise = subscribe({ + document, + schema, + abortSignal: abortController.signal, + rootValue: { + foo: async () => + new Promise(() => { + /* will never resolve */ + }), + }, + }); + + abortController.abort(); + + const result = await resultPromise; + + expectJSON(result).toDeepEqual({ + errors: [ + { + message: 'This operation was aborted', + path: ['foo'], + locations: [{ line: 3, column: 9 }], + }, + ], + }); + }); }); diff --git a/src/execution/execute.ts b/src/execution/execute.ts index 7c06624414..7bf5dea2a3 100644 --- a/src/execution/execute.ts +++ b/src/execution/execute.ts @@ -50,6 +50,7 @@ import { assertValidSchema } from '../type/validate.js'; import type { DeferUsageSet, ExecutionPlan } from './buildExecutionPlan.js'; import { buildExecutionPlan } from './buildExecutionPlan.js'; +import { Canceller } from './Canceller.js'; import type { DeferUsage, FieldDetailsList, @@ -163,6 +164,7 @@ export interface ValidatedExecutionArgs { export interface ExecutionContext { validatedExecutionArgs: ValidatedExecutionArgs; errors: Array | undefined; + canceller: Canceller | undefined; cancellableStreams: Set | undefined; } @@ -310,9 +312,11 @@ export function executeQueryOrMutationOrSubscriptionEvent( export function experimentalExecuteQueryOrMutationOrSubscriptionEvent( validatedExecutionArgs: ValidatedExecutionArgs, ): PromiseOrValue { + const abortSignal = validatedExecutionArgs.abortSignal; const exeContext: ExecutionContext = { validatedExecutionArgs, errors: undefined, + canceller: abortSignal ? new Canceller(abortSignal) : undefined, cancellableStreams: undefined, }; try { @@ -364,14 +368,18 @@ export function experimentalExecuteQueryOrMutationOrSubscriptionEvent( if (isPromise(graphqlWrappedResult)) { return graphqlWrappedResult.then( (resolved) => buildDataResponse(exeContext, resolved), - (error: unknown) => ({ - data: null, - errors: withError(exeContext.errors, error as GraphQLError), - }), + (error: unknown) => { + exeContext.canceller?.unsubscribe(); + return { + data: null, + errors: withError(exeContext.errors, error as GraphQLError), + }; + }, ); } return buildDataResponse(exeContext, graphqlWrappedResult); } catch (error) { + exeContext.canceller?.unsubscribe(); return { data: null, errors: withError(exeContext.errors, error) }; } } @@ -462,6 +470,7 @@ function buildDataResponse( const { rawResult: data, incrementalDataRecords } = graphqlWrappedResult; const errors = exeContext.errors; if (incrementalDataRecords === undefined) { + exeContext.canceller?.unsubscribe(); return errors !== undefined ? { errors, data } : { data }; } @@ -660,11 +669,12 @@ function executeFieldsSerially( incrementalContext: IncrementalContext | undefined, deferMap: ReadonlyMap | undefined, ): PromiseOrValue>> { + const abortSignal = exeContext.validatedExecutionArgs.abortSignal; return promiseReduce( groupedFieldSet, (graphqlWrappedResult, [responseName, fieldDetailsList]) => { const fieldPath = addPath(path, responseName, parentType.name); - const abortSignal = exeContext.validatedExecutionArgs.abortSignal; + if (abortSignal?.aborted) { handleFieldError( abortSignal.reason, @@ -811,7 +821,7 @@ function executeField( incrementalContext: IncrementalContext | undefined, deferMap: ReadonlyMap | undefined, ): PromiseOrValue> | undefined { - const validatedExecutionArgs = exeContext.validatedExecutionArgs; + const { validatedExecutionArgs, canceller } = exeContext; const { schema, contextValue, variableValues, hideSuggestions, abortSignal } = validatedExecutionArgs; const fieldName = fieldDetailsList[0].node.name.value; @@ -856,7 +866,7 @@ function executeField( fieldDetailsList, info, path, - result, + canceller?.withCancellation(result) ?? result, incrementalContext, deferMap, ); @@ -1745,23 +1755,13 @@ function completeObjectValue( incrementalContext: IncrementalContext | undefined, deferMap: ReadonlyMap | undefined, ): PromiseOrValue>> { - const validatedExecutionArgs = exeContext.validatedExecutionArgs; - const abortSignal = validatedExecutionArgs.abortSignal; - if (abortSignal?.aborted) { - throw locatedError( - abortSignal.reason, - toNodes(fieldDetailsList), - pathToArray(path), - ); - } - // If there is an isTypeOf predicate function, call it with the // current result. If isTypeOf returns false, then raise an error rather // than continuing execution. if (returnType.isTypeOf) { const isTypeOf = returnType.isTypeOf( result, - validatedExecutionArgs.contextValue, + exeContext.validatedExecutionArgs.contextValue, info, ); @@ -2201,11 +2201,18 @@ function executeSubscription( const result = resolveFn(rootValue, args, contextValue, info, abortSignal); if (isPromise(result)) { - return result - .then(assertEventStream) - .then(undefined, (error: unknown) => { + const canceller = abortSignal ? new Canceller(abortSignal) : undefined; + const promise = canceller?.withCancellation(result) ?? result; + return promise.then(assertEventStream).then( + (resolved) => { + canceller?.unsubscribe(); + return resolved; + }, + (error: unknown) => { + canceller?.unsubscribe(); throw locatedError(error, fieldNodes, pathToArray(path)); - }); + }, + ); } return assertEventStream(result);