Skip to content

Commit

Permalink
feat: support streaming for Gemini Pro (ChatGPTNextWeb#3688)
Browse files Browse the repository at this point in the history
* feat: support streaming for Gemini Pro

* feat: display texts smoothly

* chore: remove comments
  • Loading branch information
fredliang44 authored Dec 28, 2023
1 parent 3ba5986 commit 5cf58d9
Showing 1 changed file with 48 additions and 80 deletions.
128 changes: 48 additions & 80 deletions app/client/platforms/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export class GeminiProApi implements LLMApi {
);
}
async chat(options: ChatOptions): Promise<void> {
const apiClient = this;
const messages = options.messages.map((v) => ({
role: v.role.replace("assistant", "model").replace("system", "user"),
parts: [{ text: v.content }],
Expand Down Expand Up @@ -61,8 +62,7 @@ export class GeminiProApi implements LLMApi {

console.log("[Request] google payload: ", requestPayload);

// todo: support stream later
const shouldStream = false;
const shouldStream = !!options.config.stream;
const controller = new AbortController();
options.onController?.(controller);
try {
Expand All @@ -82,13 +82,21 @@ export class GeminiProApi implements LLMApi {
if (shouldStream) {
let responseText = "";
let remainText = "";
let streamChatPath = chatPath.replace(
"generateContent",
"streamGenerateContent",
);
let finished = false;
const finish = () => {
finished = true;
options.onFinish(responseText + remainText);
};

// animate response to make it looks smooth
function animateResponseText() {
if (finished || controller.signal.aborted) {
responseText += remainText;
console.log("[Response Animation] finished");
finish();
return;
}

Expand All @@ -105,88 +113,41 @@ export class GeminiProApi implements LLMApi {

// start animaion
animateResponseText();
fetch(streamChatPath, chatPayload)
.then((response) => {
const reader = response?.body?.getReader();
const decoder = new TextDecoder();
let partialData = "";

return reader?.read().then(function processText({
done,
value,
}): Promise<any> {
if (done) {
console.log("Stream complete");
// options.onFinish(responseText + remainText);
finished = true;
return Promise.resolve();
}

const finish = () => {
if (!finished) {
finished = true;
options.onFinish(responseText + remainText);
}
};
partialData += decoder.decode(value, { stream: true });

controller.signal.onabort = finish;

fetchEventSource(chatPath, {
...chatPayload,
async onopen(res) {
clearTimeout(requestTimeoutId);
const contentType = res.headers.get("content-type");
console.log(
"[OpenAI] request response content type: ",
contentType,
);

if (contentType?.startsWith("text/plain")) {
responseText = await res.clone().text();
return finish();
}

if (
!res.ok ||
!res.headers
.get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [responseText];
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}

if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
let data = JSON.parse(ensureProperEnding(partialData));
console.log(data);
let fetchText = apiClient.extractMessage(data[data.length - 1]);
console.log("[Response Animation] fetchText: ", fetchText);
remainText += fetchText;
} catch (error) {
// skip error message when parsing json
}

if (extraInfo) {
responseTexts.push(extraInfo);
}

responseText = responseTexts.join("\n\n");

return finish();
}
},
onmessage(msg) {
if (msg.data === "[DONE]" || finished) {
return finish();
}
const text = msg.data;
try {
const json = JSON.parse(text) as {
choices: Array<{
delta: {
content: string;
};
}>;
};
const delta = json.choices[0]?.delta?.content;
if (delta) {
remainText += delta;
}
} catch (e) {
console.error("[Request] parse error", text);
}
},
onclose() {
finish();
},
onerror(e) {
options.onError?.(e);
throw e;
},
openWhenHidden: true,
});
return reader.read().then(processText);
});
})
.catch((error) => {
console.error("Error:", error);
});
} else {
const res = await fetch(chatPath, chatPayload);
clearTimeout(requestTimeoutId);
Expand Down Expand Up @@ -220,3 +181,10 @@ export class GeminiProApi implements LLMApi {
return "/api/google/" + path;
}
}

function ensureProperEnding(str: string) {
if (str.startsWith("[") && !str.endsWith("]")) {
return str + "]";
}
return str;
}

0 comments on commit 5cf58d9

Please sign in to comment.