Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: include promises from late waitUntil calls in FetchEventResult.waitUntil #66747

Draft
wants to merge 1 commit into
base: fix-pass-waituntil-to-webserver
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,28 @@ export function getRender({
request: NextRequestHint,
event?: NextFetchEvent
) {
const isAfterEnabled = !!process.env.__NEXT_AFTER

const extendedReq = new WebNextRequest(request)
const extendedRes = new WebNextResponse(
undefined,
// tracking onClose adds overhead, so only do it if `experimental.after` is on.
!!process.env.__NEXT_AFTER
isAfterEnabled
)

handler(extendedReq, extendedRes)
const result = await extendedRes.toResponse()
request.fetchMetrics = extendedReq.fetchMetrics

if (event?.waitUntil) {
if (isAfterEnabled) {
// make sure that NextRequestHint's awaiter stays open long enough
// for late `waitUntil`s called during streaming to get picked up.
event.waitUntil(
new Promise<void>((resolve) => extendedRes.onClose(resolve))
)
}

// TODO(after):
// remove `internal_runWithWaitUntil` and the `internal-edge-wait-until` module
// when consumers switch to `unstable_after`.
Expand Down
81 changes: 81 additions & 0 deletions packages/next/src/server/lib/awaiter.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import { InvariantError } from '../../shared/lib/invariant-error'
import { AwaiterMulti, AwaiterOnce } from './awaiter'

describe('AwaiterOnce/AwaiterMulti', () => {
describe.each([
{ name: 'AwaiterMulti', impl: AwaiterMulti },
{ name: 'AwaiterOnce', impl: AwaiterOnce },
])('$name', ({ impl: AwaiterImpl }) => {
it('awaits promises added by other promises', async () => {
const awaiter = new AwaiterImpl()

const MAX_DEPTH = 5
const promises: TrackedPromise<unknown>[] = []

const waitUntil = (promise: Promise<unknown>) => {
promises.push(trackPromiseSettled(promise))
awaiter.waitUntil(promise)
}

const makeNestedPromise = async () => {
if (promises.length >= MAX_DEPTH) {
return
}
await sleep(100)
waitUntil(makeNestedPromise())
}

waitUntil(makeNestedPromise())

await awaiter.awaiting()

for (const promise of promises) {
expect(promise.isSettled).toBe(true)
}
})

it('calls onError for rejected promises', async () => {
const onError = jest.fn<void, [error: unknown]>()
const awaiter = new AwaiterImpl({ onError })

awaiter.waitUntil(Promise.reject('error 1'))
awaiter.waitUntil(
sleep(100).then(() => awaiter.waitUntil(Promise.reject('error 2')))
)

await awaiter.awaiting()

expect(onError).toHaveBeenCalledWith('error 1')
expect(onError).toHaveBeenCalledWith('error 2')
})
})
})

describe('AwaiterOnce', () => {
it("does not allow calling waitUntil after it's been awaited", async () => {
const awaiter = new AwaiterOnce()
awaiter.waitUntil(Promise.resolve(1))
await awaiter.awaiting()
expect(() => awaiter.waitUntil(Promise.resolve(2))).toThrow(InvariantError)
})
})

type TrackedPromise<T> = Promise<T> & { isSettled: boolean }

function trackPromiseSettled<T>(promise: Promise<T>): TrackedPromise<T> {
const tracked = promise as TrackedPromise<T>
tracked.isSettled = false
tracked.then(
() => {
tracked.isSettled = true
},
() => {
tracked.isSettled = true
}
)
return tracked
}

function sleep(duration: number) {
return new Promise<void>((resolve) => setTimeout(resolve, duration))
}
70 changes: 70 additions & 0 deletions packages/next/src/server/lib/awaiter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { InvariantError } from '../../shared/lib/invariant-error'

/**
* Provides a `waitUntil` implementation which gathers promises to be awaited later (via {@link AwaiterMulti.awaiting}).
* Unlike a simple `Promise.all`, {@link AwaiterMulti} works recursively --
* if a promise passed to {@link AwaiterMulti.waitUntil} calls `waitUntil` again,
* that second promise will also be awaited.
*/
export class AwaiterMulti {
private promises: Set<Promise<unknown>> = new Set()
private onError: (error: unknown) => void

constructor({ onError }: { onError?: (error: unknown) => void } = {}) {
this.onError = onError ?? console.error
}

public waitUntil = (promise: Promise<unknown>): void => {
// if a promise settles before we await it, we can drop it.
const cleanup = () => {
this.promises.delete(promise)
}

this.promises.add(
promise.then(cleanup, (err) => {
cleanup()
this.onError(err)
})
)
}

public async awaiting(): Promise<void> {
while (this.promises.size > 0) {
const promises = Array.from(this.promises)
this.promises.clear()
await Promise.all(promises)
}
}
}

/**
* Like {@link AwaiterMulti}, but can only be awaited once.
* If {@link AwaiterOnce.waitUntil} is called after that, it will throw.
*/
export class AwaiterOnce {
private awaiter: AwaiterMulti
private done: boolean = false
private pending: Promise<void> | undefined

constructor(options: { onError?: (error: unknown) => void } = {}) {
this.awaiter = new AwaiterMulti(options)
}

public waitUntil = (promise: Promise<unknown>): void => {
if (this.done) {
throw new InvariantError(
'Cannot call waitUntil() on an AwaiterOnce that was already awaited'
)
}
return this.awaiter.waitUntil(promise)
}

public async awaiting(): Promise<void> {
if (!this.pending) {
this.pending = this.awaiter.awaiting().finally(() => {
this.done = true
})
}
return this.pending
}
}
2 changes: 1 addition & 1 deletion packages/next/src/server/web/adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ export async function adapter(

return {
response: finalResponse,
waitUntil: Promise.all(event[waitUntilSymbol]),
waitUntil: event[waitUntilSymbol](),
fetchMetrics: request.fetchMetrics,
}
}
7 changes: 7 additions & 0 deletions packages/next/src/server/web/edge-route-module-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ export class EdgeRouteModuleWrapper {
const trackedBody = trackStreamConsumed(res.body, () =>
_closeController.dispatchClose()
)

// make sure that NextRequestHint's awaiter stays open long enough
// for `waitUntil`s called late during streaming to get picked up.
evt.waitUntil(
new Promise<void>((resolve) => _closeController.onClose(resolve))
)

res = new Response(trackedBody, {
status: res.status,
statusText: res.statusText,
Expand Down
14 changes: 11 additions & 3 deletions packages/next/src/server/web/spec-extension/fetch-event.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import { AwaiterOnce } from '../../lib/awaiter'
import { PageSignatureError } from '../error'
import type { NextRequest } from './request'

const responseSymbol = Symbol('response')
const passThroughSymbol = Symbol('passThrough')
const awaiterSymbol = Symbol('awaiter')

export const waitUntilSymbol = Symbol('waitUntil')

class FetchEvent {
readonly [waitUntilSymbol]: Promise<any>[] = [];
[responseSymbol]?: Promise<Response>;
[passThroughSymbol] = false
[passThroughSymbol] = false;

[awaiterSymbol] = new AwaiterOnce();

[waitUntilSymbol] = () => {
return this[awaiterSymbol].awaiting()
}

// eslint-disable-next-line @typescript-eslint/no-useless-constructor
constructor(_request: Request) {}
Expand All @@ -24,7 +32,7 @@ class FetchEvent {
}

waitUntil(promise: Promise<any>): void {
this[waitUntilSymbol].push(promise)
this[awaiterSymbol].waitUntil(promise)
}
}

Expand Down
Loading