Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.x] [Security Assistant] Product documentation tool (#199694) #203019

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion x-pack/plugins/elastic_assistant/kibana.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
"ml",
"taskManager",
"licensing",
"llmTasks",
"inference",
"productDocBase",
"spaces",
"security"
]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export const createMockClients = () => {
getSpaceId: jest.fn(),
getCurrentUser: jest.fn(),
inference: jest.fn(),
llmTasks: jest.fn(),
},
savedObjectsClient: core.savedObjects.client,

Expand Down Expand Up @@ -145,6 +146,7 @@ const createElasticAssistantRequestContextMock = (
getServerBasePath: jest.fn(),
getSpaceId: jest.fn().mockReturnValue('default'),
inference: { getClient: jest.fn() },
llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() },
core: clients.core,
telemetry: clients.elasticAssistant.telemetry,
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ensureProductDocumentationInstalled } from './helpers';
import { loggerMock } from '@kbn/logging-mocks';

const mockLogger = loggerMock.create();
const mockProductDocManager = {
getStatus: jest.fn(),
install: jest.fn(),
uninstall: jest.fn(),
update: jest.fn(),
};

describe('helpers', () => {
describe('ensureProductDocumentationInstalled', () => {
beforeEach(() => {
jest.clearAllMocks();
});

it('should install product documentation if not installed', async () => {
mockProductDocManager.getStatus.mockResolvedValue({ status: 'uninstalled' });
mockProductDocManager.install.mockResolvedValue(null);

await ensureProductDocumentationInstalled(mockProductDocManager, mockLogger);

expect(mockProductDocManager.getStatus).toHaveBeenCalled();
expect(mockLogger.debug).toHaveBeenCalledWith(
'Installing product documentation for AIAssistantService'
);
expect(mockProductDocManager.install).toHaveBeenCalled();
expect(mockLogger.debug).toHaveBeenNthCalledWith(
2,
'Successfully installed product documentation for AIAssistantService'
);
});

it('should not install product documentation if already installed', async () => {
mockProductDocManager.getStatus.mockResolvedValue({ status: 'installed' });

await ensureProductDocumentationInstalled(mockProductDocManager, mockLogger);

expect(mockProductDocManager.getStatus).toHaveBeenCalled();
expect(mockProductDocManager.install).not.toHaveBeenCalled();
expect(mockLogger.debug).not.toHaveBeenCalledWith(
'Installing product documentation for AIAssistantService'
);
});
it('should log a warning if install fails', async () => {
mockProductDocManager.getStatus.mockResolvedValue({ status: 'not_installed' });
mockProductDocManager.install.mockRejectedValue(new Error('Install failed'));

await ensureProductDocumentationInstalled(mockProductDocManager, mockLogger);

expect(mockProductDocManager.getStatus).toHaveBeenCalled();
expect(mockProductDocManager.install).toHaveBeenCalled();

expect(mockLogger.warn).toHaveBeenCalledWith(
'Failed to install product documentation for AIAssistantService: Install failed'
);
});

it('should log a warning if getStatus fails', async () => {
mockProductDocManager.getStatus.mockRejectedValue(new Error('Status check failed'));

await ensureProductDocumentationInstalled(mockProductDocManager, mockLogger);

expect(mockProductDocManager.getStatus).toHaveBeenCalled();
expect(mockLogger.warn).toHaveBeenCalledWith(
'Failed to get status of product documentation installation for AIAssistantService: Status check failed'
);
expect(mockProductDocManager.install).not.toHaveBeenCalled();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-ser
import type { MlPluginSetup } from '@kbn/ml-plugin/server';
import { DeleteByQueryRequest } from '@elastic/elasticsearch/lib/api/types';
import { i18n } from '@kbn/i18n';
import { ProductDocBaseStartContract } from '@kbn/product-doc-base-plugin/server';
import type { Logger } from '@kbn/logging';
import { getResourceName } from '.';
import { knowledgeBaseIngestPipeline } from '../ai_assistant_data_clients/knowledge_base/ingest_pipeline';
import { GetElser } from '../types';
Expand Down Expand Up @@ -141,3 +143,25 @@ const ESQL_QUERY_GENERATION_TITLE = i18n.translate(
defaultMessage: 'ES|QL Query Generation',
}
);

export const ensureProductDocumentationInstalled = async (
productDocManager: ProductDocBaseStartContract['management'],
logger: Logger
) => {
try {
const { status } = await productDocManager.getStatus();
if (status !== 'installed') {
logger.debug(`Installing product documentation for AIAssistantService`);
try {
await productDocManager.install();
logger.debug(`Successfully installed product documentation for AIAssistantService`);
} catch (e) {
logger.warn(`Failed to install product documentation for AIAssistantService: ${e.message}`);
}
}
} catch (e) {
logger.warn(
`Failed to get status of product documentation installation for AIAssistantService: ${e.message}`
);
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ describe('AI Assistant Service', () => {
kibanaVersion: '8.8.0',
ml,
taskManager: taskManagerMock.createSetup(),
productDocManager: Promise.resolve({
getStatus: jest.fn(),
install: jest.fn(),
update: jest.fn(),
uninstall: jest.fn(),
}),
};
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import type { TaskManagerSetupContract } from '@kbn/task-manager-plugin/server';
import type { MlPluginSetup } from '@kbn/ml-plugin/server';
import { Subject } from 'rxjs';
import { LicensingApiRequestHandlerContext } from '@kbn/licensing-plugin/server';
import { ProductDocBaseStartContract } from '@kbn/product-doc-base-plugin/server';
import { attackDiscoveryFieldMap } from '../lib/attack_discovery/persistence/field_maps_configuration/field_maps_configuration';
import { defendInsightsFieldMap } from '../ai_assistant_data_clients/defend_insights/field_maps_configuration';
import { getDefaultAnonymizationFields } from '../../common/anonymization';
Expand All @@ -35,7 +36,12 @@ import {
} from '../ai_assistant_data_clients/knowledge_base';
import { AttackDiscoveryDataClient } from '../lib/attack_discovery/persistence';
import { DefendInsightsDataClient } from '../ai_assistant_data_clients/defend_insights';
import { createGetElserId, createPipeline, pipelineExists } from './helpers';
import {
createGetElserId,
createPipeline,
ensureProductDocumentationInstalled,
pipelineExists,
} from './helpers';
import { hasAIAssistantLicense } from '../routes/helpers';

const TOTAL_FIELDS_LIMIT = 2500;
Expand All @@ -51,6 +57,7 @@ export interface AIAssistantServiceOpts {
ml: MlPluginSetup;
taskManager: TaskManagerSetupContract;
pluginStop$: Subject<void>;
productDocManager: Promise<ProductDocBaseStartContract['management']>;
}

export interface CreateAIAssistantClientParams {
Expand Down Expand Up @@ -87,6 +94,7 @@ export class AIAssistantService {
private initPromise: Promise<InitializationPromise>;
private isKBSetupInProgress: boolean = false;
private hasInitializedV2KnowledgeBase: boolean = false;
private productDocManager?: ProductDocBaseStartContract['management'];

constructor(private readonly options: AIAssistantServiceOpts) {
this.initialized = false;
Expand Down Expand Up @@ -129,6 +137,13 @@ export class AIAssistantService {
this.initPromise,
this.installAndUpdateSpaceLevelResources.bind(this)
);
options.productDocManager
.then((productDocManager) => {
this.productDocManager = productDocManager;
})
.catch((error) => {
this.options.logger.warn(`Failed to initialize productDocManager: ${error.message}`);
});
}

public isInitialized() {
Expand Down Expand Up @@ -183,6 +198,11 @@ export class AIAssistantService {
this.options.logger.debug(`Initializing resources for AIAssistantService`);
const esClient = await this.options.elasticsearchClientPromise;

if (this.productDocManager) {
// install product documentation without blocking other resources
void ensureProductDocumentationInstalled(this.productDocManager, this.options.logger);
}

await this.conversationsDataStream.install({
esClient,
logger: this.options.logger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { PublicMethodsOf } from '@kbn/utility-types';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import { AnalyticsServiceSetup } from '@kbn/core-analytics-server';
import { TelemetryParams } from '@kbn/langchain/server/tracers/telemetry/telemetry_tracer';
import type { LlmTasksPluginStart } from '@kbn/llm-tasks-plugin/server';
import { ResponseBody } from '../types';
import type { AssistantTool } from '../../../types';
import { AIAssistantKnowledgeBaseDataClient } from '../../../ai_assistant_data_clients/knowledge_base';
Expand Down Expand Up @@ -45,10 +46,11 @@ export interface AgentExecutorParams<T extends boolean> {
dataClients?: AssistantDataClients;
esClient: ElasticsearchClient;
langChainMessages: BaseMessage[];
llmTasks?: LlmTasksPluginStart;
llmType?: string;
isOssModel?: boolean;
logger: Logger;
inference: InferenceServerStart;
logger: Logger;
onNewReplacements?: (newReplacements: Replacements) => void;
replacements: Replacements;
isStream?: T;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
esClient,
inference,
langChainMessages,
llmTasks,
llmType,
isOssModel,
logger: parentLogger,
Expand Down Expand Up @@ -106,6 +107,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
inference,
isEnabledKnowledgeBase,
kbDataClient: dataClients?.kbDataClient,
llmTasks,
logger,
onNewReplacements,
replacements,
Expand Down
3 changes: 3 additions & 0 deletions x-pack/plugins/elastic_assistant/server/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ export class ElasticAssistantPlugin
elasticsearchClientPromise: core
.getStartServices()
.then(([{ elasticsearch }]) => elasticsearch.client.asInternalUser),
productDocManager: core
.getStartServices()
.then(([_, { productDocBase }]) => productDocBase.management),
pluginStop$: this.pluginStop$,
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ const mockContext = {
getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures),
logger: loggingSystemMock.createLogger(),
telemetry: { ...coreMock.createSetup().analytics, reportEvent },
llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() },
getCurrentUser: () => ({
username: 'user',
email: 'email',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ export const chatCompleteRoute = (
try {
telemetry = ctx.elasticAssistant.telemetry;
const inference = ctx.elasticAssistant.inference;
const productDocsAvailable =
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false;

// Perform license and authenticated user checks
const checkResponse = performChecks({
Expand Down Expand Up @@ -217,6 +219,7 @@ export const chatCompleteRoute = (
response,
telemetry,
responseLanguage: request.body.responseLanguage,
...(productDocsAvailable ? { llmTasks: ctx.elasticAssistant.llmTasks } : {}),
});
} catch (err) {
const error = transformError(err as Error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ export const postEvaluateRoute = (
const esClient = ctx.core.elasticsearch.client.asCurrentUser;

const inference = ctx.elasticAssistant.inference;
const productDocsAvailable =
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false;

// Data clients
const anonymizationFieldsDataClient =
Expand Down Expand Up @@ -280,6 +282,7 @@ export const postEvaluateRoute = (
connectorId: connector.id,
size,
telemetry: ctx.elasticAssistant.telemetry,
...(productDocsAvailable ? { llmTasks: ctx.elasticAssistant.llmTasks } : {}),
};

const tools: StructuredTool[] = assistantTools.flatMap(
Expand Down
4 changes: 4 additions & 0 deletions x-pack/plugins/elastic_assistant/server/routes/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import { ActionsClient } from '@kbn/actions-plugin/server';
import { AssistantFeatureKey } from '@kbn/elastic-assistant-common/impl/capabilities';
import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import type { LlmTasksPluginStart } from '@kbn/llm-tasks-plugin/server';
import { INVOKE_ASSISTANT_SUCCESS_EVENT } from '../lib/telemetry/event_based_telemetry';
import { AIAssistantKnowledgeBaseDataClient } from '../ai_assistant_data_clients/knowledge_base';
import { FindResponse } from '../ai_assistant_data_clients/find';
Expand Down Expand Up @@ -215,6 +216,7 @@ export interface LangChainExecuteParams {
telemetry: AnalyticsServiceSetup;
actionTypeId: string;
connectorId: string;
llmTasks?: LlmTasksPluginStart;
inference: InferenceServerStart;
isOssModel?: boolean;
conversationId?: string;
Expand Down Expand Up @@ -246,6 +248,7 @@ export const langChainExecute = async ({
isOssModel,
context,
actionsClient,
llmTasks,
inference,
request,
logger,
Expand Down Expand Up @@ -301,6 +304,7 @@ export const langChainExecute = async ({
conversationId,
connectorId,
esClient,
llmTasks,
inference,
isStream,
llmType: getLlmType(actionTypeId),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const mockContext = {
actions: {
getActionsClientWithRequest: jest.fn().mockResolvedValue(actionsClient),
},
llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() },
getRegisteredTools: jest.fn(() => []),
getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures),
logger: loggingSystemMock.createLogger(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ export const postActionsConnectorExecuteRoute = (
// get the actions plugin start contract from the request context:
const actions = ctx.elasticAssistant.actions;
const inference = ctx.elasticAssistant.inference;
const productDocsAvailable =
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false;
const actionsClient = await actions.getActionsClientWithRequest(request);
const connectors = await actionsClient.getBulk({ ids: [connectorId] });
const connector = connectors.length > 0 ? connectors[0] : undefined;
Expand Down Expand Up @@ -150,6 +152,7 @@ export const postActionsConnectorExecuteRoute = (
response,
telemetry,
systemPrompt,
...(productDocsAvailable ? { llmTasks: ctx.elasticAssistant.llmTasks } : {}),
});
} catch (err) {
logger.error(err);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export class RequestContextFactory implements IRequestContextFactory {
getRegisteredFeatures: (pluginName: string) => {
return appContextService.getRegisteredFeatures(pluginName);
},

llmTasks: startPlugins.llmTasks,
inference: startPlugins.inference,

telemetry: core.analytics,
Expand Down
Loading
Loading