From 8974e81e8579e92fe42d8f7a2582c8da44f8edf2 Mon Sep 17 00:00:00 2001 From: PopDaph Date: Thu, 7 Sep 2023 09:27:48 +0200 Subject: [PATCH] Assistant: Lib function createAgentConfiguration --- front/lib/api/assistant/actions/retrieval.ts | 69 +++------ front/lib/api/assistant/agent.ts | 122 ++++++++++++++-- .../lib/models/assistant/actions/retrieval.ts | 25 ++-- front/types/assistant/actions/retrieval.ts | 24 +++- front/types/assistant/agent-utils.ts | 134 ++++++++++++++++++ 5 files changed, 292 insertions(+), 82 deletions(-) create mode 100644 front/types/assistant/agent-utils.ts diff --git a/front/lib/api/assistant/actions/retrieval.ts b/front/lib/api/assistant/actions/retrieval.ts index 8f7557dfbaac4..fc77ca00c7b4b 100644 --- a/front/lib/api/assistant/actions/retrieval.ts +++ b/front/lib/api/assistant/actions/retrieval.ts @@ -389,66 +389,31 @@ export async function* runRetrieval( ); // Handle data sources list and parents/tags filtering. - if (c.dataSources === "all") { - const prodCredentials = await prodAPICredentialsForOwner(owner); - const api = new DustAPI(prodCredentials); + config.DATASOURCE.data_sources = c.dataSources.map((d) => ({ + workspace_id: d.workspaceId, + name: d.name, + })); + + for (const ds of c.dataSources) { + if (ds.filter.tags) { + if (!config.DATASOURCE.filter.tags) { + config.DATASOURCE.filter.tags = { in: [], not: [] }; + } - const dsRes = await api.getDataSources(prodCredentials.workspaceId); - if (dsRes.isErr()) { - return yield { - type: "retrieval_error", - created: Date.now(), - configurationId: configuration.sId, - messageId: agentMessage.sId, - error: { - code: "retrieval_data_sources_error", - message: `Error retrieving workspace data sources: ${dsRes.error.message}`, - }, - }; + config.DATASOURCE.filter.tags.in.push(...ds.filter.tags.in); + config.DATASOURCE.filter.tags.not.push(...ds.filter.tags.not); } - const ds = dsRes.value.filter((d) => d.assistantDefaultSelected); - - config.DATASOURCE.data_sources = ds.map((d) => { - return { - workspace_id: prodCredentials.workspaceId, - data_source_id: d.name, - }; - }); - } else { - config.DATASOURCE.data_sources = c.dataSources.map((d) => ({ - workspace_id: d.workspaceId, - data_source_id: d.dataSourceId, - })); - - for (const ds of c.dataSources) { - if (ds.filter.tags) { - if (!config.DATASOURCE.filter.tags) { - config.DATASOURCE.filter.tags = { in: [], not: [] }; - } - - config.DATASOURCE.filter.tags.in.push(...ds.filter.tags.in); - config.DATASOURCE.filter.tags.not.push(...ds.filter.tags.not); + if (ds.filter.parents) { + if (!config.DATASOURCE.filter.parents) { + config.DATASOURCE.filter.parents = { in: [], not: [] }; } - if (ds.filter.parents) { - if (!config.DATASOURCE.filter.parents) { - config.DATASOURCE.filter.parents = { in: [], not: [] }; - } - - config.DATASOURCE.filter.parents.in.push(...ds.filter.parents.in); - config.DATASOURCE.filter.parents.not.push(...ds.filter.parents.not); - } + config.DATASOURCE.filter.parents.in.push(...ds.filter.parents.in); + config.DATASOURCE.filter.parents.not.push(...ds.filter.parents.not); } } - // Handle timestamp filtering. - if (params.relativeTimeFrame) { - config.DATASOURCE.filter.timestamp = { - gt: timeFrameFromNow(params.relativeTimeFrame), - }; - } - // Handle top k. config.DATASOURCE.top_k = params.topK; diff --git a/front/lib/api/assistant/agent.ts b/front/lib/api/assistant/agent.ts index 2618d80552a7a..84d8b150c9031 100644 --- a/front/lib/api/assistant/agent.ts +++ b/front/lib/api/assistant/agent.ts @@ -4,8 +4,23 @@ import { } from "@app/lib/actions/registry"; import { runAction } from "@app/lib/actions/server"; import { Authenticator } from "@app/lib/auth"; +import { front_sequelize } from "@app/lib/databases"; +import { DataSource } from "@app/lib/models"; +import { + AgentDataSourceConfiguration, + AgentRetrievalConfiguration, +} from "@app/lib/models/assistant/actions/retrieval"; +import { + AgentConfiguration, + AgentGenerationConfiguration, +} from "@app/lib/models/assistant/agent"; import { Err, Ok, Result } from "@app/lib/result"; import { generateModelSId } from "@app/lib/utils"; +import { + isTemplatedQuery, + isTimeFrame, + RetrievalDataSourcesConfiguration, +} from "@app/types/assistant/actions/retrieval"; import { AgentActionConfigurationType, AgentActionSpecification, @@ -13,6 +28,7 @@ import { AgentConfigurationType, AgentGenerationConfigurationType, } from "@app/types/assistant/agent"; +import { _getAgentConfigurationType } from "@app/types/assistant/agent-utils"; import { AgentActionType, AgentMessageType, @@ -42,15 +58,103 @@ export async function createAgentConfiguration( action?: AgentActionConfigurationType; generation?: AgentGenerationConfigurationType; } -): Promise { - return { - sId: generateModelSId(), - name, - pictureUrl: pictureUrl ?? null, - status: "active", - action: action ?? null, - generation: generation ?? null, - }; +): Promise { + const owner = auth.workspace(); + + if (!owner) { + return; + } + + return await front_sequelize.transaction(async (t) => { + // Create AgentConfiguration + const agentConfigRow = await AgentConfiguration.create( + { + sId: generateModelSId(), + status: "active", + name: name, + pictureUrl: pictureUrl ?? null, + scope: "workspace", + workspaceId: owner.id, + }, + { transaction: t } + ); + + let agentGenerationConfigRow: AgentGenerationConfiguration | null = null; + let agentActionConfigRow: AgentRetrievalConfiguration | null = null; + const agentDataSourcesConfigRows: AgentDataSourceConfiguration[] = []; + + // Create AgentGenerationConfiguration + if (generation) { + agentGenerationConfigRow = await AgentGenerationConfiguration.create( + { + prompt: generation.prompt, + modelProvider: generation.model.providerId, + modelId: generation.model.modelId, + agentId: agentConfigRow.id, + }, + { transaction: t } + ); + } + + // Create AgentRetrievalConfiguration & associated AgentDataSourceConfiguration + if (action) { + const query = action.query; + const timeframe = action.relativeTimeFrame; + + agentActionConfigRow = await AgentRetrievalConfiguration.create( + { + query: isTemplatedQuery(query) ? "templated" : query, + queryTemplate: isTemplatedQuery(query) ? query.template : null, + relativeTimeFrame: isTimeFrame(timeframe) ? "custom" : timeframe, + relativeTimeFrameDuration: isTimeFrame(timeframe) + ? timeframe.duration + : null, + relativeTimeFrameUnit: isTimeFrame(timeframe) ? timeframe.unit : null, + topK: action.topK, + agentId: agentConfigRow.id, + }, + { transaction: t } + ); + + if (!agentActionConfigRow) { + return; + } + + if (action.dataSources) { + let dsRow: AgentDataSourceConfiguration | null = null; + action.dataSources.map(async (d) => { + const ds = await DataSource.findOne({ + where: { + name: d.name, + workspaceId: d.workspaceId, + }, + }); + + if (ds && agentActionConfigRow) { + dsRow = await AgentDataSourceConfiguration.create( + { + dataSourceId: ds.id, + tagsIn: d.filter.tags?.in, + tagsNotIn: d.filter.tags?.not, + parentsIn: d.filter.parents?.in, + parentsNotIn: d.filter.parents?.not, + retrievalConfigurationId: agentActionConfigRow.id, + }, + { transaction: t } + ); + agentDataSourcesConfigRows.push(dsRow); + } + }); + } + } + + return _getAgentConfigurationType({ + agent: agentConfigRow, + action: agentActionConfigRow, + generation: agentGenerationConfigRow, + dataSources: agentDataSourcesConfigRows, + }); + }); } export async function updateAgentConfiguration( diff --git a/front/lib/models/assistant/actions/retrieval.ts b/front/lib/models/assistant/actions/retrieval.ts index f123728103069..98dcaac793461 100644 --- a/front/lib/models/assistant/actions/retrieval.ts +++ b/front/lib/models/assistant/actions/retrieval.ts @@ -117,9 +117,6 @@ export class AgentDataSourceConfiguration extends Model< declare createdAt: CreationOptional; declare updatedAt: CreationOptional; - declare timeframeDuration: number | null; - declare timeframeUnit: TimeframeUnit | null; - declare tagsIn: string[] | null; declare tagsNotIn: string[] | null; declare parentsIn: string[] | null; @@ -147,14 +144,6 @@ AgentDataSourceConfiguration.init( allowNull: false, defaultValue: DataTypes.NOW, }, - timeframeDuration: { - type: DataTypes.INTEGER, - allowNull: true, - }, - timeframeUnit: { - type: DataTypes.STRING, - allowNull: true, - }, tagsIn: { type: DataTypes.ARRAY(DataTypes.STRING), allowNull: true, @@ -178,12 +167,16 @@ AgentDataSourceConfiguration.init( hooks: { beforeValidate: (dataSourceConfig: AgentDataSourceConfiguration) => { if ( - (dataSourceConfig.timeframeDuration === null) !== - (dataSourceConfig.timeframeUnit === null) + (dataSourceConfig.tagsIn === null) !== + (dataSourceConfig.tagsNotIn === null) ) { - throw new Error( - "Timeframe duration/unit must be both set or both null" - ); + throw new Error("Tags must be both set or both null"); + } + if ( + (dataSourceConfig.parentsIn === null) !== + (dataSourceConfig.parentsNotIn === null) + ) { + throw new Error("Parents must be both set or both null"); } }, }, diff --git a/front/types/assistant/actions/retrieval.ts b/front/types/assistant/actions/retrieval.ts index 2544edeb7ee7b..4d875d2ccc3e9 100644 --- a/front/types/assistant/actions/retrieval.ts +++ b/front/types/assistant/actions/retrieval.ts @@ -17,9 +17,10 @@ export type DataSourceFilter = { parents: { in: string[]; not: string[] } | null; }; +// DataSources have a unique pair (name, workspaceId) export type DataSourceConfiguration = { - workspaceId: string; - dataSourceId: string; + workspaceId: ModelId; + name: string; filter: DataSourceFilter; }; @@ -30,6 +31,15 @@ export type DataSourceConfiguration = { export type TemplatedQuery = { template: string; }; +export function isTemplatedQuery(arg: RetrievalQuery): arg is TemplatedQuery { + return (arg as TemplatedQuery).template !== undefined; +} +export function isTimeFrame(arg: RetrievalTimeframe): arg is TimeFrame { + return ( + (arg as TimeFrame).duration !== undefined && + (arg as TimeFrame).unit !== undefined + ); +} // Retrieval specifies a list of data sources (with possible parent / tags filtering, possible "all" // data sources), a query ("auto" generated by the model "none", no query, `TemplatedQuery`, fixed @@ -39,13 +49,17 @@ export type TemplatedQuery = { // `query` and `relativeTimeFrame` will be used to generate the inputs specification for the model // in charge of generating the action inputs. The results will be used along with `topK` and // `dataSources` to query the data. +export type RetrievalTimeframe = "auto" | "none" | TimeFrame; +export type RetrievalQuery = "auto" | "none" | TemplatedQuery; +export type RetrievalDataSourcesConfiguration = DataSourceConfiguration[]; + export type RetrievalConfigurationType = { id: ModelId; type: "retrieval_configuration"; - dataSources: "all" | DataSourceConfiguration[]; - query: "auto" | "none" | TemplatedQuery; - relativeTimeFrame: "auto" | "none" | TimeFrame; + dataSources: RetrievalDataSourcesConfiguration; + query: RetrievalQuery; + relativeTimeFrame: RetrievalTimeframe; topK: number; // Dynamically decide to skip, if needed in the future diff --git a/front/types/assistant/agent-utils.ts b/front/types/assistant/agent-utils.ts new file mode 100644 index 0000000000000..3a3abdc09a49f --- /dev/null +++ b/front/types/assistant/agent-utils.ts @@ -0,0 +1,134 @@ +import { DataSource } from "@app/lib/models"; +import { + AgentDataSourceConfiguration, + AgentRetrievalConfiguration, +} from "@app/lib/models/assistant/actions/retrieval"; +import { + AgentConfiguration, + AgentGenerationConfiguration, +} from "@app/lib/models/assistant/agent"; +import { + RetrievalDataSourcesConfiguration, + RetrievalQuery, + RetrievalTimeframe, +} from "@app/types/assistant/actions/retrieval"; +import { + AgentActionConfigurationType, + AgentConfigurationType, + AgentGenerationConfigurationType, +} from "@app/types/assistant/agent"; + +/** + * Builds the agent configuration type from the model + */ +export function _getAgentConfigurationType({ + agent, + action, + generation, + dataSources, +}: { + agent: AgentConfiguration; + action: AgentRetrievalConfiguration | null; + generation: AgentGenerationConfiguration | null; + dataSources: AgentDataSourceConfiguration[] | null; +}): AgentConfigurationType { + return { + sId: agent.sId, + name: agent.name, + pictureUrl: agent.pictureUrl, + status: agent.status, + action: action + ? _buildAgentActionConfigurationType(action, dataSources) + : null, + generation: generation + ? _buildAgentGenerationConfigurationType(generation) + : null, + }; +} + +/** + * Builds the agent action configuration type from the model + */ +export function _buildAgentActionConfigurationType( + action: AgentRetrievalConfiguration, + dataSourcesConfig: AgentDataSourceConfiguration[] | null +): AgentActionConfigurationType { + // Build Retrieval Timeframe + let timeframe: RetrievalTimeframe = "auto"; + if ( + action.relativeTimeFrame === "custom" && + action.relativeTimeFrameDuration && + action.relativeTimeFrameUnit + ) { + timeframe = { + duration: action.relativeTimeFrameDuration, + unit: action.relativeTimeFrameUnit, + }; + } else if (action.relativeTimeFrame === "none") { + timeframe = "none"; + } + + // Build Retrieval Query + let query: RetrievalQuery = "auto"; + if (action.query === "templated" && action.queryTemplate) { + query = { + template: action.queryTemplate, + }; + } else if (action.query === "none") { + query = "none"; + } + + // Build Retrieval DataSources + const retrievalDataSourcesConfig: RetrievalDataSourcesConfiguration = []; + let dataSource: DataSource | null = null; + + dataSourcesConfig?.forEach(async (dsConfig) => { + dataSource = await DataSource.findOne({ + where: { + id: dsConfig.dataSourceId, + }, + }); + + if (!dataSource) { + return; + } + retrievalDataSourcesConfig.push({ + name: dataSource.name, + workspaceId: dataSource.workspaceId, + filter: { + tags: + dsConfig.tagsIn && dsConfig.tagsNotIn + ? { in: dsConfig.tagsIn, not: dsConfig.tagsNotIn } + : null, + parents: + dsConfig.parentsIn && dsConfig.parentsNotIn + ? { in: dsConfig.parentsIn, not: dsConfig.parentsNotIn } + : null, + }, + }); + }); + + return { + id: action.id, + query: query, + relativeTimeFrame: timeframe, + topK: action.topK, + type: "retrieval_configuration", + dataSources: retrievalDataSourcesConfig, + }; +} + +/** + * Builds the agent generation configuration type from the model + */ +export function _buildAgentGenerationConfigurationType( + generation: AgentGenerationConfiguration +): AgentGenerationConfigurationType { + return { + prompt: generation.prompt, + model: { + providerId: generation.modelProvider, + modelId: generation.modelId, + }, + }; +}