From 631e88ceec5edba6feabbf3078d93d0ff7dd603f Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 14 Oct 2024 16:10:43 -0600 Subject: [PATCH] [Security Assistant] Fix error handling on new chat (#195507) (cherry picked from commit a15940d9b939dbf29f74dbde28a2a543b8849cc1) --- .../server/language_models/chat_openai.ts | 8 +++- .../chat_vertex/chat_vertex.ts | 7 +++- .../language_models/chat_vertex/connection.ts | 6 ++- .../server/language_models/gemini_chat.ts | 12 +++++- .../server/language_models/llm.ts | 6 ++- .../language_models/simple_chat_model.ts | 12 +++++- .../nodes/generate_chat_title.ts | 39 ++++++++++++------- .../e2e/ai_assistant/conversations.cy.ts | 11 +++--- .../cypress/tasks/assistant.ts | 9 +++-- 9 files changed, 77 insertions(+), 33 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.ts index c20de3be57e07..f679193c23f92 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.ts @@ -147,7 +147,13 @@ export class ActionsClientChatOpenAI extends ChatOpenAI { const actionResult = await this.#actionsClient.execute(requestBody); if (actionResult.status === 'error') { - throw new Error(`${LLM_TYPE}: ${actionResult?.message} - ${actionResult?.serviceMessage}`); + const error = new Error( + `${LLM_TYPE}: ${actionResult?.message} - ${actionResult?.serviceMessage}` + ); + if (actionResult?.serviceMessage) { + error.name = actionResult?.serviceMessage; + } + throw error; } if (!this.streaming) { diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts index 5627abe717291..745c273c79583 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts @@ -98,11 +98,14 @@ export class ActionsClientChatVertexAI extends ChatVertexAI { }; const actionResult = await this.#actionsClient.execute(requestBody); - if (actionResult.status === 'error') { - throw new Error( + const error = new Error( `ActionsClientChatVertexAI: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}` ); + if (actionResult?.serviceMessage) { + error.name = actionResult?.serviceMessage; + } + throw error; } const readable = get('data', actionResult) as Readable; diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts index dd3c1e1abdda0..8ce776890acfa 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts @@ -93,9 +93,13 @@ export class ActionsClientChatConnection extends ChatConnection { }; if (actionResult.status === 'error') { - throw new Error( + const error = new Error( `ActionsClientChatVertexAI: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}` ); + if (actionResult?.serviceMessage) { + error.name = actionResult?.serviceMessage; + } + throw error; } if (actionResult.data.candidates && actionResult.data.candidates.length > 0) { diff --git a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts index 197360c2f06e6..700e26d5a0a14 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts @@ -87,9 +87,13 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { }; if (actionResult.status === 'error') { - throw new Error( + const error = new Error( `ActionsClientGeminiChatModel: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}` ); + if (actionResult?.serviceMessage) { + error.name = actionResult?.serviceMessage; + } + throw error; } if (actionResult.data.candidates && actionResult.data.candidates.length > 0) { @@ -162,9 +166,13 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { const actionResult = await this.#actionsClient.execute(requestBody); if (actionResult.status === 'error') { - throw new Error( + const error = new Error( `ActionsClientGeminiChatModel: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}` ); + if (actionResult?.serviceMessage) { + error.name = actionResult?.serviceMessage; + } + throw error; } const readable = get('data', actionResult) as Readable; diff --git a/x-pack/packages/kbn-langchain/server/language_models/llm.ts b/x-pack/packages/kbn-langchain/server/language_models/llm.ts index 8ebf62e8c31f0..2a634ccb490cf 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/llm.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/llm.ts @@ -108,9 +108,13 @@ export class ActionsClientLlm extends LLM { const actionResult = await this.#actionsClient.execute(requestBody); if (actionResult.status === 'error') { - throw new Error( + const error = new Error( `${LLM_TYPE}: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}` ); + if (actionResult?.serviceMessage) { + error.name = actionResult?.serviceMessage; + } + throw error; } const content = get('data.message', actionResult); diff --git a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.ts b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.ts index 5133b1ae6543a..a66d088345b22 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.ts @@ -127,9 +127,13 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel { const actionResult = await this.#actionsClient.execute(requestBody); if (actionResult.status === 'error') { - throw new Error( + const error = new Error( `ActionsClientSimpleChatModel: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}` ); + if (actionResult?.serviceMessage) { + error.name = actionResult?.serviceMessage; + } + throw error; } if (!this.streaming) { @@ -217,9 +221,13 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel { const actionResult = await this.#actionsClient.execute(requestBody); if (actionResult.status === 'error') { - throw new Error( + const error = new Error( `ActionsClientSimpleChatModel: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}` ); + if (actionResult?.serviceMessage) { + error.name = actionResult?.serviceMessage; + } + throw error; } const readable = get('data', actionResult) as Readable; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts index 47a36ddf844b0..b01f9d3fabe9f 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts @@ -58,22 +58,31 @@ export async function generateChatTitle({ state, model, }: GenerateChatTitleParams): Promise> { - logger.debug( - () => `${NodeType.GENERATE_CHAT_TITLE}: Node state:\n${JSON.stringify(state, null, 2)}` - ); + try { + logger.debug( + () => `${NodeType.GENERATE_CHAT_TITLE}: Node state:\n${JSON.stringify(state, null, 2)}` + ); - const outputParser = new StringOutputParser(); - const graph = GENERATE_CHAT_TITLE_PROMPT(state.responseLanguage, state.llmType) - .pipe(model) - .pipe(outputParser); + const outputParser = new StringOutputParser(); + const graph = GENERATE_CHAT_TITLE_PROMPT(state.responseLanguage, state.llmType) + .pipe(model) + .pipe(outputParser); - const chatTitle = await graph.invoke({ - input: JSON.stringify(state.input, null, 2), - }); - logger.debug(`chatTitle: ${chatTitle}`); + const chatTitle = await graph.invoke({ + input: JSON.stringify(state.input, null, 2), + }); + logger.debug(`chatTitle: ${chatTitle}`); - return { - chatTitle, - lastNode: NodeType.GENERATE_CHAT_TITLE, - }; + return { + chatTitle, + lastNode: NodeType.GENERATE_CHAT_TITLE, + }; + } catch (e) { + return { + // generate a chat title if there is an error in order to complete the graph + // limit title to 60 characters + chatTitle: (e.name ?? e.message ?? e.toString()).slice(0, 60), + lastNode: NodeType.GENERATE_CHAT_TITLE, + }; + } } diff --git a/x-pack/test/security_solution_cypress/cypress/e2e/ai_assistant/conversations.cy.ts b/x-pack/test/security_solution_cypress/cypress/e2e/ai_assistant/conversations.cy.ts index 2b277e73cf24a..c91ee7de475e3 100644 --- a/x-pack/test/security_solution_cypress/cypress/e2e/ai_assistant/conversations.cy.ts +++ b/x-pack/test/security_solution_cypress/cypress/e2e/ai_assistant/conversations.cy.ts @@ -19,6 +19,7 @@ import { createNewChat, selectConversation, assertMessageSent, + assertConversationTitle, typeAndSendMessage, assertErrorResponse, selectRule, @@ -145,18 +146,16 @@ describe('AI Assistant Conversations', { tags: ['@ess', '@serverless'] }, () => assertConnectorSelected(bedrockConnectorAPIPayload.name); assertMessageSent('goodbye'); }); - // This test is flakey due to the issue linked below and will be skipped until it is fixed - it.skip('Only allows one conversation called "New chat" at a time', () => { + it('Correctly titles new conversations, and only allows one conversation called "New chat" at a time', () => { visitGetStartedPage(); openAssistant(); createNewChat(); assertNewConversation(false, 'New chat'); assertConnectorSelected(azureConnectorAPIPayload.name); typeAndSendMessage('hello'); - // TODO fix bug with new chat and error message - // https://github.com/elastic/kibana/issues/191025 - // assertMessageSent('hello'); - assertErrorResponse(); + assertMessageSent('hello'); + assertConversationTitle('Unexpected API Error: - Connection error.'); + updateConversationTitle('New chat'); selectConversation('Welcome'); createNewChat(); assertErrorToastShown('Error creating conversation with title New chat'); diff --git a/x-pack/test/security_solution_cypress/cypress/tasks/assistant.ts b/x-pack/test/security_solution_cypress/cypress/tasks/assistant.ts index 8a3bd3600591c..81491abd85f81 100644 --- a/x-pack/test/security_solution_cypress/cypress/tasks/assistant.ts +++ b/x-pack/test/security_solution_cypress/cypress/tasks/assistant.ts @@ -86,7 +86,7 @@ export const resetConversation = () => { export const selectConversation = (conversationName: string) => { cy.get(FLYOUT_NAV_TOGGLE).click(); cy.get(CONVERSATION_SELECT(conversationName)).click(); - cy.get(CONVERSATION_TITLE + ' h2').should('have.text', conversationName); + assertConversationTitle(conversationName); cy.get(FLYOUT_NAV_TOGGLE).click(); }; @@ -95,7 +95,7 @@ export const updateConversationTitle = (newTitle: string) => { cy.get(CONVERSATION_TITLE + ' input').clear(); cy.get(CONVERSATION_TITLE + ' input').type(newTitle); cy.get(CONVERSATION_TITLE + ' input').type('{enter}'); - cy.get(CONVERSATION_TITLE + ' h2').should('have.text', newTitle); + assertConversationTitle(newTitle); }; export const typeAndSendMessage = (message: string) => { @@ -171,9 +171,12 @@ export const assertNewConversation = (isWelcome: boolean, title: string) => { } else { cy.get(EMPTY_CONVO).should('be.visible'); } - cy.get(CONVERSATION_TITLE + ' h2').should('have.text', title); + assertConversationTitle(title); }; +export const assertConversationTitle = (title: string) => + cy.get(CONVERSATION_TITLE + ' h2').should('have.text', title); + export const assertSystemPromptSent = (message: string) => { cy.get(CONVERSATION_MESSAGE).eq(0).should('contain', message); };