Skip to content

Commit

Permalink
Add: ability to do a table query on attachment (#8796)
Browse files Browse the repository at this point in the history
* Add: ability to do a table query on attachment

* Review fdbk
  • Loading branch information
Fraggle authored Nov 21, 2024
1 parent 0eea1c4 commit 17347f8
Show file tree
Hide file tree
Showing 11 changed files with 210 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ export class ConversationIncludeFileAction extends BaseAction {
fileId,
title: m.title,
contentType: m.contentType,
snippet: m.snippet,
},
content: text,
});
Expand Down
62 changes: 19 additions & 43 deletions front/lib/api/assistant/actions/conversation/list_files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,11 @@ import type {
FunctionMessageTypeModel,
ModelId,
} from "@dust-tt/types";
import {
BaseAction,
getTablesQueryResultsFileTitle,
isAgentMessageType,
isContentFragmentType,
isSupportedPlainTextContentType,
isTablesQueryActionType,
} from "@dust-tt/types";
import { BaseAction } from "@dust-tt/types";
import _ from "lodash";

import { isConversationIncludableFileContentType } from "@app/lib/api/assistant/actions/conversation/include_file";
import { listFiles } from "@app/lib/api/assistant/jit_actions";

interface ConversationListFilesActionBlob {
agentMessageId: ModelId;
Expand Down Expand Up @@ -59,8 +54,14 @@ export class ConversationListFilesAction extends BaseAction {
`\n`;
for (const f of this.files) {
content +=
`<file id="${f.fileId}" name="${f.title}" type="${f.contentType}" ` +
`includable="${isConversationIncludableFileContentType(f.contentType)}"/>\n`;
`<file id="${f.fileId}" name="${_.escape(f.title)}" type="${f.contentType}" ` +
`includable="${isConversationIncludableFileContentType(f.contentType)}" queryable="${!!f.snippet}"`;

if (f.snippet) {
content += ` snippet="${_.escape(f.snippet)}"`;
}

content += "/>\n";
}

return {
Expand All @@ -72,39 +73,14 @@ export class ConversationListFilesAction extends BaseAction {
}
}

export function makeConversationListFilesAction(
agentMessage: AgentMessageType,
conversation: ConversationType
): ConversationListFilesActionType | null {
const files: ConversationFileType[] = [];

for (const m of conversation.content.flat(1)) {
if (
isContentFragmentType(m) &&
isSupportedPlainTextContentType(m.contentType) &&
m.contentFragmentVersion === "latest"
) {
if (m.fileId) {
files.push({
fileId: m.fileId,
title: m.title,
contentType: m.contentType,
});
}
} else if (isAgentMessageType(m)) {
for (const a of m.actions) {
if (isTablesQueryActionType(a)) {
if (a.resultsFileId && a.resultsFileSnippet) {
files.push({
fileId: a.resultsFileId,
contentType: "text/csv",
title: getTablesQueryResultsFileTitle({ output: a.output }),
});
}
}
}
}
}
export function makeConversationListFilesAction({
agentMessage,
conversation,
}: {
agentMessage: AgentMessageType;
conversation: ConversationType;
}): ConversationListFilesActionType | null {
const files: ConversationFileType[] = listFiles(conversation);

if (files.length === 0) {
return null;
Expand Down
17 changes: 15 additions & 2 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ import {
import assert from "assert";

import { runActionStreamed } from "@app/lib/actions/server";
import { isJITActionsEnabled } from "@app/lib/api/assistant//jit_actions";
import {
getJITActions,
isJITActionsEnabled,
} from "@app/lib/api/assistant//jit_actions";
import { makeConversationListFilesAction } from "@app/lib/api/assistant/actions/conversation/list_files";
import { getRunnerForActionConfiguration } from "@app/lib/api/assistant/actions/runners";
import { getCitationsCount } from "@app/lib/api/assistant/actions/utils";
Expand Down Expand Up @@ -91,6 +94,13 @@ export async function* runAgent(
throw new Error("Unreachable: could not find owner workspace for agent");
}

// Add JIT actions for available files in the conversation.
if (await isJITActionsEnabled(auth)) {
fullConfiguration.actions = fullConfiguration.actions.concat(
await getJITActions(auth, { conversation })
);
}

const stream = runMultiActionsAgentLoop(
auth,
fullConfiguration,
Expand Down Expand Up @@ -306,7 +316,10 @@ async function getEmulatedAgentMessageActions(
): Promise<AgentActionType[]> {
const actions: AgentActionType[] = [];
if (await isJITActionsEnabled(auth)) {
const a = makeConversationListFilesAction(agentMessage, conversation);
const a = makeConversationListFilesAction({
agentMessage,
conversation,
});
if (a) {
actions.push(a);
}
Expand Down
13 changes: 6 additions & 7 deletions front/lib/api/assistant/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1832,14 +1832,13 @@ export async function postNewContentFragment(
return { contentFragment, messageRow };
}
);
const render = await contentFragment.renderFromMessage({
auth,
conversationId: conversation.sId,
message: messageRow,
});

return new Ok(
contentFragment.renderFromMessage({
auth,
conversationId: conversation.sId,
message: messageRow,
})
);
return new Ok(render);
}

async function* streamRunAgentEvents(
Expand Down
106 changes: 106 additions & 0 deletions front/lib/api/assistant/jit_actions.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
import type {
AgentActionConfigurationType,
AssistantContentMessageTypeModel,
AssistantFunctionCallMessageTypeModel,
ConversationFileType,
ConversationType,
FunctionCallType,
FunctionMessageTypeModel,
ModelConfigurationType,
ModelConversationTypeMultiActions,
ModelMessageTypeMultiActions,
Result,
TablesQueryConfigurationType,
} from "@dust-tt/types";
import {
assertNever,
Err,
getTablesQueryResultsFileTitle,
isAgentMessageType,
isContentFragmentMessageTypeModel,
isContentFragmentType,
isDevelopment,
isSupportedPlainTextContentType,
isTablesQueryActionType,
isTextContent,
isUserMessageType,
Ok,
Expand All @@ -29,6 +35,8 @@ import {
import type { Authenticator } from "@app/lib/auth";
import { getFeatureFlags } from "@app/lib/auth";
import { renderContentFragmentForModel } from "@app/lib/resources/content_fragment_resource";
import { DataSourceViewResource } from "@app/lib/resources/data_source_view_resource";
import { generateRandomModelSId } from "@app/lib/resources/string_ids";
import { tokenCountForTexts, tokenSplit } from "@app/lib/tokenization";
import logger from "@app/logger/logger";

Expand All @@ -47,6 +55,104 @@ export async function isJITActionsEnabled(
return use;
}

export function listFiles(
conversation: ConversationType
): ConversationFileType[] {
const files: ConversationFileType[] = [];
for (const m of conversation.content.flat(1)) {
if (
isContentFragmentType(m) &&
isSupportedPlainTextContentType(m.contentType) &&
m.contentFragmentVersion === "latest"
) {
if (m.fileId) {
files.push({
fileId: m.fileId,
title: m.title,
contentType: m.contentType,
snippet: m.snippet,
});
}
} else if (isAgentMessageType(m)) {
for (const a of m.actions) {
if (isTablesQueryActionType(a)) {
if (a.resultsFileId && a.resultsFileSnippet) {
files.push({
fileId: a.resultsFileId,
contentType: "text/csv",
title: getTablesQueryResultsFileTitle({ output: a.output }),
snippet: null, // This means that we can't use it for JIT actions (the resultsFileSnippet is not the same snippet)
});
}
}
}
}
}

return files;
}

export async function getJITActions(
auth: Authenticator,
{ conversation }: { conversation: ConversationType }
): Promise<AgentActionConfigurationType[]> {
const actions: AgentActionConfigurationType[] = [];

if (await isJITActionsEnabled(auth)) {
const files = listFiles(conversation);
if (files.length > 0) {
const filesUsableForJIT = files.filter((f) => !!f.snippet);

if (filesUsableForJIT.length > 0) {
// Get the datasource view for the conversation.
const dataSourceView = await DataSourceViewResource.fetchByConversation(
auth,
conversation
);

if (!dataSourceView) {
logger.warn(
{
conversationId: conversation.sId,
fileIds: filesUsableForJIT.map((f) => f.fileId),
workspaceId: conversation.owner.sId,
},
"No default datasource view found for conversation when trying to get JIT actions"
);

return [];
}

// Check tables for the table query action.
const filesUsableAsTableQuery = filesUsableForJIT.filter(
(f) => f.contentType === "text/csv" // TODO: there should not be a hardcoded value here
);

if (filesUsableAsTableQuery.length > 0) {
// TODO(jit) Shall we look for an existing table query action and update it instead of creating a new one? This would allow join between the tables.
const action: TablesQueryConfigurationType = {
description: filesUsableAsTableQuery
.map((f) => `tableId: ${f.fileId}\n${f.snippet}`)
.join("\n\n"),
type: "tables_query_configuration",
id: -1,
name: "query_conversation_tables",
sId: generateRandomModelSId(),
tables: filesUsableAsTableQuery.map((f) => ({
workspaceId: auth.getNonNullableWorkspace().sId,
dataSourceViewId: dataSourceView.sId,
tableId: f.fileId,
})),
};
actions.push(action);
}
}
}
}

return actions;
}

/**
* Model conversation rendering - JIT actions
*/
Expand Down
23 changes: 15 additions & 8 deletions front/lib/api/assistant/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,22 @@ async function batchRenderContentFragment(
);
}

return messagesWithContentFragment.map((message: Message) => {
const contentFragment = ContentFragmentResource.fromMessage(message);
return Promise.all(
messagesWithContentFragment.map(async (message: Message) => {
const contentFragment = ContentFragmentResource.fromMessage(message);
const render = await contentFragment.renderFromMessage({
auth,
conversationId,
message,
});

return {
m: contentFragment.renderFromMessage({ auth, conversationId, message }),
rank: message.rank,
version: message.version,
};
});
return {
m: render,
rank: message.rank,
version: message.version,
};
})
);
}

/**
Expand Down
3 changes: 2 additions & 1 deletion front/lib/api/files/upsert.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
isSupportedPlainTextContentType,
Ok,
removeNulls,
slugify,
} from "@dust-tt/types";
import { Writable } from "stream";
import { pipeline } from "stream/promises";
Expand Down Expand Up @@ -163,7 +164,7 @@ const upsertTableToDatasource: ProcessingFunction = async ({
const tableId = file.sId; // Use the file sId as the table id to make it easy to track the table back to the file.
const upsertTableRes = await upsertTable({
tableId,
name: file.fileName,
name: slugify(file.fileName),
description: "Table uploaded from file",
truncate: true,
csv: content,
Expand Down
21 changes: 16 additions & 5 deletions front/lib/resources/content_fragment_resource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,15 @@ export class ContentFragmentResource extends BaseResource<ContentFragmentModel>
return this.update({ sourceUrl });
}

renderFromMessage({
async renderFromMessage({
auth,
conversationId,
message,
}: {
auth: Authenticator;
conversationId: string;
message: Message;
}): ContentFragmentType {
}): Promise<ContentFragmentType> {
const owner = auth.workspace();
if (!owner) {
throw new Error(
Expand All @@ -226,11 +226,22 @@ export class ContentFragmentResource extends BaseResource<ContentFragmentModel>
contentFormat: "text",
});

let fileSid: string | null = null;
let snippet: string | null = null;

if (this.fileId) {
const file = await FileResource.fetchByModelId(this.fileId);
fileSid = file?.sId ?? null;

// Note: For CSV files outputted by tools, we have a "snippet" version of the output with the first rows stored in GCP, maybe it's better than our "summary" snippet stored on File.
// Need more testing, for now we are using the "summary" snippet.
snippet = file?.snippet ?? null;
}

return {
id: message.id,
fileId: this.fileId
? FileResource.modelIdToSId({ id: this.fileId, workspaceId: owner.id })
: null,
fileId: fileSid,
snippet: snippet,
sId: message.sId,
created: message.createdAt.getTime(),
type: "content_fragment",
Expand Down
Loading

0 comments on commit 17347f8

Please sign in to comment.