Skip to content

Commit

Permalink
Button to stop Agent Generation (#2018)
Browse files Browse the repository at this point in the history
* Button to stop Agent Generation

* Absolutely no need for that

* Don't await check on redis

* Stop generation for the whole convo

* Stop generation for conversation, put button in FixedInputBar

* Apply first feedback

* New /cancel route to post cancellation

* Keep provider top level component

* fix bad rebase

* Loading state for stop generation button
  • Loading branch information
PopDaph authored Oct 10, 2023
1 parent 30b5c0c commit 15ccc36
Show file tree
Hide file tree
Showing 13 changed files with 589 additions and 233 deletions.
34 changes: 33 additions & 1 deletion front/components/assistant/conversation/AgentMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@ import {
EyeIcon,
Spinner,
} from "@dust-tt/sparkle";
import { useCallback, useEffect, useState } from "react";
import { useCallback, useContext, useEffect, useState } from "react";

import { AgentAction } from "@app/components/assistant/conversation/AgentAction";
import { ConversationMessage } from "@app/components/assistant/conversation/ConversationMessage";
import { GenerationContext } from "@app/components/assistant/conversation/GenerationContextProvider";
import { RenderMessageMarkdown } from "@app/components/assistant/RenderMessageMarkdown";
import { useEventSource } from "@app/hooks/useEventSource";
import {
AgentActionEvent,
AgentActionSuccessEvent,
AgentErrorEvent,
AgentGenerationCancelledEvent,
AgentGenerationSuccessEvent,
AgentMessageSuccessEvent,
} from "@app/lib/api/assistant/agent";
Expand Down Expand Up @@ -52,6 +54,7 @@ export function AgentMessage({
switch (streamedAgentMessage.status) {
case "succeeded":
case "failed":
case "cancelled":
return false;
case "created":
return true;
Expand Down Expand Up @@ -92,6 +95,7 @@ export function AgentMessage({
| AgentActionSuccessEvent
| GenerationTokensEvent
| AgentGenerationSuccessEvent
| AgentGenerationCancelledEvent
| AgentMessageSuccessEvent;
} = JSON.parse(eventStr);

Expand All @@ -116,6 +120,13 @@ export function AgentMessage({
return { ...m, content: event.text };
});
break;

case "agent_generation_cancelled":
setStreamedAgentMessage((m) => {
return { ...m, status: "cancelled" };
});
break;

case "agent_message_success": {
setStreamedAgentMessage(event.message);
break;
Expand All @@ -141,6 +152,7 @@ export function AgentMessage({
switch (message.status) {
case "succeeded":
case "failed":
case "cancelled":
return message;
case "created":
return streamedAgentMessage;
Expand All @@ -167,6 +179,26 @@ export function AgentMessage({
}
}, [agentMessageToRender.content, agentMessageToRender.status]);

// GenerationContext: to know if we are generating or not
const generationContext = useContext(GenerationContext);
if (!generationContext) {
throw new Error(
"AgentMessage must be used within a GenerationContextProvider"
);
}
useEffect(() => {
const isInArray = generationContext.generatingMessageIds.includes(
message.sId
);
if (agentMessageToRender.status === "created" && !isInArray) {
generationContext.setGeneratingMessageIds((s) => [...s, message.sId]);
} else if (agentMessageToRender.status !== "created" && isInArray) {
generationContext.setGeneratingMessageIds((s) =>
s.filter((id) => id !== message.sId)
);
}
}, [agentMessageToRender.status, generationContext, message.sId]);

const buttons =
message.status === "failed"
? []
Expand Down
3 changes: 3 additions & 0 deletions front/components/assistant/conversation/Conversation.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { useCallback, useEffect, useRef } from "react";
import { AgentMessage } from "@app/components/assistant/conversation/AgentMessage";
import { UserMessage } from "@app/components/assistant/conversation/UserMessage";
import { useEventSource } from "@app/hooks/useEventSource";
import { AgentGenerationCancelledEvent } from "@app/lib/api/assistant/agent";
import {
AgentMessageNewEvent,
ConversationTitleEvent,
Expand Down Expand Up @@ -115,6 +116,7 @@ export default function Conversation({
data:
| UserMessageNewEvent
| AgentMessageNewEvent
| AgentGenerationCancelledEvent
| ConversationTitleEvent;
} = JSON.parse(eventStr);

Expand All @@ -125,6 +127,7 @@ export default function Conversation({
switch (event.type) {
case "user_message_new":
case "agent_message_new":
case "agent_generation_cancelled":
void mutateConversation();
break;
case "conversation_title": {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { createContext, useState } from "react";

type GenerationContextType = {
generatingMessageIds: string[];
setGeneratingMessageIds: React.Dispatch<React.SetStateAction<string[]>>;
};

export const GenerationContext = createContext<
GenerationContextType | undefined
>(undefined);

export const GenerationContextProvider = ({
children,
}: {
children: React.ReactNode;
}) => {
const [generatingMessageIds, setGeneratingMessageIds] = useState<string[]>(
[]
);
return (
<GenerationContext.Provider
value={{
generatingMessageIds: generatingMessageIds,
setGeneratingMessageIds: setGeneratingMessageIds,
}}
>
{children}
</GenerationContext.Provider>
);
};
60 changes: 59 additions & 1 deletion front/components/assistant/conversation/InputBar.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import { Avatar, IconButton, PaperAirplaneIcon } from "@dust-tt/sparkle";
import {
Avatar,
Button,
IconButton,
PaperAirplaneIcon,
StopIcon,
} from "@dust-tt/sparkle";
import { Transition } from "@headlessui/react";
import {
createContext,
Expand All @@ -14,6 +20,7 @@ import {
import * as ReactDOMServer from "react-dom/server";

import { AssistantPicker } from "@app/components/assistant/AssistantPicker";
import { GenerationContext } from "@app/components/assistant/conversation/GenerationContextProvider";
import { compareAgentsForSort } from "@app/lib/assistant";
import { useAgentConfigurations } from "@app/lib/swr";
import { classNames } from "@app/lib/utils";
Expand Down Expand Up @@ -753,13 +760,64 @@ export function FixedAssistantInputBar({
owner,
onSubmit,
stickyMentions,
conversationId,
}: {
owner: WorkspaceType;
onSubmit: (input: string, mentions: MentionType[]) => void;
stickyMentions?: AgentMention[];
conversationId: string | null;
}) {
const [isProcessing, setIsProcessing] = useState<boolean>(false);

// GenerationContext: to know if we are generating or not
const generationContext = useContext(GenerationContext);
if (!generationContext) {
throw new Error(
"FixedAssistantInputBar must be used within a GenerationContextProvider"
);
}

const handleStopGeneration = async () => {
if (!conversationId) {
return;
}
setIsProcessing(true); // we don't set it back to false immediately cause it takes a bit of time to cancel
await fetch(
`/api/w/${owner.sId}/assistant/conversations/${conversationId}/cancel`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
action: "cancel",
messageIds: generationContext.generatingMessageIds,
}),
}
);
};

useEffect(() => {
if (isProcessing && generationContext.generatingMessageIds.length === 0) {
setIsProcessing(false);
}
}, [isProcessing, generationContext.generatingMessageIds.length]);

return (
<div className="4xl:px-0 fixed bottom-0 left-0 right-0 z-20 flex-initial px-2 lg:left-80">
{generationContext.generatingMessageIds.length > 0 && (
<div className="flex justify-center pb-4">
<Button
className="mt-4"
variant="tertiary"
label={isProcessing ? "Stopping generation..." : "Stop generation"}
icon={StopIcon}
onClick={handleStopGeneration}
disabled={isProcessing}
/>
</div>
)}

<div className="mx-auto max-w-4xl pb-8">
<AssistantInputBar
owner={owner}
Expand Down
18 changes: 18 additions & 0 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ export type AgentGenerationSuccessEvent = {
text: string;
};

// Event sent to stop the generation.
export type AgentGenerationCancelledEvent = {
type: "agent_generation_cancelled";
created: number;
configurationId: string;
messageId: string;
};

// Event sent once the message is completed and successful.
export type AgentMessageSuccessEvent = {
type: "agent_message_success";
Expand All @@ -211,6 +219,7 @@ export async function* runAgent(
| AgentActionSuccessEvent
| GenerationTokensEvent
| AgentGenerationSuccessEvent
| AgentGenerationCancelledEvent
| AgentMessageSuccessEvent,
void
> {
Expand Down Expand Up @@ -349,6 +358,15 @@ export async function* runAgent(
};
return;

case "generation_cancel":
yield {
type: "agent_generation_cancelled",
created: event.created,
configurationId: configuration.sId,
messageId: agentMessage.sId,
};
return;

case "generation_success":
yield {
type: "agent_generation_success",
Expand Down
20 changes: 20 additions & 0 deletions front/lib/api/assistant/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
AgentActionEvent,
AgentActionSuccessEvent,
AgentErrorEvent,
AgentGenerationCancelledEvent,
AgentGenerationSuccessEvent,
AgentMessageSuccessEvent,
runAgent,
Expand Down Expand Up @@ -649,6 +650,7 @@ export async function* postUserMessage(
| AgentActionSuccessEvent
| GenerationTokensEvent
| AgentGenerationSuccessEvent
| AgentGenerationCancelledEvent
| AgentMessageSuccessEvent
| ConversationTitleEvent,
void
Expand Down Expand Up @@ -981,6 +983,7 @@ export async function* editUserMessage(
| AgentActionSuccessEvent
| GenerationTokensEvent
| AgentGenerationSuccessEvent
| AgentGenerationCancelledEvent
| AgentMessageSuccessEvent,
void
> {
Expand Down Expand Up @@ -1324,6 +1327,7 @@ export async function* retryAgentMessage(
| AgentActionSuccessEvent
| GenerationTokensEvent
| AgentGenerationSuccessEvent
| AgentGenerationCancelledEvent
| AgentMessageSuccessEvent,
void
> {
Expand Down Expand Up @@ -1522,6 +1526,7 @@ async function* streamRunAgentEvents(
| AgentActionSuccessEvent
| GenerationTokensEvent
| AgentGenerationSuccessEvent
| AgentGenerationCancelledEvent
| AgentMessageSuccessEvent,
void
>,
Expand All @@ -1533,9 +1538,11 @@ async function* streamRunAgentEvents(
| AgentActionSuccessEvent
| GenerationTokensEvent
| AgentGenerationSuccessEvent
| AgentGenerationCancelledEvent
| AgentMessageSuccessEvent,
void
> {
let content = "";
for await (const event of eventStream) {
switch (event.type) {
case "agent_error":
Expand Down Expand Up @@ -1588,12 +1595,25 @@ async function* streamRunAgentEvents(
yield event;
break;

case "agent_generation_cancelled":
if (agentMessageRow.status !== "cancelled") {
await agentMessageRow.update({
status: "cancelled",
content: content,
});
yield event;
}
break;

// All other events that won't impact the database and are related to actions or tokens
// generation.
case "retrieval_params":
case "dust_app_run_params":
case "dust_app_run_block":
yield event;
break;
case "generation_tokens":
content += event.text;
yield event;
break;

Expand Down
Loading

0 comments on commit 15ccc36

Please sign in to comment.