Skip to content

Commit

Permalink
⚡️ perf: disable smoohting by default in fetchSSE (lobehub#3787)
Browse files Browse the repository at this point in the history
* ⚡️ perf: disable smoohting by default in fetchSSE

* ⚡️ perf: disable smoohting by default in fetchSSE

* ⚡️ perf: add smoothing config for Google
  • Loading branch information
arvinxx authored Sep 6, 2024
1 parent c22b886 commit 24cacf2
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 12 deletions.
4 changes: 4 additions & 0 deletions src/config/modelProviders/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ const Google: ModelProviderCard = {
proxyUrl: {
placeholder: 'https://generativelanguage.googleapis.com',
},
smoothing: {
speed: 2,
text: true,
},
};

export default Google;
4 changes: 4 additions & 0 deletions src/services/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { produce } from 'immer';
import { merge } from 'lodash-es';

import { createErrorResponse } from '@/app/api/errorResponse';
import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders';
import { INBOX_GUIDE_SYSTEMROLE } from '@/const/guide';
import { INBOX_SESSION_ID } from '@/const/session';
import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
Expand Down Expand Up @@ -304,6 +305,8 @@ class ChatService {
provider,
});

const providerConfig = DEFAULT_MODEL_PROVIDER_LIST.find((item) => item.id === provider);

return fetchSSE(API_ENDPOINTS.chat(provider), {
body: JSON.stringify(payload),
fetcher: fetcher,
Expand All @@ -314,6 +317,7 @@ class ChatService {
onFinish: options?.onFinish,
onMessageHandle: options?.onMessageHandle,
signal,
smoothing: providerConfig?.smoothing,
});
};

Expand Down
11 changes: 11 additions & 0 deletions src/types/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ export interface ChatModelCard {
vision?: boolean;
}

export interface SmoothingParams {
speed?: number;
text?: boolean;
toolsCalling?: boolean;
}

export interface ModelProviderCard {
chatModels: ChatModelCard[];
/**
Expand Down Expand Up @@ -89,6 +95,11 @@ export interface ModelProviderCard {
* so provider like ollama don't need api key field
*/
showApiKey?: boolean;

/**
* whether to smoothing the output
*/
smoothing?: SmoothingParams;
}

// 语言模型的设置参数
Expand Down
3 changes: 3 additions & 0 deletions src/utils/fetch/__tests__/fetchSSE.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ describe('fetchSSE', () => {
await fetchSSE('/', {
onMessageHandle: mockOnMessageHandle,
onFinish: mockOnFinish,
smoothing: true,
});

expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { text: 'Hell', type: 'text' });
Expand Down Expand Up @@ -183,6 +184,7 @@ describe('fetchSSE', () => {
await fetchSSE('/', {
onMessageHandle: mockOnMessageHandle,
onFinish: mockOnFinish,
smoothing: true,
});

// TODO: need to check whether the `aarg1` is correct
Expand Down Expand Up @@ -234,6 +236,7 @@ describe('fetchSSE', () => {
onMessageHandle: mockOnMessageHandle,
onFinish: mockOnFinish,
signal: abortController.signal,
smoothing: true,
});

expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { text: 'Hell', type: 'text' });
Expand Down
36 changes: 24 additions & 12 deletions src/utils/fetch/fetchSSE.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { MESSAGE_CANCEL_FLAT } from '@/const/message';
import { LOBE_CHAT_OBSERVATION_ID, LOBE_CHAT_TRACE_ID } from '@/const/trace';
import { ChatErrorType } from '@/types/fetch';
import { SmoothingParams } from '@/types/llm';
import {
ChatMessageError,
MessageToolCall,
Expand Down Expand Up @@ -41,14 +42,19 @@ export interface FetchSSEOptions {
onErrorHandle?: (error: ChatMessageError) => void;
onFinish?: OnFinishHandler;
onMessageHandle?: (chunk: MessageTextChunk | MessageToolCallsChunk) => void;
smoothing?: boolean;
smoothing?: SmoothingParams | boolean;
}

const START_ANIMATION_SPEED = 4;

const END_ANIMATION_SPEED = 15;

const createSmoothMessage = (params: { onTextUpdate: (delta: string, text: string) => void }) => {
const createSmoothMessage = (params: {
onTextUpdate: (delta: string, text: string) => void;
startSpeed?: number;
}) => {
const { startSpeed = START_ANIMATION_SPEED } = params;

let buffer = '';
// why use queue: https://shareg.pt/GLBrjpK
let outputQueue: string[] = [];
Expand All @@ -66,7 +72,7 @@ const createSmoothMessage = (params: { onTextUpdate: (delta: string, text: strin

// define startAnimation function to display the text in buffer smooth
// when you need to start the animation, call this function
const startAnimation = (speed = START_ANIMATION_SPEED) =>
const startAnimation = (speed = startSpeed) =>
new Promise<void>((resolve) => {
if (isAnimationActive) {
resolve();
Expand Down Expand Up @@ -122,7 +128,9 @@ const createSmoothMessage = (params: { onTextUpdate: (delta: string, text: strin

const createSmoothToolCalls = (params: {
onToolCallsUpdate: (toolCalls: MessageToolCall[], isAnimationActives: boolean[]) => void;
startSpeed?: number;
}) => {
const { startSpeed = START_ANIMATION_SPEED } = params;
let toolCallsBuffer: MessageToolCall[] = [];

// 为每个 tool_call 维护一个输出队列和动画控制器
Expand All @@ -139,7 +147,7 @@ const createSmoothToolCalls = (params: {
}
};

const startAnimation = (index: number, speed = START_ANIMATION_SPEED) =>
const startAnimation = (index: number, speed = startSpeed) =>
new Promise<void>((resolve) => {
if (isAnimationActives[index]) {
resolve();
Expand Down Expand Up @@ -194,7 +202,7 @@ const createSmoothToolCalls = (params: {
});
};

const startAnimations = async (speed = START_ANIMATION_SPEED) => {
const startAnimations = async (speed = startSpeed) => {
const pools = toolCallsBuffer.map(async (_, index) => {
if (outputQueues[index].length > 0 && !isAnimationActives[index]) {
await startAnimation(index, speed);
Expand Down Expand Up @@ -230,19 +238,26 @@ export const fetchSSE = async (url: string, options: RequestInit & FetchSSEOptio
let finishedType: SSEFinishType = 'done';
let response!: Response;

const { smoothing = true } = options;
const { smoothing } = options;

const textSmoothing = typeof smoothing === 'boolean' ? smoothing : smoothing?.text;
const toolsCallingSmoothing =
typeof smoothing === 'boolean' ? smoothing : (smoothing?.toolsCalling ?? true);
const smoothingSpeed = typeof smoothing === 'object' ? smoothing.speed : undefined;

const textController = createSmoothMessage({
onTextUpdate: (delta, text) => {
output = text;
options.onMessageHandle?.({ text: delta, type: 'text' });
},
startSpeed: smoothingSpeed,
});

const toolCallsController = createSmoothToolCalls({
onToolCallsUpdate: (toolCalls, isAnimationActives) => {
options.onMessageHandle?.({ isAnimationActives, tool_calls: toolCalls, type: 'tool_calls' });
},
startSpeed: smoothingSpeed,
});

await fetchEventSource(url, {
Expand Down Expand Up @@ -305,7 +320,7 @@ export const fetchSSE = async (url: string, options: RequestInit & FetchSSEOptio
}

case 'text': {
if (smoothing) {
if (textSmoothing) {
textController.pushToQueue(data);

if (!textController.isAnimationActive) textController.startAnimation();
Expand All @@ -323,7 +338,7 @@ export const fetchSSE = async (url: string, options: RequestInit & FetchSSEOptio
if (!toolCalls) toolCalls = [];
toolCalls = parseToolCalls(toolCalls, data);

if (smoothing) {
if (toolsCallingSmoothing) {
// make the tool calls smooth

// push the tool calls to the smooth queue
Expand All @@ -333,10 +348,7 @@ export const fetchSSE = async (url: string, options: RequestInit & FetchSSEOptio
toolCallsController.startAnimations();
}
} else {
options.onMessageHandle?.({
tool_calls: toolCalls,
type: 'tool_calls',
});
options.onMessageHandle?.({ tool_calls: toolCalls, type: 'tool_calls' });
}
}
}
Expand Down

0 comments on commit 24cacf2

Please sign in to comment.