From b20d0257536ea680f8d0e6a625d09cf08274a0ca Mon Sep 17 00:00:00 2001 From: PopDaph Date: Fri, 8 Sep 2023 14:12:53 +0200 Subject: [PATCH] WIP rework everything --- front/lib/api/assistant/actions/retrieval.ts | 25 +- front/lib/api/assistant/agent.ts | 6 +- front/lib/api/assistant/agent/agent_create.ts | 283 -------- front/lib/api/assistant/agent/agent_get.ts | 195 ------ front/lib/api/assistant/agent/agent_update.ts | 208 ------ front/lib/api/assistant/configuration.ts | 618 ++++++++++++++++++ front/lib/api/assistant/conversation.ts | 111 +--- .../lib/models/assistant/actions/retrieval.ts | 192 +----- front/lib/models/assistant/agent.ts | 167 ----- front/lib/models/assistant/configuration.ts | 352 ++++++++++ front/lib/models/index.ts | 6 +- front/types/assistant/actions/retrieval.ts | 45 +- front/types/assistant/agent.ts | 48 +- front/types/assistant/configuration.ts | 66 ++ front/types/assistant/conversation.ts | 4 +- 15 files changed, 1085 insertions(+), 1241 deletions(-) delete mode 100644 front/lib/api/assistant/agent/agent_create.ts delete mode 100644 front/lib/api/assistant/agent/agent_get.ts delete mode 100644 front/lib/api/assistant/agent/agent_update.ts create mode 100644 front/lib/api/assistant/configuration.ts delete mode 100644 front/lib/models/assistant/agent.ts create mode 100644 front/lib/models/assistant/configuration.ts create mode 100644 front/types/assistant/configuration.ts diff --git a/front/lib/api/assistant/actions/retrieval.ts b/front/lib/api/assistant/actions/retrieval.ts index f73e1be1d286..6730bef1fc89 100644 --- a/front/lib/api/assistant/actions/retrieval.ts +++ b/front/lib/api/assistant/actions/retrieval.ts @@ -17,17 +17,17 @@ import { Err, Ok, Result } from "@app/lib/result"; import { new_id } from "@app/lib/utils"; import logger from "@app/logger/logger"; import { - AgentDataSourceConfigurationType, - isRetrievalConfiguration, RetrievalActionType, - RetrievalConfigurationType, RetrievalDocumentType, TimeFrame, } from "@app/types/assistant/actions/retrieval"; +import { AgentActionSpecification } from "@app/types/assistant/agent"; import { - AgentActionSpecification, - AgentFullConfigurationType, -} from "@app/types/assistant/agent"; + AgentConfigurationType, + AgentDataSourceConfigurationType, + isRetrievalConfiguration, + RetrievalConfigurationType, +} from "@app/types/assistant/configuration"; import { AgentMessageType, ConversationType, @@ -167,7 +167,6 @@ export async function retrievalActionSpecification( } return { - id: configuration.id, name: "search_data_sources", description: "Search the data sources specified by the user for information to answer their request." + @@ -316,7 +315,7 @@ export type RetrievalSuccessEvent = { // error is expected to be stored by the caller on the parent agent message. export async function* runRetrieval( auth: Authenticator, - configuration: AgentFullConfigurationType, + configuration: AgentConfigurationType, conversation: ConversationType, userMessage: UserMessageType, agentMessage: AgentMessageType @@ -349,7 +348,7 @@ export async function* runRetrieval( return yield { type: "retrieval_error", created: Date.now(), - configurationId: configuration.agent.sId, + configurationId: configuration.sId, messageId: agentMessage.sId, error: { code: "retrieval_parameters_generation_error", @@ -433,7 +432,7 @@ export async function* runRetrieval( return yield { type: "retrieval_error", created: Date.now(), - configurationId: configuration.agent.sId, + configurationId: configuration.sId, messageId: agentMessage.sId, error: { code: "retrieval_search_error", @@ -450,7 +449,7 @@ export async function* runRetrieval( return yield { type: "retrieval_error", created: Date.now(), - configurationId: configuration.agent.sId, + configurationId: configuration.sId, messageId: agentMessage.sId, error: { code: "retrieval_search_error", @@ -529,7 +528,7 @@ export async function* runRetrieval( yield { type: "retrieval_documents", created: Date.now(), - configurationId: configuration.agent.sId, + configurationId: configuration.sId, messageId: agentMessage.sId, documents, }; @@ -537,7 +536,7 @@ export async function* runRetrieval( yield { type: "retrieval_success", created: Date.now(), - configurationId: configuration.agent.sId, + configurationId: configuration.sId, messageId: agentMessage.sId, action: { id: action.id, diff --git a/front/lib/api/assistant/agent.ts b/front/lib/api/assistant/agent.ts index 7799d8dd7549..4d9891171f10 100644 --- a/front/lib/api/assistant/agent.ts +++ b/front/lib/api/assistant/agent.ts @@ -16,11 +16,11 @@ import { import { Authenticator } from "@app/lib/auth"; import { Err, Ok, Result } from "@app/lib/result"; import { generateModelSId } from "@app/lib/utils"; -import { isRetrievalConfiguration } from "@app/types/assistant/actions/retrieval"; +import { AgentActionSpecification } from "@app/types/assistant/agent"; import { - AgentActionSpecification, AgentConfigurationType, -} from "@app/types/assistant/agent"; + isRetrievalConfiguration, +} from "@app/types/assistant/configuration"; import { AgentActionType, AgentMessageType, diff --git a/front/lib/api/assistant/agent/agent_create.ts b/front/lib/api/assistant/agent/agent_create.ts deleted file mode 100644 index 45458446ceec..000000000000 --- a/front/lib/api/assistant/agent/agent_create.ts +++ /dev/null @@ -1,283 +0,0 @@ -import { Op, Transaction } from "sequelize"; - -import { - _buildAgentActionConfigurationTypeFromModel, - _buildAgentConfigurationTypeFromModel, - _buildAgentGenerationConfigurationTypeFromModel, -} from "@app/lib/api/assistant/agent/agent_get"; -import { Authenticator } from "@app/lib/auth"; -import { front_sequelize } from "@app/lib/databases"; -import { DataSource, Workspace } from "@app/lib/models"; -import { - AgentDataSourceConfiguration, - AgentRetrievalConfiguration, -} from "@app/lib/models/assistant/actions/retrieval"; -import { - AgentConfiguration, - AgentGenerationConfiguration, -} from "@app/lib/models/assistant/agent"; -import { generateModelSId } from "@app/lib/utils"; -import { - AgentDataSourceConfigurationType, - isTemplatedQuery, - isTimeFrame, - RetrievalDataSourcesConfiguration, - RetrievalQuery, - RetrievalTimeframe, -} from "@app/types/assistant/actions/retrieval"; -import { - AgentActionConfigurationType, - AgentConfigurationStatus, - AgentConfigurationType, - AgentGenerationConfigurationType, -} from "@app/types/assistant/agent"; - -/** - * Create Agent Configuration - */ -export async function createAgentConfiguration( - auth: Authenticator, - { - name, - pictureUrl, - status, - }: { - name: string; - pictureUrl: string; - status: AgentConfigurationStatus; - } -): Promise { - const owner = auth.workspace(); - if (!owner) { - throw new Error("Cannot create AgentConfiguration without workspace"); - } - - const agentConfig = await AgentConfiguration.create({ - sId: generateModelSId(), - status: status, - name: name, - pictureUrl: pictureUrl, - scope: "workspace", - workspaceId: owner.id, - }); - - return _buildAgentConfigurationTypeFromModel({ - agent: agentConfig, - }); -} - -/** - * Create Agent Generation Configuration - */ -export async function createAgentGenerationConfiguration( - auth: Authenticator, - agentSid: string, - { - prompt, - modelProvider, - modelId, - }: { - prompt: string; - modelProvider: string; - modelId: string; - } -): Promise { - const owner = auth.workspace(); - if (!owner) { - throw new Error( - "Cannot create AgentGenerationConfiguration: Workspace not found" - ); - } - - const agentConfig = await AgentConfiguration.findOne({ - where: { - sId: agentSid, - }, - }); - if (!agentConfig) { - throw new Error( - "Cannot create AgentGenerationConfiguration: Agent not found" - ); - } - - const generation = await AgentGenerationConfiguration.create({ - prompt: prompt, - modelProvider: modelProvider, - modelId: modelId, - agentId: agentConfig.id, - }); - - return _buildAgentGenerationConfigurationTypeFromModel(generation); -} - -/** - * Create Agent Action Configuration (Retrieval) - */ -export async function createAgentActionRetrievalConfiguration( - auth: Authenticator, - agentSid: string, - { - query, - timeframe, - topK, - dataSources, - }: { - query: RetrievalQuery; - timeframe: RetrievalTimeframe; - topK: number; - dataSources: RetrievalDataSourcesConfiguration; - } -): Promise { - const owner = auth.workspace(); - if (!owner) { - throw new Error( - "Cannot create AgentActionConfiguration: Workspace not found" - ); - } - - const agentConfig = await AgentConfiguration.findOne({ - where: { - sId: agentSid, - }, - }); - if (!agentConfig) { - throw new Error("Cannot create AgentActionConfiguration: Agent not found"); - } - return await front_sequelize.transaction(async (t) => { - const agentActionConfigRow = await AgentRetrievalConfiguration.create( - { - query: isTemplatedQuery(query) ? "templated" : query, - queryTemplate: isTemplatedQuery(query) ? query.template : null, - relativeTimeFrame: isTimeFrame(timeframe) ? "custom" : timeframe, - relativeTimeFrameDuration: isTimeFrame(timeframe) - ? timeframe.duration - : null, - relativeTimeFrameUnit: isTimeFrame(timeframe) ? timeframe.unit : null, - topK: topK, - agentId: agentConfig.id, - }, - { transaction: t } - ); - const agentDataSourcesConfigRows = await _createAgentDataSourcesConfigData( - t, - dataSources, - agentActionConfigRow.id - ); - return await _buildAgentActionConfigurationTypeFromModel( - agentActionConfigRow, - agentDataSourcesConfigRows - ); - }); -} - -/** - * Create the AgentDataSourceConfiguration rows in database. - * - * Knowing that a datasource is uniquely identified by its name and its workspaceId - * We need to fetch the dataSources from the database from that. - * We obvisously need to do as few queries as possible. - */ -export async function _createAgentDataSourcesConfigData( - t: Transaction, - dataSourcesConfig: AgentDataSourceConfigurationType[], - agentActionId: number -): Promise { - // dsConfig contains this format: - // [ - // { workspaceSId: s1o1u1p, dataSourceName: "managed-notion", filter: { tags: null, parents: null } }, - // { workspaceSId: s1o1u1p, dataSourceName: "managed-slack", filter: { tags: null, parents: null } }, - // { workspaceSId: i2n2o2u, dataSourceName: "managed-notion", filter: { tags: null, parents: null } }, - // ] - - // First we get the list of workspaces because we need the mapping between workspaceSId and workspaceId - const workspaces = await Workspace.findAll({ - where: { - sId: dataSourcesConfig.map((dsConfig) => dsConfig.workspaceSId), - }, - attributes: ["id", "sId"], - }); - - // Now will want to group the datasource names by workspaceId to do only one query per workspace. - // We want this: - // [ - // { workspaceId: 1, dataSourceNames: [""managed-notion", "managed-slack"] }, - // { workspaceId: 2, dataSourceNames: ["managed-notion"] } - // ] - type _DsNamesPerWorkspaceIdType = { - workspaceId: number; - dataSourceNames: string[]; - }; - const dsNamesPerWorkspaceId = dataSourcesConfig.reduce( - ( - acc: _DsNamesPerWorkspaceIdType[], - curr: AgentDataSourceConfigurationType - ) => { - // First we need to get the workspaceId from the workspaceSId - const workspace = workspaces.find((w) => w.sId === curr.workspaceSId); - if (!workspace) { - throw new Error("Workspace not found"); - } - - // Find an existing entry for this workspaceId - const existingEntry: _DsNamesPerWorkspaceIdType | undefined = acc.find( - (entry: _DsNamesPerWorkspaceIdType) => - entry.workspaceId === workspace.id - ); - if (existingEntry) { - // Append dataSourceName to existing entry - existingEntry.dataSourceNames.push(curr.dataSourceName); - } else { - // Add a new entry for this workspaceId - acc.push({ - workspaceId: workspace.id, - dataSourceNames: [curr.dataSourceName], - }); - } - return acc; - }, - [] - ); - - // Then we get do one findAllQuery per workspaceId, in a Promise.all - const getDataSourcesQueries = dsNamesPerWorkspaceId.map( - ({ workspaceId, dataSourceNames }) => { - return DataSource.findAll({ - where: { - workspaceId, - name: { - [Op.in]: dataSourceNames, - }, - }, - }); - } - ); - const results = await Promise.all(getDataSourcesQueries); - const dataSources = results.flat(); - - const agentDataSourcesConfigRows: AgentDataSourceConfiguration[] = - await Promise.all( - dataSourcesConfig.map(async (dsConfig) => { - const dataSource = dataSources.find( - (ds) => - ds.name === dsConfig.dataSourceName && - ds.workspaceId === - workspaces.find((w) => w.sId === dsConfig.workspaceSId)?.id - ); - if (!dataSource) { - throw new Error("DataSource not found"); - } - return AgentDataSourceConfiguration.create( - { - dataSourceId: dataSource.id, - tagsIn: dsConfig.filter.tags?.in, - tagsNotIn: dsConfig.filter.tags?.not, - parentsIn: dsConfig.filter.parents?.in, - parentsNotIn: dsConfig.filter.parents?.not, - retrievalConfigurationId: agentActionId, - }, - { transaction: t } - ); - }) - ); - return agentDataSourcesConfigRows; -} diff --git a/front/lib/api/assistant/agent/agent_get.ts b/front/lib/api/assistant/agent/agent_get.ts deleted file mode 100644 index 360d99ef0ba4..000000000000 --- a/front/lib/api/assistant/agent/agent_get.ts +++ /dev/null @@ -1,195 +0,0 @@ -import { Op } from "sequelize"; - -import { Authenticator } from "@app/lib/auth"; -import { DataSource, Workspace } from "@app/lib/models"; -import { - AgentDataSourceConfiguration, - AgentRetrievalConfiguration, -} from "@app/lib/models/assistant/actions/retrieval"; -import { - AgentConfiguration, - AgentGenerationConfiguration, -} from "@app/lib/models/assistant/agent"; -import { - RetrievalDataSourcesConfiguration, - RetrievalQuery, - RetrievalTimeframe, -} from "@app/types/assistant/actions/retrieval"; -import { - AgentActionConfigurationType, - AgentConfigurationType, - AgentFullConfigurationType as AgentFullConfigurationType, - AgentGenerationConfigurationType, -} from "@app/types/assistant/agent"; - -/** - * Get an agent full configuration from its name - */ -export async function getAgent( - auth: Authenticator, - sId: string -): Promise { - const owner = auth.workspace(); - if (!owner) { - throw new Error("Cannot find Agent: no workspace"); - } - const agent = await AgentConfiguration.findOne({ - where: { - sId: sId, - workspaceId: owner.id, - }, - }); - if (!agent) { - throw new Error("Cannot find Agent: no workspace"); - } - const agentGeneration = await AgentGenerationConfiguration.findOne({ - where: { - agentId: agent.id, - }, - }); - const agentAction = await AgentRetrievalConfiguration.findOne({ - where: { - agentId: agent.id, - }, - }); - const agentDataSources = agentAction?.id - ? await AgentDataSourceConfiguration.findAll({ - where: { - retrievalConfigurationId: agentAction?.id, - }, - }) - : []; - - return { - agent: await _buildAgentConfigurationTypeFromModel({ agent }), - action: agentAction - ? await _buildAgentActionConfigurationTypeFromModel( - agentAction, - agentDataSources || [] - ) - : null, - generation: agentGeneration - ? _buildAgentGenerationConfigurationTypeFromModel(agentGeneration) - : null, - }; -} - -/** - * Builds the agent configuration type from the model - */ -export async function _buildAgentConfigurationTypeFromModel({ - agent, -}: { - agent: AgentConfiguration; -}): Promise { - return { - id: agent.id, - sId: agent.sId, - name: agent.name, - pictureUrl: agent.pictureUrl, - status: agent.status, - }; -} - -/** - * Builds the agent generation configuration type from the model - */ -export function _buildAgentGenerationConfigurationTypeFromModel( - generation: AgentGenerationConfiguration -): AgentGenerationConfigurationType { - return { - id: generation.id, - prompt: generation.prompt, - model: { - providerId: generation.modelProvider, - modelId: generation.modelId, - }, - }; -} - -/** - * Builds the agent action configuration type from the model - */ -export async function _buildAgentActionConfigurationTypeFromModel( - action: AgentRetrievalConfiguration, - dataSourcesConfig: AgentDataSourceConfiguration[] -): Promise { - // Build Retrieval Timeframe - let timeframe: RetrievalTimeframe = "auto"; - if ( - action.relativeTimeFrame === "custom" && - action.relativeTimeFrameDuration && - action.relativeTimeFrameUnit - ) { - timeframe = { - duration: action.relativeTimeFrameDuration, - unit: action.relativeTimeFrameUnit, - }; - } else if (action.relativeTimeFrame === "none") { - timeframe = "none"; - } - - // Build Retrieval Query - let query: RetrievalQuery = "auto"; - if (action.query === "templated" && action.queryTemplate) { - query = { - template: action.queryTemplate, - }; - } else if (action.query === "none") { - query = "none"; - } - - // Build Retrieval DataSources - const retrievalDataSourcesConfig: RetrievalDataSourcesConfiguration = []; - - const dataSourcesIds = dataSourcesConfig?.map((ds) => ds.dataSourceId); - const dataSources = await DataSource.findAll({ - where: { - id: { [Op.in]: dataSourcesIds }, - }, - attributes: ["id", "name", "workspaceId"], - }); - const workspaceIds = dataSources.map((ds) => ds.workspaceId); - const workspaces = await Workspace.findAll({ - where: { - id: { [Op.in]: workspaceIds }, - }, - attributes: ["id", "sId"], - }); - - let dataSource: DataSource | undefined; - let workspace: Workspace | undefined; - - dataSourcesConfig.forEach(async (dsConfig) => { - dataSource = dataSources.find((ds) => ds.id === dsConfig.dataSourceId); - workspace = workspaces.find((w) => w.id === dataSource?.workspaceId); - - if (!dataSource || !workspace) { - throw new Error("Could not find dataSource or workspace"); - } - - retrievalDataSourcesConfig.push({ - dataSourceName: dataSource.name, - workspaceSId: workspace.sId, - filter: { - tags: - dsConfig.tagsIn && dsConfig.tagsNotIn - ? { in: dsConfig.tagsIn, not: dsConfig.tagsNotIn } - : null, - parents: - dsConfig.parentsIn && dsConfig.parentsNotIn - ? { in: dsConfig.parentsIn, not: dsConfig.parentsNotIn } - : null, - }, - }); - }); - - return { - id: action.id, - query: query, - relativeTimeFrame: timeframe, - topK: action.topK, - type: "retrieval_configuration", - dataSources: retrievalDataSourcesConfig, - }; -} diff --git a/front/lib/api/assistant/agent/agent_update.ts b/front/lib/api/assistant/agent/agent_update.ts deleted file mode 100644 index 2b9eb20b98e8..000000000000 --- a/front/lib/api/assistant/agent/agent_update.ts +++ /dev/null @@ -1,208 +0,0 @@ -import { - _buildAgentActionConfigurationTypeFromModel, - _buildAgentConfigurationTypeFromModel, - _buildAgentGenerationConfigurationTypeFromModel, -} from "@app/lib/api/assistant/agent/agent_get"; -import { Authenticator } from "@app/lib/auth"; -import { front_sequelize } from "@app/lib/databases"; -import { - AgentConfiguration, - AgentDataSourceConfiguration, - AgentGenerationConfiguration, - AgentRetrievalConfiguration, -} from "@app/lib/models"; -import { - isTemplatedQuery, - isTimeFrame, - RetrievalConfigurationType, - RetrievalDataSourcesConfiguration, - RetrievalQuery, - RetrievalTimeframe, -} from "@app/types/assistant/actions/retrieval"; -import { - AgentConfigurationStatus, - AgentConfigurationType, - AgentGenerationConfigurationType, -} from "@app/types/assistant/agent"; - -import { _createAgentDataSourcesConfigData } from "./agent_create"; - -/** - * Update Agent Configuration - */ -export async function updateAgentConfiguration( - auth: Authenticator, - agentSid: string, - { - name, - pictureUrl, - status, - }: { - name: string; - pictureUrl: string; - status: AgentConfigurationStatus; - } -): Promise { - const owner = auth.workspace(); - if (!owner) { - throw new Error( - "Cannot create AgentGenerationConfiguration: Workspace not found" - ); - } - - const agentConfig = await AgentConfiguration.findOne({ - where: { - sId: agentSid, - }, - }); - if (!agentConfig) { - throw new Error( - "Cannot create AgentGenerationConfiguration: Agent not found" - ); - } - - const updatedAgent = await agentConfig.update({ - name: name, - pictureUrl: pictureUrl, - status: status, - }); - - return _buildAgentConfigurationTypeFromModel({ - agent: updatedAgent, - }); -} - -/** - * Update Agent Generation Configuration - */ -export async function updateAgentGenerationConfiguration( - auth: Authenticator, - agentSid: string, - { - prompt, - modelProvider, - modelId, - }: { - prompt: string; - modelProvider: string; - modelId: string; - } -): Promise { - const owner = auth.workspace(); - if (!owner) { - throw new Error( - "Cannot create AgentGenerationConfiguration: Workspace not found" - ); - } - - const agentConfig = await AgentConfiguration.findOne({ - where: { - sId: agentSid, - }, - }); - if (!agentConfig) { - throw new Error( - "Cannot create AgentGenerationConfiguration: Agent not found" - ); - } - - const generation = await AgentGenerationConfiguration.findOne({ - where: { - agentId: agentConfig.id, - }, - }); - if (!generation) { - throw new Error( - "Cannot update AgentGenerationConfiguration: Config not found" - ); - } - - const updatedGeneration = await generation.update({ - prompt: prompt, - modelProvider: modelProvider, - modelId: modelId, - }); - - return _buildAgentGenerationConfigurationTypeFromModel(updatedGeneration); -} - -/** - * Update Agent Generation Configuration - */ -export async function updateAgentActionRetrievalConfiguration( - auth: Authenticator, - agentSid: string, - { - query, - timeframe, - topK, - dataSources, - }: { - query: RetrievalQuery; - timeframe: RetrievalTimeframe; - topK: number; - dataSources: RetrievalDataSourcesConfiguration; - } -): Promise { - const owner = auth.workspace(); - if (!owner) { - throw new Error( - "Cannot create AgentActionConfiguration: Workspace not found" - ); - } - - const agentConfig = await AgentConfiguration.findOne({ - where: { - sId: agentSid, - }, - }); - if (!agentConfig) { - throw new Error("Cannot create AgentActionConfiguration: Agent not found"); - } - - const action = await AgentRetrievalConfiguration.findOne({ - where: { - agentId: agentConfig.id, - }, - }); - if (!action) { - throw new Error("Cannot update AgentActionConfiguration: Config not found"); - } - - // Updating both the Action and datasources in a single transaction - // So that we update both or none - return await front_sequelize.transaction(async (t) => { - // Update Action - const updatedAction = await action.update( - { - query: isTemplatedQuery(query) ? "templated" : query, - queryTemplate: isTemplatedQuery(query) ? query.template : null, - relativeTimeFrame: isTimeFrame(timeframe) ? "custom" : timeframe, - relativeTimeFrameDuration: isTimeFrame(timeframe) - ? timeframe.duration - : null, - relativeTimeFrameUnit: isTimeFrame(timeframe) ? timeframe.unit : null, - topK: topK, - agentId: agentConfig.id, - }, - { transaction: t } - ); - - // Update datasources: we drop and create them all - await AgentDataSourceConfiguration.destroy({ - where: { - retrievalConfigurationId: action.id, - }, - }); - const agentDataSourcesConfigRows = await _createAgentDataSourcesConfigData( - t, - dataSources, - action.id - ); - - return _buildAgentActionConfigurationTypeFromModel( - updatedAction, - agentDataSourcesConfigRows - ); - }); -} diff --git a/front/lib/api/assistant/configuration.ts b/front/lib/api/assistant/configuration.ts new file mode 100644 index 000000000000..a6a74f62c3e1 --- /dev/null +++ b/front/lib/api/assistant/configuration.ts @@ -0,0 +1,618 @@ +import { Op, Transaction } from "sequelize"; + +import { Authenticator } from "@app/lib/auth"; +import { front_sequelize } from "@app/lib/databases"; +import { + AgentConfiguration, + AgentDataSourceConfiguration, + AgentGenerationConfiguration, + AgentRetrievalConfiguration, + DataSource, + Workspace, +} from "@app/lib/models"; +import { generateModelSId } from "@app/lib/utils"; +import { + isTemplatedQuery, + isTimeFrame, + RetrievalQuery, + RetrievalTimeframe, +} from "@app/types/assistant/actions/retrieval"; +import { + AgentActionConfigurationType, + AgentConfigurationStatus, + AgentConfigurationType, + AgentDataSourceConfigurationType, +} from "@app/types/assistant/configuration"; + +/** + * Get an agent configuration + */ +export async function getAgentConfiguration( + auth: Authenticator, + agentId: string +): Promise { + const owner = auth.workspace(); + if (!owner) { + throw new Error("Cannot find Agent: no workspace"); + } + const agent = await AgentConfiguration.findOne({ + where: { + sId: agentId, + workspaceId: owner.id, + }, + }); + if (!agent) { + throw new Error("Cannot find Agent: no workspace"); + } + + const generation = agent.generationId + ? await AgentGenerationConfiguration.findOne({ + where: { + id: agent.generationId, + }, + }) + : null; + + const action = agent.retrievalId + ? await AgentRetrievalConfiguration.findOne({ + where: { + id: agent.retrievalId, + }, + }) + : null; + const datasources = action?.id + ? await AgentDataSourceConfiguration.findAll({ + where: { + retrievalConfigurationId: action.id, + }, + }) + : []; + + return { + sId: agent.sId, + name: agent.name, + pictureUrl: agent.pictureUrl, + status: agent.status, + action: action ? await _agentActionType(action, datasources) : null, + generation: generation + ? { + id: generation.id, + prompt: generation.prompt, + model: { + providerId: generation.providerId, + modelId: generation.modelId, + }, + } + : null, + }; +} + +/** + * Create Agent Configuration + */ +export async function createAgentConfiguration( + auth: Authenticator, + { + name, + pictureUrl, + status, + generation, + action, + }: { + name: string; + pictureUrl: string; + status: AgentConfigurationStatus; + generation: { + prompt: string; + model: { + providerId: string; + modelId: string; + }; + } | null; + action: { + type: string; + query: RetrievalQuery; + timeframe: RetrievalTimeframe; + topK: number; + dataSources: AgentDataSourceConfigurationType[]; + } | null; + } +): Promise { + const owner = auth.workspace(); + if (!owner) { + throw new Error("Cannot create AgentConfiguration without workspace"); + } + + return await front_sequelize.transaction(async (t) => { + let genConfig: AgentGenerationConfiguration | null = null; + let retrievalConfig: AgentRetrievalConfiguration | null = null; + let dataSourcesConfig: AgentDataSourceConfiguration[] = []; + + // Create Generation config + if (generation) { + const { prompt, model } = generation; + genConfig = await AgentGenerationConfiguration.create({ + prompt: prompt, + providerId: model.providerId, + modelId: model.modelId, + }); + } + + // Create Retrieval & Datasources configs + if (action && action.type === "retrieval_configuration") { + const { query, timeframe, topK, dataSources } = action; + retrievalConfig = await AgentRetrievalConfiguration.create( + { + query: isTemplatedQuery(query) ? "templated" : query, + queryTemplate: isTemplatedQuery(query) ? query.template : null, + relativeTimeFrame: isTimeFrame(timeframe) ? "custom" : timeframe, + relativeTimeFrameDuration: isTimeFrame(timeframe) + ? timeframe.duration + : null, + relativeTimeFrameUnit: isTimeFrame(timeframe) ? timeframe.unit : null, + topK: topK, + }, + { transaction: t } + ); + dataSourcesConfig = await _createAgentDataSourcesConfigData( + t, + dataSources, + retrievalConfig.id + ); + } + + // Create Agent config + const agentConfig = await AgentConfiguration.create({ + sId: generateModelSId(), + status: status, + name: name, + pictureUrl: pictureUrl, + scope: "workspace", + workspaceId: owner.id, + generationId: genConfig?.id ?? null, + retrievalId: retrievalConfig?.id ?? null, + }); + + return { + sId: agentConfig.sId, + name: agentConfig.name, + pictureUrl: agentConfig.pictureUrl, + status: agentConfig.status, + action: retrievalConfig + ? await _agentActionType(retrievalConfig, dataSourcesConfig) + : null, + generation: genConfig + ? { + id: genConfig.id, + prompt: genConfig.prompt, + model: { + providerId: genConfig.providerId, + modelId: genConfig.modelId, + }, + } + : null, + }; + }); +} + +/** + * Update Agent Generation Configuration + */ +export async function updateAgentGenerationConfiguration( + auth: Authenticator, + agentId: string, + { + name, + pictureUrl, + status, + generation, + }: { + name: string; + pictureUrl: string; + status: AgentConfigurationStatus; + generation: { + prompt: string; + model: { + providerId: string; + modelId: string; + }; + } | null; + } +): Promise { + const owner = auth.workspace(); + if (!owner) { + throw new Error( + "Cannot create AgentGenerationConfiguration: Workspace not found" + ); + } + const agentConfig = await AgentConfiguration.findOne({ + where: { + sId: agentId, + }, + }); + if (!agentConfig) { + throw new Error( + "Cannot create AgentGenerationConfiguration: Agent not found" + ); + } + const existingGeneration = agentConfig.generationId + ? await AgentGenerationConfiguration.findOne({ + where: { + id: agentConfig.generationId, + }, + }) + : null; + + const existingRetrivalConfig = agentConfig.retrievalId + ? await AgentRetrievalConfiguration.findOne({ + where: { + id: agentConfig.retrievalId, + }, + }) + : null; + + const existingDataSourcesConfig = existingRetrivalConfig?.id + ? await AgentDataSourceConfiguration.findAll({ + where: { + retrievalConfigurationId: existingRetrivalConfig.id, + }, + }) + : []; + + return await front_sequelize.transaction(async (t) => { + // Upserting Agent Config + const updatedAgentConfig = await agentConfig.update( + { + name: name, + pictureUrl: pictureUrl, + status: status, + }, + { transaction: t } + ); + + // Upserting Generation Config + let upsertedGenerationConfig: AgentGenerationConfiguration | null = null; + if (generation) { + const { prompt, model } = generation; + if (existingGeneration) { + upsertedGenerationConfig = await existingGeneration.update( + { + prompt: prompt, + providerId: model.providerId, + modelId: model.modelId, + }, + { transaction: t } + ); + } else { + upsertedGenerationConfig = await AgentGenerationConfiguration.create( + { + prompt: prompt, + providerId: model.providerId, + modelId: model.modelId, + }, + { transaction: t } + ); + } + } else if (existingGeneration) { + await existingGeneration.destroy(); + } + + return { + sId: updatedAgentConfig.sId, + name: updatedAgentConfig.name, + pictureUrl: updatedAgentConfig.pictureUrl, + status: updatedAgentConfig.status, + action: existingRetrivalConfig + ? await _agentActionType( + existingRetrivalConfig, + existingDataSourcesConfig + ) + : null, + generation: + generation && upsertedGenerationConfig + ? { + id: upsertedGenerationConfig.id, + prompt: upsertedGenerationConfig.prompt, + model: { + providerId: upsertedGenerationConfig.providerId, + modelId: upsertedGenerationConfig.modelId, + }, + } + : null, + }; + }); +} + +/** + * Update Agent Retrieval Configuration + * This will destroy and recreate the retrieval config + */ +export async function updateAgentRetrievalConfiguration( + auth: Authenticator, + agentId: string, + { + query, + timeframe, + topK, + dataSources, + }: { + query: RetrievalQuery; + timeframe: RetrievalTimeframe; + topK: number; + dataSources: AgentDataSourceConfigurationType[]; + } +): Promise { + const owner = auth.workspace(); + if (!owner) { + throw new Error( + "Cannot create AgentGenerationConfiguration: Workspace not found" + ); + } + const agentConfig = await AgentConfiguration.findOne({ + where: { + sId: agentId, + }, + }); + if (!agentConfig) { + throw new Error( + "Cannot create AgentGenerationConfiguration: Agent not found" + ); + } + const generationConfig = agentConfig.generationId + ? await AgentGenerationConfiguration.findOne({ + where: { + id: agentConfig.generationId, + }, + }) + : null; + + return await front_sequelize.transaction(async (t) => { + if (agentConfig.retrievalId) { + const existingRetrivalConfig = await AgentRetrievalConfiguration.findOne({ + where: { + id: agentConfig.retrievalId, + }, + }); + if (existingRetrivalConfig) { + await existingRetrivalConfig.destroy(); // That will destroy the dataSourcesConfig too + } + } + + const newRetrievalConfig = await AgentRetrievalConfiguration.create( + { + query: isTemplatedQuery(query) ? "templated" : query, + queryTemplate: isTemplatedQuery(query) ? query.template : null, + relativeTimeFrame: isTimeFrame(timeframe) ? "custom" : timeframe, + relativeTimeFrameDuration: isTimeFrame(timeframe) + ? timeframe.duration + : null, + relativeTimeFrameUnit: isTimeFrame(timeframe) ? timeframe.unit : null, + topK: topK, + }, + { transaction: t } + ); + const dataSourcesConfig = await _createAgentDataSourcesConfigData( + t, + dataSources, + newRetrievalConfig.id + ); + + return { + sId: agentConfig.sId, + name: agentConfig.name, + pictureUrl: agentConfig.pictureUrl, + status: agentConfig.status, + action: newRetrievalConfig + ? await _agentActionType(newRetrievalConfig, dataSourcesConfig) + : null, + generation: generationConfig + ? { + id: generationConfig.id, + prompt: generationConfig.prompt, + model: { + providerId: generationConfig.providerId, + modelId: generationConfig.modelId, + }, + } + : null, + }; + }); +} + +/** + * Builds the agent action configuration type from the model + */ +export async function _agentActionType( + action: AgentRetrievalConfiguration, + dataSourcesConfig: AgentDataSourceConfiguration[] +): Promise { + // Build Retrieval Timeframe + let timeframe: RetrievalTimeframe = "auto"; + if ( + action.relativeTimeFrame === "custom" && + action.relativeTimeFrameDuration && + action.relativeTimeFrameUnit + ) { + timeframe = { + duration: action.relativeTimeFrameDuration, + unit: action.relativeTimeFrameUnit, + }; + } else if (action.relativeTimeFrame === "none") { + timeframe = "none"; + } + + // Build Retrieval Query + let query: RetrievalQuery = "auto"; + if (action.query === "templated" && action.queryTemplate) { + query = { + template: action.queryTemplate, + }; + } else if (action.query === "none") { + query = "none"; + } + + // Build Retrieval DataSources + const dataSourcesIds = dataSourcesConfig?.map((ds) => ds.dataSourceId); + const dataSources = await DataSource.findAll({ + where: { + id: { [Op.in]: dataSourcesIds }, + }, + attributes: ["id", "name", "workspaceId"], + }); + const workspaceIds = dataSources.map((ds) => ds.workspaceId); + const workspaces = await Workspace.findAll({ + where: { + id: { [Op.in]: workspaceIds }, + }, + attributes: ["id", "sId"], + }); + + let dataSource: DataSource | undefined; + let workspace: Workspace | undefined; + const dataSourcesConfigType: AgentDataSourceConfigurationType[] = []; + + dataSourcesConfig.forEach(async (dsConfig) => { + dataSource = dataSources.find((ds) => ds.id === dsConfig.dataSourceId); + workspace = workspaces.find((w) => w.id === dataSource?.workspaceId); + + if (!dataSource || !workspace) { + throw new Error("Could not find dataSource or workspace"); + } + + dataSourcesConfigType.push({ + dataSourceName: dataSource.name, + workspaceSId: workspace.sId, + filter: { + tags: + dsConfig.tagsIn && dsConfig.tagsNotIn + ? { in: dsConfig.tagsIn, not: dsConfig.tagsNotIn } + : null, + parents: + dsConfig.parentsIn && dsConfig.parentsNotIn + ? { in: dsConfig.parentsIn, not: dsConfig.parentsNotIn } + : null, + }, + }); + }); + + return { + id: action.id, + type: "retrieval_configuration", + query: query, + relativeTimeFrame: timeframe, + topK: action.topK, + dataSources: dataSourcesConfigType, + }; +} + +/** + * Create the AgentDataSourceConfiguration rows in database. + * + * Knowing that a datasource is uniquely identified by its name and its workspaceId + * We need to fetch the dataSources from the database from that. + * We obvisously need to do as few queries as possible. + */ +export async function _createAgentDataSourcesConfigData( + t: Transaction, + dataSourcesConfig: AgentDataSourceConfigurationType[], + agentActionId: number +): Promise { + // dsConfig contains this format: + // [ + // { workspaceSId: s1o1u1p, dataSourceName: "managed-notion", filter: { tags: null, parents: null } }, + // { workspaceSId: s1o1u1p, dataSourceName: "managed-slack", filter: { tags: null, parents: null } }, + // { workspaceSId: i2n2o2u, dataSourceName: "managed-notion", filter: { tags: null, parents: null } }, + // ] + + // First we get the list of workspaces because we need the mapping between workspaceSId and workspaceId + const workspaces = await Workspace.findAll({ + where: { + sId: dataSourcesConfig.map((dsConfig) => dsConfig.workspaceSId), + }, + attributes: ["id", "sId"], + }); + + // Now will want to group the datasource names by workspaceId to do only one query per workspace. + // We want this: + // [ + // { workspaceId: 1, dataSourceNames: [""managed-notion", "managed-slack"] }, + // { workspaceId: 2, dataSourceNames: ["managed-notion"] } + // ] + type _DsNamesPerWorkspaceIdType = { + workspaceId: number; + dataSourceNames: string[]; + }; + const dsNamesPerWorkspaceId = dataSourcesConfig.reduce( + ( + acc: _DsNamesPerWorkspaceIdType[], + curr: AgentDataSourceConfigurationType + ) => { + // First we need to get the workspaceId from the workspaceSId + const workspace = workspaces.find((w) => w.sId === curr.workspaceSId); + if (!workspace) { + throw new Error("Workspace not found"); + } + + // Find an existing entry for this workspaceId + const existingEntry: _DsNamesPerWorkspaceIdType | undefined = acc.find( + (entry: _DsNamesPerWorkspaceIdType) => + entry.workspaceId === workspace.id + ); + if (existingEntry) { + // Append dataSourceName to existing entry + existingEntry.dataSourceNames.push(curr.dataSourceName); + } else { + // Add a new entry for this workspaceId + acc.push({ + workspaceId: workspace.id, + dataSourceNames: [curr.dataSourceName], + }); + } + return acc; + }, + [] + ); + + // Then we get do one findAllQuery per workspaceId, in a Promise.all + const getDataSourcesQueries = dsNamesPerWorkspaceId.map( + ({ workspaceId, dataSourceNames }) => { + return DataSource.findAll({ + where: { + workspaceId, + name: { + [Op.in]: dataSourceNames, + }, + }, + }); + } + ); + const results = await Promise.all(getDataSourcesQueries); + const dataSources = results.flat(); + + const agentDataSourcesConfigRows: AgentDataSourceConfiguration[] = + await Promise.all( + dataSourcesConfig.map(async (dsConfig) => { + const dataSource = dataSources.find( + (ds) => + ds.name === dsConfig.dataSourceName && + ds.workspaceId === + workspaces.find((w) => w.sId === dsConfig.workspaceSId)?.id + ); + if (!dataSource) { + throw new Error("DataSource not found"); + } + return AgentDataSourceConfiguration.create( + { + dataSourceId: dataSource.id, + tagsIn: dsConfig.filter.tags?.in, + tagsNotIn: dsConfig.filter.tags?.not, + parentsIn: dsConfig.filter.parents?.in, + parentsNotIn: dsConfig.filter.parents?.not, + retrievalConfigurationId: agentActionId, + }, + { transaction: t } + ); + }) + ); + return agentDataSourcesConfigRows; +} diff --git a/front/lib/api/assistant/conversation.ts b/front/lib/api/assistant/conversation.ts index b89e253484d3..1dd3da5926ef 100644 --- a/front/lib/api/assistant/conversation.ts +++ b/front/lib/api/assistant/conversation.ts @@ -75,92 +75,31 @@ export async function* postUserMessage( where: { conversationId: conversation.id, }, - transaction: t, - })) ?? -1) + 1; - - const m = await Message.create( - { - sId: generateModelSId(), - rank: nextMessageRank++, - conversationId: conversation.id, - parentId: null, - userMessageId: ( - await UserMessage.create( - { - message: message, - userContextUsername: context.username, - userContextTimezone: context.timezone, - userContextFullName: context.fullName, - userContextEmail: context.email, - userContextProfilePictureUrl: context.profilePictureUrl, - userId: user ? user.id : null, - }, - { transaction: t } - ) - ).id, - }, - { - transaction: t, - } - ); - - const userMessage: UserMessageType = { - id: m.id, - sId: m.sId, - type: "user_message", - visibility: "visible", - version: 0, - user: user, - mentions: mentions, - message: message, - context: context, - }; - - const agentMessages: AgentMessageType[] = []; - const agentMessageRows: AgentMessage[] = []; - - // for each assistant mention, create an "empty" agent message - for (const mention of mentions) { - if (isAgentMention(mention)) { - const agentMessageRow = await AgentMessage.create( - {}, - { transaction: t } - ); - const m = await Message.create( - { - sId: generateModelSId(), - rank: nextMessageRank++, - conversationId: conversation.id, - parentId: userMessage.id, - agentMessageId: agentMessageRow.id, - }, - { - transaction: t, - } - ); - agentMessageRows.push(agentMessageRow); - agentMessages.push({ - id: m.id, - sId: m.sId, - type: "agent_message", - visibility: "visible", - version: 0, - parentMessageId: userMessage.sId, - status: "created", - action: null, - message: null, - feedbacks: [], - error: null, - configuration: { - sId: mention.configurationId, - status: "active", - name: "foo", // TODO - pictureUrl: null, // TODO - action: null, // TODO - generation: null, // TODO - }, - }); - } + { + transaction: t, + } + ); + agentMessages.push({ + id: agentMessageRow.id, + sId: agentMessageRow.sId, + type: "agent_message", + visibility: "visible", + version: 0, + parentMessageId: userMessage.sId, + status: "created", + action: null, + message: null, + feedbacks: [], + error: null, + configuration: { + sId: m.configurationId, + status: "active", + name: "foo", // TODO + pictureUrl: null, // TODO + action: null, // TODO + generation: null, // TODO + }, + }); } return { userMessage, agentMessages, agentMessageRows }; diff --git a/front/lib/models/assistant/actions/retrieval.ts b/front/lib/models/assistant/actions/retrieval.ts index 95adea5566fd..6394c461465c 100644 --- a/front/lib/models/assistant/actions/retrieval.ts +++ b/front/lib/models/assistant/actions/retrieval.ts @@ -8,199 +8,9 @@ import { } from "sequelize"; import { front_sequelize } from "@app/lib/databases"; -import { AgentConfiguration } from "@app/lib/models/assistant/agent"; -import { DataSource } from "@app/lib/models/data_source"; +import { AgentRetrievalConfiguration } from "@app/lib/models/assistant/configuration"; import { TimeframeUnit } from "@app/types/assistant/actions/retrieval"; -/** - * Action Retrieval configuration - */ -export class AgentRetrievalConfiguration extends Model< - InferAttributes, - InferCreationAttributes -> { - declare id: CreationOptional; - declare createdAt: CreationOptional; - declare updatedAt: CreationOptional; - - declare query: "auto" | "none" | "templated"; - declare queryTemplate: string | null; - declare relativeTimeFrame: "auto" | "none" | "custom"; - declare relativeTimeFrameDuration: number | null; - declare relativeTimeFrameUnit: TimeframeUnit | null; - declare topK: number; - - declare agentId: ForeignKey; -} -AgentRetrievalConfiguration.init( - { - id: { - type: DataTypes.INTEGER, - autoIncrement: true, - primaryKey: true, - }, - createdAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - updatedAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - query: { - type: DataTypes.STRING, - allowNull: false, - defaultValue: "auto", - }, - queryTemplate: { - type: DataTypes.TEXT, - allowNull: true, - }, - relativeTimeFrame: { - type: DataTypes.STRING, - allowNull: false, - defaultValue: "auto", - }, - relativeTimeFrameDuration: { - type: DataTypes.INTEGER, - allowNull: true, - }, - relativeTimeFrameUnit: { - type: DataTypes.STRING, - allowNull: true, - }, - topK: { - type: DataTypes.INTEGER, - allowNull: false, - }, - }, - { - modelName: "agent_retrieval_configuration", - sequelize: front_sequelize, - hooks: { - beforeValidate: (retrieval: AgentRetrievalConfiguration) => { - // Validation for templated Query - if (retrieval.query == "templated") { - if (retrieval.queryTemplate === null) { - throw new Error("Must set a template for templated query"); - } - } else if (retrieval.queryTemplate !== null) { - throw new Error("Can't set a template without templated query"); - } - - // Validation for Timeframe - if (retrieval.relativeTimeFrame == "custom") { - if ( - retrieval.relativeTimeFrameDuration === null || - retrieval.relativeTimeFrameUnit === null - ) { - throw new Error( - "Custom relative time frame must have a duration and unit set" - ); - } - } - }, - }, - } -); - -/** - * Configuration of Datasources used for Retrieval Action. - */ -export class AgentDataSourceConfiguration extends Model< - InferAttributes, - InferCreationAttributes -> { - declare id: CreationOptional; - declare createdAt: CreationOptional; - declare updatedAt: CreationOptional; - - declare tagsIn: string[] | null; - declare tagsNotIn: string[] | null; - declare parentsIn: string[] | null; - declare parentsNotIn: string[] | null; - - declare dataSourceId: ForeignKey; - declare retrievalConfigurationId: ForeignKey< - AgentRetrievalConfiguration["id"] - >; -} -AgentDataSourceConfiguration.init( - { - id: { - type: DataTypes.INTEGER, - autoIncrement: true, - primaryKey: true, - }, - createdAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - updatedAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - tagsIn: { - type: DataTypes.ARRAY(DataTypes.STRING), - allowNull: true, - }, - tagsNotIn: { - type: DataTypes.ARRAY(DataTypes.STRING), - allowNull: true, - }, - parentsIn: { - type: DataTypes.ARRAY(DataTypes.STRING), - allowNull: true, - }, - parentsNotIn: { - type: DataTypes.ARRAY(DataTypes.STRING), - allowNull: true, - }, - }, - { - modelName: "agent_data_source_configuration", - sequelize: front_sequelize, - hooks: { - beforeValidate: (dataSourceConfig: AgentDataSourceConfiguration) => { - if ( - (dataSourceConfig.tagsIn === null) !== - (dataSourceConfig.tagsNotIn === null) - ) { - throw new Error("Tags must be both set or both null"); - } - if ( - (dataSourceConfig.parentsIn === null) !== - (dataSourceConfig.parentsNotIn === null) - ) { - throw new Error("Parents must be both set or both null"); - } - }, - }, - } -); - -// Retrieval config <> data source config -AgentRetrievalConfiguration.hasMany(AgentDataSourceConfiguration, { - foreignKey: { name: "retrievalId", allowNull: false }, - onDelete: "CASCADE", -}); - -// Data source <> Data source config -DataSource.hasMany(AgentDataSourceConfiguration, { - foreignKey: { name: "dataSourceId", allowNull: false }, - onDelete: "CASCADE", -}); - -// Agent config <> Retrieval config -AgentConfiguration.hasOne(AgentRetrievalConfiguration, { - foreignKey: { name: "agentId", allowNull: true }, // null = no generation set for this Agent - onDelete: "CASCADE", -}); - /** * Retrieval Action */ diff --git a/front/lib/models/assistant/agent.ts b/front/lib/models/assistant/agent.ts deleted file mode 100644 index c10393cceac4..000000000000 --- a/front/lib/models/assistant/agent.ts +++ /dev/null @@ -1,167 +0,0 @@ -import { - CreationOptional, - DataTypes, - ForeignKey, - InferAttributes, - InferCreationAttributes, - Model, -} from "sequelize"; - -import { front_sequelize } from "@app/lib/databases"; -import { AgentRetrievalConfiguration } from "@app/lib/models/assistant/actions/retrieval"; -import { Workspace } from "@app/lib/models/workspace"; -import { - AgentConfigurationScope, - AgentConfigurationStatus, -} from "@app/types/assistant/agent"; - -/** - * Agent configuration - */ -export class AgentConfiguration extends Model< - InferAttributes, - InferCreationAttributes -> { - declare id: CreationOptional; - declare createdAt: CreationOptional; - declare updatedAt: CreationOptional; - - declare sId: string; - declare status: AgentConfigurationStatus; - declare name: string; - declare pictureUrl: string | null; - - declare scope: AgentConfigurationScope; - declare workspaceId: ForeignKey | null; // null = it's a global agent - - declare model: ForeignKey | null; -} -AgentConfiguration.init( - { - id: { - type: DataTypes.INTEGER, - autoIncrement: true, - primaryKey: true, - }, - createdAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - updatedAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - sId: { - type: DataTypes.STRING, - allowNull: false, - unique: true, - }, - status: { - type: DataTypes.STRING, - allowNull: false, - defaultValue: "active", - }, - name: { - type: DataTypes.TEXT, - allowNull: false, - }, - pictureUrl: { - type: DataTypes.TEXT, - allowNull: true, - }, - scope: { - type: DataTypes.STRING, - allowNull: false, - defaultValue: "workspace", - }, - }, - { - modelName: "agent_configuration", - sequelize: front_sequelize, - indexes: [ - { fields: ["workspaceId"] }, - // Unique name per workspace. - // Note that on PostgreSQL a unique constraint on multiple columns will treat NULL - // as distinct from any other value, so we can create twice the same name if at least - // one of the workspaceId is null. We're okay with it. - { fields: ["workspaceId", "name", "scope"], unique: true }, - { fields: ["sId"], unique: true }, - ], - hooks: { - beforeValidate: (agent: AgentConfiguration) => { - if (agent.scope !== "workspace" && agent.workspaceId) { - throw new Error("Workspace id must be null for global agent"); - } else if (agent.scope === "workspace" && !agent.workspaceId) { - throw new Error("Workspace id must be set for non-global agent"); - } - }, - }, - } -); - -/** - * Configuration of Agent generation. - */ -export class AgentGenerationConfiguration extends Model< - InferAttributes, - InferCreationAttributes -> { - declare id: CreationOptional; - declare createdAt: CreationOptional; - declare updatedAt: CreationOptional; - - declare prompt: string; - declare modelProvider: string; - declare modelId: string; - - declare agentId: ForeignKey; -} -AgentGenerationConfiguration.init( - { - id: { - type: DataTypes.INTEGER, - autoIncrement: true, - primaryKey: true, - }, - createdAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - updatedAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - prompt: { - type: DataTypes.TEXT, - allowNull: false, - }, - modelProvider: { - type: DataTypes.STRING, - allowNull: false, - }, - modelId: { - type: DataTypes.STRING, - allowNull: false, - }, - }, - { - modelName: "agent_generation_configuration", - sequelize: front_sequelize, - } -); - -// Workspace <> Agent config -Workspace.hasMany(AgentConfiguration, { - foreignKey: { name: "workspaceId", allowNull: true }, // null = global Agent - onDelete: "CASCADE", -}); - -// Agent config <> Generation config -AgentConfiguration.hasOne(AgentGenerationConfiguration, { - foreignKey: { name: "agentId", allowNull: false }, // null = no retrieval action set for this Agent - onDelete: "CASCADE", -}); diff --git a/front/lib/models/assistant/configuration.ts b/front/lib/models/assistant/configuration.ts new file mode 100644 index 000000000000..018fc1683830 --- /dev/null +++ b/front/lib/models/assistant/configuration.ts @@ -0,0 +1,352 @@ +import { + CreationOptional, + DataTypes, + ForeignKey, + InferAttributes, + InferCreationAttributes, + Model, +} from "sequelize"; + +import { front_sequelize } from "@app/lib/databases"; +import { DataSource } from "@app/lib/models/data_source"; +import { Workspace } from "@app/lib/models/workspace"; +import { TimeframeUnit } from "@app/types/assistant/actions/retrieval"; +import { + AgentConfigurationScope, + AgentConfigurationStatus, +} from "@app/types/assistant/configuration"; + +/** + * Agent configuration + */ +export class AgentConfiguration extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: CreationOptional; + declare createdAt: CreationOptional; + declare updatedAt: CreationOptional; + + declare sId: string; + declare status: AgentConfigurationStatus; + declare name: string; + declare pictureUrl: string | null; + declare scope: AgentConfigurationScope; + + declare workspaceId: ForeignKey | null; // null = it's a global agent + declare generationId: ForeignKey | null; + declare retrievalId: ForeignKey | null; +} +AgentConfiguration.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + createdAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + updatedAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + sId: { + type: DataTypes.STRING, + allowNull: false, + unique: true, + }, + status: { + type: DataTypes.STRING, + allowNull: false, + defaultValue: "active", + }, + name: { + type: DataTypes.TEXT, + allowNull: false, + }, + pictureUrl: { + type: DataTypes.TEXT, + allowNull: true, + }, + scope: { + type: DataTypes.STRING, + allowNull: false, + defaultValue: "workspace", + }, + }, + { + modelName: "agent_configuration", + sequelize: front_sequelize, + indexes: [ + { fields: ["workspaceId"] }, + // Unique name per workspace. + // Note that on PostgreSQL a unique constraint on multiple columns will treat NULL + // as distinct from any other value, so we can create twice the same name if at least + // one of the workspaceId is null. We're okay with it. + { fields: ["workspaceId", "name", "scope"], unique: true }, + { fields: ["sId"], unique: true }, + ], + hooks: { + beforeValidate: (agent: AgentConfiguration) => { + if (agent.scope !== "workspace" && agent.workspaceId) { + throw new Error("Workspace id must be null for global agent"); + } else if (agent.scope === "workspace" && !agent.workspaceId) { + throw new Error("Workspace id must be set for non-global agent"); + } + }, + }, + } +); + +/** + * Configuration of Agent generation. + */ +export class AgentGenerationConfiguration extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: CreationOptional; + declare createdAt: CreationOptional; + declare updatedAt: CreationOptional; + + declare prompt: string; + declare providerId: string; + declare modelId: string; +} +AgentGenerationConfiguration.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + createdAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + updatedAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + prompt: { + type: DataTypes.TEXT, + allowNull: false, + }, + providerId: { + type: DataTypes.STRING, + allowNull: false, + }, + modelId: { + type: DataTypes.STRING, + allowNull: false, + }, + }, + { + modelName: "agent_generation_configuration", + sequelize: front_sequelize, + } +); + +/** + * Action Retrieval configuration + */ +export class AgentRetrievalConfiguration extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: CreationOptional; + declare createdAt: CreationOptional; + declare updatedAt: CreationOptional; + + declare query: "auto" | "none" | "templated"; + declare queryTemplate: string | null; + declare relativeTimeFrame: "auto" | "none" | "custom"; + declare relativeTimeFrameDuration: number | null; + declare relativeTimeFrameUnit: TimeframeUnit | null; + declare topK: number; +} +AgentRetrievalConfiguration.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + createdAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + updatedAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + query: { + type: DataTypes.STRING, + allowNull: false, + defaultValue: "auto", + }, + queryTemplate: { + type: DataTypes.TEXT, + allowNull: true, + }, + relativeTimeFrame: { + type: DataTypes.STRING, + allowNull: false, + defaultValue: "auto", + }, + relativeTimeFrameDuration: { + type: DataTypes.INTEGER, + allowNull: true, + }, + relativeTimeFrameUnit: { + type: DataTypes.STRING, + allowNull: true, + }, + topK: { + type: DataTypes.INTEGER, + allowNull: false, + }, + }, + { + modelName: "agent_retrieval_configuration", + sequelize: front_sequelize, + hooks: { + beforeValidate: (retrieval: AgentRetrievalConfiguration) => { + // Validation for templated Query + if (retrieval.query == "templated") { + if (retrieval.queryTemplate === null) { + throw new Error("Must set a template for templated query"); + } + } else if (retrieval.queryTemplate !== null) { + throw new Error("Can't set a template without templated query"); + } + + // Validation for Timeframe + if (retrieval.relativeTimeFrame == "custom") { + if ( + retrieval.relativeTimeFrameDuration === null || + retrieval.relativeTimeFrameUnit === null + ) { + throw new Error( + "Custom relative time frame must have a duration and unit set" + ); + } + } + }, + }, + } +); + +/** + * Configuration of Datasources used for Retrieval Action. + */ +export class AgentDataSourceConfiguration extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: CreationOptional; + declare createdAt: CreationOptional; + declare updatedAt: CreationOptional; + + declare tagsIn: string[] | null; + declare tagsNotIn: string[] | null; + declare parentsIn: string[] | null; + declare parentsNotIn: string[] | null; + + declare dataSourceId: ForeignKey; + declare retrievalConfigurationId: ForeignKey< + AgentRetrievalConfiguration["id"] + >; +} +AgentDataSourceConfiguration.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + createdAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + updatedAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + tagsIn: { + type: DataTypes.ARRAY(DataTypes.STRING), + allowNull: true, + }, + tagsNotIn: { + type: DataTypes.ARRAY(DataTypes.STRING), + allowNull: true, + }, + parentsIn: { + type: DataTypes.ARRAY(DataTypes.STRING), + allowNull: true, + }, + parentsNotIn: { + type: DataTypes.ARRAY(DataTypes.STRING), + allowNull: true, + }, + }, + { + modelName: "agent_data_source_configuration", + sequelize: front_sequelize, + hooks: { + beforeValidate: (dataSourceConfig: AgentDataSourceConfiguration) => { + if ( + (dataSourceConfig.tagsIn === null) !== + (dataSourceConfig.tagsNotIn === null) + ) { + throw new Error("Tags must be both set or both null"); + } + if ( + (dataSourceConfig.parentsIn === null) !== + (dataSourceConfig.parentsNotIn === null) + ) { + throw new Error("Parents must be both set or both null"); + } + }, + }, + } +); + +// Agent config <> Workspace +Workspace.hasMany(AgentConfiguration, { + foreignKey: { name: "workspaceId", allowNull: true }, // null = global Agent + onDelete: "CASCADE", +}); + +// Agent config <> Generation config +AgentConfiguration.hasOne(AgentGenerationConfiguration, { + foreignKey: { name: "generationId", allowNull: true }, // null = no generation set for this Agent + onDelete: "CASCADE", +}); +// Agent config <> Retrieval config +AgentConfiguration.hasOne(AgentRetrievalConfiguration, { + foreignKey: { name: "retrievalId", allowNull: true }, // null = no retrieval action set for this Agent + onDelete: "CASCADE", +}); + +// Retrieval config <> Data source config +AgentRetrievalConfiguration.hasMany(AgentDataSourceConfiguration, { + foreignKey: { name: "retrievalId", allowNull: false }, + onDelete: "CASCADE", +}); + +// Data source config <> Data source +DataSource.hasMany(AgentDataSourceConfiguration, { + foreignKey: { name: "dataSourceId", allowNull: false }, + onDelete: "CASCADE", +}); diff --git a/front/lib/models/index.ts b/front/lib/models/index.ts index 0bc9edbbbdc0..d15b932065a2 100644 --- a/front/lib/models/index.ts +++ b/front/lib/models/index.ts @@ -1,15 +1,15 @@ import { App, Clone, Dataset, Provider, Run } from "@app/lib/models/apps"; import { - AgentDataSourceConfiguration, AgentRetrievalAction, - AgentRetrievalConfiguration, RetrievalDocument, RetrievalDocumentChunk, } from "@app/lib/models/assistant/actions/retrieval"; import { AgentConfiguration, + AgentDataSourceConfiguration, AgentGenerationConfiguration, -} from "@app/lib/models/assistant/agent"; + AgentRetrievalConfiguration, +} from "@app/lib/models/assistant/configuration"; import { AgentMessage, Conversation, diff --git a/front/types/assistant/actions/retrieval.ts b/front/types/assistant/actions/retrieval.ts index eff6727df63b..47a8138a6120 100644 --- a/front/types/assistant/actions/retrieval.ts +++ b/front/types/assistant/actions/retrieval.ts @@ -1,9 +1,5 @@ -/** - * Data Source configuration - */ - import { ModelId } from "@app/lib/databases"; -import { AgentActionConfigurationType } from "@app/types/assistant/agent"; +import { AgentDataSourceConfigurationType } from "@app/types/assistant/configuration"; import { AgentActionType } from "@app/types/assistant/conversation"; export type TimeframeUnit = "hour" | "day" | "week" | "month" | "year"; @@ -12,22 +8,6 @@ export type TimeFrame = { unit: TimeframeUnit; }; -export type DataSourceFilter = { - tags: { in: string[]; not: string[] } | null; - parents: { in: string[]; not: string[] } | null; -}; - -// This is used to talk with Dust Apps and Core, so it store external Ids. -export type AgentDataSourceConfigurationType = { - workspaceSId: string; // = Workspace.sId - dataSourceName: string; // = Datasource.name - filter: DataSourceFilter; -}; - -/** - * Retrieval configuration - */ - export type TemplatedQuery = { template: string; }; @@ -51,27 +31,6 @@ export function isTimeFrame(arg: RetrievalTimeframe): arg is TimeFrame { // `dataSources` to query the data. export type RetrievalTimeframe = "auto" | "none" | TimeFrame; export type RetrievalQuery = "auto" | "none" | TemplatedQuery; -export type RetrievalDataSourcesConfiguration = - AgentDataSourceConfigurationType[]; - -export type RetrievalConfigurationType = { - id: ModelId; - - type: "retrieval_configuration"; - dataSources: RetrievalDataSourcesConfiguration; - query: RetrievalQuery; - relativeTimeFrame: RetrievalTimeframe; - topK: number; - - // Dynamically decide to skip, if needed in the future - // autoSkip: boolean; -}; - -export function isRetrievalConfiguration( - arg: AgentActionConfigurationType | null -): arg is RetrievalConfigurationType { - return arg !== null && arg.type && arg.type === "retrieval_configuration"; -} /** * Retrieval action @@ -103,7 +62,7 @@ export type RetrievalActionType = { id: ModelId; // AgentRetrieval. type: "retrieval_action"; params: { - dataSources: "all" | AgentDataSourceConfigurationType[]; + dataSources: AgentDataSourceConfigurationType[]; relativeTimeFrame: TimeFrame | null; query: string | null; topK: number; diff --git a/front/types/assistant/agent.ts b/front/types/assistant/agent.ts index f25d0197c184..4eb7cd6c0040 100644 --- a/front/types/assistant/agent.ts +++ b/front/types/assistant/agent.ts @@ -1,15 +1,7 @@ -import { ModelId } from "@app/lib/databases"; -import { RetrievalConfigurationType } from "@app/types/assistant/actions/retrieval"; - /** - * Agent Action configuration + * Agent Action */ -// New AgentActionConfigurationType checklist: -// - Add the type to the union type below -// - Add model rendering support in `renderConversationForModel` -export type AgentActionConfigurationType = RetrievalConfigurationType; - // Each AgentActionConfigurationType is capable of generating this type at runtime to specify which // inputs should be generated by the model. As an example, to run the retrieval action for which the // `relativeTimeFrame` has been specified in the configuration but for which the `query` is "auto", @@ -32,7 +24,6 @@ export type AgentActionConfigurationType = RetrievalConfigurationType; // ``` export type AgentActionSpecification = { - id: ModelId; name: string; description: string; inputs: { @@ -41,40 +32,3 @@ export type AgentActionSpecification = { type: "string" | "number" | "boolean"; }[]; }; - -/** - * Agent Message configuration - */ - -export type AgentGenerationConfigurationType = { - id: ModelId; - prompt: string; - model: { - providerId: string; - modelId: string; - }; -}; - -/** - * Agent configuration - */ - -export type AgentConfigurationStatus = "active" | "archived"; -export type AgentConfigurationScope = "global" | "workspace"; - -export type AgentConfigurationType = { - id: ModelId; - sId: string; - status: AgentConfigurationStatus; - name: string; - pictureUrl: string | null; -}; - -export type AgentFullConfigurationType = { - agent: AgentConfigurationType; - // If undefined, no action performed, otherwise the action is - // performed (potentially NoOp eg autoSkip above). - action: AgentActionConfigurationType | null; - // If undefined, no text generation. - generation: AgentGenerationConfigurationType | null; -}; diff --git a/front/types/assistant/configuration.ts b/front/types/assistant/configuration.ts new file mode 100644 index 000000000000..62d825a62beb --- /dev/null +++ b/front/types/assistant/configuration.ts @@ -0,0 +1,66 @@ +import { ModelId } from "@app/lib/databases"; +import { + RetrievalQuery, + RetrievalTimeframe, +} from "@app/types/assistant/actions/retrieval"; + +/** + * Agent config + */ +export type AgentConfigurationStatus = "active" | "archived"; +export type AgentConfigurationScope = "global" | "workspace"; +export type AgentConfigurationType = { + sId: string; + status: AgentConfigurationStatus; + name: string; + pictureUrl: string | null; + action: AgentActionConfigurationType | null; // If undefined, no action performed + generation: AgentGenerationConfigurationType | null; // If undefined, no text generation. +}; + +/** + * Generation config + */ +export type AgentGenerationConfigurationType = { + id: ModelId; + prompt: string; + model: { + providerId: string; + modelId: string; + }; +}; + +/** + * Action > Retrieval + */ +export type AgentActionConfigurationType = RetrievalConfigurationType; + +/** + * Retrieval Action config + */ +export type RetrievalConfigurationType = { + id: ModelId; + + type: "retrieval_configuration"; + dataSources: AgentDataSourceConfigurationType[]; + query: RetrievalQuery; + relativeTimeFrame: RetrievalTimeframe; + topK: number; +}; +export function isRetrievalConfiguration( + arg: AgentActionConfigurationType | null +): arg is RetrievalConfigurationType { + return arg !== null && arg.type && arg.type === "retrieval_configuration"; +} + +/** + * Datasources config for Retrieval Action + */ +export type AgentDataSourceConfigurationType = { + workspaceSId: string; // need sId to talk with Core (external id) + dataSourceName: string; // need Datasource.name to talk with Core (external id) + filter: { + tags: { in: string[]; not: string[] } | null; + parents: { in: string[]; not: string[] } | null; + }; +}; diff --git a/front/types/assistant/conversation.ts b/front/types/assistant/conversation.ts index 89b1d3cbd9a4..342d1f8a246f 100644 --- a/front/types/assistant/conversation.ts +++ b/front/types/assistant/conversation.ts @@ -1,5 +1,5 @@ import { ModelId } from "@app/lib/databases"; -import { AgentFullConfigurationType } from "@app/types/assistant/agent"; +import { AgentConfigurationType } from "@app/types/assistant/configuration"; import { UserType } from "@app/types/user"; import { RetrievalActionType } from "./actions/retrieval"; @@ -91,7 +91,7 @@ export type AgentMessageType = { version: number; parentMessageId: string | null; - configuration: AgentFullConfigurationType; + configuration: AgentConfigurationType; status: AgentMessageStatus; action: AgentActionType | null; message: string | null;