Skip to content

Commit

Permalink
Rework functions - split
Browse files Browse the repository at this point in the history
  • Loading branch information
PopDaph committed Sep 8, 2023
1 parent 24297e4 commit d2f9d8e
Show file tree
Hide file tree
Showing 8 changed files with 589 additions and 335 deletions.
16 changes: 9 additions & 7 deletions front/lib/api/assistant/actions/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {
import {
AgentActionSpecification,
AgentConfigurationType,
AgentFullConfigurationType,
} from "@app/types/assistant/agent";
import {
AgentMessageType,
Expand Down Expand Up @@ -166,6 +167,7 @@ export async function retrievalActionSpecification(
}

return {
id: configuration.id,
name: "search_data_sources",
description:
"Search the data sources specified by the user for information to answer their request." +
Expand Down Expand Up @@ -312,7 +314,7 @@ export type RetrievalSuccessEvent = {
// error is expected to be stored by the caller on the parent agent message.
export async function* runRetrieval(
auth: Authenticator,
configuration: AgentConfigurationType,
configuration: AgentFullConfigurationType,
conversation: ConversationType,
userMessage: UserMessageType,
agentMessage: AgentMessageType
Expand All @@ -332,7 +334,7 @@ export async function* runRetrieval(
return yield {
type: "retrieval_error",
created: Date.now(),
configurationId: configuration.sId,
configurationId: configuration.agent.sId,
messageId: agentMessage.sId,
error: {
code: "internal_server_error",
Expand All @@ -352,7 +354,7 @@ export async function* runRetrieval(
return yield {
type: "retrieval_error",
created: Date.now(),
configurationId: configuration.sId,
configurationId: configuration.agent.sId,
messageId: agentMessage.sId,
error: {
code: "retrieval_parameters_generation_error",
Expand Down Expand Up @@ -434,7 +436,7 @@ export async function* runRetrieval(
return yield {
type: "retrieval_error",
created: Date.now(),
configurationId: configuration.sId,
configurationId: configuration.agent.sId,
messageId: agentMessage.sId,
error: {
code: "retrieval_search_error",
Expand All @@ -451,7 +453,7 @@ export async function* runRetrieval(
return yield {
type: "retrieval_error",
created: Date.now(),
configurationId: configuration.sId,
configurationId: configuration.agent.sId,
messageId: agentMessage.sId,
error: {
code: "retrieval_search_error",
Expand Down Expand Up @@ -530,15 +532,15 @@ export async function* runRetrieval(
yield {
type: "retrieval_documents",
created: Date.now(),
configurationId: configuration.sId,
configurationId: configuration.agent.sId,
messageId: agentMessage.sId,
documents,
};

yield {
type: "retrieval_success",
created: Date.now(),
configurationId: configuration.sId,
configurationId: configuration.agent.sId,
messageId: agentMessage.sId,
action: {
id: action.id,
Expand Down
291 changes: 0 additions & 291 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
@@ -1,35 +1,14 @@
import { Op, Transaction } from "sequelize";

import {
cloneBaseConfig,
DustProdActionRegistry,
} from "@app/lib/actions/registry";
import { runAction } from "@app/lib/actions/server";
import { _getAgentConfigurationType } from "@app/lib/api/assistant/agent_utils";
import { Authenticator } from "@app/lib/auth";
import { front_sequelize } from "@app/lib/databases";
import { DataSource, Workspace } from "@app/lib/models";
import {
AgentDataSourceConfiguration,
AgentRetrievalConfiguration,
} from "@app/lib/models/assistant/actions/retrieval";
import {
AgentConfiguration,
AgentGenerationConfiguration,
} from "@app/lib/models/assistant/agent";
import { Err, Ok, Result } from "@app/lib/result";
import { generateModelSId } from "@app/lib/utils";
import {
AgentDataSourceConfigurationType,
isTemplatedQuery,
isTimeFrame,
} from "@app/types/assistant/actions/retrieval";
import {
AgentActionConfigurationType,
AgentActionSpecification,
AgentConfigurationStatus,
AgentConfigurationType,
AgentGenerationConfigurationType,
} from "@app/types/assistant/agent";
import {
AgentActionType,
Expand All @@ -43,276 +22,6 @@ import {
} from "./actions/retrieval";
import { renderConversationForModel } from "./conversation";

/**
* Get an agent configuration from its name
*/
export async function getAgentConfiguration(auth: Authenticator, name: string) {
const owner = auth.workspace();
if (!owner) {
return;
}
const agent = await AgentConfiguration.findOne({
where: {
name: name,
workspaceId: owner.id,
},
});
const agentGeneration = await AgentGenerationConfiguration.findOne({
where: {
agentId: agent?.id,
},
});
const agentAction = await AgentRetrievalConfiguration.findOne({
where: {
agentId: agent?.id,
},
});
const agentDataSources = agentAction?.id
? await AgentDataSourceConfiguration.findAll({
where: {
retrievalConfigurationId: agentAction?.id,
},
})
: [];

if (!agent) {
return;
}
return await _getAgentConfigurationType({
agent: agent,
generation: agentGeneration,
action: agentAction,
dataSources: agentDataSources,
});
}

/**
* Create a new Agent
*/
export async function createAgentConfiguration(
auth: Authenticator,
{
name,
pictureUrl,
action,
generation,
}: {
name: string;
pictureUrl?: string;
action?: AgentActionConfigurationType;
generation?: AgentGenerationConfigurationType;
}
): Promise<AgentConfigurationType | void> {
const owner = auth.workspace();
if (!owner) {
return;
}

return await front_sequelize.transaction(async (t) => {
let agentConfigRow: AgentConfiguration | null = null;
let agentGenerationConfigRow: AgentGenerationConfiguration | null = null;
let agentActionConfigRow: AgentRetrievalConfiguration | null = null;
let agentDataSourcesConfigRows: AgentDataSourceConfiguration[] = [];

// Create AgentConfiguration
agentConfigRow = await AgentConfiguration.create(
{
sId: generateModelSId(),
status: "active",
name: name,
pictureUrl: pictureUrl ?? null,
scope: "workspace",
workspaceId: owner.id,
},
{ transaction: t }
);

// Create AgentGenerationConfiguration
if (generation) {
agentGenerationConfigRow = await AgentGenerationConfiguration.create(
{
prompt: generation.prompt,
modelProvider: generation.model.providerId,
modelId: generation.model.modelId,
agentId: agentConfigRow.id,
},
{ transaction: t }
);
}

// Create AgentRetrievalConfiguration & AgentDataSourceConfiguration
if (action) {
const query = action.query;
const timeframe = action.relativeTimeFrame;
agentActionConfigRow = await AgentRetrievalConfiguration.create(
{
query: isTemplatedQuery(query) ? "templated" : query,
queryTemplate: isTemplatedQuery(query) ? query.template : null,
relativeTimeFrame: isTimeFrame(timeframe) ? "custom" : timeframe,
relativeTimeFrameDuration: isTimeFrame(timeframe)
? timeframe.duration
: null,
relativeTimeFrameUnit: isTimeFrame(timeframe) ? timeframe.unit : null,
topK: action.topK,
agentId: agentConfigRow.id,
},
{ transaction: t }
);
agentDataSourcesConfigRows = await _createAgentDataSourcesConfigData(
t,
action.dataSources,
agentActionConfigRow.id
);
}

return await _getAgentConfigurationType({
agent: agentConfigRow,
action: agentActionConfigRow,
generation: agentGenerationConfigRow,
dataSources: agentDataSourcesConfigRows,
});
});
}

/**
* Create the AgentDataSourceConfiguration rows in database.
*
* Knowing that a datasource is uniquely identified by its name and its workspaceId
* We need to fetch the dataSources from the database from that.
* We obvisously need to do as few queries as possible.
*/
async function _createAgentDataSourcesConfigData(
t: Transaction,
dataSourcesConfig: AgentDataSourceConfigurationType[],
agentActionId: number
): Promise<AgentDataSourceConfiguration[]> {
// dsConfig contains this format:
// [
// { workspaceSId: s1o1u1p, dataSourceName: "managed-notion", filter: { tags: null, parents: null } },
// { workspaceSId: s1o1u1p, dataSourceName: "managed-slack", filter: { tags: null, parents: null } },
// { workspaceSId: i2n2o2u, dataSourceName: "managed-notion", filter: { tags: null, parents: null } },
// ]

// First we get the list of workspaces because we need the mapping between workspaceSId and workspaceId
const workspaces = await Workspace.findAll({
where: {
sId: dataSourcesConfig.map((dsConfig) => dsConfig.workspaceSId),
},
attributes: ["id", "sId"],
});

// Now will want to group the datasource names by workspaceId to do only one query per workspace.
// We want this:
// [
// { workspaceId: 1, dataSourceNames: [""managed-notion", "managed-slack"] },
// { workspaceId: 2, dataSourceNames: ["managed-notion"] }
// ]
type _DsNamesPerWorkspaceIdType = {
workspaceId: number;
dataSourceNames: string[];
};
const dsNamesPerWorkspaceId = dataSourcesConfig.reduce(
(
acc: _DsNamesPerWorkspaceIdType[],
curr: AgentDataSourceConfigurationType
) => {
// First we need to get the workspaceId from the workspaceSId
const workspace = workspaces.find((w) => w.sId === curr.workspaceSId);
if (!workspace) {
throw new Error("Workspace not found");
}

// Find an existing entry for this workspaceId
const existingEntry: _DsNamesPerWorkspaceIdType | undefined = acc.find(
(entry: _DsNamesPerWorkspaceIdType) =>
entry.workspaceId === workspace.id
);
if (existingEntry) {
// Append dataSourceName to existing entry
existingEntry.dataSourceNames.push(curr.dataSourceName);
} else {
// Add a new entry for this workspaceId
acc.push({
workspaceId: workspace.id,
dataSourceNames: [curr.dataSourceName],
});
}
return acc;
},
[]
);

// Then we get do one findAllQuery per workspaceId, in a Promise.all
const getDataSourcesQueries = dsNamesPerWorkspaceId.map(
({ workspaceId, dataSourceNames }) => {
return DataSource.findAll({
where: {
workspaceId,
name: {
[Op.in]: dataSourceNames,
},
},
});
}
);
const results = await Promise.all(getDataSourcesQueries);
const dataSources = results.flat();

const agentDataSourcesConfigRows: AgentDataSourceConfiguration[] =
await Promise.all(
dataSourcesConfig.map(async (dsConfig) => {
const dataSource = dataSources.find(
(ds) =>
ds.name === dsConfig.dataSourceName &&
ds.workspaceId ===
workspaces.find((w) => w.sId === dsConfig.workspaceSId)?.id
);
if (!dataSource) {
throw new Error("DataSource not found");
}
return AgentDataSourceConfiguration.create(
{
dataSourceId: dataSource.id,
tagsIn: dsConfig.filter.tags?.in,
tagsNotIn: dsConfig.filter.tags?.not,
parentsIn: dsConfig.filter.parents?.in,
parentsNotIn: dsConfig.filter.parents?.not,
retrievalConfigurationId: agentActionId,
},
{ transaction: t }
);
})
);
return agentDataSourcesConfigRows;
}

export async function updateAgentConfiguration(
auth: Authenticator,
configurationId: string,
{
name,
pictureUrl,
status,
action,
generation,
}: {
name: string;
pictureUrl?: string;
status: AgentConfigurationStatus;
action?: AgentActionConfigurationType;
generation?: AgentGenerationConfigurationType;
}
): Promise<AgentConfigurationType> {
return {
sId: generateModelSId(),
name,
pictureUrl: pictureUrl ?? null,
status,
action: action ?? null,
generation: generation ?? null,
};
}

/**
* Action Inputs generation.
*/
Expand Down
Loading

0 comments on commit d2f9d8e

Please sign in to comment.