diff --git a/front/components/assistant/conversation/AgentMessage.tsx b/front/components/assistant/conversation/AgentMessage.tsx index 856b1d39e4c7..f8dad26cf10d 100644 --- a/front/components/assistant/conversation/AgentMessage.tsx +++ b/front/components/assistant/conversation/AgentMessage.tsx @@ -8,16 +8,18 @@ import { EyeIcon, Spinner, } from "@dust-tt/sparkle"; -import { useCallback, useEffect, useState } from "react"; +import { useCallback, useContext, useEffect, useState } from "react"; import { AgentAction } from "@app/components/assistant/conversation/AgentAction"; import { ConversationMessage } from "@app/components/assistant/conversation/ConversationMessage"; +import { GenerationContext } from "@app/components/assistant/conversation/GenerationContextProvider"; import { RenderMessageMarkdown } from "@app/components/assistant/RenderMessageMarkdown"; import { useEventSource } from "@app/hooks/useEventSource"; import { AgentActionEvent, AgentActionSuccessEvent, AgentErrorEvent, + AgentGenerationCancelledEvent, AgentGenerationSuccessEvent, AgentMessageSuccessEvent, } from "@app/lib/api/assistant/agent"; @@ -52,6 +54,7 @@ export function AgentMessage({ switch (streamedAgentMessage.status) { case "succeeded": case "failed": + case "cancelled": return false; case "created": return true; @@ -92,6 +95,7 @@ export function AgentMessage({ | AgentActionSuccessEvent | GenerationTokensEvent | AgentGenerationSuccessEvent + | AgentGenerationCancelledEvent | AgentMessageSuccessEvent; } = JSON.parse(eventStr); @@ -116,6 +120,13 @@ export function AgentMessage({ return { ...m, content: event.text }; }); break; + + case "agent_generation_cancelled": + setStreamedAgentMessage((m) => { + return { ...m, status: "cancelled" }; + }); + break; + case "agent_message_success": { setStreamedAgentMessage(event.message); break; @@ -141,6 +152,7 @@ export function AgentMessage({ switch (message.status) { case "succeeded": case "failed": + case "cancelled": return message; case "created": return streamedAgentMessage; @@ -167,6 +179,26 @@ export function AgentMessage({ } }, [agentMessageToRender.content, agentMessageToRender.status]); + // GenerationContext: to know if we are generating or not + const generationContext = useContext(GenerationContext); + if (!generationContext) { + throw new Error( + "AgentMessage must be used within a GenerationContextProvider" + ); + } + useEffect(() => { + const isInArray = generationContext.generatingMessageIds.includes( + message.sId + ); + if (agentMessageToRender.status === "created" && !isInArray) { + generationContext.setGeneratingMessageIds((s) => [...s, message.sId]); + } else if (agentMessageToRender.status !== "created" && isInArray) { + generationContext.setGeneratingMessageIds((s) => + s.filter((id) => id !== message.sId) + ); + } + }, [agentMessageToRender.status, generationContext, message.sId]); + const buttons = message.status === "failed" ? [] diff --git a/front/components/assistant/conversation/Conversation.tsx b/front/components/assistant/conversation/Conversation.tsx index 55df56fab33d..fe9776809052 100644 --- a/front/components/assistant/conversation/Conversation.tsx +++ b/front/components/assistant/conversation/Conversation.tsx @@ -3,6 +3,7 @@ import { useCallback, useEffect, useRef } from "react"; import { AgentMessage } from "@app/components/assistant/conversation/AgentMessage"; import { UserMessage } from "@app/components/assistant/conversation/UserMessage"; import { useEventSource } from "@app/hooks/useEventSource"; +import { AgentGenerationCancelledEvent } from "@app/lib/api/assistant/agent"; import { AgentMessageNewEvent, ConversationTitleEvent, @@ -115,6 +116,7 @@ export default function Conversation({ data: | UserMessageNewEvent | AgentMessageNewEvent + | AgentGenerationCancelledEvent | ConversationTitleEvent; } = JSON.parse(eventStr); @@ -125,6 +127,7 @@ export default function Conversation({ switch (event.type) { case "user_message_new": case "agent_message_new": + case "agent_generation_cancelled": void mutateConversation(); break; case "conversation_title": { diff --git a/front/components/assistant/conversation/GenerationContextProvider.tsx b/front/components/assistant/conversation/GenerationContextProvider.tsx new file mode 100644 index 000000000000..ed93655861db --- /dev/null +++ b/front/components/assistant/conversation/GenerationContextProvider.tsx @@ -0,0 +1,30 @@ +import { createContext, useState } from "react"; + +type GenerationContextType = { + generatingMessageIds: string[]; + setGeneratingMessageIds: React.Dispatch>; +}; + +export const GenerationContext = createContext< + GenerationContextType | undefined +>(undefined); + +export const GenerationContextProvider = ({ + children, +}: { + children: React.ReactNode; +}) => { + const [generatingMessageIds, setGeneratingMessageIds] = useState( + [] + ); + return ( + + {children} + + ); +}; diff --git a/front/components/assistant/conversation/InputBar.tsx b/front/components/assistant/conversation/InputBar.tsx index 2094a5d13a48..498ff8bd682b 100644 --- a/front/components/assistant/conversation/InputBar.tsx +++ b/front/components/assistant/conversation/InputBar.tsx @@ -1,4 +1,10 @@ -import { Avatar, IconButton, PaperAirplaneIcon } from "@dust-tt/sparkle"; +import { + Avatar, + Button, + IconButton, + PaperAirplaneIcon, + StopIcon, +} from "@dust-tt/sparkle"; import { Transition } from "@headlessui/react"; import { createContext, @@ -14,6 +20,7 @@ import { import * as ReactDOMServer from "react-dom/server"; import { AssistantPicker } from "@app/components/assistant/AssistantPicker"; +import { GenerationContext } from "@app/components/assistant/conversation/GenerationContextProvider"; import { compareAgentsForSort } from "@app/lib/assistant"; import { useAgentConfigurations } from "@app/lib/swr"; import { classNames } from "@app/lib/utils"; @@ -753,13 +760,64 @@ export function FixedAssistantInputBar({ owner, onSubmit, stickyMentions, + conversationId, }: { owner: WorkspaceType; onSubmit: (input: string, mentions: MentionType[]) => void; stickyMentions?: AgentMention[]; + conversationId: string | null; }) { + const [isProcessing, setIsProcessing] = useState(false); + + // GenerationContext: to know if we are generating or not + const generationContext = useContext(GenerationContext); + if (!generationContext) { + throw new Error( + "FixedAssistantInputBar must be used within a GenerationContextProvider" + ); + } + + const handleStopGeneration = async () => { + if (!conversationId) { + return; + } + setIsProcessing(true); // we don't set it back to false immediately cause it takes a bit of time to cancel + await fetch( + `/api/w/${owner.sId}/assistant/conversations/${conversationId}/cancel`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + action: "cancel", + messageIds: generationContext.generatingMessageIds, + }), + } + ); + }; + + useEffect(() => { + if (isProcessing && generationContext.generatingMessageIds.length === 0) { + setIsProcessing(false); + } + }, [isProcessing, generationContext.generatingMessageIds.length]); + return (
+ {generationContext.generatingMessageIds.length > 0 && ( +
+
+ )} +
{ @@ -349,6 +358,15 @@ export async function* runAgent( }; return; + case "generation_cancel": + yield { + type: "agent_generation_cancelled", + created: event.created, + configurationId: configuration.sId, + messageId: agentMessage.sId, + }; + return; + case "generation_success": yield { type: "agent_generation_success", diff --git a/front/lib/api/assistant/conversation.ts b/front/lib/api/assistant/conversation.ts index ab4f0fa11b09..3c331e6b367a 100644 --- a/front/lib/api/assistant/conversation.ts +++ b/front/lib/api/assistant/conversation.ts @@ -9,6 +9,7 @@ import { AgentActionEvent, AgentActionSuccessEvent, AgentErrorEvent, + AgentGenerationCancelledEvent, AgentGenerationSuccessEvent, AgentMessageSuccessEvent, runAgent, @@ -649,6 +650,7 @@ export async function* postUserMessage( | AgentActionSuccessEvent | GenerationTokensEvent | AgentGenerationSuccessEvent + | AgentGenerationCancelledEvent | AgentMessageSuccessEvent | ConversationTitleEvent, void @@ -981,6 +983,7 @@ export async function* editUserMessage( | AgentActionSuccessEvent | GenerationTokensEvent | AgentGenerationSuccessEvent + | AgentGenerationCancelledEvent | AgentMessageSuccessEvent, void > { @@ -1324,6 +1327,7 @@ export async function* retryAgentMessage( | AgentActionSuccessEvent | GenerationTokensEvent | AgentGenerationSuccessEvent + | AgentGenerationCancelledEvent | AgentMessageSuccessEvent, void > { @@ -1522,6 +1526,7 @@ async function* streamRunAgentEvents( | AgentActionSuccessEvent | GenerationTokensEvent | AgentGenerationSuccessEvent + | AgentGenerationCancelledEvent | AgentMessageSuccessEvent, void >, @@ -1533,9 +1538,11 @@ async function* streamRunAgentEvents( | AgentActionSuccessEvent | GenerationTokensEvent | AgentGenerationSuccessEvent + | AgentGenerationCancelledEvent | AgentMessageSuccessEvent, void > { + let content = ""; for await (const event of eventStream) { switch (event.type) { case "agent_error": @@ -1588,12 +1595,25 @@ async function* streamRunAgentEvents( yield event; break; + case "agent_generation_cancelled": + if (agentMessageRow.status !== "cancelled") { + await agentMessageRow.update({ + status: "cancelled", + content: content, + }); + yield event; + } + break; + // All other events that won't impact the database and are related to actions or tokens // generation. case "retrieval_params": case "dust_app_run_params": case "dust_app_run_block": + yield event; + break; case "generation_tokens": + content += event.text; yield event; break; diff --git a/front/lib/api/assistant/generation.ts b/front/lib/api/assistant/generation.ts index 974330d6048c..a357454d6b72 100644 --- a/front/lib/api/assistant/generation.ts +++ b/front/lib/api/assistant/generation.ts @@ -16,6 +16,7 @@ import { } from "@app/lib/assistant"; import { Authenticator } from "@app/lib/auth"; import { CoreAPI } from "@app/lib/core_api"; +import { redisClient } from "@app/lib/redis"; import { Err, Ok, Result } from "@app/lib/result"; import logger from "@app/logger/logger"; import { isDustAppRunActionType } from "@app/types/assistant/actions/dust_app_run"; @@ -34,6 +35,7 @@ import { } from "@app/types/assistant/conversation"; import { renderDustAppRunActionForModel } from "./actions/dust_app_run"; +const CANCELLATION_CHECK_INTERVAL = 500; /** * Model rendering of conversations. @@ -236,6 +238,13 @@ export type GenerationSuccessEvent = { text: string; }; +export type GenerationCancelEvent = { + type: "generation_cancel"; + created: number; + configurationId: string; + messageId: string; +}; + // Construct the full prompt from the agent configuration. // - Meta data about the agent and current time. // - Insructions from the agent configuration (in case of generation) @@ -272,7 +281,10 @@ export async function* runGeneration( userMessage: UserMessageType, agentMessage: AgentMessageType ): AsyncGenerator< - GenerationErrorEvent | GenerationTokensEvent | GenerationSuccessEvent, + | GenerationErrorEvent + | GenerationTokensEvent + | GenerationSuccessEvent + | GenerationCancelEvent, void > { const owner = auth.workspace(); @@ -394,17 +406,29 @@ export async function* runGeneration( const { eventStream } = res.value; - for await (const event of eventStream) { - if (event.type === "tokens") { - yield { - type: "generation_tokens", - created: Date.now(), - configurationId: configuration.sId, - messageId: agentMessage.sId, - text: event.content.tokens.text, - }; + let shouldYieldCancel = false; + let lastCheckCancellation = Date.now(); + const redis = await redisClient(); + + const _checkCancellation = async () => { + try { + const cancelled = await redis.get( + `assistant:generation:cancelled:${agentMessage.sId}` + ); + if (cancelled === "1") { + shouldYieldCancel = true; + await redis.set( + `assistant:generation:cancelled:${agentMessage.sId}`, + 0 + ); + } + } catch (error) { + console.error("Error checking cancellation:", error); + return false; } + }; + for await (const event of eventStream) { if (event.type === "error") { yield { type: "generation_error", @@ -419,6 +443,35 @@ export async function* runGeneration( return; } + const currentTimestamp = Date.now(); + if ( + currentTimestamp - lastCheckCancellation >= + CANCELLATION_CHECK_INTERVAL + ) { + void _checkCancellation(); // Trigger the async function without awaiting + lastCheckCancellation = currentTimestamp; + } + + if (shouldYieldCancel) { + yield { + type: "generation_cancel", + created: Date.now(), + configurationId: configuration.sId, + messageId: agentMessage.sId, + }; + return; + } + + if (event.type === "tokens") { + yield { + type: "generation_tokens", + created: Date.now(), + configurationId: configuration.sId, + messageId: agentMessage.sId, + text: event.content.tokens.text, + }; + } + if (event.type === "block_execution") { const e = event.content.execution[0][0]; if (e.error) { @@ -451,4 +504,5 @@ export async function* runGeneration( } } } + await redis.quit(); } diff --git a/front/lib/api/assistant/pubsub.ts b/front/lib/api/assistant/pubsub.ts index dc3b90513f17..3501b5a78705 100644 --- a/front/lib/api/assistant/pubsub.ts +++ b/front/lib/api/assistant/pubsub.ts @@ -2,6 +2,7 @@ import { AgentActionEvent, AgentActionSuccessEvent, AgentErrorEvent, + AgentGenerationCancelledEvent, AgentGenerationSuccessEvent, AgentMessageSuccessEvent, } from "@app/lib/api/assistant/agent"; @@ -87,6 +88,7 @@ async function handleUserMessageEvents( | AgentActionSuccessEvent | GenerationTokensEvent | AgentGenerationSuccessEvent + | AgentGenerationCancelledEvent | AgentMessageSuccessEvent | ConversationTitleEvent, void @@ -123,6 +125,7 @@ async function handleUserMessageEvents( case "agent_action_success": case "generation_tokens": case "agent_generation_success": + case "agent_generation_cancelled": case "agent_message_success": { const pubsubChannel = getMessageChannelId(event.messageId); await redis.xAdd(pubsubChannel, "*", { @@ -218,6 +221,7 @@ export async function retryAgentMessageWithPubSub( case "agent_action_success": case "generation_tokens": case "agent_generation_success": + case "agent_generation_cancelled": case "agent_message_success": { const pubsubChannel = getMessageChannelId(event.messageId); await redis.xAdd(pubsubChannel, "*", { @@ -226,6 +230,7 @@ export async function retryAgentMessageWithPubSub( await redis.expire(pubsubChannel, 60 * 10); break; } + default: ((event: never) => { logger.error("Unknown event type", event); @@ -296,6 +301,16 @@ export async function* getConversationEvents( } } +export async function cancelMessageGenerationEvent( + messageIds: string[] +): Promise { + const redis = await redisClient(); + messageIds.forEach(async (messageId) => { + await redis.set(`assistant:generation:cancelled:${messageId}`, 1); + }); + await redis.quit(); +} + export async function* getMessagesEvents( messageId: string, lastEventId: string | null @@ -306,6 +321,7 @@ export async function* getMessagesEvents( | AgentErrorEvent | AgentActionEvent | AgentActionSuccessEvent + | AgentGenerationCancelledEvent | GenerationTokensEvent | AgentGenerationSuccessEvent; }, diff --git a/front/pages/api/w/[wId]/assistant/conversations/[cId]/cancel.ts b/front/pages/api/w/[wId]/assistant/conversations/[cId]/cancel.ts new file mode 100644 index 000000000000..5629aae97b97 --- /dev/null +++ b/front/pages/api/w/[wId]/assistant/conversations/[cId]/cancel.ts @@ -0,0 +1,111 @@ +import { isLeft } from "fp-ts/lib/Either"; +import * as t from "io-ts"; +import * as reporter from "io-ts-reporters"; +import { NextApiRequest, NextApiResponse } from "next"; + +import { getConversation } from "@app/lib/api/assistant/conversation"; +import { cancelMessageGenerationEvent } from "@app/lib/api/assistant/pubsub"; +import { Authenticator, getSession } from "@app/lib/auth"; +import { ReturnedAPIErrorType } from "@app/lib/error"; +import { apiError, withLogging } from "@app/logger/withlogging"; + +export type PostMessageEventResponseBody = { + success: true; +}; +const PostMessageEventBodySchema = t.type({ + action: t.literal("cancel"), + messageIds: t.array(t.string), +}); + +async function handler( + req: NextApiRequest, + res: NextApiResponse +): Promise { + const session = await getSession(req, res); + const auth = await Authenticator.fromSession( + session, + req.query.wId as string + ); + + const owner = auth.workspace(); + if (!owner) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "workspace_not_found", + message: "The workspace you're trying to modify was not found.", + }, + }); + } + + if (!auth.user()) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "workspace_user_not_found", + message: "Could not find the user of the current session.", + }, + }); + } + + if (!auth.isUser()) { + return apiError(req, res, { + status_code: 403, + api_error: { + type: "workspace_auth_error", + message: + "Only users of the current workspace can access conversations.", + }, + }); + } + if (!(typeof req.query.cId === "string")) { + return apiError(req, res, { + status_code: 400, + api_error: { + type: "invalid_request_error", + message: "Invalid query parameters, `cId` (string) is required.", + }, + }); + } + const conversationId = req.query.cId; + const conversation = await getConversation(auth, conversationId); + + if (!conversation) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "conversation_not_found", + message: "The conversation you're trying to access was not found.", + }, + }); + } + + switch (req.method) { + case "POST": + const bodyValidation = PostMessageEventBodySchema.decode(req.body); + if (isLeft(bodyValidation)) { + const pathError = reporter.formatValidationErrors(bodyValidation.left); + + return apiError(req, res, { + status_code: 400, + api_error: { + type: "invalid_request_error", + message: `Invalid request body: ${pathError}`, + }, + }); + } + await cancelMessageGenerationEvent(bodyValidation.right.messageIds); + return res.status(200).json({ success: true }); + + default: + return apiError(req, res, { + status_code: 405, + api_error: { + type: "method_not_supported_error", + message: "The method passed is not supported, POST is expected.", + }, + }); + } +} + +export default withLogging(handler, true); diff --git a/front/pages/api/w/[wId]/assistant/conversations/[cId]/events.ts b/front/pages/api/w/[wId]/assistant/conversations/[cId]/events.ts index 77ade1c47e51..a2d72b57086a 100644 --- a/front/pages/api/w/[wId]/assistant/conversations/[cId]/events.ts +++ b/front/pages/api/w/[wId]/assistant/conversations/[cId]/events.ts @@ -103,7 +103,6 @@ async function handler( res.status(200).end(); return; - default: return apiError(req, res, { status_code: 405, diff --git a/front/pages/w/[wId]/assistant/[cId]/index.tsx b/front/pages/w/[wId]/assistant/[cId]/index.tsx index 91e154b79118..a0070084e7b6 100644 --- a/front/pages/w/[wId]/assistant/[cId]/index.tsx +++ b/front/pages/w/[wId]/assistant/[cId]/index.tsx @@ -4,6 +4,7 @@ import { useEffect, useState } from "react"; import Conversation from "@app/components/assistant/conversation/Conversation"; import { ConversationTitle } from "@app/components/assistant/conversation/ConversationTitle"; +import { GenerationContextProvider } from "@app/components/assistant/conversation/GenerationContextProvider"; import { FixedAssistantInputBar } from "@app/components/assistant/conversation/InputBar"; import { AssistantSidebarMenu } from "@app/components/assistant/conversation/SidebarMenu"; import AppLayout from "@app/components/sparkle/AppLayout"; @@ -118,37 +119,40 @@ export default function AssistantConversation({ }; return ( - + { + void handdleDeleteConversation(); + }} + /> + } + navChildren={ + + } + > + { - void handdleDeleteConversation(); - }} + onStickyMentionsChange={setStickyMentions} /> - } - navChildren={ - - } - > - - - + + + ); } diff --git a/front/pages/w/[wId]/assistant/new.tsx b/front/pages/w/[wId]/assistant/new.tsx index 9fd17ffe9f70..f3d4dd21be38 100644 --- a/front/pages/w/[wId]/assistant/new.tsx +++ b/front/pages/w/[wId]/assistant/new.tsx @@ -16,6 +16,7 @@ import { useEffect, useState } from "react"; import Conversation from "@app/components/assistant/conversation/Conversation"; import { ConversationTitle } from "@app/components/assistant/conversation/ConversationTitle"; +import { GenerationContextProvider } from "@app/components/assistant/conversation/GenerationContextProvider"; import { FixedAssistantInputBar, InputBarContext, @@ -151,208 +152,214 @@ export default function AssistantNew({ return ( - + + ) + } + navChildren={ + - ) - } - navChildren={ - - } - > - {!conversation ? ( -
- - - {/* FEATURED AGENTS */} - - - - {isBuilder && ( - <> - - Dust comes with multiple assistants, each with a - specific set of skills. -
- Create assistants tailored for your needs. -
- - )} - {!isBuilder && ( - <> - - Dust is not just a single assistant, it’s a full team at - your service. -
- Each member has a set of specific set skills. -
- - Meet some of your assistants team: - - - )} -
-
-
- {displayedAgents.map((agent) => ( - { - void handleSubmit( - `Hi :mention[${agent.name}]{sId=${agent.sId}}, what can you help me with?`, - [ - { - configurationId: agent.sId, - }, - ] - ); - }} - > - - - ))} + } + > + {!conversation ? ( +
+ + + {/* FEATURED AGENTS */} + + + + {isBuilder && ( + <> + + Dust comes with multiple assistants, each with a + specific set of skills. +
+ Create assistants tailored for your needs. +
+ + )} + {!isBuilder && ( + <> + + Dust is not just a single assistant, it’s a full team + at your service. +
+ Each member has a set of specific set skills. +
+ + Meet some of your assistants team: + + + )} +
+ -
- - {activeAgents.length > 4 && ( -
- ) : ( - + ) : ( + + )} + + - )} - - - + + ); } diff --git a/front/types/assistant/conversation.ts b/front/types/assistant/conversation.ts index f6d63235c916..2fc4590ff5ad 100644 --- a/front/types/assistant/conversation.ts +++ b/front/types/assistant/conversation.ts @@ -83,7 +83,11 @@ export function isUserMessageType( */ export type AgentActionType = RetrievalActionType | DustAppRunActionType; -export type AgentMessageStatus = "created" | "succeeded" | "failed"; +export type AgentMessageStatus = + | "created" + | "succeeded" + | "failed" + | "cancelled"; /** * Both `action` and `message` are optional (we could have a no-op agent basically).