Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runAgent implementation, finalized postUserMessage modulo configuration #1317

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions front/lib/api/assistant/actions/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -332,16 +332,9 @@ export async function* runRetrieval(

const c = configuration.action;
if (!isRetrievalConfiguration(c)) {
return yield {
type: "retrieval_error",
created: Date.now(),
configurationId: configuration.sId,
messageId: agentMessage.sId,
error: {
code: "internal_server_error",
message: "Unexpected action configuration received in `runRetrieval`",
},
};
throw new Error(
"Unexpected action configuration received in `runRetrieval`"
);
}

const paramsRes = await generateRetrievalParams(
Expand Down
92 changes: 92 additions & 0 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ import { runAction } from "@app/lib/actions/server";
import {
RetrievalDocumentsEvent,
RetrievalParamsEvent,
runRetrieval,
} from "@app/lib/api/assistant/actions/retrieval";
import {
GenerationTokensEvent,
renderConversationForModel,
runGeneration,
} from "@app/lib/api/assistant/generation";
import { Authenticator } from "@app/lib/auth";
import { Err, Ok, Result } from "@app/lib/result";
import { generateModelSId } from "@app/lib/utils";
import { isRetrievalConfiguration } from "@app/types/assistant/actions/retrieval";
import {
AgentActionConfigurationType,
AgentActionSpecification,
Expand Down Expand Up @@ -216,6 +219,95 @@ export async function* runAgent(
| AgentGenerationSuccessEvent
| AgentMessageSuccessEvent
> {
// First run the action if a configuration is present.
if (configuration.action !== null) {
if (isRetrievalConfiguration(configuration.action)) {
const eventStream = runRetrieval(
auth,
configuration,
conversation,
userMessage,
agentMessage
);

for await (const event of eventStream) {
if (event.type === "retrieval_params") {
yield event;
}
if (event.type === "retrieval_documents") {
yield event;
}
if (event.type === "retrieval_error") {
yield {
type: "agent_error",
created: event.created,
configurationId: configuration.sId,
messageId: agentMessage.sId,
error: {
code: event.error.code,
message: event.error.message,
},
};
}
if (event.type === "retrieval_success") {
yield {
type: "agent_action_success",
created: event.created,
configurationId: configuration.sId,
messageId: agentMessage.sId,
action: event.action,
};

// We stitch the action into the agent message. The conversation is expected to include
// the agentMessage object, updating this object will update the conversation as well.
agentMessage.action = event.action;
}
}
} else {
throw new Error(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: would be nicer to throw early instead of indenting the whole code inside the if (easier to read / understand if I can pop that branch off my mental stack)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future with more actions this construct will make more sense?

"runAgent implementation missing for action configuration"
);
}

// Then run the generation if a configuration is present.
if (configuration.generation !== null) {
const eventStream = runGeneration(
auth,
configuration,
conversation,
userMessage,
agentMessage
);

for await (const event of eventStream) {
if (event.type === "generation_tokens") {
yield event;
}
if (event.type === "generation_error") {
yield {
type: "agent_error",
created: event.created,
configurationId: configuration.sId,
messageId: agentMessage.sId,
error: {
code: event.error.code,
message: event.error.message,
},
};
}
if (event.type === "generation_success") {
yield {
type: "agent_generation_success",
created: event.created,
configurationId: configuration.sId,
messageId: agentMessage.sId,
text: event.text,
};
}
}
}
}

yield {
type: "agent_error",
created: Date.now(),
Expand Down
147 changes: 73 additions & 74 deletions front/lib/api/assistant/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,80 +177,88 @@ export async function* postUserMessage(
message: userMessage,
};

for (let i = 0; i < agentMessages.length; i++) {
const agentMessage = agentMessages[i];
const agentMessageRow = agentMessageRows[i];
await Promise.allSettled(
agentMessages.map(async function* (agentMessage, i) {
//for (let i = 0; i < agentMessages.length; i++) {
//const agentMessage = agentMessages[i];
const agentMessageRow = agentMessageRows[i];

yield {
type: "agent_message_new",
created: Date.now(),
configurationId: agentMessage.configuration.sId,
messageId: agentMessage.sId,
message: agentMessage,
};

const eventStream = runAgent(
auth,
agentMessage.configuration,
conversation,
userMessage,
agentMessage
);
yield {
type: "agent_message_new",
created: Date.now(),
configurationId: agentMessage.configuration.sId,
messageId: agentMessage.sId,
message: agentMessage,
};

for await (const event of eventStream) {
if (event.type === "agent_error") {
// Store error in database.
await agentMessageRow.update({
status: "failed",
errorCode: event.error.code,
errorMessage: event.error.message,
});
yield event;
}
// For each agent we stitch the conversation to add the user message and only that agent message
// so that it can be used to prompt the agent.
const eventStream = runAgent(
auth,
agentMessage.configuration,
{
...conversation,
content: [...conversation.content, [userMessage], [agentMessage]],
},
userMessage,
agentMessage
);

if (event.type === "agent_action_success") {
// Store action in database.
if (event.action.type === "retrieval_action") {
for await (const event of eventStream) {
if (event.type === "agent_error") {
// Store error in database.
await agentMessageRow.update({
agentRetrievalActionId: event.action.id,
status: "failed",
errorCode: event.error.code,
errorMessage: event.error.message,
});
} else {
throw new Error(
`Action type ${event.action.type} agent_action_success handling not implemented`
);
yield event;
}
yield event;
}

if (event.type === "agent_generation_success") {
// Store message in database.
await agentMessageRow.update({
message: event.text,
});
yield event;
}
if (event.type === "agent_action_success") {
// Store action in database.
if (event.action.type === "retrieval_action") {
await agentMessageRow.update({
agentRetrievalActionId: event.action.id,
});
} else {
throw new Error(
`Action type ${event.action.type} agent_action_success handling not implemented`
);
}
yield event;
}

if (event.type === "agent_message_success") {
// Update status in database.
await agentMessageRow.update({
status: "succeeded",
});
yield event;
}
if (event.type === "agent_generation_success") {
// Store message in database.
await agentMessageRow.update({
message: event.text,
});
yield event;
}

if (event.type === "agent_message_success") {
// Update status in database.
await agentMessageRow.update({
status: "succeeded",
});
yield event;
}

// All other events that won't impact the database and are related to actions or tokens
// generation.
if (
[
"retrieval_params",
"retrieval_documents",
"generation_tokens",
].includes(event.type)
) {
yield event;
// All other events that won't impact the database and are related to actions or tokens
// generation.
if (
[
"retrieval_params",
"retrieval_documents",
"generation_tokens",
].includes(event.type)
) {
yield event;
}
}
}
}
})
);
}

// This method is in charge of re-running an agent interaction (generating a new
Expand Down Expand Up @@ -298,16 +306,7 @@ export async function* editUserMessage(
message: UserMessageType;
content: string;
}
): AsyncGenerator<
| UserMessageNewEvent
| AgentMessageNewEvent
| AgentErrorEvent
| AgentActionEvent
| AgentActionSuccessEvent
| GenerationTokensEvent
| AgentGenerationSuccessEvent
| AgentMessageSuccessEvent
> {
): AsyncGenerator<UserMessageNewEvent | AgentErrorEvent> {
yield {
type: "agent_error",
created: Date.now(),
Expand Down