diff --git a/front/lib/api/assistant/conversation.ts b/front/lib/api/assistant/conversation.ts index d4cd85d348f4..170533a63250 100644 --- a/front/lib/api/assistant/conversation.ts +++ b/front/lib/api/assistant/conversation.ts @@ -25,7 +25,6 @@ import { GPT_3_5_TURBO_MODEL_CONFIG } from "@app/lib/assistant"; import { Authenticator } from "@app/lib/auth"; import { front_sequelize } from "@app/lib/databases"; import { - AgentDustAppRunAction, AgentMessage, Conversation, ConversationParticipant, @@ -39,7 +38,6 @@ import { updateWorkspacePerMonthlyActiveUsersSubscriptionUsage } from "@app/lib/ import { Err, Ok, Result } from "@app/lib/result"; import { generateModelSId } from "@app/lib/utils"; import logger from "@app/logger/logger"; -import { AgentConfigurationType } from "@app/types/assistant/agent"; import { AgentMessageType, ContentFragmentContentType, @@ -59,6 +57,7 @@ import { import { PlanType } from "@app/types/plan"; import { WorkspaceType } from "@app/types/user"; +import { renderDustAppRunActionByModelId } from "./actions/dust_app_run"; import { renderRetrievalActionByModelId } from "./actions/retrieval"; /** * Conversation Creation, update and deletion @@ -183,17 +182,14 @@ export async function deleteConversation( * Conversation Rendering */ -async function batchRenderUserMessages(messages: Message[]) { - if (messages.find((m) => !m.userMessage)) { - throw new Error( - "Unreachable: batchRenderUserMessages must be called with only user messages" - ); - } - - const [mentions, users] = await Promise.all([ +async function renderUserMessage( + message: Message, + userMessage: UserMessage +): Promise { + const [mentions, user] = await Promise.all([ Mention.findAll({ where: { - messageId: messages.map((m) => m.id), + messageId: message.id, }, include: [ { @@ -204,194 +200,126 @@ async function batchRenderUserMessages(messages: Message[]) { ], }), (async () => { - const userIds = messages - .map((m) => m.userMessage?.userId) - .filter((id) => !!id) as number[]; - if (userIds.length === 0) { - return []; + if (userMessage.userId) { + return await User.findOne({ + where: { + id: userMessage.userId, + }, + }); } - return await User.findAll({ - where: { - id: userIds, - }, - }); + return null; })(), ]); - return messages.map((message) => { - if (!message.userMessage) { - throw new Error( - "Unreachable: batchRenderUserMessages must be called with only user messages" - ); - } - const userMessage = message.userMessage; - const messageMentions = mentions.filter((m) => m.messageId === message.id); - const user = users.find((u) => u.id === userMessage.userId) || null; - - const m = { - id: message.id, - sId: message.sId, - type: "user_message", - visibility: message.visibility, - version: message.version, - created: message.createdAt.getTime(), - user: user - ? { - id: user.id, - provider: user.provider, - providerId: user.providerId, - username: user.username, - email: user.email, - firstName: user.firstName, - lastName: user.lastName, - fullName: - user.firstName + (user.lastName ? ` ${user.lastName}` : ""), - image: null, - workspaces: [], - } - : null, - mentions: messageMentions.map((m) => { - if (m.agentConfigurationId) { - return { - configurationId: m.agentConfigurationId, - }; - } - if (m.user) { - return { - provider: m.user.provider, - providerId: m.user.providerId, - }; + return { + id: message.id, + sId: message.sId, + type: "user_message", + visibility: message.visibility, + version: message.version, + created: message.createdAt.getTime(), + user: user + ? { + id: user.id, + provider: user.provider, + providerId: user.providerId, + username: user.username, + email: user.email, + firstName: user.firstName, + lastName: user.lastName, + fullName: user.firstName + (user.lastName ? ` ${user.lastName}` : ""), + image: null, + workspaces: [], } - throw new Error("Unreachable: mention must be either agent or user"); - }), - content: userMessage.content, - context: { - username: userMessage.userContextUsername, - timezone: userMessage.userContextTimezone, - fullName: userMessage.userContextFullName, - email: userMessage.userContextEmail, - profilePictureUrl: userMessage.userContextProfilePictureUrl, - }, - }; - return { m, rank: message.rank, version: message.version }; - }); + : null, + mentions: mentions.map((m) => { + if (m.agentConfigurationId) { + return { + configurationId: m.agentConfigurationId, + }; + } + if (m.user) { + return { + provider: m.user.provider, + providerId: m.user.providerId, + }; + } + throw new Error("Unreachable: mention must be either agent or user"); + }), + content: userMessage.content, + context: { + username: userMessage.userContextUsername, + timezone: userMessage.userContextTimezone, + fullName: userMessage.userContextFullName, + email: userMessage.userContextEmail, + profilePictureUrl: userMessage.userContextProfilePictureUrl, + }, + }; } -async function batchRenderAgentMessages( +async function renderAgentMessage( auth: Authenticator, - messages: Message[] -) { - if (messages.find((m) => !m.agentMessage)) { - throw new Error( - "Unreachable: batchRenderAgentMessages must be called with only agent messages" - ); - } - - const [agentConfigurations, agentRetrievalActions, agentDustAppRunActions] = + { + message, + agentMessage, + messages, + }: { message: Message; agentMessage: AgentMessage; messages: Message[] } +): Promise { + const [agentConfiguration, agentRetrievalAction, agentDustAppRunAction] = await Promise.all([ + getAgentConfiguration(auth, agentMessage.agentConfigurationId), (async () => { - const agentConfigurationIds: string[] = messages.reduce( - (acc: string[], m) => { - const agentId = m.agentMessage?.agentConfigurationId; - if (agentId && !acc.includes(agentId)) { - acc.push(agentId); - } - return acc; - }, - [] - ); - const agents = ( - await Promise.all( - agentConfigurationIds.map((agentConfigId) => { - return getAgentConfiguration(auth, agentConfigId); - }) - ) - ).filter((a) => a !== null) as AgentConfigurationType[]; - return agents; - })(), - (async () => { - return await Promise.all( - messages - .filter((m) => m.agentMessage?.agentRetrievalActionId) - .map((m) => { - return renderRetrievalActionByModelId( - m.agentMessage?.agentRetrievalActionId as number - ); - }) - ); + if (agentMessage.agentRetrievalActionId) { + return await renderRetrievalActionByModelId( + agentMessage.agentRetrievalActionId + ); + } + return null; })(), (async () => { - const actions = await AgentDustAppRunAction.findAll({ - where: { - id: { - [Op.in]: messages - .filter((m) => m.agentMessage?.agentDustAppRunActionId) - .map((m) => m.agentMessage?.agentDustAppRunActionId as number), - }, - }, - }); - return actions.map((action) => { - return { - id: action.id, - type: "dust_app_run_action", - appWorkspaceId: action.appWorkspaceId, - appId: action.appId, - appName: action.appName, - params: action.params, - runningBlock: null, - output: action.output, - }; - }); + if (agentMessage.agentDustAppRunActionId) { + return await renderDustAppRunActionByModelId( + agentMessage.agentDustAppRunActionId + ); + } + return null; })(), ]); - return messages.map((message) => { - if (!message.agentMessage) { - throw new Error( - "Unreachable: batchRenderUserMessages must be called with only user messages" - ); - } - const agentMessage = message.agentMessage; - const action = - agentRetrievalActions.find( - (a) => a.id === agentMessage?.agentRetrievalActionId - ) ?? - agentDustAppRunActions.find( - (a) => a.id === agentMessage.agentDustAppRunActionId - ); - const agentConfiguration = agentConfigurations.find( - (a) => a.sId === message.agentMessage?.agentConfigurationId + if (!agentConfiguration) { + throw new Error( + `Configuration ${agentMessage.agentConfigurationId} not found` ); + } - let error: { - code: string; - message: string; - } | null = null; - - if (agentMessage.errorCode !== null && agentMessage.errorMessage !== null) { - error = { - code: agentMessage.errorCode, - message: agentMessage.errorMessage, - }; - } + const action = agentRetrievalAction ?? agentDustAppRunAction; - const m = { - id: message.id, - sId: message.sId, - created: message.createdAt.getTime(), - type: "agent_message", - visibility: message.visibility, - version: message.version, - parentMessageId: - messages.find((m) => m.id === message.parentId)?.sId ?? null, - status: agentMessage.status, - action, - content: agentMessage.content, - error, - configuration: agentConfiguration, + let error: { + code: string; + message: string; + } | null = null; + if (agentMessage.errorCode !== null && agentMessage.errorMessage !== null) { + error = { + code: agentMessage.errorCode, + message: agentMessage.errorMessage, }; - return { m, rank: message.rank, version: message.version }; - }); + } + + return { + id: message.id, + sId: message.sId, + created: message.createdAt.getTime(), + type: "agent_message", + visibility: message.visibility, + version: message.version, + parentMessageId: + messages.find((m) => m.id === message.parentId)?.sId ?? null, + status: agentMessage.status, + action, + content: agentMessage.content, + error, + configuration: agentConfiguration, + }; } function renderContentFragment({ @@ -421,29 +349,6 @@ function renderContentFragment({ }; } -async function batchRenderContentFragment(messages: Message[]) { - if (messages.find((m) => !m.contentFragment)) { - throw new Error( - "Unreachable: batchRenderContentFragment must be called with only content fragments" - ); - } - - return messages.map((message) => { - if (!message.contentFragment) { - throw new Error( - "Unreachable: batchRenderContentFragment must be called with only content fragments" - ); - } - const contentFragment = message.contentFragment; - - return { - m: renderContentFragment({ message, contentFragment }), - rank: message.rank, - version: message.version, - }; - }); -} - export async function getUserConversations( auth: Authenticator, includeDeleted?: boolean @@ -550,16 +455,34 @@ export async function getConversation( ], }); - const [userMessages, agentMessages, contentFragments] = await Promise.all([ - batchRenderUserMessages(messages.filter((m) => !!m.userMessage)), - batchRenderAgentMessages( - auth, - messages.filter((m) => !!m.agentMessage) - ), - batchRenderContentFragment(messages.filter((m) => !!m.contentFragment)), - ]); - - const render = [...userMessages, ...agentMessages, ...contentFragments]; + const render = await Promise.all( + messages.map((message) => { + return (async () => { + if (message.userMessage) { + const m = await renderUserMessage(message, message.userMessage); + return { m, rank: message.rank, version: message.version }; + } + if (message.agentMessage) { + const m = await renderAgentMessage(auth, { + message, + agentMessage: message.agentMessage, + messages, + }); + return { m, rank: message.rank, version: message.version }; + } + if (message.contentFragment) { + const m = await renderContentFragment({ + message: message, + contentFragment: message.contentFragment, + }); + return { m, rank: message.rank, version: message.version }; + } + throw new Error( + "Unreachable: message must be either user, agent or content fragment" + ); + })(); + }) + ); render.sort((a, b) => { if (a.rank !== b.rank) { return a.rank - b.rank; diff --git a/front/lib/api/assistant/reaction.ts b/front/lib/api/assistant/reaction.ts index cc5e6019b833..9dc29b0e376d 100644 --- a/front/lib/api/assistant/reaction.ts +++ b/front/lib/api/assistant/reaction.ts @@ -24,7 +24,6 @@ export async function getMessageReactions( where: { conversationId: conversation.id, }, - attributes: ["sId"], include: [ { model: MessageReaction, diff --git a/front/package-lock.json b/front/package-lock.json index ce0df6bc9d4a..c41245abaa31 100644 --- a/front/package-lock.json +++ b/front/package-lock.json @@ -66,7 +66,7 @@ "sqlite3": "^5.1.4", "sse.js": "^0.6.1", "stripe": "^14.2.0", - "swr": "^2.2.4", + "swr": "^2.0.2", "tailwind-scrollbar-hide": "^1.1.7", "tailwindcss": "^3.2.4", "three": "^0.155.0", diff --git a/front/package.json b/front/package.json index 5081b7e38123..97ce11085123 100644 --- a/front/package.json +++ b/front/package.json @@ -74,7 +74,7 @@ "sqlite3": "^5.1.4", "sse.js": "^0.6.1", "stripe": "^14.2.0", - "swr": "^2.2.4", + "swr": "^2.0.2", "tailwind-scrollbar-hide": "^1.1.7", "tailwindcss": "^3.2.4", "three": "^0.155.0",