Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): binary chunk streaming support #77

Merged
merged 14 commits into from
Aug 6, 2024
9 changes: 7 additions & 2 deletions apps/demo-nextjs-app-router/app/queue/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ function Error(props: ErrorProps) {
);
}

const DEFAULT_ENDPOINT_ID = 'fal-ai/fast-sdxl';
const DEFAULT_INPUT = `{
"prompt": "A beautiful sunset over the ocean"
}`;

export default function Home() {
// Input state
const [endpointId, setEndpointId] = useState<string>('');
const [input, setInput] = useState<string>('{}');
const [endpointId, setEndpointId] = useState<string>(DEFAULT_ENDPOINT_ID);
const [input, setInput] = useState<string>(DEFAULT_INPUT);
// Result state
const [loading, setLoading] = useState(false);
const [error, setError] = useState<Error | null>(null);
Expand Down
2 changes: 1 addition & 1 deletion libs/client/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client",
"version": "0.14.0-alpha.3",
"version": "0.14.0",
"license": "MIT",
"repository": {
"type": "git",
Expand Down
13 changes: 0 additions & 13 deletions libs/client/src/function.spec.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
import uuid from 'uuid-random';
import { buildUrl } from './function';

describe('The function test suite', () => {
it('should build the URL with a function UUIDv4', () => {
const id = uuid();
const url = buildUrl(`12345/${id}`);
expect(url).toMatch(`trigger/12345/${id}`);
});

it('should build the URL with a function user-id-app-alias', () => {
const alias = '12345-some-alias';
const url = buildUrl(alias);
expect(url).toMatch(`fal.run/12345/some-alias`);
});

it('should build the URL with a function username/app-alias', () => {
const alias = 'fal-ai/text-to-image';
const url = buildUrl(alias);
Expand Down
54 changes: 33 additions & 21 deletions libs/client/src/function.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import { getTemporaryAuthToken } from './auth';
import { dispatchRequest } from './request';
import { storageImpl } from './storage';
import { FalStream } from './streaming';
import { FalStream, StreamingConnectionMode } from './streaming';
import {
CompletedQueueStatus,
EnqueueResult,
QueueStatus,
RequestLog,
} from './types';
import { ensureAppIdFormat, isUUIDv4, isValidUrl, parseAppId } from './utils';
import { ensureAppIdFormat, isValidUrl, parseAppId } from './utils';

/**
* The function input and other configuration when running
Expand Down Expand Up @@ -80,20 +79,13 @@ export function buildUrl<Input>(
Object.keys(params).length > 0
? `?${new URLSearchParams(params).toString()}`
: '';
const parts = id.split('/');

// if a fal url is passed, just use it
if (isValidUrl(id)) {
const url = id.endsWith('/') ? id : `${id}/`;
return `${url}${path}${queryParams}`;
}

// TODO remove this after some time, fal.run should be preferred
if (parts.length === 2 && isUUIDv4(parts[1])) {
const host = 'gateway.shark.fal.ai';
return `https://${host}/trigger/${id}/${path}${queryParams}`;
}

const appId = ensureAppIdFormat(id);
const subdomain = options.subdomain ? `${options.subdomain}.` : '';
const url = `https://${subdomain}fal.run/${appId}/${path}`;
Expand Down Expand Up @@ -199,6 +191,12 @@ type QueueSubscribeOptions = {
}
| {
mode: 'streaming';

/**
* The connection mode to use for streaming updates. It defaults to `server`.
* Set to `client` if your server proxy doesn't support streaming.
*/
connectionMode?: StreamingConnectionMode;
}
);

Expand Down Expand Up @@ -228,6 +226,14 @@ type QueueStatusOptions = BaseQueueOptions & {
logs?: boolean;
};

type QueueStatusStreamOptions = QueueStatusOptions & {
/**
* The connection mode to use for streaming updates. It defaults to `server`.
* Set to `client` if your server proxy doesn't support streaming.
*/
connectionMode?: StreamingConnectionMode;
};

/**
* Represents a request queue with methods for submitting requests,
* checking their status, retrieving results, and subscribing to updates.
Expand Down Expand Up @@ -263,7 +269,7 @@ interface Queue {
*/
streamStatus(
endpointId: string,
options: QueueStatusOptions
options: QueueStatusStreamOptions
): Promise<FalStream<unknown, QueueStatus>>;

/**
Expand Down Expand Up @@ -340,24 +346,26 @@ export const queue: Queue = {

async streamStatus(
endpointId: string,
{ requestId, logs = false }: QueueStatusOptions
{ requestId, logs = false, connectionMode }: QueueStatusStreamOptions
): Promise<FalStream<unknown, QueueStatus>> {
const appId = parseAppId(endpointId);
const prefix = appId.namespace ? `${appId.namespace}/` : '';
const token = await getTemporaryAuthToken(endpointId);

const queryParams = {
logs: logs ? '1' : '0',
};

const url = buildUrl(`${prefix}${appId.owner}/${appId.alias}`, {
subdomain: 'queue',
path: `/requests/${requestId}/status/stream`,
query: queryParams,
});

const queryParams = new URLSearchParams({
fal_jwt_token: token,
logs: logs ? '1' : '0',
});

return new FalStream<unknown, QueueStatus>(`${url}?${queryParams}`, {
input: {},
return new FalStream<unknown, QueueStatus>(endpointId, {
url,
method: 'get',
connectionMode,
queryParams,
});
},

Expand All @@ -375,6 +383,10 @@ export const queue: Queue = {
const status = await queue.streamStatus(endpointId, {
requestId,
logs: options.logs,
connectionMode:
'connectionMode' in options
? (options.connectionMode as StreamingConnectionMode)
: undefined,
});
const logs: RequestLog[] = [];
if (timeout) {
Expand All @@ -390,7 +402,7 @@ export const queue: Queue = {
);
}, timeout);
}
status.on('message', (data: QueueStatus) => {
status.on('data', (data: QueueStatus) => {
if (options.onQueueUpdate) {
// accumulate logs to match previous polling behavior
if (
Expand Down
19 changes: 16 additions & 3 deletions libs/client/src/request.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import { getConfig } from './config';
import { ResponseHandler } from './response';
import { getUserAgent, isBrowser } from './runtime';

const isCloudflareWorkers =
typeof navigator !== 'undefined' &&
navigator?.userAgent === 'Cloudflare-Workers';

type RequestOptions = {
responseHandler?: ResponseHandler<any>;

Check warning on line 10 in libs/client/src/request.ts

View workflow job for this annotation

GitHub Actions / build

Unexpected any. Specify a different type
};

export async function dispatchRequest<Input, Output>(
method: string,
targetUrl: string,
input: Input
input: Input,
options: RequestOptions & RequestInit = {}
): Promise<Output> {
const {
credentials: credentialsValue,
Expand Down Expand Up @@ -39,14 +45,21 @@
...userAgent,
...(headers ?? {}),
} as HeadersInit;

const { responseHandler: customResponseHandler, ...requestInit } = options;
const response = await fetch(url, {
...requestInit,
method,
headers: requestHeaders,
headers: {
...requestHeaders,
...(requestInit.headers ?? {}),
},
...(!isCloudflareWorkers && { mode: 'cors' }),
body:
method.toLowerCase() !== 'get' && input
? JSON.stringify(input)
: undefined,
});
return await responseHandler(response);
const handleResponse = customResponseHandler ?? responseHandler;
return await handleResponse(response);
}
Loading
Loading