From 816800d9659cbec878a35d6d33d8146329d3f01a Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Fri, 8 Sep 2023 11:08:41 +0200 Subject: [PATCH] runAgent implementation, finalized postUserMessage modulo configuraiton --- front/lib/api/assistant/actions/retrieval.ts | 13 +- front/lib/api/assistant/agent.ts | 92 ++++++++++++ front/lib/api/assistant/conversation.ts | 147 +++++++++---------- 3 files changed, 168 insertions(+), 84 deletions(-) diff --git a/front/lib/api/assistant/actions/retrieval.ts b/front/lib/api/assistant/actions/retrieval.ts index 56d49d3bc346..e7a77e7ecc45 100644 --- a/front/lib/api/assistant/actions/retrieval.ts +++ b/front/lib/api/assistant/actions/retrieval.ts @@ -332,16 +332,9 @@ export async function* runRetrieval( const c = configuration.action; if (!isRetrievalConfiguration(c)) { - return yield { - type: "retrieval_error", - created: Date.now(), - configurationId: configuration.sId, - messageId: agentMessage.sId, - error: { - code: "internal_server_error", - message: "Unexpected action configuration received in `runRetrieval`", - }, - }; + throw new Error( + "Unexpected action configuration received in `runRetrieval`" + ); } const paramsRes = await generateRetrievalParams( diff --git a/front/lib/api/assistant/agent.ts b/front/lib/api/assistant/agent.ts index 137a08d3028e..75dce5fe133c 100644 --- a/front/lib/api/assistant/agent.ts +++ b/front/lib/api/assistant/agent.ts @@ -6,14 +6,17 @@ import { runAction } from "@app/lib/actions/server"; import { RetrievalDocumentsEvent, RetrievalParamsEvent, + runRetrieval, } from "@app/lib/api/assistant/actions/retrieval"; import { GenerationTokensEvent, renderConversationForModel, + runGeneration, } from "@app/lib/api/assistant/generation"; import { Authenticator } from "@app/lib/auth"; import { Err, Ok, Result } from "@app/lib/result"; import { generateModelSId } from "@app/lib/utils"; +import { isRetrievalConfiguration } from "@app/types/assistant/actions/retrieval"; import { AgentActionConfigurationType, AgentActionSpecification, @@ -216,6 +219,95 @@ export async function* runAgent( | AgentGenerationSuccessEvent | AgentMessageSuccessEvent > { + // First run the action if a configuration is present. + if (configuration.action !== null) { + if (isRetrievalConfiguration(configuration.action)) { + const eventStream = runRetrieval( + auth, + configuration, + conversation, + userMessage, + agentMessage + ); + + for await (const event of eventStream) { + if (event.type === "retrieval_params") { + yield event; + } + if (event.type === "retrieval_documents") { + yield event; + } + if (event.type === "retrieval_error") { + yield { + type: "agent_error", + created: event.created, + configurationId: configuration.sId, + messageId: agentMessage.sId, + error: { + code: event.error.code, + message: event.error.message, + }, + }; + } + if (event.type === "retrieval_success") { + yield { + type: "agent_action_success", + created: event.created, + configurationId: configuration.sId, + messageId: agentMessage.sId, + action: event.action, + }; + + // We stitch the action into the agent message. The conversation is expected to include + // the agentMessage object, updating this object will update the conversation as well. + agentMessage.action = event.action; + } + } + } else { + throw new Error( + "runAgent implementation missing for action configuration" + ); + } + + // Then run the generation if a configuration is present. + if (configuration.generation !== null) { + const eventStream = runGeneration( + auth, + configuration, + conversation, + userMessage, + agentMessage + ); + + for await (const event of eventStream) { + if (event.type === "generation_tokens") { + yield event; + } + if (event.type === "generation_error") { + yield { + type: "agent_error", + created: event.created, + configurationId: configuration.sId, + messageId: agentMessage.sId, + error: { + code: event.error.code, + message: event.error.message, + }, + }; + } + if (event.type === "generation_success") { + yield { + type: "agent_generation_success", + created: event.created, + configurationId: configuration.sId, + messageId: agentMessage.sId, + text: event.text, + }; + } + } + } + } + yield { type: "agent_error", created: Date.now(), diff --git a/front/lib/api/assistant/conversation.ts b/front/lib/api/assistant/conversation.ts index 6e8baf94450c..b89e253484d3 100644 --- a/front/lib/api/assistant/conversation.ts +++ b/front/lib/api/assistant/conversation.ts @@ -177,80 +177,88 @@ export async function* postUserMessage( message: userMessage, }; - for (let i = 0; i < agentMessages.length; i++) { - const agentMessage = agentMessages[i]; - const agentMessageRow = agentMessageRows[i]; + await Promise.allSettled( + agentMessages.map(async function* (agentMessage, i) { + //for (let i = 0; i < agentMessages.length; i++) { + //const agentMessage = agentMessages[i]; + const agentMessageRow = agentMessageRows[i]; - yield { - type: "agent_message_new", - created: Date.now(), - configurationId: agentMessage.configuration.sId, - messageId: agentMessage.sId, - message: agentMessage, - }; - - const eventStream = runAgent( - auth, - agentMessage.configuration, - conversation, - userMessage, - agentMessage - ); + yield { + type: "agent_message_new", + created: Date.now(), + configurationId: agentMessage.configuration.sId, + messageId: agentMessage.sId, + message: agentMessage, + }; - for await (const event of eventStream) { - if (event.type === "agent_error") { - // Store error in database. - await agentMessageRow.update({ - status: "failed", - errorCode: event.error.code, - errorMessage: event.error.message, - }); - yield event; - } + // For each agent we stitch the conversation to add the user message and only that agent message + // so that it can be used to prompt the agent. + const eventStream = runAgent( + auth, + agentMessage.configuration, + { + ...conversation, + content: [...conversation.content, [userMessage], [agentMessage]], + }, + userMessage, + agentMessage + ); - if (event.type === "agent_action_success") { - // Store action in database. - if (event.action.type === "retrieval_action") { + for await (const event of eventStream) { + if (event.type === "agent_error") { + // Store error in database. await agentMessageRow.update({ - agentRetrievalActionId: event.action.id, + status: "failed", + errorCode: event.error.code, + errorMessage: event.error.message, }); - } else { - throw new Error( - `Action type ${event.action.type} agent_action_success handling not implemented` - ); + yield event; } - yield event; - } - if (event.type === "agent_generation_success") { - // Store message in database. - await agentMessageRow.update({ - message: event.text, - }); - yield event; - } + if (event.type === "agent_action_success") { + // Store action in database. + if (event.action.type === "retrieval_action") { + await agentMessageRow.update({ + agentRetrievalActionId: event.action.id, + }); + } else { + throw new Error( + `Action type ${event.action.type} agent_action_success handling not implemented` + ); + } + yield event; + } - if (event.type === "agent_message_success") { - // Update status in database. - await agentMessageRow.update({ - status: "succeeded", - }); - yield event; - } + if (event.type === "agent_generation_success") { + // Store message in database. + await agentMessageRow.update({ + message: event.text, + }); + yield event; + } + + if (event.type === "agent_message_success") { + // Update status in database. + await agentMessageRow.update({ + status: "succeeded", + }); + yield event; + } - // All other events that won't impact the database and are related to actions or tokens - // generation. - if ( - [ - "retrieval_params", - "retrieval_documents", - "generation_tokens", - ].includes(event.type) - ) { - yield event; + // All other events that won't impact the database and are related to actions or tokens + // generation. + if ( + [ + "retrieval_params", + "retrieval_documents", + "generation_tokens", + ].includes(event.type) + ) { + yield event; + } } - } - } + }) + ); } // This method is in charge of re-running an agent interaction (generating a new @@ -298,16 +306,7 @@ export async function* editUserMessage( message: UserMessageType; content: string; } -): AsyncGenerator< - | UserMessageNewEvent - | AgentMessageNewEvent - | AgentErrorEvent - | AgentActionEvent - | AgentActionSuccessEvent - | GenerationTokensEvent - | AgentGenerationSuccessEvent - | AgentMessageSuccessEvent -> { +): AsyncGenerator { yield { type: "agent_error", created: Date.now(),