Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runGeneration implementation and associated Dust app #1302

Merged
merged 5 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions front/lib/actions/registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ export const DustProdActionRegistry = createActionRegistry({
},
},
},
"assistant-v2-generator": {
app: {
workspaceId: PRODUCTION_DUST_APPS_WORKSPACE_ID,
appId: "6a27050429",
appHash:
"356e16f5254284cc1c08512bebf9638bbc3e94eb5b29ac27599ccce7bee7843c",
},
config: {
MODEL: {
provider_id: "openai",
model_id: "gpt-4",
function_call: null,
use_cache: false,
use_stream: true,
},
},
},

"chat-retrieval": {
app: {
Expand Down
6 changes: 4 additions & 2 deletions front/lib/api/assistant/actions/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {
} from "@app/lib/actions/registry";
import { runAction } from "@app/lib/actions/server";
import { generateActionInputs } from "@app/lib/api/assistant/agent";
import { ModelMessageType } from "@app/lib/api/assistant/conversation";
import { Authenticator, prodAPICredentialsForOwner } from "@app/lib/auth";
import { front_sequelize } from "@app/lib/databases";
import { DustAPI } from "@app/lib/dust_api";
Expand All @@ -14,6 +13,7 @@ import {
RetrievalDocumentChunk,
} from "@app/lib/models";
import { Err, Ok, Result } from "@app/lib/result";
import { new_id } from "@app/lib/utils";
import logger from "@app/logger/logger";
import {
DataSourceConfiguration,
Expand All @@ -33,6 +33,8 @@ import {
UserMessageType,
} from "@app/types/assistant/conversation";

import { ModelMessageType } from "../generation";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: relative import


/**
* TimeFrame parsing
*/
Expand Down Expand Up @@ -507,7 +509,7 @@ export async function* runRetrieval(
id: 0, // dummy pending database insertion
dataSourceId: d.data_source_id,
documentId: d.document_id,
reference: "",
reference: new_id().slice(0, 3),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: In the type this field is described as is: "Short random string so that the model can refer to the document."
As "reference" is not self-explanatory, what do you think about calling sId and have it sliced to 10 char?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we want it to be much shorter. The idea is for the model (with proper prompting) to generate references using eg [a32] or [ef2]. We want it short so that the generation is snappy and we can render it as a nice reference in the UI.

This is quite WIP of course.

timestamp: d.timestamp,
tags: d.tags,
sourceUrl: d.source_url ?? null,
Expand Down
27 changes: 15 additions & 12 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ import {
DustProdActionRegistry,
} from "@app/lib/actions/registry";
import { runAction } from "@app/lib/actions/server";
import {
RetrievalDocumentsEvent,
RetrievalParamsEvent,
} from "@app/lib/api/assistant/actions/retrieval";
import {
GenerationTokensEvent,
renderConversationForModel,
} from "@app/lib/api/assistant/generation";
import { Authenticator } from "@app/lib/auth";
import { Err, Ok, Result } from "@app/lib/result";
import { generateModelSId } from "@app/lib/utils";
Expand All @@ -19,12 +27,6 @@ import {
ConversationType,
} from "@app/types/assistant/conversation";

import {
RetrievalDocumentsEvent,
RetrievalParamsEvent,
} from "./actions/retrieval";
import { renderConversationForModel } from "./conversation";

/**
* Agent configuration.
*/
Expand Down Expand Up @@ -187,18 +189,18 @@ export type AgentActionSuccessEvent = {
action: AgentActionType;
};

// Event sent when tokens are streamed as the the agent is generating a message.
export type AgentGenerationTokensEvent = {
type: "agent_generation_tokens";
// Event sent once the generation is completed.
export type AgentGenerationSuccessEvent = {
type: "agent_generation_success";
created: number;
configurationId: string;
messageId: string;
text: string;
};

// Event sent once the message is completed and successful.
export type AgentGenerationSuccessEvent = {
type: "agent_generation_success";
export type AgentSuccessEvent = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Looking at the comment, should it be "AgentMessageSuccessEvent"? (asking as "once the message is completed and successful" means it is attached to the generation, right?)

Or will we use this event also after a successful action and it's the comment that needs to be updated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to call it AgentMessageSuccessEvent 👍 The idea is for this one to be the final success event of the agent meaning that everything is ready and done. Since we have succeeded on AgentMessage, let's add Message in the event name here 👍

type: "agent_success";
created: number;
configurationId: string;
generationId: string;
Expand All @@ -217,8 +219,9 @@ export async function* runAgent(
| AgentErrorEvent
| AgentActionEvent
| AgentActionSuccessEvent
| AgentGenerationTokensEvent
| GenerationTokensEvent
| AgentGenerationSuccessEvent
| AgentSuccessEvent
> {
yield {
type: "agent_error",
Expand Down
137 changes: 8 additions & 129 deletions front/lib/api/assistant/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,147 +3,23 @@ import {
AgentActionSuccessEvent,
AgentErrorEvent,
AgentGenerationSuccessEvent,
AgentGenerationTokensEvent,
AgentMessageNewEvent,
AgentSuccessEvent,
} from "@app/lib/api/assistant/agent";
import { Authenticator } from "@app/lib/auth";
import { CoreAPI } from "@app/lib/core_api";
import { front_sequelize } from "@app/lib/databases";
import { AgentMessage, Message, UserMessage } from "@app/lib/models";
import { Err, Ok, Result } from "@app/lib/result";
import { generateModelSId } from "@app/lib/utils";
import logger from "@app/logger/logger";
import { isRetrievalActionType } from "@app/types/assistant/actions/retrieval";
import {
AgentMessageType,
ConversationType,
isAgentMention,
isAgentMessageType,
isUserMessageType,
Mention,
UserMessageContext,
UserMessageType,
} from "@app/types/assistant/conversation";

import { renderRetrievalActionForModel } from "./actions/retrieval";

/**
* Model rendering of conversations.
*/

export type ModelMessageType = {
role: "action" | "agent" | "user";
name: string;
content: string;
};

export type ModelConversationType = {
messages: ModelMessageType[];
};

// This function transforms a conversation in a simplified format that we feed the model as context.
// It takes care of truncating the conversation all the way to `allowedTokenCount` tokens.
export async function renderConversationForModel({
conversation,
model,
allowedTokenCount,
}: {
conversation: ConversationType;
model: { providerId: string; modelId: string };
allowedTokenCount: number;
}): Promise<Result<ModelConversationType, Error>> {
const messages = [];

let retrievalFound = false;

// Render all messages and all actions but only keep the latest retrieval action.
for (let i = conversation.content.length - 1; i >= 0; i--) {
const versions = conversation.content[i];
const m = versions[versions.length - 1];

if (isAgentMessageType(m)) {
if (m.action) {
if (isRetrievalActionType(m.action) && !retrievalFound) {
messages.unshift(renderRetrievalActionForModel(m.action));
retrievalFound = true;
} else {
return new Err(
new Error(
"Unsupported action type during conversation model rendering"
)
);
}
}
if (m.message) {
messages.unshift({
role: "agent" as const,
name: m.configuration.name,
content: m.message,
});
}
}
if (isUserMessageType(m)) {
messages.unshift({
role: "user" as const,
name: m.context.username,
content: m.message,
});
}
}

async function tokenCountForMessage(
message: ModelMessageType,
model: { providerId: string; modelId: string }
): Promise<Result<number, Error>> {
const res = await CoreAPI.tokenize({
text: message.content,
providerId: model.providerId,
modelId: model.modelId,
});

if (res.isErr()) {
return new Err(new Error(`Error tokenizing model message: ${res.error}`));
}

return new Ok(res.value.tokens.length);
}

const now = Date.now();

// This is a bit aggressive but fuck it.
const tokenCountRes = await Promise.all(
messages.map((m) => {
return tokenCountForMessage(m, model);
})
);

logger.info(
{
messageCount: messages.length,
elapsed: Date.now() - now,
},
"[ASSISTANT_STATS] message token counts for model conversation rendering"
);

// Go backward and accumulate as much as we can within allowedTokenCount.
const selected = [];
let tokensUsed = 0;
for (let i = messages.length - 1; i >= 0; i--) {
const r = tokenCountRes[i];
if (r.isErr()) {
return new Err(r.error);
}
const c = r.value;
if (tokensUsed + c <= allowedTokenCount) {
tokensUsed += c;
selected.unshift(messages[i]);
}
}

return new Ok({
messages: selected,
});
}
import { GenerationTokensEvent } from "./generation";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: relative import


/**
* Conversation API
Expand Down Expand Up @@ -176,8 +52,9 @@ export async function* postUserMessage(
| AgentErrorEvent
| AgentActionEvent
| AgentActionSuccessEvent
| AgentGenerationTokensEvent
| GenerationTokensEvent
| AgentGenerationSuccessEvent
| AgentSuccessEvent
> {
const user = auth.user();

Expand Down Expand Up @@ -320,8 +197,9 @@ export async function* retryAgentMessage(
| AgentErrorEvent
| AgentActionEvent
| AgentActionSuccessEvent
| AgentGenerationTokensEvent
| GenerationTokensEvent
| AgentGenerationSuccessEvent
| AgentSuccessEvent
> {
yield {
type: "agent_error",
Expand Down Expand Up @@ -354,8 +232,9 @@ export async function* editUserMessage(
| AgentErrorEvent
| AgentActionEvent
| AgentActionSuccessEvent
| AgentGenerationTokensEvent
| GenerationTokensEvent
| AgentGenerationSuccessEvent
| AgentSuccessEvent
> {
yield {
type: "agent_error",
Expand Down
Loading