From 20b2e15ab78d7a21d025bba58e68a87d9f828136 Mon Sep 17 00:00:00 2001 From: PopDaph Date: Thu, 7 Sep 2023 09:27:48 +0200 Subject: [PATCH] Assistant: Lib function createAgentConfiguration --- front/lib/api/assistant/agent.ts | 146 +++++++++++++++++++++++++++++-- 1 file changed, 137 insertions(+), 9 deletions(-) diff --git a/front/lib/api/assistant/agent.ts b/front/lib/api/assistant/agent.ts index 2618d80552a7a..ccc5c6d0200cd 100644 --- a/front/lib/api/assistant/agent.ts +++ b/front/lib/api/assistant/agent.ts @@ -4,6 +4,12 @@ import { } from "@app/lib/actions/registry"; import { runAction } from "@app/lib/actions/server"; import { Authenticator } from "@app/lib/auth"; +import { front_sequelize } from "@app/lib/databases"; +import { AgentRetrievalConfiguration } from "@app/lib/models/assistant/actions/retrieval"; +import { + AgentConfiguration, + AgentGenerationConfiguration, +} from "@app/lib/models/assistant/agent"; import { Err, Ok, Result } from "@app/lib/result"; import { generateModelSId } from "@app/lib/utils"; import { @@ -24,11 +30,76 @@ import { RetrievalParamsEvent, } from "./actions/retrieval"; import { renderConversationForModel } from "./conversation"; +import { TimeFrame } from "@app/types/assistant/actions/retrieval"; /** * Agent configuration. */ +function _buildAgentActionConfigurationType( + action: AgentRetrievalConfiguration +): AgentActionConfigurationType { + + let timeframe: "auto" | "none" | TimeFrame = "auto"; + if (action.relativeTimeFrame === "custom" && action.relativeTimeFrameDuration && action.relativeTimeFrameUnit) { + timeframe = { + duration: action.relativeTimeFrameDuration, + unit: action.relativeTimeFrameUnit, + }; + } else if (action.relativeTimeFrame === "none") { + timeframe = action.relativeTimeFrame; + } + + let query: "auto" | "none" | { template: string } = "auto"; + if (action.query === "templated" && action.queryTemplate) { + query = { + template: action.queryTemplate, + }; + } else if (action.query === "none") { + query = "none"; + } + + return { + query: query, + relativeTimeFrame: timeframe, + topK: action.topK, + type: "retrieval_configuration", + dataSources: "all", + }; +} + +function _buildAgentGenerationConfigurationType( + generation: AgentGenerationConfiguration +): AgentGenerationConfigurationType { + return { + prompt: generation.prompt, + model: { + providerId: generation.modelProvider, + modelId: generation.modelId, + }, + }; +} + +function _getAgentConfigurationType({ + agent, + action, + generation, +}: { + agent: AgentConfiguration; + action: AgentRetrievalConfiguration | null; + generation: AgentGenerationConfiguration | null; +}): AgentConfigurationType { + + return { + sId: agent.sId, + name: agent.name, + pictureUrl: agent.pictureUrl, + status: agent.status, + action: action? _buildAgentActionConfigurationType(action) : null, + generation: generation? _buildAgentGenerationConfigurationType(generation) : null, + }; +} + export async function createAgentConfiguration( auth: Authenticator, { @@ -42,15 +113,72 @@ export async function createAgentConfiguration( action?: AgentActionConfigurationType; generation?: AgentGenerationConfigurationType; } -): Promise { - return { - sId: generateModelSId(), - name, - pictureUrl: pictureUrl ?? null, - status: "active", - action: action ?? null, - generation: generation ?? null, - }; +): Promise { + const owner = auth.workspace(); + + if (!owner) { + return; + } + + let agent: AgentConfigurationType | null = null; + + await front_sequelize.transaction(async (t) => { + const agentConfigRow = await AgentConfiguration.create( + { + sId: generateModelSId(), + status: "active", + name: name, + pictureUrl: pictureUrl ?? null, + scope: "workspace", + workspaceId: owner.id, + }, + { transaction: t } + ); + + let agentGenerationConfigRow: AgentGenerationConfiguration | null = null; + let agentActionConfigRow: AgentRetrievalConfiguration | null = null; + + if (generation) { + agentGenerationConfigRow = + await AgentGenerationConfiguration.create( + { + prompt: generation.prompt, + modelProvider: generation.model.providerId, + modelId: generation.model.modelId, + agentId: agentConfigRow.id, + }, + { transaction: t } + ); + } + if (action) { + agentActionConfigRow = await AgentRetrievalConfiguration.create({ + query: typeof action.query === "object" ? "templated" : action.query, + queryTemplate: + typeof action.query === "object" ? action.query.template : null, + relativeTimeFrame: + typeof action.relativeTimeFrame === "object" + ? "custom" + : action.relativeTimeFrame, + relativeTimeFrameDuration: + typeof action.relativeTimeFrame === "object" + ? action.relativeTimeFrame.duration + : null, + relativeTimeFrameUnit: + typeof action.relativeTimeFrame === "object" + ? action.relativeTimeFrame.unit + : null, + topK: action.topK, + agentId: agentConfigRow.id, + }, { transaction: t }); + } + + agent = _getAgentConfigurationType({ + agent: agentConfigRow, + action: agentActionConfigRow, + generation: agentGenerationConfigRow, + }); + + return agent; } export async function updateAgentConfiguration(