From 17624a4f1c498a839560fafed861e6f4489675f5 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Thu, 5 Dec 2024 14:27:29 -0700 Subject: [PATCH] [Security Assistant] Abort signal fix (#203041) (cherry picked from commit b3b2c1745ac846b377d8e234526aa965196e4de9) --- .../chat_vertex/chat_vertex.ts | 1 + .../language_models/chat_vertex/connection.ts | 1 + .../graphs/default_assistant_graph/graph.ts | 17 ++++++++-- .../graphs/default_assistant_graph/helpers.ts | 5 +-- .../graphs/default_assistant_graph/index.ts | 1 + .../default_assistant_graph/nodes/respond.ts | 5 ++- .../nodes/run_agent.ts | 31 ++++++++++--------- 7 files changed, 38 insertions(+), 23 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts index 7cea2d421a9da..5c7a9ef918da3 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts @@ -93,6 +93,7 @@ export class ActionsClientChatVertexAI extends ChatVertexAI { tools: data?.tools, temperature: this.temperature, ...systemInstruction, + signal: options?.signal, }, }, }; diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts index 8ce776890acfa..442e6b079db9b 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts @@ -82,6 +82,7 @@ export class ActionsClientChatConnection extends ChatConnection { tools: data?.tools, temperature: this.temperature, ...systemInstruction, + signal: options?.signal, }, }, }; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index 4688caa176b56..10ecebb5e3f9b 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -34,6 +34,7 @@ export interface GetDefaultAssistantGraphParams { dataClients?: AssistantDataClients; createLlmInstance: () => BaseChatModel; logger: Logger; + signal?: AbortSignal; tools: StructuredTool[]; replacements: Replacements; } @@ -45,6 +46,8 @@ export const getDefaultAssistantGraph = ({ dataClients, createLlmInstance, logger, + // some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model + signal, tools, replacements, }: GetDefaultAssistantGraphParams) => { @@ -137,11 +140,19 @@ export const getDefaultAssistantGraph = ({ }) ) .addNode(NodeType.AGENT, (state: AgentState) => - runAgent({ ...nodeParams, state, agentRunnable, kbDataClient: dataClients?.kbDataClient }) + runAgent({ + ...nodeParams, + config: { signal }, + state, + agentRunnable, + kbDataClient: dataClients?.kbDataClient, + }) + ) + .addNode(NodeType.TOOLS, (state: AgentState) => + executeTools({ ...nodeParams, config: { signal }, state, tools }) ) - .addNode(NodeType.TOOLS, (state: AgentState) => executeTools({ ...nodeParams, state, tools })) .addNode(NodeType.RESPOND, (state: AgentState) => - respond({ ...nodeParams, state, model: createLlmInstance() }) + respond({ ...nodeParams, config: { signal }, state, model: createLlmInstance() }) ) .addNode(NodeType.MODEL_INPUT, (state: AgentState) => modelInput({ ...nodeParams, state })) .addEdge(START, NodeType.MODEL_INPUT) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 0126692b5b6a5..a4b36dfa8dc22 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -160,10 +160,7 @@ export const streamGraph = async ({ finalMessage += msg.content; } } else if (event.event === 'on_llm_end' && !didEnd) { - const generations = event.data.output?.generations[0]; - if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { - handleStreamEnd(generations[0]?.text ?? finalMessage); - } + handleStreamEnd(event.data.output?.generations[0][0]?.text ?? finalMessage); } } } diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 4ddd3eae11624..60c229b46e61c 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -173,6 +173,7 @@ export const callAssistantGraph: AgentExecutor = async ({ // we need to pass it like this or streaming does not work for bedrock createLlmInstance, logger, + signal: abortSignal, tools, replacements, }); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts index bfd62ee7aab21..76d449373488f 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts @@ -7,6 +7,7 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { StringWithAutocomplete } from '@langchain/core/dist/utils/types'; +import { RunnableConfig } from '@langchain/core/runnables'; import { AGENT_NODE_TAG } from './run_agent'; import { AgentState, NodeParamsBase } from '../types'; import { NodeType } from '../constants'; @@ -14,9 +15,11 @@ import { NodeType } from '../constants'; export interface RespondParams extends NodeParamsBase { state: AgentState; model: BaseChatModel; + config?: RunnableConfig; } export async function respond({ + config, logger, state, model, @@ -34,7 +37,7 @@ export async function respond({ const responseMessage = await model // use AGENT_NODE_TAG to identify as agent node for stream parsing - .withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG] }) + .withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG], signal: config?.signal }) .invoke([userMessage]); return { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts index 053254a1d99b3..952b97287c3ca 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts @@ -43,21 +43,22 @@ export async function runAgent({ logger.debug(() => `${NodeType.AGENT}: Node state:\n${JSON.stringify(state, null, 2)}`); const knowledgeHistory = await kbDataClient?.getRequiredKnowledgeBaseDocumentEntries(); - - const agentOutcome = await agentRunnable.withConfig({ tags: [AGENT_NODE_TAG] }).invoke( - { - ...state, - knowledge_history: `${KNOWLEDGE_HISTORY_PREFIX}\n${ - knowledgeHistory?.length - ? JSON.stringify(knowledgeHistory.map((e) => e.text)) - : NO_KNOWLEDGE_HISTORY - }`, - // prepend any user prompt (gemini) - input: formatLatestUserMessage(state.input, state.llmType), - chat_history: state.messages, // TODO: Message de-dupe with ...state spread - }, - config - ); + const agentOutcome = await agentRunnable + .withConfig({ tags: [AGENT_NODE_TAG], signal: config?.signal }) + .invoke( + { + ...state, + knowledge_history: `${KNOWLEDGE_HISTORY_PREFIX}\n${ + knowledgeHistory?.length + ? JSON.stringify(knowledgeHistory.map((e) => e.text)) + : NO_KNOWLEDGE_HISTORY + }`, + // prepend any user prompt (gemini) + input: formatLatestUserMessage(state.input, state.llmType), + chat_history: state.messages, // TODO: Message de-dupe with ...state spread + }, + config + ); return { agentOutcome,