diff --git a/docker-compose.yml b/docker-compose.yml index 514308639d09..be3846fa08e4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,6 +16,11 @@ services: - dustvolume:/qdrant ports: - 6334:6334 + - 6333:6333 + redis: + image: redis + ports: + - 6379:6379 volumes: diff --git a/front/lib/api/assistant/pubsub.ts b/front/lib/api/assistant/pubsub.ts new file mode 100644 index 000000000000..81d112517416 --- /dev/null +++ b/front/lib/api/assistant/pubsub.ts @@ -0,0 +1,158 @@ +import { + AgentActionEvent, + AgentActionSuccessEvent, + AgentErrorEvent, + AgentGenerationSuccessEvent, +} from "@app/lib/api/assistant/agent"; +import { + AgentMessageNewEvent, + postUserMessage, + UserMessageNewEvent, +} from "@app/lib/api/assistant/conversation"; +import { GenerationTokensEvent } from "@app/lib/api/assistant/generation"; +import { Authenticator } from "@app/lib/auth"; +import { redisClient } from "@app/lib/redis"; +import logger from "@app/logger/logger"; +import { + ConversationType, + Mention, + UserMessageContext, +} from "@app/types/assistant/conversation"; + +export async function postUserMessageWithPubSub( + auth: Authenticator, + { + conversation, + message, + mentions, + context, + }: { + conversation: ConversationType; + message: string; + mentions: Mention[]; + context: UserMessageContext; + } +) { + const redis = await redisClient(); + try { + for await (const event of postUserMessage(auth, { + conversation, + message, + mentions, + context, + })) { + switch (event.type) { + case "user_message_new": + case "agent_message_new": { + const pubsubChannel = getConversationChannelId(conversation.sId); + await redis.xAdd(pubsubChannel, "*", { + payload: JSON.stringify(event), + }); + break; + } + case "retrieval_params": + case "agent_error": + case "agent_action_success": + case "retrieval_documents": + case "generation_tokens": + case "agent_generation_success": + case "agent_message_success": { + const pubsubChannel = getMessageChannelId(event.messageId); + await redis.xAdd(pubsubChannel, "*", { + payload: JSON.stringify(event), + }); + break; + } + + default: + ((blockParent: never) => { + logger.error("Unknown event type", blockParent); + })(event); + return null; + } + } + } finally { + await redis.quit(); + } + console.log("exiting postUserMessageWithPubSub", message, conversation.sId); +} + +export async function* getConversationEvents( + conversationId: string, + lastEventId: string | null +): AsyncGenerator<{ + eventId: string; + data: UserMessageNewEvent | AgentMessageNewEvent; +}> { + const redis = await redisClient(); + const pubsubChannel = getConversationChannelId(conversationId); + + try { + while (true) { + const events = await redis.xRead( + { key: pubsubChannel, id: lastEventId ? lastEventId : "0-0" }, + // weird, xread does not return on new message when count is = 1. Anything over 1 works. + { COUNT: 1, BLOCK: 60 * 1000 } + ); + if (!events) { + return; + } + for (const event of events) { + for (const message of event.messages) { + const payloadStr = message.message["payload"]; + const messageId = message.id; + const payload = JSON.parse(payloadStr); + lastEventId = messageId; + yield { + eventId: messageId, + data: payload, + }; + } + } + } + } finally { + await redis.quit(); + } +} + +export async function* getMessagesEvents( + messageId: string, + lastEventId: string | null +): AsyncGenerator<{ + eventId: string; + data: + | AgentErrorEvent + | AgentActionEvent + | AgentActionSuccessEvent + | GenerationTokensEvent + | AgentGenerationSuccessEvent; +}> { + const pubsubChannel = getMessageChannelId(messageId); + const client = await redisClient(); + const events = await client.xRead( + { key: pubsubChannel, id: lastEventId ? lastEventId : "0-0" }, + { COUNT: 1, BLOCK: 60 * 1000 } + ); + if (!events) { + return; + } + for (const event of events) { + for (const message of event.messages) { + const payloadStr = message.message["payload"]; + const messageId = message.id; + const payload = JSON.parse(payloadStr); + yield { + eventId: messageId, + data: payload, + }; + } + } +} + +function getConversationChannelId(channelId: string) { + return `conversation-${channelId}`; +} + +function getMessageChannelId(messageId: string) { + return `message-${messageId}`; +} diff --git a/front/lib/error.ts b/front/lib/error.ts index 62c19b39f30e..1094048be3af 100644 --- a/front/lib/error.ts +++ b/front/lib/error.ts @@ -41,7 +41,8 @@ export type APIErrorType = | "extracted_event_not_found" | "connector_update_error" | "connector_update_unauthorized" - | "connector_oauth_target_mismatch"; + | "connector_oauth_target_mismatch" + | "conversation_not_found"; export type APIError = { type: APIErrorType; diff --git a/front/lib/redis.ts b/front/lib/redis.ts new file mode 100644 index 000000000000..7e2966714378 --- /dev/null +++ b/front/lib/redis.ts @@ -0,0 +1,16 @@ +import { createClient } from "redis"; + +export async function redisClient() { + const { REDIS_URI } = process.env; + if (!REDIS_URI) { + throw new Error("REDIS_URI is not defined"); + } + const client = createClient({ + url: REDIS_URI, + }); + client.on("error", (err) => console.log("Redis Client Error", err)); + + await client.connect(); + + return client; +} diff --git a/front/package-lock.json b/front/package-lock.json index ad42f2aa5ec9..910a484cdb91 100644 --- a/front/package-lock.json +++ b/front/package-lock.json @@ -48,6 +48,7 @@ "react-markdown": "^8.0.7", "react-p5": "^1.3.35", "react-textarea-autosize": "^8.4.0", + "redis": "^4.6.8", "remark-gfm": "^3.0.1", "sequelize": "^6.31.0", "showdown": "^2.1.0", @@ -2133,6 +2134,59 @@ "version": "1.1.0", "license": "BSD-3-Clause" }, + "node_modules/@redis/bloom": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@redis/bloom/-/bloom-1.2.0.tgz", + "integrity": "sha512-HG2DFjYKbpNmVXsa0keLHp/3leGJz1mjh09f2RLGGLQZzSHpkmZWuwJbAvo3QcRY8p80m5+ZdXZdYOSBLlp7Cg==", + "peerDependencies": { + "@redis/client": "^1.0.0" + } + }, + "node_modules/@redis/client": { + "version": "1.5.9", + "resolved": "https://registry.npmjs.org/@redis/client/-/client-1.5.9.tgz", + "integrity": "sha512-SffgN+P1zdWJWSXBvJeynvEnmnZrYmtKSRW00xl8pOPFOMJjxRR9u0frSxJpPR6Y4V+k54blJjGW7FgxbTI7bQ==", + "dependencies": { + "cluster-key-slot": "1.1.2", + "generic-pool": "3.9.0", + "yallist": "4.0.0" + }, + "engines": { + "node": ">=14" + } + }, + "node_modules/@redis/graph": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@redis/graph/-/graph-1.1.0.tgz", + "integrity": "sha512-16yZWngxyXPd+MJxeSr0dqh2AIOi8j9yXKcKCwVaKDbH3HTuETpDVPcLujhFYVPtYrngSco31BUcSa9TH31Gqg==", + "peerDependencies": { + "@redis/client": "^1.0.0" + } + }, + "node_modules/@redis/json": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@redis/json/-/json-1.0.4.tgz", + "integrity": "sha512-LUZE2Gdrhg0Rx7AN+cZkb1e6HjoSKaeeW8rYnt89Tly13GBI5eP4CwDVr+MY8BAYfCg4/N15OUrtLoona9uSgw==", + "peerDependencies": { + "@redis/client": "^1.0.0" + } + }, + "node_modules/@redis/search": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@redis/search/-/search-1.1.3.tgz", + "integrity": "sha512-4Dg1JjvCevdiCBTZqjhKkGoC5/BcB7k9j99kdMnaXFXg8x4eyOIVg9487CMv7/BUVkFLZCaIh8ead9mU15DNng==", + "peerDependencies": { + "@redis/client": "^1.0.0" + } + }, + "node_modules/@redis/time-series": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@redis/time-series/-/time-series-1.0.5.tgz", + "integrity": "sha512-IFjIgTusQym2B5IZJG3XKr5llka7ey84fw/NOYqESP5WUfQs9zz1ww/9+qoz4ka/S6KcGBodzlCeZ5UImKbscg==", + "peerDependencies": { + "@redis/client": "^1.0.0" + } + }, "node_modules/@rushstack/eslint-patch": { "version": "1.3.2", "dev": true, @@ -4118,6 +4172,14 @@ "node": ">=6" } }, + "node_modules/cluster-key-slot": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/cluster-key-slot/-/cluster-key-slot-1.1.2.tgz", + "integrity": "sha512-RMr0FhtfXemyinomL4hrWcYJxmX6deFdCxpJzhDttxgO1+bcCnkk+9drydLVDmAMG7NE6aN/fl4F7ucU/90gAA==", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/co": { "version": "4.6.0", "dev": true, @@ -5962,6 +6024,14 @@ "node": ">=12" } }, + "node_modules/generic-pool": { + "version": "3.9.0", + "resolved": "https://registry.npmjs.org/generic-pool/-/generic-pool-3.9.0.tgz", + "integrity": "sha512-hymDOu5B53XvN4QT9dBmZxPX4CWhBPPLguTZ9MMFeFa/Kg0xWVfylOVNlJji/E7yTZWFd/q9GO5TxDLq156D7g==", + "engines": { + "node": ">= 4" + } + }, "node_modules/gensync": { "version": "1.0.0-beta.2", "license": "MIT", @@ -10666,6 +10736,19 @@ "node": ">= 12.13.0" } }, + "node_modules/redis": { + "version": "4.6.8", + "resolved": "https://registry.npmjs.org/redis/-/redis-4.6.8.tgz", + "integrity": "sha512-S7qNkPUYrsofQ0ztWlTHSaK0Qqfl1y+WMIxrzeAGNG+9iUZB4HGeBgkHxE6uJJ6iXrkvLd1RVJ2nvu6H1sAzfQ==", + "dependencies": { + "@redis/bloom": "1.2.0", + "@redis/client": "1.5.9", + "@redis/graph": "1.1.0", + "@redis/json": "1.0.4", + "@redis/search": "1.1.3", + "@redis/time-series": "1.0.5" + } + }, "node_modules/refractor": { "version": "4.8.1", "license": "MIT", diff --git a/front/package.json b/front/package.json index d55502f25095..9b3685651bc8 100644 --- a/front/package.json +++ b/front/package.json @@ -56,6 +56,7 @@ "react-markdown": "^8.0.7", "react-p5": "^1.3.35", "react-textarea-autosize": "^8.4.0", + "redis": "^4.6.8", "remark-gfm": "^3.0.1", "sequelize": "^6.31.0", "showdown": "^2.1.0", diff --git a/front/pages/api/v1/w/[wId]/assistant/[cId]/messages/index.ts b/front/pages/api/v1/w/[wId]/assistant/[cId]/messages/index.ts new file mode 100644 index 000000000000..5f4cecc33438 --- /dev/null +++ b/front/pages/api/v1/w/[wId]/assistant/[cId]/messages/index.ts @@ -0,0 +1,110 @@ +import { NextApiRequest, NextApiResponse } from "next"; + +import { postUserMessageWithPubSub } from "@app/lib/api/assistant/pubsub"; +import { Authenticator, getAPIKey } from "@app/lib/auth"; +import { ReturnedAPIErrorType } from "@app/lib/error"; +import { Conversation } from "@app/lib/models"; +import { apiError, withLogging } from "@app/logger/withlogging"; + +async function handler( + req: NextApiRequest, + res: NextApiResponse +): Promise { + const keyRes = await getAPIKey(req); + if (keyRes.isErr()) { + return apiError(req, res, keyRes.error); + } + + const { auth, keyWorkspaceId } = await Authenticator.fromKey( + keyRes.value, + req.query.wId as string + ); + + if (!keyRes.value.isSystem) { + return apiError(req, res, { + status_code: 400, + api_error: { + type: "invalid_request_error", + message: + "The Assitant API is only accessible by system API Key. Ping us at team@dust.tt if you want access to it.", + }, + }); + } + + if (keyWorkspaceId !== req.query.wId) { + return apiError(req, res, { + status_code: 400, + api_error: { + type: "invalid_request_error", + message: "The Assistant API is only available on your own workspace.", + }, + }); + } + + const conv = await Conversation.findOne({ + where: { + sId: req.query.cId as string, + }, + }); + if (!conv) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "conversation_not_found", + message: "Conversation not found.", + }, + }); + } + + // no time for actual io-ts parsing right now, so here is the expected structure. + // Will handle proper parsing later. + const payload = req.body as { + message: string; + context: { + timezone: string; + username: string; + fullName: string; + email: string; + profilePictureUrl: string; + }; + }; + + switch (req.method) { + case "POST": + // Not awaiting this promise on prupose. + // We want to answer "OK" to the client ASAP and process the events in the background. + void postUserMessageWithPubSub(auth, { + conversation: { + id: conv.id, + created: conv.created.getTime(), + sId: conv.sId, + title: conv.title, + // not sure how to provide the content here for now. + content: [], + visibility: conv.visibility, + }, + message: payload.message, + mentions: [], + context: { + timezone: payload.context.timezone, + username: payload.context.username, + fullName: payload.context.fullName, + email: payload.context.email, + profilePictureUrl: payload.context.profilePictureUrl, + }, + }); + res.status(200).end(); + return; + + default: + return apiError(req, res, { + status_code: 405, + api_error: { + type: "method_not_supported_error", + message: "The method passed is not supported, POST is expected.", + }, + }); + } +} + +export default withLogging(handler); diff --git a/front/pages/api/v1/w/[wId]/assistant/conversation/[cId]/sub.ts b/front/pages/api/v1/w/[wId]/assistant/conversation/[cId]/sub.ts new file mode 100644 index 000000000000..e7c9eecb9aac --- /dev/null +++ b/front/pages/api/v1/w/[wId]/assistant/conversation/[cId]/sub.ts @@ -0,0 +1,87 @@ +import { NextApiRequest, NextApiResponse } from "next"; + +import { getConversationEvents } from "@app/lib/api/assistant/pubsub"; +import { Authenticator, getAPIKey } from "@app/lib/auth"; +import { ReturnedAPIErrorType } from "@app/lib/error"; +import { Conversation } from "@app/lib/models"; +import { apiError, withLogging } from "@app/logger/withlogging"; + +async function handler( + req: NextApiRequest, + res: NextApiResponse +): Promise { + const keyRes = await getAPIKey(req); + if (keyRes.isErr()) { + return apiError(req, res, keyRes.error); + } + + if (!keyRes.value.isSystem) { + return apiError(req, res, { + status_code: 400, + api_error: { + type: "invalid_request_error", + message: + "The Assitant API is only accessible by system API Key. Ping us at team@dust.tt if you want access to it.", + }, + }); + } + const { keyWorkspaceId } = await Authenticator.fromKey( + keyRes.value, + req.query.wId as string + ); + + if (keyWorkspaceId !== req.query.wId) { + return apiError(req, res, { + status_code: 400, + api_error: { + type: "invalid_request_error", + message: "The Assistant API is only available on your own workspace.", + }, + }); + } + + const conv = await Conversation.findOne({ + where: { + sId: req.query.cId as string, + }, + }); + if (!conv) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "conversation_not_found", + message: "Conversation not found.", + }, + }); + } + + switch (req.method) { + case "GET": { + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + + for await (const event of getConversationEvents(conv.sId, null)) { + res.write(JSON.stringify(event)); + // @ts-expect-error we need to flush for streaming but TS thinks flush() does not exists. + res.flush(); + } + + res.end(); + return; + } + + default: + return apiError(req, res, { + status_code: 405, + api_error: { + type: "method_not_supported_error", + message: "The method passed is not supported, POST is expected.", + }, + }); + } +} + +export default withLogging(handler); diff --git a/front/pages/api/v1/w/[wId]/assistant/conversation/index.ts b/front/pages/api/v1/w/[wId]/assistant/conversation/index.ts new file mode 100644 index 000000000000..fb9075d9321b --- /dev/null +++ b/front/pages/api/v1/w/[wId]/assistant/conversation/index.ts @@ -0,0 +1,73 @@ +import { NextApiRequest, NextApiResponse } from "next"; + +import { Authenticator, getAPIKey } from "@app/lib/auth"; +import { ReturnedAPIErrorType } from "@app/lib/error"; +import { Conversation } from "@app/lib/models"; +import { generateModelSId } from "@app/lib/utils"; +import { apiError, withLogging } from "@app/logger/withlogging"; +import { ConversationType } from "@app/types/assistant/conversation"; + +async function handler( + req: NextApiRequest, + res: NextApiResponse +): Promise { + const keyRes = await getAPIKey(req); + if (keyRes.isErr()) { + return apiError(req, res, keyRes.error); + } + + const { auth, keyWorkspaceId } = await Authenticator.fromKey( + keyRes.value, + req.query.wId as string + ); + + if (!keyRes.value.isSystem) { + return apiError(req, res, { + status_code: 400, + api_error: { + type: "invalid_request_error", + message: + "The Assitant API is only accessible by system API Key. Ping us at team@dust.tt if you want access to it.", + }, + }); + } + + if (keyWorkspaceId !== req.query.wId) { + return apiError(req, res, { + status_code: 400, + api_error: { + type: "invalid_request_error", + message: "The Assistant API is only available on your own workspace.", + }, + }); + } + + switch (req.method) { + case "POST": + const conv = await Conversation.create({ + sId: generateModelSId(), + title: req.body.title, + created: new Date(), + visibility: req.body.visibility, + }); + return res.status(200).json({ + id: conv.id, + created: conv.created.getTime(), + sId: conv.sId, + title: conv.title, + visibility: conv.visibility, + content: [], + }); + + default: + return apiError(req, res, { + status_code: 405, + api_error: { + type: "method_not_supported_error", + message: "The method passed is not supported, POST is expected.", + }, + }); + } +} + +export default withLogging(handler);