Skip to content

Commit

Permalink
more progress
Browse files Browse the repository at this point in the history
  • Loading branch information
spolu committed Sep 9, 2023
1 parent 175abdc commit c2fd7a4
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 51 deletions.
78 changes: 74 additions & 4 deletions front/lib/api/assistant/actions/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ import {
import { runAction } from "@app/lib/actions/server";
import { generateActionInputs } from "@app/lib/api/assistant/agent";
import { ModelMessageType } from "@app/lib/api/assistant/generation";
import { Authenticator, prodAPICredentialsForOwner } from "@app/lib/auth";
import { front_sequelize } from "@app/lib/databases";
import { DustAPI } from "@app/lib/dust_api";
import { Authenticator } from "@app/lib/auth";
import { front_sequelize, ModelId } from "@app/lib/databases";
import {
AgentRetrievalAction,
RetrievalDocument,
Expand Down Expand Up @@ -263,6 +262,78 @@ export async function generateRetrievalParams(
});
}

/**
* Action rendering.
*/

// Internal interface for the retrieval and rendering of a retrieval action. This should no be
// outside of api/assistant. We allow a ModelId interface here to save a round trip to the database.
export async function renderRetrievalActionByModelId(
id: ModelId
): Promise<RetrievalActionType> {
const action = await AgentRetrievalAction.findByPk(id);
if (!action) {
throw new Error(`No retrieval action found with id ${id}`);
}

const documentRows = await RetrievalDocument.findAll({
where: {
retrievalActionId: action.id,
},
});

const chunkRows = await RetrievalDocumentChunk.findAll({
where: {
retrievalDocumentId: documentRows.map((d) => d.id),
},
});

let relativeTimeFrame: TimeFrame | null = null;
if (action.relativeTimeFrameDuration && action.relativeTimeFrameUnit) {
relativeTimeFrame = {
duration: action.relativeTimeFrameDuration,
unit: action.relativeTimeFrameUnit,
};
}

const documents: RetrievalDocumentType[] = documentRows.map((d) => {
const chunks = chunkRows
.filter((c) => c.retrievalDocumentId === d.id)
.map((c) => ({
text: c.text,
offset: c.offset,
score: c.score,
}));

chunks.sort((a, b) => b.score - a.score);

return {
id: d.id,
dataSourceId: d.dataSourceId,
sourceUrl: d.sourceUrl,
documentId: d.documentId,
reference: d.reference,
timestamp: d.timestamp,
tags: d.tags,
score: d.score,
chunks,
};
});

documents.sort((a, b) => b.score - a.score);

return {
id: action.id,
type: "retrieval_action",
params: {
query: action.query,
relativeTimeFrame,
topK: action.topK,
},
documents,
};
}

/**
* Action execution.
*/
Expand Down Expand Up @@ -542,7 +613,6 @@ export async function* runRetrieval(
id: action.id,
type: "retrieval_action",
params: {
dataSources: c.dataSources,
relativeTimeFrame: params.relativeTimeFrame,
query: params.query,
topK: params.topK,
Expand Down
75 changes: 29 additions & 46 deletions front/lib/api/assistant/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
AgentMessageSuccessEvent,
runAgent,
} from "@app/lib/api/assistant/agent";
import { getAgentConfiguration } from "@app/lib/api/assistant/configuration";
import { GenerationTokensEvent } from "@app/lib/api/assistant/generation";
import { Authenticator } from "@app/lib/auth";
import { front_sequelize } from "@app/lib/databases";
Expand All @@ -31,6 +32,8 @@ import {
UserMessageType,
} from "@app/types/assistant/conversation";

import { renderRetrievalActionByModelId } from "./actions/retrieval";

/**
* Conversation Creation and Update
*/
Expand Down Expand Up @@ -159,11 +162,9 @@ async function renderAgentMessage(
}),
(async () => {
if (agentMessage.agentRetrievalActionId) {
return await AgentRetrievalAction.findOne({
where: {
id: agentMessage.agentRetrievalActionId,
},
});
return await renderRetrievalActionByModelId(
agentMessage.agentRetrievalActionId
);
}
return null;
})(),
Expand All @@ -175,6 +176,11 @@ async function renderAgentMessage(
);
}

const configuration = await getAgentConfiguration(
auth,
agentConfiguration.sId
);

return {
id: message.id,
sId: message.sId,
Expand All @@ -183,31 +189,11 @@ async function renderAgentMessage(
version: message.version,
parentMessageId: null,
status: agentMessage.status,
action: agentRetrievalAction
? {
id: agentRetrievalAction.id,
type: "retrieval_action",
params: {
dataSources: [], // TODO
query: agentRetrievalAction.query,
relativeTimeFrame: null, // TODO
topK: agentRetrievalAction.topK,
},
documents: [], // TODO
}
: null,
action: agentRetrievalAction,
message: agentMessage.message,
feedbacks: [],
error: null,
configuration: {
sId: agentConfiguration.sId,
status: "active",
name: agentConfiguration.name,
pictureUrl: agentConfiguration.pictureUrl,
// TODO(spolu)
action: null,
generation: null,
},
configuration,
};
}

Expand Down Expand Up @@ -443,19 +429,24 @@ export async function* postUserMessage(
// for each assistant mention, create an "empty" agent message
for (const mention of mentions) {
if (isAgentMention(mention)) {
// TODO(spolu): retrieve configuration from mention.
// Mention.create({
// messageId: m.id,
// configurationId: mention.configurationId,
// });
const configuration = await getAgentConfiguration(
auth,
mention.configurationId
);

await Mention.create({
messageId: m.id,
agentConfigurationId: configuration.id,
});

const agentMessageRow = await AgentMessage.create(
{
// TODO(spolu): add agentConfigurationId
status: "created",
agentConfigurationId: configuration.id,
},
{ transaction: t }
);
const m = await Message.create(
const messageRow = await Message.create(
{
sId: generateModelSId(),
rank: nextMessageRank++,
Expand All @@ -467,10 +458,11 @@ export async function* postUserMessage(
transaction: t,
}
);

agentMessageRows.push(agentMessageRow);
agentMessages.push({
id: m.id,
sId: m.sId,
id: messageRow.id,
sId: messageRow.sId,
type: "agent_message",
visibility: "visible",
version: 0,
Expand All @@ -480,14 +472,7 @@ export async function* postUserMessage(
message: null,
feedbacks: [],
error: null,
configuration: {
sId: mention.configurationId,
status: "active",
name: "foo", // TODO
pictureUrl: null, // TODO
action: null, // TODO
generation: null, // TODO
},
configuration,
});
}

Expand Down Expand Up @@ -524,8 +509,6 @@ export async function* postUserMessage(

await Promise.allSettled(
agentMessages.map(async function* (agentMessage, i) {
//for (let i = 0; i < agentMessages.length; i++) {
//const agentMessage = agentMessages[i];
const agentMessageRow = agentMessageRows[i];

yield {
Expand Down
1 change: 0 additions & 1 deletion front/types/assistant/actions/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ export type RetrievalActionType = {
id: ModelId; // AgentRetrieval.
type: "retrieval_action";
params: {
dataSources: "all" | DataSourceConfiguration[];
relativeTimeFrame: TimeFrame | null;
query: string | null;
topK: number;
Expand Down
1 change: 1 addition & 0 deletions front/types/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export type AgentConfigurationStatus = "active" | "archived";
export type AgentConfigurationScope = "global" | "workspace";

export type AgentConfigurationType = {
id: ModelId;
sId: string;
status: AgentConfigurationStatus;

Expand Down

0 comments on commit c2fd7a4

Please sign in to comment.