From 8284398023648e850a8ca038bce8be6cc85cc51f Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Wed, 18 Oct 2023 12:06:03 -0600 Subject: [PATCH] [Security solution] Fix OpenAI token reporting (#169156) --- .../impl/assistant/api.test.tsx | 17 +++++--- .../impl/assistant/api.tsx | 7 +-- .../server/__mocks__/action_result_data.ts | 5 ++- .../elastic_assistant/server/lib/executor.ts | 43 +++++++++++++++++++ .../langchain/llm/actions_client_llm.test.ts | 5 ++- .../lib/langchain/llm/actions_client_llm.ts | 3 +- .../server/routes/evaluate/post_evaluate.ts | 1 + .../post_actions_connector_execute.test.ts | 41 +++++++++++++++++- .../routes/post_actions_connector_execute.ts | 9 ++++ .../schemas/post_actions_connector_execute.ts | 1 + .../stack_connectors/common/bedrock/schema.ts | 4 +- .../stack_connectors/common/openai/schema.ts | 12 +++++- .../connector_types/bedrock/bedrock.test.ts | 4 +- .../server/connector_types/bedrock/bedrock.ts | 2 +- .../connector_types/openai/openai.test.ts | 8 +++- .../server/connector_types/openai/openai.ts | 10 ++++- .../tests/actions/connector_types/bedrock.ts | 2 +- 17 files changed, 147 insertions(+), 27 deletions(-) create mode 100644 x-pack/plugins/elastic_assistant/server/lib/executor.ts diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx index 33dc820f449fa..e8feefbfd2533 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx @@ -54,7 +54,7 @@ describe('API tests', () => { expect(mockHttp.fetch).toHaveBeenCalledWith( '/internal/elastic_assistant/actions/connector/foo/_execute', { - body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"}}', + body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":true}', headers: { 'Content-Type': 'application/json' }, method: 'POST', signal: undefined, @@ -72,12 +72,15 @@ describe('API tests', () => { await fetchConnectorExecuteAction(testProps); - expect(mockHttp.fetch).toHaveBeenCalledWith('/api/actions/connector/foo/_execute', { - body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"}}', - headers: { 'Content-Type': 'application/json' }, - method: 'POST', - signal: undefined, - }); + expect(mockHttp.fetch).toHaveBeenCalledWith( + '/internal/elastic_assistant/actions/connector/foo/_execute', + { + body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":false}', + headers: { 'Content-Type': 'application/json' }, + method: 'POST', + signal: undefined, + } + ); }); it('returns API_ERROR when the response status is not ok', async () => { diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx index c7c1254656d61..8ccb2e72cfee9 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx @@ -59,19 +59,16 @@ export const fetchConnectorExecuteAction = async ({ subActionParams: body, subAction: 'invokeAI', }, + assistantLangChain, }; try { - const path = assistantLangChain - ? `/internal/elastic_assistant/actions/connector/${apiConfig?.connectorId}/_execute` - : `/api/actions/connector/${apiConfig?.connectorId}/_execute`; - const response = await http.fetch<{ connector_id: string; status: string; data: string; service_message?: string; - }>(path, { + }>(`/internal/elastic_assistant/actions/connector/${apiConfig?.connectorId}/_execute`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/x-pack/plugins/elastic_assistant/server/__mocks__/action_result_data.ts b/x-pack/plugins/elastic_assistant/server/__mocks__/action_result_data.ts index 17aa4b83ca67b..dbc095a334cea 100644 --- a/x-pack/plugins/elastic_assistant/server/__mocks__/action_result_data.ts +++ b/x-pack/plugins/elastic_assistant/server/__mocks__/action_result_data.ts @@ -5,4 +5,7 @@ * 2.0. */ -export const mockActionResponse = 'Yes, your name is Andrew. How can I assist you further, Andrew?'; +export const mockActionResponse = { + message: 'Yes, your name is Andrew. How can I assist you further, Andrew?', + usage: { prompt_tokens: 4, completion_tokens: 10, total_tokens: 14 }, +}; diff --git a/x-pack/plugins/elastic_assistant/server/lib/executor.ts b/x-pack/plugins/elastic_assistant/server/lib/executor.ts new file mode 100644 index 0000000000000..936e3781731d8 --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/lib/executor.ts @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { get } from 'lodash/fp'; +import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import { KibanaRequest } from '@kbn/core-http-server'; +import { RequestBody } from './langchain/types'; + +interface Props { + actions: ActionsPluginStart; + connectorId: string; + request: KibanaRequest; +} +interface StaticResponse { + connector_id: string; + data: string; + status: string; +} + +export const executeAction = async ({ + actions, + request, + connectorId, +}: Props): Promise => { + const actionsClient = await actions.getActionsClientWithRequest(request); + const actionResult = await actionsClient.execute({ + actionId: connectorId, + params: request.body.params, + }); + const content = get('data.message', actionResult); + if (typeof content === 'string') { + return { + connector_id: connectorId, + data: content, // the response from the actions framework + status: 'ok', + }; + } + throw new Error('Unexpected action result'); +}; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/llm/actions_client_llm.test.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/llm/actions_client_llm.test.ts index b5f8fa7e88c74..5c27cdef4d3e1 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/llm/actions_client_llm.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/llm/actions_client_llm.test.ts @@ -51,6 +51,7 @@ const mockRequest: KibanaRequest = { }, subAction: 'invokeAI', }, + assistantLangChain: true, }, } as KibanaRequest; @@ -72,7 +73,7 @@ describe('ActionsClientLlm', () => { await actionsClientLlm._call(prompt); // ignore the result - expect(actionsClientLlm.getActionResultData()).toEqual(mockActionResponse); + expect(actionsClientLlm.getActionResultData()).toEqual(mockActionResponse.message); }); }); @@ -141,7 +142,7 @@ describe('ActionsClientLlm', () => { }); it('rejects with the expected error the message has invalid content', async () => { - const invalidContent = 1234; + const invalidContent = { message: 1234 }; mockExecute.mockImplementation(() => ({ data: invalidContent, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/llm/actions_client_llm.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/llm/actions_client_llm.ts index e4403b64d6e0d..f499452e1d764 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/llm/actions_client_llm.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/llm/actions_client_llm.ts @@ -92,9 +92,8 @@ export class ActionsClientLlm extends LLM { `${LLM_TYPE}: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}` ); } - // TODO: handle errors from the connector - const content = get('data', actionResult); + const content = get('data.message', actionResult); if (typeof content !== 'string') { throw new Error( diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index b65822524f1cd..1b533e49c4cfe 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -105,6 +105,7 @@ export const postEvaluateRoute = ( messages: [], }, }, + assistantLangChain: true, }, }; diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts index fa0afb540dc30..507246670833c 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts @@ -19,6 +19,13 @@ import { coreMock } from '@kbn/core/server/mocks'; jest.mock('../lib/build_response', () => ({ buildResponse: jest.fn().mockImplementation((x) => x), })); +jest.mock('../lib/executor', () => ({ + executeAction: jest.fn().mockImplementation((x) => ({ + connector_id: 'mock-connector-id', + data: mockActionResponse, + status: 'ok', + })), +})); jest.mock('../lib/langchain/execute_custom_llm_chain', () => ({ callAgentExecutor: jest.fn().mockImplementation( @@ -82,6 +89,7 @@ const mockRequest = { }, subAction: 'invokeAI', }, + assistantLangChain: true, }, }; @@ -97,7 +105,38 @@ describe('postActionsConnectorExecuteRoute', () => { jest.clearAllMocks(); }); - it('returns the expected response', async () => { + it('returns the expected response when assistantLangChain=false', async () => { + const mockRouter = { + post: jest.fn().mockImplementation(async (_, handler) => { + const result = await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + assistantLangChain: false, + }, + }, + mockResponse + ); + + expect(result).toEqual({ + body: { + connector_id: 'mock-connector-id', + data: mockActionResponse, + status: 'ok', + }, + }); + }), + }; + + await postActionsConnectorExecuteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('returns the expected response when assistantLangChain=true', async () => { const mockRouter = { post: jest.fn().mockImplementation(async (_, handler) => { const result = await handler(mockContext, mockRequest, mockResponse); diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 5303796d1c983..8da820288ae1b 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -7,6 +7,7 @@ import { IRouter, Logger } from '@kbn/core/server'; import { transformError } from '@kbn/securitysolution-es-utils'; +import { executeAction } from '../lib/executor'; import { POST_ACTIONS_CONNECTOR_EXECUTE } from '../../common/constants'; import { getLangChainMessages } from '../lib/langchain/helpers'; import { buildResponse } from '../lib/build_response'; @@ -41,6 +42,14 @@ export const postActionsConnectorExecuteRoute = ( // get the actions plugin start contract from the request context: const actions = (await context.elasticAssistant).actions; + // if not langchain, call execute action directly and return the response: + if (!request.body.assistantLangChain) { + const result = await executeAction({ actions, request, connectorId }); + return response.ok({ + body: result, + }); + } + // get a scoped esClient for assistant memory const esClient = (await context.core).elasticsearch.client.asCurrentUser; diff --git a/x-pack/plugins/elastic_assistant/server/schemas/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/schemas/post_actions_connector_execute.ts index b30ccd94e105b..7a8d52e725722 100644 --- a/x-pack/plugins/elastic_assistant/server/schemas/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/schemas/post_actions_connector_execute.ts @@ -34,6 +34,7 @@ export const PostActionsConnectorExecuteBody = t.type({ ]), subAction: t.string, }), + assistantLangChain: t.boolean, }); export type PostActionsConnectorExecuteBodyInputs = t.TypeOf< diff --git a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts index ac23ed9667ada..64699253c709f 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts @@ -34,7 +34,9 @@ export const InvokeAIActionParamsSchema = schema.object({ model: schema.maybe(schema.string()), }); -export const InvokeAIActionResponseSchema = schema.string(); +export const InvokeAIActionResponseSchema = schema.object({ + message: schema.string(), +}); export const RunActionResponseSchema = schema.object( { diff --git a/x-pack/plugins/stack_connectors/common/openai/schema.ts b/x-pack/plugins/stack_connectors/common/openai/schema.ts index fa14aa61fa5b3..fd0b872ab9f36 100644 --- a/x-pack/plugins/stack_connectors/common/openai/schema.ts +++ b/x-pack/plugins/stack_connectors/common/openai/schema.ts @@ -44,7 +44,17 @@ export const InvokeAIActionParamsSchema = schema.object({ temperature: schema.maybe(schema.number()), }); -export const InvokeAIActionResponseSchema = schema.string(); +export const InvokeAIActionResponseSchema = schema.object({ + message: schema.string(), + usage: schema.object( + { + prompt_tokens: schema.number(), + completion_tokens: schema.number(), + total_tokens: schema.number(), + }, + { unknowns: 'ignore' } + ), +}); // Execute action schema export const StreamActionParamsSchema = schema.object({ diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts index 7ee8fd54833c7..dcd3d70f9b4ff 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts @@ -109,7 +109,7 @@ describe('BedrockConnector', () => { stop_sequences: ['\n\nHuman:'], }), }); - expect(response).toEqual(mockResponseString); + expect(response.message).toEqual(mockResponseString); }); it('Properly formats messages from user, assistant, and system', async () => { @@ -148,7 +148,7 @@ describe('BedrockConnector', () => { stop_sequences: ['\n\nHuman:'], }), }); - expect(response).toEqual(mockResponseString); + expect(response.message).toEqual(mockResponseString); }); it('errors during API calls are properly handled', async () => { diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index 6510731f8ad7e..0e1235312a52c 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -150,6 +150,6 @@ export class BedrockConnector extends SubActionConnector { }; const res = await this.runApi({ body: JSON.stringify(req), model }); - return res.completion.trim(); + return { message: res.completion.trim() }; } } diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts index 00f3b67aafb97..0a4a6a2931d8d 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts @@ -37,6 +37,11 @@ describe('OpenAIConnector', () => { index: 0, }, ], + usage: { + prompt_tokens: 4, + completion_tokens: 5, + total_tokens: 9, + }, }, }; beforeEach(() => { @@ -273,7 +278,8 @@ describe('OpenAIConnector', () => { 'content-type': 'application/json', }, }); - expect(response).toEqual(mockResponseString); + expect(response.message).toEqual(mockResponseString); + expect(response.usage.total_tokens).toEqual(9); }); it('errors during API calls are properly handled', async () => { diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts index 21c7bc4abdcc0..7413ba56090a1 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts @@ -192,9 +192,15 @@ export class OpenAIConnector extends SubActionConnector { if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) { const result = res.choices[0].message.content.trim(); - return result; + return { message: result, usage: res.usage }; } - return 'An error occurred sending your message. \n\nAPI Error: The response from OpenAI was in an unrecognized format.'; + return { + message: + 'An error occurred sending your message. \n\nAPI Error: The response from OpenAI was in an unrecognized format.', + ...(res.usage + ? { usage: res.usage } + : { usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 } }), + }; } } diff --git a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts index 4983d19d36b69..67053bef7801b 100644 --- a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts +++ b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts @@ -404,7 +404,7 @@ export default function bedrockTest({ getService }: FtrProviderContext) { expect(body).to.eql({ status: 'ok', connector_id: bedrockActionId, - data: bedrockSuccessResponse.completion, + data: { message: bedrockSuccessResponse.completion }, }); }); });