Skip to content

Commit

Permalink
Post User Message in public API answers with agentMessages (#4543)
Browse files Browse the repository at this point in the history
* Post User Message in public API answers with agentMessages

* Apply feedback
  • Loading branch information
PopDaph authored Apr 4, 2024
1 parent e1a08a5 commit 3ef4baf
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 137 deletions.
238 changes: 149 additions & 89 deletions front/lib/api/assistant/pubsub.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,17 @@ export async function postUserMessageWithPubSub(
content: string;
mentions: MentionType[];
context: UserMessageContext;
}
): Promise<Result<UserMessageType, PubSubError>> {
},
{ resolveAfterFullGeneration }: { resolveAfterFullGeneration: boolean }
): Promise<
Result<
{
userMessage: UserMessageType;
agentMessages?: AgentMessageType[];
},
PubSubError
>
> {
let maxPerTimeframe: number | undefined = undefined;
let timeframeSeconds: number | undefined = undefined;
let rateLimitKey: string | undefined = "";
Expand Down Expand Up @@ -91,7 +100,11 @@ export async function postUserMessageWithPubSub(
mentions,
context,
});
return handleUserMessageEvents(conversation, postMessageEvents);
return handleUserMessageEvents(
conversation,
postMessageEvents,
resolveAfterFullGeneration
);
}

export async function editUserMessageWithPubSub(
Expand All @@ -107,14 +120,22 @@ export async function editUserMessageWithPubSub(
content: string;
mentions: MentionType[];
}
): Promise<Result<UserMessageType, PubSubError>> {
): Promise<
Result<
{
userMessage: UserMessageType;
agentMessages?: AgentMessageType[];
},
PubSubError
>
> {
const editMessageEvents = editUserMessage(auth, {
conversation,
message,
content,
mentions,
});
return handleUserMessageEvents(conversation, editMessageEvents);
return handleUserMessageEvents(conversation, editMessageEvents, false);
}

async function handleUserMessageEvents(
Expand All @@ -132,107 +153,146 @@ async function handleUserMessageEvents(
| AgentMessageSuccessEvent
| ConversationTitleEvent,
void
>,
resolveAfterFullGeneration = false
): Promise<
Result<
{
userMessage: UserMessageType;
agentMessages?: AgentMessageType[];
},
PubSubError
>
): Promise<Result<UserMessageType, PubSubError>> {
const promise: Promise<Result<UserMessageType, PubSubError>> = new Promise(
(resolve) => {
void wakeLock(async () => {
const redis = await redisClient();
let didResolve = false;
try {
for await (const event of messageEventGenerator) {
switch (event.type) {
case "user_message_new":
case "agent_message_new":
case "conversation_title": {
const pubsubChannel = getConversationChannelId(
conversation.sId
);
await redis.xAdd(pubsubChannel, "*", {
payload: JSON.stringify(event),
});
await redis.expire(pubsubChannel, 60 * 10);
if (event.type === "user_message_new") {
didResolve = true;
resolve(new Ok(event.message));
}
break;
}
case "retrieval_params":
case "dust_app_run_params":
case "dust_app_run_block":
case "tables_query_params":
case "tables_query_output":
case "agent_error":
case "agent_action_success":
case "generation_tokens":
case "agent_generation_success":
case "agent_generation_cancelled":
case "agent_message_success": {
const pubsubChannel = getMessageChannelId(event.messageId);
await redis.xAdd(pubsubChannel, "*", {
payload: JSON.stringify(event),
});
await redis.expire(pubsubChannel, 60 * 10);
break;
}
case "user_message_error": {
// We resolve the promise with an error as we were not able to
// create the user message. This is possible for a variety of
// reason and will get turned into a 400 in the API route calling
// `{post/edit}UserMessageWithPubSub`, except for the case of used
// up messages for the test plan, handled separately
> {
const promise: Promise<
Result<
{
userMessage: UserMessageType;
agentMessages?: AgentMessageType[];
},
PubSubError
>
> = new Promise((resolve) => {
void wakeLock(async () => {
const redis = await redisClient();
let didResolve = false;

didResolve = true;
if (event.error.code === "plan_message_limit_exceeded") {
let userMessage: UserMessageType | undefined = undefined;
const agentMessages: AgentMessageType[] = [];
try {
for await (const event of messageEventGenerator) {
switch (event.type) {
case "user_message_new":
case "agent_message_new":
case "conversation_title": {
const pubsubChannel = getConversationChannelId(conversation.sId);
await redis.xAdd(pubsubChannel, "*", {
payload: JSON.stringify(event),
});
await redis.expire(pubsubChannel, 60 * 10);
if (event.type === "user_message_new") {
userMessage = event.message;
if (!resolveAfterFullGeneration) {
didResolve = true;
resolve(
new Err({
status_code: 403,
api_error: {
type: "plan_message_limit_exceeded",
message: event.error.message,
},
new Ok({
userMessage,
})
);
}
}
break;
}
case "retrieval_params":
case "dust_app_run_params":
case "dust_app_run_block":
case "tables_query_params":
case "tables_query_output":
case "agent_error":
case "agent_action_success":
case "generation_tokens":
case "agent_generation_success":
case "agent_generation_cancelled":
case "agent_message_success": {
const pubsubChannel = getMessageChannelId(event.messageId);
await redis.xAdd(pubsubChannel, "*", {
payload: JSON.stringify(event),
});
await redis.expire(pubsubChannel, 60 * 10);

if (
event.type === "agent_message_success" &&
resolveAfterFullGeneration
) {
agentMessages.push(event.message);
}
break;
}
case "user_message_error": {
// We resolve the promise with an error as we were not able to
// create the user message. This is possible for a variety of
// reason and will get turned into a 400 in the API route calling
// `{post/edit}UserMessageWithPubSub`, except for the case of used
// up messages for the test plan, handled separately

didResolve = true;
if (event.error.code === "plan_message_limit_exceeded") {
resolve(
new Err({
status_code: 400,
status_code: 403,
api_error: {
type: "invalid_request_error",
type: "plan_message_limit_exceeded",
message: event.error.message,
},
})
);
break;
}

default:
((event: never) => {
logger.error("Unknown event type", event);
})(event);
return null;
resolve(
new Err({
status_code: 400,
api_error: {
type: "invalid_request_error",
message: event.error.message,
},
})
);
break;
}
}
} catch (e) {
logger.error({ error: e }, "Error Posting message");
} finally {
await redis.quit();
if (!didResolve) {
resolve(
new Err({
status_code: 500,
api_error: {
type: "internal_server_error",
message: `Never got the user_message_new event for ${conversation.sId}`,
},
})
);

default:
((event: never) => {
logger.error("Unknown event type", event);
})(event);
return null;
}
}
});
}
);
if (resolveAfterFullGeneration && userMessage && !didResolve) {
didResolve = true;
resolve(
new Ok({
userMessage,
agentMessages,
})
);
}
} catch (e) {
logger.error({ error: e }, "Error Posting message");
} finally {
await redis.quit();
if (!didResolve) {
resolve(
new Err({
status_code: 500,
api_error: {
type: "internal_server_error",
message: `Never got the resolved event for ${conversation.sId} (resolveAfterFullGeneration: ${resolveAfterFullGeneration})`,
},
})
);
}
}
});
});

return promise;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import type { UserMessageType, WithAPIErrorReponse } from "@dust-tt/types";
import type {
AgentMessageType,
UserMessageType,
WithAPIErrorReponse,
} from "@dust-tt/types";
import { PublicPostMessagesRequestBodySchema } from "@dust-tt/types";
import { isLeft } from "fp-ts/lib/Either";
import * as reporter from "io-ts-reporters";
Expand All @@ -11,11 +15,12 @@ import { apiError, withLogging } from "@app/logger/withlogging";

export type PostMessagesResponseBody = {
message: UserMessageType;
agentMessages?: AgentMessageType[];
};

async function handler(
req: NextApiRequest,
res: NextApiResponse<WithAPIErrorReponse<{ message: UserMessageType }>>
res: NextApiResponse<WithAPIErrorReponse<PostMessagesResponseBody>>
): Promise<void> {
const keyRes = await getAPIKey(req);
if (keyRes.isErr()) {
Expand Down Expand Up @@ -64,19 +69,26 @@ async function handler(
});
}

const { content, context, mentions } = bodyValidation.right;
const { content, context, mentions, isSync } = bodyValidation.right;

const messageRes = await postUserMessageWithPubSub(auth, {
conversation,
content,
mentions,
context,
});
const messageRes = await postUserMessageWithPubSub(
auth,
{
conversation,
content,
mentions,
context,
},
{ resolveAfterFullGeneration: isSync === true }
);
if (messageRes.isErr()) {
return apiError(req, res, messageRes.error);
}

res.status(200).json({ message: messageRes.value });
res.status(200).json({
message: messageRes.value.userMessage,
agentMessages: messageRes.value.agentMessages ?? undefined,
});
return;

default:
Expand Down
30 changes: 17 additions & 13 deletions front/pages/api/v1/w/[wId]/assistant/conversations/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async function handler(
});
}

const { title, visibility, message, contentFragment } =
const { title, visibility, message, contentFragment, isSync } =
bodyValidation.right;

if (contentFragment) {
Expand Down Expand Up @@ -135,24 +135,28 @@ async function handler(
// before returning the conversation along with the message.
// PostUserMessageWithPubSub returns swiftly since it only waits for the
// initial message creation event (or error)
const messageRes = await postUserMessageWithPubSub(auth, {
conversation,
content: message.content,
mentions: message.mentions,
context: {
timezone: message.context.timezone,
username: message.context.username,
fullName: message.context.fullName,
email: message.context.email,
profilePictureUrl: message.context.profilePictureUrl,
const messageRes = await postUserMessageWithPubSub(
auth,
{
conversation,
content: message.content,
mentions: message.mentions,
context: {
timezone: message.context.timezone,
username: message.context.username,
fullName: message.context.fullName,
email: message.context.email,
profilePictureUrl: message.context.profilePictureUrl,
},
},
});
{ resolveAfterFullGeneration: isSync === true }
);

if (messageRes.isErr()) {
return apiError(req, res, messageRes.error);
}

newMessage = messageRes.value;
newMessage = messageRes.value.userMessage;
}

if (newContentFragment || newMessage) {
Expand Down
Loading

0 comments on commit 3ef4baf

Please sign in to comment.