diff --git a/front/admin/db.ts b/front/admin/db.ts index 29890f3a3d0be..214baa6956fc0 100644 --- a/front/admin/db.ts +++ b/front/admin/db.ts @@ -20,6 +20,7 @@ import { Key, Membership, MembershipInvitation, + Mention, Message, Provider, RetrievalDocument, @@ -69,6 +70,7 @@ async function main() { await UserMessage.sync({ alter: true }); await AgentMessage.sync({ alter: true }); await Message.sync({ alter: true }); + await Mention.sync({ alter: true }); await XP1User.sync({ alter: true }); await XP1Run.sync({ alter: true }); diff --git a/front/lib/api/assistant/conversation.ts b/front/lib/api/assistant/conversation.ts index b89e253484d3e..85e1b089a440c 100644 --- a/front/lib/api/assistant/conversation.ts +++ b/front/lib/api/assistant/conversation.ts @@ -9,17 +9,354 @@ import { import { GenerationTokensEvent } from "@app/lib/api/assistant/generation"; import { Authenticator } from "@app/lib/auth"; import { front_sequelize } from "@app/lib/databases"; -import { AgentMessage, Message, UserMessage } from "@app/lib/models"; +import { + AgentConfiguration, + AgentMessage, + AgentRetrievalAction, + Conversation, + Mention, + Message, + User, + UserMessage, +} from "@app/lib/models"; import { generateModelSId } from "@app/lib/utils"; import { AgentMessageType, ConversationType, + ConversationVisibility, isAgentMention, - Mention, + isUserMention, + MentionType, UserMessageContext, UserMessageType, } from "@app/types/assistant/conversation"; +/** + * Conversation Creation and Update + */ + +export async function createConversation( + auth: Authenticator, + { + title, + visibility, + }: { + title: string | null; + visibility: ConversationVisibility; + } +): Promise { + const conversation = await Conversation.create({ + sId: generateModelSId(), + title: title, + visibility: visibility, + }); + + return { + id: conversation.id, + created: conversation.createdAt.getTime(), + sId: conversation.sId, + title: conversation.title, + visibility: conversation.visibility, + content: [], + }; +} + +/** + * Conversation Rendering + */ + +async function renderUserMessage( + auth: Authenticator, + message: Message, + userMessage: UserMessage +): Promise { + const [mentions, user] = await Promise.all([ + Mention.findAll({ + where: { + messageId: message.id, + }, + include: [ + { + model: User, + as: "user", + required: false, + }, + { + model: AgentConfiguration, + as: "agentConfiguration", + required: false, + }, + ], + }), + (async () => { + if (userMessage.userId) { + return await User.findOne({ + where: { + id: userMessage.userId, + }, + }); + } + return null; + })(), + ]); + + return { + id: message.id, + sId: message.sId, + type: "user_message", + visibility: message.visibility, + version: message.version, + user: user + ? { + id: user.id, + provider: user.provider, + providerId: user.providerId, + username: user.username, + email: user.email, + name: user.name, + image: null, + workspaces: [], + isDustSuperUser: false, + } + : null, + mentions: mentions.map((m) => { + if (m.agentConfiguration) { + return { + id: m.id, + configurationId: m.agentConfiguration.sId, + }; + } + if (m.user) { + return { + id: m.id, + provider: m.user.provider, + providerId: m.user.providerId, + }; + } + throw new Error("Unreachable: mention must be either agent or user"); + }), + message: userMessage.message, + context: { + username: userMessage.userContextUsername, + timezone: userMessage.userContextTimezone, + fullName: userMessage.userContextFullName, + email: userMessage.userContextEmail, + profilePictureUrl: userMessage.userContextProfilePictureUrl, + }, + }; +} + +async function renderAgentMessage( + auth: Authenticator, + message: Message, + agentMessage: AgentMessage +): Promise { + const [agentConfiguration, agentRetrievalAction] = await Promise.all([ + AgentConfiguration.findOne({ + where: { + id: agentMessage.agentConfigurationId, + }, + }), + (async () => { + if (agentMessage.agentRetrievalActionId) { + return await AgentRetrievalAction.findOne({ + where: { + id: agentMessage.agentRetrievalActionId, + }, + }); + } + return null; + })(), + ]); + + if (!agentConfiguration) { + throw new Error( + `Agent configuration ${agentMessage.agentConfigurationId} not found` + ); + } + + return { + id: message.id, + sId: message.sId, + type: "agent_message", + visibility: message.visibility, + version: message.version, + parentMessageId: null, + status: agentMessage.status, + action: agentRetrievalAction + ? { + id: agentRetrievalAction.id, + type: "retrieval_action", + params: { + dataSources: [], // TODO + query: agentRetrievalAction.query, + relativeTimeFrame: null, // TODO + topK: agentRetrievalAction.topK, + }, + documents: [], // TODO + } + : null, + message: agentMessage.message, + feedbacks: [], + error: null, + configuration: { + sId: agentConfiguration.sId, + status: "active", + name: agentConfiguration.name, + pictureUrl: agentConfiguration.pictureUrl, + // TODO(spolu) + action: null, + generation: null, + }, + }; +} + +export async function getConversation( + auth: Authenticator, + conversationId: string +): Promise { + const conversation = await Conversation.findOne({ + where: { + sId: conversationId, + }, + }); + + if (!conversation) { + return null; + } + + const messages = await Message.findAll({ + where: { + conversationId: conversation.id, + }, + order: [ + ["rank", "ASC"], + ["version", "ASC"], + ], + include: [ + { + model: UserMessage, + as: "userMessage", + required: false, + }, + { + model: AgentMessage, + as: "agentMessage", + required: false, + include: [ + { + model: AgentRetrievalAction, + as: "agentRetrievalAction", + required: false, + }, + { + model: AgentConfiguration, + as: "agentConfiguration", + required: true, + }, + ], + }, + ], + }); + + const maxRank = messages.reduce((acc, m) => Math.max(acc, m.rank), -1); + const content: (UserMessageType | AgentMessageType)[][] = Array.from( + { length: maxRank + 1 }, + () => [] + ); + + for (const message of messages) { + if (message.userMessage) { + //content[message.rank].push({ + // id: message.id, + // sId: message.sId, + // type: "user_message", + // visibility: message.visibility, + // version: message.version, + // user: null, + // mentions: [], + // message: message.userMessage.message, + // context: { + // username: message.userMessage.userContextUsername, + // timezone: message.userMessage.userContextTimezone, + // fullName: message.userMessage.userContextFullName, + // email: message.userMessage.userContextEmail, + // profilePictureUrl: message.userMessage.userContextProfilePictureUrl, + // }, + //}); + } + if (message.agentMessage) { + // if (message.agentMessage.agentRetrievalActionId) { + // } + // content[message.rank].push({ + // id: message.id, + // sId: message.sId, + // type: "agent_message", + // visibility: message.visibility, + // version: message.version, + // parentMessageId: null, + // status: message.agentMessage.status, + // action: null, + // message: message.agentMessage.message, + // feedbacks: [], + // error: null, + // configuration: { + // sId: "foo", + // status: "active", + // name: "foo", // TODO + // pictureUrl: null, // TODO + // action: null, // TODO + // generation: null, // TODO + // }, + // }); + } + } + + return { + id: conversation.id, + created: conversation.createdAt.getTime(), + sId: conversation.sId, + title: conversation.title, + visibility: conversation.visibility, + content: [], + }; +} + +export async function updateConversation( + auth: Authenticator, + conversationId: string, + { + title, + visibility, + }: { + title: string | null; + visibility: ConversationVisibility; + } +): Promise { + const conversation = await Conversation.findOne({ + where: { + sId: conversationId, + }, + }); + + if (!conversation) { + throw new Error(`Conversation ${conversationId} not found`); + } + + await conversation.update({ + title: title, + visibility: visibility, + }); + + const c = await getConversation(auth, conversationId); + + if (!c) { + throw new Error(`Conversation ${conversationId} not found`); + } + + return c; +} + /** * Conversation API */ @@ -53,7 +390,7 @@ export async function* postUserMessage( }: { conversation: ConversationType; message: string; - mentions: Mention[]; + mentions: MentionType[]; context: UserMessageContext; } ): AsyncGenerator< @@ -122,8 +459,16 @@ export async function* postUserMessage( // for each assistant mention, create an "empty" agent message for (const mention of mentions) { if (isAgentMention(mention)) { + // TODO(spolu): retrieve configuration from mention. + // Mention.create({ + // messageId: m.id, + // configurationId: mention.configurationId, + // }); + const agentMessageRow = await AgentMessage.create( - {}, + { + // TODO(spolu): add agentConfigurationId + }, { transaction: t } ); const m = await Message.create( @@ -161,6 +506,22 @@ export async function* postUserMessage( }, }); } + + if (isUserMention(mention)) { + const user = await User.findOne({ + where: { + provider: mention.provider, + providerId: mention.providerId, + }, + }); + + if (user) { + await Mention.create({ + messageId: m.id, + userId: user.id, + }); + } + } } return { userMessage, agentMessages, agentMessageRows }; diff --git a/front/lib/models/assistant/conversation.ts b/front/lib/models/assistant/conversation.ts index a9ba1c72bd602..d3efd437085e2 100644 --- a/front/lib/models/assistant/conversation.ts +++ b/front/lib/models/assistant/conversation.ts @@ -5,17 +5,21 @@ import { InferAttributes, InferCreationAttributes, Model, + NonAttribute, } from "sequelize"; import { front_sequelize } from "@app/lib/databases"; import { AgentRetrievalAction } from "@app/lib/models/assistant/actions/retrieval"; import { User } from "@app/lib/models/user"; +import { Workspace } from "@app/lib/models/workspace"; import { AgentMessageStatus, ConversationVisibility, MessageVisibility, } from "@app/types/assistant/conversation"; +import { AgentConfiguration } from "./agent"; + export class Conversation extends Model< InferAttributes, InferCreationAttributes @@ -26,7 +30,6 @@ export class Conversation extends Model< declare sId: string; declare title: string | null; - declare created: Date; declare visibility: CreationOptional; } @@ -56,22 +59,29 @@ Conversation.init( type: DataTypes.TEXT, allowNull: true, }, - created: { - type: DataTypes.DATE, - allowNull: false, - }, visibility: { type: DataTypes.STRING, allowNull: false, - defaultValue: "private", + defaultValue: "unlisted", }, }, { modelName: "conversation", + indexes: [ + { + unique: true, + fields: ["sId"], + }, + ], sequelize: front_sequelize, } ); +Workspace.hasMany(Conversation, { + foreignKey: { allowNull: false }, + onDelete: "CASCADE", +}); + export class UserMessage extends Model< InferAttributes, InferCreationAttributes @@ -140,7 +150,7 @@ UserMessage.init( ); User.hasMany(UserMessage, { - foreignKey: { name: "userId" }, + foreignKey: { name: "userId", allowNull: true }, // null = message is not associated with a user }); export class AgentMessage extends Model< @@ -158,6 +168,7 @@ export class AgentMessage extends Model< declare errorMessage: string | null; declare agentRetrievalActionId: ForeignKey | null; + declare agentConfigurationId: ForeignKey; } AgentMessage.init( @@ -210,6 +221,12 @@ AgentMessage.init( AgentRetrievalAction.hasOne(AgentMessage, { foreignKey: { name: "agentRetrievalActionId", allowNull: true }, // null = no retrieval action set for this Agent onDelete: "CASCADE", + as: "agentRetrievalAction", +}); + +AgentConfiguration.hasMany(AgentMessage, { + foreignKey: { name: "agentConfigurationId", allowNull: false }, + as: "agentConfiguration", }); export class Message extends Model< @@ -231,6 +248,9 @@ export class Message extends Model< declare parentId: ForeignKey | null; declare userMessageId: ForeignKey | null; declare agentMessageId: ForeignKey | null; + + declare userMessage?: NonAttribute; + declare agentMessage?: NonAttribute; } Message.init( @@ -276,7 +296,11 @@ Message.init( indexes: [ { unique: true, - fields: ["version", "conversationId", "rank"], + fields: ["conversationId", "rank", "version"], + }, + { + unique: true, + fields: ["sId"], }, ], hooks: { @@ -296,12 +320,65 @@ Conversation.hasMany(Message, { }); UserMessage.hasOne(Message, { foreignKey: "userMessageId", - as: "_message", + as: "userMessage", }); AgentMessage.hasOne(Message, { foreignKey: "agentMessageId", - as: "_message", + as: "agentMessage", }); Message.belongsTo(Message, { foreignKey: "parentId", }); + +export class Mention extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: CreationOptional; + declare createdAt: CreationOptional; + declare updatedAt: CreationOptional; + + declare messageId: ForeignKey; + declare userId: ForeignKey | null; + declare agentConfigurationId: ForeignKey | null; + + declare user?: NonAttribute; + declare agentConfiguration?: NonAttribute; +} + +Mention.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + createdAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + updatedAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + }, + { + modelName: "mention", + sequelize: front_sequelize, + } +); + +Message.hasMany(Mention, { + foreignKey: { name: "messageId", allowNull: false }, + onDelete: "CASCADE", +}); +User.hasMany(Mention, { + foreignKey: { name: "userId", allowNull: true }, // null = mention is not a user mention + as: "user", +}); +AgentConfiguration.hasMany(Mention, { + foreignKey: { name: "agentConfigurationId", allowNull: true }, // null = mention is not an agent mention + as: "agentConfiguration", +}); diff --git a/front/lib/models/index.ts b/front/lib/models/index.ts index 0bc9edbbbdc07..c9d162cf3e091 100644 --- a/front/lib/models/index.ts +++ b/front/lib/models/index.ts @@ -13,6 +13,7 @@ import { import { AgentMessage, Conversation, + Mention, Message, UserMessage, } from "@app/lib/models/assistant/conversation"; @@ -59,6 +60,7 @@ export { Key, Membership, MembershipInvitation, + Mention, Message, Provider, RetrievalDocument, diff --git a/front/types/assistant/actions/retrieval.ts b/front/types/assistant/actions/retrieval.ts index 2544edeb7ee7b..09a26ca6cab8a 100644 --- a/front/types/assistant/actions/retrieval.ts +++ b/front/types/assistant/actions/retrieval.ts @@ -88,7 +88,6 @@ export type RetrievalActionType = { id: ModelId; // AgentRetrieval. type: "retrieval_action"; params: { - dataSources: "all" | DataSourceConfiguration[]; relativeTimeFrame: TimeFrame | null; query: string | null; topK: number; diff --git a/front/types/assistant/conversation.ts b/front/types/assistant/conversation.ts index 14ff04663172f..12150ca2888ab 100644 --- a/front/types/assistant/conversation.ts +++ b/front/types/assistant/conversation.ts @@ -9,23 +9,25 @@ import { AgentConfigurationType } from "./agent"; */ export type AgentMention = { + id: ModelId; configurationId: string; }; export type UserMention = { + id: ModelId; provider: string; providerId: string; }; -export type Mention = AgentMention | UserMention; +export type MentionType = AgentMention | UserMention; export type MessageVisibility = "visible" | "deleted"; -export function isAgentMention(arg: Mention): arg is AgentMention { +export function isAgentMention(arg: MentionType): arg is AgentMention { return (arg as AgentMention).configurationId !== undefined; } -export function isUserMention(arg: Mention): arg is UserMention { +export function isUserMention(arg: MentionType): arg is UserMention { const maybeUserMention = arg as UserMention; return ( maybeUserMention.provider !== undefined && @@ -52,7 +54,7 @@ export type UserMessageType = { visibility: MessageVisibility; version: number; user: UserType | null; - mentions: Mention[]; + mentions: MentionType[]; message: string; context: UserMessageContext; }; @@ -112,7 +114,7 @@ export function isAgentMessageType( * Conversations */ -export type ConversationVisibility = "private" | "workspace"; +export type ConversationVisibility = "unlisted" | "workspace"; /** * content [][] structure is intended to allow retries (of agent messages) or edits (of user @@ -123,6 +125,6 @@ export type ConversationType = { created: number; sId: string; title: string | null; - content: (UserMessageType[] | AgentMessageType[])[]; visibility: ConversationVisibility; + content: (UserMessageType[] | AgentMessageType[])[]; };