Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security solution] Fix OpenAI token reporting #169156

Merged
merged 5 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ describe('API tests', () => {
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/actions/connector/foo/_execute',
{
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"}}',
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":true}',
headers: { 'Content-Type': 'application/json' },
method: 'POST',
signal: undefined,
Expand All @@ -72,12 +72,15 @@ describe('API tests', () => {

await fetchConnectorExecuteAction(testProps);

expect(mockHttp.fetch).toHaveBeenCalledWith('/api/actions/connector/foo/_execute', {
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"}}',
headers: { 'Content-Type': 'application/json' },
method: 'POST',
signal: undefined,
});
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/actions/connector/foo/_execute',
{
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":false}',
headers: { 'Content-Type': 'application/json' },
method: 'POST',
signal: undefined,
}
);
});

it('returns API_ERROR when the response status is not ok', async () => {
Expand Down
7 changes: 2 additions & 5 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,16 @@ export const fetchConnectorExecuteAction = async ({
subActionParams: body,
subAction: 'invokeAI',
},
assistantLangChain,
};

try {
const path = assistantLangChain
? `/internal/elastic_assistant/actions/connector/${apiConfig?.connectorId}/_execute`
: `/api/actions/connector/${apiConfig?.connectorId}/_execute`;

const response = await http.fetch<{
connector_id: string;
status: string;
data: string;
service_message?: string;
}>(path, {
}>(`/internal/elastic_assistant/actions/connector/${apiConfig?.connectorId}/_execute`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@
* 2.0.
*/

export const mockActionResponse = 'Yes, your name is Andrew. How can I assist you further, Andrew?';
export const mockActionResponse = {
message: 'Yes, your name is Andrew. How can I assist you further, Andrew?',
usage: { prompt_tokens: 4, completion_tokens: 10, total_tokens: 14 },
};
43 changes: 43 additions & 0 deletions x-pack/plugins/elastic_assistant/server/lib/executor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { get } from 'lodash/fp';
import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import { KibanaRequest } from '@kbn/core-http-server';
import { RequestBody } from './langchain/types';

interface Props {
actions: ActionsPluginStart;
connectorId: string;
request: KibanaRequest<unknown, unknown, RequestBody>;
}
interface StaticResponse {
connector_id: string;
data: string;
status: string;
}

export const executeAction = async ({
actions,
request,
connectorId,
}: Props): Promise<StaticResponse> => {
const actionsClient = await actions.getActionsClientWithRequest(request);
const actionResult = await actionsClient.execute({
actionId: connectorId,
params: request.body.params,
});
const content = get('data.message', actionResult);
if (typeof content === 'string') {
return {
connector_id: connectorId,
data: content, // the response from the actions framework
status: 'ok',
};
}
throw new Error('Unexpected action result');
};
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const mockRequest: KibanaRequest<unknown, unknown, RequestBody> = {
},
subAction: 'invokeAI',
},
assistantLangChain: true,
},
} as KibanaRequest<unknown, unknown, RequestBody>;

Expand All @@ -72,7 +73,7 @@ describe('ActionsClientLlm', () => {

await actionsClientLlm._call(prompt); // ignore the result

expect(actionsClientLlm.getActionResultData()).toEqual(mockActionResponse);
expect(actionsClientLlm.getActionResultData()).toEqual(mockActionResponse.message);
});
});

Expand Down Expand Up @@ -141,7 +142,7 @@ describe('ActionsClientLlm', () => {
});

it('rejects with the expected error the message has invalid content', async () => {
const invalidContent = 1234;
const invalidContent = { message: 1234 };

mockExecute.mockImplementation(() => ({
data: invalidContent,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ export class ActionsClientLlm extends LLM {
`${LLM_TYPE}: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}`
);
}

// TODO: handle errors from the connector
const content = get('data', actionResult);
const content = get('data.message', actionResult);

if (typeof content !== 'string') {
throw new Error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ export const postEvaluateRoute = (
messages: [],
},
},
assistantLangChain: true,
},
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ import { coreMock } from '@kbn/core/server/mocks';
jest.mock('../lib/build_response', () => ({
buildResponse: jest.fn().mockImplementation((x) => x),
}));
jest.mock('../lib/executor', () => ({
executeAction: jest.fn().mockImplementation((x) => ({
connector_id: 'mock-connector-id',
data: mockActionResponse,
status: 'ok',
})),
}));

jest.mock('../lib/langchain/execute_custom_llm_chain', () => ({
callAgentExecutor: jest.fn().mockImplementation(
Expand Down Expand Up @@ -82,6 +89,7 @@ const mockRequest = {
},
subAction: 'invokeAI',
},
assistantLangChain: true,
},
};

Expand All @@ -97,7 +105,38 @@ describe('postActionsConnectorExecuteRoute', () => {
jest.clearAllMocks();
});

it('returns the expected response', async () => {
it('returns the expected response when assistantLangChain=false', async () => {
const mockRouter = {
post: jest.fn().mockImplementation(async (_, handler) => {
const result = await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
assistantLangChain: false,
},
},
mockResponse
);

expect(result).toEqual({
body: {
connector_id: 'mock-connector-id',
data: mockActionResponse,
status: 'ok',
},
});
}),
};

await postActionsConnectorExecuteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});

it('returns the expected response when assistantLangChain=true', async () => {
const mockRouter = {
post: jest.fn().mockImplementation(async (_, handler) => {
const result = await handler(mockContext, mockRequest, mockResponse);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import { IRouter, Logger } from '@kbn/core/server';
import { transformError } from '@kbn/securitysolution-es-utils';
import { executeAction } from '../lib/executor';
import { POST_ACTIONS_CONNECTOR_EXECUTE } from '../../common/constants';
import { getLangChainMessages } from '../lib/langchain/helpers';
import { buildResponse } from '../lib/build_response';
Expand Down Expand Up @@ -41,6 +42,14 @@ export const postActionsConnectorExecuteRoute = (
// get the actions plugin start contract from the request context:
const actions = (await context.elasticAssistant).actions;

// if not langchain, call execute action directly and return the response:
if (!request.body.assistantLangChain) {
const result = await executeAction({ actions, request, connectorId });
return response.ok({
body: result,
});
}

Comment on lines +45 to +52
Copy link
Member

@spong spong Oct 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did assistantLangChain need to be pushed server side for this fix? @andrew-goldstein intentionally kept the split in logic between calling the actions framework, and the new langchain implementation at the client API layer for separation of concerns. There's no intent to continue supporting an actions 'pass through' once this new API is validated (at least AFAIK), so unless this is somehow required for the fix I'd be hesitant to push this logic into this route.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is going to be necessary for the upcoming streaming work, so all good! 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, decided to move it new because I wanted to parse the response to continue returning a string from this endpoint, and I knew I'd be moving it for streaming anyways.

// get a scoped esClient for assistant memory
const esClient = (await context.core).elasticsearch.client.asCurrentUser;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export const PostActionsConnectorExecuteBody = t.type({
]),
subAction: t.string,
}),
assistantLangChain: t.boolean,
});

export type PostActionsConnectorExecuteBodyInputs = t.TypeOf<
Expand Down
4 changes: 3 additions & 1 deletion x-pack/plugins/stack_connectors/common/bedrock/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ export const InvokeAIActionParamsSchema = schema.object({
model: schema.maybe(schema.string()),
});

export const InvokeAIActionResponseSchema = schema.string();
export const InvokeAIActionResponseSchema = schema.object({
message: schema.string(),
});

export const RunActionResponseSchema = schema.object(
{
Expand Down
12 changes: 11 additions & 1 deletion x-pack/plugins/stack_connectors/common/openai/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,17 @@ export const InvokeAIActionParamsSchema = schema.object({
temperature: schema.maybe(schema.number()),
});

export const InvokeAIActionResponseSchema = schema.string();
export const InvokeAIActionResponseSchema = schema.object({
message: schema.string(),
usage: schema.object(
{
prompt_tokens: schema.number(),
completion_tokens: schema.number(),
total_tokens: schema.number(),
},
{ unknowns: 'ignore' }
),
});

// Execute action schema
export const StreamActionParamsSchema = schema.object({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ describe('BedrockConnector', () => {
stop_sequences: ['\n\nHuman:'],
}),
});
expect(response).toEqual(mockResponseString);
expect(response.message).toEqual(mockResponseString);
});

it('Properly formats messages from user, assistant, and system', async () => {
Expand Down Expand Up @@ -148,7 +148,7 @@ describe('BedrockConnector', () => {
stop_sequences: ['\n\nHuman:'],
}),
});
expect(response).toEqual(mockResponseString);
expect(response.message).toEqual(mockResponseString);
});

it('errors during API calls are properly handled', async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,6 @@ export class BedrockConnector extends SubActionConnector<Config, Secrets> {
};

const res = await this.runApi({ body: JSON.stringify(req), model });
return res.completion.trim();
return { message: res.completion.trim() };
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ describe('OpenAIConnector', () => {
index: 0,
},
],
usage: {
prompt_tokens: 4,
completion_tokens: 5,
total_tokens: 9,
},
},
};
beforeEach(() => {
Expand Down Expand Up @@ -273,7 +278,8 @@ describe('OpenAIConnector', () => {
'content-type': 'application/json',
},
});
expect(response).toEqual(mockResponseString);
expect(response.message).toEqual(mockResponseString);
expect(response.usage.total_tokens).toEqual(9);
});

it('errors during API calls are properly handled', async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,15 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {

if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) {
const result = res.choices[0].message.content.trim();
return result;
return { message: result, usage: res.usage };
}

return 'An error occurred sending your message. \n\nAPI Error: The response from OpenAI was in an unrecognized format.';
return {
message:
'An error occurred sending your message. \n\nAPI Error: The response from OpenAI was in an unrecognized format.',
...(res.usage
? { usage: res.usage }
: { usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 } }),
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ export default function bedrockTest({ getService }: FtrProviderContext) {
expect(body).to.eql({
status: 'ok',
connector_id: bedrockActionId,
data: bedrockSuccessResponse.completion,
data: { message: bedrockSuccessResponse.completion },
});
});
});
Expand Down
Loading