diff --git a/front/lib/api/assistant/configuration.ts b/front/lib/api/assistant/configuration.ts index 7414667cfb206..4f78fb50aa83f 100644 --- a/front/lib/api/assistant/configuration.ts +++ b/front/lib/api/assistant/configuration.ts @@ -37,37 +37,49 @@ export async function getAgentConfiguration( throw new Error("Cannot find AgentConfiguration: no workspace."); } const agent = await AgentConfiguration.findOne({ + //logging: console.log, where: { sId: agentId, workspaceId: owner.id, }, + include: [ + { + model: AgentGenerationConfiguration, + as: "generationConfiguration", + }, + { + model: AgentRetrievalConfiguration, + as: "retrievalConfiguration", + include: [ + { + model: AgentDataSourceConfiguration, + as: "dataSourceConfigurations", + include: [ + { + model: DataSource, + as: "ds", + attributes: ["id", "name", "workspaceId"], + include: [ + { + model: Workspace, + as: "w", + attributes: ["id", "sId"], + }, + ], + }, + ], + }, + ], + }, + ], }); if (!agent) { throw new Error("Cannot find AgentConfiguration."); } - const generationConfig = agent.generationConfigurationId - ? await AgentGenerationConfiguration.findOne({ - where: { - id: agent.generationConfigurationId, - }, - }) - : null; - - const actionConfig = agent.retrievalConfigurationId - ? await AgentRetrievalConfiguration.findOne({ - where: { - id: agent.retrievalConfigurationId, - }, - }) - : null; - const dataSourcesConfig = actionConfig?.id - ? await AgentDataSourceConfiguration.findAll({ - where: { - retrievalConfigurationId: actionConfig.id, - }, - }) - : []; + const generationConfig = agent.generationConfiguration; + const actionConfig = agent.retrievalConfiguration; + const dataSourcesConfig = actionConfig?.dataSourceConfigurations || []; return { sId: agent.sId, @@ -75,10 +87,29 @@ export async function getAgentConfiguration( pictureUrl: agent.pictureUrl, status: agent.status, action: actionConfig - ? await renderAgentActionConfigurationType( - actionConfig, - dataSourcesConfig - ) + ? { + id: actionConfig.id, + type: "retrieval_configuration", + query: renderRetrievalQueryType(actionConfig), + relativeTimeFrame: renderRetrievalTimeframeType(actionConfig), + topK: actionConfig.topK, + dataSources: dataSourcesConfig.map((dsConfig) => { + return { + dataSourceId: dsConfig.ds.name, + workspaceId: dsConfig.ds.w.sId, + filter: { + tags: + dsConfig.tagsIn && dsConfig.tagsNotIn + ? { in: dsConfig.tagsIn, not: dsConfig.tagsNotIn } + : null, + parents: + dsConfig.parentsIn && dsConfig.parentsNotIn + ? { in: dsConfig.parentsIn, not: dsConfig.parentsNotIn } + : null, + }, + }; + }), + } : null, generation: generationConfig ? { @@ -164,7 +195,38 @@ export async function updateAgentConfiguration( const agentConfig = await AgentConfiguration.findOne({ where: { sId: agentId, + workspaceId: owner.id, }, + include: [ + { + model: AgentGenerationConfiguration, + as: "generationConfiguration", + }, + { + model: AgentRetrievalConfiguration, + as: "retrievalConfiguration", + include: [ + { + model: AgentDataSourceConfiguration, + as: "dataSourceConfigurations", + include: [ + { + model: DataSource, + as: "ds", + attributes: ["id", "name", "workspaceId"], + include: [ + { + model: Workspace, + as: "w", + attributes: ["id", "sId"], + }, + ], + }, + ], + }, + ], + }, + ], }); if (!agentConfig) { throw new Error( @@ -178,49 +240,47 @@ export async function updateAgentConfiguration( status: status, }); - // Return the config with Generation and Action if any - const existingGeneration = agentConfig.generationConfigurationId - ? await AgentGenerationConfiguration.findOne({ - where: { - id: agentConfig.generationConfigurationId, - }, - }) - : null; - - const existingRetrivalConfig = agentConfig.retrievalConfigurationId - ? await AgentRetrievalConfiguration.findOne({ - where: { - id: agentConfig.retrievalConfigurationId, - }, - }) - : null; - - const existingDataSourcesConfig = existingRetrivalConfig?.id - ? await AgentDataSourceConfiguration.findAll({ - where: { - retrievalConfigurationId: existingRetrivalConfig.id, - }, - }) - : []; + const generationConfig = agentConfig.generationConfiguration; + const actionConfig = agentConfig.retrievalConfiguration; + const dataSourcesConfig = actionConfig?.dataSourceConfigurations || []; return { - sId: updatedAgentConfig.sId, + sId: agentConfig.sId, name: updatedAgentConfig.name, pictureUrl: updatedAgentConfig.pictureUrl, status: updatedAgentConfig.status, - action: existingRetrivalConfig - ? await renderAgentActionConfigurationType( - existingRetrivalConfig, - existingDataSourcesConfig - ) + action: actionConfig + ? { + id: actionConfig.id, + type: "retrieval_configuration", + query: renderRetrievalQueryType(actionConfig), + relativeTimeFrame: renderRetrievalTimeframeType(actionConfig), + topK: actionConfig.topK, + dataSources: dataSourcesConfig.map((dsConfig) => { + return { + dataSourceId: dsConfig.ds.name, + workspaceId: dsConfig.ds.w.sId, + filter: { + tags: + dsConfig.tagsIn && dsConfig.tagsNotIn + ? { in: dsConfig.tagsIn, not: dsConfig.tagsNotIn } + : null, + parents: + dsConfig.parentsIn && dsConfig.parentsNotIn + ? { in: dsConfig.parentsIn, not: dsConfig.parentsNotIn } + : null, + }, + }; + }), + } : null, - generation: existingGeneration + generation: generationConfig ? { - id: existingGeneration.id, - prompt: existingGeneration.prompt, + id: generationConfig.id, + prompt: generationConfig.prompt, model: { - providerId: existingGeneration.providerId, - modelId: existingGeneration.modelId, + providerId: generationConfig.providerId, + modelId: generationConfig.modelId, }, } : null, @@ -374,16 +434,16 @@ export async function createAgentActionConfiguration( }, { transaction: t } ); - const dataSourcesConfig = await _createAgentDataSourcesConfigData( - t, + await _createAgentDataSourcesConfigData(t, dataSources, retrievalConfig.id); + + return { + id: retrievalConfig.id, + type: "retrieval_configuration", + query, + relativeTimeFrame: timeframe, + topK, dataSources, - retrievalConfig.id - ); - - return await renderAgentActionConfigurationType( - retrievalConfig, - dataSourcesConfig - ); + }; }); } @@ -466,27 +526,24 @@ export async function updateAgentActionConfiguration( }); // Create new dataSources config - const dataSourcesConfig = await _createAgentDataSourcesConfigData( + await _createAgentDataSourcesConfigData( t, dataSources, updatedRetrievalConfig.id ); - return await renderAgentActionConfigurationType( - updatedRetrievalConfig, - dataSourcesConfig - ); + return { + id: updatedRetrievalConfig.id, + type: "retrieval_configuration", + query, + relativeTimeFrame: timeframe, + topK, + dataSources, + }; }); } -/** - * Builds the agent action configuration type from the model - */ -async function renderAgentActionConfigurationType( - action: AgentRetrievalConfiguration, - dataSourcesConfig: AgentDataSourceConfiguration[] -): Promise { - // Build Retrieval Timeframe +function renderRetrievalTimeframeType(action: AgentRetrievalConfiguration) { let timeframe: RetrievalTimeframe = "auto"; if ( action.relativeTimeFrame === "custom" && @@ -500,8 +557,10 @@ async function renderAgentActionConfigurationType( } else if (action.relativeTimeFrame === "none") { timeframe = "none"; } + return timeframe; +} - // Build Retrieval Query +function renderRetrievalQueryType(action: AgentRetrievalConfiguration) { let query: RetrievalQuery = "auto"; if (action.query === "templated" && action.queryTemplate) { query = { @@ -510,59 +569,7 @@ async function renderAgentActionConfigurationType( } else if (action.query === "none") { query = "none"; } - - // Build Retrieval DataSources - const dataSourcesIds = dataSourcesConfig?.map((ds) => ds.dataSourceId); - const dataSources = await DataSource.findAll({ - where: { - id: { [Op.in]: dataSourcesIds }, - }, - attributes: ["id", "name", "workspaceId"], - }); - const workspaceIds = dataSources.map((ds) => ds.workspaceId); - const workspaces = await Workspace.findAll({ - where: { - id: { [Op.in]: workspaceIds }, - }, - attributes: ["id", "sId"], - }); - - let dataSource: DataSource | undefined; - let workspace: Workspace | undefined; - const dataSourcesConfigType: DataSourceConfiguration[] = []; - - dataSourcesConfig.forEach(async (dsConfig) => { - dataSource = dataSources.find((ds) => ds.id === dsConfig.dataSourceId); - workspace = workspaces.find((w) => w.id === dataSource?.workspaceId); - - if (!dataSource || !workspace) { - throw new Error("Can't render Agent Retrieval dataSources: not found."); - } - - dataSourcesConfigType.push({ - dataSourceId: dataSource.name, - workspaceId: workspace.sId, - filter: { - tags: - dsConfig.tagsIn && dsConfig.tagsNotIn - ? { in: dsConfig.tagsIn, not: dsConfig.tagsNotIn } - : null, - parents: - dsConfig.parentsIn && dsConfig.parentsNotIn - ? { in: dsConfig.parentsIn, not: dsConfig.parentsNotIn } - : null, - }, - }); - }); - - return { - id: action.id, - type: "retrieval_configuration", - query: query, - relativeTimeFrame: timeframe, - topK: action.topK, - dataSources: dataSourcesConfigType, - }; + return query; } /** diff --git a/front/lib/models/assistant/actions/retrieval.ts b/front/lib/models/assistant/actions/retrieval.ts index 238057dab3f8e..ce55b62dcb89f 100644 --- a/front/lib/models/assistant/actions/retrieval.ts +++ b/front/lib/models/assistant/actions/retrieval.ts @@ -5,6 +5,7 @@ import { InferAttributes, InferCreationAttributes, Model, + NonAttribute, } from "sequelize"; import { front_sequelize } from "@app/lib/databases"; @@ -29,6 +30,10 @@ export class AgentRetrievalConfiguration extends Model< declare relativeTimeFrameDuration: number | null; declare relativeTimeFrameUnit: TimeframeUnit | null; declare topK: number; + + declare dataSourceConfigurations: NonAttribute< + AgentDataSourceConfiguration[] + >; } AgentRetrievalConfiguration.init( { @@ -124,6 +129,8 @@ export class AgentDataSourceConfiguration extends Model< declare retrievalConfigurationId: ForeignKey< AgentRetrievalConfiguration["id"] >; + + declare ds: NonAttribute; } AgentDataSourceConfiguration.init( { @@ -185,18 +192,32 @@ AgentDataSourceConfiguration.init( AgentRetrievalConfiguration.hasOne(AgentConfiguration, { foreignKey: { name: "retrievalConfigurationId", allowNull: true }, // null = no retrieval action set for this Agent }); +AgentConfiguration.belongsTo(AgentRetrievalConfiguration, { + as: "retrievalConfiguration", + foreignKey: { name: "retrievalConfigurationId", allowNull: true }, +}); // Retrieval config <> Data source config AgentRetrievalConfiguration.hasMany(AgentDataSourceConfiguration, { + as: "dataSourceConfigurations", foreignKey: { name: "retrievalConfigurationId", allowNull: false }, onDelete: "CASCADE", }); +AgentDataSourceConfiguration.belongsTo(AgentRetrievalConfiguration, { + as: "dataSourceConfigurations", + foreignKey: "retrievalConfigurationId", +}); // Data source config <> Data source DataSource.hasMany(AgentDataSourceConfiguration, { + as: "ds", foreignKey: { name: "dataSourceId", allowNull: false }, onDelete: "CASCADE", }); +AgentDataSourceConfiguration.belongsTo(DataSource, { + as: "ds", + foreignKey: "dataSourceId", +}); /** * Retrieval Action diff --git a/front/lib/models/assistant/agent.ts b/front/lib/models/assistant/agent.ts index 18e1dfc744424..ba5da389f1e40 100644 --- a/front/lib/models/assistant/agent.ts +++ b/front/lib/models/assistant/agent.ts @@ -5,6 +5,7 @@ import { InferAttributes, InferCreationAttributes, Model, + NonAttribute, } from "sequelize"; import { front_sequelize } from "@app/lib/databases"; @@ -39,6 +40,9 @@ export class AgentConfiguration extends Model< declare retrievalConfigurationId: ForeignKey< AgentRetrievalConfiguration["id"] > | null; + + declare generationConfiguration: NonAttribute; + declare retrievalConfiguration: NonAttribute; } AgentConfiguration.init( { @@ -166,3 +170,8 @@ Workspace.hasMany(AgentConfiguration, { AgentGenerationConfiguration.hasOne(AgentConfiguration, { foreignKey: { name: "generationConfigurationId", allowNull: true }, // null = no generation set for this Agent }); + +AgentConfiguration.belongsTo(AgentGenerationConfiguration, { + as: "generationConfiguration", + foreignKey: { name: "generationConfigurationId", allowNull: true }, +}); diff --git a/front/lib/models/data_source.ts b/front/lib/models/data_source.ts index fe6ef61d8755f..20f5f6c6e8aa5 100644 --- a/front/lib/models/data_source.ts +++ b/front/lib/models/data_source.ts @@ -5,6 +5,7 @@ import { InferAttributes, InferCreationAttributes, Model, + NonAttribute, } from "sequelize"; import { ConnectorProvider } from "@app/lib/connectors_api"; @@ -28,6 +29,8 @@ export class DataSource extends Model< declare connectorId: string | null; declare connectorProvider: ConnectorProvider | null; declare workspaceId: ForeignKey; + + declare w: NonAttribute; } DataSource.init( @@ -88,6 +91,11 @@ DataSource.init( } ); Workspace.hasMany(DataSource, { - foreignKey: { allowNull: false }, + as: "w", + foreignKey: { name: "workspaceId", allowNull: false }, onDelete: "CASCADE", }); +DataSource.belongsTo(Workspace, { + as: "w", + foreignKey: "workspaceId", +});