Skip to content

Commit

Permalink
Conversation actions scafolding + initial list_files action (#8646)
Browse files Browse the repository at this point in the history
* jit actions scafolding

* list_files types

* more

* simplify

* simplify

* clean

* simplify

* one way of doing it

* fix imports

* typing work

* fixes

* rename

* emulated actions injection/scrub

* reintroduce step

* fix imports

* clean-up visualiation

* fix steps

* rebase and comments

* rename jitFiles

* assert message

* stepByIndex

* remove comments
  • Loading branch information
spolu authored Nov 15, 2024
1 parent d3f6e3d commit 0e7714d
Show file tree
Hide file tree
Showing 10 changed files with 271 additions and 152 deletions.
4 changes: 4 additions & 0 deletions front/components/actions/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ const actionsSpecification: ActionSpecifications = {
detailsComponent: BrowseActionDetails,
runningLabel: ACTION_RUNNING_LABELS.browse_action,
},
conversation_list_files_action: {
detailsComponent: () => null,
runningLabel: ACTION_RUNNING_LABELS.conversation_list_files_action,
},
};

export function getActionSpecification<T extends ActionType>(
Expand Down
105 changes: 105 additions & 0 deletions front/lib/api/assistant/actions/conversation/list_files.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import type {
AgentMessageType,
ConversationFileType,
ConversationListFilesActionType,
ConversationType,
FunctionCallType,
FunctionMessageTypeModel,
ModelId,
} from "@dust-tt/types";
import {
BaseAction,
getTablesQueryResultsFileTitle,
isAgentMessageType,
isContentFragmentType,
isTablesQueryActionType,
} from "@dust-tt/types";

interface ConversationListFilesActionBlob {
agentMessageId: ModelId;
functionCallId: string | null;
functionCallName: string | null;
files: ConversationFileType[];
}

export class ConversationListFilesAction extends BaseAction {
readonly agentMessageId: ModelId;
readonly files: ConversationFileType[];
readonly functionCallId: string | null;
readonly functionCallName: string | null;
readonly step: number = -1;
readonly type = "conversation_list_files_action";

constructor(blob: ConversationListFilesActionBlob) {
super(-1, "conversation_list_files_action");

this.agentMessageId = blob.agentMessageId;
this.files = blob.files;
this.functionCallId = blob.functionCallId;
this.functionCallName = blob.functionCallName;
}

renderForFunctionCall(): FunctionCallType {
return {
id: this.functionCallId ?? `call_${this.id.toString()}`,
name: this.functionCallName ?? "list_conversation_files",
arguments: JSON.stringify({}),
};
}

renderForMultiActionsModel(): FunctionMessageTypeModel {
let content = "CONVERSATION FILES:\n";
for (const f of this.files) {
content += `<file id="${f.fileId}" name="${f.title}" type="${f.contentType}" />\n`;
}

return {
role: "function" as const,
name: this.functionCallName ?? "list_conversation_files",
function_call_id: this.functionCallId ?? `call_${this.id.toString()}`,
content,
};
}
}

export function makeConversationListFilesAction(
agentMessage: AgentMessageType,
conversation: ConversationType
): ConversationListFilesActionType | null {
const files: ConversationFileType[] = [];

for (const m of conversation.content.flat(1)) {
if (isContentFragmentType(m)) {
if (m.fileId) {
files.push({
fileId: m.fileId,
title: m.title,
contentType: m.contentType,
});
}
} else if (isAgentMessageType(m)) {
for (const a of m.actions) {
if (isTablesQueryActionType(a)) {
if (a.resultsFileId && a.resultsFileSnippet) {
files.push({
fileId: a.resultsFileId,
contentType: "text/csv",
title: getTablesQueryResultsFileTitle({ output: a.output }),
});
}
}
}
}
}

if (files.length === 0) {
return null;
}

return new ConversationListFilesAction({
functionCallId: "call_" + Math.random().toString(36).substring(7),
functionCallName: "list_conversation_files",
files,
agentMessageId: agentMessage.agentMessageId,
});
}
44 changes: 42 additions & 2 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type {
AgentActionSpecification,
AgentActionSpecificEvent,
AgentActionSuccessEvent,
AgentActionType,
AgentChainOfThoughtEvent,
AgentConfigurationType,
AgentContentEvent,
Expand All @@ -29,9 +30,13 @@ import {
isWebsearchConfiguration,
SUPPORTED_MODEL_CONFIGS,
} from "@dust-tt/types";
import assert from "assert";

import { runActionStreamed } from "@app/lib/actions/server";
import { isJITActionsEnabled } from "@app/lib/api/assistant//jit_actions";
import { makeConversationListFilesAction } from "@app/lib/api/assistant/actions/conversation/list_files";
import { getRunnerForActionConfiguration } from "@app/lib/api/assistant/actions/runners";
import { getCitationsCount } from "@app/lib/api/assistant/actions/utils";
import {
AgentMessageContentParser,
getDelimitersConfiguration,
Expand All @@ -49,8 +54,6 @@ import { AgentMessageContent } from "@app/lib/models/assistant/agent_message_con
import { cloneBaseConfig, DustProdActionRegistry } from "@app/lib/registry";
import logger from "@app/logger/logger";

import { getCitationsCount } from "./actions/utils";

const CANCELLATION_CHECK_INTERVAL = 500;
const MAX_ACTIONS_PER_STEP = 16;

Expand Down Expand Up @@ -292,6 +295,29 @@ async function* runMultiActionsAgentLoop(
}
}

async function getEmulatedAgentMessageActions(
auth: Authenticator,
{
agentMessage,
conversation,
}: { agentMessage: AgentMessageType; conversation: ConversationType }
): Promise<AgentActionType[]> {
const actions: AgentActionType[] = [];
if (await isJITActionsEnabled(auth)) {
const a = makeConversationListFilesAction(agentMessage, conversation);
if (a) {
actions.push(a);
}
}

// We ensure that all emulated actions are injected with step -1.
assert(
actions.every((a) => a.step === -1),
"Emulated actions must have step -1"
);
return actions;
}

// This method is used by the multi-actions execution loop to pick the next action to execute and
// generate its inputs.
async function* runMultiActionsAgent(
Expand Down Expand Up @@ -362,6 +388,15 @@ async function* runMultiActionsAgent(

const MIN_GENERATION_TOKENS = 2048;

const emulatedActions = await getEmulatedAgentMessageActions(auth, {
agentMessage,
conversation,
});

// Prepend emulated actions to the current agent message before rendering the conversation for the
// model.
agentMessage.actions = emulatedActions.concat(agentMessage.actions);

// Turn the conversation into a digest that can be presented to the model.
const modelConversationRes = await renderConversationForModel(auth, {
conversation,
Expand All @@ -370,6 +405,11 @@ async function* runMultiActionsAgent(
allowedTokenCount: model.contextSize - MIN_GENERATION_TOKENS,
});

// Scrub emulated actions from the agent message after rendering.
agentMessage.actions = agentMessage.actions.filter(
(a) => !emulatedActions.includes(a)
);

if (modelConversationRes.isErr()) {
logger.error(
{
Expand Down
121 changes: 25 additions & 96 deletions front/lib/api/assistant/jit_actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ import type {
import {
assertNever,
Err,
getTablesQueryResultsFileAttachment,
isAgentMessageType,
isContentFragmentMessageTypeModel,
isContentFragmentType,
isDevelopment,
isTablesQueryActionType,
isTextContent,
isUserMessageType,
Ok,
Expand Down Expand Up @@ -81,46 +79,55 @@ export async function renderConversationForModelJIT({
const now = Date.now();
const messages: ModelMessageTypeMultiActions[] = [];

// Render loop.
// Render all messages and all actions.
// Render loop: dender all messages and all actions.
for (const versions of conversation.content) {
const m = versions[versions.length - 1];

if (isAgentMessageType(m)) {
const actions = removeNulls(m.actions);

// This array is 2D, because we can have multiple calls per agent message (parallel calls).

const steps = [] as Array<{
contents: string[];
actions: Array<{
call: FunctionCallType;
result: FunctionMessageTypeModel;
}>;
}>;
// This is a record of arrays, because we can have multiple calls per agent message (parallel
// calls). Actions all have a step index which indicates how they should be grouped but some
// actions injected by `getEmulatedAgentMessageActions` have a step index of `-1`. We
// therefore group by index, then order and transform in a 2D array to present to the model.
const stepByStepIndex = {} as Record<
string,
{
contents: string[];
actions: Array<{
call: FunctionCallType;
result: FunctionMessageTypeModel;
}>;
}
>;

const emptyStep = () =>
({
contents: [],
actions: [],
}) satisfies (typeof steps)[number];
}) satisfies (typeof stepByStepIndex)[number];

for (const action of actions) {
const stepIndex = action.step;
steps[stepIndex] = steps[stepIndex] || emptyStep();
steps[stepIndex].actions.push({
stepByStepIndex[stepIndex] = stepByStepIndex[stepIndex] || emptyStep();
stepByStepIndex[stepIndex].actions.push({
call: action.renderForFunctionCall(),
result: action.renderForMultiActionsModel(),
});
}

for (const content of m.rawContents) {
steps[content.step] = steps[content.step] || emptyStep();
stepByStepIndex[content.step] =
stepByStepIndex[content.step] || emptyStep();
if (content.content.trim()) {
steps[content.step].contents.push(content.content);
stepByStepIndex[content.step].contents.push(content.content);
}
}

const steps = Object.entries(stepByStepIndex)
.sort(([a], [b]) => Number(a) - Number(b))
.map(([, step]) => step);

if (excludeActions) {
// In Exclude Actions mode, we only render the last step that has content.
const stepsWithContent = steps.filter((s) => s?.contents.length);
Expand Down Expand Up @@ -222,44 +229,6 @@ export async function renderConversationForModelJIT({
}
}

// If we have messages...
if (messages.length > 0) {
const { filesAsXML, hasFiles } = listConversationFiles({
conversation,
});

// ... and files, we simulate a function call to list the files at the end of the conversation.
if (hasFiles) {
const randomCallId = "tool_" + Math.random().toString(36).substring(7);
const functionName = "list_conversation_files";

const simulatedAgentMessages = [
// 1. We add a message from the agent, asking to use the files listing function
{
role: "assistant",
function_calls: [
{
id: randomCallId,
name: functionName,
arguments: "{}",
},
],
} as AssistantFunctionCallMessageTypeModel,

// 2. We add a message with the resulting files listing
{
function_call_id: randomCallId,
role: "function",
name: functionName,
content: filesAsXML,
} as FunctionMessageTypeModel,
];

// Append the simulated messages to the end of the conversation.
messages.push(...simulatedAgentMessages);
}
}

// Compute in parallel the token count for each message and the prompt.
const res = await tokenCountForTexts(
[prompt, ...getTextRepresentationFromMessages(messages)],
Expand Down Expand Up @@ -402,43 +371,3 @@ export async function renderConversationForModelJIT({
tokensUsed,
});
}

function listConversationFiles({
conversation,
}: {
conversation: ConversationType;
}) {
const fileAttachments: string[] = [];
for (const m of conversation.content.flat(1)) {
if (isContentFragmentType(m)) {
if (!m.fileId) {
continue;
}
fileAttachments.push(
`<file id="${m.fileId}" name="${m.title}" type="${m.contentType}" />`
);
} else if (isAgentMessageType(m)) {
for (const a of m.actions) {
if (isTablesQueryActionType(a)) {
const attachment = getTablesQueryResultsFileAttachment({
resultsFileId: a.resultsFileId,
resultsFileSnippet: a.resultsFileSnippet,
output: a.output,
includeSnippet: false,
});
if (attachment) {
fileAttachments.push(attachment);
}
}
}
}
}
let filesAsXML = "<files>\n";

if (fileAttachments.length > 0) {
filesAsXML += fileAttachments.join("\n");
}
filesAsXML += "\n</files>";

return { filesAsXML, hasFiles: fileAttachments.length > 0 };
}
Loading

0 comments on commit 0e7714d

Please sign in to comment.