diff --git a/front/lib/api/assistant/actions/dust_app_run.ts b/front/lib/api/assistant/actions/dust_app_run.ts index 7881326119ec..657f34f85cc6 100644 --- a/front/lib/api/assistant/actions/dust_app_run.ts +++ b/front/lib/api/assistant/actions/dust_app_run.ts @@ -224,10 +224,17 @@ export async function renderDustAppRunActionByModelId( // message. export async function* runDustApp( auth: Authenticator, - configuration: AgentConfigurationType, - conversation: ConversationType, - userMessage: UserMessageType, - agentMessage: AgentMessageType + { + configuration, + conversation, + userMessage, + agentMessage, + }: { + configuration: AgentConfigurationType; + conversation: ConversationType; + userMessage: UserMessageType; + agentMessage: AgentMessageType; + } ): AsyncGenerator< | DustAppRunParamsEvent | DustAppRunBlockEvent diff --git a/front/lib/api/assistant/actions/retrieval.ts b/front/lib/api/assistant/actions/retrieval.ts index 844e4ce1070a..5ce45ec7e18a 100644 --- a/front/lib/api/assistant/actions/retrieval.ts +++ b/front/lib/api/assistant/actions/retrieval.ts @@ -483,10 +483,17 @@ const getRefs = () => { // error is expected to be stored by the caller on the parent agent message. export async function* runRetrieval( auth: Authenticator, - configuration: AgentConfigurationType, - conversation: ConversationType, - userMessage: UserMessageType, - agentMessage: AgentMessageType + { + configuration, + conversation, + userMessage, + agentMessage, + }: { + configuration: AgentConfigurationType; + conversation: ConversationType; + userMessage: UserMessageType; + agentMessage: AgentMessageType; + } ): AsyncGenerator< RetrievalParamsEvent | RetrievalSuccessEvent | RetrievalErrorEvent, void diff --git a/front/lib/api/assistant/actions/tables_query.ts b/front/lib/api/assistant/actions/tables_query.ts index d30c392fc395..052e5c743ae7 100644 --- a/front/lib/api/assistant/actions/tables_query.ts +++ b/front/lib/api/assistant/actions/tables_query.ts @@ -112,19 +112,20 @@ export async function generateTablesQueryAppParams( /** * Run the TablesQuery app. */ -export async function* runTablesQuery({ - auth, - configuration, - conversation, - userMessage, - agentMessage, -}: { - auth: Authenticator; - configuration: AgentConfigurationType; - conversation: ConversationType; - userMessage: UserMessageType; - agentMessage: AgentMessageType; -}): AsyncGenerator< +export async function* runTablesQuery( + auth: Authenticator, + { + configuration, + conversation, + userMessage, + agentMessage, + }: { + configuration: AgentConfigurationType; + conversation: ConversationType; + userMessage: UserMessageType; + agentMessage: AgentMessageType; + } +): AsyncGenerator< | TablesQueryErrorEvent | TablesQuerySuccessEvent | TablesQueryParamsEvent diff --git a/front/lib/api/assistant/agent.ts b/front/lib/api/assistant/agent.ts index 7e18cc161ef0..36055840554e 100644 --- a/front/lib/api/assistant/agent.ts +++ b/front/lib/api/assistant/agent.ts @@ -1,4 +1,5 @@ import type { + AgentActionConfigurationType, AgentActionEvent, AgentActionSpecification, AgentActionSuccessEvent, @@ -188,156 +189,17 @@ export async function* runAgent( ); } - const action = deprecatedGetFirstActionConfiguration(fullConfiguration); + const action = await pickAction(auth, fullConfiguration, conversation); if (action !== null) { - if (isRetrievalConfiguration(action)) { - const eventStream = runRetrieval( - auth, - fullConfiguration, - conversation, - userMessage, - agentMessage - ); - - for await (const event of eventStream) { - switch (event.type) { - case "retrieval_params": - yield event; - break; - case "retrieval_error": - yield { - type: "agent_error", - created: event.created, - configurationId: configuration.sId, - messageId: agentMessage.sId, - error: { - code: event.error.code, - message: event.error.message, - }, - }; - return; - case "retrieval_success": - yield { - type: "agent_action_success", - created: event.created, - configurationId: configuration.sId, - messageId: agentMessage.sId, - action: event.action, - }; - - // We stitch the action into the agent message. The conversation is expected to include - // the agentMessage object, updating this object will update the conversation as well. - agentMessage.action = event.action; - break; - - default: - ((event: never) => { - logger.error("Unknown `runAgent` event type", event); - })(event); - return; - } - } - } else if (isDustAppRunConfiguration(action)) { - const eventStream = runDustApp( - auth, - fullConfiguration, - conversation, - userMessage, - agentMessage - ); - - for await (const event of eventStream) { - switch (event.type) { - case "dust_app_run_params": - yield event; - break; - case "dust_app_run_error": - yield { - type: "agent_error", - created: event.created, - configurationId: configuration.sId, - messageId: agentMessage.sId, - error: { - code: event.error.code, - message: event.error.message, - }, - }; - return; - case "dust_app_run_block": - yield event; - break; - case "dust_app_run_success": - yield { - type: "agent_action_success", - created: event.created, - configurationId: configuration.sId, - messageId: agentMessage.sId, - action: event.action, - }; - - // We stitch the action into the agent message. The conversation is expected to include - // the agentMessage object, updating this object will update the conversation as well. - agentMessage.action = event.action; - break; - - default: - ((event: never) => { - logger.error("Unknown `runAgent` event type", event); - })(event); - return; - } - } - } else if (isTablesQueryConfiguration(action)) { - const eventStream = runTablesQuery({ - auth, - configuration: fullConfiguration, - conversation, - userMessage, - agentMessage, - }); - for await (const event of eventStream) { - switch (event.type) { - case "tables_query_params": - case "tables_query_output": - yield event; - break; - case "tables_query_error": - yield { - type: "agent_error", - created: event.created, - configurationId: configuration.sId, - messageId: agentMessage.sId, - error: { - code: event.error.code, - message: event.error.message, - }, - }; - return; - case "tables_query_success": - yield { - type: "agent_action_success", - created: event.created, - configurationId: configuration.sId, - messageId: agentMessage.sId, - action: event.action, - }; - - // We stitch the action into the agent message. The conversation is expected to include - // the agentMessage object, updating this object will update the conversation as well. - agentMessage.action = event.action; - break; - default: - ((event: never) => { - logger.error("Unknown `runAgent` event type", event); - })(event); - return; - } - } - } else { - ((a: never) => { - throw new Error(`Unexpected action type: ${a}`); - })(action); + for await (const event of runAction(auth, { + configuration: fullConfiguration, + conversation, + userMessage, + agentMessage, + action, + })) { + yield event; } } @@ -409,3 +271,183 @@ export async function* runAgent( message: agentMessage, }; } + +async function pickAction( + auth: Authenticator, + configuration: AgentConfigurationType, + conversation: ConversationType +): Promise { + // TODO(@fontanierh): This is a placeholder for the action selection logic. + // This will be replaced by the multi-actions "pick tool" logic. + // This should also include the inputs-generation logic (generating the arguments for the action) + void auth; + void conversation; + + return deprecatedGetFirstActionConfiguration(configuration); +} + +async function* runAction( + auth: Authenticator, + { + configuration, + conversation, + userMessage, + agentMessage, + action, + }: { + configuration: AgentConfigurationType; + conversation: ConversationType; + userMessage: UserMessageType; + agentMessage: AgentMessageType; + action: AgentActionConfigurationType; + } +): AsyncGenerator< + AgentErrorEvent | AgentActionEvent | AgentActionSuccessEvent, + void +> { + if (isRetrievalConfiguration(action)) { + const eventStream = runRetrieval(auth, { + configuration, + conversation, + userMessage, + agentMessage, + }); + + for await (const event of eventStream) { + switch (event.type) { + case "retrieval_params": + yield event; + break; + case "retrieval_error": + yield { + type: "agent_error", + created: event.created, + configurationId: configuration.sId, + messageId: agentMessage.sId, + error: { + code: event.error.code, + message: event.error.message, + }, + }; + return; + case "retrieval_success": + yield { + type: "agent_action_success", + created: event.created, + configurationId: configuration.sId, + messageId: agentMessage.sId, + action: event.action, + }; + + // We stitch the action into the agent message. The conversation is expected to include + // the agentMessage object, updating this object will update the conversation as well. + agentMessage.action = event.action; + break; + + default: + ((event: never) => { + logger.error("Unknown `runAgent` event type", event); + })(event); + return; + } + } + } else if (isDustAppRunConfiguration(action)) { + const eventStream = runDustApp(auth, { + configuration, + conversation, + userMessage, + agentMessage, + }); + + for await (const event of eventStream) { + switch (event.type) { + case "dust_app_run_params": + yield event; + break; + case "dust_app_run_error": + yield { + type: "agent_error", + created: event.created, + configurationId: configuration.sId, + messageId: agentMessage.sId, + error: { + code: event.error.code, + message: event.error.message, + }, + }; + return; + case "dust_app_run_block": + yield event; + break; + case "dust_app_run_success": + yield { + type: "agent_action_success", + created: event.created, + configurationId: configuration.sId, + messageId: agentMessage.sId, + action: event.action, + }; + + // We stitch the action into the agent message. The conversation is expected to include + // the agentMessage object, updating this object will update the conversation as well. + agentMessage.action = event.action; + break; + + default: + ((event: never) => { + logger.error("Unknown `runAgent` event type", event); + })(event); + return; + } + } + } else if (isTablesQueryConfiguration(action)) { + const eventStream = runTablesQuery(auth, { + configuration, + conversation, + userMessage, + agentMessage, + }); + for await (const event of eventStream) { + switch (event.type) { + case "tables_query_params": + case "tables_query_output": + yield event; + break; + case "tables_query_error": + yield { + type: "agent_error", + created: event.created, + configurationId: configuration.sId, + messageId: agentMessage.sId, + error: { + code: event.error.code, + message: event.error.message, + }, + }; + return; + case "tables_query_success": + yield { + type: "agent_action_success", + created: event.created, + configurationId: configuration.sId, + messageId: agentMessage.sId, + action: event.action, + }; + + // We stitch the action into the agent message. The conversation is expected to include + // the agentMessage object, updating this object will update the conversation as well. + agentMessage.action = event.action; + break; + default: + ((event: never) => { + logger.error("Unknown `runAgent` event type", event); + })(event); + return; + } + } + } else { + ((a: never) => { + throw new Error(`Unexpected action type: ${a}`); + })(action); + } +}