From 743ad3793c4a9c651a771396b727505458e182bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daphn=C3=A9=20Popin?= Date: Mon, 27 Nov 2023 16:40:29 +0100 Subject: [PATCH] Render batch for agent messages & content fragment --- front/lib/api/assistant/conversation.ts | 252 ++++++++++++++++-------- 1 file changed, 171 insertions(+), 81 deletions(-) diff --git a/front/lib/api/assistant/conversation.ts b/front/lib/api/assistant/conversation.ts index d838e4f658a73..6f7de4add737f 100644 --- a/front/lib/api/assistant/conversation.ts +++ b/front/lib/api/assistant/conversation.ts @@ -25,6 +25,8 @@ import { GPT_3_5_TURBO_MODEL_CONFIG } from "@app/lib/assistant"; import { Authenticator } from "@app/lib/auth"; import { front_sequelize } from "@app/lib/databases"; import { + AgentConfiguration, + AgentDustAppRunAction, AgentMessage, Conversation, ConversationParticipant, @@ -38,6 +40,7 @@ import { updateWorkspacePerMonthlyActiveUsersSubscriptionUsage } from "@app/lib/ import { Err, Ok, Result } from "@app/lib/result"; import { generateModelSId } from "@app/lib/utils"; import logger from "@app/logger/logger"; +import { AgentConfigurationType } from "@app/types/assistant/agent"; import { AgentMessageType, ContentFragmentContentType, @@ -59,6 +62,7 @@ import { WorkspaceType } from "@app/types/user"; import { renderDustAppRunActionByModelId } from "./actions/dust_app_run"; import { renderRetrievalActionByModelId } from "./actions/retrieval"; +import { getGlobalAgents } from "./global_agents"; /** * Conversation Creation, update and deletion */ @@ -278,69 +282,135 @@ async function batchRenderUserMessages(messages: Message[]) { }); } -async function renderAgentMessage( +async function batchRenderAgentMessages( auth: Authenticator, - { - message, - agentMessage, - messages, - }: { message: Message; agentMessage: AgentMessage; messages: Message[] } -): Promise { - const [agentConfiguration, agentRetrievalAction, agentDustAppRunAction] = + messages: Message[] +) { + if (messages.find((m) => !m.agentMessage)) { + throw new Error( + "Unreachable: batchRenderAgentMessages must be called with only agent messages" + ); + } + + const [agentConfigurations, agentRetrievalActions, agentDustAppRunActions] = await Promise.all([ - getAgentConfiguration(auth, agentMessage.agentConfigurationId), (async () => { - if (agentMessage.agentRetrievalActionId) { - return await renderRetrievalActionByModelId( - agentMessage.agentRetrievalActionId - ); - } - return null; + const agentConfigurationIds: string[] = messages.reduce( + (acc: string[], m) => { + const agentId = m.agentMessage?.agentConfigurationId; + if (agentId && !acc.includes(agentId)) { + acc.push(agentId); + } + return acc; + }, + [] + ); + const agents = ( + await Promise.all( + agentConfigurationIds.map(async (agentConfigId) => { + return await getAgentConfiguration(auth, agentConfigId); + }) + ) + ).filter((a) => a !== null) as AgentConfigurationType[]; + const globalAgents = await getGlobalAgents(auth); + return [...globalAgents, ...agents]; })(), (async () => { - if (agentMessage.agentDustAppRunActionId) { - return await renderDustAppRunActionByModelId( - agentMessage.agentDustAppRunActionId - ); - } - return null; + const agentRetrievalActionIds: number[] = messages.reduce( + (acc: number[], m) => { + const agentId = m.agentMessage?.agentRetrievalActionId; + if (agentId && !acc.includes(agentId)) { + acc.push(agentId); + } + return acc; + }, + [] + ); + return await Promise.all( + agentRetrievalActionIds.map(async (agentRetrievalActionId) => { + return await renderRetrievalActionByModelId(agentRetrievalActionId); + }) + ); + })(), + (async () => { + const agentDustAppRunActionsIds: number[] = messages.reduce( + (acc: number[], m) => { + const agentId = m.agentMessage?.agentDustAppRunActionId; + if (agentId && !acc.includes(agentId)) { + acc.push(agentId); + } + return acc; + }, + [] + ); + const actions = await AgentDustAppRunAction.findAll({ + where: { + id: { + [Op.in]: agentDustAppRunActionsIds, + }, + }, + }); + return actions.map((action) => { + return { + id: action.id, + type: "dust_app_run_action", + appWorkspaceId: action.appWorkspaceId, + appId: action.appId, + appName: action.appName, + params: action.params, + runningBlock: null, + output: action.output, + }; + }); })(), ]); - if (!agentConfiguration) { - throw new Error( - `Configuration ${agentMessage.agentConfigurationId} not found` + return messages.map((message) => { + if (!message.agentMessage) { + throw new Error( + "Unreachable: batchRenderUserMessages must be called with only user messages" + ); + } + const agentMessage = message.agentMessage; + const action = + agentRetrievalActions.find( + (a) => a.id === agentMessage?.agentRetrievalActionId + ) ?? + agentDustAppRunActions.find( + (a) => a.id === agentMessage.agentDustAppRunActionId + ); + const agentConfiguration = agentConfigurations.find( + (a) => a.sId === message.agentMessage?.agentConfigurationId ); - } - const action = agentRetrievalAction ?? agentDustAppRunAction; + let error: { + code: string; + message: string; + } | null = null; + if (agentMessage.errorCode !== null && agentMessage.errorMessage !== null) { + error = { + code: agentMessage.errorCode, + message: agentMessage.errorMessage, + }; + } - let error: { - code: string; - message: string; - } | null = null; - if (agentMessage.errorCode !== null && agentMessage.errorMessage !== null) { - error = { - code: agentMessage.errorCode, - message: agentMessage.errorMessage, + const m = { + id: message.id, + sId: message.sId, + created: message.createdAt.getTime(), + type: "agent_message", + visibility: message.visibility, + version: message.version, + parentMessageId: + messages.find((m) => m.id === message.parentId)?.sId ?? null, + status: agentMessage.status, + action, + content: agentMessage.content, + error, + configuration: agentConfiguration, }; - } - - return { - id: message.id, - sId: message.sId, - created: message.createdAt.getTime(), - type: "agent_message", - visibility: message.visibility, - version: message.version, - parentMessageId: - messages.find((m) => m.id === message.parentId)?.sId ?? null, - status: agentMessage.status, - action, - content: agentMessage.content, - error, - configuration: agentConfiguration, - }; + return { m, rank: message.rank, version: message.version }; + }); } function renderContentFragment({ @@ -370,6 +440,43 @@ function renderContentFragment({ }; } +async function batchRenderContentFragment(messages: Message[]) { + if (messages.find((m) => !m.contentFragment)) { + throw new Error( + "Unreachable: batchRenderContentFragment must be called with only content fragments" + ); + } + + return messages.map((message) => { + if (!message.contentFragment) { + throw new Error( + "Unreachable: batchRenderContentFragment must be called with only content fragments" + ); + } + const contentFragment = message.contentFragment; + + const m = { + id: message.id, + sId: message.sId, + created: message.createdAt.getTime(), + type: "content_fragment", + visibility: message.visibility, + version: message.version, + title: contentFragment.title, + content: contentFragment.content, + url: contentFragment.url, + contentType: contentFragment.contentType, + context: { + profilePictureUrl: contentFragment.userContextProfilePictureUrl, + fullName: contentFragment.userContextFullName, + email: contentFragment.userContextEmail, + username: contentFragment.userContextUsername, + }, + }; + return { m, rank: message.rank, version: message.version }; + }); +} + export async function getUserConversations( auth: Authenticator, includeDeleted?: boolean @@ -476,43 +583,26 @@ export async function getConversation( ], }); - const [userMessages] = await Promise.all([ + const [userMessages, agentMessages, contentFragments] = await Promise.all([ (async () => { return await batchRenderUserMessages( messages.filter((m) => !!m.userMessage) ); })(), + (async () => { + return await batchRenderAgentMessages( + auth, + messages.filter((m) => !!m.agentMessage) + ); + })(), + (async () => { + return await batchRenderContentFragment( + messages.filter((m) => !!m.contentFragment) + ); + })(), ]); - const renderAgentAndContentFragments = await Promise.all( - messages - .filter((m) => !m.userMessage) - .map((message) => { - return (async () => { - if (message.agentMessage) { - const m = await renderAgentMessage(auth, { - message, - agentMessage: message.agentMessage, - messages, - }); - return { m, rank: message.rank, version: message.version }; - } - if (message.contentFragment) { - const m = await renderContentFragment({ - message: message, - contentFragment: message.contentFragment, - }); - return { m, rank: message.rank, version: message.version }; - } - throw new Error( - "Unreachable: message must be either user, agent or content fragment" - ); - })(); - }) - ); - - const render = [...userMessages, ...renderAgentAndContentFragments]; - + const render = [...userMessages, ...agentMessages, ...contentFragments]; render.sort((a, b) => { if (a.rank !== b.rank) { return a.rank - b.rank;