From 25016554fc94a34a679aef4095f4aefb9c504738 Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Wed, 6 Sep 2023 18:01:55 +0200 Subject: [PATCH] RetrievalAction models --- front/lib/api/assistant/actions/retrieval.ts | 5 +- .../lib/models/assistant/actions/retrieval.ts | 222 ++++++++++++++++++ front/lib/models/assistant/conversation.ts | 14 ++ front/lib/models/index.ts | 12 + 4 files changed, 250 insertions(+), 3 deletions(-) diff --git a/front/lib/api/assistant/actions/retrieval.ts b/front/lib/api/assistant/actions/retrieval.ts index 554a95027f382..cb3dfe49805b4 100644 --- a/front/lib/api/assistant/actions/retrieval.ts +++ b/front/lib/api/assistant/actions/retrieval.ts @@ -11,7 +11,6 @@ import { Err, Ok, Result } from "@app/lib/result"; import logger from "@app/logger/logger"; import { DataSourceConfiguration, - DataSourceFilter, isRetrievalConfiguration, RetrievalActionType, RetrievalConfigurationType, @@ -299,8 +298,8 @@ export type RetrievalSuccessEvent = { action: RetrievalActionType; }; -// This method is in charge of running the retrieval and creating an AgentRetrieval DB -// object in the database (along with the RetrievedDocument objects). It does not create any generic +// This method is in charge of running the retrieval and creating an AgentRetrievalAction +// object in the database (along with the RetrievalDocument and RetrievalDocumentChunk objects). It does not create any generic // model related to the conversation. export async function* runRetrieval( auth: Authenticator, diff --git a/front/lib/models/assistant/actions/retrieval.ts b/front/lib/models/assistant/actions/retrieval.ts index 5121a796e941f..3b705d27a3e83 100644 --- a/front/lib/models/assistant/actions/retrieval.ts +++ b/front/lib/models/assistant/actions/retrieval.ts @@ -1,4 +1,5 @@ import { + CreationOptional, DataTypes, ForeignKey, InferAttributes, @@ -195,3 +196,224 @@ AgentConfiguration.hasOne(AgentRetrievalConfiguration, { foreignKey: { name: "agentId", allowNull: true }, // null = no generation set for this Agent onDelete: "CASCADE", }); + +/** + * Retrieval Action + */ +export class AgentRetrievalAction extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: number; + declare createdAt: CreationOptional; + declare updatedAt: CreationOptional; + + declare query: string | null; + declare relativeTimeFrameDuration: number | null; + declare relativeTimeFrameUnit: TimeframeUnit | null; + declare topK: number; + + declare retrievalConfigurationId: ForeignKey< + AgentRetrievalConfiguration["id"] + >; +} +AgentRetrievalAction.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + createdAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + updatedAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + query: { + type: DataTypes.TEXT, + allowNull: true, + }, + relativeTimeFrameDuration: { + type: DataTypes.INTEGER, + allowNull: true, + }, + relativeTimeFrameUnit: { + type: DataTypes.STRING, + allowNull: true, + }, + topK: { + type: DataTypes.INTEGER, + allowNull: false, + }, + }, + { + modelName: "agent_retrieval_action", + sequelize: front_sequelize, + hooks: { + beforeValidate: (retrieval: AgentRetrievalAction) => { + // Validation for Timeframe + if ( + retrieval.relativeTimeFrameDuration === null && + retrieval.relativeTimeFrameUnit !== null + ) { + throw new Error( + "Relative time frame must have a duration and unit set or they should both be null" + ); + } + if ( + retrieval.relativeTimeFrameDuration !== null && + retrieval.relativeTimeFrameUnit === null + ) { + throw new Error( + "Relative time frame must have a duration and unit set or they should both be null" + ); + } + }, + }, + } +); + +AgentRetrievalConfiguration.hasMany(AgentRetrievalAction, { + foreignKey: { name: "retrievalConfigurationId", allowNull: false }, + // We don't want to delete the action when the configuration is deleted + // But really we don't want to delete configurations ever. +}); + +export class RetrievalDocument extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: CreationOptional; + declare createdAt: CreationOptional; + declare updatedAt: CreationOptional; + + declare dataSourceId: string; + declare sourceUrl: string | null; + declare documentId: string; + declare reference: string; + declare timestamp: number; + declare tags: string[]; + declare score: number; + + declare retrievalActionId: ForeignKey; +} + +RetrievalDocument.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + createdAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + updatedAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + dataSourceId: { + type: DataTypes.STRING, + allowNull: false, + }, + sourceUrl: { + type: DataTypes.STRING, + allowNull: true, + }, + documentId: { + type: DataTypes.STRING, + allowNull: false, + }, + reference: { + type: DataTypes.STRING, + allowNull: false, + }, + timestamp: { + type: DataTypes.BIGINT, + allowNull: false, + }, + tags: { + type: DataTypes.ARRAY(DataTypes.STRING), + allowNull: false, + }, + score: { + type: DataTypes.REAL, + allowNull: false, + }, + }, + { + modelName: "retrieval_document", + sequelize: front_sequelize, + indexes: [{ fields: ["retrievalActionId"] }], + } +); + +AgentRetrievalAction.hasMany(RetrievalDocument, { + foreignKey: { name: "retrievalActionId", allowNull: false }, + onDelete: "CASCADE", +}); + +export class RetrievalDocumentChunk extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: CreationOptional; + declare createdAt: CreationOptional; + declare updatedAt: CreationOptional; + + declare text: string; + declare offset: number; + declare score: number; + + declare retrievalDocumentId: ForeignKey; +} + +RetrievalDocumentChunk.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + createdAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + updatedAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + text: { + type: DataTypes.TEXT, + allowNull: false, + }, + offset: { + type: DataTypes.INTEGER, + allowNull: false, + }, + score: { + type: DataTypes.REAL, + allowNull: false, + }, + }, + { + modelName: "retrieval_document_chunk", + sequelize: front_sequelize, + indexes: [{ fields: ["retrievalDocumentId"] }], + } +); + +RetrievalDocument.hasMany(RetrievalDocumentChunk, { + foreignKey: { name: "retrievalDocumentId", allowNull: false }, + onDelete: "CASCADE", +}); diff --git a/front/lib/models/assistant/conversation.ts b/front/lib/models/assistant/conversation.ts index a30c377e3ffd4..2e3ac681f9640 100644 --- a/front/lib/models/assistant/conversation.ts +++ b/front/lib/models/assistant/conversation.ts @@ -8,6 +8,7 @@ import { } from "sequelize"; import { front_sequelize } from "@app/lib/databases"; +import { AgentRetrievalAction } from "@app/lib/models/assistant/actions/retrieval"; import { User } from "@app/lib/models/user"; import { AgentMessageStatus, @@ -128,6 +129,8 @@ export class AgentMessage extends Model< declare message: string | null; declare errorCode: string | null; declare errorMessage: string | null; + + declare agentRetrievalActionId: ForeignKey; } AgentMessage.init( @@ -157,10 +160,21 @@ AgentMessage.init( }, { modelName: "agent_message", + indexes: [ + { + unique: true, + fields: ["agentRetrievalActionId"], + }, + ], sequelize: front_sequelize, } ); +AgentRetrievalAction.hasOne(AgentMessage, { + foreignKey: { name: "agentRetrievalActionId", allowNull: true }, // null = no retrieval action set for this Agent + onDelete: "CASCADE", +}); + export class Message extends Model< InferAttributes, InferCreationAttributes diff --git a/front/lib/models/index.ts b/front/lib/models/index.ts index 6d5a9ad276d73..8dd9974663ffe 100644 --- a/front/lib/models/index.ts +++ b/front/lib/models/index.ts @@ -1,4 +1,11 @@ import { App, Clone, Dataset, Provider, Run } from "@app/lib/models/apps"; +import { + AgentDataSourceConfiguration, + AgentRetrievalAction, + AgentRetrievalConfiguration, + RetrievalDocument, + RetrievalDocumentChunk, +} from "@app/lib/models/assistant/actions/retrieval"; import { AgentMessage, Conversation, @@ -27,7 +34,10 @@ import { import { XP1Run, XP1User } from "@app/lib/models/xp1"; export { + AgentDataSourceConfiguration, AgentMessage, + AgentRetrievalAction, + AgentRetrievalConfiguration, App, ChatMessage, ChatRetrievedDocument, @@ -45,6 +55,8 @@ export { MembershipInvitation, Message, Provider, + RetrievalDocument, + RetrievalDocumentChunk, Run, TrackedDocument, User,