diff --git a/.changeset/thin-oranges-laugh.md b/.changeset/thin-oranges-laugh.md new file mode 100644 index 00000000000..9cda6f16002 --- /dev/null +++ b/.changeset/thin-oranges-laugh.md @@ -0,0 +1,18 @@ +--- +"@apollo/client": patch +--- + +Allow `RetryLink` to retry an operation when fatal [transport-level errors](https://www.apollographql.com/docs/graphos/routing/operations/subscriptions/multipart-protocol#message-and-error-format) are emitted from multipart subscriptions. + +```js +const retryLink = new RetryLink({ + attempts: (count, operation, error) => { + if (error instanceof ApolloError) { + // errors available on the `protocolErrors` field in `ApolloError` + console.log(error.protocolErrors) + } + + return true; + } +}); +``` diff --git a/src/link/retry/__tests__/retryLink.ts b/src/link/retry/__tests__/retryLink.ts index 85955021588..973b8ab92e6 100644 --- a/src/link/retry/__tests__/retryLink.ts +++ b/src/link/retry/__tests__/retryLink.ts @@ -5,7 +5,11 @@ import { execute } from "../../core/execute"; import { Observable } from "../../../utilities/observables/Observable"; import { fromError } from "../../utils/fromError"; import { RetryLink } from "../retryLink"; -import { ObservableStream } from "../../../testing/internal"; +import { + mockMultipartSubscriptionStream, + ObservableStream, +} from "../../../testing/internal"; +import { ApolloError } from "../../../core"; const query = gql` { @@ -210,4 +214,64 @@ describe("RetryLink", () => { [3, operation, standardError], ]); }); + + it("handles protocol errors from multipart subscriptions", async () => { + const subscription = gql` + subscription MySubscription { + aNewDieWasCreated { + die { + roll + sides + color + } + } + } + `; + + const attemptStub = jest.fn(); + attemptStub.mockReturnValueOnce(true); + + const retryLink = new RetryLink({ + delay: { initial: 1 }, + attempts: attemptStub, + }); + + const { httpLink, enqueuePayloadResult, enqueueProtocolErrors } = + mockMultipartSubscriptionStream(); + const link = ApolloLink.from([retryLink, httpLink]); + const stream = new ObservableStream(execute(link, { query: subscription })); + + enqueueProtocolErrors([ + { message: "Error field", extensions: { code: "INTERNAL_SERVER_ERROR" } }, + ]); + + enqueuePayloadResult({ + data: { + aNewDieWasCreated: { die: { color: "blue", roll: 2, sides: 6 } }, + }, + }); + + await expect(stream).toEmitValue({ + data: { + aNewDieWasCreated: { die: { color: "blue", roll: 2, sides: 6 } }, + }, + }); + + expect(attemptStub).toHaveBeenCalledTimes(1); + expect(attemptStub).toHaveBeenCalledWith( + 1, + expect.objectContaining({ + operationName: "MySubscription", + query: subscription, + }), + new ApolloError({ + protocolErrors: [ + { + message: "Error field", + extensions: { code: "INTERNAL_SERVER_ERROR" }, + }, + ], + }) + ); + }); }); diff --git a/src/link/retry/retryLink.ts b/src/link/retry/retryLink.ts index cde2dd2ea9c..37c293d99df 100644 --- a/src/link/retry/retryLink.ts +++ b/src/link/retry/retryLink.ts @@ -7,6 +7,11 @@ import { buildDelayFunction } from "./delayFunction.js"; import type { RetryFunction, RetryFunctionOptions } from "./retryFunction.js"; import { buildRetryFunction } from "./retryFunction.js"; import type { SubscriptionObserver } from "zen-observable-ts"; +import { + ApolloError, + graphQLResultHasProtocolErrors, + PROTOCOL_ERRORS_SYMBOL, +} from "../../errors/index.js"; export namespace RetryLink { export interface Options { @@ -54,7 +59,21 @@ class RetryableOperation { private try() { this.currentSubscription = this.forward(this.operation).subscribe({ - next: this.observer.next.bind(this.observer), + next: (result) => { + if (graphQLResultHasProtocolErrors(result)) { + this.onError( + new ApolloError({ + protocolErrors: result.extensions[PROTOCOL_ERRORS_SYMBOL], + }) + ); + // Unsubscribe from the current subscription to prevent the `complete` + // handler to be called as a result of the stream closing. + this.currentSubscription?.unsubscribe(); + return; + } + + this.observer.next(result); + }, error: this.onError, complete: this.observer.complete.bind(this.observer), });