Skip to content

Commit

Permalink
Assistant: Lib function createAgentConfiguration
Browse files Browse the repository at this point in the history
  • Loading branch information
PopDaph committed Sep 7, 2023
1 parent 57e0182 commit 8974e81
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 82 deletions.
69 changes: 17 additions & 52 deletions front/lib/api/assistant/actions/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -389,66 +389,31 @@ export async function* runRetrieval(
);

// Handle data sources list and parents/tags filtering.
if (c.dataSources === "all") {
const prodCredentials = await prodAPICredentialsForOwner(owner);
const api = new DustAPI(prodCredentials);
config.DATASOURCE.data_sources = c.dataSources.map((d) => ({
workspace_id: d.workspaceId,
name: d.name,
}));

for (const ds of c.dataSources) {
if (ds.filter.tags) {
if (!config.DATASOURCE.filter.tags) {
config.DATASOURCE.filter.tags = { in: [], not: [] };
}

const dsRes = await api.getDataSources(prodCredentials.workspaceId);
if (dsRes.isErr()) {
return yield {
type: "retrieval_error",
created: Date.now(),
configurationId: configuration.sId,
messageId: agentMessage.sId,
error: {
code: "retrieval_data_sources_error",
message: `Error retrieving workspace data sources: ${dsRes.error.message}`,
},
};
config.DATASOURCE.filter.tags.in.push(...ds.filter.tags.in);
config.DATASOURCE.filter.tags.not.push(...ds.filter.tags.not);
}

const ds = dsRes.value.filter((d) => d.assistantDefaultSelected);

config.DATASOURCE.data_sources = ds.map((d) => {
return {
workspace_id: prodCredentials.workspaceId,
data_source_id: d.name,
};
});
} else {
config.DATASOURCE.data_sources = c.dataSources.map((d) => ({
workspace_id: d.workspaceId,
data_source_id: d.dataSourceId,
}));

for (const ds of c.dataSources) {
if (ds.filter.tags) {
if (!config.DATASOURCE.filter.tags) {
config.DATASOURCE.filter.tags = { in: [], not: [] };
}

config.DATASOURCE.filter.tags.in.push(...ds.filter.tags.in);
config.DATASOURCE.filter.tags.not.push(...ds.filter.tags.not);
if (ds.filter.parents) {
if (!config.DATASOURCE.filter.parents) {
config.DATASOURCE.filter.parents = { in: [], not: [] };
}

if (ds.filter.parents) {
if (!config.DATASOURCE.filter.parents) {
config.DATASOURCE.filter.parents = { in: [], not: [] };
}

config.DATASOURCE.filter.parents.in.push(...ds.filter.parents.in);
config.DATASOURCE.filter.parents.not.push(...ds.filter.parents.not);
}
config.DATASOURCE.filter.parents.in.push(...ds.filter.parents.in);
config.DATASOURCE.filter.parents.not.push(...ds.filter.parents.not);
}
}

// Handle timestamp filtering.
if (params.relativeTimeFrame) {
config.DATASOURCE.filter.timestamp = {
gt: timeFrameFromNow(params.relativeTimeFrame),
};
}

// Handle top k.
config.DATASOURCE.top_k = params.topK;

Expand Down
122 changes: 113 additions & 9 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,31 @@ import {
} from "@app/lib/actions/registry";
import { runAction } from "@app/lib/actions/server";
import { Authenticator } from "@app/lib/auth";
import { front_sequelize } from "@app/lib/databases";
import { DataSource } from "@app/lib/models";
import {
AgentDataSourceConfiguration,
AgentRetrievalConfiguration,
} from "@app/lib/models/assistant/actions/retrieval";
import {
AgentConfiguration,
AgentGenerationConfiguration,
} from "@app/lib/models/assistant/agent";
import { Err, Ok, Result } from "@app/lib/result";
import { generateModelSId } from "@app/lib/utils";
import {
isTemplatedQuery,
isTimeFrame,
RetrievalDataSourcesConfiguration,
} from "@app/types/assistant/actions/retrieval";
import {
AgentActionConfigurationType,
AgentActionSpecification,
AgentConfigurationStatus,
AgentConfigurationType,
AgentGenerationConfigurationType,
} from "@app/types/assistant/agent";
import { _getAgentConfigurationType } from "@app/types/assistant/agent-utils";
import {
AgentActionType,
AgentMessageType,
Expand Down Expand Up @@ -42,15 +58,103 @@ export async function createAgentConfiguration(
action?: AgentActionConfigurationType;
generation?: AgentGenerationConfigurationType;
}
): Promise<AgentConfigurationType> {
return {
sId: generateModelSId(),
name,
pictureUrl: pictureUrl ?? null,
status: "active",
action: action ?? null,
generation: generation ?? null,
};
): Promise<AgentConfigurationType | void> {
const owner = auth.workspace();

if (!owner) {
return;
}

return await front_sequelize.transaction(async (t) => {
// Create AgentConfiguration
const agentConfigRow = await AgentConfiguration.create(
{
sId: generateModelSId(),
status: "active",
name: name,
pictureUrl: pictureUrl ?? null,
scope: "workspace",
workspaceId: owner.id,
},
{ transaction: t }
);

let agentGenerationConfigRow: AgentGenerationConfiguration | null = null;
let agentActionConfigRow: AgentRetrievalConfiguration | null = null;
const agentDataSourcesConfigRows: AgentDataSourceConfiguration[] = [];

// Create AgentGenerationConfiguration
if (generation) {
agentGenerationConfigRow = await AgentGenerationConfiguration.create(
{
prompt: generation.prompt,
modelProvider: generation.model.providerId,
modelId: generation.model.modelId,
agentId: agentConfigRow.id,
},
{ transaction: t }
);
}

// Create AgentRetrievalConfiguration & associated AgentDataSourceConfiguration
if (action) {
const query = action.query;
const timeframe = action.relativeTimeFrame;

agentActionConfigRow = await AgentRetrievalConfiguration.create(
{
query: isTemplatedQuery(query) ? "templated" : query,
queryTemplate: isTemplatedQuery(query) ? query.template : null,
relativeTimeFrame: isTimeFrame(timeframe) ? "custom" : timeframe,
relativeTimeFrameDuration: isTimeFrame(timeframe)
? timeframe.duration
: null,
relativeTimeFrameUnit: isTimeFrame(timeframe) ? timeframe.unit : null,
topK: action.topK,
agentId: agentConfigRow.id,
},
{ transaction: t }
);

if (!agentActionConfigRow) {
return;
}

if (action.dataSources) {
let dsRow: AgentDataSourceConfiguration | null = null;
action.dataSources.map(async (d) => {
const ds = await DataSource.findOne({
where: {
name: d.name,
workspaceId: d.workspaceId,
},
});

if (ds && agentActionConfigRow) {
dsRow = await AgentDataSourceConfiguration.create(
{
dataSourceId: ds.id,
tagsIn: d.filter.tags?.in,
tagsNotIn: d.filter.tags?.not,
parentsIn: d.filter.parents?.in,
parentsNotIn: d.filter.parents?.not,
retrievalConfigurationId: agentActionConfigRow.id,
},
{ transaction: t }
);
agentDataSourcesConfigRows.push(dsRow);
}
});
}
}

return _getAgentConfigurationType({
agent: agentConfigRow,
action: agentActionConfigRow,
generation: agentGenerationConfigRow,
dataSources: agentDataSourcesConfigRows,
});
});
}

export async function updateAgentConfiguration(
Expand Down
25 changes: 9 additions & 16 deletions front/lib/models/assistant/actions/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,6 @@ export class AgentDataSourceConfiguration extends Model<
declare createdAt: CreationOptional<Date>;
declare updatedAt: CreationOptional<Date>;

declare timeframeDuration: number | null;
declare timeframeUnit: TimeframeUnit | null;

declare tagsIn: string[] | null;
declare tagsNotIn: string[] | null;
declare parentsIn: string[] | null;
Expand Down Expand Up @@ -147,14 +144,6 @@ AgentDataSourceConfiguration.init(
allowNull: false,
defaultValue: DataTypes.NOW,
},
timeframeDuration: {
type: DataTypes.INTEGER,
allowNull: true,
},
timeframeUnit: {
type: DataTypes.STRING,
allowNull: true,
},
tagsIn: {
type: DataTypes.ARRAY(DataTypes.STRING),
allowNull: true,
Expand All @@ -178,12 +167,16 @@ AgentDataSourceConfiguration.init(
hooks: {
beforeValidate: (dataSourceConfig: AgentDataSourceConfiguration) => {
if (
(dataSourceConfig.timeframeDuration === null) !==
(dataSourceConfig.timeframeUnit === null)
(dataSourceConfig.tagsIn === null) !==
(dataSourceConfig.tagsNotIn === null)
) {
throw new Error(
"Timeframe duration/unit must be both set or both null"
);
throw new Error("Tags must be both set or both null");
}
if (
(dataSourceConfig.parentsIn === null) !==
(dataSourceConfig.parentsNotIn === null)
) {
throw new Error("Parents must be both set or both null");
}
},
},
Expand Down
24 changes: 19 additions & 5 deletions front/types/assistant/actions/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ export type DataSourceFilter = {
parents: { in: string[]; not: string[] } | null;
};

// DataSources have a unique pair (name, workspaceId)
export type DataSourceConfiguration = {
workspaceId: string;
dataSourceId: string;
workspaceId: ModelId;
name: string;
filter: DataSourceFilter;
};

Expand All @@ -30,6 +31,15 @@ export type DataSourceConfiguration = {
export type TemplatedQuery = {
template: string;
};
export function isTemplatedQuery(arg: RetrievalQuery): arg is TemplatedQuery {
return (arg as TemplatedQuery).template !== undefined;
}
export function isTimeFrame(arg: RetrievalTimeframe): arg is TimeFrame {
return (
(arg as TimeFrame).duration !== undefined &&
(arg as TimeFrame).unit !== undefined
);
}

// Retrieval specifies a list of data sources (with possible parent / tags filtering, possible "all"
// data sources), a query ("auto" generated by the model "none", no query, `TemplatedQuery`, fixed
Expand All @@ -39,13 +49,17 @@ export type TemplatedQuery = {
// `query` and `relativeTimeFrame` will be used to generate the inputs specification for the model
// in charge of generating the action inputs. The results will be used along with `topK` and
// `dataSources` to query the data.
export type RetrievalTimeframe = "auto" | "none" | TimeFrame;
export type RetrievalQuery = "auto" | "none" | TemplatedQuery;
export type RetrievalDataSourcesConfiguration = DataSourceConfiguration[];

export type RetrievalConfigurationType = {
id: ModelId;

type: "retrieval_configuration";
dataSources: "all" | DataSourceConfiguration[];
query: "auto" | "none" | TemplatedQuery;
relativeTimeFrame: "auto" | "none" | TimeFrame;
dataSources: RetrievalDataSourcesConfiguration;
query: RetrievalQuery;
relativeTimeFrame: RetrievalTimeframe;
topK: number;

// Dynamically decide to skip, if needed in the future
Expand Down
Loading

0 comments on commit 8974e81

Please sign in to comment.