diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/executor.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/executor.test.ts deleted file mode 100644 index a01ac3d126e59..0000000000000 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/executor.test.ts +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { executeAction, Props } from './executor'; -import { PassThrough } from 'stream'; -import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; -import { loggerMock } from '@kbn/logging-mocks'; -import * as ParseStream from './parse_stream'; - -const onLlmResponse = jest.fn(async () => {}); // We need it to be a promise, or it'll crash because of missing `.catch` -const connectorId = 'testConnectorId'; -const mockLogger = loggerMock.create(); -const testProps: Omit = { - params: { - subAction: 'invokeAI', - subActionParams: { messages: [{ content: 'hello', role: 'user' }] }, - }, - actionTypeId: '.bedrock', - connectorId, - actionsClient: actionsClientMock.create(), - onLlmResponse, - logger: mockLogger, -}; - -const handleStreamStorageSpy = jest.spyOn(ParseStream, 'handleStreamStorage'); - -describe('executeAction', () => { - beforeEach(() => { - jest.clearAllMocks(); - }); - it('should execute an action and return a StaticResponse when the response from the actions framework is a string', async () => { - testProps.actionsClient.execute = jest.fn().mockResolvedValue({ - data: { - message: 'Test message', - }, - }); - - const result = await executeAction({ ...testProps }); - - expect(result).toEqual({ - connector_id: connectorId, - data: 'Test message', - status: 'ok', - }); - expect(onLlmResponse).toHaveBeenCalledWith('Test message'); - }); - - it('should execute an action and return a Readable object when the response from the actions framework is a stream', async () => { - const readableStream = new PassThrough(); - const actionsClient = actionsClientMock.create(); - actionsClient.execute.mockImplementationOnce( - jest.fn().mockResolvedValue({ - status: 'ok', - data: readableStream, - }) - ); - - const result = await executeAction({ ...testProps, actionsClient }); - - expect(JSON.stringify(result)).toStrictEqual( - JSON.stringify(readableStream.pipe(new PassThrough())) - ); - - expect(handleStreamStorageSpy).toHaveBeenCalledWith({ - actionTypeId: '.bedrock', - onMessageSent: onLlmResponse, - logger: mockLogger, - responseStream: readableStream, - }); - }); - - it('should throw an error if the actions client fails to execute the action', async () => { - const actionsClient = actionsClientMock.create(); - actionsClient.execute.mockRejectedValue(new Error('Failed to execute action')); - testProps.actionsClient = actionsClient; - - await expect(executeAction({ ...testProps, actionsClient })).rejects.toThrowError( - 'Failed to execute action' - ); - }); - - it('should throw an error when the response from the actions framework is null or undefined', async () => { - const actionsClient = actionsClientMock.create(); - actionsClient.execute.mockImplementationOnce( - jest.fn().mockResolvedValue({ - data: null, - }) - ); - testProps.actionsClient = actionsClient; - - try { - await executeAction({ ...testProps, actionsClient }); - } catch (e) { - expect(e.message).toBe('Action result status is error: result is not streamable'); - } - }); - - it('should throw an error if action result status is "error"', async () => { - const actionsClient = actionsClientMock.create(); - actionsClient.execute.mockImplementationOnce( - jest.fn().mockResolvedValue({ - status: 'error', - message: 'Error message', - serviceMessage: 'Service error message', - }) - ); - testProps.actionsClient = actionsClient; - - await expect( - executeAction({ - ...testProps, - actionsClient, - connectorId: '12345', - }) - ).rejects.toThrowError('Action result status is error: Error message - Service error message'); - }); - - it('should throw an error if content of response data is not a string or streamable', async () => { - const actionsClient = actionsClientMock.create(); - actionsClient.execute.mockImplementationOnce( - jest.fn().mockResolvedValue({ - status: 'ok', - data: { - message: 12345, - }, - }) - ); - testProps.actionsClient = actionsClient; - - await expect( - executeAction({ - ...testProps, - - actionsClient, - connectorId: '12345', - }) - ).rejects.toThrowError('Action result status is error: result is not streamable'); - }); -}); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/executor.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/executor.ts deleted file mode 100644 index bd25a77808dbe..0000000000000 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/executor.ts +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { get } from 'lodash/fp'; -import { ActionsClient } from '@kbn/actions-plugin/server'; -import { PassThrough, Readable } from 'stream'; -import { Logger } from '@kbn/core/server'; -import { PublicMethodsOf } from '@kbn/utility-types'; -import { handleStreamStorage } from './parse_stream'; - -export interface Props { - onLlmResponse?: (content: string) => Promise; - abortSignal?: AbortSignal; - actionsClient: PublicMethodsOf; - connectorId: string; - params: InvokeAIActionsParams; - actionTypeId: string; - logger: Logger; -} -export interface StaticResponse { - connector_id: string; - data: string; - status: string; -} - -interface InvokeAIActionsParams { - subActionParams: { - messages: Array<{ role: string; content: string }>; - model?: string; - n?: number; - stop?: string | string[] | null; - stopSequences?: string[]; - temperature?: number; - }; - subAction: 'invokeAI' | 'invokeStream'; -} - -export const executeAction = async ({ - onLlmResponse, - actionsClient, - params, - connectorId, - actionTypeId, - logger, - abortSignal, -}: Props): Promise => { - const actionResult = await actionsClient.execute({ - actionId: connectorId, - params: { - subAction: params.subAction, - subActionParams: { - ...params.subActionParams, - signal: abortSignal, - }, - }, - }); - - if (actionResult.status === 'error') { - throw new Error( - `Action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}` - ); - } - const content = get('data.message', actionResult); - if (typeof content === 'string') { - if (onLlmResponse) { - await onLlmResponse(content); - } - return { - connector_id: connectorId, - data: content, // the response from the actions framework - status: 'ok', - }; - } - - const readable = get('data', actionResult) as Readable; - if (typeof readable?.read !== 'function') { - throw new Error('Action result status is error: result is not streamable'); - } - - // do not await, blocks stream for UI - handleStreamStorage({ - actionTypeId, - onMessageSent: onLlmResponse, - logger, - responseStream: readable, - abortSignal, - }).catch(() => {}); - - return readable.pipe(new PassThrough()); -}; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts index 32f2b808b41a1..f8f84b8c2cc0a 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts @@ -74,35 +74,24 @@ describe('streamGraph', () => { describe('OpenAI Function Agent streaming', () => { it('should execute the graph in streaming mode - OpenAI + isOssModel = false', async () => { mockStreamEvents.mockReturnValue({ - next: jest - .fn() - .mockResolvedValueOnce({ - value: { - name: 'ActionsClientChatOpenAI', - event: 'on_llm_stream', - data: { chunk: { message: { content: 'content' } } }, - tags: [AGENT_NODE_TAG], - }, - done: false, - }) - .mockResolvedValueOnce({ - value: { - name: 'ActionsClientChatOpenAI', - event: 'on_llm_end', - data: { - output: { - generations: [ - [{ generationInfo: { finish_reason: 'stop' }, text: 'final message' }], - ], - }, + async *[Symbol.asyncIterator]() { + yield { + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }; + yield { + event: 'on_llm_end', + data: { + output: { + generations: [ + [{ generationInfo: { finish_reason: 'stop' }, text: 'final message' }], + ], }, - tags: [AGENT_NODE_TAG], }, - }) - .mockResolvedValue({ - done: true, - }), - return: jest.fn(), + tags: [AGENT_NODE_TAG], + }; + }, }); const response = await streamGraph(requestArgs); @@ -119,33 +108,22 @@ describe('streamGraph', () => { }); it('on_llm_end events with finish_reason != stop should not end the stream', async () => { mockStreamEvents.mockReturnValue({ - next: jest - .fn() - .mockResolvedValueOnce({ - value: { - name: 'ActionsClientChatOpenAI', - event: 'on_llm_stream', - data: { chunk: { message: { content: 'content' } } }, - tags: [AGENT_NODE_TAG], - }, - done: false, - }) - .mockResolvedValueOnce({ - value: { - name: 'ActionsClientChatOpenAI', - event: 'on_llm_end', - data: { - output: { - generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]], - }, + async *[Symbol.asyncIterator]() { + yield { + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }; + yield { + event: 'on_llm_end', + data: { + output: { + generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]], }, - tags: [AGENT_NODE_TAG], }, - }) - .mockResolvedValue({ - done: true, - }), - return: jest.fn(), + tags: [AGENT_NODE_TAG], + }; + }, }); const response = await streamGraph(requestArgs); @@ -158,33 +136,22 @@ describe('streamGraph', () => { }); it('on_llm_end events without a finish_reason should end the stream', async () => { mockStreamEvents.mockReturnValue({ - next: jest - .fn() - .mockResolvedValueOnce({ - value: { - name: 'ActionsClientChatOpenAI', - event: 'on_llm_stream', - data: { chunk: { message: { content: 'content' } } }, - tags: [AGENT_NODE_TAG], - }, - done: false, - }) - .mockResolvedValueOnce({ - value: { - name: 'ActionsClientChatOpenAI', - event: 'on_llm_end', - data: { - output: { - generations: [[{ generationInfo: {}, text: 'final message' }]], - }, + async *[Symbol.asyncIterator]() { + yield { + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }; + yield { + event: 'on_llm_end', + data: { + output: { + generations: [[{ generationInfo: {}, text: 'final message' }]], }, - tags: [AGENT_NODE_TAG], }, - }) - .mockResolvedValue({ - done: true, - }), - return: jest.fn(), + tags: [AGENT_NODE_TAG], + }; + }, }); const response = await streamGraph(requestArgs); @@ -201,33 +168,22 @@ describe('streamGraph', () => { }); it('on_llm_end events is called with chunks if there is no final text value', async () => { mockStreamEvents.mockReturnValue({ - next: jest - .fn() - .mockResolvedValueOnce({ - value: { - name: 'ActionsClientChatOpenAI', - event: 'on_llm_stream', - data: { chunk: { message: { content: 'content' } } }, - tags: [AGENT_NODE_TAG], - }, - done: false, - }) - .mockResolvedValueOnce({ - value: { - name: 'ActionsClientChatOpenAI', - event: 'on_llm_end', - data: { - output: { - generations: [[{ generationInfo: {}, text: '' }]], - }, + async *[Symbol.asyncIterator]() { + yield { + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }; + yield { + event: 'on_llm_end', + data: { + output: { + generations: [[{ generationInfo: {}, text: '' }]], }, - tags: [AGENT_NODE_TAG], }, - }) - .mockResolvedValue({ - done: true, - }), - return: jest.fn(), + tags: [AGENT_NODE_TAG], + }; + }, }); const response = await streamGraph(requestArgs); @@ -242,6 +198,28 @@ describe('streamGraph', () => { ); }); }); + it('on_llm_end does not call handleStreamEnd if generations is undefined', async () => { + mockStreamEvents.mockReturnValue({ + async *[Symbol.asyncIterator]() { + yield { + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }; + yield { + event: 'on_llm_end', + data: {}, + tags: [AGENT_NODE_TAG], + }; + }, + }); + + const response = await streamGraph(requestArgs); + + expect(response).toBe(mockResponseWithHeaders); + expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); + expect(mockOnLlmResponse).not.toHaveBeenCalled(); + }); }); describe('Tool Calling Agent and Structured Chat Agent streaming', () => { @@ -330,7 +308,7 @@ describe('streamGraph', () => { await expectConditions(response); }); - it('should execute the graph in streaming mode - OpenAI + isOssModel = false', async () => { + it('should execute the graph in streaming mode - OpenAI + isOssModel = true', async () => { const mockAssistantGraphAsyncIterator = { streamEvents: () => mockAsyncIterator, } as unknown as DefaultAssistantGraph; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index f1a5413197632..73b7b43c2d036 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -9,7 +9,6 @@ import agent, { Span } from 'elastic-apm-node'; import type { Logger } from '@kbn/logging'; import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry'; import { streamFactory, StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; -import { transformError } from '@kbn/securitysolution-es-utils'; import type { KibanaRequest } from '@kbn/core-http-server'; import type { ExecuteConnectorRequestBody, TraceData } from '@kbn/elastic-assistant-common'; import { APMTracer } from '@kbn/langchain/server/tracers/apm'; @@ -126,7 +125,6 @@ export const streamGraph = async ({ // Stream is from openai functions agent let finalMessage = ''; - let conversationId: string | undefined; const stream = assistantGraph.streamEvents(inputs, { callbacks: [ apmTracer, @@ -139,63 +137,37 @@ export const streamGraph = async ({ version: 'v1', }); - const processEvent = async () => { - try { - const { value, done } = await stream.next(); - if (done) return; - - const event = value; - // only process events that are part of the agent run - if ((event.tags || []).includes(AGENT_NODE_TAG)) { - if (event.name === 'ActionsClientChatOpenAI') { - if (event.event === 'on_llm_stream') { - const chunk = event.data?.chunk; - const msg = chunk.message; - if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) { - // I don't think we hit this anymore because of our check for AGENT_NODE_TAG - // however, no harm to keep it in - /* empty */ - } else if (!didEnd) { - push({ payload: msg.content, type: 'content' }); - finalMessage += msg.content; - } - } else if (event.event === 'on_llm_end' && !didEnd) { - const generation = event.data.output?.generations[0][0]; - if ( - // no finish_reason means the stream was aborted - !generation?.generationInfo?.finish_reason || - generation?.generationInfo?.finish_reason === 'stop' - ) { - handleStreamEnd( - generation?.text && generation?.text.length ? generation?.text : finalMessage - ); - } - } + for await (const { event, data, tags } of stream) { + if ((tags || []).includes(AGENT_NODE_TAG)) { + if (event === 'on_llm_stream') { + const chunk = data?.chunk; + const msg = chunk.message; + if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) { + // I don't think we hit this anymore because of our check for AGENT_NODE_TAG + // however, no harm to keep it in + /* empty */ + } else if (!didEnd) { + push({ payload: msg.content, type: 'content' }); + finalMessage += msg.content; } } - void processEvent(); - } catch (err) { - // if I throw an error here, it crashes the server. Not sure how to get around that. - // If I put await on this function the error works properly, but when there is not an error - // it waits for the entire stream to complete before resolving - const error = transformError(err); - - if (error.message === 'AbortError') { - // user aborted the stream, we must end it manually here - return handleStreamEnd(finalMessage); - } - logger.error(`Error streaming from LangChain: ${error.message}`); - if (conversationId) { - push({ payload: `Conversation id: ${conversationId}`, type: 'content' }); + if (event === 'on_llm_end' && !didEnd) { + const generation = data.output?.generations[0][0]; + if ( + // if generation is null, an error occurred - do nothing and let error handling complete the stream + generation != null && + // no finish_reason means the stream was aborted + (!generation?.generationInfo?.finish_reason || + generation?.generationInfo?.finish_reason === 'stop') + ) { + handleStreamEnd( + generation?.text && generation?.text.length ? generation?.text : finalMessage + ); + } } - push({ payload: error.message, type: 'content' }); - handleStreamEnd(error.message, true); } - }; - - // Start processing events, do not await! Return `responseWithHeaders` immediately - void processEvent(); + } return responseWithHeaders; }; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/parse_stream.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/parse_stream.test.ts deleted file mode 100644 index 959bb51c40949..0000000000000 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/parse_stream.test.ts +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -import { Readable, Transform } from 'stream'; -import { loggerMock } from '@kbn/logging-mocks'; -import { handleStreamStorage } from './parse_stream'; -import { EventStreamCodec } from '@smithy/eventstream-codec'; -import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; - -function createStreamMock() { - const transform: Transform = new Transform({}); - - return { - write: (data: unknown) => { - transform.push(data); - }, - fail: () => { - transform.emit('error', new Error('Stream failed')); - transform.end(); - }, - transform, - complete: () => { - transform.end(); - }, - }; -} -const mockLogger = loggerMock.create(); -const onMessageSent = jest.fn(); -describe('handleStreamStorage', () => { - beforeEach(() => { - jest.resetAllMocks(); - }); - let stream: ReturnType; - - const chunk = { - object: 'chat.completion.chunk', - choices: [ - { - delta: { - content: 'Single.', - }, - }, - ], - }; - let defaultProps = { - responseStream: jest.fn() as unknown as Readable, - actionTypeId: '.gen-ai', - onMessageSent, - logger: mockLogger, - }; - - describe('OpenAI stream', () => { - beforeEach(() => { - stream = createStreamMock(); - stream.write(`data: ${JSON.stringify(chunk)}`); - defaultProps = { - responseStream: stream.transform, - actionTypeId: '.gen-ai', - onMessageSent, - logger: mockLogger, - }; - }); - - it('saves the final string successful streaming event', async () => { - stream.complete(); - await handleStreamStorage(defaultProps); - expect(onMessageSent).toHaveBeenCalledWith('Single.'); - }); - it('saves the error message on a failed streaming event', async () => { - const tokenPromise = handleStreamStorage(defaultProps); - - stream.fail(); - await expect(tokenPromise).resolves.not.toThrow(); - expect(onMessageSent).toHaveBeenCalledWith( - `An error occurred while streaming the response:\n\nStream failed` - ); - }); - }); - describe('Bedrock stream', () => { - beforeEach(() => { - stream = createStreamMock(); - stream.write(encodeBedrockResponse('Simple.')); - defaultProps = { - responseStream: stream.transform, - actionTypeId: '.gen-ai', - onMessageSent, - logger: mockLogger, - }; - }); - - it('saves the final string successful streaming event', async () => { - stream.complete(); - await handleStreamStorage({ ...defaultProps, actionTypeId: '.bedrock' }); - expect(onMessageSent).toHaveBeenCalledWith('Simple.'); - }); - it('saves the error message on a failed streaming event', async () => { - const tokenPromise = handleStreamStorage({ ...defaultProps, actionTypeId: '.bedrock' }); - - stream.fail(); - await expect(tokenPromise).resolves.not.toThrow(); - expect(onMessageSent).toHaveBeenCalledWith( - `An error occurred while streaming the response:\n\nStream failed` - ); - }); - }); - describe('Gemini stream', () => { - beforeEach(() => { - stream = createStreamMock(); - const payload = { - candidates: [ - { - content: { - parts: [ - { - text: 'Single.', - }, - ], - }, - }, - ], - }; - stream.write(`data: ${JSON.stringify(payload)}`); - defaultProps = { - responseStream: stream.transform, - actionTypeId: '.gemini', - onMessageSent, - logger: mockLogger, - }; - }); - - it('saves the final string successful streaming event', async () => { - stream.complete(); - await handleStreamStorage(defaultProps); - expect(onMessageSent).toHaveBeenCalledWith('Single.'); - }); - it('saves the error message on a failed streaming event', async () => { - const tokenPromise = handleStreamStorage(defaultProps); - - stream.fail(); - await expect(tokenPromise).resolves.not.toThrow(); - expect(onMessageSent).toHaveBeenCalledWith( - `An error occurred while streaming the response:\n\nStream failed` - ); - }); - }); -}); - -function encodeBedrockResponse(completion: string) { - return new EventStreamCodec(toUtf8, fromUtf8).encode({ - headers: {}, - body: Uint8Array.from( - Buffer.from( - JSON.stringify({ - bytes: Buffer.from( - JSON.stringify({ - type: 'content_block_delta', - index: 0, - delta: { type: 'text_delta', text: completion }, - }) - ).toString('base64'), - }) - ) - ), - }); -} diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/parse_stream.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/parse_stream.ts deleted file mode 100644 index 3aef870be8116..0000000000000 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/parse_stream.ts +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { Readable } from 'stream'; -import { Logger } from '@kbn/core/server'; -import { parseBedrockStream, parseGeminiResponse } from '@kbn/langchain/server'; - -type StreamParser = ( - responseStream: Readable, - logger: Logger, - abortSignal?: AbortSignal, - tokenHandler?: (token: string) => void -) => Promise; - -export const handleStreamStorage = async ({ - abortSignal, - responseStream, - actionTypeId, - onMessageSent, - logger, -}: { - abortSignal?: AbortSignal; - responseStream: Readable; - actionTypeId: string; - onMessageSent?: (content: string) => void; - logger: Logger; -}): Promise => { - try { - const parser = - actionTypeId === '.bedrock' - ? parseBedrockStream - : actionTypeId === '.gemini' - ? parseGeminiStream - : parseOpenAIStream; - const parsedResponse = await parser(responseStream, logger, abortSignal); - if (onMessageSent) { - onMessageSent(parsedResponse); - } - } catch (e) { - if (onMessageSent) { - onMessageSent(`An error occurred while streaming the response:\n\n${e.message}`); - } - } -}; - -const parseOpenAIStream: StreamParser = async (stream, logger, abortSignal) => { - let responseBody = ''; - stream.on('data', (chunk) => { - responseBody += chunk.toString(); - }); - return new Promise((resolve, reject) => { - stream.on('end', () => { - resolve(parseOpenAIResponse(responseBody)); - }); - stream.on('error', (err) => { - reject(err); - }); - if (abortSignal) { - abortSignal.addEventListener('abort', () => { - stream.destroy(); - resolve(parseOpenAIResponse(responseBody)); - }); - } - }); -}; - -const parseOpenAIResponse = (responseBody: string) => - responseBody - .split('\n') - .filter((line) => { - return line.startsWith('data: ') && !line.endsWith('[DONE]'); - }) - .map((line) => { - return JSON.parse(line.replace('data: ', '')); - }) - .filter( - ( - line - ): line is { - choices: Array<{ - delta: { content?: string; function_call?: { name?: string; arguments: string } }; - }>; - } => { - return ( - 'object' in line && line.object === 'chat.completion.chunk' && line.choices.length > 0 - ); - } - ) - .reduce((prev, line) => { - const msg = line.choices[0].delta; - return prev + (msg.content || ''); - }, ''); - -export const parseGeminiStream: StreamParser = async (stream, logger, abortSignal) => { - let responseBody = ''; - stream.on('data', (chunk) => { - responseBody += chunk.toString(); - }); - return new Promise((resolve, reject) => { - stream.on('end', () => { - resolve(parseGeminiResponse(responseBody)); - }); - stream.on('error', (err) => { - reject(err); - }); - if (abortSignal) { - abortSignal.addEventListener('abort', () => { - stream.destroy(); - logger.info('Gemini stream parsing was aborted.'); - resolve(parseGeminiResponse(responseBody)); - }); - } - }); -}; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts index 9f4d0beb3caff..0f0ee48c94cf5 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts @@ -30,19 +30,7 @@ const actionsClient = actionsClientMock.create(); jest.mock('../lib/build_response', () => ({ buildResponse: jest.fn().mockImplementation((x) => x), })); -jest.mock('../lib/executor', () => ({ - executeAction: jest.fn().mockImplementation(async ({ connectorId }) => { - if (connectorId === 'mock-connector-id') { - return { - connector_id: 'mock-connector-id', - data: mockActionResponse, - status: 'ok', - }; - } else { - throw new Error('simulated error'); - } - }), -})); + const mockStream = jest.fn().mockImplementation(() => new PassThrough()); const mockLangChainExecute = langChainExecute as jest.Mock; const mockAppendAssistantMessageToConversation = appendAssistantMessageToConversation as jest.Mock;