diff --git a/x-pack/plugins/ml/common/types/trained_models.ts b/x-pack/plugins/ml/common/types/trained_models.ts index f4ed52ff21f52..25d7e231bf166 100644 --- a/x-pack/plugins/ml/common/types/trained_models.ts +++ b/x-pack/plugins/ml/common/types/trained_models.ts @@ -5,14 +5,25 @@ * 2.0. */ import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; -import type { TrainedModelType } from '@kbn/ml-trained-models-utils'; +import type { + InferenceInferenceEndpointInfo, + MlInferenceConfigCreateContainer, +} from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; +import type { + ModelDefinitionResponse, + ModelState, + TrainedModelType, +} from '@kbn/ml-trained-models-utils'; +import { + BUILT_IN_MODEL_TAG, + ELASTIC_MODEL_TAG, + TRAINED_MODEL_TYPE, +} from '@kbn/ml-trained-models-utils'; import type { DataFrameAnalyticsConfig, FeatureImportanceBaseline, TotalFeatureImportance, } from '@kbn/ml-data-frame-analytics-utils'; -import type { IndexName, IndicesIndexState } from '@elastic/elasticsearch/lib/api/types'; -import type { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils'; import type { XOR } from './common'; import type { MlSavedObjectType } from './saved_objects'; @@ -95,33 +106,12 @@ export type PutTrainedModelConfig = { >; // compressed_definition and definition are mutually exclusive export type TrainedModelConfigResponse = estypes.MlTrainedModelConfig & { - /** - * Associated pipelines. Extends response from the ES endpoint. - */ - pipelines?: Record | null; - origin_job_exists?: boolean; - - metadata?: { - analytics_config: DataFrameAnalyticsConfig; + metadata?: estypes.MlTrainedModelConfig['metadata'] & { + analytics_config?: DataFrameAnalyticsConfig; input: unknown; total_feature_importance?: TotalFeatureImportance[]; feature_importance_baseline?: FeatureImportanceBaseline; - model_aliases?: string[]; } & Record; - model_id: string; - model_type: TrainedModelType; - tags: string[]; - version: string; - inference_config?: Record; - indices?: Array>; - /** - * Whether the model has inference services - */ - hasInferenceServices?: boolean; - /** - * Inference services associated with the model - */ - inference_apis?: InferenceAPIConfigResponse[]; }; export interface PipelineDefinition { @@ -309,3 +299,125 @@ export interface ModelDownloadState { total_parts: number; downloaded_parts: number; } + +export type Stats = Omit; + +/** + * Additional properties for all items in the Trained models table + * */ +interface BaseModelItem { + type?: string[]; + tags: string[]; + /** + * Whether the model has inference services + */ + hasInferenceServices?: boolean; + /** + * Inference services associated with the model + */ + inference_apis?: InferenceInferenceEndpointInfo[]; + /** + * Associated pipelines. Extends response from the ES endpoint. + */ + pipelines?: Record; + /** + * Indices with associated pipelines that have inference processors utilizing the model deployments. + */ + indices?: string[]; +} + +/** Common properties for existing NLP models and NLP model download configs */ +interface BaseNLPModelItem extends BaseModelItem { + disclaimer?: string; + recommended?: boolean; + supported?: boolean; + state: ModelState | undefined; + downloadState?: ModelDownloadState; +} + +/** Model available for download */ +export type ModelDownloadItem = BaseNLPModelItem & + Omit & { + putModelConfig?: object; + softwareLicense?: string; + }; +/** Trained NLP model, i.e. pytorch model returned by the trained_models API */ +export type NLPModelItem = BaseNLPModelItem & + TrainedModelItem & { + stats: Stats & { deployment_stats: TrainedModelDeploymentStatsResponse[] }; + /** + * Description of the current model state + */ + stateDescription?: string; + /** + * Deployment ids extracted from the deployment stats + */ + deployment_ids: string[]; + }; + +export function isBaseNLPModelItem(item: unknown): item is BaseNLPModelItem { + return ( + typeof item === 'object' && + item !== null && + 'type' in item && + Array.isArray(item.type) && + item.type.includes(TRAINED_MODEL_TYPE.PYTORCH) + ); +} + +export function isNLPModelItem(item: unknown): item is NLPModelItem { + return isExistingModel(item) && item.model_type === TRAINED_MODEL_TYPE.PYTORCH; +} + +export const isElasticModel = (item: TrainedModelConfigResponse) => + item.tags.includes(ELASTIC_MODEL_TAG); + +export type ExistingModelBase = TrainedModelConfigResponse & BaseModelItem; + +/** Any model returned by the trained_models API, e.g. lang_ident, elser, dfa model */ +export type TrainedModelItem = ExistingModelBase & { stats: Stats }; + +/** Trained DFA model */ +export type DFAModelItem = Omit & { + origin_job_exists?: boolean; + inference_config?: Pick; + metadata?: estypes.MlTrainedModelConfig['metadata'] & { + analytics_config: DataFrameAnalyticsConfig; + input: unknown; + total_feature_importance?: TotalFeatureImportance[]; + feature_importance_baseline?: FeatureImportanceBaseline; + } & Record; +}; + +export type TrainedModelWithPipelines = TrainedModelItem & { + pipelines: Record; +}; + +export function isExistingModel(item: unknown): item is TrainedModelItem { + return ( + typeof item === 'object' && + item !== null && + 'model_type' in item && + 'create_time' in item && + !!item.create_time + ); +} + +export function isDFAModelItem(item: unknown): item is DFAModelItem { + return isExistingModel(item) && item.model_type === TRAINED_MODEL_TYPE.TREE_ENSEMBLE; +} + +export function isModelDownloadItem(item: TrainedModelUIItem): item is ModelDownloadItem { + return 'putModelConfig' in item && !!item.type?.includes(TRAINED_MODEL_TYPE.PYTORCH); +} + +export const isBuiltInModel = (item: TrainedModelConfigResponse | TrainedModelUIItem) => + item.tags.includes(BUILT_IN_MODEL_TAG); +/** + * This type represents a union of different model entities: + * - Any existing trained model returned by the API, e.g., lang_ident_model_1, DFA models, etc. + * - Hosted model configurations available for download, e.g., ELSER or E5 + * - NLP models already downloaded into Elasticsearch + * - DFA models + */ +export type TrainedModelUIItem = TrainedModelItem | ModelDownloadItem | NLPModelItem | DFAModelItem; diff --git a/x-pack/plugins/ml/public/application/components/ml_inference/add_inference_pipeline_flyout.tsx b/x-pack/plugins/ml/public/application/components/ml_inference/add_inference_pipeline_flyout.tsx index 1d58dce866449..5bd47702ed3f0 100644 --- a/x-pack/plugins/ml/public/application/components/ml_inference/add_inference_pipeline_flyout.tsx +++ b/x-pack/plugins/ml/public/application/components/ml_inference/add_inference_pipeline_flyout.tsx @@ -20,7 +20,7 @@ import { import { i18n } from '@kbn/i18n'; import { extractErrorProperties } from '@kbn/ml-error-utils'; -import type { ModelItem } from '../../model_management/models_list'; +import type { DFAModelItem } from '../../../../common/types/trained_models'; import type { AddInferencePipelineSteps } from './types'; import { ADD_INFERENCE_PIPELINE_STEPS } from './constants'; import { AddInferencePipelineFooter } from '../shared'; @@ -39,7 +39,7 @@ import { useFetchPipelines } from './hooks/use_fetch_pipelines'; export interface AddInferencePipelineFlyoutProps { onClose: () => void; - model: ModelItem; + model: DFAModelItem; } export const AddInferencePipelineFlyout: FC = ({ diff --git a/x-pack/plugins/ml/public/application/components/ml_inference/components/processor_configuration.tsx b/x-pack/plugins/ml/public/application/components/ml_inference/components/processor_configuration.tsx index cd8bdf52166e0..0803cd98679a8 100644 --- a/x-pack/plugins/ml/public/application/components/ml_inference/components/processor_configuration.tsx +++ b/x-pack/plugins/ml/public/application/components/ml_inference/components/processor_configuration.tsx @@ -25,7 +25,7 @@ import { import { i18n } from '@kbn/i18n'; import { FormattedMessage } from '@kbn/i18n-react'; import { CodeEditor } from '@kbn/code-editor'; -import type { ModelItem } from '../../../model_management/models_list'; +import type { DFAModelItem } from '../../../../../common/types/trained_models'; import { EDIT_MESSAGE, CANCEL_EDIT_MESSAGE, @@ -56,9 +56,9 @@ interface Props { condition?: string; fieldMap: MlInferenceState['fieldMap']; handleAdvancedConfigUpdate: (configUpdate: Partial) => void; - inferenceConfig: ModelItem['inference_config']; - modelInferenceConfig: ModelItem['inference_config']; - modelInputFields: ModelItem['input']; + inferenceConfig: DFAModelItem['inference_config']; + modelInferenceConfig: DFAModelItem['inference_config']; + modelInputFields: DFAModelItem['input']; modelType?: InferenceModelTypes; setHasUnsavedChanges: React.Dispatch>; tag?: string; diff --git a/x-pack/plugins/ml/public/application/components/ml_inference/state.ts b/x-pack/plugins/ml/public/application/components/ml_inference/state.ts index 787a2335717df..26bfe934eb46b 100644 --- a/x-pack/plugins/ml/public/application/components/ml_inference/state.ts +++ b/x-pack/plugins/ml/public/application/components/ml_inference/state.ts @@ -6,10 +6,10 @@ */ import { getAnalysisType } from '@kbn/ml-data-frame-analytics-utils'; +import type { DFAModelItem } from '../../../../common/types/trained_models'; import type { MlInferenceState } from './types'; -import type { ModelItem } from '../../model_management/models_list'; -export const getModelType = (model: ModelItem): string | undefined => { +export const getModelType = (model: DFAModelItem): string | undefined => { const analysisConfig = model.metadata?.analytics_config?.analysis; return analysisConfig !== undefined ? getAnalysisType(analysisConfig) : undefined; }; @@ -54,13 +54,17 @@ export const getDefaultOnFailureConfiguration = (): MlInferenceState['onFailure' }, ]; -export const getInitialState = (model: ModelItem): MlInferenceState => { +export const getInitialState = (model: DFAModelItem): MlInferenceState => { const modelType = getModelType(model); let targetField; if (modelType !== undefined) { targetField = model.inference_config - ? `ml.inference.${model.inference_config[modelType].results_field}` + ? `ml.inference.${ + model.inference_config[ + modelType as keyof Exclude + ]!.results_field + }` : undefined; } diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/components/analytics_selector/analytics_id_selector.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/components/analytics_selector/analytics_id_selector.tsx index bf786436919a9..9fe4da68aa6f8 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/components/analytics_selector/analytics_id_selector.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/components/analytics_selector/analytics_id_selector.tsx @@ -154,6 +154,7 @@ export function AnalyticsIdSelector({ async function fetchAnalyticsModels() { setIsLoading(true); try { + // FIXME should if fetch all trained models? const response = await trainedModelsApiService.getTrainedModels(); setTrainedModels(response); } catch (e) { diff --git a/x-pack/plugins/ml/public/application/model_management/add_model_flyout.tsx b/x-pack/plugins/ml/public/application/model_management/add_model_flyout.tsx index 5a92a67962579..24c8ce0915234 100644 --- a/x-pack/plugins/ml/public/application/model_management/add_model_flyout.tsx +++ b/x-pack/plugins/ml/public/application/model_management/add_model_flyout.tsx @@ -30,12 +30,12 @@ import { FormattedMessage } from '@kbn/i18n-react'; import React, { type FC, useMemo, useState } from 'react'; import { groupBy } from 'lodash'; import { ElandPythonClient } from '@kbn/inference_integration_flyout'; +import type { ModelDownloadItem } from '../../../common/types/trained_models'; import { usePermissionCheck } from '../capabilities/check_capabilities'; import { useMlKibana } from '../contexts/kibana'; -import type { ModelItem } from './models_list'; export interface AddModelFlyoutProps { - modelDownloads: ModelItem[]; + modelDownloads: ModelDownloadItem[]; onClose: () => void; onSubmit: (modelId: string) => void; } @@ -138,7 +138,7 @@ export const AddModelFlyout: FC = ({ onClose, onSubmit, mod }; interface ClickToDownloadTabContentProps { - modelDownloads: ModelItem[]; + modelDownloads: ModelDownloadItem[]; onModelDownload: (modelId: string) => void; } diff --git a/x-pack/plugins/ml/public/application/model_management/create_pipeline_for_model/create_pipeline_for_model_flyout.tsx b/x-pack/plugins/ml/public/application/model_management/create_pipeline_for_model/create_pipeline_for_model_flyout.tsx index 21fac6f6a28f8..580341800f3b5 100644 --- a/x-pack/plugins/ml/public/application/model_management/create_pipeline_for_model/create_pipeline_for_model_flyout.tsx +++ b/x-pack/plugins/ml/public/application/model_management/create_pipeline_for_model/create_pipeline_for_model_flyout.tsx @@ -21,7 +21,7 @@ import { i18n } from '@kbn/i18n'; import { extractErrorProperties } from '@kbn/ml-error-utils'; import type { SupportedPytorchTasksType } from '@kbn/ml-trained-models-utils'; -import type { ModelItem } from '../models_list'; +import type { TrainedModelItem } from '../../../../common/types/trained_models'; import type { AddInferencePipelineSteps } from '../../components/ml_inference/types'; import { ADD_INFERENCE_PIPELINE_STEPS } from '../../components/ml_inference/constants'; import { AddInferencePipelineFooter } from '../../components/shared'; @@ -40,7 +40,7 @@ import { useTestTrainedModelsContext } from '../test_models/test_trained_models_ export interface CreatePipelineForModelFlyoutProps { onClose: (refreshList?: boolean) => void; - model: ModelItem; + model: TrainedModelItem; } export const CreatePipelineForModelFlyout: FC = ({ diff --git a/x-pack/plugins/ml/public/application/model_management/create_pipeline_for_model/state.ts b/x-pack/plugins/ml/public/application/model_management/create_pipeline_for_model/state.ts index 603a542e7964f..586537222c3c5 100644 --- a/x-pack/plugins/ml/public/application/model_management/create_pipeline_for_model/state.ts +++ b/x-pack/plugins/ml/public/application/model_management/create_pipeline_for_model/state.ts @@ -7,8 +7,8 @@ import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; import type { IngestInferenceProcessor } from '@elastic/elasticsearch/lib/api/types'; +import type { TrainedModelItem } from '../../../../common/types/trained_models'; import { getDefaultOnFailureConfiguration } from '../../components/ml_inference/state'; -import type { ModelItem } from '../models_list'; export interface InferecePipelineCreationState { creatingPipeline: boolean; @@ -26,7 +26,7 @@ export interface InferecePipelineCreationState { } export const getInitialState = ( - model: ModelItem, + model: TrainedModelItem, initialPipelineConfig: estypes.IngestPipeline | undefined ): InferecePipelineCreationState => ({ creatingPipeline: false, diff --git a/x-pack/plugins/ml/public/application/model_management/create_pipeline_for_model/test_trained_model.tsx b/x-pack/plugins/ml/public/application/model_management/create_pipeline_for_model/test_trained_model.tsx index 46ec8a6060ac5..ba25e3b26f920 100644 --- a/x-pack/plugins/ml/public/application/model_management/create_pipeline_for_model/test_trained_model.tsx +++ b/x-pack/plugins/ml/public/application/model_management/create_pipeline_for_model/test_trained_model.tsx @@ -12,13 +12,13 @@ import { i18n } from '@kbn/i18n'; import { FormattedMessage } from '@kbn/i18n-react'; import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; -import type { ModelItem } from '../models_list'; +import type { TrainedModelItem } from '../../../../common/types/trained_models'; import { TestTrainedModelContent } from '../test_models/test_trained_model_content'; import { useMlKibana } from '../../contexts/kibana'; import { type InferecePipelineCreationState } from './state'; interface ContentProps { - model: ModelItem; + model: TrainedModelItem; handlePipelineConfigUpdate: (configUpdate: Partial) => void; externalPipelineConfig?: estypes.IngestPipeline; } diff --git a/x-pack/plugins/ml/public/application/model_management/delete_models_modal.tsx b/x-pack/plugins/ml/public/application/model_management/delete_models_modal.tsx index 0f5c515c22776..7afad711521dc 100644 --- a/x-pack/plugins/ml/public/application/model_management/delete_models_modal.tsx +++ b/x-pack/plugins/ml/public/application/model_management/delete_models_modal.tsx @@ -22,14 +22,15 @@ import { EuiSpacer, } from '@elastic/eui'; import { isPopulatedObject } from '@kbn/ml-is-populated-object'; +import type { TrainedModelItem, TrainedModelUIItem } from '../../../common/types/trained_models'; +import { isExistingModel } from '../../../common/types/trained_models'; import { type WithRequired } from '../../../common/types/common'; import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models'; import { useToastNotificationService } from '../services/toast_notification_service'; import { DeleteSpaceAwareItemCheckModal } from '../components/delete_space_aware_item_check_modal'; -import { type ModelItem } from './models_list'; interface DeleteModelsModalProps { - models: ModelItem[]; + models: TrainedModelUIItem[]; onClose: (refreshList?: boolean) => void; } @@ -42,11 +43,14 @@ export const DeleteModelsModal: FC = ({ models, onClose const modelIds = models.map((m) => m.model_id); - const modelsWithPipelines = models.filter((m) => isPopulatedObject(m.pipelines)) as Array< - WithRequired - >; + const modelsWithPipelines = models.filter( + (m): m is WithRequired => + isExistingModel(m) && isPopulatedObject(m.pipelines) + ); - const modelsWithInferenceAPIs = models.filter((m) => m.hasInferenceServices); + const modelsWithInferenceAPIs = models.filter( + (m): m is TrainedModelItem => isExistingModel(m) && !!m.hasInferenceServices + ); const inferenceAPIsIDs: string[] = modelsWithInferenceAPIs.flatMap((model) => { return (model.inference_apis ?? []).map((inference) => inference.inference_id); diff --git a/x-pack/plugins/ml/public/application/model_management/deployment_setup.tsx b/x-pack/plugins/ml/public/application/model_management/deployment_setup.tsx index 87fff2bf3eb75..c5b38feb4c799 100644 --- a/x-pack/plugins/ml/public/application/model_management/deployment_setup.tsx +++ b/x-pack/plugins/ml/public/application/model_management/deployment_setup.tsx @@ -42,9 +42,11 @@ import { css } from '@emotion/react'; import { toMountPoint } from '@kbn/react-kibana-mount'; import { dictionaryValidator } from '@kbn/ml-validators'; import type { NLPSettings } from '../../../common/constants/app'; -import type { TrainedModelDeploymentStatsResponse } from '../../../common/types/trained_models'; +import type { + NLPModelItem, + TrainedModelDeploymentStatsResponse, +} from '../../../common/types/trained_models'; import { type CloudInfo, getNewJobLimits } from '../services/ml_server_info'; -import type { ModelItem } from './models_list'; import type { MlStartTrainedModelDeploymentRequestNew } from './deployment_params_mapper'; import { DeploymentParamsMapper } from './deployment_params_mapper'; @@ -645,7 +647,7 @@ export const DeploymentSetup: FC = ({ }; interface StartDeploymentModalProps { - model: ModelItem; + model: NLPModelItem; startModelDeploymentDocUrl: string; onConfigChange: (config: DeploymentParamsUI) => void; onClose: () => void; @@ -845,7 +847,7 @@ export const getUserInputModelDeploymentParamsProvider = nlpSettings: NLPSettings ) => ( - model: ModelItem, + model: NLPModelItem, initialParams?: TrainedModelDeploymentStatsResponse, deploymentIds?: string[] ): Promise => { diff --git a/x-pack/plugins/ml/public/application/model_management/expanded_row.tsx b/x-pack/plugins/ml/public/application/model_management/expanded_row.tsx index eaec5776c3b81..4fc5fde65948a 100644 --- a/x-pack/plugins/ml/public/application/model_management/expanded_row.tsx +++ b/x-pack/plugins/ml/public/application/model_management/expanded_row.tsx @@ -26,18 +26,23 @@ import { FormattedMessage } from '@kbn/i18n-react'; import { FIELD_FORMAT_IDS } from '@kbn/field-formats-plugin/common'; import { isPopulatedObject } from '@kbn/ml-is-populated-object'; import { isDefined } from '@kbn/ml-is-defined'; -import { TRAINED_MODEL_TYPE } from '@kbn/ml-trained-models-utils'; +import { MODEL_STATE, TRAINED_MODEL_TYPE } from '@kbn/ml-trained-models-utils'; import { dynamic } from '@kbn/shared-ux-utility'; import { InferenceApi } from './inference_api_tab'; -import type { ModelItemFull } from './models_list'; import { ModelPipelines } from './pipelines'; import { AllocatedModels } from '../memory_usage/nodes_overview/allocated_models'; -import type { AllocatedModel, TrainedModelStat } from '../../../common/types/trained_models'; +import type { + AllocatedModel, + NLPModelItem, + TrainedModelItem, + TrainedModelStat, +} from '../../../common/types/trained_models'; import { useFieldFormatter } from '../contexts/kibana/use_field_formatter'; import { useEnabledFeatures } from '../contexts/ml'; +import { isNLPModelItem } from '../../../common/types/trained_models'; interface ExpandedRowProps { - item: ModelItemFull; + item: TrainedModelItem; } const JobMap = dynamic(async () => ({ @@ -169,8 +174,14 @@ export const ExpandedRow: FC = ({ item }) => { license_level, ]); + const hideColumns = useMemo(() => { + return showNodeInfo ? ['model_id'] : ['model_id', 'node_name']; + }, [showNodeInfo]); + const deploymentStatItems = useMemo(() => { - const deploymentStats = stats.deployment_stats; + if (!isNLPModelItem(item)) return []; + + const deploymentStats = (stats as NLPModelItem['stats'])!.deployment_stats; const modelSizeStats = stats.model_size_stats; if (!deploymentStats || !modelSizeStats) return []; @@ -227,11 +238,7 @@ export const ExpandedRow: FC = ({ item }) => { }; }); }); - }, [stats]); - - const hideColumns = useMemo(() => { - return showNodeInfo ? ['model_id'] : ['model_id', 'node_name']; - }, [showNodeInfo]); + }, [stats, item]); const tabs = useMemo(() => { return [ @@ -319,9 +326,7 @@ export const ExpandedRow: FC = ({ item }) => { @@ -528,7 +533,9 @@ export const ExpandedRow: FC = ({ item }) => { ]); const initialSelectedTab = - item.state === 'started' ? tabs.find((t) => t.id === 'stats') : tabs[0]; + isNLPModelItem(item) && item.state === MODEL_STATE.STARTED + ? tabs.find((t) => t.id === 'stats') + : tabs[0]; return ( void; onConfirm: (deploymentIds: string[]) => void; } @@ -220,7 +220,7 @@ export const StopModelDeploymentsConfirmDialog: FC) => - async (forceStopModel: ModelItem): Promise => { + async (forceStopModel: NLPModelItem): Promise => { return new Promise(async (resolve, reject) => { try { const modalSession = overlays.openModal( diff --git a/x-pack/plugins/ml/public/application/model_management/get_model_state.tsx b/x-pack/plugins/ml/public/application/model_management/get_model_state.tsx index d8bf2b8084a6a..75f8f9faa7a91 100644 --- a/x-pack/plugins/ml/public/application/model_management/get_model_state.tsx +++ b/x-pack/plugins/ml/public/application/model_management/get_model_state.tsx @@ -5,40 +5,18 @@ * 2.0. */ -import React from 'react'; -import { DEPLOYMENT_STATE, MODEL_STATE, type ModelState } from '@kbn/ml-trained-models-utils'; import { EuiBadge, - EuiHealth, - EuiLoadingSpinner, - type EuiHealthProps, EuiFlexGroup, EuiFlexItem, + EuiHealth, + EuiLoadingSpinner, EuiText, + type EuiHealthProps, } from '@elastic/eui'; import { i18n } from '@kbn/i18n'; -import type { ModelItem } from './models_list'; - -/** - * Resolves result model state based on the state of each deployment. - * - * If at least one deployment is in the STARTED state, the model state is STARTED. - * Then if none of the deployments are in the STARTED state, but at least one is in the STARTING state, the model state is STARTING. - * If all deployments are in the STOPPING state, the model state is STOPPING. - */ -export const getModelDeploymentState = (model: ModelItem): ModelState | undefined => { - if (!model.stats?.deployment_stats?.length) return; - - if (model.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTED)) { - return MODEL_STATE.STARTED; - } - if (model.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTING)) { - return MODEL_STATE.STARTING; - } - if (model.stats?.deployment_stats?.every((v) => v.state === DEPLOYMENT_STATE.STOPPING)) { - return MODEL_STATE.STOPPING; - } -}; +import { MODEL_STATE, type ModelState } from '@kbn/ml-trained-models-utils'; +import React from 'react'; export const getModelStateColor = ( state: ModelState | undefined diff --git a/x-pack/plugins/ml/public/application/model_management/inference_api_tab.tsx b/x-pack/plugins/ml/public/application/model_management/inference_api_tab.tsx index dc86c359bb1aa..3f55871a93e44 100644 --- a/x-pack/plugins/ml/public/application/model_management/inference_api_tab.tsx +++ b/x-pack/plugins/ml/public/application/model_management/inference_api_tab.tsx @@ -16,10 +16,10 @@ import { EuiTitle, } from '@elastic/eui'; import { FormattedMessage } from '@kbn/i18n-react'; -import type { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils'; +import type { InferenceInferenceEndpointInfo } from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; export interface InferenceAPITabProps { - inferenceApis: InferenceAPIConfigResponse[]; + inferenceApis: InferenceInferenceEndpointInfo[]; } export const InferenceApi: FC = ({ inferenceApis }) => { diff --git a/x-pack/plugins/ml/public/application/model_management/model_actions.tsx b/x-pack/plugins/ml/public/application/model_management/model_actions.tsx index 133698b0e72f1..1fe008871b0ef 100644 --- a/x-pack/plugins/ml/public/application/model_management/model_actions.tsx +++ b/x-pack/plugins/ml/public/application/model_management/model_actions.tsx @@ -8,18 +8,28 @@ import type { Action } from '@elastic/eui/src/components/basic_table/action_types'; import { i18n } from '@kbn/i18n'; import { isPopulatedObject } from '@kbn/ml-is-populated-object'; -import { EuiToolTip, useIsWithinMaxBreakpoint } from '@elastic/eui'; -import React, { useCallback, useMemo, useEffect, useState } from 'react'; -import { - BUILT_IN_MODEL_TAG, - DEPLOYMENT_STATE, - TRAINED_MODEL_TYPE, -} from '@kbn/ml-trained-models-utils'; +import { useIsWithinMaxBreakpoint } from '@elastic/eui'; +import React, { useMemo, useEffect, useState } from 'react'; +import { DEPLOYMENT_STATE } from '@kbn/ml-trained-models-utils'; import { MODEL_STATE } from '@kbn/ml-trained-models-utils/src/constants/trained_models'; import { getAnalysisType, type DataFrameAnalysisConfigType, } from '@kbn/ml-data-frame-analytics-utils'; +import useMountedState from 'react-use/lib/useMountedState'; +import type { + DFAModelItem, + NLPModelItem, + TrainedModelItem, + TrainedModelUIItem, +} from '../../../common/types/trained_models'; +import { + isBuiltInModel, + isDFAModelItem, + isExistingModel, + isModelDownloadItem, + isNLPModelItem, +} from '../../../common/types/trained_models'; import { useEnabledFeatures, useMlServerInfo } from '../contexts/ml'; import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models'; import { getUserConfirmationProvider } from './force_stop_dialog'; @@ -27,8 +37,7 @@ import { useToastNotificationService } from '../services/toast_notification_serv import { getUserInputModelDeploymentParamsProvider } from './deployment_setup'; import { useMlKibana, useMlLocator, useNavigateToPath } from '../contexts/kibana'; import { ML_PAGES } from '../../../common/constants/locator'; -import { isTestable, isDfaTrainedModel } from './test_models'; -import type { ModelItem } from './models_list'; +import { isTestable } from './test_models'; import { usePermissionCheck } from '../capabilities/check_capabilities'; import { useCloudCheck } from '../components/node_available_warning/hooks'; @@ -44,16 +53,17 @@ export function useModelActions({ onModelDownloadRequest, }: { isLoading: boolean; - onDfaTestAction: (model: ModelItem) => void; - onTestAction: (model: ModelItem) => void; - onModelsDeleteRequest: (models: ModelItem[]) => void; - onModelDeployRequest: (model: ModelItem) => void; + onDfaTestAction: (model: DFAModelItem) => void; + onTestAction: (model: TrainedModelItem) => void; + onModelsDeleteRequest: (models: TrainedModelUIItem[]) => void; + onModelDeployRequest: (model: DFAModelItem) => void; onModelDownloadRequest: (modelId: string) => void; onLoading: (isLoading: boolean) => void; fetchModels: () => Promise; modelAndDeploymentIds: string[]; -}): Array> { +}): Array> { const isMobileLayout = useIsWithinMaxBreakpoint('l'); + const isMounted = useMountedState(); const { services: { @@ -95,23 +105,19 @@ export function useModelActions({ const trainedModelsApiService = useTrainedModelsApiService(); useEffect(() => { - let isMounted = true; mlApi .hasPrivileges({ cluster: ['manage_ingest_pipelines'], }) .then((result) => { - if (isMounted) { + if (isMounted()) { setCanManageIngestPipelines( result.hasPrivileges === undefined || result.hasPrivileges.cluster?.manage_ingest_pipelines === true ); } }); - return () => { - isMounted = false; - }; - }, [mlApi]); + }, [mlApi, isMounted]); const getUserConfirmation = useMemo( () => getUserConfirmationProvider(overlays, startServices), @@ -131,12 +137,7 @@ export function useModelActions({ [overlays, startServices, startModelDeploymentDocUrl, cloudInfo, showNodeInfo, nlpSettings] ); - const isBuiltInModel = useCallback( - (item: ModelItem) => item.tags.includes(BUILT_IN_MODEL_TAG), - [] - ); - - return useMemo>>( + return useMemo>>( () => [ { name: i18n.translate('xpack.ml.trainedModels.modelsList.viewTrainingDataNameActionLabel', { @@ -150,10 +151,10 @@ export function useModelActions({ ), icon: 'visTable', type: 'icon', - available: (item) => !!item.metadata?.analytics_config?.id, - enabled: (item) => item.origin_job_exists === true, + available: (item) => isDFAModelItem(item) && !!item.metadata?.analytics_config?.id, + enabled: (item) => isDFAModelItem(item) && item.origin_job_exists === true, onClick: async (item) => { - if (item.metadata?.analytics_config === undefined) return; + if (!isDFAModelItem(item) || item.metadata?.analytics_config === undefined) return; const analysisType = getAnalysisType( item.metadata?.analytics_config.analysis @@ -185,7 +186,7 @@ export function useModelActions({ icon: 'graphApp', type: 'icon', isPrimary: true, - available: (item) => !!item.metadata?.analytics_config?.id, + available: (item) => isDFAModelItem(item) && !!item.metadata?.analytics_config?.id, onClick: async (item) => { const path = await urlLocator.getUrl({ page: ML_PAGES.DATA_FRAME_ANALYTICS_MAP, @@ -216,15 +217,14 @@ export function useModelActions({ }, available: (item) => { return ( - item.model_type === TRAINED_MODEL_TYPE.PYTORCH && - !!item.state && + isNLPModelItem(item) && item.state !== MODEL_STATE.DOWNLOADING && item.state !== MODEL_STATE.NOT_DOWNLOADED ); }, onClick: async (item) => { const modelDeploymentParams = await getUserInputModelDeploymentParams( - item, + item as NLPModelItem, undefined, modelAndDeploymentIds ); @@ -277,11 +277,13 @@ export function useModelActions({ type: 'icon', isPrimary: false, available: (item) => - item.model_type === TRAINED_MODEL_TYPE.PYTORCH && + isNLPModelItem(item) && canStartStopTrainedModels && !isLoading && !!item.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTED), onClick: async (item) => { + if (!isNLPModelItem(item)) return; + const deploymentIdToUpdate = item.deployment_ids[0]; const targetDeployment = item.stats!.deployment_stats.find( @@ -345,7 +347,7 @@ export function useModelActions({ type: 'icon', isPrimary: false, available: (item) => - item.model_type === TRAINED_MODEL_TYPE.PYTORCH && + isNLPModelItem(item) && canStartStopTrainedModels && // Deployment can be either started, starting, or exist in a failed state (item.state === MODEL_STATE.STARTED || item.state === MODEL_STATE.STARTING) && @@ -358,6 +360,8 @@ export function useModelActions({ )), enabled: (item) => !isLoading, onClick: async (item) => { + if (!isNLPModelItem(item)) return; + const requireForceStop = isPopulatedObject(item.pipelines); const hasMultipleDeployments = item.deployment_ids.length > 1; @@ -423,7 +427,10 @@ export function useModelActions({ // @ts-ignore type: isMobileLayout ? 'icon' : 'button', isPrimary: true, - available: (item) => canCreateTrainedModels && item.state === MODEL_STATE.NOT_DOWNLOADED, + available: (item) => + canCreateTrainedModels && + isModelDownloadItem(item) && + item.state === MODEL_STATE.NOT_DOWNLOADED, enabled: (item) => !isLoading, onClick: async (item) => { onModelDownloadRequest(item.model_id); @@ -431,28 +438,9 @@ export function useModelActions({ }, { name: (model) => { - const hasDeployments = model.state === MODEL_STATE.STARTED; - return ( - - <> - {i18n.translate('xpack.ml.trainedModels.modelsList.deployModelActionLabel', { - defaultMessage: 'Deploy model', - })} - - - ); + return i18n.translate('xpack.ml.trainedModels.modelsList.deployModelActionLabel', { + defaultMessage: 'Deploy model', + }); }, description: i18n.translate('xpack.ml.trainedModels.modelsList.deployModelActionLabel', { defaultMessage: 'Deploy model', @@ -462,23 +450,18 @@ export function useModelActions({ type: 'icon', isPrimary: false, onClick: (model) => { - onModelDeployRequest(model); + onModelDeployRequest(model as DFAModelItem); }, available: (item) => { - return ( - isDfaTrainedModel(item) && - !isBuiltInModel(item) && - !item.putModelConfig && - canManageIngestPipelines - ); + return isDFAModelItem(item) && canManageIngestPipelines; }, enabled: (item) => { - return canStartStopTrainedModels && item.state !== MODEL_STATE.STARTED; + return canStartStopTrainedModels; }, }, { name: (model) => { - return model.state === MODEL_STATE.DOWNLOADING ? ( + return isModelDownloadItem(model) && model.state === MODEL_STATE.DOWNLOADING ? ( <> {i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', { defaultMessage: 'Cancel', @@ -492,33 +475,33 @@ export function useModelActions({ ); }, - description: (model: ModelItem) => { - const hasDeployments = model.deployment_ids.length > 0; - const { hasInferenceServices } = model; - - if (model.state === MODEL_STATE.DOWNLOADING) { + description: (model: TrainedModelUIItem) => { + if (isModelDownloadItem(model) && model.state === MODEL_STATE.DOWNLOADING) { return i18n.translate('xpack.ml.trainedModels.modelsList.cancelDownloadActionLabel', { defaultMessage: 'Cancel download', }); - } else if (hasInferenceServices) { - return i18n.translate( - 'xpack.ml.trainedModels.modelsList.deleteDisabledWithInferenceServicesTooltip', - { - defaultMessage: 'Model is used by the _inference API', - } - ); - } else if (hasDeployments) { - return i18n.translate( - 'xpack.ml.trainedModels.modelsList.deleteDisabledWithDeploymentsTooltip', - { - defaultMessage: 'Model has started deployments', - } - ); - } else { - return i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', { - defaultMessage: 'Delete model', - }); + } else if (isNLPModelItem(model)) { + const hasDeployments = model.deployment_ids?.length ?? 0 > 0; + const { hasInferenceServices } = model; + if (hasInferenceServices) { + return i18n.translate( + 'xpack.ml.trainedModels.modelsList.deleteDisabledWithInferenceServicesTooltip', + { + defaultMessage: 'Model is used by the _inference API', + } + ); + } else if (hasDeployments) { + return i18n.translate( + 'xpack.ml.trainedModels.modelsList.deleteDisabledWithDeploymentsTooltip', + { + defaultMessage: 'Model has started deployments', + } + ); + } } + return i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', { + defaultMessage: 'Delete model', + }); }, 'data-test-subj': 'mlModelsTableRowDeleteAction', icon: 'trash', @@ -530,16 +513,17 @@ export function useModelActions({ onModelsDeleteRequest([model]); }, available: (item) => { - const hasZeroPipelines = Object.keys(item.pipelines ?? {}).length === 0; - return ( - canDeleteTrainedModels && - !isBuiltInModel(item) && - !item.putModelConfig && - (hasZeroPipelines || canManageIngestPipelines) - ); + if (!canDeleteTrainedModels || isBuiltInModel(item)) return false; + + if (isModelDownloadItem(item)) { + return !!item.downloadState; + } else { + const hasZeroPipelines = Object.keys(item.pipelines ?? {}).length === 0; + return hasZeroPipelines || canManageIngestPipelines; + } }, enabled: (item) => { - return item.state !== MODEL_STATE.STARTED; + return !isNLPModelItem(item) || item.state !== MODEL_STATE.STARTED; }, }, { @@ -556,9 +540,9 @@ export function useModelActions({ isPrimary: true, available: (item) => isTestable(item, true), onClick: (item) => { - if (isDfaTrainedModel(item) && !isBuiltInModel(item)) { + if (isDFAModelItem(item)) { onDfaTestAction(item); - } else { + } else if (isExistingModel(item)) { onTestAction(item); } }, @@ -579,19 +563,20 @@ export function useModelActions({ isPrimary: true, available: (item) => { return ( - item?.metadata?.analytics_config !== undefined || - (Array.isArray(item.indices) && item.indices.length > 0) + isDFAModelItem(item) || + (isExistingModel(item) && Array.isArray(item.indices) && item.indices.length > 0) ); }, onClick: async (item) => { - let indexPatterns: string[] | undefined = item?.indices - ?.map((o) => Object.keys(o)) - .flat(); + if (!isDFAModelItem(item) || !isExistingModel(item)) return; - if (item?.metadata?.analytics_config?.dest?.index !== undefined) { + let indexPatterns: string[] | undefined = item.indices; + + if (isDFAModelItem(item) && item?.metadata?.analytics_config?.dest?.index !== undefined) { const destIndex = item.metadata.analytics_config.dest?.index; indexPatterns = [destIndex]; } + const path = await urlLocator.getUrl({ page: ML_PAGES.DATA_DRIFT_CUSTOM, pageState: indexPatterns ? { comparison: indexPatterns.join(',') } : {}, @@ -612,7 +597,6 @@ export function useModelActions({ fetchModels, getUserConfirmation, getUserInputModelDeploymentParams, - isBuiltInModel, isLoading, modelAndDeploymentIds, navigateToPath, diff --git a/x-pack/plugins/ml/public/application/model_management/models_list.tsx b/x-pack/plugins/ml/public/application/model_management/models_list.tsx index d66ab1ab3db16..9547e7c6473bd 100644 --- a/x-pack/plugins/ml/public/application/model_management/models_list.tsx +++ b/x-pack/plugins/ml/public/application/model_management/models_list.tsx @@ -29,33 +29,29 @@ import type { EuiTableSelectionType } from '@elastic/eui/src/components/basic_ta import { i18n } from '@kbn/i18n'; import { FormattedMessage } from '@kbn/i18n-react'; import { useTimefilter } from '@kbn/ml-date-picker'; -import { isDefined } from '@kbn/ml-is-defined'; import { isPopulatedObject } from '@kbn/ml-is-populated-object'; import { useStorage } from '@kbn/ml-local-storage'; -import { - BUILT_IN_MODEL_TAG, - BUILT_IN_MODEL_TYPE, - ELASTIC_MODEL_TAG, - ELASTIC_MODEL_TYPE, - ELSER_ID_V1, - MODEL_STATE, - type ModelState, -} from '@kbn/ml-trained-models-utils'; +import { ELSER_ID_V1, MODEL_STATE } from '@kbn/ml-trained-models-utils'; import type { ListingPageUrlState } from '@kbn/ml-url-state'; import { usePageUrlState } from '@kbn/ml-url-state'; import { dynamic } from '@kbn/shared-ux-utility'; -import { cloneDeep, groupBy, isEmpty, memoize } from 'lodash'; +import { cloneDeep, isEmpty } from 'lodash'; import type { FC } from 'react'; import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'; import useMountedState from 'react-use/lib/useMountedState'; import { ML_PAGES } from '../../../common/constants/locator'; import { ML_ELSER_CALLOUT_DISMISSED } from '../../../common/types/storage'; import type { - ModelDownloadState, - ModelPipelines, - TrainedModelConfigResponse, - TrainedModelDeploymentStatsResponse, - TrainedModelStat, + DFAModelItem, + NLPModelItem, + TrainedModelItem, + TrainedModelUIItem, +} from '../../../common/types/trained_models'; +import { + isBaseNLPModelItem, + isBuiltInModel, + isModelDownloadItem, + isNLPModelItem, } from '../../../common/types/trained_models'; import { AddInferencePipelineFlyout } from '../components/ml_inference'; import { SavedObjectsWarning } from '../components/saved_objects_warning'; @@ -70,41 +66,11 @@ import { useTrainedModelsApiService } from '../services/ml_api_service/trained_m import { useToastNotificationService } from '../services/toast_notification_service'; import { ModelsTableToConfigMapping } from './config_mapping'; import { DeleteModelsModal } from './delete_models_modal'; -import { getModelDeploymentState, getModelStateColor } from './get_model_state'; +import { getModelStateColor } from './get_model_state'; import { useModelActions } from './model_actions'; import { TestDfaModelsFlyout } from './test_dfa_models_flyout'; import { TestModelAndPipelineCreationFlyout } from './test_models'; -type Stats = Omit; - -export type ModelItem = TrainedModelConfigResponse & { - type?: string[]; - stats?: Stats & { deployment_stats: TrainedModelDeploymentStatsResponse[] }; - pipelines?: ModelPipelines['pipelines'] | null; - origin_job_exists?: boolean; - deployment_ids: string[]; - putModelConfig?: object; - state: ModelState | undefined; - /** - * Description of the current model state - */ - stateDescription?: string; - recommended?: boolean; - supported: boolean; - /** - * Model name, e.g. elser - */ - modelName?: string; - os?: string; - arch?: string; - softwareLicense?: string; - licenseUrl?: string; - downloadState?: ModelDownloadState; - disclaimer?: string; -}; - -export type ModelItemFull = Required; - interface PageUrlState { pageKey: typeof ML_PAGES.TRAINED_MODELS_MANAGE; pageUrlState: ListingPageUrlState; @@ -185,120 +151,29 @@ export const ModelsList: FC = ({ const [isInitialized, setIsInitialized] = useState(false); const [isLoading, setIsLoading] = useState(false); - const [items, setItems] = useState([]); - const [selectedModels, setSelectedModels] = useState([]); - const [modelsToDelete, setModelsToDelete] = useState([]); - const [modelToDeploy, setModelToDeploy] = useState(); + const [items, setItems] = useState([]); + const [selectedModels, setSelectedModels] = useState([]); + const [modelsToDelete, setModelsToDelete] = useState([]); + const [modelToDeploy, setModelToDeploy] = useState(); const [itemIdToExpandedRowMap, setItemIdToExpandedRowMap] = useState>( {} ); - const [modelToTest, setModelToTest] = useState(null); - const [dfaModelToTest, setDfaModelToTest] = useState(null); + const [modelToTest, setModelToTest] = useState(null); + const [dfaModelToTest, setDfaModelToTest] = useState(null); const [isAddModelFlyoutVisible, setIsAddModelFlyoutVisible] = useState(false); - const isBuiltInModel = useCallback( - (item: ModelItem) => item.tags.includes(BUILT_IN_MODEL_TAG), - [] - ); - - const isElasticModel = useCallback( - (item: ModelItem) => item.tags.includes(ELASTIC_MODEL_TAG), - [] - ); - // List of downloaded/existing models - const existingModels = useMemo(() => { - return items.filter((i) => !i.putModelConfig); + const existingModels = useMemo>(() => { + return items.filter((i): i is NLPModelItem | DFAModelItem => !isModelDownloadItem(i)); }, [items]); - /** - * Fetch of model definitions available for download needs to happen only once - */ - const getTrainedModelDownloads = memoize(trainedModelsApiService.getTrainedModelDownloads); - /** * Fetches trained models. */ const fetchModelsData = useCallback(async () => { setIsLoading(true); try { - const response = await trainedModelsApiService.getTrainedModels(undefined, { - with_pipelines: true, - with_indices: true, - }); - - const newItems: ModelItem[] = []; - const expandedItemsToRefresh = []; - - for (const model of response) { - const tableItem: ModelItem = { - ...model, - // Extract model types - ...(typeof model.inference_config === 'object' - ? { - type: [ - model.model_type, - ...Object.keys(model.inference_config), - ...(isBuiltInModel(model as ModelItem) ? [BUILT_IN_MODEL_TYPE] : []), - ...(isElasticModel(model as ModelItem) ? [ELASTIC_MODEL_TYPE] : []), - ], - } - : {}), - } as ModelItem; - newItems.push(tableItem); - - if (itemIdToExpandedRowMap[model.model_id]) { - expandedItemsToRefresh.push(tableItem); - } - } - - // Need to fetch stats for all models to enable/disable actions - // TODO combine fetching models definitions and stats into a single function - await fetchModelsStats(newItems); - - let resultItems = newItems; - // don't add any of the built-in models (e.g. elser) if NLP is disabled - if (isNLPEnabled) { - const idMap = new Map( - resultItems.map((model) => [model.model_id, model]) - ); - /** - * Fetches model definitions available for download - */ - const forDownload = await getTrainedModelDownloads(); - - const notDownloaded: ModelItem[] = forDownload - .filter(({ model_id: modelId, hidden, recommended, supported, disclaimer }) => { - if (idMap.has(modelId)) { - const model = idMap.get(modelId)!; - if (recommended) { - model.recommended = true; - } - model.supported = supported; - model.disclaimer = disclaimer; - } - return !idMap.has(modelId) && !hidden; - }) - .map((modelDefinition) => { - return { - model_id: modelDefinition.model_id, - type: modelDefinition.type, - tags: modelDefinition.type?.includes(ELASTIC_MODEL_TAG) ? [ELASTIC_MODEL_TAG] : [], - putModelConfig: modelDefinition.config, - description: modelDefinition.description, - state: MODEL_STATE.NOT_DOWNLOADED, - recommended: !!modelDefinition.recommended, - modelName: modelDefinition.modelName, - os: modelDefinition.os, - arch: modelDefinition.arch, - softwareLicense: modelDefinition.license, - licenseUrl: modelDefinition.licenseUrl, - supported: modelDefinition.supported, - disclaimer: modelDefinition.disclaimer, - } as ModelItem; - }); - resultItems = [...resultItems, ...notDownloaded]; - } + const resultItems = await trainedModelsApiService.getTrainedModelsList(); setItems((prevItems) => { // Need to merge existing items with new items @@ -307,7 +182,7 @@ export const ModelsList: FC = ({ const prevItem = prevItems.find((i) => i.model_id === item.model_id); return { ...item, - ...(prevItem?.state === MODEL_STATE.DOWNLOADING + ...(isBaseNLPModelItem(prevItem) && prevItem?.state === MODEL_STATE.DOWNLOADING ? { state: prevItem.state, downloadState: prevItem.downloadState, @@ -322,7 +197,7 @@ export const ModelsList: FC = ({ return Object.fromEntries( Object.keys(prev).map((modelId) => { const item = resultItems.find((i) => i.model_id === modelId); - return item ? [modelId, ] : []; + return item ? [modelId, ] : []; }) ); }); @@ -365,51 +240,6 @@ export const ModelsList: FC = ({ }; }, [existingModels]); - /** - * Fetches models stats and update the original object - */ - const fetchModelsStats = useCallback(async (models: ModelItem[]) => { - try { - if (models) { - const { trained_model_stats: modelsStatsResponse } = - await trainedModelsApiService.getTrainedModelStats(); - - const groupByModelId = groupBy(modelsStatsResponse, 'model_id'); - - models.forEach((model) => { - const modelStats = groupByModelId[model.model_id]; - model.stats = { - ...(model.stats ?? {}), - ...modelStats[0], - deployment_stats: modelStats.map((d) => d.deployment_stats).filter(isDefined), - }; - - // Extract deployment ids from deployment stats - model.deployment_ids = modelStats - .map((v) => v.deployment_stats?.deployment_id) - .filter(isDefined); - - model.state = getModelDeploymentState(model); - model.stateDescription = model.stats.deployment_stats.reduce((acc, c) => { - if (acc) return acc; - return c.reason ?? ''; - }, ''); - }); - } - - return true; - } catch (error) { - displayErrorToast( - error, - i18n.translate('xpack.ml.trainedModels.modelsList.fetchModelStatsErrorMessage', { - defaultMessage: 'Error loading trained models statistics', - }) - ); - return false; - } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); - const downLoadStatusFetchInProgress = useRef(false); const abortedDownload = useRef(new Set()); @@ -432,7 +262,7 @@ export const ModelsList: FC = ({ if (isMounted()) { setItems((prevItems) => { return prevItems.map((item) => { - if (!item.type?.includes('pytorch')) { + if (!isBaseNLPModelItem(item)) { return item; } const newItem = cloneDeep(item); @@ -493,7 +323,9 @@ export const ModelsList: FC = ({ if (type) { acc.add(type); } - acc.add(item.model_type); + if (item.model_type) { + acc.add(item.model_type); + } return acc; }, new Set()); return [...result] @@ -504,15 +336,15 @@ export const ModelsList: FC = ({ })); }, [existingModels]); - const modelAndDeploymentIds = useMemo( - () => [ + const modelAndDeploymentIds = useMemo(() => { + const nlpModels = existingModels.filter(isNLPModelItem); + return [ ...new Set([ - ...existingModels.flatMap((v) => v.deployment_ids), - ...existingModels.map((i) => i.model_id), + ...nlpModels.flatMap((v) => v.deployment_ids), + ...nlpModels.map((i) => i.model_id), ]), - ], - [existingModels] - ); + ]; + }, [existingModels]); const onModelDownloadRequest = useCallback( async (modelId: string) => { @@ -550,22 +382,22 @@ export const ModelsList: FC = ({ onModelDownloadRequest, }); - const toggleDetails = async (item: ModelItem) => { + const toggleDetails = async (item: TrainedModelUIItem) => { const itemIdToExpandedRowMapValues = { ...itemIdToExpandedRowMap }; if (itemIdToExpandedRowMapValues[item.model_id]) { delete itemIdToExpandedRowMapValues[item.model_id]; } else { - itemIdToExpandedRowMapValues[item.model_id] = ; + itemIdToExpandedRowMapValues[item.model_id] = ; } setItemIdToExpandedRowMap(itemIdToExpandedRowMapValues); }; - const columns: Array> = [ + const columns: Array> = [ { isExpander: true, align: 'center', - render: (item: ModelItem) => { - if (!item.stats) { + render: (item: TrainedModelUIItem) => { + if (isModelDownloadItem(item) || !item.stats) { return null; } return ( @@ -588,38 +420,38 @@ export const ModelsList: FC = ({ }, { name: modelIdColumnName, - sortable: ({ model_id: modelId }: ModelItem) => modelId, + sortable: ({ model_id: modelId }: TrainedModelUIItem) => modelId, truncateText: false, textOnly: false, 'data-test-subj': 'mlModelsTableColumnId', - render: ({ - description, - model_id: modelId, - recommended, - supported, - type, - disclaimer, - }: ModelItem) => { + render: (item: TrainedModelUIItem) => { + const { description, model_id: modelId, type } = item; + const isTechPreview = description?.includes('(Tech Preview)'); let descriptionText = description?.replace('(Tech Preview)', ''); - if (disclaimer) { - descriptionText += '. ' + disclaimer; - } + let tooltipContent = null; - const tooltipContent = - supported === false ? ( - - ) : recommended === false ? ( - - ) : null; + if (isBaseNLPModelItem(item)) { + const { disclaimer, recommended, supported } = item; + if (disclaimer) { + descriptionText += '. ' + disclaimer; + } + + tooltipContent = + supported === false ? ( + + ) : recommended === false ? ( + + ) : null; + } return ( @@ -675,7 +507,10 @@ export const ModelsList: FC = ({ }), truncateText: false, width: '150px', - render: ({ state, downloadState }: ModelItem) => { + render: (item: TrainedModelUIItem) => { + if (!isBaseNLPModelItem(item)) return null; + + const { state, downloadState } = item; const config = getModelStateColor(state); if (!config) return null; @@ -776,7 +611,7 @@ export const ModelsList: FC = ({ const isSelectionAllowed = canDeleteTrainedModels; - const selection: EuiTableSelectionType | undefined = isSelectionAllowed + const selection: EuiTableSelectionType | undefined = isSelectionAllowed ? { selectableMessage: (selectable, item) => { if (selectable) { @@ -784,31 +619,28 @@ export const ModelsList: FC = ({ defaultMessage: 'Select a model', }); } - if (isPopulatedObject(item.pipelines)) { + // TODO support multiple model downloads with selection + if (!isModelDownloadItem(item) && isPopulatedObject(item.pipelines)) { return i18n.translate('xpack.ml.trainedModels.modelsList.disableSelectableMessage', { defaultMessage: 'Model has associated pipelines', }); } - if (isBuiltInModel(item)) { return i18n.translate('xpack.ml.trainedModels.modelsList.builtInModelMessage', { defaultMessage: 'Built-in model', }); } - return ''; }, selectable: (item) => - !isPopulatedObject(item.pipelines) && - !isBuiltInModel(item) && - !(isElasticModel(item) && !item.state), + !isModelDownloadItem(item) && !isPopulatedObject(item.pipelines) && !isBuiltInModel(item), onSelectionChange: (selectedItems) => { setSelectedModels(selectedItems); }, } : undefined; - const { onTableChange, pagination, sorting } = useTableSettings( + const { onTableChange, pagination, sorting } = useTableSettings( items.length, pageState, updatePageState, @@ -847,7 +679,7 @@ export const ModelsList: FC = ({ return items; } else { // by default show only deployed models or recommended for download - return items.filter((item) => item.create_time || item.recommended); + return items.filter((item) => !isModelDownloadItem(item) || item.recommended); } }, [items, pageState.showAll]); @@ -896,7 +728,7 @@ export const ModelsList: FC = ({
- + tableLayout={'auto'} responsiveBreakpoint={'xl'} allowNeutralSort={false} @@ -952,7 +784,7 @@ export const ModelsList: FC = ({ { modelsToDelete.forEach((model) => { - if (model.state === MODEL_STATE.DOWNLOADING) { + if (isBaseNLPModelItem(model) && model.state === MODEL_STATE.DOWNLOADING) { abortedDownload.current.add(model.model_id); } }); @@ -996,7 +828,7 @@ export const ModelsList: FC = ({ ) : null} {isAddModelFlyoutVisible ? ( i.state === MODEL_STATE.NOT_DOWNLOADED)} + modelDownloads={items.filter(isModelDownloadItem)} onClose={setIsAddModelFlyoutVisible.bind(null, false)} onSubmit={(modelId) => { onModelDownloadRequest(modelId); diff --git a/x-pack/plugins/ml/public/application/model_management/pipelines/pipelines.tsx b/x-pack/plugins/ml/public/application/model_management/pipelines/pipelines.tsx index d144bf2aaf558..384a0736ed6f2 100644 --- a/x-pack/plugins/ml/public/application/model_management/pipelines/pipelines.tsx +++ b/x-pack/plugins/ml/public/application/model_management/pipelines/pipelines.tsx @@ -17,14 +17,14 @@ import { EuiAccordion, } from '@elastic/eui'; import { FormattedMessage } from '@kbn/i18n-react'; +import type { TrainedModelItem } from '../../../../common/types/trained_models'; import { useMlKibana } from '../../contexts/kibana'; -import type { ModelItem } from '../models_list'; import { ProcessorsStats } from './expanded_row'; -export type IngestStatsResponse = Exclude['ingest']; +export type IngestStatsResponse = Exclude['ingest']; interface ModelPipelinesProps { - pipelines: ModelItem['pipelines']; + pipelines: TrainedModelItem['pipelines']; ingestStats: IngestStatsResponse; } diff --git a/x-pack/plugins/ml/public/application/model_management/test_dfa_models_flyout.tsx b/x-pack/plugins/ml/public/application/model_management/test_dfa_models_flyout.tsx index 86ddd16e620ad..4593413154bd5 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_dfa_models_flyout.tsx +++ b/x-pack/plugins/ml/public/application/model_management/test_dfa_models_flyout.tsx @@ -9,14 +9,13 @@ import type { FC } from 'react'; import React, { useMemo } from 'react'; import { EuiFlyout, EuiFlyoutBody, EuiFlyoutHeader, EuiSpacer, EuiTitle } from '@elastic/eui'; import { FormattedMessage } from '@kbn/i18n-react'; - +import type { DFAModelItem } from '../../../common/types/trained_models'; import { TestPipeline } from '../components/ml_inference/components/test_pipeline'; import { getInitialState } from '../components/ml_inference/state'; -import type { ModelItem } from './models_list'; import { TEST_PIPELINE_MODE } from '../components/ml_inference/types'; interface Props { - model: ModelItem; + model: DFAModelItem; onClose: () => void; } diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/index.ts b/x-pack/plugins/ml/public/application/model_management/test_models/index.ts index 4b238f477092e..209704581f489 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/index.ts +++ b/x-pack/plugins/ml/public/application/model_management/test_models/index.ts @@ -6,4 +6,4 @@ */ export { TestModelAndPipelineCreationFlyout } from './test_model_and_pipeline_creation_flyout'; -export { isTestable, isDfaTrainedModel } from './utils'; +export { isTestable } from './utils'; diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/test_flyout.tsx b/x-pack/plugins/ml/public/application/model_management/test_models/test_flyout.tsx index b8bc2d706b8c0..3b8d3cc7bdea9 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/test_flyout.tsx +++ b/x-pack/plugins/ml/public/application/model_management/test_models/test_flyout.tsx @@ -7,15 +7,13 @@ import type { FC } from 'react'; import React from 'react'; - import { FormattedMessage } from '@kbn/i18n-react'; import { EuiFlyout, EuiFlyoutBody, EuiFlyoutHeader, EuiSpacer, EuiTitle } from '@elastic/eui'; - -import { type ModelItem } from '../models_list'; +import type { TrainedModelItem } from '../../../../common/types/trained_models'; import { TestTrainedModelContent } from './test_trained_model_content'; interface Props { - model: ModelItem; + model: TrainedModelItem; onClose: () => void; } export const TestTrainedModelFlyout: FC = ({ model, onClose }) => ( diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/test_model_and_pipeline_creation_flyout.tsx b/x-pack/plugins/ml/public/application/model_management/test_models/test_model_and_pipeline_creation_flyout.tsx index 240c2545f3d8e..f78f12cf88211 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/test_model_and_pipeline_creation_flyout.tsx +++ b/x-pack/plugins/ml/public/application/model_management/test_models/test_model_and_pipeline_creation_flyout.tsx @@ -7,17 +7,16 @@ import type { FC } from 'react'; import React, { useState } from 'react'; - +import type { TrainedModelItem } from '../../../../common/types/trained_models'; import { type TestTrainedModelsContextType, TestTrainedModelsContext, } from './test_trained_models_context'; -import type { ModelItem } from '../models_list'; import { TestTrainedModelFlyout } from './test_flyout'; import { CreatePipelineForModelFlyout } from '../create_pipeline_for_model/create_pipeline_for_model_flyout'; interface Props { - model: ModelItem; + model: TrainedModelItem; onClose: (refreshList?: boolean) => void; } export const TestModelAndPipelineCreationFlyout: FC = ({ model, onClose }) => { diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/test_trained_model_content.tsx b/x-pack/plugins/ml/public/application/model_management/test_models/test_trained_model_content.tsx index da4c496700687..3c829c8f7cd49 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/test_trained_model_content.tsx +++ b/x-pack/plugins/ml/public/application/model_management/test_models/test_trained_model_content.tsx @@ -12,14 +12,15 @@ import { SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils'; import { FormattedMessage } from '@kbn/i18n-react'; import { EuiFormRow, EuiSelect, EuiSpacer, EuiTab, EuiTabs, useEuiPaddingSize } from '@elastic/eui'; import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; +import type { TrainedModelItem } from '../../../../common/types/trained_models'; +import { isNLPModelItem } from '../../../../common/types/trained_models'; import { SelectedModel } from './selected_model'; -import { type ModelItem } from '../models_list'; import { INPUT_TYPE } from './models/inference_base'; import { useTestTrainedModelsContext } from './test_trained_models_context'; import { type InferecePipelineCreationState } from '../create_pipeline_for_model/state'; interface ContentProps { - model: ModelItem; + model: TrainedModelItem; handlePipelineConfigUpdate?: (configUpdate: Partial) => void; externalPipelineConfig?: estypes.IngestPipeline; } @@ -29,7 +30,9 @@ export const TestTrainedModelContent: FC = ({ handlePipelineConfigUpdate, externalPipelineConfig, }) => { - const [deploymentId, setDeploymentId] = useState(model.deployment_ids[0]); + const [deploymentId, setDeploymentId] = useState( + isNLPModelItem(model) ? model.deployment_ids[0] : model.model_id + ); const mediumPadding = useEuiPaddingSize('m'); const [inputType, setInputType] = useState(INPUT_TYPE.TEXT); @@ -46,8 +49,7 @@ export const TestTrainedModelContent: FC = ({ }, [model, createPipelineFlyoutOpen]); return ( <> - {' '} - {model.deployment_ids.length > 1 ? ( + {isNLPModelItem(model) && model.deployment_ids.length > 1 ? ( <> ({ + path: `${ML_INTERNAL_BASE_PATH}/trained_models_list`, + method: 'GET', + version: '1', + }); + }, + /** * Fetches usage information for trained inference models. * @param modelId - Model ID, collection of Model IDs or Model ID pattern. diff --git a/x-pack/plugins/ml/server/models/data_frame_analytics/analytics_manager.ts b/x-pack/plugins/ml/server/models/data_frame_analytics/analytics_manager.ts index e720f12fa4dd5..04a14c7f235ff 100644 --- a/x-pack/plugins/ml/server/models/data_frame_analytics/analytics_manager.ts +++ b/x-pack/plugins/ml/server/models/data_frame_analytics/analytics_manager.ts @@ -51,7 +51,12 @@ export class AnalyticsManager { private readonly _enabledFeatures: MlFeatures, cloud: CloudSetup ) { - this._modelsProvider = modelsProvider(this._client, this._mlClient, cloud); + this._modelsProvider = modelsProvider( + this._client, + this._mlClient, + cloud, + this._enabledFeatures + ); } private async initData() { diff --git a/x-pack/plugins/ml/public/application/model_management/get_model_state.test.tsx b/x-pack/plugins/ml/server/models/model_management/get_model_state.test.tsx similarity index 94% rename from x-pack/plugins/ml/public/application/model_management/get_model_state.test.tsx rename to x-pack/plugins/ml/server/models/model_management/get_model_state.test.tsx index 1431b2da0439c..16c30395d1b15 100644 --- a/x-pack/plugins/ml/public/application/model_management/get_model_state.test.tsx +++ b/x-pack/plugins/ml/server/models/model_management/get_model_state.test.tsx @@ -5,9 +5,9 @@ * 2.0. */ -import { getModelDeploymentState } from './get_model_state'; import { MODEL_STATE } from '@kbn/ml-trained-models-utils'; -import type { ModelItem } from './models_list'; +import type { NLPModelItem } from '../../../common/types/trained_models'; +import { getModelDeploymentState } from './get_model_state'; describe('getModelDeploymentState', () => { it('returns STARTED if any deployment is in STARTED state', () => { @@ -37,7 +37,7 @@ describe('getModelDeploymentState', () => { }, ], }, - } as unknown as ModelItem; + } as unknown as NLPModelItem; const result = getModelDeploymentState(model); expect(result).toEqual(MODEL_STATE.STARTED); }); @@ -69,7 +69,7 @@ describe('getModelDeploymentState', () => { }, ], }, - } as unknown as ModelItem; + } as unknown as NLPModelItem; const result = getModelDeploymentState(model); expect(result).toEqual(MODEL_STATE.STARTING); }); @@ -96,7 +96,7 @@ describe('getModelDeploymentState', () => { }, ], }, - } as unknown as ModelItem; + } as unknown as NLPModelItem; const result = getModelDeploymentState(model); expect(result).toEqual(MODEL_STATE.STOPPING); }); @@ -112,7 +112,7 @@ describe('getModelDeploymentState', () => { deployment_stats: [], }, - } as unknown as ModelItem; + } as unknown as NLPModelItem; const result = getModelDeploymentState(model); expect(result).toEqual(undefined); }); diff --git a/x-pack/plugins/ml/server/models/model_management/get_model_state.ts b/x-pack/plugins/ml/server/models/model_management/get_model_state.ts new file mode 100644 index 0000000000000..2ee2bf8cb4532 --- /dev/null +++ b/x-pack/plugins/ml/server/models/model_management/get_model_state.ts @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { DEPLOYMENT_STATE, MODEL_STATE, type ModelState } from '@kbn/ml-trained-models-utils'; +import type { NLPModelItem } from '../../../common/types/trained_models'; + +/** + * Resolves result model state based on the state of each deployment. + * + * If at least one deployment is in the STARTED state, the model state is STARTED. + * Then if none of the deployments are in the STARTED state, but at least one is in the STARTING state, the model state is STARTING. + * If all deployments are in the STOPPING state, the model state is STOPPING. + */ +export const getModelDeploymentState = (model: NLPModelItem): ModelState | undefined => { + if (!model.stats?.deployment_stats?.length) return; + + if (model.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTED)) { + return MODEL_STATE.STARTED; + } + if (model.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTING)) { + return MODEL_STATE.STARTING; + } + if (model.stats?.deployment_stats?.every((v) => v.state === DEPLOYMENT_STATE.STOPPING)) { + return MODEL_STATE.STOPPING; + } +}; diff --git a/x-pack/plugins/ml/server/models/model_management/model_provider.test.ts b/x-pack/plugins/ml/server/models/model_management/model_provider.test.ts index 0b9b93720234d..0a73dfa3053db 100644 --- a/x-pack/plugins/ml/server/models/model_management/model_provider.test.ts +++ b/x-pack/plugins/ml/server/models/model_management/model_provider.test.ts @@ -6,45 +6,54 @@ */ import { modelsProvider } from './models_provider'; -import { type IScopedClusterClient } from '@kbn/core/server'; import { cloudMock } from '@kbn/cloud-plugin/server/mocks'; import type { MlClient } from '../../lib/ml_client'; import downloadTasksResponse from './__mocks__/mock_download_tasks.json'; +import type { MlFeatures } from '../../../common/constants/app'; +import { mlLog } from '../../lib/log'; +import { errors } from '@elastic/elasticsearch'; +import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks'; +import type { ExistingModelBase } from '../../../common/types/trained_models'; +import type { InferenceInferenceEndpointInfo } from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; + +jest.mock('../../lib/log'); describe('modelsProvider', () => { - const mockClient = { - asInternalUser: { - transport: { - request: jest.fn().mockResolvedValue({ - _nodes: { - total: 1, - successful: 1, - failed: 0, - }, - cluster_name: 'default', - nodes: { - yYmqBqjpQG2rXsmMSPb9pQ: { - name: 'node-0', - roles: ['ml'], - attributes: {}, - os: { - name: 'Linux', - arch: 'amd64', - }, - }, - }, - }), - }, - tasks: { - list: jest.fn().mockResolvedValue({ tasks: [] }), + const mockClient = elasticsearchClientMock.createScopedClusterClient(); + + mockClient.asInternalUser.transport.request.mockResolvedValue({ + _nodes: { + total: 1, + successful: 1, + failed: 0, + }, + cluster_name: 'default', + nodes: { + yYmqBqjpQG2rXsmMSPb9pQ: { + name: 'node-0', + roles: ['ml'], + attributes: {}, + os: { + name: 'Linux', + arch: 'amd64', + }, }, }, - } as unknown as jest.Mocked; + }); + + mockClient.asInternalUser.tasks.list.mockResolvedValue({ tasks: [] }); const mockMlClient = {} as unknown as jest.Mocked; const mockCloud = cloudMock.createSetup(); - const modelService = modelsProvider(mockClient, mockMlClient, mockCloud); + + const enabledMlFeatures: MlFeatures = { + ad: false, + dfa: true, + nlp: true, + }; + + const modelService = modelsProvider(mockClient, mockMlClient, mockCloud, enabledMlFeatures); afterEach(() => { jest.clearAllMocks(); @@ -122,7 +131,7 @@ describe('modelsProvider', () => { test('provides a list of models with default model as recommended', async () => { mockCloud.cloudId = undefined; - (mockClient.asInternalUser.transport.request as jest.Mock).mockResolvedValueOnce({ + mockClient.asInternalUser.transport.request.mockResolvedValueOnce({ _nodes: { total: 1, successful: 1, @@ -218,7 +227,7 @@ describe('modelsProvider', () => { test('provides a default version if there is no recommended', async () => { mockCloud.cloudId = undefined; - (mockClient.asInternalUser.transport.request as jest.Mock).mockResolvedValueOnce({ + mockClient.asInternalUser.transport.request.mockResolvedValueOnce({ _nodes: { total: 1, successful: 1, @@ -261,7 +270,7 @@ describe('modelsProvider', () => { test('provides a default version if there is no recommended', async () => { mockCloud.cloudId = undefined; - (mockClient.asInternalUser.transport.request as jest.Mock).mockResolvedValueOnce({ + mockClient.asInternalUser.transport.request.mockResolvedValueOnce({ _nodes: { total: 1, successful: 1, @@ -292,9 +301,7 @@ describe('modelsProvider', () => { expect(result).toEqual({}); }); test('provides download status for all models', async () => { - (mockClient.asInternalUser.tasks.list as jest.Mock).mockResolvedValueOnce( - downloadTasksResponse - ); + mockClient.asInternalUser.tasks.list.mockResolvedValueOnce(downloadTasksResponse); const result = await modelService.getModelsDownloadStatus(); expect(result).toEqual({ '.elser_model_2': { downloaded_parts: 0, total_parts: 418 }, @@ -302,4 +309,124 @@ describe('modelsProvider', () => { }); }); }); + + describe('#assignInferenceEndpoints', () => { + let trainedModels: ExistingModelBase[]; + + const inferenceServices = [ + { + service: 'elser', + model_id: 'elser_test', + service_settings: { model_id: '.elser_model_2' }, + }, + { service: 'open_api_01', service_settings: {} }, + ] as InferenceInferenceEndpointInfo[]; + + beforeEach(() => { + trainedModels = [ + { model_id: '.elser_model_2' }, + { model_id: 'model2' }, + ] as ExistingModelBase[]; + + mockClient.asInternalUser.inference.get.mockResolvedValue({ + endpoints: inferenceServices, + }); + + jest.clearAllMocks(); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('when the user has required privileges', () => { + beforeEach(() => { + mockClient.asCurrentUser.inference.get.mockResolvedValue({ + endpoints: inferenceServices, + }); + }); + + test('should populate inference services for trained models', async () => { + // act + await modelService.assignInferenceEndpoints(trainedModels, false); + + // assert + expect(mockClient.asCurrentUser.inference.get).toHaveBeenCalledWith({ + inference_id: '_all', + }); + + expect(mockClient.asInternalUser.inference.get).not.toHaveBeenCalled(); + + expect(trainedModels[0].inference_apis).toEqual([ + { + model_id: 'elser_test', + service: 'elser', + service_settings: { model_id: '.elser_model_2' }, + }, + ]); + expect(trainedModels[0].hasInferenceServices).toBe(true); + + expect(trainedModels[1].inference_apis).toEqual(undefined); + expect(trainedModels[1].hasInferenceServices).toBe(false); + + expect(mlLog.error).not.toHaveBeenCalled(); + }); + }); + + describe('when the user does not have required privileges', () => { + beforeEach(() => { + mockClient.asCurrentUser.inference.get.mockRejectedValue( + new errors.ResponseError( + elasticsearchClientMock.createApiResponse({ + statusCode: 403, + body: { message: 'not allowed' }, + }) + ) + ); + }); + + test('should retry with internal user if an error occurs', async () => { + await modelService.assignInferenceEndpoints(trainedModels, false); + + // assert + expect(mockClient.asCurrentUser.inference.get).toHaveBeenCalledWith({ + inference_id: '_all', + }); + + expect(mockClient.asInternalUser.inference.get).toHaveBeenCalledWith({ + inference_id: '_all', + }); + + expect(trainedModels[0].inference_apis).toEqual(undefined); + expect(trainedModels[0].hasInferenceServices).toBe(true); + + expect(trainedModels[1].inference_apis).toEqual(undefined); + expect(trainedModels[1].hasInferenceServices).toBe(false); + + expect(mlLog.error).not.toHaveBeenCalled(); + }); + }); + + test('should not retry on any other error than 403', async () => { + const notFoundError = new errors.ResponseError( + elasticsearchClientMock.createApiResponse({ + statusCode: 404, + body: { message: 'not found' }, + }) + ); + + mockClient.asCurrentUser.inference.get.mockRejectedValue(notFoundError); + + await modelService.assignInferenceEndpoints(trainedModels, false); + + // assert + expect(mockClient.asCurrentUser.inference.get).toHaveBeenCalledWith({ + inference_id: '_all', + }); + + expect(mockClient.asInternalUser.inference.get).not.toHaveBeenCalled(); + + expect(mlLog.error).toHaveBeenCalledWith(notFoundError); + }); + }); }); diff --git a/x-pack/plugins/ml/server/models/model_management/models_provider.ts b/x-pack/plugins/ml/server/models/model_management/models_provider.ts index 1c175cee26d14..0f302363f66ea 100644 --- a/x-pack/plugins/ml/server/models/model_management/models_provider.ts +++ b/x-pack/plugins/ml/server/models/model_management/models_provider.ts @@ -8,10 +8,11 @@ import Boom from '@hapi/boom'; import type { IScopedClusterClient } from '@kbn/core/server'; import { JOB_MAP_NODE_TYPES, type MapElements } from '@kbn/ml-data-frame-analytics-utils'; -import { flatten } from 'lodash'; +import { flatten, groupBy, isEmpty } from 'lodash'; import type { InferenceInferenceEndpoint, InferenceTaskType, + MlGetTrainedModelsRequest, TasksTaskInfo, TransformGetTransformTransformSummary, } from '@elastic/elasticsearch/lib/api/types'; @@ -24,22 +25,50 @@ import type { } from '@elastic/elasticsearch/lib/api/types'; import { ELASTIC_MODEL_DEFINITIONS, + ELASTIC_MODEL_TAG, + MODEL_STATE, type GetModelDownloadConfigOptions, type ModelDefinitionResponse, + ELASTIC_MODEL_TYPE, + BUILT_IN_MODEL_TYPE, } from '@kbn/ml-trained-models-utils'; import type { CloudSetup } from '@kbn/cloud-plugin/server'; import type { ElasticCuratedModelName } from '@kbn/ml-trained-models-utils'; -import type { ModelDownloadState, PipelineDefinition } from '../../../common/types/trained_models'; +import { isDefined } from '@kbn/ml-is-defined'; +import { DEFAULT_TRAINED_MODELS_PAGE_SIZE } from '../../../common/constants/trained_models'; +import type { MlFeatures } from '../../../common/constants/app'; +import type { + DFAModelItem, + ExistingModelBase, + ModelDownloadItem, + NLPModelItem, + TrainedModelItem, + TrainedModelUIItem, + TrainedModelWithPipelines, +} from '../../../common/types/trained_models'; +import { isBuiltInModel, isExistingModel } from '../../../common/types/trained_models'; +import { + isDFAModelItem, + isElasticModel, + isNLPModelItem, + type ModelDownloadState, + type PipelineDefinition, + type TrainedModelConfigResponse, +} from '../../../common/types/trained_models'; import type { MlClient } from '../../lib/ml_client'; import type { MLSavedObjectService } from '../../saved_objects'; +import { filterForEnabledFeatureModels } from '../../routes/trained_models'; +import { mlLog } from '../../lib/log'; +import { getModelDeploymentState } from './get_model_state'; export type ModelService = ReturnType; export const modelsProvider = ( client: IScopedClusterClient, mlClient: MlClient, - cloud: CloudSetup -) => new ModelsProvider(client, mlClient, cloud); + cloud: CloudSetup, + enabledFeatures: MlFeatures +) => new ModelsProvider(client, mlClient, cloud, enabledFeatures); interface ModelMapResult { ingestPipelines: Map | null>; @@ -66,7 +95,8 @@ export class ModelsProvider { constructor( private _client: IScopedClusterClient, private _mlClient: MlClient, - private _cloud: CloudSetup + private _cloud: CloudSetup, + private _enabledFeatures: MlFeatures ) {} private async initTransformData() { @@ -110,6 +140,291 @@ export class ModelsProvider { return `${elementOriginalId}-${nodeType}`; } + /** + * Assigns inference endpoints to trained models + * @param trainedModels + * @param asInternal + */ + async assignInferenceEndpoints(trainedModels: ExistingModelBase[], asInternal: boolean = false) { + const esClient = asInternal ? this._client.asInternalUser : this._client.asCurrentUser; + + try { + // Check if model is used by an inference service + const { endpoints } = await esClient.inference.get({ + inference_id: '_all', + }); + + const inferenceAPIMap = groupBy( + endpoints, + (endpoint) => endpoint.service === 'elser' && endpoint.service_settings.model_id + ); + + for (const model of trainedModels) { + const inferenceApis = inferenceAPIMap[model.model_id]; + model.hasInferenceServices = !!inferenceApis; + if (model.hasInferenceServices && !asInternal) { + model.inference_apis = inferenceApis; + } + } + } catch (e) { + if (!asInternal && e.statusCode === 403) { + // retry with internal user to get an indicator if models has associated inference services, without mentioning the names + await this.assignInferenceEndpoints(trainedModels, true); + } else { + mlLog.error(e); + } + } + } + + /** + * Assigns trained model stats to trained models + * @param trainedModels + */ + async assignModelStats(trainedModels: ExistingModelBase[]): Promise { + const { trained_model_stats: modelsStatsResponse } = await this._mlClient.getTrainedModelsStats( + { + size: DEFAULT_TRAINED_MODELS_PAGE_SIZE, + } + ); + + const groupByModelId = groupBy(modelsStatsResponse, 'model_id'); + + return trainedModels.map((model) => { + const modelStats = groupByModelId[model.model_id]; + + const completeModelItem: TrainedModelItem = { + ...model, + // @ts-ignore FIXME: fix modelStats type + stats: { + ...modelStats[0], + ...(isNLPModelItem(model) + ? { deployment_stats: modelStats.map((d) => d.deployment_stats).filter(isDefined) } + : {}), + }, + }; + + if (isNLPModelItem(completeModelItem)) { + // Extract deployment ids from deployment stats + completeModelItem.deployment_ids = modelStats + .map((v) => v.deployment_stats?.deployment_id) + .filter(isDefined); + + completeModelItem.state = getModelDeploymentState(completeModelItem); + + completeModelItem.stateDescription = completeModelItem.stats.deployment_stats.reduce( + (acc, c) => { + if (acc) return acc; + return c.reason ?? ''; + }, + '' + ); + } + + return completeModelItem; + }); + } + + /** + * Merges the list of models with the list of models available for download. + */ + async includeModelDownloads(resultItems: TrainedModelUIItem[]): Promise { + const idMap = new Map( + resultItems.map((model) => [model.model_id, model]) + ); + /** + * Fetches model definitions available for download + */ + const forDownload = await this.getModelDownloads(); + + const notDownloaded: TrainedModelUIItem[] = forDownload + .filter(({ model_id: modelId, hidden, recommended, supported, disclaimer }) => { + if (idMap.has(modelId)) { + const model = idMap.get(modelId)! as NLPModelItem; + if (recommended) { + model.recommended = true; + } + model.supported = supported; + model.disclaimer = disclaimer; + } + return !idMap.has(modelId) && !hidden; + }) + .map((modelDefinition) => { + return { + model_id: modelDefinition.model_id, + type: modelDefinition.type, + tags: modelDefinition.type?.includes(ELASTIC_MODEL_TAG) ? [ELASTIC_MODEL_TAG] : [], + putModelConfig: modelDefinition.config, + description: modelDefinition.description, + state: MODEL_STATE.NOT_DOWNLOADED, + recommended: !!modelDefinition.recommended, + modelName: modelDefinition.modelName, + os: modelDefinition.os, + arch: modelDefinition.arch, + softwareLicense: modelDefinition.license, + licenseUrl: modelDefinition.licenseUrl, + supported: modelDefinition.supported, + disclaimer: modelDefinition.disclaimer, + } as ModelDownloadItem; + }); + + // show model downloads first + return [...notDownloaded, ...resultItems]; + } + + /** + * Assigns pipelines to trained models + */ + async assignPipelines(trainedModels: TrainedModelItem[]): Promise { + // For each model create a dict with model aliases and deployment ids for faster lookup + const modelToAliasesAndDeployments: Record> = Object.fromEntries( + trainedModels.map((model) => [ + model.model_id, + new Set([ + model.model_id, + ...(model.metadata?.model_aliases ?? []), + ...(isNLPModelItem(model) ? model.deployment_ids : []), + ]), + ]) + ); + + // Set of unique model ids, aliases, and deployment ids. + const modelIdsAndAliases: string[] = Object.values(modelToAliasesAndDeployments).flatMap((s) => + Array.from(s) + ); + + try { + // Get all pipelines first in one call: + const modelPipelinesMap = await this.getModelsPipelines(modelIdsAndAliases); + + trainedModels.forEach((model) => { + const modelAliasesAndDeployments = modelToAliasesAndDeployments[model.model_id]; + // Check model pipelines map for any pipelines associated with the model + for (const [modelEntityId, pipelines] of modelPipelinesMap) { + if (modelAliasesAndDeployments.has(modelEntityId)) { + // Merge pipeline definitions into the model + model.pipelines = model.pipelines + ? Object.assign(model.pipelines, pipelines) + : pipelines; + } + } + }); + } catch (e) { + // the user might not have required permissions to fetch pipelines + // log the error to the debug log as this might be a common situation and + // we don't need to fill kibana's log with these messages. + mlLog.debug(e); + } + } + + /** + * Assigns indices to trained models + */ + async assignModelIndices(trainedModels: TrainedModelItem[]): Promise { + // Get a list of all uniquer pipeline ids to retrieve mapping with indices + const pipelineIds = new Set( + trainedModels + .filter((model): model is TrainedModelWithPipelines => isDefined(model.pipelines)) + .flatMap((model) => Object.keys(model.pipelines)) + ); + + const pipelineToIndicesMap = await this.getPipelineToIndicesMap(pipelineIds); + + trainedModels.forEach((model) => { + if (!isEmpty(model.pipelines)) { + model.indices = Object.entries(pipelineToIndicesMap) + .filter(([pipelineId]) => !isEmpty(model.pipelines?.[pipelineId])) + .flatMap(([_, indices]) => indices); + } + }); + } + + /** + * Assign a check for each DFA model if origin job exists + */ + async assignDFAJobCheck(trainedModels: DFAModelItem[]): Promise { + try { + const dfaJobIds = trainedModels + .map((model) => { + const id = model.metadata?.analytics_config?.id; + if (id) { + return `${id}*`; + } + }) + .filter(isDefined); + + if (dfaJobIds.length > 0) { + const { data_frame_analytics: jobs } = await this._mlClient.getDataFrameAnalytics({ + id: dfaJobIds.join(','), + allow_no_match: true, + }); + + trainedModels.forEach((model) => { + const dfaId = model?.metadata?.analytics_config?.id; + if (dfaId !== undefined) { + // if this is a dfa model, set origin_job_exists + model.origin_job_exists = jobs.find((job) => job.id === dfaId) !== undefined; + } + }); + } + } catch (e) { + return; + } + } + + /** + * Returns a complete list of entities for the Trained Models UI + */ + async getTrainedModelList(): Promise { + const resp = await this._mlClient.getTrainedModels({ + size: 1000, + } as MlGetTrainedModelsRequest); + + let resultItems: TrainedModelUIItem[] = []; + + // Filter models based on enabled features + const filteredModels = filterForEnabledFeatureModels( + resp.trained_model_configs, + this._enabledFeatures + ) as TrainedModelConfigResponse[]; + + const formattedModels = filteredModels.map((model) => { + return { + ...model, + // Extract model types + type: [ + model.model_type, + ...(isBuiltInModel(model) ? [BUILT_IN_MODEL_TYPE] : []), + ...(isElasticModel(model) ? [ELASTIC_MODEL_TYPE] : []), + ...(typeof model.inference_config === 'object' + ? Object.keys(model.inference_config) + : []), + ].filter(isDefined), + }; + }); + + // Update inference endpoints info + await this.assignInferenceEndpoints(formattedModels); + + // Assign model stats + resultItems = await this.assignModelStats(formattedModels); + + if (this._enabledFeatures.nlp) { + resultItems = await this.includeModelDownloads(resultItems); + } + + const existingModels = resultItems.filter(isExistingModel); + + // Assign pipelines to existing models + await this.assignPipelines(existingModels); + + // Assign indices + await this.assignModelIndices(existingModels); + + await this.assignDFAJobCheck(resultItems.filter(isDFAModelItem)); + + return resultItems; + } + /** * Simulates the effect of the pipeline on given document. * @@ -170,12 +485,13 @@ export class ModelsProvider { } /** - * Retrieves the map of model ids and aliases with associated pipelines. + * Retrieves the map of model ids and aliases with associated pipelines, + * where key is a model, alias or deployment id, and value is a map of pipeline ids and pipeline definitions. * @param modelIds - Array of models ids and model aliases. */ async getModelsPipelines(modelIds: string[]) { - const modelIdsMap = new Map | null>( - modelIds.map((id: string) => [id, null]) + const modelIdsMap = new Map>( + modelIds.map((id: string) => [id, {}]) ); try { @@ -208,6 +524,53 @@ export class ModelsProvider { return modelIdsMap; } + /** + * Match pipelines to indices based on the default_pipeline setting in the index settings. + */ + async getPipelineToIndicesMap(pipelineIds: Set): Promise> { + const pipelineIdsToDestinationIndices: Record = {}; + + let indicesPermissions; + let indicesSettings; + + try { + indicesSettings = await this._client.asInternalUser.indices.getSettings(); + const hasPrivilegesResponse = await this._client.asCurrentUser.security.hasPrivileges({ + index: [ + { + names: Object.keys(indicesSettings), + privileges: ['read'], + }, + ], + }); + indicesPermissions = hasPrivilegesResponse.index; + } catch (e) { + // Possible that the user doesn't have permissions to view + if (e.meta?.statusCode !== 403) { + mlLog.error(e); + } + return pipelineIdsToDestinationIndices; + } + + // From list of model pipelines, find all indices that have pipeline set as index.default_pipeline + for (const [indexName, { settings }] of Object.entries(indicesSettings)) { + const defaultPipeline = settings?.index?.default_pipeline; + if ( + defaultPipeline && + pipelineIds.has(defaultPipeline) && + indicesPermissions[indexName]?.read === true + ) { + if (Array.isArray(pipelineIdsToDestinationIndices[defaultPipeline])) { + pipelineIdsToDestinationIndices[defaultPipeline].push(indexName); + } else { + pipelineIdsToDestinationIndices[defaultPipeline] = [indexName]; + } + } + } + + return pipelineIdsToDestinationIndices; + } + /** * Retrieves the network map and metadata of model ids, pipelines, and indices that are tied to the model ids. * @param modelIds - Array of models ids and model aliases. @@ -229,7 +592,6 @@ export class ModelsProvider { }; let pipelinesResponse; - let indicesSettings; try { pipelinesResponse = await this.getModelsPipelines([modelId]); @@ -264,44 +626,8 @@ export class ModelsProvider { } if (withIndices === true) { - const pipelineIdsToDestinationIndices: Record = {}; - - let indicesPermissions; - try { - indicesSettings = await this._client.asInternalUser.indices.getSettings(); - const hasPrivilegesResponse = await this._client.asCurrentUser.security.hasPrivileges({ - index: [ - { - names: Object.keys(indicesSettings), - privileges: ['read'], - }, - ], - }); - indicesPermissions = hasPrivilegesResponse.index; - } catch (e) { - // Possible that the user doesn't have permissions to view - // If so, gracefully exit - if (e.meta?.statusCode !== 403) { - // eslint-disable-next-line no-console - console.error(e); - } - return result; - } - - // 2. From list of model pipelines, find all indices that have pipeline set as index.default_pipeline - for (const [indexName, { settings }] of Object.entries(indicesSettings)) { - if ( - settings?.index?.default_pipeline && - pipelineIds.has(settings.index.default_pipeline) && - indicesPermissions[indexName]?.read === true - ) { - if (Array.isArray(pipelineIdsToDestinationIndices[settings.index.default_pipeline])) { - pipelineIdsToDestinationIndices[settings.index.default_pipeline].push(indexName); - } else { - pipelineIdsToDestinationIndices[settings.index.default_pipeline] = [indexName]; - } - } - } + const pipelineIdsToDestinationIndices: Record = + await this.getPipelineToIndicesMap(pipelineIds); // 3. Grab index information for all the indices found, and add their info to the map for (const [pipelineId, indexIds] of Object.entries(pipelineIdsToDestinationIndices)) { diff --git a/x-pack/plugins/ml/server/plugin.ts b/x-pack/plugins/ml/server/plugin.ts index 01f3ec7e5d131..f21d61509d1a9 100644 --- a/x-pack/plugins/ml/server/plugin.ts +++ b/x-pack/plugins/ml/server/plugin.ts @@ -222,7 +222,8 @@ export class MlServerPlugin getDataViews, () => this.auditService, () => this.isMlReady, - this.compatibleModuleType + this.compatibleModuleType, + this.enabledFeatures ); const routeInit: RouteInitialization = { diff --git a/x-pack/plugins/ml/server/routes/inference_models.ts b/x-pack/plugins/ml/server/routes/inference_models.ts index 866398ac56ce9..8318fadee8ebd 100644 --- a/x-pack/plugins/ml/server/routes/inference_models.ts +++ b/x-pack/plugins/ml/server/routes/inference_models.ts @@ -19,7 +19,7 @@ import { ML_INTERNAL_BASE_PATH } from '../../common/constants/app'; import { syncSavedObjectsFactory } from '../saved_objects'; export function inferenceModelRoutes( - { router, routeGuard }: RouteInitialization, + { router, routeGuard, getEnabledFeatures }: RouteInitialization, cloud: CloudSetup ) { router.versioned @@ -48,7 +48,12 @@ export function inferenceModelRoutes( async ({ client, mlClient, request, response, mlSavedObjectService }) => { try { const { inferenceId, taskType } = request.params; - const body = await modelsProvider(client, mlClient, cloud).createInferenceEndpoint( + const body = await modelsProvider( + client, + mlClient, + cloud, + getEnabledFeatures() + ).createInferenceEndpoint( inferenceId, taskType as InferenceTaskType, request.body as InferenceInferenceEndpoint diff --git a/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts b/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts index f8305cff189ed..1c2ec984fc286 100644 --- a/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts +++ b/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts @@ -65,8 +65,6 @@ export const optionalModelIdSchema = schema.object({ export const getInferenceQuerySchema = schema.object({ size: schema.maybe(schema.string()), - with_pipelines: schema.maybe(schema.string()), - with_indices: schema.maybe(schema.oneOf([schema.string(), schema.boolean()])), include: schema.maybe(schema.string()), }); diff --git a/x-pack/plugins/ml/server/routes/trained_models.test.ts b/x-pack/plugins/ml/server/routes/trained_models.test.ts deleted file mode 100644 index ca3eb19e757c6..0000000000000 --- a/x-pack/plugins/ml/server/routes/trained_models.test.ts +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { errors } from '@elastic/elasticsearch'; -import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks'; -import type { TrainedModelConfigResponse } from '../../common/types/trained_models'; -import { populateInferenceServicesProvider } from './trained_models'; -import { mlLog } from '../lib/log'; - -jest.mock('../lib/log'); - -describe('populateInferenceServicesProvider', () => { - const client = elasticsearchClientMock.createScopedClusterClient(); - - let trainedModels: TrainedModelConfigResponse[]; - - const inferenceServices = [ - { - service: 'elser', - model_id: 'elser_test', - service_settings: { model_id: '.elser_model_2' }, - }, - { service: 'open_api_01', model_id: 'open_api_model', service_settings: {} }, - ]; - - beforeEach(() => { - trainedModels = [ - { model_id: '.elser_model_2' }, - { model_id: 'model2' }, - ] as TrainedModelConfigResponse[]; - - client.asInternalUser.transport.request.mockResolvedValue({ endpoints: inferenceServices }); - - jest.clearAllMocks(); - }); - - afterEach(() => { - jest.clearAllMocks(); - }); - - describe('when the user has required privileges', () => { - beforeEach(() => { - client.asCurrentUser.transport.request.mockResolvedValue({ endpoints: inferenceServices }); - }); - - test('should populate inference services for trained models', async () => { - const populateInferenceServices = populateInferenceServicesProvider(client); - // act - await populateInferenceServices(trainedModels, false); - - // assert - expect(client.asCurrentUser.transport.request).toHaveBeenCalledWith({ - method: 'GET', - path: '/_inference/_all', - }); - - expect(client.asInternalUser.transport.request).not.toHaveBeenCalled(); - - expect(trainedModels[0].inference_apis).toEqual([ - { - model_id: 'elser_test', - service: 'elser', - service_settings: { model_id: '.elser_model_2' }, - }, - ]); - expect(trainedModels[0].hasInferenceServices).toBe(true); - - expect(trainedModels[1].inference_apis).toEqual(undefined); - expect(trainedModels[1].hasInferenceServices).toBe(false); - - expect(mlLog.error).not.toHaveBeenCalled(); - }); - }); - - describe('when the user does not have required privileges', () => { - beforeEach(() => { - client.asCurrentUser.transport.request.mockRejectedValue( - new errors.ResponseError( - elasticsearchClientMock.createApiResponse({ - statusCode: 403, - body: { message: 'not allowed' }, - }) - ) - ); - }); - - test('should retry with internal user if an error occurs', async () => { - const populateInferenceServices = populateInferenceServicesProvider(client); - await populateInferenceServices(trainedModels, false); - - // assert - expect(client.asCurrentUser.transport.request).toHaveBeenCalledWith({ - method: 'GET', - path: '/_inference/_all', - }); - - expect(client.asInternalUser.transport.request).toHaveBeenCalledWith({ - method: 'GET', - path: '/_inference/_all', - }); - - expect(trainedModels[0].inference_apis).toEqual(undefined); - expect(trainedModels[0].hasInferenceServices).toBe(true); - - expect(trainedModels[1].inference_apis).toEqual(undefined); - expect(trainedModels[1].hasInferenceServices).toBe(false); - - expect(mlLog.error).not.toHaveBeenCalled(); - }); - }); - - test('should not retry on any other error than 403', async () => { - const notFoundError = new errors.ResponseError( - elasticsearchClientMock.createApiResponse({ - statusCode: 404, - body: { message: 'not found' }, - }) - ); - - client.asCurrentUser.transport.request.mockRejectedValue(notFoundError); - - const populateInferenceServices = populateInferenceServicesProvider(client); - await populateInferenceServices(trainedModels, false); - - // assert - expect(client.asCurrentUser.transport.request).toHaveBeenCalledWith({ - method: 'GET', - path: '/_inference/_all', - }); - - expect(client.asInternalUser.transport.request).not.toHaveBeenCalled(); - - expect(mlLog.error).toHaveBeenCalledWith(notFoundError); - }); -}); diff --git a/x-pack/plugins/ml/server/routes/trained_models.ts b/x-pack/plugins/ml/server/routes/trained_models.ts index 2782c4be18207..adedb37b4a7a5 100644 --- a/x-pack/plugins/ml/server/routes/trained_models.ts +++ b/x-pack/plugins/ml/server/routes/trained_models.ts @@ -6,21 +6,18 @@ */ import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; -import { groupBy } from 'lodash'; +import type { CloudSetup } from '@kbn/cloud-plugin/server'; import { schema } from '@kbn/config-schema'; import type { ErrorType } from '@kbn/ml-error-utils'; -import type { CloudSetup } from '@kbn/cloud-plugin/server'; -import type { - ElasticCuratedModelName, - ElserVersion, - InferenceAPIConfigResponse, -} from '@kbn/ml-trained-models-utils'; -import { isDefined } from '@kbn/ml-is-defined'; -import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server'; +import type { ElasticCuratedModelName, ElserVersion } from '@kbn/ml-trained-models-utils'; +import { TRAINED_MODEL_TYPE } from '@kbn/ml-trained-models-utils'; +import { ML_INTERNAL_BASE_PATH, type MlFeatures } from '../../common/constants/app'; import { DEFAULT_TRAINED_MODELS_PAGE_SIZE } from '../../common/constants/trained_models'; -import { type MlFeatures, ML_INTERNAL_BASE_PATH } from '../../common/constants/app'; -import type { RouteInitialization } from '../types'; +import { type TrainedModelConfigResponse } from '../../common/types/trained_models'; import { wrapError } from '../client/error_wrapper'; +import { modelsProvider } from '../models/model_management'; +import type { RouteInitialization } from '../types'; +import { forceQuerySchema } from './schemas/anomaly_detectors_schema'; import { createIngestPipelineSchema, curatedModelsParamsSchema, @@ -39,70 +36,57 @@ import { threadingParamsQuerySchema, updateDeploymentParamsSchema, } from './schemas/inference_schema'; -import type { PipelineDefinition } from '../../common/types/trained_models'; -import { type TrainedModelConfigResponse } from '../../common/types/trained_models'; -import { mlLog } from '../lib/log'; -import { forceQuerySchema } from './schemas/anomaly_detectors_schema'; -import { modelsProvider } from '../models/model_management'; export function filterForEnabledFeatureModels< T extends TrainedModelConfigResponse | estypes.MlTrainedModelConfig >(models: T[], enabledFeatures: MlFeatures) { let filteredModels = models; if (enabledFeatures.nlp === false) { - filteredModels = filteredModels.filter((m) => m.model_type === 'tree_ensemble'); + filteredModels = filteredModels.filter((m) => m.model_type !== TRAINED_MODEL_TYPE.PYTORCH); } - if (enabledFeatures.dfa === false) { - filteredModels = filteredModels.filter((m) => m.model_type !== 'tree_ensemble'); + filteredModels = filteredModels.filter( + (m) => m.model_type !== TRAINED_MODEL_TYPE.TREE_ENSEMBLE + ); } - return filteredModels; } -export const populateInferenceServicesProvider = (client: IScopedClusterClient) => { - return async function populateInferenceServices( - trainedModels: TrainedModelConfigResponse[], - asInternal: boolean = false - ) { - const esClient = asInternal ? client.asInternalUser : client.asCurrentUser; - - try { - // Check if model is used by an inference service - const { endpoints } = await esClient.transport.request<{ - endpoints: InferenceAPIConfigResponse[]; - }>({ - method: 'GET', - path: `/_inference/_all`, - }); - - const inferenceAPIMap = groupBy( - endpoints, - (endpoint) => endpoint.service === 'elser' && endpoint.service_settings.model_id - ); - - for (const model of trainedModels) { - const inferenceApis = inferenceAPIMap[model.model_id]; - model.hasInferenceServices = !!inferenceApis; - if (model.hasInferenceServices && !asInternal) { - model.inference_apis = inferenceApis; - } - } - } catch (e) { - if (!asInternal && e.statusCode === 403) { - // retry with internal user to get an indicator if models has associated inference services, without mentioning the names - await populateInferenceServices(trainedModels, true); - } else { - mlLog.error(e); - } - } - }; -}; - export function trainedModelsRoutes( { router, routeGuard, getEnabledFeatures }: RouteInitialization, cloud: CloudSetup ) { + router.versioned + .get({ + path: `${ML_INTERNAL_BASE_PATH}/trained_models_list`, + access: 'internal', + security: { + authz: { + requiredPrivileges: ['ml:canGetTrainedModels'], + }, + }, + summary: 'Get trained models list', + description: + 'Retrieves a complete list of trained models with stats, pipelines, and indices.', + }) + .addVersion( + { + version: '1', + validate: false, + }, + routeGuard.fullLicenseAPIGuard(async ({ client, mlClient, request, response }) => { + try { + const modelsClient = modelsProvider(client, mlClient, cloud, getEnabledFeatures()); + const models = await modelsClient.getTrainedModelList(); + return response.ok({ + body: models, + }); + } catch (e) { + return response.customError(wrapError(e)); + } + }) + ); + router.versioned .get({ path: `${ML_INTERNAL_BASE_PATH}/trained_models/{modelId?}`, @@ -128,14 +112,7 @@ export function trainedModelsRoutes( routeGuard.fullLicenseAPIGuard(async ({ client, mlClient, request, response }) => { try { const { modelId } = request.params; - const { - with_pipelines: withPipelines, - with_indices: withIndicesRaw, - ...getTrainedModelsRequestParams - } = request.query; - - const withIndices = - request.query.with_indices === 'true' || request.query.with_indices === true; + const { ...getTrainedModelsRequestParams } = request.query; const resp = await mlClient.getTrainedModels({ ...getTrainedModelsRequestParams, @@ -146,126 +123,8 @@ export function trainedModelsRoutes( // @ts-ignore const result = resp.trained_model_configs as TrainedModelConfigResponse[]; - const populateInferenceServices = populateInferenceServicesProvider(client); - await populateInferenceServices(result, false); - - try { - if (withPipelines) { - // Also need to retrieve the list of deployment IDs from stats - const stats = await mlClient.getTrainedModelsStats({ - ...(modelId ? { model_id: modelId } : {}), - size: 10000, - }); - - const modelDeploymentsMap = stats.trained_model_stats.reduce((acc, curr) => { - if (!curr.deployment_stats) return acc; - // @ts-ignore elasticsearch-js client is missing deployment_id - const deploymentId = curr.deployment_stats.deployment_id; - if (acc[curr.model_id]) { - acc[curr.model_id].push(deploymentId); - } else { - acc[curr.model_id] = [deploymentId]; - } - return acc; - }, {} as Record); - - const modelIdsAndAliases: string[] = Array.from( - new Set([ - ...result - .map(({ model_id: id, metadata }) => { - return [id, ...(metadata?.model_aliases ?? [])]; - }) - .flat(), - ...Object.values(modelDeploymentsMap).flat(), - ]) - ); - const modelsClient = modelsProvider(client, mlClient, cloud); - - const modelsPipelinesAndIndices = await Promise.all( - modelIdsAndAliases.map(async (modelIdOrAlias) => { - return { - modelIdOrAlias, - result: await modelsClient.getModelsPipelinesAndIndicesMap(modelIdOrAlias, { - withIndices, - }), - }; - }) - ); - - for (const model of result) { - const modelAliases = model.metadata?.model_aliases ?? []; - const modelMap = modelsPipelinesAndIndices.find( - (d) => d.modelIdOrAlias === model.model_id - )?.result; - - const allRelatedModels = modelsPipelinesAndIndices - .filter( - (m) => - [ - model.model_id, - ...modelAliases, - ...(modelDeploymentsMap[model.model_id] ?? []), - ].findIndex((alias) => alias === m.modelIdOrAlias) > -1 - ) - .map((r) => r?.result) - .filter(isDefined); - const ingestPipelinesFromModelAliases = allRelatedModels - .map((r) => r?.ingestPipelines) - .filter(isDefined) as Array>>; - - model.pipelines = ingestPipelinesFromModelAliases.reduce< - Record - >((allPipelines, modelsToPipelines) => { - for (const [, pipelinesObj] of modelsToPipelines?.entries()) { - Object.entries(pipelinesObj).forEach(([pipelineId, pipelineInfo]) => { - allPipelines[pipelineId] = pipelineInfo; - }); - } - return allPipelines; - }, {}); - - if (modelMap && withIndices) { - model.indices = modelMap.indices; - } - } - } - } catch (e) { - // the user might not have required permissions to fetch pipelines - // log the error to the debug log as this might be a common situation and - // we don't need to fill kibana's log with these messages. - mlLog.debug(e); - } - const filteredModels = filterForEnabledFeatureModels(result, getEnabledFeatures()); - try { - const jobIds = filteredModels - .map((model) => { - const id = model.metadata?.analytics_config?.id; - if (id) { - return `${id}*`; - } - }) - .filter((id) => id !== undefined); - - if (jobIds.length) { - const { data_frame_analytics: jobs } = await mlClient.getDataFrameAnalytics({ - id: jobIds.join(','), - allow_no_match: true, - }); - - filteredModels.forEach((model) => { - const dfaId = model?.metadata?.analytics_config?.id; - if (dfaId !== undefined) { - // if this is a dfa model, set origin_job_exists - model.origin_job_exists = jobs.find((job) => job.id === dfaId) !== undefined; - } - }); - } - } catch (e) { - // Swallow error to prevent blocking trained models result - } - return response.ok({ body: filteredModels, }); @@ -367,9 +226,12 @@ export function trainedModelsRoutes( routeGuard.fullLicenseAPIGuard(async ({ client, request, mlClient, response }) => { try { const { modelId } = request.params; - const result = await modelsProvider(client, mlClient, cloud).getModelsPipelines( - modelId.split(',') - ); + const result = await modelsProvider( + client, + mlClient, + cloud, + getEnabledFeatures() + ).getModelsPipelines(modelId.split(',')); return response.ok({ body: [...result].map(([id, pipelines]) => ({ model_id: id, pipelines })), }); @@ -396,9 +258,14 @@ export function trainedModelsRoutes( version: '1', validate: false, }, - routeGuard.fullLicenseAPIGuard(async ({ client, request, mlClient, response }) => { + routeGuard.fullLicenseAPIGuard(async ({ client, mlClient, response }) => { try { - const body = await modelsProvider(client, mlClient, cloud).getPipelines(); + const body = await modelsProvider( + client, + mlClient, + cloud, + getEnabledFeatures() + ).getPipelines(); return response.ok({ body, }); @@ -432,10 +299,12 @@ export function trainedModelsRoutes( routeGuard.fullLicenseAPIGuard(async ({ client, request, mlClient, response }) => { try { const { pipeline, pipelineName } = request.body; - const body = await modelsProvider(client, mlClient, cloud).createInferencePipeline( - pipeline!, - pipelineName - ); + const body = await modelsProvider( + client, + mlClient, + cloud, + getEnabledFeatures() + ).createInferencePipeline(pipeline!, pipelineName); return response.ok({ body, }); @@ -517,7 +386,12 @@ export function trainedModelsRoutes( if (withPipelines) { // first we need to delete pipelines, otherwise ml api return an error - await modelsProvider(client, mlClient, cloud).deleteModelPipelines(modelId.split(',')); + await modelsProvider( + client, + mlClient, + cloud, + getEnabledFeatures() + ).deleteModelPipelines(modelId.split(',')); } const body = await mlClient.deleteTrainedModel({ @@ -773,7 +647,12 @@ export function trainedModelsRoutes( }, routeGuard.fullLicenseAPIGuard(async ({ response, mlClient, client }) => { try { - const body = await modelsProvider(client, mlClient, cloud).getModelDownloads(); + const body = await modelsProvider( + client, + mlClient, + cloud, + getEnabledFeatures() + ).getModelDownloads(); return response.ok({ body, @@ -809,7 +688,7 @@ export function trainedModelsRoutes( try { const { version } = request.query; - const body = await modelsProvider(client, mlClient, cloud).getELSER( + const body = await modelsProvider(client, mlClient, cloud, getEnabledFeatures()).getELSER( version ? { version: Number(version) as ElserVersion } : undefined ); @@ -847,10 +726,12 @@ export function trainedModelsRoutes( async ({ client, mlClient, request, response, mlSavedObjectService }) => { try { const { modelId } = request.params; - const body = await modelsProvider(client, mlClient, cloud).installElasticModel( - modelId, - mlSavedObjectService - ); + const body = await modelsProvider( + client, + mlClient, + cloud, + getEnabledFeatures() + ).installElasticModel(modelId, mlSavedObjectService); return response.ok({ body, @@ -882,7 +763,12 @@ export function trainedModelsRoutes( routeGuard.fullLicenseAPIGuard( async ({ client, mlClient, request, response, mlSavedObjectService }) => { try { - const body = await modelsProvider(client, mlClient, cloud).getModelsDownloadStatus(); + const body = await modelsProvider( + client, + mlClient, + cloud, + getEnabledFeatures() + ).getModelsDownloadStatus(); return response.ok({ body, @@ -920,10 +806,14 @@ export function trainedModelsRoutes( routeGuard.fullLicenseAPIGuard( async ({ client, mlClient, request, response, mlSavedObjectService }) => { try { - const body = await modelsProvider(client, mlClient, cloud).getCuratedModelConfig( - request.params.modelName as ElasticCuratedModelName, - { version: request.query.version as ElserVersion } - ); + const body = await modelsProvider( + client, + mlClient, + cloud, + getEnabledFeatures() + ).getCuratedModelConfig(request.params.modelName as ElasticCuratedModelName, { + version: request.query.version as ElserVersion, + }); return response.ok({ body, diff --git a/x-pack/plugins/ml/server/shared_services/providers/trained_models.ts b/x-pack/plugins/ml/server/shared_services/providers/trained_models.ts index 36d639066f97a..04f12d82688e1 100644 --- a/x-pack/plugins/ml/server/shared_services/providers/trained_models.ts +++ b/x-pack/plugins/ml/server/shared_services/providers/trained_models.ts @@ -12,6 +12,7 @@ import type { GetModelDownloadConfigOptions, ModelDefinitionResponse, } from '@kbn/ml-trained-models-utils'; +import type { MlFeatures } from '../../../common/constants/app'; import type { MlInferTrainedModelRequest, MlStopTrainedModelDeploymentRequest, @@ -59,7 +60,8 @@ export interface TrainedModelsProvider { export function getTrainedModelsProvider( getGuards: GetGuards, - cloud: CloudSetup + cloud: CloudSetup, + enabledFeatures: MlFeatures ): TrainedModelsProvider { return { trainedModelsProvider(request: KibanaRequest, savedObjectsClient: SavedObjectsClientContract) { @@ -134,7 +136,9 @@ export function getTrainedModelsProvider( .isFullLicense() .hasMlCapabilities(['canGetTrainedModels']) .ok(async ({ scopedClient, mlClient }) => { - return modelsProvider(scopedClient, mlClient, cloud).getELSER(params); + return modelsProvider(scopedClient, mlClient, cloud, enabledFeatures).getELSER( + params + ); }); }, async getCuratedModelConfig(...params: GetCuratedModelConfigParams) { @@ -142,7 +146,12 @@ export function getTrainedModelsProvider( .isFullLicense() .hasMlCapabilities(['canGetTrainedModels']) .ok(async ({ scopedClient, mlClient }) => { - return modelsProvider(scopedClient, mlClient, cloud).getCuratedModelConfig(...params); + return modelsProvider( + scopedClient, + mlClient, + cloud, + enabledFeatures + ).getCuratedModelConfig(...params); }); }, async installElasticModel(modelId: string) { @@ -150,10 +159,12 @@ export function getTrainedModelsProvider( .isFullLicense() .hasMlCapabilities(['canGetTrainedModels']) .ok(async ({ scopedClient, mlClient, mlSavedObjectService }) => { - return modelsProvider(scopedClient, mlClient, cloud).installElasticModel( - modelId, - mlSavedObjectService - ); + return modelsProvider( + scopedClient, + mlClient, + cloud, + enabledFeatures + ).installElasticModel(modelId, mlSavedObjectService); }); }, }; diff --git a/x-pack/plugins/ml/server/shared_services/shared_services.ts b/x-pack/plugins/ml/server/shared_services/shared_services.ts index f7b0d42409221..ed71319cfb6a1 100644 --- a/x-pack/plugins/ml/server/shared_services/shared_services.ts +++ b/x-pack/plugins/ml/server/shared_services/shared_services.ts @@ -19,7 +19,7 @@ import type { CloudSetup } from '@kbn/cloud-plugin/server'; import type { PluginStart as DataViewsPluginStart } from '@kbn/data-views-plugin/server'; import type { SecurityPluginSetup } from '@kbn/security-plugin/server'; import type { FieldFormatsStart } from '@kbn/field-formats-plugin/server'; -import type { CompatibleModule } from '../../common/constants/app'; +import type { CompatibleModule, MlFeatures } from '../../common/constants/app'; import type { MlLicense } from '../../common/license'; import { licenseChecks } from './license_checks'; @@ -113,7 +113,8 @@ export function createSharedServices( getDataViews: () => DataViewsPluginStart, getAuditService: () => CoreAuditService | null, isMlReady: () => Promise, - compatibleModuleType: CompatibleModule | null + compatibleModuleType: CompatibleModule | null, + enabledFeatures: MlFeatures ): { sharedServicesProviders: SharedServices; internalServicesProviders: MlServicesProviders; @@ -191,7 +192,7 @@ export function createSharedServices( ...getResultsServiceProvider(getGuards), ...getMlSystemProvider(getGuards, mlLicense, getSpaces, cloud, resolveMlCapabilities), ...getAlertingServiceProvider(getGuards), - ...getTrainedModelsProvider(getGuards, cloud), + ...getTrainedModelsProvider(getGuards, cloud, enabledFeatures), }, /** * Services providers for ML internal usage diff --git a/x-pack/plugins/translations/translations/fr-FR.json b/x-pack/plugins/translations/translations/fr-FR.json index f88ffac9e5bce..8833313d33b62 100644 --- a/x-pack/plugins/translations/translations/fr-FR.json +++ b/x-pack/plugins/translations/translations/fr-FR.json @@ -31761,7 +31761,6 @@ "xpack.ml.trainedModels.modelsList.expandRow": "Développer", "xpack.ml.trainedModels.modelsList.fetchDeletionErrorMessage": "La suppression {modelsCount, plural, one {du modèle} other {des modèles}} a échoué", "xpack.ml.trainedModels.modelsList.fetchFailedErrorMessage": "Erreur lors du chargement des modèles entraînés", - "xpack.ml.trainedModels.modelsList.fetchModelStatsErrorMessage": "Erreur lors du chargement des statistiques des modèles entraînés", "xpack.ml.trainedModels.modelsList.forceStopDialog.cancelText": "Annuler", "xpack.ml.trainedModels.modelsList.forceStopDialog.confirmText": "Arrêt", "xpack.ml.trainedModels.modelsList.forceStopDialog.hasInferenceServicesWarning": "Ce modèle est utilisé par l'API _inference", diff --git a/x-pack/plugins/translations/translations/ja-JP.json b/x-pack/plugins/translations/translations/ja-JP.json index 44d6cd0653bfa..5637ba0f16fc2 100644 --- a/x-pack/plugins/translations/translations/ja-JP.json +++ b/x-pack/plugins/translations/translations/ja-JP.json @@ -31732,7 +31732,6 @@ "xpack.ml.trainedModels.modelsList.expandRow": "拡張", "xpack.ml.trainedModels.modelsList.fetchDeletionErrorMessage": "{modelsCount, plural, other {モデル}}を削除できませんでした", "xpack.ml.trainedModels.modelsList.fetchFailedErrorMessage": "学習済みモデルの読み込みエラー", - "xpack.ml.trainedModels.modelsList.fetchModelStatsErrorMessage": "学習済みモデル統計情報の読み込みエラー", "xpack.ml.trainedModels.modelsList.forceStopDialog.cancelText": "キャンセル", "xpack.ml.trainedModels.modelsList.forceStopDialog.confirmText": "終了", "xpack.ml.trainedModels.modelsList.forceStopDialog.hasInferenceServicesWarning": "モデルは_inference APIによって使用されます。", diff --git a/x-pack/plugins/translations/translations/zh-CN.json b/x-pack/plugins/translations/translations/zh-CN.json index f8fc385b8e339..c300413636870 100644 --- a/x-pack/plugins/translations/translations/zh-CN.json +++ b/x-pack/plugins/translations/translations/zh-CN.json @@ -31794,7 +31794,6 @@ "xpack.ml.trainedModels.modelsList.expandRow": "展开", "xpack.ml.trainedModels.modelsList.fetchDeletionErrorMessage": "{modelsCount, plural, other {# 个模型}}删除失败", "xpack.ml.trainedModels.modelsList.fetchFailedErrorMessage": "加载已训练模型时出错", - "xpack.ml.trainedModels.modelsList.fetchModelStatsErrorMessage": "加载已训练模型统计信息时出错", "xpack.ml.trainedModels.modelsList.forceStopDialog.cancelText": "取消", "xpack.ml.trainedModels.modelsList.forceStopDialog.confirmText": "停止点", "xpack.ml.trainedModels.modelsList.forceStopDialog.hasInferenceServicesWarning": "此模型由 _inference API 使用", diff --git a/x-pack/test/api_integration/apis/ml/trained_models/get_models.ts b/x-pack/test/api_integration/apis/ml/trained_models/get_models.ts index a3953a87b82b2..654f55c7e1254 100644 --- a/x-pack/test/api_integration/apis/ml/trained_models/get_models.ts +++ b/x-pack/test/api_integration/apis/ml/trained_models/get_models.ts @@ -16,61 +16,27 @@ export default ({ getService }: FtrProviderContext) => { const esDeleteAllIndices = getService('esDeleteAllIndices'); describe('GET trained_models', () => { - let testModelIds: string[] = []; - before(async () => { await ml.api.initSavedObjects(); await ml.testResources.setKibanaTimeZoneToUTC(); - testModelIds = await ml.api.createTestTrainedModels('regression', 5, true); - await ml.api.createModelAlias('dfa_regression_model_n_0', 'dfa_regression_model_alias'); - await ml.api.createIngestPipeline('dfa_regression_model_alias'); - - // Creating an indices that are tied to modelId: dfa_regression_model_n_1 - await ml.api.createIndex(`user-index_dfa_regression_model_n_1`, undefined, { - index: { default_pipeline: `pipeline_dfa_regression_model_n_1` }, - }); + await ml.api.createTestTrainedModels('regression', 5, true); }); after(async () => { await esDeleteAllIndices('user-index_dfa*'); - - // delete created ingest pipelines - await Promise.all( - ['dfa_regression_model_alias', ...testModelIds].map((modelId) => - ml.api.deleteIngestPipeline(modelId) - ) - ); await ml.testResources.cleanMLSavedObjects(); await ml.api.cleanMlIndices(); }); - it('returns all trained models with associated pipelines including aliases', async () => { + it('returns all trained models', async () => { const { body, status } = await supertest - .get(`/internal/ml/trained_models?with_pipelines=true`) + .get(`/internal/ml/trained_models`) .auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER)) .set(getCommonRequestHeader('1')); ml.api.assertResponseStatusCode(200, status, body); // Created models + system model expect(body.length).to.eql(6); - - const sampleModel = body.find((v: any) => v.model_id === 'dfa_regression_model_n_0'); - - expect(Object.keys(sampleModel.pipelines).length).to.eql(2); - }); - - it('returns models without pipeline in case user does not have required permission', async () => { - const { body, status } = await supertest - .get(`/internal/ml/trained_models?with_pipelines=true&with_indices=true`) - .auth(USER.ML_VIEWER, ml.securityCommon.getPasswordForUser(USER.ML_VIEWER)) - .set(getCommonRequestHeader('1')); - ml.api.assertResponseStatusCode(200, status, body); - - // Created models + system model - expect(body.length).to.eql(6); - const sampleModel = body.find((v: any) => v.model_id === 'dfa_regression_model_n_0'); - - expect(sampleModel.pipelines).to.eql(undefined); }); it('returns trained model by id', async () => { @@ -84,58 +50,6 @@ export default ({ getService }: FtrProviderContext) => { const sampleModel = body[0]; expect(sampleModel.model_id).to.eql('dfa_regression_model_n_1'); - expect(sampleModel.pipelines).to.eql(undefined); - expect(sampleModel.indices).to.eql(undefined); - }); - - it('returns trained model by id with_pipelines=true,with_indices=false', async () => { - const { body, status } = await supertest - .get( - `/internal/ml/trained_models/dfa_regression_model_n_1?with_pipelines=true&with_indices=false` - ) - .auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER)) - .set(getCommonRequestHeader('1')); - ml.api.assertResponseStatusCode(200, status, body); - - expect(body.length).to.eql(1); - const sampleModel = body[0]; - - expect(sampleModel.model_id).to.eql('dfa_regression_model_n_1'); - expect(Object.keys(sampleModel.pipelines).length).to.eql( - 1, - `Expected number of pipelines for dfa_regression_model_n_1 to be ${1} (got ${ - Object.keys(sampleModel.pipelines).length - })` - ); - expect(sampleModel.indices).to.eql( - undefined, - `Expected indices for dfa_regression_model_n_1 to be undefined (got ${sampleModel.indices})` - ); - }); - - it('returns trained model by id with_pipelines=true,with_indices=true', async () => { - const { body, status } = await supertest - .get( - `/internal/ml/trained_models/dfa_regression_model_n_1?with_pipelines=true&with_indices=true` - ) - .auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER)) - .set(getCommonRequestHeader('1')); - ml.api.assertResponseStatusCode(200, status, body); - - const sampleModel = body[0]; - expect(sampleModel.model_id).to.eql('dfa_regression_model_n_1'); - expect(Object.keys(sampleModel.pipelines).length).to.eql( - 1, - `Expected number of pipelines for dfa_regression_model_n_1 to be ${1} (got ${ - Object.keys(sampleModel.pipelines).length - })` - ); - expect(sampleModel.indices.length).to.eql( - 1, - `Expected number of indices for dfa_regression_model_n_1 to be ${1} (got ${ - sampleModel.indices.length - })` - ); }); it('returns 404 if requested trained model does not exist', async () => { diff --git a/x-pack/test/api_integration/apis/ml/trained_models/index.ts b/x-pack/test/api_integration/apis/ml/trained_models/index.ts index c9bf98545e2b4..319899ec9a693 100644 --- a/x-pack/test/api_integration/apis/ml/trained_models/index.ts +++ b/x-pack/test/api_integration/apis/ml/trained_models/index.ts @@ -9,6 +9,7 @@ import { FtrProviderContext } from '../../../ftr_provider_context'; export default function ({ loadTestFile }: FtrProviderContext) { describe('trained models', function () { + loadTestFile(require.resolve('./trained_models_list')); loadTestFile(require.resolve('./get_models')); loadTestFile(require.resolve('./get_model_stats')); loadTestFile(require.resolve('./get_model_pipelines')); diff --git a/x-pack/test/api_integration/apis/ml/trained_models/trained_models_list.ts b/x-pack/test/api_integration/apis/ml/trained_models/trained_models_list.ts new file mode 100644 index 0000000000000..1feac44b13ca8 --- /dev/null +++ b/x-pack/test/api_integration/apis/ml/trained_models/trained_models_list.ts @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import expect from '@kbn/expect'; +import { FtrProviderContext } from '../../../ftr_provider_context'; +import { USER } from '../../../../functional/services/ml/security_common'; +import { getCommonRequestHeader } from '../../../../functional/services/ml/common_api'; + +export default ({ getService }: FtrProviderContext) => { + const supertest = getService('supertestWithoutAuth'); + const ml = getService('ml'); + const esDeleteAllIndices = getService('esDeleteAllIndices'); + + describe('GET trained_models_list', () => { + let testModelIds: string[] = []; + + before(async () => { + await ml.api.initSavedObjects(); + await ml.testResources.setKibanaTimeZoneToUTC(); + testModelIds = await ml.api.createTestTrainedModels('regression', 5, true); + await ml.api.createModelAlias('dfa_regression_model_n_0', 'dfa_regression_model_alias'); + await ml.api.createIngestPipeline('dfa_regression_model_alias'); + + // Creating an index that is tied to modelId: dfa_regression_model_n_1 + await ml.api.createIndex(`user-index_dfa_regression_model_n_1`, undefined, { + index: { default_pipeline: `pipeline_dfa_regression_model_n_1` }, + }); + }); + + after(async () => { + await esDeleteAllIndices('user-index_dfa*'); + + // delete created ingest pipelines + await Promise.all( + ['dfa_regression_model_alias', ...testModelIds].map((modelId) => + ml.api.deleteIngestPipeline(modelId) + ) + ); + await ml.testResources.cleanMLSavedObjects(); + await ml.api.cleanMlIndices(); + }); + + it('returns a formatted list of trained model with stats, associated pipelines and indices', async () => { + const { body, status } = await supertest + .get(`/internal/ml/trained_models_list`) + .auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER)) + .set(getCommonRequestHeader('1')); + ml.api.assertResponseStatusCode(200, status, body); + + // Created models + system model + model downloads + expect(body.length).to.eql(10); + + const dfaRegressionN0 = body.find((v: any) => v.model_id === 'dfa_regression_model_n_0'); + + expect(Object.keys(dfaRegressionN0.pipelines).length).to.eql(2); + + const dfaRegressionN1 = body.find((v: any) => v.model_id === 'dfa_regression_model_n_1'); + expect(Object.keys(dfaRegressionN1.pipelines).length).to.eql( + 1, + `Expected number of pipelines for dfa_regression_model_n_1 to be ${1} (got ${ + Object.keys(dfaRegressionN1.pipelines).length + })` + ); + expect(dfaRegressionN1.indices.length).to.eql( + 1, + `Expected number of indices for dfa_regression_model_n_1 to be ${1} (got ${ + dfaRegressionN1.indices.length + })` + ); + }); + + it('returns models without pipeline in case user does not have required permission', async () => { + const { body, status } = await supertest + .get(`/internal/ml/trained_models_list`) + .auth(USER.ML_VIEWER, ml.securityCommon.getPasswordForUser(USER.ML_VIEWER)) + .set(getCommonRequestHeader('1')); + ml.api.assertResponseStatusCode(200, status, body); + + expect(body.length).to.eql(10); + const sampleModel = body.find((v: any) => v.model_id === 'dfa_regression_model_n_0'); + expect(sampleModel.pipelines).to.eql(undefined); + }); + + it('returns an error for unauthorized user', async () => { + const { body, status } = await supertest + .get(`/internal/ml/trained_models_list`) + .auth(USER.ML_UNAUTHORIZED, ml.securityCommon.getPasswordForUser(USER.ML_UNAUTHORIZED)) + .set(getCommonRequestHeader('1')); + ml.api.assertResponseStatusCode(403, status, body); + }); + }); +}; diff --git a/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts b/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts index 7977f17bf5f65..c0d4af068832e 100644 --- a/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts +++ b/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts @@ -17,6 +17,8 @@ export default function ({ getService }: FtrProviderContext) { id: model.name, })); + const modelAllSpaces = SUPPORTED_TRAINED_MODELS.TINY_ELSER; + describe('trained models', function () { // 'Created at' will be different on each run, // so we will just assert that the value is in the expected timestamp format. @@ -91,6 +93,10 @@ export default function ({ getService }: FtrProviderContext) { await ml.api.importTrainedModel(model.id, model.name); } + // Assign model to all spaces + await ml.api.updateTrainedModelSpaces(modelAllSpaces.name, ['*'], ['default']); + await ml.api.assertTrainedModelSpaces(modelAllSpaces.name, ['*']); + await ml.api.createTestTrainedModels('classification', 15, true); await ml.api.createTestTrainedModels('regression', 15); @@ -173,9 +179,10 @@ export default function ({ getService }: FtrProviderContext) { await ml.securityUI.logout(); }); - it('should not be able to delete a model assigned to all spaces, and show a warning copy explaining the situation', async () => { - await ml.testExecution.logTestStep('should select the model named elser_model_2'); - await ml.trainedModels.selectModel('.elser_model_2'); + it.skip('should not be able to delete a model assigned to all spaces, and show a warning copy explaining the situation', async () => { + await ml.testExecution.logTestStep('should select a model'); + await ml.trainedModelsTable.filterWithSearchString(modelAllSpaces.name, 1); + await ml.trainedModels.selectModel(modelAllSpaces.name); await ml.testExecution.logTestStep('should attempt to delete the model'); await ml.trainedModels.clickBulkDelete(); @@ -493,6 +500,11 @@ export default function ({ getService }: FtrProviderContext) { await ml.trainedModelsTable.assertStatsTabContent(); await ml.trainedModelsTable.assertPipelinesTabContent(false); }); + } + + describe('supports actions for an imported model', function () { + // It's enough to test the actions for one model + const model = trainedModels[trainedModels.length - 1]; it(`starts deployment of the imported model ${model.id}`, async () => { await ml.trainedModelsTable.startDeploymentWithParams(model.id, { @@ -513,7 +525,7 @@ export default function ({ getService }: FtrProviderContext) { it(`deletes the imported model ${model.id}`, async () => { await ml.trainedModelsTable.deleteModel(model.id); }); - } + }); }); });