From 187e5da55526dc72b3ec136db9e4d2f172080ac2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daphn=C3=A9=20Popin?= Date: Wed, 6 Sep 2023 16:51:29 +0200 Subject: [PATCH] Assistant V2: Models for configuration of Agent (#1283) * Assistant V2: Models for configuration of Agent * Rename remove prefix Assistant from models * Apply feedback * update timeframe after rebase * Fix missed in rename * Remove topKdefault value * Another fix after rebase --- front/admin/db.ts | 13 ++ front/lib/api/assistant/actions/retrieval.ts | 28 +-- front/lib/api/assistant/agent.ts | 28 +-- front/lib/api/assistant/conversation.ts | 18 +- .../lib/models/assistant/actions/retrieval.ts | 197 ++++++++++++++++++ front/lib/models/assistant/agent.ts | 142 +++++++++++++ front/types/assistant/actions/retrieval.ts | 5 +- front/types/assistant/agent.ts | 5 +- 8 files changed, 395 insertions(+), 41 deletions(-) create mode 100644 front/lib/models/assistant/actions/retrieval.ts create mode 100644 front/lib/models/assistant/agent.ts diff --git a/front/admin/db.ts b/front/admin/db.ts index 9fd574e5d253..9141ddad148a 100644 --- a/front/admin/db.ts +++ b/front/admin/db.ts @@ -26,6 +26,14 @@ import { XP1Run, XP1User, } from "@app/lib/models"; +import { + AgentDataSourceConfiguration, + AgentRetrievalConfiguration, +} from "@app/lib/models/assistant/actions/retrieval"; +import { + AgentConfiguration, + AgentGenerationConfiguration, +} from "@app/lib/models/assistant/agent"; async function main() { await User.sync({ alter: true }); @@ -54,6 +62,11 @@ async function main() { await AgentMessage.sync({ alter: true }); await Message.sync({ alter: true }); + await AgentConfiguration.sync({ alter: true }); + await AgentGenerationConfiguration.sync({ alter: true }); + await AgentRetrievalConfiguration.sync({ alter: true }); + await AgentDataSourceConfiguration.sync({ alter: true }); + await XP1User.sync({ alter: true }); await XP1Run.sync({ alter: true }); diff --git a/front/lib/api/assistant/actions/retrieval.ts b/front/lib/api/assistant/actions/retrieval.ts index 922b1fc17f71..554a95027f38 100644 --- a/front/lib/api/assistant/actions/retrieval.ts +++ b/front/lib/api/assistant/actions/retrieval.ts @@ -45,32 +45,32 @@ export function parseTimeFrame(raw: string): TimeFrame | null { return null; } - const count = parseInt(m[1], 10); - if (isNaN(count)) { + const duration = parseInt(m[1], 10); + if (isNaN(duration)) { return null; } - let duration: TimeFrame["duration"]; + let unit: TimeFrame["unit"]; switch (m[2]) { case "d": - duration = "day"; + unit = "day"; break; case "w": - duration = "week"; + unit = "week"; break; case "m": - duration = "month"; + unit = "month"; break; case "y": - duration = "year"; + unit = "year"; break; default: return null; } return { - count, duration, + unit, }; } @@ -78,17 +78,17 @@ export function parseTimeFrame(raw: string): TimeFrame | null { export function timeFrameFromNow(timeFrame: TimeFrame): number { const now = Date.now(); - switch (timeFrame.duration) { + switch (timeFrame.unit) { case "hour": - return now - timeFrame.count * 60 * 60 * 1000; + return now - timeFrame.duration * 60 * 60 * 1000; case "day": - return now - timeFrame.count * 24 * 60 * 60 * 1000; + return now - timeFrame.duration * 24 * 60 * 60 * 1000; case "week": - return now - timeFrame.count * 7 * 24 * 60 * 60 * 1000; + return now - timeFrame.duration * 7 * 24 * 60 * 60 * 1000; case "month": - return now - timeFrame.count * 30 * 24 * 60 * 60 * 1000; + return now - timeFrame.duration * 30 * 24 * 60 * 60 * 1000; case "year": - return now - timeFrame.count * 365 * 24 * 60 * 60 * 1000; + return now - timeFrame.duration * 365 * 24 * 60 * 60 * 1000; } } diff --git a/front/lib/api/assistant/agent.ts b/front/lib/api/assistant/agent.ts index 2e6fbfa5e943..2618d80552a7 100644 --- a/front/lib/api/assistant/agent.ts +++ b/front/lib/api/assistant/agent.ts @@ -11,7 +11,7 @@ import { AgentActionSpecification, AgentConfigurationStatus, AgentConfigurationType, - AgentMessageConfigurationType, + AgentGenerationConfigurationType, } from "@app/types/assistant/agent"; import { AgentActionType, @@ -35,12 +35,12 @@ export async function createAgentConfiguration( name, pictureUrl, action, - message, + generation, }: { name: string; pictureUrl?: string; action?: AgentActionConfigurationType; - message?: AgentMessageConfigurationType; + generation?: AgentGenerationConfigurationType; } ): Promise { return { @@ -49,7 +49,7 @@ export async function createAgentConfiguration( pictureUrl: pictureUrl ?? null, status: "active", action: action ?? null, - message: message ?? null, + generation: generation ?? null, }; } @@ -61,13 +61,13 @@ export async function updateAgentConfiguration( pictureUrl, status, action, - message, + generation, }: { name: string; pictureUrl?: string; status: AgentConfigurationStatus; action?: AgentActionConfigurationType; - message?: AgentMessageConfigurationType; + generation?: AgentGenerationConfigurationType; } ): Promise { return { @@ -76,7 +76,7 @@ export async function updateAgentConfiguration( pictureUrl: pictureUrl ?? null, status, action: action ?? null, - message: message ?? null, + generation: generation ?? null, }; } @@ -188,8 +188,8 @@ export type AgentActionSuccessEvent = { }; // Event sent when tokens are streamed as the the agent is generating a message. -export type AgentMessageTokensEvent = { - type: "agent_message_tokens"; +export type AgentGenerationTokensEvent = { + type: "agent_generation_tokens"; created: number; configurationId: string; messageId: string; @@ -197,11 +197,11 @@ export type AgentMessageTokensEvent = { }; // Event sent once the message is completed and successful. -export type AgentMessageSuccessEvent = { - type: "agent_message_success"; +export type AgentGenerationSuccessEvent = { + type: "agent_generation_success"; created: number; configurationId: string; - messageId: string; + generationId: string; message: AgentMessageType; }; @@ -217,8 +217,8 @@ export async function* runAgent( | AgentErrorEvent | AgentActionEvent | AgentActionSuccessEvent - | AgentMessageTokensEvent - | AgentMessageSuccessEvent + | AgentGenerationTokensEvent + | AgentGenerationSuccessEvent > { yield { type: "agent_error", diff --git a/front/lib/api/assistant/conversation.ts b/front/lib/api/assistant/conversation.ts index 9d5f1384871a..601c0f0acd69 100644 --- a/front/lib/api/assistant/conversation.ts +++ b/front/lib/api/assistant/conversation.ts @@ -2,9 +2,9 @@ import { AgentActionEvent, AgentActionSuccessEvent, AgentErrorEvent, + AgentGenerationSuccessEvent, + AgentGenerationTokensEvent, AgentMessageNewEvent, - AgentMessageSuccessEvent, - AgentMessageTokensEvent, } from "@app/lib/api/assistant/agent"; import { Authenticator } from "@app/lib/auth"; import { CoreAPI } from "@app/lib/core_api"; @@ -176,8 +176,8 @@ export async function* postUserMessage( | AgentErrorEvent | AgentActionEvent | AgentActionSuccessEvent - | AgentMessageTokensEvent - | AgentMessageSuccessEvent + | AgentGenerationTokensEvent + | AgentGenerationSuccessEvent > { const user = auth.user(); @@ -266,7 +266,7 @@ export async function* postUserMessage( name: "foo", // TODO pictureUrl: null, // TODO action: null, // TODO - message: null, // TODO + generation: null, // TODO }, }); } @@ -320,8 +320,8 @@ export async function* retryAgentMessage( | AgentErrorEvent | AgentActionEvent | AgentActionSuccessEvent - | AgentMessageTokensEvent - | AgentMessageSuccessEvent + | AgentGenerationTokensEvent + | AgentGenerationSuccessEvent > { yield { type: "agent_error", @@ -354,8 +354,8 @@ export async function* editUserMessage( | AgentErrorEvent | AgentActionEvent | AgentActionSuccessEvent - | AgentMessageTokensEvent - | AgentMessageSuccessEvent + | AgentGenerationTokensEvent + | AgentGenerationSuccessEvent > { yield { type: "agent_error", diff --git a/front/lib/models/assistant/actions/retrieval.ts b/front/lib/models/assistant/actions/retrieval.ts new file mode 100644 index 000000000000..5121a796e941 --- /dev/null +++ b/front/lib/models/assistant/actions/retrieval.ts @@ -0,0 +1,197 @@ +import { + DataTypes, + ForeignKey, + InferAttributes, + InferCreationAttributes, + Model, +} from "sequelize"; + +import { front_sequelize } from "@app/lib/databases"; +import { AgentConfiguration } from "@app/lib/models/assistant/agent"; +import { DataSource } from "@app/lib/models/data_source"; +import { TimeframeUnit } from "@app/types/assistant/actions/retrieval"; + +/** + * Action Retrieval configuration + */ +export class AgentRetrievalConfiguration extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: number; + + declare query: "auto" | "none" | "templated"; + declare queryTemplate: string | null; + declare relativeTimeFrame: "auto" | "none" | "custom"; + declare relativeTimeFrameDuration: number | null; + declare relativeTimeFrameUnit: TimeframeUnit | null; + declare topK: number; + + declare agentId: ForeignKey; +} +AgentRetrievalConfiguration.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + query: { + type: DataTypes.STRING, + allowNull: false, + defaultValue: "auto", + }, + queryTemplate: { + type: DataTypes.TEXT, + allowNull: true, + }, + relativeTimeFrame: { + type: DataTypes.STRING, + allowNull: false, + defaultValue: "auto", + }, + relativeTimeFrameDuration: { + type: DataTypes.INTEGER, + allowNull: true, + }, + relativeTimeFrameUnit: { + type: DataTypes.STRING, + allowNull: true, + }, + topK: { + type: DataTypes.INTEGER, + allowNull: false, + }, + }, + { + modelName: "agent_retrieval_configuration", + sequelize: front_sequelize, + hooks: { + beforeValidate: (retrieval: AgentRetrievalConfiguration) => { + // Validation for templated Query + if (retrieval.query == "templated") { + if (!retrieval.queryTemplate) { + throw new Error("Must set a template for templated query"); + } + } else if (retrieval.queryTemplate) { + throw new Error("Can't set a template without templated query"); + } + + // Validation for Timeframe + if (retrieval.relativeTimeFrame == "custom") { + if ( + !retrieval.relativeTimeFrameDuration || + !retrieval.relativeTimeFrameUnit + ) { + throw new Error( + "Custom relative time frame must have a duration and unit set" + ); + } + } + }, + }, + } +); + +/** + * Configuration of Datasources used for Retrieval Action. + */ +export class AgentDataSourceConfiguration extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: number; + + declare minTimestamp: number | null; + declare maxTimestamp: number | null; + declare timeframeDuration: number | null; + declare timeframeUnit: TimeframeUnit | null; + + declare tagsIn: string[] | null; + declare tagsNotIn: string[] | null; + declare parentsIn: string[] | null; + declare parentsNotIn: string[] | null; + + declare dataSourceId: ForeignKey; + declare retrievalConfigurationId: ForeignKey< + AgentRetrievalConfiguration["id"] + >; +} +AgentDataSourceConfiguration.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + minTimestamp: { + type: DataTypes.INTEGER, + allowNull: true, + }, + maxTimestamp: { + type: DataTypes.INTEGER, + allowNull: true, + }, + timeframeDuration: { + type: DataTypes.INTEGER, + allowNull: true, + }, + timeframeUnit: { + type: DataTypes.STRING, + allowNull: true, + }, + tagsIn: { + type: DataTypes.ARRAY(DataTypes.STRING), + allowNull: true, + }, + tagsNotIn: { + type: DataTypes.ARRAY(DataTypes.STRING), + allowNull: true, + }, + parentsIn: { + type: DataTypes.ARRAY(DataTypes.STRING), + allowNull: true, + }, + parentsNotIn: { + type: DataTypes.ARRAY(DataTypes.STRING), + allowNull: true, + }, + }, + { + modelName: "agent_data_source_configuration", + sequelize: front_sequelize, + hooks: { + beforeValidate: (dataSourceConfig: AgentDataSourceConfiguration) => { + if (!dataSourceConfig.minTimestamp !== !dataSourceConfig.maxTimestamp) { + throw new Error("Timestamps must be both set or both null"); + } + if ( + !dataSourceConfig.timeframeDuration !== + !dataSourceConfig.timeframeUnit + ) { + throw new Error( + "Timeframe duration/unit must be both set or both null" + ); + } + if ( + (dataSourceConfig.minTimestamp || dataSourceConfig.maxTimestamp) && + (dataSourceConfig.timeframeDuration || dataSourceConfig.timeframeUnit) + ) { + throw new Error("Cannot use both timestamps and timeframe"); + } + }, + }, + } +); + +// Retrieval config <> data source config +AgentRetrievalConfiguration.hasMany(AgentDataSourceConfiguration, { + foreignKey: { name: "retrievalId", allowNull: false }, + onDelete: "CASCADE", +}); + +// Agent config <> Retrieval config +AgentConfiguration.hasOne(AgentRetrievalConfiguration, { + foreignKey: { name: "agentId", allowNull: true }, // null = no generation set for this Agent + onDelete: "CASCADE", +}); diff --git a/front/lib/models/assistant/agent.ts b/front/lib/models/assistant/agent.ts new file mode 100644 index 000000000000..dcda67f5b35c --- /dev/null +++ b/front/lib/models/assistant/agent.ts @@ -0,0 +1,142 @@ +import { + DataTypes, + ForeignKey, + InferAttributes, + InferCreationAttributes, + Model, +} from "sequelize"; + +import { front_sequelize } from "@app/lib/databases"; +import { AgentRetrievalConfiguration } from "@app/lib/models/assistant/actions/retrieval"; +import { Workspace } from "@app/lib/models/workspace"; +import { + AgentConfigurationScope, + AgentConfigurationStatus, +} from "@app/types/assistant/agent"; + +/** + * Agent configuration + */ +export class AgentConfiguration extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: number; + + declare sId: string; + declare status: AgentConfigurationStatus; + declare name: string; + declare pictureUrl: string | null; + + declare scope: AgentConfigurationScope; + declare workspaceId: ForeignKey | null; // null = it's a global agent + + declare model: ForeignKey | null; +} +AgentConfiguration.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + sId: { + type: DataTypes.STRING, + allowNull: false, + unique: true, + }, + status: { + type: DataTypes.STRING, + allowNull: false, + defaultValue: "active", + }, + name: { + type: DataTypes.TEXT, + allowNull: false, + }, + pictureUrl: { + type: DataTypes.TEXT, + allowNull: true, + }, + scope: { + type: DataTypes.STRING, + allowNull: false, + defaultValue: "workspace", + }, + }, + { + modelName: "agent_configuration", + sequelize: front_sequelize, + indexes: [ + { fields: ["workspaceId"] }, + // Unique name per workspace. + // Note that on PostgreSQL a unique constraint on multiple columns will treat NULL + // as distinct from any other value, so we can create twice the same name if at least + // one of the workspaceId is null. We're okay with it. + { fields: ["workspaceId", "name", "scope"], unique: true }, + { fields: ["sId"], unique: true }, + ], + hooks: { + beforeValidate: (agent: AgentConfiguration) => { + if (agent.scope !== "workspace" && agent.workspaceId) { + throw new Error("Workspace id must be null for global agent"); + } else if (agent.scope === "workspace" && !agent.workspaceId) { + throw new Error("Workspace id must be set for non-global agent"); + } + }, + }, + } +); + +/** + * Configuration of Agent generation. + */ +export class AgentGenerationConfiguration extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: number; + + declare prompt: string; + declare modelProvider: string; + declare modelId: string; + + declare agentId: ForeignKey; +} +AgentGenerationConfiguration.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + prompt: { + type: DataTypes.TEXT, + allowNull: false, + }, + modelProvider: { + type: DataTypes.STRING, + allowNull: false, + }, + modelId: { + type: DataTypes.STRING, + allowNull: false, + }, + }, + { + modelName: "agent_generation_configuration", + sequelize: front_sequelize, + } +); + +// Workspace <> Agent config +Workspace.hasMany(AgentConfiguration, { + foreignKey: { name: "workspaceId", allowNull: true }, // null = global Agent + onDelete: "CASCADE", +}); + +// Agent config <> Generation config +AgentConfiguration.hasOne(AgentGenerationConfiguration, { + foreignKey: { name: "agentId", allowNull: false }, // null = no retrieval action set for this Agent + onDelete: "CASCADE", +}); diff --git a/front/types/assistant/actions/retrieval.ts b/front/types/assistant/actions/retrieval.ts index 8117c070c9ee..3a473bbe86ce 100644 --- a/front/types/assistant/actions/retrieval.ts +++ b/front/types/assistant/actions/retrieval.ts @@ -6,9 +6,10 @@ import { ModelId } from "@app/lib/databases"; import { AgentActionConfigurationType } from "@app/types/assistant/agent"; import { AgentActionType } from "@app/types/assistant/conversation"; +export type TimeframeUnit = "hour" | "day" | "week" | "month" | "year"; export type TimeFrame = { - count: number; - duration: "hour" | "day" | "week" | "month" | "year"; + duration: number; + unit: TimeframeUnit; }; export type DataSourceFilter = { diff --git a/front/types/assistant/agent.ts b/front/types/assistant/agent.ts index f099621e0ee7..457320076e1a 100644 --- a/front/types/assistant/agent.ts +++ b/front/types/assistant/agent.ts @@ -43,7 +43,7 @@ export type AgentActionSpecification = { * Agent Message configuration */ -export type AgentMessageConfigurationType = { +export type AgentGenerationConfigurationType = { prompt: string; model: { providerId: string; @@ -56,6 +56,7 @@ export type AgentMessageConfigurationType = { */ export type AgentConfigurationStatus = "active" | "archived"; +export type AgentConfigurationScope = "global" | "workspace"; export type AgentConfigurationType = { sId: string; @@ -69,5 +70,5 @@ export type AgentConfigurationType = { action: AgentActionConfigurationType | null; // If undefined, no text generation. - message: AgentMessageConfigurationType | null; + generation: AgentGenerationConfigurationType | null; };