Skip to content

Commit

Permalink
Render batch for agent messages & content fragment
Browse files Browse the repository at this point in the history
  • Loading branch information
PopDaph committed Nov 27, 2023
1 parent 0e05c2d commit 743ad37
Showing 1 changed file with 171 additions and 81 deletions.
252 changes: 171 additions & 81 deletions front/lib/api/assistant/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import { GPT_3_5_TURBO_MODEL_CONFIG } from "@app/lib/assistant";
import { Authenticator } from "@app/lib/auth";
import { front_sequelize } from "@app/lib/databases";
import {
AgentConfiguration,
AgentDustAppRunAction,
AgentMessage,
Conversation,
ConversationParticipant,
Expand All @@ -38,6 +40,7 @@ import { updateWorkspacePerMonthlyActiveUsersSubscriptionUsage } from "@app/lib/
import { Err, Ok, Result } from "@app/lib/result";
import { generateModelSId } from "@app/lib/utils";
import logger from "@app/logger/logger";
import { AgentConfigurationType } from "@app/types/assistant/agent";
import {
AgentMessageType,
ContentFragmentContentType,
Expand All @@ -59,6 +62,7 @@ import { WorkspaceType } from "@app/types/user";

import { renderDustAppRunActionByModelId } from "./actions/dust_app_run";
import { renderRetrievalActionByModelId } from "./actions/retrieval";
import { getGlobalAgents } from "./global_agents";
/**
* Conversation Creation, update and deletion
*/
Expand Down Expand Up @@ -278,69 +282,135 @@ async function batchRenderUserMessages(messages: Message[]) {
});
}

async function renderAgentMessage(
async function batchRenderAgentMessages(
auth: Authenticator,
{
message,
agentMessage,
messages,
}: { message: Message; agentMessage: AgentMessage; messages: Message[] }
): Promise<AgentMessageType> {
const [agentConfiguration, agentRetrievalAction, agentDustAppRunAction] =
messages: Message[]
) {
if (messages.find((m) => !m.agentMessage)) {
throw new Error(
"Unreachable: batchRenderAgentMessages must be called with only agent messages"
);
}

const [agentConfigurations, agentRetrievalActions, agentDustAppRunActions] =
await Promise.all([
getAgentConfiguration(auth, agentMessage.agentConfigurationId),
(async () => {
if (agentMessage.agentRetrievalActionId) {
return await renderRetrievalActionByModelId(
agentMessage.agentRetrievalActionId
);
}
return null;
const agentConfigurationIds: string[] = messages.reduce(
(acc: string[], m) => {
const agentId = m.agentMessage?.agentConfigurationId;
if (agentId && !acc.includes(agentId)) {
acc.push(agentId);
}
return acc;
},
[]
);
const agents = (
await Promise.all(
agentConfigurationIds.map(async (agentConfigId) => {
return await getAgentConfiguration(auth, agentConfigId);
})
)
).filter((a) => a !== null) as AgentConfigurationType[];
const globalAgents = await getGlobalAgents(auth);
return [...globalAgents, ...agents];
})(),
(async () => {
if (agentMessage.agentDustAppRunActionId) {
return await renderDustAppRunActionByModelId(
agentMessage.agentDustAppRunActionId
);
}
return null;
const agentRetrievalActionIds: number[] = messages.reduce(
(acc: number[], m) => {
const agentId = m.agentMessage?.agentRetrievalActionId;
if (agentId && !acc.includes(agentId)) {
acc.push(agentId);
}
return acc;
},
[]
);
return await Promise.all(
agentRetrievalActionIds.map(async (agentRetrievalActionId) => {
return await renderRetrievalActionByModelId(agentRetrievalActionId);
})
);
})(),
(async () => {
const agentDustAppRunActionsIds: number[] = messages.reduce(
(acc: number[], m) => {
const agentId = m.agentMessage?.agentDustAppRunActionId;
if (agentId && !acc.includes(agentId)) {
acc.push(agentId);
}
return acc;
},
[]
);
const actions = await AgentDustAppRunAction.findAll({
where: {
id: {
[Op.in]: agentDustAppRunActionsIds,
},
},
});
return actions.map((action) => {
return {
id: action.id,
type: "dust_app_run_action",
appWorkspaceId: action.appWorkspaceId,
appId: action.appId,
appName: action.appName,
params: action.params,
runningBlock: null,
output: action.output,
};
});
})(),
]);

if (!agentConfiguration) {
throw new Error(
`Configuration ${agentMessage.agentConfigurationId} not found`
return messages.map((message) => {
if (!message.agentMessage) {
throw new Error(
"Unreachable: batchRenderUserMessages must be called with only user messages"
);
}
const agentMessage = message.agentMessage;
const action =
agentRetrievalActions.find(
(a) => a.id === agentMessage?.agentRetrievalActionId
) ??
agentDustAppRunActions.find(
(a) => a.id === agentMessage.agentDustAppRunActionId
);
const agentConfiguration = agentConfigurations.find(
(a) => a.sId === message.agentMessage?.agentConfigurationId
);
}

const action = agentRetrievalAction ?? agentDustAppRunAction;
let error: {
code: string;
message: string;
} | null = null;
if (agentMessage.errorCode !== null && agentMessage.errorMessage !== null) {
error = {
code: agentMessage.errorCode,
message: agentMessage.errorMessage,
};
}

let error: {
code: string;
message: string;
} | null = null;
if (agentMessage.errorCode !== null && agentMessage.errorMessage !== null) {
error = {
code: agentMessage.errorCode,
message: agentMessage.errorMessage,
const m = {
id: message.id,
sId: message.sId,
created: message.createdAt.getTime(),
type: "agent_message",
visibility: message.visibility,
version: message.version,
parentMessageId:
messages.find((m) => m.id === message.parentId)?.sId ?? null,
status: agentMessage.status,
action,
content: agentMessage.content,
error,
configuration: agentConfiguration,
};
}

return {
id: message.id,
sId: message.sId,
created: message.createdAt.getTime(),
type: "agent_message",
visibility: message.visibility,
version: message.version,
parentMessageId:
messages.find((m) => m.id === message.parentId)?.sId ?? null,
status: agentMessage.status,
action,
content: agentMessage.content,
error,
configuration: agentConfiguration,
};
return { m, rank: message.rank, version: message.version };
});
}

function renderContentFragment({
Expand Down Expand Up @@ -370,6 +440,43 @@ function renderContentFragment({
};
}

async function batchRenderContentFragment(messages: Message[]) {
if (messages.find((m) => !m.contentFragment)) {
throw new Error(
"Unreachable: batchRenderContentFragment must be called with only content fragments"
);
}

return messages.map((message) => {
if (!message.contentFragment) {
throw new Error(
"Unreachable: batchRenderContentFragment must be called with only content fragments"
);
}
const contentFragment = message.contentFragment;

const m = {
id: message.id,
sId: message.sId,
created: message.createdAt.getTime(),
type: "content_fragment",
visibility: message.visibility,
version: message.version,
title: contentFragment.title,
content: contentFragment.content,
url: contentFragment.url,
contentType: contentFragment.contentType,
context: {
profilePictureUrl: contentFragment.userContextProfilePictureUrl,
fullName: contentFragment.userContextFullName,
email: contentFragment.userContextEmail,
username: contentFragment.userContextUsername,
},
};
return { m, rank: message.rank, version: message.version };
});
}

export async function getUserConversations(
auth: Authenticator,
includeDeleted?: boolean
Expand Down Expand Up @@ -476,43 +583,26 @@ export async function getConversation(
],
});

const [userMessages] = await Promise.all([
const [userMessages, agentMessages, contentFragments] = await Promise.all([
(async () => {
return await batchRenderUserMessages(
messages.filter((m) => !!m.userMessage)
);
})(),
(async () => {
return await batchRenderAgentMessages(
auth,
messages.filter((m) => !!m.agentMessage)
);
})(),
(async () => {
return await batchRenderContentFragment(
messages.filter((m) => !!m.contentFragment)
);
})(),
]);

const renderAgentAndContentFragments = await Promise.all(
messages
.filter((m) => !m.userMessage)
.map((message) => {
return (async () => {
if (message.agentMessage) {
const m = await renderAgentMessage(auth, {
message,
agentMessage: message.agentMessage,
messages,
});
return { m, rank: message.rank, version: message.version };
}
if (message.contentFragment) {
const m = await renderContentFragment({
message: message,
contentFragment: message.contentFragment,
});
return { m, rank: message.rank, version: message.version };
}
throw new Error(
"Unreachable: message must be either user, agent or content fragment"
);
})();
})
);

const render = [...userMessages, ...renderAgentAndContentFragments];

const render = [...userMessages, ...agentMessages, ...contentFragments];
render.sort((a, b) => {
if (a.rank !== b.rank) {
return a.rank - b.rank;
Expand Down

0 comments on commit 743ad37

Please sign in to comment.