Skip to content

Commit

Permalink
[Security AI] Bedrock prompt tuning and inference corrections (elasti…
Browse files Browse the repository at this point in the history
…c#209011)

(cherry picked from commit 0d415a6)
  • Loading branch information
stephmilovic committed Feb 3, 2025
1 parent 58c53c6 commit 2975769
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ const AS_PLAIN_TEXT: EuiComboBoxSingleSelectionShape = { asPlainText: true };
*/
export const EvaluationSettings: React.FC = React.memo(() => {
const { actionTypeRegistry, http, setTraceOptions, toasts, traceOptions } = useAssistantContext();
const { data: connectors } = useLoadConnectors({ http });
const { data: connectors } = useLoadConnectors({ http, inferenceEnabled: true });
const { mutate: performEvaluation, isLoading: isPerformingEvaluation } = usePerformEvaluation({
http,
toasts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/

export { promptType } from './src/saved_object_mappings';
export { getPrompt, getPromptsByGroupId } from './src/get_prompt';
export { getPrompt, getPromptsByGroupId, resolveProviderAndModel } from './src/get_prompt';
export {
type PromptArray,
type Prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ export const getPrompt = async ({
return prompt;
};

const resolveProviderAndModel = async ({
export const resolveProviderAndModel = async ({
providedProvider,
providedModel,
connectorId,
actionsClient,
providedConnector,
}: {
providedProvider: string | undefined;
providedModel: string | undefined;
providedProvider?: string;
providedModel?: string;
connectorId: string;
actionsClient: PublicMethodsOf<ActionsClient>;
providedConnector?: Connector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ export const getDefaultAssistantGraph = ({
value: (x: boolean, y?: boolean) => y ?? x,
default: () => contentReferencesEnabled,
},
provider: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
};

// Default node parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ describe('streamGraph', () => {
input: 'input',
responseLanguage: 'English',
llmType: 'openai',
provider: 'openai',
connectorId: '123',
},
logger: mockLogger,
Expand Down Expand Up @@ -291,6 +292,7 @@ describe('streamGraph', () => {
inputs: {
...requestArgs.inputs,
llmType: 'gemini',
provider: 'gemini',
},
});

Expand All @@ -306,6 +308,7 @@ describe('streamGraph', () => {
inputs: {
...requestArgs.inputs,
llmType: 'bedrock',
provider: 'bedrock',
},
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,21 @@ export const streamGraph = async ({

// Stream is from openai functions agent
let finalMessage = '';
const stream = assistantGraph.streamEvents(inputs, {
callbacks: [
apmTracer,
...(traceOptions?.tracers ?? []),
...(telemetryTracer ? [telemetryTracer] : []),
],
runName: DEFAULT_ASSISTANT_GRAPH_ID,
streamMode: 'values',
tags: traceOptions?.tags ?? [],
version: 'v1',
});
const stream = assistantGraph.streamEvents(
inputs,
{
callbacks: [
apmTracer,
...(traceOptions?.tracers ?? []),
...(telemetryTracer ? [telemetryTracer] : []),
],
runName: DEFAULT_ASSISTANT_GRAPH_ID,
streamMode: 'values',
tags: traceOptions?.tags ?? [],
version: 'v1',
},
inputs?.provider === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
);

const pushStreamUpdate = async () => {
for await (const { event, data, tags } of stream) {
Expand All @@ -155,8 +159,6 @@ export const streamGraph = async ({
const chunk = data?.chunk;
const msg = chunk.message;
if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) {
// I don't think we hit this anymore because of our check for AGENT_NODE_TAG
// however, no harm to keep it in
/* empty */
} else if (!didEnd) {
push({ payload: msg.content, type: 'content' });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ import {
} from 'langchain/agents';
import { contentReferencesStoreFactoryMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { savedObjectsClientMock } from '@kbn/core-saved-objects-api-server-mocks';
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
jest.mock('./graph');
jest.mock('./helpers');
jest.mock('langchain/agents');
jest.mock('@kbn/langchain/server/tracers/apm');
jest.mock('@kbn/langchain/server/tracers/telemetry');
jest.mock('@kbn/security-ai-prompts');
const getDefaultAssistantGraphMock = getDefaultAssistantGraph as jest.Mock;
const resolveProviderAndModelMock = resolveProviderAndModel as jest.Mock;
describe('callAssistantGraph', () => {
const mockDataClients = {
anonymizationFieldsDataClient: {
Expand Down Expand Up @@ -83,6 +86,9 @@ describe('callAssistantGraph', () => {
jest.clearAllMocks();
(mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockResolvedValue(true);
getDefaultAssistantGraphMock.mockReturnValue({});
resolveProviderAndModelMock.mockResolvedValue({
provider: 'bedrock',
});
(invokeGraph as jest.Mock).mockResolvedValue({
output: 'test-output',
traceData: {},
Expand Down Expand Up @@ -224,5 +230,23 @@ describe('callAssistantGraph', () => {
expect(createOpenAIToolsAgent).not.toHaveBeenCalled();
expect(createToolCallingAgent).not.toHaveBeenCalled();
});
it('does not calls resolveProviderAndModel when llmType === openai', async () => {
const params = { ...defaultParams, llmType: 'openai' };
await callAssistantGraph(params);

expect(resolveProviderAndModelMock).not.toHaveBeenCalled();
});
it('calls resolveProviderAndModel when llmType === inference', async () => {
const params = { ...defaultParams, llmType: 'inference' };
await callAssistantGraph(params);

expect(resolveProviderAndModelMock).toHaveBeenCalled();
});
it('calls resolveProviderAndModel when llmType === undefined', async () => {
const params = { ...defaultParams, llmType: undefined };
await callAssistantGraph(params);

expect(resolveProviderAndModelMock).toHaveBeenCalled();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry';
import { pruneContentReferences, MessageMetadata } from '@kbn/elastic-assistant-common';
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
import { promptGroupId } from '../../../prompt/local_prompt_object';
import { getModelOrOss } from '../../../prompt/helpers';
import { getPrompt, promptDictionary } from '../../../prompt';
Expand Down Expand Up @@ -183,6 +184,13 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
logger
)
: undefined;
const { provider } =
!llmType || llmType === 'inference'
? await resolveProviderAndModel({
connectorId,
actionsClient,
})
: { provider: llmType };
const assistantGraph = getDefaultAssistantGraph({
agentRunnable,
dataClients,
Expand All @@ -205,6 +213,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
isStream,
isOssModel,
input: latestMessage[0]?.content as string,
provider: provider ?? '',
};

if (isStream) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ interface ModelInputParams extends NodeParamsBase {
*/
export function modelInput({ logger, state }: ModelInputParams): Partial<AgentState> {
logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`);

const hasRespondStep = state.isStream && (state.isOssModel || state.llmType === 'bedrock');
const hasRespondStep = state.isStream && (state.isOssModel || state.provider === 'bedrock');

return {
hasRespondStep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export interface GraphInputs {
isStream?: boolean;
isOssModel?: boolean;
input: string;
provider: string;
responseLanguage?: string;
}

Expand All @@ -37,6 +38,7 @@ export interface AgentState extends AgentStateBase {
isStream: boolean;
isOssModel: boolean;
llmType: string;
provider: string;
responseLanguage: string;
connectorId: string;
conversation: ConversationResponse | undefined;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const BASE_GEMINI_PROMPT =
const KB_CATCH =
'If the knowledge base tool gives empty results, do your best to answer the question from the perspective of an expert security analyst.';
export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH} {include_citations_prompt_placeholder}`;
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from NaturalLanguageESQLTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response. ALWAYS return the exact response from NaturalLanguageESQLTool verbatim in the final response, without adding further description.`;
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;

export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. ${KNOWLEDGE_HISTORY} You have access to the following tools:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/
import { getDefaultArguments } from '@kbn/langchain/server';
import { StructuredTool } from '@langchain/core/tools';
import {
createOpenAIFunctionsAgent,
createOpenAIToolsAgent,
createStructuredChatAgent,
createToolCallingAgent,
} from 'langchain/agents';
Expand Down Expand Up @@ -331,26 +331,27 @@ export const postEvaluateRoute = (
savedObjectsClient,
});

const agentRunnable = isOpenAI
? await createOpenAIFunctionsAgent({
llm,
tools,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: llmType && ['bedrock', 'gemini'].includes(llmType)
? createToolCallingAgent({
llm,
tools,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: await createStructuredChatAgent({
llm,
tools,
prompt: formatPromptStructured(defaultSystemPrompt),
streamRunnable: false,
});
const agentRunnable =
isOpenAI || llmType === 'inference'
? await createOpenAIToolsAgent({
llm,
tools,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: llmType && ['bedrock', 'gemini'].includes(llmType)
? createToolCallingAgent({
llm,
tools,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: await createStructuredChatAgent({
llm,
tools,
prompt: formatPromptStructured(defaultSystemPrompt),
streamRunnable: false,
});

return {
connectorId: connector.id,
Expand Down

0 comments on commit 2975769

Please sign in to comment.