Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle message during tool call #4692

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions gui/src/redux/slices/sessionSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ export const sessionSlice = createSlice({
}: PayloadAction<{
index: number;
editorState: JSONContent;
cancelsToolId: string | undefined;
}>,
) => {
const { index, editorState } = payload;
const { index, editorState, cancelsToolId } = payload;

if (state.history.length && index < state.history.length) {
// Resubmission - update input message, truncate history after resubmit with new empty response message
Expand All @@ -209,7 +210,7 @@ export const sessionSlice = createSlice({
state.curCheckpointIndex = Math.floor(index / 2);
} else {
// New input/response messages
state.history = state.history.concat([
const newMessages: ChatHistoryItemWithMessageId[] = [
{
message: {
id: uuidv4(),
Expand All @@ -227,7 +228,22 @@ export const sessionSlice = createSlice({
},
contextItems: [],
},
]);
];

if (cancelsToolId) {
newMessages.unshift({
message: {
role: "tool",
content:
"The user cancelled this tool call and is giving further instructions/feedback below.",
toolCallId: cancelsToolId,
id: uuidv4(),
},
contextItems: [],
});
}

state.history = state.history.concat(newMessages);

state.curCheckpointIndex = Math.floor((state.history.length - 1) / 2); // TODO this feels really fragile
}
Expand Down Expand Up @@ -294,7 +310,6 @@ export const sessionSlice = createSlice({
state.streamAborter = new AbortController();
},
streamUpdate: (state, action: PayloadAction<ChatMessage[]>) => {

if (state.history.length) {
function toolCallDeltaToState(
toolCallDelta: ToolCallDelta,
Expand Down Expand Up @@ -332,7 +347,7 @@ export const sessionSlice = createSlice({
id: uuidv4(),
},
contextItems: [],
})
});
continue;
}

Expand All @@ -345,7 +360,7 @@ export const sessionSlice = createSlice({
!(!lastMessage.toolCalls?.length && !lastMessage.content) &&
// And there's a difference in tool call presence
(lastMessage.toolCalls?.length ?? 0) !==
(message.toolCalls?.length ?? 0))
(message.toolCalls?.length ?? 0))
) {
// Create a new message
const historyItem: ChatHistoryItemWithMessageId = {
Expand Down Expand Up @@ -495,9 +510,9 @@ export const sessionSlice = createSlice({
state.allSessionMetadata = state.allSessionMetadata.map((session) =>
session.sessionId === payload.sessionId
? {
...session,
...payload,
}
...session,
...payload,
}
: session,
);
if (payload.title && payload.sessionId === state.id) {
Expand Down Expand Up @@ -532,8 +547,9 @@ export const sessionSlice = createSlice({
payload.rangeInFileWithContents.filepath,
);

const lineNums = `(${payload.rangeInFileWithContents.range.start.line + 1
}-${payload.rangeInFileWithContents.range.end.line + 1})`;
const lineNums = `(${
payload.rangeInFileWithContents.range.start.line + 1
}-${payload.rangeInFileWithContents.range.end.line + 1})`;

contextItems.push({
name: `${fileName} ${lineNums}`,
Expand Down Expand Up @@ -718,9 +734,9 @@ function addPassthroughCases(
) {
thunks.forEach((thunk) => {
builder
.addCase(thunk.fulfilled, (state, action) => { })
.addCase(thunk.rejected, (state, action) => { })
.addCase(thunk.pending, (state, action) => { });
.addCase(thunk.fulfilled, (state, action) => {})
.addCase(thunk.rejected, (state, action) => {})
.addCase(thunk.pending, (state, action) => {});
});
}

Expand Down
113 changes: 0 additions & 113 deletions gui/src/redux/thunks/gatherContext.ts

This file was deleted.

3 changes: 1 addition & 2 deletions gui/src/redux/thunks/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
export { exitEditMode } from "./exitEditMode";
export { gatherContext } from "./gatherContext";
export { streamThunkWrapper } from "./streamThunkWrapper";
export { resetStateForNewMessage } from "./resetStateForNewMessage";
export { streamNormalInput } from "./streamNormalInput";
export { streamResponseThunk } from "./streamResponse";
export { streamResponseAfterToolCall } from "./streamResponseAfterToolCall";
export { streamSlashCommand } from "./streamSlashCommand";
export { streamThunkWrapper } from "./streamThunkWrapper";
110 changes: 95 additions & 15 deletions gui/src/redux/thunks/streamResponse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@ import {
import { constructMessages } from "core/llm/constructMessages";
import { renderChatMessage } from "core/util/messageContent";
import posthog from "posthog-js";
import * as URI from "uri-js";
import { v4 as uuidv4 } from "uuid";
import resolveEditorContent from "../../components/mainInput/resolveInput";
import { selectDefaultModel } from "../slices/configSlice";
import {
cancelToolCall,
submitEditorAndInitAtIndex,
updateHistoryItemAtIndex,
} from "../slices/sessionSlice";
import { ThunkApiType } from "../store";
import { gatherContext } from "./gatherContext";
import { resetStateForNewMessage } from "./resetStateForNewMessage";
import { streamNormalInput } from "./streamNormalInput";
import { streamSlashCommand } from "./streamSlashCommand";
Expand Down Expand Up @@ -71,29 +73,107 @@ export const streamResponseThunk = createAsyncThunk<
await dispatch(
streamThunkWrapper(async () => {
const state = getState();
const useTools = state.ui.useTools;
const defaultModel = selectDefaultModel(state);
const slashCommands = state.config.config.slashCommands || [];
const inputIndex = index ?? state.session.history.length; // Either given index or concat to end

const defaultModel = selectDefaultModel(state);
if (!defaultModel) {
throw new Error("No chat model selected");
}

const useTools = state.ui.useTools;
const slashCommands = state.config.config.slashCommands || [];
const insertIndex = index ?? state.session.history.length; // Either given index or concat to end
let userMessageIndex = insertIndex;
const lastItem = state.session.history[insertIndex - 1];

let cancelsToolId: string | undefined = undefined;
if (
lastItem &&
lastItem.message.role === "assistant" &&
lastItem.message.toolCalls?.length &&
lastItem.toolCallState?.toolCallId
) {
cancelsToolId = lastItem.toolCallState.toolCallId;
userMessageIndex++;
dispatch(cancelToolCall());
}

dispatch(
submitEditorAndInitAtIndex({ index: inputIndex, editorState }),
submitEditorAndInitAtIndex({
index: insertIndex,
editorState,
cancelsToolId,
}),
);
resetStateForNewMessage();

const result = await dispatch(
gatherContext({
const defaultContextProviders =
state.config.config.experimental?.defaultContext ?? [];

// Resolve context providers and construct new history
let [selectedContextItems, selectedCode, content] =
await resolveEditorContent({
editorState,
modifiers,
promptPreamble,
}),
);
const unwrapped = unwrapResult(result);
const { selectedContextItems, selectedCode, content } = unwrapped;
ideMessenger: extra.ideMessenger,
defaultContextProviders,
dispatch,
selectedModelTitle: defaultModel.title,
});

// Automatically use currently open file
if (!modifiers.noContext) {
const usingFreeTrial = defaultModel.provider === "free-trial";

const currentFileResponse = await extra.ideMessenger.request(
"context/getContextItems",
{
name: "currentFile",
query: "non-mention-usage",
fullInput: "",
selectedCode: [],
selectedModelTitle: defaultModel.title,
},
);
if (currentFileResponse.status === "success") {
const items = currentFileResponse.content;
if (items.length > 0) {
const currentFile = items[0];
const uri = currentFile.uri?.value;

// don't add the file if it's already in the context items
if (
uri &&
!selectedContextItems.find(
(item) => item.uri?.value && URI.equal(item.uri.value, uri),
)
) {
// Limit to 1000 lines if using free trial
if (usingFreeTrial) {
currentFile.content = currentFile.content
.split("\n")
.slice(0, 1000)
.join("\n");
if (!currentFile.content.endsWith("```")) {
currentFile.content += "\n```";
}
}
currentFile.id = {
providerTitle: "file",
itemId: uri,
};
selectedContextItems.unshift(currentFile);
}
}
}
}

if (promptPreamble) {
if (typeof content === "string") {
content = promptPreamble + content;
} else if (content[0].type === "text") {
content[0].text = promptPreamble + content[0].text;
}
}

// symbols for both context items AND selected codeblocks
const filesForSymbols = [
Expand All @@ -106,7 +186,7 @@ export const streamResponseThunk = createAsyncThunk<

dispatch(
updateHistoryItemAtIndex({
index: inputIndex,
index: userMessageIndex,
updates: {
message: {
role: "user",
Expand Down Expand Up @@ -153,7 +233,7 @@ export const streamResponseThunk = createAsyncThunk<
messages,
slashCommand,
input: commandInput,
historyIndex: inputIndex,
historyIndex: userMessageIndex,
selectedCode,
contextItems: [],
}),
Expand Down
Loading