From 0e7714df047fe17cc552468bae40b14256fbf7bd Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Fri, 15 Nov 2024 15:06:15 +0100 Subject: [PATCH] Conversation actions scafolding + initial list_files action (#8646) * jit actions scafolding * list_files types * more * simplify * simplify * clean * simplify * one way of doing it * fix imports * typing work * fixes * rename * emulated actions injection/scrub * reintroduce step * fix imports * clean-up visualiation * fix steps * rebase and comments * rename jitFiles * assert message * stepByIndex * remove comments --- front/components/actions/types.ts | 4 + .../actions/conversation/list_files.ts | 105 +++++++++++++++ front/lib/api/assistant/agent.ts | 44 ++++++- front/lib/api/assistant/jit_actions.ts | 121 ++++-------------- front/lib/api/assistant/visualization.ts | 101 ++++++++------- sdks/js/src/types.ts | 16 +++ .../actions/conversation/list_files.ts | 17 +++ types/src/front/assistant/conversation.ts | 11 +- .../front/lib/api/assistant/actions/index.ts | 3 +- types/src/index.ts | 1 + 10 files changed, 271 insertions(+), 152 deletions(-) create mode 100644 front/lib/api/assistant/actions/conversation/list_files.ts create mode 100644 types/src/front/assistant/actions/conversation/list_files.ts diff --git a/front/components/actions/types.ts b/front/components/actions/types.ts index 9a280c405a9d..c4b3cf74dbff 100644 --- a/front/components/actions/types.ts +++ b/front/components/actions/types.ts @@ -52,6 +52,10 @@ const actionsSpecification: ActionSpecifications = { detailsComponent: BrowseActionDetails, runningLabel: ACTION_RUNNING_LABELS.browse_action, }, + conversation_list_files_action: { + detailsComponent: () => null, + runningLabel: ACTION_RUNNING_LABELS.conversation_list_files_action, + }, }; export function getActionSpecification( diff --git a/front/lib/api/assistant/actions/conversation/list_files.ts b/front/lib/api/assistant/actions/conversation/list_files.ts new file mode 100644 index 000000000000..4a66694c14c5 --- /dev/null +++ b/front/lib/api/assistant/actions/conversation/list_files.ts @@ -0,0 +1,105 @@ +import type { + AgentMessageType, + ConversationFileType, + ConversationListFilesActionType, + ConversationType, + FunctionCallType, + FunctionMessageTypeModel, + ModelId, +} from "@dust-tt/types"; +import { + BaseAction, + getTablesQueryResultsFileTitle, + isAgentMessageType, + isContentFragmentType, + isTablesQueryActionType, +} from "@dust-tt/types"; + +interface ConversationListFilesActionBlob { + agentMessageId: ModelId; + functionCallId: string | null; + functionCallName: string | null; + files: ConversationFileType[]; +} + +export class ConversationListFilesAction extends BaseAction { + readonly agentMessageId: ModelId; + readonly files: ConversationFileType[]; + readonly functionCallId: string | null; + readonly functionCallName: string | null; + readonly step: number = -1; + readonly type = "conversation_list_files_action"; + + constructor(blob: ConversationListFilesActionBlob) { + super(-1, "conversation_list_files_action"); + + this.agentMessageId = blob.agentMessageId; + this.files = blob.files; + this.functionCallId = blob.functionCallId; + this.functionCallName = blob.functionCallName; + } + + renderForFunctionCall(): FunctionCallType { + return { + id: this.functionCallId ?? `call_${this.id.toString()}`, + name: this.functionCallName ?? "list_conversation_files", + arguments: JSON.stringify({}), + }; + } + + renderForMultiActionsModel(): FunctionMessageTypeModel { + let content = "CONVERSATION FILES:\n"; + for (const f of this.files) { + content += `\n`; + } + + return { + role: "function" as const, + name: this.functionCallName ?? "list_conversation_files", + function_call_id: this.functionCallId ?? `call_${this.id.toString()}`, + content, + }; + } +} + +export function makeConversationListFilesAction( + agentMessage: AgentMessageType, + conversation: ConversationType +): ConversationListFilesActionType | null { + const files: ConversationFileType[] = []; + + for (const m of conversation.content.flat(1)) { + if (isContentFragmentType(m)) { + if (m.fileId) { + files.push({ + fileId: m.fileId, + title: m.title, + contentType: m.contentType, + }); + } + } else if (isAgentMessageType(m)) { + for (const a of m.actions) { + if (isTablesQueryActionType(a)) { + if (a.resultsFileId && a.resultsFileSnippet) { + files.push({ + fileId: a.resultsFileId, + contentType: "text/csv", + title: getTablesQueryResultsFileTitle({ output: a.output }), + }); + } + } + } + } + } + + if (files.length === 0) { + return null; + } + + return new ConversationListFilesAction({ + functionCallId: "call_" + Math.random().toString(36).substring(7), + functionCallName: "list_conversation_files", + files, + agentMessageId: agentMessage.agentMessageId, + }); +} diff --git a/front/lib/api/assistant/agent.ts b/front/lib/api/assistant/agent.ts index fb88edfd1fa1..9e57c810a246 100644 --- a/front/lib/api/assistant/agent.ts +++ b/front/lib/api/assistant/agent.ts @@ -4,6 +4,7 @@ import type { AgentActionSpecification, AgentActionSpecificEvent, AgentActionSuccessEvent, + AgentActionType, AgentChainOfThoughtEvent, AgentConfigurationType, AgentContentEvent, @@ -29,9 +30,13 @@ import { isWebsearchConfiguration, SUPPORTED_MODEL_CONFIGS, } from "@dust-tt/types"; +import assert from "assert"; import { runActionStreamed } from "@app/lib/actions/server"; +import { isJITActionsEnabled } from "@app/lib/api/assistant//jit_actions"; +import { makeConversationListFilesAction } from "@app/lib/api/assistant/actions/conversation/list_files"; import { getRunnerForActionConfiguration } from "@app/lib/api/assistant/actions/runners"; +import { getCitationsCount } from "@app/lib/api/assistant/actions/utils"; import { AgentMessageContentParser, getDelimitersConfiguration, @@ -49,8 +54,6 @@ import { AgentMessageContent } from "@app/lib/models/assistant/agent_message_con import { cloneBaseConfig, DustProdActionRegistry } from "@app/lib/registry"; import logger from "@app/logger/logger"; -import { getCitationsCount } from "./actions/utils"; - const CANCELLATION_CHECK_INTERVAL = 500; const MAX_ACTIONS_PER_STEP = 16; @@ -292,6 +295,29 @@ async function* runMultiActionsAgentLoop( } } +async function getEmulatedAgentMessageActions( + auth: Authenticator, + { + agentMessage, + conversation, + }: { agentMessage: AgentMessageType; conversation: ConversationType } +): Promise { + const actions: AgentActionType[] = []; + if (await isJITActionsEnabled(auth)) { + const a = makeConversationListFilesAction(agentMessage, conversation); + if (a) { + actions.push(a); + } + } + + // We ensure that all emulated actions are injected with step -1. + assert( + actions.every((a) => a.step === -1), + "Emulated actions must have step -1" + ); + return actions; +} + // This method is used by the multi-actions execution loop to pick the next action to execute and // generate its inputs. async function* runMultiActionsAgent( @@ -362,6 +388,15 @@ async function* runMultiActionsAgent( const MIN_GENERATION_TOKENS = 2048; + const emulatedActions = await getEmulatedAgentMessageActions(auth, { + agentMessage, + conversation, + }); + + // Prepend emulated actions to the current agent message before rendering the conversation for the + // model. + agentMessage.actions = emulatedActions.concat(agentMessage.actions); + // Turn the conversation into a digest that can be presented to the model. const modelConversationRes = await renderConversationForModel(auth, { conversation, @@ -370,6 +405,11 @@ async function* runMultiActionsAgent( allowedTokenCount: model.contextSize - MIN_GENERATION_TOKENS, }); + // Scrub emulated actions from the agent message after rendering. + agentMessage.actions = agentMessage.actions.filter( + (a) => !emulatedActions.includes(a) + ); + if (modelConversationRes.isErr()) { logger.error( { diff --git a/front/lib/api/assistant/jit_actions.ts b/front/lib/api/assistant/jit_actions.ts index 49226eaeec92..71dd1b1db1f6 100644 --- a/front/lib/api/assistant/jit_actions.ts +++ b/front/lib/api/assistant/jit_actions.ts @@ -12,12 +12,10 @@ import type { import { assertNever, Err, - getTablesQueryResultsFileAttachment, isAgentMessageType, isContentFragmentMessageTypeModel, isContentFragmentType, isDevelopment, - isTablesQueryActionType, isTextContent, isUserMessageType, Ok, @@ -81,46 +79,55 @@ export async function renderConversationForModelJIT({ const now = Date.now(); const messages: ModelMessageTypeMultiActions[] = []; - // Render loop. - // Render all messages and all actions. + // Render loop: dender all messages and all actions. for (const versions of conversation.content) { const m = versions[versions.length - 1]; if (isAgentMessageType(m)) { const actions = removeNulls(m.actions); - // This array is 2D, because we can have multiple calls per agent message (parallel calls). - - const steps = [] as Array<{ - contents: string[]; - actions: Array<{ - call: FunctionCallType; - result: FunctionMessageTypeModel; - }>; - }>; + // This is a record of arrays, because we can have multiple calls per agent message (parallel + // calls). Actions all have a step index which indicates how they should be grouped but some + // actions injected by `getEmulatedAgentMessageActions` have a step index of `-1`. We + // therefore group by index, then order and transform in a 2D array to present to the model. + const stepByStepIndex = {} as Record< + string, + { + contents: string[]; + actions: Array<{ + call: FunctionCallType; + result: FunctionMessageTypeModel; + }>; + } + >; const emptyStep = () => ({ contents: [], actions: [], - }) satisfies (typeof steps)[number]; + }) satisfies (typeof stepByStepIndex)[number]; for (const action of actions) { const stepIndex = action.step; - steps[stepIndex] = steps[stepIndex] || emptyStep(); - steps[stepIndex].actions.push({ + stepByStepIndex[stepIndex] = stepByStepIndex[stepIndex] || emptyStep(); + stepByStepIndex[stepIndex].actions.push({ call: action.renderForFunctionCall(), result: action.renderForMultiActionsModel(), }); } for (const content of m.rawContents) { - steps[content.step] = steps[content.step] || emptyStep(); + stepByStepIndex[content.step] = + stepByStepIndex[content.step] || emptyStep(); if (content.content.trim()) { - steps[content.step].contents.push(content.content); + stepByStepIndex[content.step].contents.push(content.content); } } + const steps = Object.entries(stepByStepIndex) + .sort(([a], [b]) => Number(a) - Number(b)) + .map(([, step]) => step); + if (excludeActions) { // In Exclude Actions mode, we only render the last step that has content. const stepsWithContent = steps.filter((s) => s?.contents.length); @@ -222,44 +229,6 @@ export async function renderConversationForModelJIT({ } } - // If we have messages... - if (messages.length > 0) { - const { filesAsXML, hasFiles } = listConversationFiles({ - conversation, - }); - - // ... and files, we simulate a function call to list the files at the end of the conversation. - if (hasFiles) { - const randomCallId = "tool_" + Math.random().toString(36).substring(7); - const functionName = "list_conversation_files"; - - const simulatedAgentMessages = [ - // 1. We add a message from the agent, asking to use the files listing function - { - role: "assistant", - function_calls: [ - { - id: randomCallId, - name: functionName, - arguments: "{}", - }, - ], - } as AssistantFunctionCallMessageTypeModel, - - // 2. We add a message with the resulting files listing - { - function_call_id: randomCallId, - role: "function", - name: functionName, - content: filesAsXML, - } as FunctionMessageTypeModel, - ]; - - // Append the simulated messages to the end of the conversation. - messages.push(...simulatedAgentMessages); - } - } - // Compute in parallel the token count for each message and the prompt. const res = await tokenCountForTexts( [prompt, ...getTextRepresentationFromMessages(messages)], @@ -402,43 +371,3 @@ export async function renderConversationForModelJIT({ tokensUsed, }); } - -function listConversationFiles({ - conversation, -}: { - conversation: ConversationType; -}) { - const fileAttachments: string[] = []; - for (const m of conversation.content.flat(1)) { - if (isContentFragmentType(m)) { - if (!m.fileId) { - continue; - } - fileAttachments.push( - `` - ); - } else if (isAgentMessageType(m)) { - for (const a of m.actions) { - if (isTablesQueryActionType(a)) { - const attachment = getTablesQueryResultsFileAttachment({ - resultsFileId: a.resultsFileId, - resultsFileSnippet: a.resultsFileSnippet, - output: a.output, - includeSnippet: false, - }); - if (attachment) { - fileAttachments.push(attachment); - } - } - } - } - } - let filesAsXML = "\n"; - - if (fileAttachments.length > 0) { - filesAsXML += fileAttachments.join("\n"); - } - filesAsXML += "\n"; - - return { filesAsXML, hasFiles: fileAttachments.length > 0 }; -} diff --git a/front/lib/api/assistant/visualization.ts b/front/lib/api/assistant/visualization.ts index a9a0caed72af..b5c08770e97a 100644 --- a/front/lib/api/assistant/visualization.ts +++ b/front/lib/api/assistant/visualization.ts @@ -19,65 +19,64 @@ export async function getVisualizationPrompt({ auth: Authenticator; conversation: ConversationType; }) { - const isJITEnabled = await isJITActionsEnabled(auth); + // If `jit_conversations_actions` is enabled we rely on the `conversations_list_files` emulated + // actions to make the list of files available to the agent. + if (await isJITActionsEnabled(auth)) { + return visualizationSystemPrompt; + } - // When JIT is enabled, we return the visualization prompt directly without listing the files as the files will be made available to the model via another mechanism (simulated function call). - if (isJITEnabled) { - return visualizationSystemPrompt.trim(); - } else { - const contentFragmentMessages: Array = []; - for (const m of conversation.content.flat(1)) { - if (isContentFragmentType(m)) { - contentFragmentMessages.push(m); - } + const contentFragmentMessages: Array = []; + for (const m of conversation.content.flat(1)) { + if (isContentFragmentType(m)) { + contentFragmentMessages.push(m); } - const contentFragmentFileBySid = _.keyBy( - await FileResource.fetchByIds( - auth, - removeNulls(contentFragmentMessages.map((m) => m.fileId)) - ), - "sId" - ); - - let prompt = visualizationSystemPrompt.trim() + "\n\n"; - - const fileAttachments: string[] = []; - for (const m of conversation.content.flat(1)) { - if (isContentFragmentType(m)) { - if (!m.fileId || !contentFragmentFileBySid[m.fileId]) { - continue; - } - fileAttachments.push( - `` - ); - } else if (isAgentMessageType(m)) { - for (const a of m.actions) { - if (isTablesQueryActionType(a)) { - const attachment = getTablesQueryResultsFileAttachment({ - resultsFileId: a.resultsFileId, - resultsFileSnippet: a.resultsFileSnippet, - output: a.output, - includeSnippet: false, - }); - if (attachment) { - fileAttachments.push(attachment); - } + } + const contentFragmentFileBySid = _.keyBy( + await FileResource.fetchByIds( + auth, + removeNulls(contentFragmentMessages.map((m) => m.fileId)) + ), + "sId" + ); + + let prompt = visualizationSystemPrompt.trim() + "\n\n"; + + const fileAttachments: string[] = []; + for (const m of conversation.content.flat(1)) { + if (isContentFragmentType(m)) { + if (!m.fileId || !contentFragmentFileBySid[m.fileId]) { + continue; + } + fileAttachments.push( + `` + ); + } else if (isAgentMessageType(m)) { + for (const a of m.actions) { + if (isTablesQueryActionType(a)) { + const attachment = getTablesQueryResultsFileAttachment({ + resultsFileId: a.resultsFileId, + resultsFileSnippet: a.resultsFileSnippet, + output: a.output, + includeSnippet: false, + }); + if (attachment) { + fileAttachments.push(attachment); } } } } + } - if (fileAttachments.length > 0) { - prompt += - "Files accessible to the :::visualization directive environment:\n"; - prompt += fileAttachments.join("\n"); - } else { - prompt += - "No files are currently accessible to the :::visualization directive environment in this conversation."; - } - - return prompt; + if (fileAttachments.length > 0) { + prompt += + "Files accessible to the :::visualization directive environment:\n"; + prompt += fileAttachments.join("\n"); + } else { + prompt += + "No files are currently accessible to the :::visualization directive environment in this conversation."; } + + return prompt; } export const visualizationSystemPrompt = `\ diff --git a/sdks/js/src/types.ts b/sdks/js/src/types.ts index 2d974328c4db..cf2a46902831 100644 --- a/sdks/js/src/types.ts +++ b/sdks/js/src/types.ts @@ -460,6 +460,21 @@ const BrowseActionTypeSchema = BaseActionSchema.extend({ }); type BrowseActionPublicType = z.infer; +const ConversationFileTypeSchema = z.object({ + fileId: z.string(), + title: z.string(), + contentType: z.string(), +}); + +const ConversationListFilesActionTypeSchema = BaseActionSchema.extend({ + files: z.array(ConversationFileTypeSchema), + functionCallId: z.string().nullable(), + functionCallName: z.string().nullable(), + agentMessageId: ModelIdSchema, + step: z.number(), + type: z.literal("conversation_list_files_action"), +}); + const DustAppParametersSchema = z.record( z.union([z.string(), z.number(), z.boolean()]) ); @@ -839,6 +854,7 @@ const AgentActionTypeSchema = z.union([ ProcessActionTypeSchema, WebsearchActionTypeSchema, BrowseActionTypeSchema, + ConversationListFilesActionTypeSchema, ]); export type AgentActionPublicType = z.infer; diff --git a/types/src/front/assistant/actions/conversation/list_files.ts b/types/src/front/assistant/actions/conversation/list_files.ts new file mode 100644 index 000000000000..3d098c87ee2a --- /dev/null +++ b/types/src/front/assistant/actions/conversation/list_files.ts @@ -0,0 +1,17 @@ +import { BaseAction } from "../../../../front/lib/api/assistant/actions/index"; +import { ModelId } from "../../../../shared/model_id"; + +export type ConversationFileType = { + fileId: string; + title: string; + contentType: string; +}; + +export interface ConversationListFilesActionType extends BaseAction { + agentMessageId: ModelId; + files: ConversationFileType[]; + functionCallId: string | null; + functionCallName: string | null; + step: number; + type: "conversation_list_files_action"; +} diff --git a/types/src/front/assistant/conversation.ts b/types/src/front/assistant/conversation.ts index fd82bf0db958..1bf83a1f18ae 100644 --- a/types/src/front/assistant/conversation.ts +++ b/types/src/front/assistant/conversation.ts @@ -7,6 +7,7 @@ import { UserType, WorkspaceType } from "../../front/user"; import { ModelId } from "../../shared/model_id"; import { ContentFragmentType } from "../content_fragment"; import { BrowseActionType } from "./actions/browse"; +import { ConversationListFilesActionType } from "./actions/conversation/list_files"; import { WebsearchActionType } from "./actions/websearch"; /** @@ -103,8 +104,7 @@ export function isUserMessageType(arg: MessageType): arg is UserMessageType { /** * Agent messages */ - -export type AgentActionType = +export type ConfigurableAgentActionType = | RetrievalActionType | DustAppRunActionType | TablesQueryActionType @@ -112,6 +112,12 @@ export type AgentActionType = | WebsearchActionType | BrowseActionType; +export type ConversationAgentActionType = ConversationListFilesActionType; + +export type AgentActionType = + | ConfigurableAgentActionType + | ConversationAgentActionType; + export type AgentMessageStatus = | "created" | "succeeded" @@ -125,6 +131,7 @@ export const ACTION_RUNNING_LABELS: Record = { tables_query_action: "Querying tables", websearch_action: "Searching the web", browse_action: "Browsing page", + conversation_list_files_action: "Listing conversation files", }; /** diff --git a/types/src/front/lib/api/assistant/actions/index.ts b/types/src/front/lib/api/assistant/actions/index.ts index 28a0da1f2b29..ceaf4c7c0bbe 100644 --- a/types/src/front/lib/api/assistant/actions/index.ts +++ b/types/src/front/lib/api/assistant/actions/index.ts @@ -8,7 +8,8 @@ type BaseActionType = | "process_action" | "websearch_action" | "browse_action" - | "visualization_action"; + | "visualization_action" + | "conversation_list_files_action"; export abstract class BaseAction { readonly id: ModelId; diff --git a/types/src/index.ts b/types/src/index.ts index 9673a27a9d4d..57c97ba80867 100644 --- a/types/src/index.ts +++ b/types/src/index.ts @@ -21,6 +21,7 @@ export * from "./front/api_handlers/public/data_sources"; export * from "./front/api_handlers/public/spaces"; export * from "./front/app"; export * from "./front/assistant/actions/browse"; +export * from "./front/assistant/actions/conversation/list_files"; export * from "./front/assistant/actions/dust_app_run"; export * from "./front/assistant/actions/guards"; export * from "./front/assistant/actions/process";