Skip to content

Commit

Permalink
fix: Resolve engine halting issue due to service worker being killed
Browse files Browse the repository at this point in the history
  • Loading branch information
Neet-Nestor committed May 16, 2024
1 parent bc848ed commit c5d9470
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 108 deletions.
2 changes: 0 additions & 2 deletions app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,5 @@ export interface LLMModelProvider {
export abstract class LLMApi {
abstract chat(options: ChatOptions): Promise<void>;
abstract usage(): Promise<LLMUsage>;
abstract models(): Promise<LLMModel[]>;
abstract abort(): Promise<void>;
abstract clear(): void;
}
179 changes: 109 additions & 70 deletions app/client/webllm.ts
Original file line number Diff line number Diff line change
@@ -1,92 +1,143 @@
import { createContext } from "react";
import {
CreateWebServiceWorkerEngine,
InitProgressReport,
prebuiltAppConfig,
ChatCompletionMessageParam,
WebServiceWorkerEngine,
WebServiceWorker,
ChatCompletionChunk,
ChatCompletion,
} from "@neet-nestor/web-llm";

import { ChatOptions, LLMApi, LLMConfig } from "./api";
import { ChatOptions, LLMApi, LLMConfig, RequestMessage } from "./api";

const KEEP_ALIVE_INTERVAL = 10000;

export class WebLLMApi implements LLMApi {
private currentModel?: string;
private engine?: WebServiceWorkerEngine;
private llmConfig?: LLMConfig;
engine?: WebServiceWorkerEngine;

constructor(onEngineCrash: () => void) {
setInterval(() => {
if ((this.engine?.missedHeatbeat || 0) > 2) {
onEngineCrash?.();
}
}, 10000);
constructor() {
this.engine = new WebServiceWorkerEngine(new WebServiceWorker());
this.engine.keepAlive(
window.location.href + "ping.txt",
KEEP_ALIVE_INTERVAL,
);
}

async initModel(onUpdate?: (message: string, chunk: string) => void) {
if (!this.llmConfig) {
throw Error("llmConfig is undefined");
}
if (!this.engine) {
this.engine = new WebServiceWorkerEngine(new WebServiceWorker());
}
let hasResponse = false;
this.engine.setInitProgressCallback((report: InitProgressReport) => {
onUpdate?.(report.text, report.text);
hasResponse = true;
});
let initRequest = this.engine.init(this.llmConfig.model, this.llmConfig, {
...prebuiltAppConfig,
useIndexedDBCache: this.llmConfig.cache === "index_db",
});
// In case the service worker is dead, init will halt indefinitely
// so we manually retry if timeout
let retry = 0;
let engine = this.engine;
let llmConfig = this.llmConfig;
let retryInterval: NodeJS.Timeout;

await new Promise<void>((resolve, reject) => {
retryInterval = setInterval(() => {
if (hasResponse) {
clearInterval(retryInterval);
initRequest.then(resolve);
return;
}
if (retry >= 5) {
clearInterval(retryInterval);
reject("Model initialization timed out for too many times");
return;
}
retry += 1;
initRequest = engine.init(llmConfig.model, llmConfig, {
...prebuiltAppConfig,
useIndexedDBCache: llmConfig.cache === "index_db",
});
}, 5000);
});
}

clear() {
this.engine = undefined;
isConfigChanged(config: LLMConfig) {
return (
this.llmConfig?.model !== config.model ||
this.llmConfig?.cache !== config.cache ||
this.llmConfig?.temperature !== config.temperature ||
this.llmConfig?.top_p !== config.top_p ||
this.llmConfig?.presence_penalty !== config.presence_penalty ||
this.llmConfig?.frequency_penalty !== config.frequency_penalty
);
}

async initModel(
config: LLMConfig,
async chatCompletion(
stream: boolean,
messages: RequestMessage[],
onUpdate?: (message: string, chunk: string) => void,
) {
this.currentModel = config.model;
this.engine = await CreateWebServiceWorkerEngine(config.model, {
chatOpts: {
temperature: config.temperature,
top_p: config.top_p,
presence_penalty: config.presence_penalty,
frequency_penalty: config.frequency_penalty,
},
appConfig: {
...prebuiltAppConfig,
useIndexedDBCache: config.cache === "index_db",
},
initProgressCallback: (report: InitProgressReport) => {
onUpdate?.(report.text, report.text);
},
let reply: string | null = "";

const completion = await this.engine!.chatCompletion({
stream: stream,
messages: messages as ChatCompletionMessageParam[],
});

if (stream) {
const asyncGenerator = completion as AsyncIterable<ChatCompletionChunk>;
for await (const chunk of asyncGenerator) {
if (chunk.choices[0].delta.content) {
reply += chunk.choices[0].delta.content;
onUpdate?.(reply, chunk.choices[0].delta.content);
}
}
return reply;
}
return (completion as ChatCompletion).choices[0].message.content;
}

async chat(options: ChatOptions): Promise<void> {
if (options.config.model !== this.currentModel) {
// in case the service worker is dead, revive it by firing a fetch event
fetch("/ping.txt");

if (this.isConfigChanged(options.config)) {
this.llmConfig = options.config;
try {
await this.initModel(options.config, options.onUpdate);
await this.initModel(options.onUpdate);
} catch (e) {
console.error("Error in initModel", e);
}
}

let reply: string | null = "";
if (options.config.stream) {
try {
const asyncChunkGenerator = await this.engine!.chatCompletion({
stream: options.config.stream,
messages: options.messages as ChatCompletionMessageParam[],
});

for await (const chunk of asyncChunkGenerator) {
if (chunk.choices[0].delta.content) {
reply += chunk.choices[0].delta.content;
options.onUpdate?.(reply, chunk.choices[0].delta.content);
}
}
} catch (err) {
console.error("Error in streaming chatCompletion", err);
options.onError?.(err as Error);
return;
}
} else {
try {
const completion = await this.engine!.chatCompletion({
stream: options.config.stream,
messages: options.messages as ChatCompletionMessageParam[],
});
reply = completion.choices[0].message.content;
} catch (err) {
console.error("Error in non-streaming chatCompletion", err);
try {
reply = await this.chatCompletion(
!!options.config.stream,
options.messages,
options.onUpdate,
);
} catch (err: any) {
if (err.toString().includes("Please call `Engine.reload(model)` first")) {
console.error("Error in chatCompletion", err);
options.onError?.(err as Error);
return;
}
// Service worker has been stopped. Restart it
await this.initModel(options.onUpdate);
reply = await this.chatCompletion(
!!options.config.stream,
options.messages,
options.onUpdate,
);
}

if (reply) {
Expand All @@ -106,18 +157,6 @@ export class WebLLMApi implements LLMApi {
total: 0,
};
}

async models() {
return prebuiltAppConfig.model_list.map((record) => ({
name: record.model_id,
available: true,
provider: {
id: "huggingface",
providerName: "huggingface",
providerType: "huggingface",
},
}));
}
}

export const WebLLMContext = createContext<WebLLMApi | null>(null);
9 changes: 4 additions & 5 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ import { ExportMessageModal } from "./exporter";
import { getClientConfig } from "../config/client";
import { useAllModels } from "../utils/hooks";
import { MultimodalContent } from "../client/api";
import { WebLLMApi, WebLLMContext } from "../client/webllm";
import { WebLLMContext } from "../client/webllm";

const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
loading: () => <LoadingIcon />,
Expand Down Expand Up @@ -682,8 +682,7 @@ function _Chat() {
const navigate = useNavigate();
const [attachImages, setAttachImages] = useState<string[]>([]);
const [uploading, setUploading] = useState(false);

const webllm = useContext(WebLLMContext);
const webllm = useContext(WebLLMContext)!;

// prompt hints
const promptStore = usePromptStore();
Expand Down Expand Up @@ -764,7 +763,7 @@ function _Chat() {
if (isStreaming) return;
setIsLoading(true);
chatStore
.onUserInput(userInput, webllm!, attachImages)
.onUserInput(userInput, webllm, attachImages)
.then(() => setIsLoading(false));
setAttachImages([]);
localStorage.setItem(LAST_INPUT_KEY, userInput);
Expand Down Expand Up @@ -922,7 +921,7 @@ function _Chat() {
const textContent = getMessageTextContent(userMessage);
const images = getMessageImages(userMessage);
chatStore
.onUserInput(textContent, webllm!, images)
.onUserInput(textContent, webllm, images)
.then(() => setIsLoading(false));
inputRef.current?.focus();
};
Expand Down
53 changes: 35 additions & 18 deletions app/components/home.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import { useAppConfig } from "../store/config";
import { getClientConfig } from "../config/client";
import { WebLLMApi, WebLLMContext } from "../client/webllm";
import Locale from "../locales";
import { prebuiltAppConfig } from "@neet-nestor/web-llm";

export function Loading(props: { noLogo?: boolean }) {
return (
Expand Down Expand Up @@ -177,40 +178,56 @@ function Screen() {
);
}

export function useLoadData(webllm: WebLLMApi) {
export function useLoadData() {
const config = useAppConfig();

useEffect(() => {
(async () => {
if (webllm) {
const models = await webllm.models();
config.mergeModels(models);
}
})();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [webllm]);
config.mergeModels(
prebuiltAppConfig.model_list.map((record) => ({
name: record.model_id,
available: true,
provider: {
id: "huggingface",
providerName: "huggingface",
providerType: "huggingface",
},
})),
);
}, []);
}

const useWebLLM = () => {
const [webllm, setWebLLM] = useState<WebLLMApi | null>(null);
const [isSWAlive, setSWAlive] = useState(true);

useEffect(() => {
setWebLLM(new WebLLMApi());
}, []);

setInterval(() => {
if (webllm) {
// 10s per heartbeat, dead after 1 min of inactivity
setSWAlive(!!webllm.engine && webllm.engine.missedHeatbeat < 6);
}
});

return { webllm, isWebllmAlive: isSWAlive };
};

export function Home() {
const hasHydrated = useHasHydrated();
const isServiceWorkerReady = useServiceWorkerReady();
const [isEngineCrash, setEngineCrash] = useState(false);

const webllm = useMemo(() => {
return new WebLLMApi(() => {
setEngineCrash(true);
});
}, []);
const { webllm, isWebllmAlive } = useWebLLM();

useLoadData(webllm);
useLoadData();
useSwitchTheme();
useHtmlLang();

if (!hasHydrated || !isServiceWorkerReady) {
return <Loading />;
}

if (isEngineCrash) {
if (!isWebllmAlive) {
return <ErrorScreen message={Locale.ServiceWorker.Error} />;
}

Expand Down
1 change: 0 additions & 1 deletion app/components/settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,6 @@ export function Settings() {
<Select
value="cache"
onChange={(e) => {
webllm?.clear();
updateConfig(
(config) =>
(config.cacheType = e.currentTarget
Expand Down
Loading

0 comments on commit c5d9470

Please sign in to comment.