-
Notifications
You must be signed in to change notification settings - Fork 112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
runGeneration
implementation and associated Dust app
#1302
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,6 @@ import { | |
} from "@app/lib/actions/registry"; | ||
import { runAction } from "@app/lib/actions/server"; | ||
import { generateActionInputs } from "@app/lib/api/assistant/agent"; | ||
import { ModelMessageType } from "@app/lib/api/assistant/conversation"; | ||
import { Authenticator, prodAPICredentialsForOwner } from "@app/lib/auth"; | ||
import { front_sequelize } from "@app/lib/databases"; | ||
import { DustAPI } from "@app/lib/dust_api"; | ||
|
@@ -14,6 +13,7 @@ import { | |
RetrievalDocumentChunk, | ||
} from "@app/lib/models"; | ||
import { Err, Ok, Result } from "@app/lib/result"; | ||
import { new_id } from "@app/lib/utils"; | ||
import logger from "@app/logger/logger"; | ||
import { | ||
DataSourceConfiguration, | ||
|
@@ -33,6 +33,8 @@ import { | |
UserMessageType, | ||
} from "@app/types/assistant/conversation"; | ||
|
||
import { ModelMessageType } from "../generation"; | ||
|
||
/** | ||
* TimeFrame parsing | ||
*/ | ||
|
@@ -507,7 +509,7 @@ export async function* runRetrieval( | |
id: 0, // dummy pending database insertion | ||
dataSourceId: d.data_source_id, | ||
documentId: d.document_id, | ||
reference: "", | ||
reference: new_id().slice(0, 3), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: In the type this field is described as is: "Short random string so that the model can refer to the document." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we want it to be much shorter. The idea is for the model (with proper prompting) to generate references using eg This is quite WIP of course. |
||
timestamp: d.timestamp, | ||
tags: d.tags, | ||
sourceUrl: d.source_url ?? null, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,14 @@ import { | |
DustProdActionRegistry, | ||
} from "@app/lib/actions/registry"; | ||
import { runAction } from "@app/lib/actions/server"; | ||
import { | ||
RetrievalDocumentsEvent, | ||
RetrievalParamsEvent, | ||
} from "@app/lib/api/assistant/actions/retrieval"; | ||
import { | ||
GenerationTokensEvent, | ||
renderConversationForModel, | ||
} from "@app/lib/api/assistant/generation"; | ||
import { Authenticator } from "@app/lib/auth"; | ||
import { Err, Ok, Result } from "@app/lib/result"; | ||
import { generateModelSId } from "@app/lib/utils"; | ||
|
@@ -19,12 +27,6 @@ import { | |
ConversationType, | ||
} from "@app/types/assistant/conversation"; | ||
|
||
import { | ||
RetrievalDocumentsEvent, | ||
RetrievalParamsEvent, | ||
} from "./actions/retrieval"; | ||
import { renderConversationForModel } from "./conversation"; | ||
|
||
/** | ||
* Agent configuration. | ||
*/ | ||
|
@@ -187,18 +189,18 @@ export type AgentActionSuccessEvent = { | |
action: AgentActionType; | ||
}; | ||
|
||
// Event sent when tokens are streamed as the the agent is generating a message. | ||
export type AgentGenerationTokensEvent = { | ||
type: "agent_generation_tokens"; | ||
// Event sent once the generation is completed. | ||
export type AgentGenerationSuccessEvent = { | ||
type: "agent_generation_success"; | ||
created: number; | ||
configurationId: string; | ||
messageId: string; | ||
text: string; | ||
}; | ||
|
||
// Event sent once the message is completed and successful. | ||
export type AgentGenerationSuccessEvent = { | ||
type: "agent_generation_success"; | ||
export type AgentSuccessEvent = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: Looking at the comment, should it be "AgentMessageSuccessEvent"? (asking as "once the message is completed and successful" means it is attached to the generation, right?) Or will we use this event also after a successful action and it's the comment that needs to be updated? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to call it AgentMessageSuccessEvent 👍 The idea is for this one to be the final success event of the agent meaning that everything is ready and done. Since we have |
||
type: "agent_success"; | ||
created: number; | ||
configurationId: string; | ||
generationId: string; | ||
|
@@ -217,8 +219,9 @@ export async function* runAgent( | |
| AgentErrorEvent | ||
| AgentActionEvent | ||
| AgentActionSuccessEvent | ||
| AgentGenerationTokensEvent | ||
| GenerationTokensEvent | ||
| AgentGenerationSuccessEvent | ||
| AgentSuccessEvent | ||
> { | ||
yield { | ||
type: "agent_error", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,147 +3,23 @@ import { | |
AgentActionSuccessEvent, | ||
AgentErrorEvent, | ||
AgentGenerationSuccessEvent, | ||
AgentGenerationTokensEvent, | ||
AgentMessageNewEvent, | ||
AgentSuccessEvent, | ||
} from "@app/lib/api/assistant/agent"; | ||
import { Authenticator } from "@app/lib/auth"; | ||
import { CoreAPI } from "@app/lib/core_api"; | ||
import { front_sequelize } from "@app/lib/databases"; | ||
import { AgentMessage, Message, UserMessage } from "@app/lib/models"; | ||
import { Err, Ok, Result } from "@app/lib/result"; | ||
import { generateModelSId } from "@app/lib/utils"; | ||
import logger from "@app/logger/logger"; | ||
import { isRetrievalActionType } from "@app/types/assistant/actions/retrieval"; | ||
import { | ||
AgentMessageType, | ||
ConversationType, | ||
isAgentMention, | ||
isAgentMessageType, | ||
isUserMessageType, | ||
Mention, | ||
UserMessageContext, | ||
UserMessageType, | ||
} from "@app/types/assistant/conversation"; | ||
|
||
import { renderRetrievalActionForModel } from "./actions/retrieval"; | ||
|
||
/** | ||
* Model rendering of conversations. | ||
*/ | ||
|
||
export type ModelMessageType = { | ||
role: "action" | "agent" | "user"; | ||
name: string; | ||
content: string; | ||
}; | ||
|
||
export type ModelConversationType = { | ||
messages: ModelMessageType[]; | ||
}; | ||
|
||
// This function transforms a conversation in a simplified format that we feed the model as context. | ||
// It takes care of truncating the conversation all the way to `allowedTokenCount` tokens. | ||
export async function renderConversationForModel({ | ||
conversation, | ||
model, | ||
allowedTokenCount, | ||
}: { | ||
conversation: ConversationType; | ||
model: { providerId: string; modelId: string }; | ||
allowedTokenCount: number; | ||
}): Promise<Result<ModelConversationType, Error>> { | ||
const messages = []; | ||
|
||
let retrievalFound = false; | ||
|
||
// Render all messages and all actions but only keep the latest retrieval action. | ||
for (let i = conversation.content.length - 1; i >= 0; i--) { | ||
const versions = conversation.content[i]; | ||
const m = versions[versions.length - 1]; | ||
|
||
if (isAgentMessageType(m)) { | ||
if (m.action) { | ||
if (isRetrievalActionType(m.action) && !retrievalFound) { | ||
messages.unshift(renderRetrievalActionForModel(m.action)); | ||
retrievalFound = true; | ||
} else { | ||
return new Err( | ||
new Error( | ||
"Unsupported action type during conversation model rendering" | ||
) | ||
); | ||
} | ||
} | ||
if (m.message) { | ||
messages.unshift({ | ||
role: "agent" as const, | ||
name: m.configuration.name, | ||
content: m.message, | ||
}); | ||
} | ||
} | ||
if (isUserMessageType(m)) { | ||
messages.unshift({ | ||
role: "user" as const, | ||
name: m.context.username, | ||
content: m.message, | ||
}); | ||
} | ||
} | ||
|
||
async function tokenCountForMessage( | ||
message: ModelMessageType, | ||
model: { providerId: string; modelId: string } | ||
): Promise<Result<number, Error>> { | ||
const res = await CoreAPI.tokenize({ | ||
text: message.content, | ||
providerId: model.providerId, | ||
modelId: model.modelId, | ||
}); | ||
|
||
if (res.isErr()) { | ||
return new Err(new Error(`Error tokenizing model message: ${res.error}`)); | ||
} | ||
|
||
return new Ok(res.value.tokens.length); | ||
} | ||
|
||
const now = Date.now(); | ||
|
||
// This is a bit aggressive but fuck it. | ||
const tokenCountRes = await Promise.all( | ||
messages.map((m) => { | ||
return tokenCountForMessage(m, model); | ||
}) | ||
); | ||
|
||
logger.info( | ||
{ | ||
messageCount: messages.length, | ||
elapsed: Date.now() - now, | ||
}, | ||
"[ASSISTANT_STATS] message token counts for model conversation rendering" | ||
); | ||
|
||
// Go backward and accumulate as much as we can within allowedTokenCount. | ||
const selected = []; | ||
let tokensUsed = 0; | ||
for (let i = messages.length - 1; i >= 0; i--) { | ||
const r = tokenCountRes[i]; | ||
if (r.isErr()) { | ||
return new Err(r.error); | ||
} | ||
const c = r.value; | ||
if (tokensUsed + c <= allowedTokenCount) { | ||
tokensUsed += c; | ||
selected.unshift(messages[i]); | ||
} | ||
} | ||
|
||
return new Ok({ | ||
messages: selected, | ||
}); | ||
} | ||
import { GenerationTokensEvent } from "./generation"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: relative import |
||
|
||
/** | ||
* Conversation API | ||
|
@@ -176,8 +52,9 @@ export async function* postUserMessage( | |
| AgentErrorEvent | ||
| AgentActionEvent | ||
| AgentActionSuccessEvent | ||
| AgentGenerationTokensEvent | ||
| GenerationTokensEvent | ||
| AgentGenerationSuccessEvent | ||
| AgentSuccessEvent | ||
> { | ||
const user = auth.user(); | ||
|
||
|
@@ -320,8 +197,9 @@ export async function* retryAgentMessage( | |
| AgentErrorEvent | ||
| AgentActionEvent | ||
| AgentActionSuccessEvent | ||
| AgentGenerationTokensEvent | ||
| GenerationTokensEvent | ||
| AgentGenerationSuccessEvent | ||
| AgentSuccessEvent | ||
> { | ||
yield { | ||
type: "agent_error", | ||
|
@@ -354,8 +232,9 @@ export async function* editUserMessage( | |
| AgentErrorEvent | ||
| AgentActionEvent | ||
| AgentActionSuccessEvent | ||
| AgentGenerationTokensEvent | ||
| GenerationTokensEvent | ||
| AgentGenerationSuccessEvent | ||
| AgentSuccessEvent | ||
> { | ||
yield { | ||
type: "agent_error", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: relative import