From d264e844f9aa19e3601319cc35aac42309be96cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daphn=C3=A9=20Popin?= Date: Tue, 28 Nov 2023 18:13:05 +0100 Subject: [PATCH] Batch render user message (#2692) * Batch render user messages * Render batch for agent messages & content fragment * apply feedback * Batch render user message diff (#2686) * code simplification * nit * clean-up * Global agents are loaded from getAgetnConfiguration * Update swr * Don't load full messages on reactions route * Fix agent message retry * Apply feedback --------- Co-authored-by: Stanislas Polu --- front/lib/api/assistant/conversation.ts | 343 ++++++++++++++---------- front/lib/api/assistant/reaction.ts | 1 + front/package-lock.json | 2 +- front/package.json | 2 +- 4 files changed, 210 insertions(+), 138 deletions(-) diff --git a/front/lib/api/assistant/conversation.ts b/front/lib/api/assistant/conversation.ts index 170533a63250..59c80dfb815e 100644 --- a/front/lib/api/assistant/conversation.ts +++ b/front/lib/api/assistant/conversation.ts @@ -25,6 +25,7 @@ import { GPT_3_5_TURBO_MODEL_CONFIG } from "@app/lib/assistant"; import { Authenticator } from "@app/lib/auth"; import { front_sequelize } from "@app/lib/databases"; import { + AgentDustAppRunAction, AgentMessage, Conversation, ConversationParticipant, @@ -38,6 +39,7 @@ import { updateWorkspacePerMonthlyActiveUsersSubscriptionUsage } from "@app/lib/ import { Err, Ok, Result } from "@app/lib/result"; import { generateModelSId } from "@app/lib/utils"; import logger from "@app/logger/logger"; +import { AgentConfigurationType } from "@app/types/assistant/agent"; import { AgentMessageType, ContentFragmentContentType, @@ -57,7 +59,6 @@ import { import { PlanType } from "@app/types/plan"; import { WorkspaceType } from "@app/types/user"; -import { renderDustAppRunActionByModelId } from "./actions/dust_app_run"; import { renderRetrievalActionByModelId } from "./actions/retrieval"; /** * Conversation Creation, update and deletion @@ -182,14 +183,14 @@ export async function deleteConversation( * Conversation Rendering */ -async function renderUserMessage( - message: Message, - userMessage: UserMessage -): Promise { - const [mentions, user] = await Promise.all([ +async function batchRenderUserMessages(messages: Message[]) { + const userMessages = messages.filter( + (m) => m.userMessage !== null && m.userMessage !== undefined + ); + const [mentions, users] = await Promise.all([ Mention.findAll({ where: { - messageId: message.id, + messageId: userMessages.map((m) => m.id), }, include: [ { @@ -200,126 +201,191 @@ async function renderUserMessage( ], }), (async () => { - if (userMessage.userId) { - return await User.findOne({ - where: { - id: userMessage.userId, - }, - }); + const userIds = userMessages + .map((m) => m.userMessage?.userId) + .filter((id) => !!id) as number[]; + if (userIds.length === 0) { + return []; } - return null; + return await User.findAll({ + where: { + id: userIds, + }, + }); })(), ]); - return { - id: message.id, - sId: message.sId, - type: "user_message", - visibility: message.visibility, - version: message.version, - created: message.createdAt.getTime(), - user: user - ? { - id: user.id, - provider: user.provider, - providerId: user.providerId, - username: user.username, - email: user.email, - firstName: user.firstName, - lastName: user.lastName, - fullName: user.firstName + (user.lastName ? ` ${user.lastName}` : ""), - image: null, - workspaces: [], + return userMessages.map((message) => { + if (!message.userMessage) { + throw new Error( + "Unreachable: batchRenderUserMessages has been filtered on user messages" + ); + } + const userMessage = message.userMessage; + const messageMentions = mentions.filter((m) => m.messageId === message.id); + const user = users.find((u) => u.id === userMessage.userId) || null; + + const m = { + id: message.id, + sId: message.sId, + type: "user_message", + visibility: message.visibility, + version: message.version, + created: message.createdAt.getTime(), + user: user + ? { + id: user.id, + provider: user.provider, + providerId: user.providerId, + username: user.username, + email: user.email, + firstName: user.firstName, + lastName: user.lastName, + fullName: + user.firstName + (user.lastName ? ` ${user.lastName}` : ""), + image: null, + workspaces: [], + } + : null, + mentions: messageMentions.map((m) => { + if (m.agentConfigurationId) { + return { + configurationId: m.agentConfigurationId, + }; } - : null, - mentions: mentions.map((m) => { - if (m.agentConfigurationId) { - return { - configurationId: m.agentConfigurationId, - }; - } - if (m.user) { - return { - provider: m.user.provider, - providerId: m.user.providerId, - }; - } - throw new Error("Unreachable: mention must be either agent or user"); - }), - content: userMessage.content, - context: { - username: userMessage.userContextUsername, - timezone: userMessage.userContextTimezone, - fullName: userMessage.userContextFullName, - email: userMessage.userContextEmail, - profilePictureUrl: userMessage.userContextProfilePictureUrl, - }, - }; + if (m.user) { + return { + provider: m.user.provider, + providerId: m.user.providerId, + }; + } + throw new Error("Unreachable: mention must be either agent or user"); + }), + content: userMessage.content, + context: { + username: userMessage.userContextUsername, + timezone: userMessage.userContextTimezone, + fullName: userMessage.userContextFullName, + email: userMessage.userContextEmail, + profilePictureUrl: userMessage.userContextProfilePictureUrl, + }, + }; + return { m, rank: message.rank, version: message.version }; + }); } -async function renderAgentMessage( +async function batchRenderAgentMessages( auth: Authenticator, - { - message, - agentMessage, - messages, - }: { message: Message; agentMessage: AgentMessage; messages: Message[] } -): Promise { - const [agentConfiguration, agentRetrievalAction, agentDustAppRunAction] = + messages: Message[] +) { + const agentMessages = messages.filter((m) => !!m.agentMessage); + const userMessages = messages.filter((m) => !!m.userMessage); + + const [agentConfigurations, agentRetrievalActions, agentDustAppRunActions] = await Promise.all([ - getAgentConfiguration(auth, agentMessage.agentConfigurationId), (async () => { - if (agentMessage.agentRetrievalActionId) { - return await renderRetrievalActionByModelId( - agentMessage.agentRetrievalActionId - ); - } - return null; + const agentConfigurationIds: string[] = agentMessages.reduce( + (acc: string[], m) => { + const agentId = m.agentMessage?.agentConfigurationId; + if (agentId && !acc.includes(agentId)) { + acc.push(agentId); + } + return acc; + }, + [] + ); + const agents = ( + await Promise.all( + agentConfigurationIds.map((agentConfigId) => { + return getAgentConfiguration(auth, agentConfigId); + }) + ) + ).filter((a) => a !== null) as AgentConfigurationType[]; + return agents; })(), (async () => { - if (agentMessage.agentDustAppRunActionId) { - return await renderDustAppRunActionByModelId( - agentMessage.agentDustAppRunActionId - ); - } - return null; + return await Promise.all( + agentMessages + .filter((m) => m.agentMessage?.agentRetrievalActionId) + .map((m) => { + return renderRetrievalActionByModelId( + m.agentMessage?.agentRetrievalActionId as number + ); + }) + ); + })(), + (async () => { + const actions = await AgentDustAppRunAction.findAll({ + where: { + id: { + [Op.in]: agentMessages + .filter((m) => m.agentMessage?.agentDustAppRunActionId) + .map((m) => m.agentMessage?.agentDustAppRunActionId as number), + }, + }, + }); + return actions.map((action) => { + return { + id: action.id, + type: "dust_app_run_action", + appWorkspaceId: action.appWorkspaceId, + appId: action.appId, + appName: action.appName, + params: action.params, + runningBlock: null, + output: action.output, + }; + }); })(), ]); - if (!agentConfiguration) { - throw new Error( - `Configuration ${agentMessage.agentConfigurationId} not found` + return agentMessages.map((message) => { + if (!message.agentMessage) { + throw new Error( + "Unreachable: batchRenderAgentMessages has been filtered on agent message" + ); + } + const agentMessage = message.agentMessage; + const action = + agentRetrievalActions.find( + (a) => a.id === agentMessage?.agentRetrievalActionId + ) ?? + agentDustAppRunActions.find( + (a) => a.id === agentMessage.agentDustAppRunActionId + ); + const agentConfiguration = agentConfigurations.find( + (a) => a.sId === message.agentMessage?.agentConfigurationId ); - } - const action = agentRetrievalAction ?? agentDustAppRunAction; + let error: { + code: string; + message: string; + } | null = null; - let error: { - code: string; - message: string; - } | null = null; - if (agentMessage.errorCode !== null && agentMessage.errorMessage !== null) { - error = { - code: agentMessage.errorCode, - message: agentMessage.errorMessage, - }; - } + if (agentMessage.errorCode !== null && agentMessage.errorMessage !== null) { + error = { + code: agentMessage.errorCode, + message: agentMessage.errorMessage, + }; + } - return { - id: message.id, - sId: message.sId, - created: message.createdAt.getTime(), - type: "agent_message", - visibility: message.visibility, - version: message.version, - parentMessageId: - messages.find((m) => m.id === message.parentId)?.sId ?? null, - status: agentMessage.status, - action, - content: agentMessage.content, - error, - configuration: agentConfiguration, - }; + const m = { + id: message.id, + sId: message.sId, + created: message.createdAt.getTime(), + type: "agent_message", + visibility: message.visibility, + version: message.version, + parentMessageId: + userMessages.find((m) => m.id === message.parentId)?.sId ?? null, + status: agentMessage.status, + action, + content: agentMessage.content, + error, + configuration: agentConfiguration, + }; + return { m, rank: message.rank, version: message.version }; + }); } function renderContentFragment({ @@ -349,6 +415,32 @@ function renderContentFragment({ }; } +async function batchRenderContentFragment(messages: Message[]) { + const messagesWithContentFragment = messages.filter( + (m) => !!m.contentFragment + ); + if (messagesWithContentFragment.find((m) => !m.contentFragment)) { + throw new Error( + "Unreachable: batchRenderContentFragment must be called with only content fragments" + ); + } + + return messagesWithContentFragment.map((message) => { + if (!message.contentFragment) { + throw new Error( + "Unreachable: batchRenderContentFragment must be called with only content fragments" + ); + } + const contentFragment = message.contentFragment; + + return { + m: renderContentFragment({ message, contentFragment }), + rank: message.rank, + version: message.version, + }; + }); +} + export async function getUserConversations( auth: Authenticator, includeDeleted?: boolean @@ -455,34 +547,13 @@ export async function getConversation( ], }); - const render = await Promise.all( - messages.map((message) => { - return (async () => { - if (message.userMessage) { - const m = await renderUserMessage(message, message.userMessage); - return { m, rank: message.rank, version: message.version }; - } - if (message.agentMessage) { - const m = await renderAgentMessage(auth, { - message, - agentMessage: message.agentMessage, - messages, - }); - return { m, rank: message.rank, version: message.version }; - } - if (message.contentFragment) { - const m = await renderContentFragment({ - message: message, - contentFragment: message.contentFragment, - }); - return { m, rank: message.rank, version: message.version }; - } - throw new Error( - "Unreachable: message must be either user, agent or content fragment" - ); - })(); - }) - ); + const [userMessages, agentMessages, contentFragments] = await Promise.all([ + batchRenderUserMessages(messages), + batchRenderAgentMessages(auth, messages), + batchRenderContentFragment(messages), + ]); + + const render = [...userMessages, ...agentMessages, ...contentFragments]; render.sort((a, b) => { if (a.rank !== b.rank) { return a.rank - b.rank; diff --git a/front/lib/api/assistant/reaction.ts b/front/lib/api/assistant/reaction.ts index 9dc29b0e376d..cc5e6019b833 100644 --- a/front/lib/api/assistant/reaction.ts +++ b/front/lib/api/assistant/reaction.ts @@ -24,6 +24,7 @@ export async function getMessageReactions( where: { conversationId: conversation.id, }, + attributes: ["sId"], include: [ { model: MessageReaction, diff --git a/front/package-lock.json b/front/package-lock.json index 7915ac59fdfa..772a153c711c 100644 --- a/front/package-lock.json +++ b/front/package-lock.json @@ -67,7 +67,7 @@ "sqlite3": "^5.1.4", "sse.js": "^0.6.1", "stripe": "^14.2.0", - "swr": "^2.0.2", + "swr": "^2.2.4", "tailwind-scrollbar-hide": "^1.1.7", "tailwindcss": "^3.2.4", "three": "^0.155.0", diff --git a/front/package.json b/front/package.json index 32a69536ce4a..fa6704bff4d6 100644 --- a/front/package.json +++ b/front/package.json @@ -75,7 +75,7 @@ "sqlite3": "^5.1.4", "sse.js": "^0.6.1", "stripe": "^14.2.0", - "swr": "^2.0.2", + "swr": "^2.2.4", "tailwind-scrollbar-hide": "^1.1.7", "tailwindcss": "^3.2.4", "three": "^0.155.0",