Skip to content

Commit

Permalink
[Security solution] Fix OpenAI token reporting (#169156)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic authored Oct 18, 2023
1 parent 7df3f96 commit 8284398
Show file tree
Hide file tree
Showing 17 changed files with 147 additions and 27 deletions.
17 changes: 10 additions & 7 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ describe('API tests', () => {
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/actions/connector/foo/_execute',
{
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"}}',
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":true}',
headers: { 'Content-Type': 'application/json' },
method: 'POST',
signal: undefined,
Expand All @@ -72,12 +72,15 @@ describe('API tests', () => {

await fetchConnectorExecuteAction(testProps);

expect(mockHttp.fetch).toHaveBeenCalledWith('/api/actions/connector/foo/_execute', {
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"}}',
headers: { 'Content-Type': 'application/json' },
method: 'POST',
signal: undefined,
});
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/actions/connector/foo/_execute',
{
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":false}',
headers: { 'Content-Type': 'application/json' },
method: 'POST',
signal: undefined,
}
);
});

it('returns API_ERROR when the response status is not ok', async () => {
Expand Down
7 changes: 2 additions & 5 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,16 @@ export const fetchConnectorExecuteAction = async ({
subActionParams: body,
subAction: 'invokeAI',
},
assistantLangChain,
};

try {
const path = assistantLangChain
? `/internal/elastic_assistant/actions/connector/${apiConfig?.connectorId}/_execute`
: `/api/actions/connector/${apiConfig?.connectorId}/_execute`;

const response = await http.fetch<{
connector_id: string;
status: string;
data: string;
service_message?: string;
}>(path, {
}>(`/internal/elastic_assistant/actions/connector/${apiConfig?.connectorId}/_execute`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@
* 2.0.
*/

export const mockActionResponse = 'Yes, your name is Andrew. How can I assist you further, Andrew?';
export const mockActionResponse = {
message: 'Yes, your name is Andrew. How can I assist you further, Andrew?',
usage: { prompt_tokens: 4, completion_tokens: 10, total_tokens: 14 },
};
43 changes: 43 additions & 0 deletions x-pack/plugins/elastic_assistant/server/lib/executor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { get } from 'lodash/fp';
import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import { KibanaRequest } from '@kbn/core-http-server';
import { RequestBody } from './langchain/types';

interface Props {
actions: ActionsPluginStart;
connectorId: string;
request: KibanaRequest<unknown, unknown, RequestBody>;
}
interface StaticResponse {
connector_id: string;
data: string;
status: string;
}

export const executeAction = async ({
actions,
request,
connectorId,
}: Props): Promise<StaticResponse> => {
const actionsClient = await actions.getActionsClientWithRequest(request);
const actionResult = await actionsClient.execute({
actionId: connectorId,
params: request.body.params,
});
const content = get('data.message', actionResult);
if (typeof content === 'string') {
return {
connector_id: connectorId,
data: content, // the response from the actions framework
status: 'ok',
};
}
throw new Error('Unexpected action result');
};
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const mockRequest: KibanaRequest<unknown, unknown, RequestBody> = {
},
subAction: 'invokeAI',
},
assistantLangChain: true,
},
} as KibanaRequest<unknown, unknown, RequestBody>;

Expand All @@ -72,7 +73,7 @@ describe('ActionsClientLlm', () => {

await actionsClientLlm._call(prompt); // ignore the result

expect(actionsClientLlm.getActionResultData()).toEqual(mockActionResponse);
expect(actionsClientLlm.getActionResultData()).toEqual(mockActionResponse.message);
});
});

Expand Down Expand Up @@ -141,7 +142,7 @@ describe('ActionsClientLlm', () => {
});

it('rejects with the expected error the message has invalid content', async () => {
const invalidContent = 1234;
const invalidContent = { message: 1234 };

mockExecute.mockImplementation(() => ({
data: invalidContent,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ export class ActionsClientLlm extends LLM {
`${LLM_TYPE}: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}`
);
}

// TODO: handle errors from the connector
const content = get('data', actionResult);
const content = get('data.message', actionResult);

if (typeof content !== 'string') {
throw new Error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ export const postEvaluateRoute = (
messages: [],
},
},
assistantLangChain: true,
},
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ import { coreMock } from '@kbn/core/server/mocks';
jest.mock('../lib/build_response', () => ({
buildResponse: jest.fn().mockImplementation((x) => x),
}));
jest.mock('../lib/executor', () => ({
executeAction: jest.fn().mockImplementation((x) => ({
connector_id: 'mock-connector-id',
data: mockActionResponse,
status: 'ok',
})),
}));

jest.mock('../lib/langchain/execute_custom_llm_chain', () => ({
callAgentExecutor: jest.fn().mockImplementation(
Expand Down Expand Up @@ -82,6 +89,7 @@ const mockRequest = {
},
subAction: 'invokeAI',
},
assistantLangChain: true,
},
};

Expand All @@ -97,7 +105,38 @@ describe('postActionsConnectorExecuteRoute', () => {
jest.clearAllMocks();
});

it('returns the expected response', async () => {
it('returns the expected response when assistantLangChain=false', async () => {
const mockRouter = {
post: jest.fn().mockImplementation(async (_, handler) => {
const result = await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
assistantLangChain: false,
},
},
mockResponse
);

expect(result).toEqual({
body: {
connector_id: 'mock-connector-id',
data: mockActionResponse,
status: 'ok',
},
});
}),
};

await postActionsConnectorExecuteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});

it('returns the expected response when assistantLangChain=true', async () => {
const mockRouter = {
post: jest.fn().mockImplementation(async (_, handler) => {
const result = await handler(mockContext, mockRequest, mockResponse);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import { IRouter, Logger } from '@kbn/core/server';
import { transformError } from '@kbn/securitysolution-es-utils';
import { executeAction } from '../lib/executor';
import { POST_ACTIONS_CONNECTOR_EXECUTE } from '../../common/constants';
import { getLangChainMessages } from '../lib/langchain/helpers';
import { buildResponse } from '../lib/build_response';
Expand Down Expand Up @@ -41,6 +42,14 @@ export const postActionsConnectorExecuteRoute = (
// get the actions plugin start contract from the request context:
const actions = (await context.elasticAssistant).actions;

// if not langchain, call execute action directly and return the response:
if (!request.body.assistantLangChain) {
const result = await executeAction({ actions, request, connectorId });
return response.ok({
body: result,
});
}

// get a scoped esClient for assistant memory
const esClient = (await context.core).elasticsearch.client.asCurrentUser;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export const PostActionsConnectorExecuteBody = t.type({
]),
subAction: t.string,
}),
assistantLangChain: t.boolean,
});

export type PostActionsConnectorExecuteBodyInputs = t.TypeOf<
Expand Down
4 changes: 3 additions & 1 deletion x-pack/plugins/stack_connectors/common/bedrock/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ export const InvokeAIActionParamsSchema = schema.object({
model: schema.maybe(schema.string()),
});

export const InvokeAIActionResponseSchema = schema.string();
export const InvokeAIActionResponseSchema = schema.object({
message: schema.string(),
});

export const RunActionResponseSchema = schema.object(
{
Expand Down
12 changes: 11 additions & 1 deletion x-pack/plugins/stack_connectors/common/openai/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,17 @@ export const InvokeAIActionParamsSchema = schema.object({
temperature: schema.maybe(schema.number()),
});

export const InvokeAIActionResponseSchema = schema.string();
export const InvokeAIActionResponseSchema = schema.object({
message: schema.string(),
usage: schema.object(
{
prompt_tokens: schema.number(),
completion_tokens: schema.number(),
total_tokens: schema.number(),
},
{ unknowns: 'ignore' }
),
});

// Execute action schema
export const StreamActionParamsSchema = schema.object({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ describe('BedrockConnector', () => {
stop_sequences: ['\n\nHuman:'],
}),
});
expect(response).toEqual(mockResponseString);
expect(response.message).toEqual(mockResponseString);
});

it('Properly formats messages from user, assistant, and system', async () => {
Expand Down Expand Up @@ -148,7 +148,7 @@ describe('BedrockConnector', () => {
stop_sequences: ['\n\nHuman:'],
}),
});
expect(response).toEqual(mockResponseString);
expect(response.message).toEqual(mockResponseString);
});

it('errors during API calls are properly handled', async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,6 @@ export class BedrockConnector extends SubActionConnector<Config, Secrets> {
};

const res = await this.runApi({ body: JSON.stringify(req), model });
return res.completion.trim();
return { message: res.completion.trim() };
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ describe('OpenAIConnector', () => {
index: 0,
},
],
usage: {
prompt_tokens: 4,
completion_tokens: 5,
total_tokens: 9,
},
},
};
beforeEach(() => {
Expand Down Expand Up @@ -273,7 +278,8 @@ describe('OpenAIConnector', () => {
'content-type': 'application/json',
},
});
expect(response).toEqual(mockResponseString);
expect(response.message).toEqual(mockResponseString);
expect(response.usage.total_tokens).toEqual(9);
});

it('errors during API calls are properly handled', async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,15 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {

if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) {
const result = res.choices[0].message.content.trim();
return result;
return { message: result, usage: res.usage };
}

return 'An error occurred sending your message. \n\nAPI Error: The response from OpenAI was in an unrecognized format.';
return {
message:
'An error occurred sending your message. \n\nAPI Error: The response from OpenAI was in an unrecognized format.',
...(res.usage
? { usage: res.usage }
: { usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 } }),
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ export default function bedrockTest({ getService }: FtrProviderContext) {
expect(body).to.eql({
status: 'ok',
connector_id: bedrockActionId,
data: bedrockSuccessResponse.completion,
data: { message: bedrockSuccessResponse.completion },
});
});
});
Expand Down

0 comments on commit 8284398

Please sign in to comment.