Skip to content

Commit

Permalink
fix: include promises from late waitUntil calls in FetchEventResult.w…
Browse files Browse the repository at this point in the history
…aitUntil
  • Loading branch information
lubieowoce committed Aug 8, 2024
1 parent dcab3e5 commit 8efe900
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,28 @@ export function getRender({
request: NextRequestHint,
event?: NextFetchEvent
) {
const isAfterEnabled = !!process.env.__NEXT_AFTER

const extendedReq = new WebNextRequest(request)
const extendedRes = new WebNextResponse(
undefined,
// tracking onClose adds overhead, so only do it if `experimental.after` is on.
!!process.env.__NEXT_AFTER
isAfterEnabled
)

handler(extendedReq, extendedRes)
const result = await extendedRes.toResponse()
request.fetchMetrics = extendedReq.fetchMetrics

if (event?.waitUntil) {
if (isAfterEnabled) {
// make sure that NextRequestHint's awaiter stays open long enough
// for late `waitUntil`s called during streaming to get picked up.
event.waitUntil(
new Promise<void>((resolve) => extendedRes.onClose(resolve))
)
}

// TODO(after):
// remove `internal_runWithWaitUntil` and the `internal-edge-wait-until` module
// when consumers switch to `unstable_after`.
Expand Down
81 changes: 81 additions & 0 deletions packages/next/src/server/lib/awaiter.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import { InvariantError } from '../../shared/lib/invariant-error'
import { AwaiterMulti, AwaiterOnce } from './awaiter'

describe('AwaiterOnce/AwaiterMulti', () => {
describe.each([
{ name: 'AwaiterMulti', impl: AwaiterMulti },
{ name: 'AwaiterOnce', impl: AwaiterOnce },
])('$name', ({ impl: AwaiterImpl }) => {
it('awaits promises added by other promises', async () => {
const awaiter = new AwaiterImpl()

const MAX_DEPTH = 5
const promises: TrackedPromise<unknown>[] = []

const waitUntil = (promise: Promise<unknown>) => {
promises.push(trackPromiseSettled(promise))
awaiter.waitUntil(promise)
}

const makeNestedPromise = async () => {
if (promises.length >= MAX_DEPTH) {
return
}
await sleep(100)
waitUntil(makeNestedPromise())
}

waitUntil(makeNestedPromise())

await awaiter.awaiting()

for (const promise of promises) {
expect(promise.isSettled).toBe(true)
}
})

it('calls onError for rejected promises', async () => {
const onError = jest.fn<void, [error: unknown]>()
const awaiter = new AwaiterImpl({ onError })

awaiter.waitUntil(Promise.reject('error 1'))
awaiter.waitUntil(
sleep(100).then(() => awaiter.waitUntil(Promise.reject('error 2')))
)

await awaiter.awaiting()

expect(onError).toHaveBeenCalledWith('error 1')
expect(onError).toHaveBeenCalledWith('error 2')
})
})
})

describe('AwaiterOnce', () => {
it("does not allow calling waitUntil after it's been awaited", async () => {
const awaiter = new AwaiterOnce()
awaiter.waitUntil(Promise.resolve(1))
await awaiter.awaiting()
expect(() => awaiter.waitUntil(Promise.resolve(2))).toThrow(InvariantError)
})
})

type TrackedPromise<T> = Promise<T> & { isSettled: boolean }

function trackPromiseSettled<T>(promise: Promise<T>): TrackedPromise<T> {
const tracked = promise as TrackedPromise<T>
tracked.isSettled = false
tracked.then(
() => {
tracked.isSettled = true
},
() => {
tracked.isSettled = true
}
)
return tracked
}

function sleep(duration: number) {
return new Promise<void>((resolve) => setTimeout(resolve, duration))
}
70 changes: 70 additions & 0 deletions packages/next/src/server/lib/awaiter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { InvariantError } from '../../shared/lib/invariant-error'

/**
* Provides a `waitUntil` implementation which gathers promises to be awaited later (via {@link AwaiterMulti.awaiting}).
* Unlike a simple `Promise.all`, {@link AwaiterMulti} works recursively --
* if a promise passed to {@link AwaiterMulti.waitUntil} calls `waitUntil` again,
* that second promise will also be awaited.
*/
export class AwaiterMulti {
private promises: Set<Promise<unknown>> = new Set()
private onError: (error: unknown) => void

constructor({ onError }: { onError?: (error: unknown) => void } = {}) {
this.onError = onError ?? console.error
}

public waitUntil = (promise: Promise<unknown>): void => {
// if a promise settles before we await it, we can drop it.
const cleanup = () => {
this.promises.delete(promise)
}

this.promises.add(
promise.then(cleanup, (err) => {
cleanup()
this.onError(err)
})
)
}

public async awaiting(): Promise<void> {
while (this.promises.size > 0) {
const promises = Array.from(this.promises)
this.promises.clear()
await Promise.all(promises)
}
}
}

/**
* Like {@link AwaiterMulti}, but can only be awaited once.
* If {@link AwaiterOnce.waitUntil} is called after that, it will throw.
*/
export class AwaiterOnce {
private awaiter: AwaiterMulti
private done: boolean = false
private pending: Promise<void> | undefined

constructor(options: { onError?: (error: unknown) => void } = {}) {
this.awaiter = new AwaiterMulti(options)
}

public waitUntil = (promise: Promise<unknown>): void => {
if (this.done) {
throw new InvariantError(
'Cannot call waitUntil() on an AwaiterOnce that was already awaited'
)
}
return this.awaiter.waitUntil(promise)
}

public async awaiting(): Promise<void> {
if (!this.pending) {
this.pending = this.awaiter.awaiting().finally(() => {
this.done = true
})
}
return this.pending
}
}
2 changes: 1 addition & 1 deletion packages/next/src/server/web/adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ export async function adapter(

return {
response: finalResponse,
waitUntil: Promise.all(event[waitUntilSymbol]),
waitUntil: event[waitUntilSymbol](),
fetchMetrics: request.fetchMetrics,
}
}
7 changes: 7 additions & 0 deletions packages/next/src/server/web/edge-route-module-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ export class EdgeRouteModuleWrapper {
const trackedBody = trackStreamConsumed(res.body, () =>
_closeController.dispatchClose()
)

// make sure that NextRequestHint's awaiter stays open long enough
// for `waitUntil`s called late during streaming to get picked up.
evt.waitUntil(
new Promise<void>((resolve) => _closeController.onClose(resolve))
)

res = new Response(trackedBody, {
status: res.status,
statusText: res.statusText,
Expand Down
14 changes: 11 additions & 3 deletions packages/next/src/server/web/spec-extension/fetch-event.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import { AwaiterOnce } from '../../lib/awaiter'
import { PageSignatureError } from '../error'
import type { NextRequest } from './request'

const responseSymbol = Symbol('response')
const passThroughSymbol = Symbol('passThrough')
const awaiterSymbol = Symbol('awaiter')

export const waitUntilSymbol = Symbol('waitUntil')

class FetchEvent {
readonly [waitUntilSymbol]: Promise<any>[] = [];
[responseSymbol]?: Promise<Response>;
[passThroughSymbol] = false
[passThroughSymbol] = false;

[awaiterSymbol] = new AwaiterOnce();

[waitUntilSymbol] = () => {
return this[awaiterSymbol].awaiting()
}

// eslint-disable-next-line @typescript-eslint/no-useless-constructor
constructor(_request: Request) {}
Expand All @@ -24,7 +32,7 @@ class FetchEvent {
}

waitUntil(promise: Promise<any>): void {
this[waitUntilSymbol].push(promise)
this[awaiterSymbol].waitUntil(promise)
}
}

Expand Down

0 comments on commit 8efe900

Please sign in to comment.