From b9e2a523b6f67d11dbb85c7048ff9792f12471f1 Mon Sep 17 00:00:00 2001 From: Ashley McEntee <123661468+ashley-o0o@users.noreply.github.com> Date: Fri, 9 Aug 2024 12:13:24 -0400 Subject: [PATCH 01/31] Fix for untrue flag (#3049) --- .../__tests__/useNotebookImageData.spec.ts | 34 +++ .../detail/notebooks/useNotebookImageData.ts | 194 +++++++++++++----- 2 files changed, 182 insertions(+), 46 deletions(-) diff --git a/frontend/src/pages/projects/screens/detail/notebooks/__tests__/useNotebookImageData.spec.ts b/frontend/src/pages/projects/screens/detail/notebooks/__tests__/useNotebookImageData.spec.ts index 8685653586..0bc949321c 100644 --- a/frontend/src/pages/projects/screens/detail/notebooks/__tests__/useNotebookImageData.spec.ts +++ b/frontend/src/pages/projects/screens/detail/notebooks/__tests__/useNotebookImageData.spec.ts @@ -106,4 +106,38 @@ describe('getNotebookImageData', () => { const result = getNotebookImageData(notebook, images); expect(result?.imageAvailability).toBe(NotebookImageAvailability.DELETED); }); + + it('should fail when custom image shows unexpected Deleted flag', () => { + const imageName = 'jupyter-datascience-notebook'; + const tagName = '2024.1'; + const notebook = mockNotebookK8sResource({ + lastImageSelection: `${imageName}:${tagName}`, + image: `quay.io/opendatahub/${imageName}:${tagName}`, + }); + const images = [ + mockImageStreamK8sResource({ + tagName, + name: imageName, + }), + ]; + const result = getNotebookImageData(notebook, images); + expect(result?.imageAvailability).toBe(NotebookImageAvailability.ENABLED); + }); + + it('should test an image defined via sha', () => { + const imageName = 'jupyter-datascience-notebook'; + const imageSha = 'sha256:a138838e1c9acd7708462e420bf939e03296b97e9cf6c0aa0fd9a5d20361ab75'; + const notebook = mockNotebookK8sResource({ + lastImageSelection: `${imageName}:${imageSha}`, + image: `quay.io/opendatahub/${imageName}@${imageSha}`, + }); + const images = [ + mockImageStreamK8sResource({ + imageTag: `quay.io/opendatahub/${imageName}@${imageSha}`, + name: imageName, + }), + ]; + const result = getNotebookImageData(notebook, images); + expect(result?.imageAvailability).toBe(NotebookImageAvailability.ENABLED); + }); }); diff --git a/frontend/src/pages/projects/screens/detail/notebooks/useNotebookImageData.ts b/frontend/src/pages/projects/screens/detail/notebooks/useNotebookImageData.ts index 17f4078783..7792bc688b 100644 --- a/frontend/src/pages/projects/screens/detail/notebooks/useNotebookImageData.ts +++ b/frontend/src/pages/projects/screens/detail/notebooks/useNotebookImageData.ts @@ -24,59 +24,52 @@ export const getNotebookImageData = ( } const [imageName, versionName] = imageTag; - const [lastImageSelectionName] = + const [lastImageSelectionName, lastImageSelectionTag] = notebook.metadata.annotations?.['notebooks.opendatahub.io/last-image-selection']?.split(':') ?? []; - // Fallback for non internal registry clusters - const imageStream = - images.find((image) => image.metadata.name === imageName) || - images.find((image) => - image.spec.tags - ? image.spec.tags.find( - (version) => - version.from?.name === container.image && - image.metadata.name === lastImageSelectionName, - ) - : false, - ); - - // if the image stream is not found, consider it deleted - if (!imageStream) { - // Get the image display name from the notebook metadata if we can't find the image stream. (this is a fallback and could still be undefined) - const imageDisplayName = notebook.metadata.annotations?.['opendatahub.io/image-display-name']; - - return { - imageAvailability: NotebookImageAvailability.DELETED, - imageDisplayName, - }; + const notebookImageInternalRegistry = getNotebookImageInternalRegistry( + notebook, + images, + imageName, + versionName, + ); + if ( + notebookImageInternalRegistry && + notebookImageInternalRegistry.imageAvailability !== NotebookImageAvailability.DELETED + ) { + return notebookImageInternalRegistry; } - - const versions = imageStream.spec.tags || []; - const imageVersion = versions.find( - (version) => version.name === versionName || version.from?.name === container.image, + const notebookImageNoInternalRegistry = getNotebookImageNoInternalRegistry( + notebook, + images, + lastImageSelectionName, + container.image, ); - - // because the image stream was found, get its display name - const imageDisplayName = getImageStreamDisplayName(imageStream); - - // if the image version is not found, consider the image stream deleted - if (!imageVersion) { - return { - imageAvailability: NotebookImageAvailability.DELETED, - imageDisplayName, - }; + if ( + notebookImageNoInternalRegistry && + notebookImageNoInternalRegistry.imageAvailability !== NotebookImageAvailability.DELETED + ) { + return notebookImageNoInternalRegistry; + } + const notebookImageNoInternalRegistryNoSHA = getNotebookImageNoInternalRegistryNoSHA( + notebook, + images, + lastImageSelectionTag, + container.image, + ); + if ( + notebookImageNoInternalRegistryNoSHA && + notebookImageNoInternalRegistryNoSHA.imageAvailability !== NotebookImageAvailability.DELETED + ) { + return notebookImageNoInternalRegistryNoSHA; } - - // if the image stream exists and the image version exists, return the image data return { - imageStream, - imageVersion, - imageAvailability: - imageStream.metadata.labels?.['opendatahub.io/notebook-image'] === 'true' - ? NotebookImageAvailability.ENABLED - : NotebookImageAvailability.DISABLED, - imageDisplayName, + imageAvailability: NotebookImageAvailability.DELETED, + imageDisplayName: + notebookImageInternalRegistry?.imageDisplayName || + notebookImageNoInternalRegistry?.imageDisplayName || + notebookImageNoInternalRegistryNoSHA?.imageDisplayName, }; }; @@ -98,4 +91,113 @@ const useNotebookImageData = (notebook?: NotebookKind): NotebookImageData => { }, [images, notebook, loaded, loadError]); }; +const getNotebookImageInternalRegistry = ( + notebook: NotebookKind, + images: ImageStreamKind[], + imageName: string, + versionName: string, +): NotebookImageData[0] => { + const imageStream = images.find((image) => image.metadata.name === imageName); + + if (!imageStream) { + // Get the image display name from the notebook metadata if we can't find the image stream. (this is a fallback and could still be undefined) + return getDeletedImageData( + notebook.metadata.annotations?.['opendatahub.io/image-display-name'], + ); + } + + const versions = imageStream.spec.tags || []; + const imageVersion = versions.find((version) => version.name === versionName); + const imageDisplayName = getImageStreamDisplayName(imageStream); + if (!imageVersion) { + return getDeletedImageData(imageDisplayName); + } + return { + imageStream, + imageVersion, + imageAvailability: getImageAvailability(imageStream), + imageDisplayName, + }; +}; + +const getNotebookImageNoInternalRegistry = ( + notebook: NotebookKind, + images: ImageStreamKind[], + lastImageSelectionName: string, + containerImage: string, +): NotebookImageData[0] => { + const imageStream = images.find( + (image) => + image.metadata.name === lastImageSelectionName && + image.spec.tags?.find((version) => version.from?.name === containerImage), + ); + + if (!imageStream) { + // Get the image display name from the notebook metadata if we can't find the image stream. (this is a fallback and could still be undefined) + return getDeletedImageData( + notebook.metadata.annotations?.['opendatahub.io/image-display-name'], + ); + } + + const versions = imageStream.spec.tags || []; + const imageVersion = versions.find((version) => version.from?.name === containerImage); + const imageDisplayName = getImageStreamDisplayName(imageStream); + if (!imageVersion) { + return getDeletedImageData(imageDisplayName); + } + return { + imageStream, + imageVersion, + imageAvailability: getImageAvailability(imageStream), + imageDisplayName, + }; +}; + +const getNotebookImageNoInternalRegistryNoSHA = ( + notebook: NotebookKind, + images: ImageStreamKind[], + lastImageSelectionTag: string, + containerImage: string, +): NotebookImageData[0] => { + const imageStream = images.find((image) => + image.status?.tags?.find( + (version) => + version.tag === lastImageSelectionTag && + version.items?.find((item) => item.dockerImageReference === containerImage), + ), + ); + + if (!imageStream) { + // Get the image display name from the notebook metadata if we can't find the image stream. (this is a fallback and could still be undefined) + return getDeletedImageData( + notebook.metadata.annotations?.['opendatahub.io/image-display-name'], + ); + } + + const versions = imageStream.spec.tags || []; + const imageVersion = versions.find((version) => version.name === lastImageSelectionTag); + const imageDisplayName = getImageStreamDisplayName(imageStream); + if (!imageVersion) { + return getDeletedImageData(imageDisplayName); + } + return { + imageStream, + imageVersion, + imageAvailability: getImageAvailability(imageStream), + imageDisplayName, + }; +}; + +export const getImageAvailability = (imageStream: ImageStreamKind): NotebookImageAvailability => + imageStream.metadata.labels?.['opendatahub.io/notebook-image'] === 'true' + ? NotebookImageAvailability.ENABLED + : NotebookImageAvailability.DISABLED; + +export const getDeletedImageData = ( + imageDisplayName: string | undefined, +): NotebookImageData[0] => ({ + imageAvailability: NotebookImageAvailability.DELETED, + imageDisplayName, +}); + export default useNotebookImageData; From e42fdcd887ab5a3c562fdc74d60d02c9d87208f3 Mon Sep 17 00:00:00 2001 From: "Heiko W. Rupp" Date: Mon, 12 Aug 2024 12:04:16 +0200 Subject: [PATCH 02/31] RHOAIENG-9232 Create useTrackUser (#3024) --- .../analyticsTracking/segmentIOUtils.tsx | 16 +++++- .../analyticsTracking/trackingProperties.ts | 3 ++ .../analyticsTracking/useSegmentTracking.ts | 30 +++++------ .../analyticsTracking/useTrackUser.ts | 50 +++++++++++++++++++ 4 files changed, 82 insertions(+), 17 deletions(-) create mode 100644 frontend/src/concepts/analyticsTracking/useTrackUser.ts diff --git a/frontend/src/concepts/analyticsTracking/segmentIOUtils.tsx b/frontend/src/concepts/analyticsTracking/segmentIOUtils.tsx index 080ebeb379..f55556ba6b 100644 --- a/frontend/src/concepts/analyticsTracking/segmentIOUtils.tsx +++ b/frontend/src/concepts/analyticsTracking/segmentIOUtils.tsx @@ -86,6 +86,13 @@ export const firePageEvent = (): void => { } }; +// Stuff that gets send over as traits on an identify call. Must not include (anonymous) user Id. +type IdentifyTraits = { + isAdmin: boolean; + canCreateProjects: boolean; + clusterID: string; +}; + /* * This fires a call to associate further processing with the passed (anonymous) userId * in the properties. @@ -94,8 +101,13 @@ export const fireIdentifyEvent = (properties: IdentifyEventProperties): void => const clusterID = window.clusterID ?? ''; if (DEV_MODE) { /* eslint-disable-next-line no-console */ - console.log(`Identify event triggered`); + console.log(`Identify event triggered: ${JSON.stringify(properties)}`); } else if (window.analytics) { - window.analytics.identify(properties.anonymousID, { clusterID }); + const traits: IdentifyTraits = { + clusterID, + isAdmin: properties.isAdmin, + canCreateProjects: properties.canCreateProjects, + }; + window.analytics.identify(properties.anonymousID, traits); } }; diff --git a/frontend/src/concepts/analyticsTracking/trackingProperties.ts b/frontend/src/concepts/analyticsTracking/trackingProperties.ts index 81c26f44e9..17daeb3a16 100644 --- a/frontend/src/concepts/analyticsTracking/trackingProperties.ts +++ b/frontend/src/concepts/analyticsTracking/trackingProperties.ts @@ -3,7 +3,10 @@ export type ODHSegmentKey = { }; export type IdentifyEventProperties = { + isAdmin: boolean; anonymousID?: string; + userId?: string; + canCreateProjects: boolean; }; export const enum TrackingOutcome { diff --git a/frontend/src/concepts/analyticsTracking/useSegmentTracking.ts b/frontend/src/concepts/analyticsTracking/useSegmentTracking.ts index d609a9a2f9..2411a5b744 100644 --- a/frontend/src/concepts/analyticsTracking/useSegmentTracking.ts +++ b/frontend/src/concepts/analyticsTracking/useSegmentTracking.ts @@ -2,6 +2,7 @@ import React from 'react'; import { useAppContext } from '~/app/AppContext'; import { useAppSelector } from '~/redux/hooks'; import { fireIdentifyEvent, firePageEvent } from '~/concepts/analyticsTracking/segmentIOUtils'; +import { useTrackUser } from '~/concepts/analyticsTracking/useTrackUser'; import { useWatchSegmentKey } from './useWatchSegmentKey'; import { initSegment } from './initSegment'; @@ -10,28 +11,27 @@ export const useSegmentTracking = (): void => { const { dashboardConfig } = useAppContext(); const username = useAppSelector((state) => state.user); const clusterID = useAppSelector((state) => state.clusterID); + const [userProps, uPropsLoaded] = useTrackUser(username); React.useEffect(() => { - if (segmentKey && loaded && !loadError && username && clusterID) { - const computeUserId = async () => { - const anonymousIDBuffer = await crypto.subtle.digest( - 'SHA-1', - new TextEncoder().encode(username), - ); - const anonymousIDArray = Array.from(new Uint8Array(anonymousIDBuffer)); - return anonymousIDArray.map((b) => b.toString(16).padStart(2, '0')).join(''); - }; - + if (segmentKey && loaded && !loadError && username && clusterID && uPropsLoaded) { window.clusterID = clusterID; initSegment({ segmentKey, enabled: !dashboardConfig.spec.dashboardConfig.disableTracking, }).then(() => { - computeUserId().then((userId) => { - fireIdentifyEvent({ anonymousID: userId }); - firePageEvent(); - }); + fireIdentifyEvent(userProps); + firePageEvent(); }); } - }, [clusterID, loadError, loaded, segmentKey, username, dashboardConfig]); + }, [ + clusterID, + loadError, + loaded, + segmentKey, + username, + dashboardConfig, + userProps, + uPropsLoaded, + ]); }; diff --git a/frontend/src/concepts/analyticsTracking/useTrackUser.ts b/frontend/src/concepts/analyticsTracking/useTrackUser.ts new file mode 100644 index 0000000000..062334dd76 --- /dev/null +++ b/frontend/src/concepts/analyticsTracking/useTrackUser.ts @@ -0,0 +1,50 @@ +import React from 'react'; +import { useUser } from '~/redux/selectors'; +import { useAccessReview } from '~/api'; +import { AccessReviewResourceAttributes } from '~/k8sTypes'; +import { IdentifyEventProperties } from '~/concepts/analyticsTracking/trackingProperties'; + +export const useTrackUser = (username?: string): [IdentifyEventProperties, boolean] => { + const { isAdmin } = useUser(); + const [anonymousId, setAnonymousId] = React.useState(undefined); + + const [loaded, setLoaded] = React.useState(false); + const createReviewResource: AccessReviewResourceAttributes = { + group: 'project.openshift.io', + resource: 'projectrequests', + verb: 'create', + }; + const [allowCreate, acLoaded] = useAccessReview(createReviewResource); + + React.useEffect(() => { + const computeAnonymousUserId = async () => { + const anonymousIDBuffer = await crypto.subtle.digest( + 'SHA-1', + new TextEncoder().encode(username), + ); + const anonymousIDArray = Array.from(new Uint8Array(anonymousIDBuffer)); + const aId = anonymousIDArray.map((b) => b.toString(16).padStart(2, '0')).join(''); + return aId; + }; + + if (!anonymousId) { + computeAnonymousUserId().then((val) => { + setAnonymousId(val); + }); + } + if (acLoaded && anonymousId) { + setLoaded(true); + } + }, [username, anonymousId, acLoaded]); + + const props: IdentifyEventProperties = React.useMemo( + () => ({ + isAdmin, + canCreateProjects: allowCreate, + anonymousID: anonymousId, + }), + [isAdmin, allowCreate, anonymousId], + ); + + return [props, loaded]; +}; From dc622dbafd2052597cb4840d3ff1df7a7d39e45d Mon Sep 17 00:00:00 2001 From: Juntao Wang <37624318+DaoDaoNoCode@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:52:32 -0400 Subject: [PATCH 03/31] Add deployment modal for model registry (#3074) * Add deployment modal for model registry * address feedback --- frontend/src/__mocks__/mockModelArtifact.ts | 5 +- .../src/__mocks__/mockSecretK8sResource.ts | 10 +- .../modelRegistry/modelVersionDeployModal.ts | 17 + .../cypress/cypress/pages/modelServing.ts | 4 + .../modelRegistry/modelVersionDeploy.cy.ts | 326 ++++++++++++++++++ .../modelRegistry/modelVersionDetails.cy.ts | 2 +- .../context/ModelRegistryContext.tsx | 27 ++ .../ModelVersionDetailsHeaderActions.tsx | 10 +- .../ModelVersions/ModelVersionsTableRow.tsx | 10 +- .../RegisteredModelTableRow.tsx | 6 - .../useLabeledDataConnections.ts | 46 +++ .../usePrefillDeployModalFromModelRegistry.ts | 87 +++++ .../useProjectErrorForRegisteredModel.ts | 34 ++ .../useRegisteredModelDeployInfo.ts | 65 ++++ .../useProjectErrorForRegisteredModel.spec.ts | 97 ++++++ .../components/DeployRegisteredModelModal.tsx | 127 +++++++ .../customServingRuntimes/utils.ts | 14 + .../DataConnectionExistingField.tsx | 35 +- .../DataConnectionSection.tsx | 36 +- .../InferenceServiceFrameworkSection.tsx | 18 +- .../ManageInferenceServiceModal.tsx | 99 ++++-- .../InferenceServiceModal/ProjectSelector.tsx | 95 +++++ .../screens/projects/__tests__/utils.spec.ts | 19 +- .../kServeModal/ManageKServeModal.tsx | 172 +++++---- .../modelServing/screens/projects/utils.ts | 15 +- .../src/pages/modelServing/screens/types.ts | 13 +- 26 files changed, 1233 insertions(+), 156 deletions(-) create mode 100644 frontend/src/__tests__/cypress/cypress/pages/modelRegistry/modelVersionDeployModal.ts create mode 100644 frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDeploy.cy.ts create mode 100644 frontend/src/pages/modelRegistry/screens/RegisteredModels/useLabeledDataConnections.ts create mode 100644 frontend/src/pages/modelRegistry/screens/RegisteredModels/usePrefillDeployModalFromModelRegistry.ts create mode 100644 frontend/src/pages/modelRegistry/screens/RegisteredModels/useProjectErrorForRegisteredModel.ts create mode 100644 frontend/src/pages/modelRegistry/screens/RegisteredModels/useRegisteredModelDeployInfo.ts create mode 100644 frontend/src/pages/modelRegistry/screens/__tests__/useProjectErrorForRegisteredModel.spec.ts create mode 100644 frontend/src/pages/modelRegistry/screens/components/DeployRegisteredModelModal.tsx create mode 100644 frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/ProjectSelector.tsx diff --git a/frontend/src/__mocks__/mockModelArtifact.ts b/frontend/src/__mocks__/mockModelArtifact.ts index 3bf055eb33..78c75ae1ef 100644 --- a/frontend/src/__mocks__/mockModelArtifact.ts +++ b/frontend/src/__mocks__/mockModelArtifact.ts @@ -8,6 +8,9 @@ export const mockModelArtifact = (): ModelArtifact => ({ description: 'Description of model version', artifactType: 'model-artifact', customProperties: {}, + storageKey: 'test storage key', storagePath: 'test path', - uri: 'https://huggingface.io/mnist.onnx', + uri: 's3://test-bucket/demo-models/test-path?endpoint=test-endpoint&defaultRegion=test-region', + modelFormatName: 'test model format', + modelFormatVersion: 'test version 1', }); diff --git a/frontend/src/__mocks__/mockSecretK8sResource.ts b/frontend/src/__mocks__/mockSecretK8sResource.ts index 3a787ee31c..90baaae2a5 100644 --- a/frontend/src/__mocks__/mockSecretK8sResource.ts +++ b/frontend/src/__mocks__/mockSecretK8sResource.ts @@ -6,6 +6,8 @@ type MockResourceConfigType = { namespace?: string; displayName?: string; s3Bucket?: string; + endPoint?: string; + region?: string; uid?: string; }; @@ -13,7 +15,9 @@ export const mockSecretK8sResource = ({ name = 'test-secret', namespace = 'test-project', displayName = 'Test Secret', - s3Bucket = 'test-bucket', + s3Bucket = 'dGVzdC1idWNrZXQ=', + endPoint = 'aHR0cHM6Ly9zMy5hbWF6b25hd3MuY29tLw==', + region = 'dXMtZWFzdC0x', uid = genUID('secret'), }: MockResourceConfigType): SecretKind => ({ kind: 'Secret', @@ -35,9 +39,9 @@ export const mockSecretK8sResource = ({ }, data: { AWS_ACCESS_KEY_ID: 'c2RzZA==', - AWS_DEFAULT_REGION: 'dXMtZWFzdC0x', + AWS_DEFAULT_REGION: region, AWS_S3_BUCKET: s3Bucket, - AWS_S3_ENDPOINT: 'aHR0cHM6Ly9zMy5hbWF6b25hd3MuY29tLw==', + AWS_S3_ENDPOINT: endPoint, AWS_SECRET_ACCESS_KEY: 'c2RzZA==', }, type: 'Opaque', diff --git a/frontend/src/__tests__/cypress/cypress/pages/modelRegistry/modelVersionDeployModal.ts b/frontend/src/__tests__/cypress/cypress/pages/modelRegistry/modelVersionDeployModal.ts new file mode 100644 index 0000000000..383ab38c3c --- /dev/null +++ b/frontend/src/__tests__/cypress/cypress/pages/modelRegistry/modelVersionDeployModal.ts @@ -0,0 +1,17 @@ +import { Modal } from '~/__tests__/cypress/cypress/pages/components/Modal'; + +class ModelVersionDeployModal extends Modal { + constructor() { + super('Deploy model'); + } + + findProjectSelector() { + return cy.findByTestId('deploy-model-project-selector'); + } + + selectProjectByName(name: string) { + this.findProjectSelector().findDropdownItem(name).click(); + } +} + +export const modelVersionDeployModal = new ModelVersionDeployModal(); diff --git a/frontend/src/__tests__/cypress/cypress/pages/modelServing.ts b/frontend/src/__tests__/cypress/cypress/pages/modelServing.ts index 18bd61e95a..3279f246e8 100644 --- a/frontend/src/__tests__/cypress/cypress/pages/modelServing.ts +++ b/frontend/src/__tests__/cypress/cypress/pages/modelServing.ts @@ -141,6 +141,10 @@ class InferenceServiceModal extends Modal { return this.find().findByTestId('field AWS_S3_BUCKET'); } + findLocationRegionInput() { + return this.find().findByTestId('field AWS_DEFAULT_REGION'); + } + findLocationPathInput() { return this.find().findByTestId('folder-path'); } diff --git a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDeploy.cy.ts b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDeploy.cy.ts new file mode 100644 index 0000000000..09eaf92631 --- /dev/null +++ b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDeploy.cy.ts @@ -0,0 +1,326 @@ +/* eslint-disable camelcase */ +import { + mockDscStatus, + mockK8sResourceList, + mockProjectK8sResource, + mockSecretK8sResource, +} from '~/__mocks__'; +import { mockDashboardConfig } from '~/__mocks__/mockDashboardConfig'; +import { mockRegisteredModelList } from '~/__mocks__/mockRegisteredModelsList'; +import { + ProjectModel, + SecretModel, + ServiceModel, + ServingRuntimeModel, + TemplateModel, +} from '~/__tests__/cypress/cypress/utils/models'; +import { mockModelVersionList } from '~/__mocks__/mockModelVersionList'; +import { mockModelVersion } from '~/__mocks__/mockModelVersion'; +import type { ModelVersion } from '~/concepts/modelRegistry/types'; +import { ModelState } from '~/concepts/modelRegistry/types'; +import { mockRegisteredModel } from '~/__mocks__/mockRegisteredModel'; +import { modelRegistry } from '~/__tests__/cypress/cypress/pages/modelRegistry'; +import { mockModelRegistryService } from '~/__mocks__/mockModelRegistryService'; +import { modelVersionDeployModal } from '~/__tests__/cypress/cypress/pages/modelRegistry/modelVersionDeployModal'; +import { mockModelArtifactList } from '~/__mocks__/mockModelArtifactList'; +import { + mockInvalidTemplateK8sResource, + mockServingRuntimeTemplateK8sResource, +} from '~/__mocks__/mockServingRuntimeTemplateK8sResource'; +import { ServingRuntimePlatform } from '~/types'; +import { kserveModal } from '~/__tests__/cypress/cypress/pages/modelServing'; +import { mockModelArtifact } from '~/__mocks__/mockModelArtifact'; + +const MODEL_REGISTRY_API_VERSION = 'v1alpha3'; + +type HandlersProps = { + registeredModelsSize?: number; + modelVersions?: ModelVersion[]; + modelMeshInstalled?: boolean; + kServeInstalled?: boolean; +}; + +const registeredModelMocked = mockRegisteredModel({ name: 'test-1' }); +const modelVersionMocked = mockModelVersion({ + id: '1', + name: 'test model version', + state: ModelState.LIVE, +}); +const modelArtifactMocked = mockModelArtifact(); + +const initIntercepts = ({ + registeredModelsSize = 4, + modelVersions = [mockModelVersion({ id: '1', name: 'test model version' })], + modelMeshInstalled = true, + kServeInstalled = true, +}: HandlersProps) => { + cy.interceptOdh( + 'GET /api/config', + mockDashboardConfig({ + disableModelRegistry: false, + }), + ); + cy.interceptOdh( + 'GET /api/dsc/status', + mockDscStatus({ + installedComponents: { + kserve: kServeInstalled, + 'model-mesh': modelMeshInstalled, + 'model-registry-operator': true, + }, + }), + ); + + cy.interceptK8sList( + ServiceModel, + mockK8sResourceList([mockModelRegistryService({ name: 'modelregistry-sample' })]), + ); + + cy.interceptOdh( + 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models', + { path: { serviceName: 'modelregistry-sample', apiVersion: MODEL_REGISTRY_API_VERSION } }, + mockRegisteredModelList({ size: registeredModelsSize }), + ); + + cy.interceptOdh( + 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId/versions', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 1, + }, + }, + mockModelVersionList({ + items: modelVersions, + }), + ); + + cy.interceptOdh( + 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 1, + }, + }, + registeredModelMocked, + ); + + cy.interceptOdh( + 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions/:modelVersionId', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 1, + }, + }, + modelVersionMocked, + ); + + cy.interceptK8sList( + ProjectModel, + mockK8sResourceList([ + mockProjectK8sResource({ + enableModelMesh: true, + k8sName: 'model-mesh-project', + displayName: 'Model mesh project', + }), + mockProjectK8sResource({ + enableModelMesh: false, + k8sName: 'kserve-project', + displayName: 'KServe project', + }), + mockProjectK8sResource({ k8sName: 'test-project', displayName: 'Test project' }), + ]), + ); + + cy.interceptOdh( + `GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions/:modelVersionId/artifacts`, + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 1, + }, + }, + mockModelArtifactList(), + ); + + cy.interceptK8sList( + TemplateModel, + mockK8sResourceList( + [ + mockServingRuntimeTemplateK8sResource({ + name: 'template-1', + displayName: 'Multi Platform', + platforms: [ServingRuntimePlatform.SINGLE, ServingRuntimePlatform.MULTI], + }), + mockServingRuntimeTemplateK8sResource({ + name: 'template-2', + displayName: 'Caikit', + platforms: [ServingRuntimePlatform.SINGLE], + }), + mockServingRuntimeTemplateK8sResource({ + name: 'template-3', + displayName: 'New OVMS Server', + platforms: [ServingRuntimePlatform.MULTI], + }), + mockServingRuntimeTemplateK8sResource({ + name: 'template-4', + displayName: 'Serving Runtime with No Annotations', + }), + mockInvalidTemplateK8sResource({}), + ], + { namespace: 'opendatahub' }, + ), + ); +}; + +describe('Deploy model version', () => { + it('Deploy model version on unsupported platform', () => { + initIntercepts({ kServeInstalled: false, modelMeshInstalled: false }); + cy.visit(`/modelRegistry/modelregistry-sample/registeredModels/1/versions`); + const modelVersionRow = modelRegistry.getModelVersionRow('test model version'); + modelVersionRow.findKebabAction('Deploy').click(); + modelVersionDeployModal.selectProjectByName('Model mesh project'); + cy.findByText('Multi-model platform is not installed').should('exist'); + modelVersionDeployModal.selectProjectByName('KServe project'); + cy.findByText('Single-model platform is not installed').should('exist'); + }); + + it('Deploy model version on a project which platform is not selected', () => { + initIntercepts({}); + cy.visit(`/modelRegistry/modelregistry-sample/registeredModels/1/versions`); + const modelVersionRow = modelRegistry.getModelVersionRow('test model version'); + modelVersionRow.findKebabAction('Deploy').click(); + modelVersionDeployModal.selectProjectByName('Test project'); + cy.findByText('Cannot deploy the model until you select a model serving platform').should( + 'exist', + ); + }); + + it('Deploy model version on a model mesh project that has no model servers', () => { + initIntercepts({}); + cy.visit(`/modelRegistry/modelregistry-sample/registeredModels/1/versions`); + const modelVersionRow = modelRegistry.getModelVersionRow('test model version'); + modelVersionRow.findKebabAction('Deploy').click(); + cy.interceptK8sList(ServingRuntimeModel, mockK8sResourceList([])); + modelVersionDeployModal.selectProjectByName('Model mesh project'); + cy.findByText('Cannot deploy the model until you configure a model server').should('exist'); + }); + + it('Pre-fill deployment information on KServe modal', () => { + initIntercepts({}); + cy.interceptK8sList( + SecretModel, + mockK8sResourceList([ + mockSecretK8sResource({ + name: 'test-secret-not-match', + displayName: 'Test Secret Not Match', + namespace: 'kserve-project', + s3Bucket: 'dGVzdC1idWNrZXQ=', + endPoint: 'dGVzdC1lbmRwb2ludC1ub3QtbWF0Y2g=', // endpoint not match + region: 'dGVzdC1yZWdpb24=', + }), + ]), + ); + cy.visit(`/modelRegistry/modelregistry-sample/registeredModels/1/versions`); + const modelVersionRow = modelRegistry.getModelVersionRow('test model version'); + modelVersionRow.findKebabAction('Deploy').click(); + modelVersionDeployModal.selectProjectByName('KServe project'); + + // Validate name input field + kserveModal + .findModelNameInput() + .should('contain.value', `${registeredModelMocked.name} - ${modelVersionMocked.name} - `); + + // Validate model framework section + kserveModal.findModelFrameworkSelect().should('be.disabled'); + cy.findByText('The source model format is').should('not.exist'); + kserveModal.findServingRuntimeTemplateDropdown().findDropdownItem('Multi Platform').click(); + kserveModal.findModelFrameworkSelect().should('be.enabled'); + cy.findByText( + `The source model format is ${modelArtifactMocked.modelFormatName} - ${modelArtifactMocked.modelFormatVersion}`, + ).should('exist'); + + // Validate data connection section + cy.findByText( + "We've auto-switched to create a new data connection and pre-filled the details for you.", + ).should('exist'); + kserveModal.findNewDataConnectionOption().should('be.checked'); + kserveModal.findLocationNameInput().should('have.value', modelArtifactMocked.storageKey); + kserveModal.findLocationBucketInput().should('have.value', 'test-bucket'); + kserveModal.findLocationRegionInput().should('have.value', 'test-region'); + kserveModal.findLocationEndpointInput().should('have.value', 'test-endpoint'); + kserveModal.findLocationPathInput().should('have.value', 'demo-models/test-path'); + }); + + it('One match data connection on KServe modal', () => { + initIntercepts({}); + cy.interceptK8sList( + SecretModel, + mockK8sResourceList([ + mockSecretK8sResource({ + namespace: 'kserve-project', + s3Bucket: 'dGVzdC1idWNrZXQ=', + endPoint: 'dGVzdC1lbmRwb2ludA==', + region: 'dGVzdC1yZWdpb24=', + }), + mockSecretK8sResource({ + name: 'test-secret-not-match', + displayName: 'Test Secret Not Match', + namespace: 'kserve-project', + s3Bucket: 'dGVzdC1idWNrZXQ=', + endPoint: 'dGVzdC1lbmRwb2ludC1ub3QtbWF0Y2g=', // endpoint not match + region: 'dGVzdC1yZWdpb24=', + }), + ]), + ); + + cy.visit(`/modelRegistry/modelregistry-sample/registeredModels/1/versions`); + const modelVersionRow = modelRegistry.getModelVersionRow('test model version'); + modelVersionRow.findKebabAction('Deploy').click(); + modelVersionDeployModal.selectProjectByName('KServe project'); + + // Validate data connection section + kserveModal.findExistingDataConnectionOption().should('be.checked'); + kserveModal.findExistingConnectionSelect().should('contain.text', 'Test Secret'); + kserveModal.findLocationPathInput().should('have.value', 'demo-models/test-path'); + }); + + it('More than one match data connections on KServe modal', () => { + initIntercepts({}); + cy.interceptK8sList( + SecretModel, + mockK8sResourceList([ + mockSecretK8sResource({ + namespace: 'kserve-project', + s3Bucket: 'dGVzdC1idWNrZXQ=', + endPoint: 'dGVzdC1lbmRwb2ludA==', + region: 'dGVzdC1yZWdpb24=', + }), + mockSecretK8sResource({ + name: 'test-secret-2', + displayName: 'Test Secret 2', + namespace: 'kserve-project', + s3Bucket: 'dGVzdC1idWNrZXQ=', + endPoint: 'dGVzdC1lbmRwb2ludA==', + region: 'dGVzdC1yZWdpb24=', + }), + ]), + ); + + cy.visit(`/modelRegistry/modelregistry-sample/registeredModels/1/versions`); + const modelVersionRow = modelRegistry.getModelVersionRow('test model version'); + modelVersionRow.findKebabAction('Deploy').click(); + modelVersionDeployModal.selectProjectByName('KServe project'); + + // Validate data connection section + kserveModal.findExistingDataConnectionOption().should('be.checked'); + kserveModal.findExistingConnectionSelect().should('contain.text', 'Select...'); + kserveModal.findLocationPathInput().should('have.value', 'demo-models/test-path'); + }); +}); diff --git a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDetails.cy.ts b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDetails.cy.ts index ac956b5520..ff3057cb45 100644 --- a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDetails.cy.ts +++ b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDetails.cy.ts @@ -143,7 +143,7 @@ describe('Model version details', () => { 'Label y', 'Label z', ]); - modelVersionDetails.findStorageLocation().contains('https://huggingface.io/mnist.onnx'); + modelVersionDetails.findStorageLocation().contains('s3://test-bucket/demo-models/test-path'); }); it('Switching model versions', () => { diff --git a/frontend/src/concepts/modelRegistry/context/ModelRegistryContext.tsx b/frontend/src/concepts/modelRegistry/context/ModelRegistryContext.tsx index 92d045a96b..40b42a1e84 100644 --- a/frontend/src/concepts/modelRegistry/context/ModelRegistryContext.tsx +++ b/frontend/src/concepts/modelRegistry/context/ModelRegistryContext.tsx @@ -1,10 +1,21 @@ import * as React from 'react'; import { SupportedArea, conditionalArea } from '~/concepts/areas'; +import { useTemplates } from '~/api'; +import { useDashboardNamespace } from '~/redux/selectors'; +import { useContextResourceData } from '~/utilities/useContextResourceData'; +import useTemplateOrder from '~/pages/modelServing/customServingRuntimes/useTemplateOrder'; +import { ContextResourceData, CustomWatchK8sResult } from '~/types'; +import { TemplateKind } from '~/k8sTypes'; +import { DEFAULT_CONTEXT_DATA, DEFAULT_LIST_WATCH_RESULT } from '~/utilities/const'; +import useTemplateDisablement from '~/pages/modelServing/customServingRuntimes/useTemplateDisablement'; import useModelRegistryAPIState, { ModelRegistryAPIState } from './useModelRegistryAPIState'; export type ModelRegistryContextType = { apiState: ModelRegistryAPIState; refreshAPIState: () => void; + servingRuntimeTemplates: CustomWatchK8sResult; + servingRuntimeTemplateOrder: ContextResourceData; + servingRuntimeTemplateDisablement: ContextResourceData; }; type ModelRegistryContextProviderProps = { @@ -16,12 +27,25 @@ export const ModelRegistryContext = React.createContext undefined, + servingRuntimeTemplates: DEFAULT_LIST_WATCH_RESULT, + servingRuntimeTemplateOrder: DEFAULT_CONTEXT_DATA, + servingRuntimeTemplateDisablement: DEFAULT_CONTEXT_DATA, }); export const ModelRegistryContextProvider = conditionalArea( SupportedArea.MODEL_REGISTRY, true, )(({ children, modelRegistryName }) => { + const { dashboardNamespace } = useDashboardNamespace(); + + const servingRuntimeTemplates = useTemplates(dashboardNamespace); + const servingRuntimeTemplateOrder = useContextResourceData( + useTemplateOrder(dashboardNamespace), + ); + const servingRuntimeTemplateDisablement = useContextResourceData( + useTemplateDisablement(dashboardNamespace), + ); + const hostPath = modelRegistryName ? `/api/service/modelregistry/${modelRegistryName}` : null; const [apiState, refreshAPIState] = useModelRegistryAPIState(hostPath); @@ -31,6 +55,9 @@ export const ModelRegistryContextProvider = conditionalArea {children} diff --git a/frontend/src/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsHeaderActions.tsx b/frontend/src/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsHeaderActions.tsx index 98a0216e98..59f01cf536 100644 --- a/frontend/src/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsHeaderActions.tsx +++ b/frontend/src/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsHeaderActions.tsx @@ -7,6 +7,7 @@ import { ModelVersion, ModelState } from '~/concepts/modelRegistry/types'; import { getPatchBodyForModelVersion } from '~/pages/modelRegistry/screens/utils'; import { ModelRegistrySelectorContext } from '~/concepts/modelRegistry/context/ModelRegistrySelectorContext'; import { modelVersionArchiveDetailsUrl } from '~/pages/modelRegistry/screens/routeUtils'; +import DeployRegisteredModelModal from '~/pages/modelRegistry/screens/components/DeployRegisteredModelModal'; interface ModelVersionsDetailsHeaderActionsProps { mv: ModelVersion; @@ -21,6 +22,7 @@ const ModelVersionsDetailsHeaderActions: React.FC(null); return ( @@ -48,9 +50,8 @@ const ModelVersionsDetailsHeaderActions: React.FC undefined} + onClick={() => setIsDeployModalOpen(true)} ref={tooltipRef} - isDisabled // TODO This feature is currently disabled but will be enabled in a future PR post-summit release. > Deploy @@ -65,6 +66,11 @@ const ModelVersionsDetailsHeaderActions: React.FC + setIsDeployModalOpen(false)} + isOpen={isDeployModalOpen} + modelVersion={mv} + /> setIsArchiveModalOpen(false)} onSubmit={() => diff --git a/frontend/src/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow.tsx b/frontend/src/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow.tsx index 3271fcb940..35401b761c 100644 --- a/frontend/src/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow.tsx +++ b/frontend/src/pages/modelRegistry/screens/ModelVersions/ModelVersionsTableRow.tsx @@ -14,6 +14,7 @@ import { ArchiveModelVersionModal } from '~/pages/modelRegistry/screens/componen import { ModelRegistryContext } from '~/concepts/modelRegistry/context/ModelRegistryContext'; import { getPatchBodyForModelVersion } from '~/pages/modelRegistry/screens/utils'; import { RestoreModelVersionModal } from '~/pages/modelRegistry/screens/components/RestoreModelVersionModal'; +import DeployRegisteredModelModal from '~/pages/modelRegistry/screens/components/DeployRegisteredModelModal'; type ModelVersionsTableRowProps = { modelVersion: ModelVersion; @@ -30,6 +31,7 @@ const ModelVersionsTableRow: React.FC = ({ const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext); const [isArchiveModalOpen, setIsArchiveModalOpen] = React.useState(false); const [isRestoreModalOpen, setIsRestoreModalOpen] = React.useState(false); + const [isDeployModalOpen, setIsDeployModalOpen] = React.useState(false); const { apiState } = React.useContext(ModelRegistryContext); const actions = isArchiveRow @@ -42,8 +44,7 @@ const ModelVersionsTableRow: React.FC = ({ : [ { title: 'Deploy', - // TODO: Implement functionality for onClick. This will be added in another PR - onClick: () => undefined, + onClick: () => setIsDeployModalOpen(true), }, { title: 'Archive version', @@ -105,6 +106,11 @@ const ModelVersionsTableRow: React.FC = ({ isOpen={isArchiveModalOpen} modelVersionName={mv.name} /> + setIsDeployModalOpen(false)} + isOpen={isDeployModalOpen} + modelVersion={mv} + /> setIsRestoreModalOpen(false)} onSubmit={() => diff --git a/frontend/src/pages/modelRegistry/screens/RegisteredModels/RegisteredModelTableRow.tsx b/frontend/src/pages/modelRegistry/screens/RegisteredModels/RegisteredModelTableRow.tsx index 28d6bca4c4..8dde378d8a 100644 --- a/frontend/src/pages/modelRegistry/screens/RegisteredModels/RegisteredModelTableRow.tsx +++ b/frontend/src/pages/modelRegistry/screens/RegisteredModels/RegisteredModelTableRow.tsx @@ -42,12 +42,6 @@ const RegisteredModelTableRow: React.FC = ({ }, ] : [ - { - title: 'Deploy', - isDisabled: true, - // TODO: Implement functionality for onClick. This will be added in another PR - onClick: () => undefined, - }, { title: 'Archive model', onClick: () => setIsArchiveModalOpen(true), diff --git a/frontend/src/pages/modelRegistry/screens/RegisteredModels/useLabeledDataConnections.ts b/frontend/src/pages/modelRegistry/screens/RegisteredModels/useLabeledDataConnections.ts new file mode 100644 index 0000000000..e03e7ce376 --- /dev/null +++ b/frontend/src/pages/modelRegistry/screens/RegisteredModels/useLabeledDataConnections.ts @@ -0,0 +1,46 @@ +import React from 'react'; +import { ObjectStorageFields, uriToObjectStorageFields } from '~/concepts/modelRegistry/utils'; +import { LabeledDataConnection } from '~/pages/modelServing/screens/types'; +import { AwsKeys } from '~/pages/projects/dataConnections/const'; +import { convertAWSSecretData } from '~/pages/projects/screens/detail/data-connections/utils'; +import { DataConnection } from '~/pages/projects/types'; + +const useLabeledDataConnections = ( + modelArtifactUri: string | undefined, + dataConnections: DataConnection[] = [], +): { + dataConnections: LabeledDataConnection[]; + storageFields: ObjectStorageFields | null; +} => + React.useMemo(() => { + if (!modelArtifactUri) { + return { + dataConnections: dataConnections.map((dataConnection) => ({ dataConnection })), + storageFields: null, + }; + } + const storageFields = uriToObjectStorageFields(modelArtifactUri); + if (!storageFields) { + return { + dataConnections: dataConnections.map((dataConnection) => ({ dataConnection })), + storageFields, + }; + } + const labeledDataConnections = dataConnections.map((dataConnection) => { + const awsData = convertAWSSecretData(dataConnection); + const bucket = awsData.find((data) => data.key === AwsKeys.AWS_S3_BUCKET)?.value; + const endpoint = awsData.find((data) => data.key === AwsKeys.S3_ENDPOINT)?.value; + const region = awsData.find((data) => data.key === AwsKeys.DEFAULT_REGION)?.value; + if ( + bucket === storageFields.bucket && + endpoint === storageFields.endpoint && + (region === storageFields.region || !storageFields.region) + ) { + return { dataConnection, isRecommended: true }; + } + return { dataConnection }; + }); + return { dataConnections: labeledDataConnections, storageFields }; + }, [dataConnections, modelArtifactUri]); + +export default useLabeledDataConnections; diff --git a/frontend/src/pages/modelRegistry/screens/RegisteredModels/usePrefillDeployModalFromModelRegistry.ts b/frontend/src/pages/modelRegistry/screens/RegisteredModels/usePrefillDeployModalFromModelRegistry.ts new file mode 100644 index 0000000000..22b6e757ea --- /dev/null +++ b/frontend/src/pages/modelRegistry/screens/RegisteredModels/usePrefillDeployModalFromModelRegistry.ts @@ -0,0 +1,87 @@ +import { AlertVariant } from '@patternfly/react-core'; +import React from 'react'; +import { ProjectKind } from '~/k8sTypes'; +import useLabeledDataConnections from '~/pages/modelRegistry/screens/RegisteredModels/useLabeledDataConnections'; +import { RegisteredModelDeployInfo } from '~/pages/modelRegistry/screens/RegisteredModels/useRegisteredModelDeployInfo'; +import { + CreatingInferenceServiceObject, + InferenceServiceStorageType, + LabeledDataConnection, +} from '~/pages/modelServing/screens/types'; +import { AwsKeys, EMPTY_AWS_SECRET_DATA } from '~/pages/projects/dataConnections/const'; +import useDataConnections from '~/pages/projects/screens/detail/data-connections/useDataConnections'; +import { DataConnection, UpdateObjectAtPropAndValue } from '~/pages/projects/types'; + +const usePrefillDeployModalFromModelRegistry = ( + projectContext: { currentProject: ProjectKind; dataConnections: DataConnection[] } | undefined, + createData: CreatingInferenceServiceObject, + setCreateData: UpdateObjectAtPropAndValue, + registeredModelDeployInfo?: RegisteredModelDeployInfo, +): [LabeledDataConnection[], boolean, Error | undefined] => { + const [fetchedDataConnections, dataConnectionsLoaded, dataConnectionsLoadError] = + useDataConnections(projectContext ? undefined : createData.project); + const allDataConnections = projectContext?.dataConnections || fetchedDataConnections; + const { dataConnections, storageFields } = useLabeledDataConnections( + registeredModelDeployInfo?.modelArtifactUri, + allDataConnections, + ); + + React.useEffect(() => { + if (registeredModelDeployInfo) { + setCreateData('name', registeredModelDeployInfo.modelName); + const recommendedDataConnections = dataConnections.filter( + (dataConnection) => dataConnection.isRecommended, + ); + + if (!storageFields) { + setCreateData('storage', { + awsData: EMPTY_AWS_SECRET_DATA, + dataConnection: '', + path: '', + type: InferenceServiceStorageType.EXISTING_STORAGE, + }); + } else { + const prefilledAWSData = [ + { key: AwsKeys.NAME, value: registeredModelDeployInfo.modelArtifactStorageKey || '' }, + { key: AwsKeys.AWS_S3_BUCKET, value: storageFields.bucket }, + { key: AwsKeys.S3_ENDPOINT, value: storageFields.endpoint }, + { key: AwsKeys.DEFAULT_REGION, value: storageFields.region || '' }, + ...EMPTY_AWS_SECRET_DATA, + ]; + if (recommendedDataConnections.length === 0) { + setCreateData('storage', { + awsData: prefilledAWSData, + dataConnection: '', + path: storageFields.path, + type: InferenceServiceStorageType.NEW_STORAGE, + alert: { + type: AlertVariant.info, + title: + "We've auto-switched to create a new data connection and pre-filled the details for you.", + message: + 'Model location info is available in the registry but no matching data connection in the project. So we automatically switched the option to create a new data connection and prefilled the information.', + }, + }); + } else if (recommendedDataConnections.length === 1) { + setCreateData('storage', { + awsData: prefilledAWSData, + dataConnection: recommendedDataConnections[0].dataConnection.data.metadata.name, + path: storageFields.path, + type: InferenceServiceStorageType.EXISTING_STORAGE, + }); + } else { + setCreateData('storage', { + awsData: prefilledAWSData, + dataConnection: '', + path: storageFields.path, + type: InferenceServiceStorageType.EXISTING_STORAGE, + }); + } + } + } + }, [dataConnections, storageFields, registeredModelDeployInfo, setCreateData]); + + return [dataConnections, dataConnectionsLoaded, dataConnectionsLoadError]; +}; + +export default usePrefillDeployModalFromModelRegistry; diff --git a/frontend/src/pages/modelRegistry/screens/RegisteredModels/useProjectErrorForRegisteredModel.ts b/frontend/src/pages/modelRegistry/screens/RegisteredModels/useProjectErrorForRegisteredModel.ts new file mode 100644 index 0000000000..c43e4ba66a --- /dev/null +++ b/frontend/src/pages/modelRegistry/screens/RegisteredModels/useProjectErrorForRegisteredModel.ts @@ -0,0 +1,34 @@ +import useServingRuntimes from '~/pages/modelServing/useServingRuntimes'; +import { ServingRuntimePlatform } from '~/types'; + +const useProjectErrorForRegisteredModel = ( + projectName?: string, + platform?: ServingRuntimePlatform, +): Error | undefined => { + const [servingRuntimes, loaded, loadError] = useServingRuntimes(projectName); + + // If project is not selected, there is no error + if (!projectName) { + return undefined; + } + + // If the platform is not selected + if (!platform) { + return new Error('Cannot deploy the model until you select a model serving platform'); + } + + if (loadError) { + return loadError; + } + + // If the platform is MULTI but it doesn't have a server + if (platform === ServingRuntimePlatform.MULTI) { + if (loaded && servingRuntimes.length === 0) { + return new Error('Cannot deploy the model until you configure a model server'); + } + } + + return undefined; +}; + +export default useProjectErrorForRegisteredModel; diff --git a/frontend/src/pages/modelRegistry/screens/RegisteredModels/useRegisteredModelDeployInfo.ts b/frontend/src/pages/modelRegistry/screens/RegisteredModels/useRegisteredModelDeployInfo.ts new file mode 100644 index 0000000000..77d75f0ab7 --- /dev/null +++ b/frontend/src/pages/modelRegistry/screens/RegisteredModels/useRegisteredModelDeployInfo.ts @@ -0,0 +1,65 @@ +import React from 'react'; +import useModelArtifactsByVersionId from '~/concepts/modelRegistry/apiHooks/useModelArtifactsByVersionId'; +import useRegisteredModelById from '~/concepts/modelRegistry/apiHooks/useRegisteredModelById'; +import { ModelVersion } from '~/concepts/modelRegistry/types'; + +export type RegisteredModelDeployInfo = { + modelName: string; + modelFormat?: string; + modelArtifactUri?: string; + modelArtifactStorageKey?: string; +}; + +const useRegisteredModelDeployInfo = ( + modelVersion: ModelVersion, +): { + registeredModelDeployInfo: RegisteredModelDeployInfo; + loaded: boolean; + error: Error | undefined; +} => { + const [registeredModel, registeredModelLoaded, registeredModelError] = useRegisteredModelById( + modelVersion.registeredModelId, + ); + const [modelArtifactList, modelArtifactListLoaded, modelArtifactListError] = + useModelArtifactsByVersionId(modelVersion.id); + + const registeredModelDeployInfo = React.useMemo(() => { + const dateString = new Date().toISOString(); + const modelName = `${registeredModel?.name} - ${modelVersion.name} - ${dateString}`; + if (modelArtifactList.size === 0) { + return { + registeredModelDeployInfo: { + modelName, + }, + loaded: registeredModelLoaded && modelArtifactListLoaded, + error: registeredModelError || modelArtifactListError, + }; + } + const modelArtifact = modelArtifactList.items[0]; + return { + registeredModelDeployInfo: { + modelName, + modelFormat: modelArtifact.modelFormatName + ? `${modelArtifact.modelFormatName} - ${modelArtifact.modelFormatVersion}` + : undefined, + modelArtifactUri: modelArtifact.uri, + modelArtifactStorageKey: modelArtifact.storageKey, + }, + loaded: registeredModelLoaded && modelArtifactListLoaded, + error: registeredModelError || modelArtifactListError, + }; + }, [ + modelArtifactList.items, + modelArtifactList.size, + modelArtifactListError, + modelArtifactListLoaded, + modelVersion.name, + registeredModel?.name, + registeredModelError, + registeredModelLoaded, + ]); + + return registeredModelDeployInfo; +}; + +export default useRegisteredModelDeployInfo; diff --git a/frontend/src/pages/modelRegistry/screens/__tests__/useProjectErrorForRegisteredModel.spec.ts b/frontend/src/pages/modelRegistry/screens/__tests__/useProjectErrorForRegisteredModel.spec.ts new file mode 100644 index 0000000000..2c928fb2e4 --- /dev/null +++ b/frontend/src/pages/modelRegistry/screens/__tests__/useProjectErrorForRegisteredModel.spec.ts @@ -0,0 +1,97 @@ +import { k8sListResource } from '@openshift/dynamic-plugin-sdk-utils'; +import { mockDashboardConfig, mockK8sResourceList } from '~/__mocks__'; +import { mockServingRuntimeK8sResource } from '~/__mocks__/mockServingRuntimeK8sResource'; +import { testHook } from '~/__tests__/unit/testUtils/hooks'; +import { useAccessReview } from '~/api'; +import { useAppContext } from '~/app/AppContext'; +import { ServingRuntimeKind } from '~/k8sTypes'; +import useProjectErrorForRegisteredModel from '~/pages/modelRegistry/screens/RegisteredModels/useProjectErrorForRegisteredModel'; +import { ServingRuntimePlatform } from '~/types'; + +jest.mock('@openshift/dynamic-plugin-sdk-utils', () => ({ + k8sListResource: jest.fn(), +})); + +// Mock the API functions +jest.mock('~/api', () => ({ + ...jest.requireActual('~/api'), + useAccessReview: jest.fn(), +})); + +jest.mock('~/pages/modelServing/useServingPlatformStatuses', () => ({ + __esModule: true, + default: jest.fn(), +})); + +jest.mock('~/app/AppContext', () => ({ + __esModule: true, + useAppContext: jest.fn(), +})); + +const useAppContextMock = jest.mocked(useAppContext); +const k8sListResourceMock = jest.mocked(k8sListResource); +const useAccessReviewMock = jest.mocked(useAccessReview); + +describe('useProjectErrorForRegisteredModel', () => { + beforeEach(() => { + useAppContextMock.mockReturnValue({ + buildStatuses: [], + dashboardConfig: mockDashboardConfig({}), + storageClasses: [], + isRHOAI: false, + }); + useAccessReviewMock.mockReturnValue([true, true]); + }); + it('should return undefined when the project is not selected', async () => { + k8sListResourceMock.mockResolvedValue(mockK8sResourceList([])); + const renderResult = testHook(useProjectErrorForRegisteredModel)(undefined, undefined); + // wait for update + await renderResult.waitForNextUpdate(); + expect(renderResult).hookToStrictEqual(undefined); + }); + + it('should return undefined when only kServe is supported', async () => { + k8sListResourceMock.mockResolvedValue(mockK8sResourceList([])); + const renderResult = testHook(useProjectErrorForRegisteredModel)( + 'test-project', + ServingRuntimePlatform.SINGLE, + ); + // wait for update + await renderResult.waitForNextUpdate(); + expect(renderResult).hookToStrictEqual(undefined); + }); + + it('should return undefined when only modelMesh is supported with server deployed', async () => { + k8sListResourceMock.mockResolvedValue(mockK8sResourceList([mockServingRuntimeK8sResource({})])); + const renderResult = testHook(useProjectErrorForRegisteredModel)( + 'test-project', + ServingRuntimePlatform.MULTI, + ); + // wait for update + await renderResult.waitForNextUpdate(); + expect(renderResult).hookToStrictEqual(undefined); + }); + + it('should return error when only modelMesh is supported with no server deployed', async () => { + k8sListResourceMock.mockResolvedValue(mockK8sResourceList([])); + const renderResult = testHook(useProjectErrorForRegisteredModel)( + 'test-project', + ServingRuntimePlatform.MULTI, + ); + // wait for update + await renderResult.waitForNextUpdate(); + expect(renderResult).hookToStrictEqual( + new Error('Cannot deploy the model until you configure a model server'), + ); + }); + + it('should return error when platform is not selected', async () => { + k8sListResourceMock.mockResolvedValue(mockK8sResourceList([])); + const renderResult = testHook(useProjectErrorForRegisteredModel)('test-project', undefined); + // wait for update + await renderResult.waitForNextUpdate(); + expect(renderResult).hookToStrictEqual( + new Error('Cannot deploy the model until you select a model serving platform'), + ); + }); +}); diff --git a/frontend/src/pages/modelRegistry/screens/components/DeployRegisteredModelModal.tsx b/frontend/src/pages/modelRegistry/screens/components/DeployRegisteredModelModal.tsx new file mode 100644 index 0000000000..0fe3e70dce --- /dev/null +++ b/frontend/src/pages/modelRegistry/screens/components/DeployRegisteredModelModal.tsx @@ -0,0 +1,127 @@ +import React from 'react'; +import { Alert, Button, Form, Modal, Spinner } from '@patternfly/react-core'; +import { ModelVersion } from '~/concepts/modelRegistry/types'; +import { ProjectKind } from '~/k8sTypes'; +import useProjectErrorForRegisteredModel from '~/pages/modelRegistry/screens/RegisteredModels/useProjectErrorForRegisteredModel'; +import ProjectSelector from '~/pages/modelServing/screens/projects/InferenceServiceModal/ProjectSelector'; +import ManageKServeModal from '~/pages/modelServing/screens/projects/kServeModal/ManageKServeModal'; +import useServingPlatformStatuses from '~/pages/modelServing/useServingPlatformStatuses'; +import { getProjectModelServingPlatform } from '~/pages/modelServing/screens/projects/utils'; +import { ServingRuntimePlatform } from '~/types'; +import ManageInferenceServiceModal from '~/pages/modelServing/screens/projects/InferenceServiceModal/ManageInferenceServiceModal'; +import useRegisteredModelDeployInfo from '~/pages/modelRegistry/screens/RegisteredModels/useRegisteredModelDeployInfo'; +import { ModelRegistryContext } from '~/concepts/modelRegistry/context/ModelRegistryContext'; +import { getKServeTemplates } from '~/pages/modelServing/customServingRuntimes/utils'; +import useDataConnections from '~/pages/projects/screens/detail/data-connections/useDataConnections'; + +interface DeployRegisteredModelModalProps { + onCancel: () => void; + isOpen: boolean; + modelVersion: ModelVersion; +} + +const DeployRegisteredModelModal: React.FC = ({ + isOpen, + onCancel, + modelVersion, +}) => { + const { + servingRuntimeTemplates: [templates], + servingRuntimeTemplateOrder: { data: templateOrder }, + servingRuntimeTemplateDisablement: { data: templateDisablement }, + } = React.useContext(ModelRegistryContext); + + const [selectedProject, setSelectedProject] = React.useState(null); + const servingPlatformStatuses = useServingPlatformStatuses(); + const { platform, error: platformError } = getProjectModelServingPlatform( + selectedProject, + servingPlatformStatuses, + ); + const projectError = useProjectErrorForRegisteredModel(selectedProject?.metadata.name, platform); + const [dataConnections] = useDataConnections(selectedProject?.metadata.name); + const error = platformError || projectError; + + const { + registeredModelDeployInfo, + loaded, + error: deployInfoError, + } = useRegisteredModelDeployInfo(modelVersion); + + const onClose = React.useCallback(() => { + setSelectedProject(null); + onCancel(); + }, [onCancel]); + + if (!selectedProject || !platform) { + return ( + + Deploy + , + , + ]} + showClose + > +
+ {deployInfoError ? ( + + {deployInfoError.message} + + ) : !loaded ? ( + + ) : ( + + )} + +
+ ); + } + + return ( + <> + + } + /> + + } + /> + + ); +}; + +export default DeployRegisteredModelModal; diff --git a/frontend/src/pages/modelServing/customServingRuntimes/utils.ts b/frontend/src/pages/modelServing/customServingRuntimes/utils.ts index 6e0c34f66c..4cc18a7d7d 100644 --- a/frontend/src/pages/modelServing/customServingRuntimes/utils.ts +++ b/frontend/src/pages/modelServing/customServingRuntimes/utils.ts @@ -164,3 +164,17 @@ export const getAPIProtocolFromServingRuntime = ( ) ?? undefined ); }; + +export const getKServeTemplates = ( + templates: TemplateKind[], + templateOrder: string[], + templateDisablement: string[], +): TemplateKind[] => { + const templatesSorted = getSortedTemplates(templates, templateOrder); + const templatesEnabled = templatesSorted.filter((template) => + getTemplateEnabled(template, templateDisablement), + ); + return templatesEnabled.filter((template) => + getTemplateEnabledForPlatform(template, ServingRuntimePlatform.SINGLE), + ); +}; diff --git a/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/DataConnectionExistingField.tsx b/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/DataConnectionExistingField.tsx index 40541368ad..5543d4a9b8 100644 --- a/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/DataConnectionExistingField.tsx +++ b/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/DataConnectionExistingField.tsx @@ -1,9 +1,21 @@ import * as React from 'react'; -import { Button, FormGroup, Popover, Stack, StackItem } from '@patternfly/react-core'; +import { + Button, + Flex, + FlexItem, + FormGroup, + Label, + Popover, + Stack, + StackItem, +} from '@patternfly/react-core'; import { Select, SelectOption } from '@patternfly/react-core/deprecated'; import { OutlinedQuestionCircleIcon } from '@patternfly/react-icons'; -import { DataConnection, UpdateObjectAtPropAndValue } from '~/pages/projects/types'; -import { CreatingInferenceServiceObject } from '~/pages/modelServing/screens/types'; +import { UpdateObjectAtPropAndValue } from '~/pages/projects/types'; +import { + CreatingInferenceServiceObject, + LabeledDataConnection, +} from '~/pages/modelServing/screens/types'; import { filterOutConnectionsWithoutBucket } from '~/pages/modelServing/screens/projects/utils'; import { getDataConnectionDisplayName } from '~/pages/projects/screens/detail/data-connections/utils'; import DataConnectionFolderPathField from './DataConnectionFolderPathField'; @@ -11,7 +23,7 @@ import DataConnectionFolderPathField from './DataConnectionFolderPathField'; type DataConnectionExistingFieldType = { data: CreatingInferenceServiceObject; setData: UpdateObjectAtPropAndValue; - dataConnections: DataConnection[]; + dataConnections: LabeledDataConnection[]; }; const DataConnectionExistingField: React.FC = ({ @@ -64,10 +76,19 @@ const DataConnectionExistingField: React.FC = ( > {connectionsWithoutBucket.map((connection) => ( - {getDataConnectionDisplayName(connection)} + + {getDataConnectionDisplayName(connection.dataConnection)} + {connection.isRecommended && ( + + + + )} + ))} diff --git a/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/DataConnectionSection.tsx b/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/DataConnectionSection.tsx index 3df55fbe00..3852958739 100644 --- a/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/DataConnectionSection.tsx +++ b/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/DataConnectionSection.tsx @@ -1,12 +1,12 @@ import * as React from 'react'; import { Alert, FormGroup, Radio, Skeleton, Stack, StackItem } from '@patternfly/react-core'; -import { DataConnection, UpdateObjectAtPropAndValue } from '~/pages/projects/types'; +import { UpdateObjectAtPropAndValue } from '~/pages/projects/types'; import { CreatingInferenceServiceObject, InferenceServiceStorageType, + LabeledDataConnection, } from '~/pages/modelServing/screens/types'; import AWSField from '~/pages/projects/dataConnections/AWSField'; -import useDataConnections from '~/pages/projects/screens/detail/data-connections/useDataConnections'; import { AwsKeys } from '~/pages/projects/dataConnections/const'; import DataConnectionExistingField from './DataConnectionExistingField'; import DataConnectionFolderPathField from './DataConnectionFolderPathField'; @@ -14,19 +14,18 @@ import DataConnectionFolderPathField from './DataConnectionFolderPathField'; type DataConnectionSectionType = { data: CreatingInferenceServiceObject; setData: UpdateObjectAtPropAndValue; - dataConnectionContext?: DataConnection[]; + loaded: boolean; + loadError: Error | undefined; + dataConnections: LabeledDataConnection[]; }; const DataConnectionSection: React.FC = ({ data, setData, - dataConnectionContext, + loaded, + loadError, + dataConnections, }) => { - const [dataContext, loaded, loadError] = useDataConnections( - dataConnectionContext ? undefined : data.project, - ); - const dataConnections = dataConnectionContext || dataContext; - if (loadError) { return ( @@ -53,7 +52,7 @@ const DataConnectionSection: React.FC = ({ } body={ data.storage.type === InferenceServiceStorageType.EXISTING_STORAGE && - (!dataConnectionContext && !loaded && data.project !== '' ? ( + (!loaded && data.project !== '' ? ( ) : ( = ({ label="New data connection" isChecked={data.storage.type === InferenceServiceStorageType.NEW_STORAGE} onChange={() => - setData('storage', { ...data.storage, type: InferenceServiceStorageType.NEW_STORAGE }) + setData('storage', { + ...data.storage, + type: InferenceServiceStorageType.NEW_STORAGE, + alert: undefined, + }) } body={ data.storage.type === InferenceServiceStorageType.NEW_STORAGE && ( + {data.storage.alert && ( + + + {data.storage.alert.message} + + + )} ; modelContext?: SupportedModelFormats[]; + registeredModelFormat?: string; }; const InferenceServiceFrameworkSection: React.FC = ({ data, setData, modelContext, + registeredModelFormat, }) => { const [isOpen, setOpen] = React.useState(false); @@ -74,6 +83,13 @@ const InferenceServiceFrameworkSection: React.FC; })} + {registeredModelFormat && models.length !== 0 && ( + + + The source model format is {registeredModelFormat} + + + )} ); }; diff --git a/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/ManageInferenceServiceModal.tsx b/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/ManageInferenceServiceModal.tsx index fb9a503e3d..d6cd447b01 100644 --- a/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/ManageInferenceServiceModal.tsx +++ b/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/ManageInferenceServiceModal.tsx @@ -13,6 +13,8 @@ import { isAWSValid } from '~/pages/projects/screens/spawner/spawnerUtils'; import { AwsKeys } from '~/pages/projects/dataConnections/const'; import { getDisplayNameFromK8sResource, translateDisplayNameForK8s } from '~/concepts/k8s/utils'; import { containsOnlySlashes, isS3PathValid } from '~/utilities/string'; +import { RegisteredModelDeployInfo } from '~/pages/modelRegistry/screens/RegisteredModels/useRegisteredModelDeployInfo'; +import usePrefillDeployModalFromModelRegistry from '~/pages/modelRegistry/screens/RegisteredModels/usePrefillDeployModalFromModelRegistry'; import DataConnectionSection from './DataConnectionSection'; import ProjectSection from './ProjectSection'; import InferenceServiceFrameworkSection from './InferenceServiceFrameworkSection'; @@ -22,6 +24,9 @@ import InferenceServiceNameSection from './InferenceServiceNameSection'; type ManageInferenceServiceModalProps = { isOpen: boolean; onClose: (submit: boolean) => void; + registeredModelDeployInfo?: RegisteredModelDeployInfo; + shouldFormHidden?: boolean; + projectSection?: React.ReactNode; } & EitherOrNone< { editInfo?: InferenceServiceKind }, { @@ -38,6 +43,9 @@ const ManageInferenceServiceModal: React.FC = onClose, editInfo, projectContext, + projectSection, + registeredModelDeployInfo, + shouldFormHidden, }) => { const [createData, setCreateData, resetData] = useCreateInferenceServiceObject(editInfo); const [actionInProgress, setActionInProgress] = React.useState(false); @@ -48,6 +56,14 @@ const ManageInferenceServiceModal: React.FC = const currentProjectName = projectContext?.currentProject.metadata.name || ''; const currentServingRuntimeName = projectContext?.currentServingRuntime?.metadata.name || ''; + const [dataConnections, dataConnectionsLoaded, dataConnectionsLoadError] = + usePrefillDeployModalFromModelRegistry( + projectContext, + createData, + setCreateData, + registeredModelDeployInfo, + ); + React.useEffect(() => { setCreateData('project', currentProjectName); setCreateData('servingRuntimeName', currentServingRuntimeName); @@ -121,45 +137,54 @@ const ManageInferenceServiceModal: React.FC =
- - - - - - - - - - - - - - - + )} + {!shouldFormHidden && ( + <> + + + + + + + + + + + + + + + + )}
diff --git a/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/ProjectSelector.tsx b/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/ProjectSelector.tsx new file mode 100644 index 0000000000..9c6e48a960 --- /dev/null +++ b/frontend/src/pages/modelServing/screens/projects/InferenceServiceModal/ProjectSelector.tsx @@ -0,0 +1,95 @@ +import React from 'react'; +import { + Alert, + FormGroup, + MenuToggle, + Select, + SelectList, + SelectOption, + Stack, + StackItem, +} from '@patternfly/react-core'; +import { Link } from 'react-router-dom'; +import { getDisplayNameFromK8sResource } from '~/concepts/k8s/utils'; +import { byName, ProjectsContext } from '~/concepts/projects/ProjectsContext'; +import { ProjectKind } from '~/k8sTypes'; +import { ProjectSectionID } from '~/pages/projects/screens/detail/types'; + +type ProjectSelectorProps = { + selectedProject: ProjectKind | null; + setSelectedProject: (project: ProjectKind | null) => void; + error?: Error; +}; + +const ProjectSelector: React.FC = ({ + selectedProject, + setSelectedProject, + error, +}) => { + const [isOpen, setOpen] = React.useState(false); + + const { projects } = React.useContext(ProjectsContext); + + return ( + + + + + + {error && selectedProject && ( + + + + Go to {getDisplayNameFromK8sResource(selectedProject)} project page + + + + )} + + + ); +}; + +export default ProjectSelector; diff --git a/frontend/src/pages/modelServing/screens/projects/__tests__/utils.spec.ts b/frontend/src/pages/modelServing/screens/projects/__tests__/utils.spec.ts index 6b24665a42..e98de451ea 100644 --- a/frontend/src/pages/modelServing/screens/projects/__tests__/utils.spec.ts +++ b/frontend/src/pages/modelServing/screens/projects/__tests__/utils.spec.ts @@ -5,30 +5,33 @@ import { getProjectModelServingPlatform, getUrlFromKserveInferenceService, } from '~/pages/modelServing/screens/projects/utils'; -import { DataConnection } from '~/pages/projects/types'; -import { ServingPlatformStatuses } from '~/pages/modelServing/screens/types'; +import { LabeledDataConnection, ServingPlatformStatuses } from '~/pages/modelServing/screens/types'; import { ServingRuntimePlatform } from '~/types'; import { mockInferenceServiceK8sResource } from '~/__mocks__/mockInferenceServiceK8sResource'; describe('filterOutConnectionsWithoutBucket', () => { it('should return an empty array if input connections array is empty', () => { - const inputConnections: DataConnection[] = []; + const inputConnections: LabeledDataConnection[] = []; const result = filterOutConnectionsWithoutBucket(inputConnections); expect(result).toEqual([]); }); it('should filter out connections without an AWS_S3_BUCKET property', () => { const dataConnections = [ - mockDataConnection({ name: 'name1', s3Bucket: 'bucket1' }), - mockDataConnection({ name: 'name2', s3Bucket: '' }), - mockDataConnection({ name: 'name3', s3Bucket: 'bucket2' }), + { dataConnection: mockDataConnection({ name: 'name1', s3Bucket: 'bucket1' }) }, + { dataConnection: mockDataConnection({ name: 'name2', s3Bucket: '' }) }, + { dataConnection: mockDataConnection({ name: 'name3', s3Bucket: 'bucket2' }) }, ]; const result = filterOutConnectionsWithoutBucket(dataConnections); expect(result).toMatchObject([ - { data: { data: { Name: 'name1' } } }, - { data: { data: { Name: 'name3' } } }, + { + dataConnection: { data: { data: { Name: 'name1' } } }, + }, + { + dataConnection: { data: { data: { Name: 'name3' } } }, + }, ]); }); }); diff --git a/frontend/src/pages/modelServing/screens/projects/kServeModal/ManageKServeModal.tsx b/frontend/src/pages/modelServing/screens/projects/kServeModal/ManageKServeModal.tsx index 07a1a61d1b..0c598e02ed 100644 --- a/frontend/src/pages/modelServing/screens/projects/kServeModal/ManageKServeModal.tsx +++ b/frontend/src/pages/modelServing/screens/projects/kServeModal/ManageKServeModal.tsx @@ -45,6 +45,8 @@ import { containsOnlySlashes, isS3PathValid } from '~/utilities/string'; import AuthServingRuntimeSection from '~/pages/modelServing/screens/projects/ServingRuntimeModal/AuthServingRuntimeSection'; import { useAccessReview } from '~/api'; import { SupportedArea, useIsAreaAvailable } from '~/concepts/areas'; +import { RegisteredModelDeployInfo } from '~/pages/modelRegistry/screens/RegisteredModels/useRegisteredModelDeployInfo'; +import usePrefillDeployModalFromModelRegistry from '~/pages/modelRegistry/screens/RegisteredModels/usePrefillDeployModalFromModelRegistry'; import KServeAutoscalerReplicaSection from './KServeAutoscalerReplicaSection'; const accessReviewResource: AccessReviewResourceAttributes = { @@ -57,6 +59,9 @@ type ManageKServeModalProps = { isOpen: boolean; onClose: (submit: boolean) => void; servingRuntimeTemplates?: TemplateKind[]; + registeredModelDeployInfo?: RegisteredModelDeployInfo; + shouldFormHidden?: boolean; + projectSection?: React.ReactNode; } & EitherOrNone< { projectContext?: { @@ -79,6 +84,9 @@ const ManageKServeModal: React.FC = ({ servingRuntimeTemplates, projectContext, editInfo, + projectSection, + registeredModelDeployInfo, + shouldFormHidden, }) => { const [createDataServingRuntime, setCreateDataServingRuntime, resetDataServingRuntime, sizes] = useCreateServingRuntimeObject(editInfo?.servingRuntimeEditInfo); @@ -88,6 +96,13 @@ const ManageKServeModal: React.FC = ({ editInfo?.servingRuntimeEditInfo?.servingRuntime, editInfo?.secrets, ); + const [dataConnections, dataConnectionsLoaded, dataConnectionsLoadError] = + usePrefillDeployModalFromModelRegistry( + projectContext, + createDataInferenceService, + setCreateDataInferenceService, + registeredModelDeployInfo, + ); const isAuthorinoEnabled = useIsAreaAvailable(SupportedArea.K_SERVE_AUTH).status; const currentProjectName = projectContext?.currentProject.metadata.name; @@ -116,6 +131,18 @@ const ManageKServeModal: React.FC = ({ } }, [currentProjectName, setCreateDataInferenceService, isOpen]); + // Refresh model format selection when changing serving runtime template selection + // Don't affect the edit modal + React.useEffect(() => { + if (!editInfo?.servingRuntimeEditInfo?.servingRuntime) { + setCreateDataInferenceService('format', { name: '' }); + } + }, [ + createDataServingRuntime.servingRuntimeTemplateName, + editInfo?.servingRuntimeEditInfo?.servingRuntime, + setCreateDataInferenceService, + ]); + // Serving Runtime Validation const isDisabledServingRuntime = namespace === '' || actionInProgress; @@ -270,78 +297,87 @@ const ManageKServeModal: React.FC = ({
)} - - - - - - + )} - - - - - - - - - - - - - {isAuthorinoEnabled && ( - - - + /> + + + + + {isAuthorinoEnabled && ( + + + + )} + + + + + + )} - - - - -
diff --git a/frontend/src/pages/modelServing/screens/projects/utils.ts b/frontend/src/pages/modelServing/screens/projects/utils.ts index 8d2edeae8a..fcdcaa16d0 100644 --- a/frontend/src/pages/modelServing/screens/projects/utils.ts +++ b/frontend/src/pages/modelServing/screens/projects/utils.ts @@ -7,11 +7,7 @@ import { SecretKind, ServingRuntimeKind, } from '~/k8sTypes'; -import { - DataConnection, - NamespaceApplicationCase, - UpdateObjectAtPropAndValue, -} from '~/pages/projects/types'; +import { NamespaceApplicationCase, UpdateObjectAtPropAndValue } from '~/pages/projects/types'; import useGenericObjectState from '~/utilities/useGenericObjectState'; import { CreatingInferenceServiceObject, @@ -20,6 +16,7 @@ import { ServingPlatformStatuses, ServingRuntimeEditInfo, ModelServingSize, + LabeledDataConnection, } from '~/pages/modelServing/screens/types'; import { ServingRuntimePlatform } from '~/types'; import { DEFAULT_MODEL_SERVER_SIZES } from '~/pages/modelServing/screens/const'; @@ -536,8 +533,10 @@ export const isUrlInternalService = (url: string | undefined): boolean => url !== undefined && url.endsWith('.svc.cluster.local'); export const filterOutConnectionsWithoutBucket = ( - connections: DataConnection[], -): DataConnection[] => + connections: LabeledDataConnection[], +): LabeledDataConnection[] => connections.filter( - (obj) => isDataConnectionAWS(obj) && obj.data.data.AWS_S3_BUCKET.trim() !== '', + (obj) => + isDataConnectionAWS(obj.dataConnection) && + obj.dataConnection.data.data.AWS_S3_BUCKET.trim() !== '', ); diff --git a/frontend/src/pages/modelServing/screens/types.ts b/frontend/src/pages/modelServing/screens/types.ts index a8bf14fdb2..fea89c70c5 100644 --- a/frontend/src/pages/modelServing/screens/types.ts +++ b/frontend/src/pages/modelServing/screens/types.ts @@ -1,5 +1,6 @@ +import { AlertVariant } from '@patternfly/react-core'; import { SecretKind, ServingRuntimeKind } from '~/k8sTypes'; -import { EnvVariableDataEntry } from '~/pages/projects/types'; +import { DataConnection, EnvVariableDataEntry } from '~/pages/projects/types'; import { ContainerResources } from '~/types'; export enum PerformanceMetricType { @@ -78,6 +79,11 @@ export type InferenceServiceStorage = { path: string; dataConnection: string; awsData: EnvVariableDataEntry[]; + alert?: { + type: AlertVariant; + title: string; + message: string; + }; }; export type InferenceServiceFormat = { @@ -100,3 +106,8 @@ export type ServingPlatformStatuses = { installed: boolean; }; }; + +export type LabeledDataConnection = { + dataConnection: DataConnection; + isRecommended?: boolean; +}; From 807e556e9b64e2d86a71584f6078ac8c34dc8cb6 Mon Sep 17 00:00:00 2001 From: jpuzz0 <96431149+jpuzz0@users.noreply.github.com> Date: Mon, 12 Aug 2024 17:25:33 -0400 Subject: [PATCH 04/31] [RHOAIENG-8816] Automatic discovery of Accelerator profile handles AMD GPUs like they were NVIDIA (#3064) --- backend/src/utils/resourceUtils.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/backend/src/utils/resourceUtils.ts b/backend/src/utils/resourceUtils.ts index 041beb1beb..1bec2d7b93 100644 --- a/backend/src/utils/resourceUtils.ts +++ b/backend/src/utils/resourceUtils.ts @@ -677,8 +677,11 @@ export const cleanupGPU = async (fastify: KubeFastifyInstance): Promise => ) { // if gpu detected on cluster, create our default migrated-gpu const acceleratorDetected = await getDetectedAccelerators(fastify); + const hasNvidiaNodes = Object.keys(acceleratorDetected.total).some( + (nodeKey) => nodeKey === 'nvidia.com/gpu', + ); - if (acceleratorDetected.configured) { + if (acceleratorDetected.configured && hasNvidiaNodes) { const payload: AcceleratorProfileKind = { kind: 'AcceleratorProfile', apiVersion: 'dashboard.opendatahub.io/v1', From 68047f3fe47fc73cda07178499ec133676d1d8cd Mon Sep 17 00:00:00 2001 From: "Heiko W. Rupp" Date: Tue, 13 Aug 2024 14:49:58 +0200 Subject: [PATCH 05/31] RHOAIENG-9232 Create useTrackUser. (#3080) Addendum from review of PR3024 --- .../analyticsTracking/trackingProperties.ts | 1 - .../analyticsTracking/useSegmentTracking.ts | 5 +++-- .../concepts/analyticsTracking/useTrackUser.ts | 18 +++++++----------- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/frontend/src/concepts/analyticsTracking/trackingProperties.ts b/frontend/src/concepts/analyticsTracking/trackingProperties.ts index 17daeb3a16..61a56add1e 100644 --- a/frontend/src/concepts/analyticsTracking/trackingProperties.ts +++ b/frontend/src/concepts/analyticsTracking/trackingProperties.ts @@ -5,7 +5,6 @@ export type ODHSegmentKey = { export type IdentifyEventProperties = { isAdmin: boolean; anonymousID?: string; - userId?: string; canCreateProjects: boolean; }; diff --git a/frontend/src/concepts/analyticsTracking/useSegmentTracking.ts b/frontend/src/concepts/analyticsTracking/useSegmentTracking.ts index 2411a5b744..7c5473d832 100644 --- a/frontend/src/concepts/analyticsTracking/useSegmentTracking.ts +++ b/frontend/src/concepts/analyticsTracking/useSegmentTracking.ts @@ -12,13 +12,14 @@ export const useSegmentTracking = (): void => { const username = useAppSelector((state) => state.user); const clusterID = useAppSelector((state) => state.clusterID); const [userProps, uPropsLoaded] = useTrackUser(username); + const disableTrackingConfig = dashboardConfig.spec.dashboardConfig.disableTracking; React.useEffect(() => { if (segmentKey && loaded && !loadError && username && clusterID && uPropsLoaded) { window.clusterID = clusterID; initSegment({ segmentKey, - enabled: !dashboardConfig.spec.dashboardConfig.disableTracking, + enabled: !disableTrackingConfig, }).then(() => { fireIdentifyEvent(userProps); firePageEvent(); @@ -30,7 +31,7 @@ export const useSegmentTracking = (): void => { loaded, segmentKey, username, - dashboardConfig, + disableTrackingConfig, userProps, uPropsLoaded, ]); diff --git a/frontend/src/concepts/analyticsTracking/useTrackUser.ts b/frontend/src/concepts/analyticsTracking/useTrackUser.ts index 062334dd76..27429837c4 100644 --- a/frontend/src/concepts/analyticsTracking/useTrackUser.ts +++ b/frontend/src/concepts/analyticsTracking/useTrackUser.ts @@ -8,7 +8,6 @@ export const useTrackUser = (username?: string): [IdentifyEventProperties, boole const { isAdmin } = useUser(); const [anonymousId, setAnonymousId] = React.useState(undefined); - const [loaded, setLoaded] = React.useState(false); const createReviewResource: AccessReviewResourceAttributes = { group: 'project.openshift.io', resource: 'projectrequests', @@ -27,15 +26,12 @@ export const useTrackUser = (username?: string): [IdentifyEventProperties, boole return aId; }; - if (!anonymousId) { - computeAnonymousUserId().then((val) => { - setAnonymousId(val); - }); - } - if (acLoaded && anonymousId) { - setLoaded(true); - } - }, [username, anonymousId, acLoaded]); + computeAnonymousUserId().then((val) => { + setAnonymousId(val); + }); + // compute anonymousId only once + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); const props: IdentifyEventProperties = React.useMemo( () => ({ @@ -46,5 +42,5 @@ export const useTrackUser = (username?: string): [IdentifyEventProperties, boole [isAdmin, allowCreate, anonymousId], ); - return [props, loaded]; + return [props, acLoaded && !!anonymousId]; }; From 9ffed3cd68a0894dd3127e1c74e3a16f522c5fca Mon Sep 17 00:00:00 2001 From: Purva Naik Date: Tue, 13 Aug 2024 19:35:42 +0530 Subject: [PATCH 06/31] Use artofact downloader hook in fetching markdown html (#3053) --- frontend/src/__mocks__/mockArtifactStorage.ts | 25 ++ .../cypress/pages/pipelines/compareRuns.ts | 7 + .../cypress/cypress/support/commands/odh.ts | 9 + .../tests/mocked/pipelines/compareRuns.cy.ts | 310 +++++++++++------- .../__tests__/useArtifactStorage.spec.ts | 2 +- .../pipelines/apiHooks/useArtifactStorage.ts | 121 ++++--- .../markdown/MarkdownCompare.tsx | 9 +- .../markdown/useFetchMarkdownMaps.ts | 30 +- .../artifacts/ArtifactVisualization.tsx | 31 +- ...dhdashboardconfigs.opendatahub.io.crd.yaml | 2 + 10 files changed, 348 insertions(+), 198 deletions(-) create mode 100644 frontend/src/__mocks__/mockArtifactStorage.ts diff --git a/frontend/src/__mocks__/mockArtifactStorage.ts b/frontend/src/__mocks__/mockArtifactStorage.ts new file mode 100644 index 0000000000..a91d622aed --- /dev/null +++ b/frontend/src/__mocks__/mockArtifactStorage.ts @@ -0,0 +1,25 @@ +/* eslint-disable camelcase */ +import { ArtifactStorage } from '~/concepts/pipelines/types'; + +type MockArtifactStorageType = { + namespace?: string; + artifactId?: string; +}; + +export const mockArtifactStorage = ({ + namespace = 'test', + artifactId = '16', +}: MockArtifactStorageType): ArtifactStorage => ({ + artifact_id: artifactId, + storage_provider: 's3', + storage_path: + 'metrics-visualization-pipeline/5e873c64-39fa-4dd4-83db-eff0cdd1e274/html-visualization/html_artifact', + uri: 's3://aballant-pipelines/metrics-visualization-pipeline/5e873c64-39fa-4dd4-83db-eff0cdd1e274/html-visualization/html_artifact', + download_url: + 'https://test.s3.dualstack.us-east-1.amazonaws.com/metrics-visualization-pipeline/5e873c64-39fa-4dd4-83db-eff0cdd1e274/html-visualization/html_artifact?X-Amz-Algorithm=AWS4-HMAC-SHA256\u0026X-Amz-Credential=AKIAYQPE7PSILMBBLXMO%2F20240808%2Fus-east-1%2Fs3%2Faws4_request\u0026X-Amz-Date=20240808T070034Z\u0026X-Amz-Expires=15\u0026X-Amz-SignedHeaders=host\u0026response-content-disposition=attachment%3B%20filename%3D%22%22\u0026X-Amz-Signature=de39ee684dd606e75da3b07c1b9f0820f7442ea7a037ae1bffccea9e33610ea9', + namespace, + artifact_type: 'system.Markdownxw', + artifact_size: '61', + created_at: '2024-08-07T07:13:46.078Z', + last_updated_at: '2024-08-07T07:13:46.078Z', +}); diff --git a/frontend/src/__tests__/cypress/cypress/pages/pipelines/compareRuns.ts b/frontend/src/__tests__/cypress/cypress/pages/pipelines/compareRuns.ts index 5dd47265f8..caf7b326cf 100644 --- a/frontend/src/__tests__/cypress/cypress/pages/pipelines/compareRuns.ts +++ b/frontend/src/__tests__/cypress/cypress/pages/pipelines/compareRuns.ts @@ -178,6 +178,13 @@ class CompareRunsArtifactSelect extends Contextual { findArtifactContent(index = 0) { return this.find().findByTestId(`pipeline-run-artifact-content-${index}`); } + + findIframeContent(index = 0) { + return this.findArtifactContent(index) + .findByTestId('markdown-compare') + .its('0.contentDocument') + .its('body'); + } } class ConfusionMatrixArtifactSelect extends CompareRunsArtifactSelect { diff --git a/frontend/src/__tests__/cypress/cypress/support/commands/odh.ts b/frontend/src/__tests__/cypress/cypress/support/commands/odh.ts index 51838eb9a3..d742cba552 100644 --- a/frontend/src/__tests__/cypress/cypress/support/commands/odh.ts +++ b/frontend/src/__tests__/cypress/cypress/support/commands/odh.ts @@ -50,6 +50,7 @@ import type { } from '~/concepts/pipelines/kfTypes'; import type { GrpcResponse } from '~/__mocks__/mlmd/utils'; import type { BuildMockPipelinveVersionsType } from '~/__mocks__'; +import type { ArtifactStorage } from '~/concepts/pipelines/types'; type SuccessErrorResponse = { success: boolean; @@ -566,6 +567,14 @@ declare global { }, response: OdhResponse<{ notebook: NotebookKind; isRunning: boolean }>, ) => Cypress.Chainable) & + (( + type: 'GET /api/service/pipelines/:namespace/:serviceName/apis/v2beta1/artifacts/:artifactId', + options: { + query: { view: string }; + path: { namespace: string; serviceName: string; artifactId: number }; + }, + response: OdhResponse, + ) => Cypress.Chainable) & (( type: 'GET /api/storage/:namespace', options: { diff --git a/frontend/src/__tests__/cypress/cypress/tests/mocked/pipelines/compareRuns.cy.ts b/frontend/src/__tests__/cypress/cypress/tests/mocked/pipelines/compareRuns.cy.ts index 3cbe29a453..a3714565b2 100644 --- a/frontend/src/__tests__/cypress/cypress/tests/mocked/pipelines/compareRuns.cy.ts +++ b/frontend/src/__tests__/cypress/cypress/tests/mocked/pipelines/compareRuns.cy.ts @@ -24,6 +24,7 @@ import { compareRunsMetricsContent, } from '~/__tests__/cypress/cypress/pages/pipelines/compareRuns'; import { mockCancelledGoogleRpcStatus } from '~/__mocks__/mockGoogleRpcStatusKF'; +import { mockArtifactStorage } from '~/__mocks__/mockArtifactStorage'; import { initMlmdIntercepts } from './mlmdUtils'; const projectName = 'test-project-name'; @@ -82,7 +83,7 @@ const mockRun3 = buildMockRunKF({ describe('Compare runs', () => { beforeEach(() => { - initIntercepts(); + initIntercepts({}); }); it('zero runs in url', () => { @@ -210,122 +211,196 @@ describe('Compare runs', () => { }); describe('Metrics', () => { - beforeEach(() => { - initIntercepts(); - compareRunsGlobal.visit(projectName, mockExperiment.experiment_id, [ - mockRun.run_id, - mockRun2.run_id, - mockRun3.run_id, - ]); + describe('Metrics when s3endpoint feature flag is available and artifactapi is unavailable', () => { + beforeEach(() => { + initIntercepts({}); + + cy.interceptOdh( + 'GET /api/storage/:namespace/size', + { + query: { + key: 'metrics-visualization-pipeline/16dbff18-a3d5-4684-90ac-4e6198a9da0f/markdown-visualization/markdown_artifact', + }, + path: { namespace: projectName }, + }, + { body: 61 }, + ); + cy.interceptOdh( + 'GET /api/storage/:namespace', + { + query: { + key: 'metrics-visualization-pipeline/16dbff18-a3d5-4684-90ac-4e6198a9da0f/markdown-visualization/markdown_artifact', + }, + path: { namespace: projectName }, + }, + 'helloWorld', + ); + compareRunsGlobal.visit(projectName, mockExperiment.experiment_id, [ + mockRun.run_id, + mockRun2.run_id, + mockRun3.run_id, + ]); + }); + + it('shows empty state when the Runs list has no selections', () => { + compareRunsListTable.findSelectAllCheckbox().click(); // Uncheck all + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricsEmptyState() + .should('exist'); + compareRunsMetricsContent.findConfusionMatrixTab().click(); + compareRunsMetricsContent + .findConfusionMatrixTabContent() + .findConfusionMatrixEmptyState() + .should('exist'); + compareRunsMetricsContent.findRocCurveTab().click(); + compareRunsMetricsContent.findRocCurveTabContent().findRocCurveEmptyState().should('exist'); + compareRunsMetricsContent.findMarkdownTab().click(); + compareRunsMetricsContent.findMarkdownTabContent().findMarkdownEmptyState().should('exist'); + }); + + it('displays scalar metrics table data based on selections from Run list', () => { + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricsTable() + .should('exist'); + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricsColumnByName('Run name') + .should('exist'); + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricsColumnByName('Run 1') + .should('exist'); + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricsColumnByName('Run 2') + .should('exist'); + + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricsColumnByName('Execution name > Artifact name') + .should('exist'); + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricsColumnByName('digit-classification > metrics') + .should('exist'); + + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricName('accuracy') + .should('exist'); + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricCell('accuracy', 1) + .should('contain.text', '92'); + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricCell('accuracy', 2) + .should('contain.text', '92'); + + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricName('displayName') + .should('exist'); + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricCell('displayName', 1) + .should('contain.text', '"metrics"'); + compareRunsMetricsContent + .findScalarMetricsTabContent() + .findScalarMetricCell('displayName', 2) + .should('contain.text', '"metrics"'); + }); + + it('displays confusion matrix data based on selections from Run list', () => { + compareRunsMetricsContent.findConfusionMatrixTab().click(); + + const confusionMatrixCompare = compareRunsMetricsContent + .findConfusionMatrixTabContent() + .findConfusionMatrixSelect(mockRun.run_id); + + // check graph data + const graph = confusionMatrixCompare.findConfusionMatrixGraph(); + graph.checkLabels(['Setosa', 'Versicolour', 'Virginica']); + graph.checkCells([ + [38, 0, 0], + [2, 19, 9], + [1, 17, 19], + ]); + + // check expanded graph + confusionMatrixCompare.findExpandButton().click(); + compareRunsMetricsContent + .findConfusionMatrixTabContent() + .findExpandedConfusionMatrix() + .find() + .should('exist'); + }); + + it('display markdown based on selections from Run list', () => { + compareRunsMetricsContent.findMarkdownTab().click(); + let markDown = compareRunsMetricsContent + .findMarkdownTabContent() + .findMarkdownSelect(mockRun.run_id); + markDown.findArtifactContent().should('have.text', 'helloWorld'); + markDown = compareRunsMetricsContent + .findMarkdownTabContent() + .findMarkdownSelect(mockRun2.run_id); + markDown.findArtifactContent().should('have.text', 'helloWorld'); + markDown.findExpandButton().click(); + compareRunsMetricsContent.findMarkdownTabContent().findExpandedMarkdown().should('exist'); + }); + + it('displays ROC curve empty state when no artifacts are found', () => { + compareRunsMetricsContent.findRocCurveTab().click(); + const content = compareRunsMetricsContent.findRocCurveTabContent(); + content.findRocCurveSearchBar().type('invalid'); + content.findRocCurveTableEmptyState().should('exist'); + }); }); - - it('shows empty state when the Runs list has no selections', () => { - compareRunsListTable.findSelectAllCheckbox().click(); // Uncheck all - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricsEmptyState() - .should('exist'); - compareRunsMetricsContent.findConfusionMatrixTab().click(); - compareRunsMetricsContent - .findConfusionMatrixTabContent() - .findConfusionMatrixEmptyState() - .should('exist'); - compareRunsMetricsContent.findRocCurveTab().click(); - compareRunsMetricsContent.findRocCurveTabContent().findRocCurveEmptyState().should('exist'); - compareRunsMetricsContent.findMarkdownTab().click(); - compareRunsMetricsContent.findMarkdownTabContent().findMarkdownEmptyState().should('exist'); - }); - - it('displays scalar metrics table data based on selections from Run list', () => { - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricsTable() - .should('exist'); - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricsColumnByName('Run name') - .should('exist'); - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricsColumnByName('Run 1') - .should('exist'); - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricsColumnByName('Run 2') - .should('exist'); - - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricsColumnByName('Execution name > Artifact name') - .should('exist'); - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricsColumnByName('digit-classification > metrics') - .should('exist'); - - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricName('accuracy') - .should('exist'); - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricCell('accuracy', 1) - .should('contain.text', '92'); - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricCell('accuracy', 2) - .should('contain.text', '92'); - - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricName('displayName') - .should('exist'); - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricCell('displayName', 1) - .should('contain.text', '"metrics"'); - compareRunsMetricsContent - .findScalarMetricsTabContent() - .findScalarMetricCell('displayName', 2) - .should('contain.text', '"metrics"'); - }); - - it('displays confusion matrix data based on selections from Run list', () => { - compareRunsMetricsContent.findConfusionMatrixTab().click(); - - const confusionMatrixCompare = compareRunsMetricsContent - .findConfusionMatrixTabContent() - .findConfusionMatrixSelect(mockRun.run_id); - - // check graph data - const graph = confusionMatrixCompare.findConfusionMatrixGraph(); - graph.checkLabels(['Setosa', 'Versicolour', 'Virginica']); - graph.checkCells([ - [38, 0, 0], - [2, 19, 9], - [1, 17, 19], - ]); - - // check expanded graph - confusionMatrixCompare.findExpandButton().click(); - compareRunsMetricsContent - .findConfusionMatrixTabContent() - .findExpandedConfusionMatrix() - .find() - .should('exist'); - }); - - it('displays ROC curve empty state when no artifacts are found', () => { - compareRunsMetricsContent.findRocCurveTab().click(); - const content = compareRunsMetricsContent.findRocCurveTabContent(); - content.findRocCurveSearchBar().type('invalid'); - content.findRocCurveTableEmptyState().should('exist'); + describe('Metrics when artifactApi feature flag is available', () => { + beforeEach(() => { + initIntercepts({ disableArtifactAPI: false }); + compareRunsGlobal.visit(projectName, mockExperiment.experiment_id, [ + mockRun.run_id, + mockRun2.run_id, + mockRun3.run_id, + ]); + cy.interceptOdh( + 'GET /api/service/pipelines/:namespace/:serviceName/apis/v2beta1/artifacts/:artifactId', + { + query: { view: 'DOWNLOAD' }, + path: { namespace: projectName, serviceName: 'dspa', artifactId: 16 }, + }, + mockArtifactStorage({ namespace: projectName }), + ); + cy.intercept( + 'GET', + 'https://test.s3.dualstack.us-east-1.amazonaws.com/metrics-visualization-pipeline/5e873c64-39fa-4dd4-83db-eff0cdd1e274/html-visualization/html_artifact?X-Amz-Algorithm=AWS4-HMAC-SHA256\u0026X-Amz-Credential=AKIAYQPE7PSILMBBLXMO%2F20240808%2Fus-east-1%2Fs3%2Faws4_request\u0026X-Amz-Date=20240808T070034Z\u0026X-Amz-Expires=15\u0026X-Amz-SignedHeaders=host\u0026response-content-disposition=attachment%3B%20filename%3D%22%22\u0026X-Amz-Signature=de39ee684dd606e75da3b07c1b9f0820f7442ea7a037ae1bffccea9e33610ea9', + 'helloWorld', + ); + }); + + it('display markdown based on selections from Run list', () => { + compareRunsMetricsContent.findMarkdownTab().click(); + let markDown = compareRunsMetricsContent + .findMarkdownTabContent() + .findMarkdownSelect(mockRun.run_id); + markDown.findIframeContent().should('have.text', 'helloWorld'); + markDown = compareRunsMetricsContent + .findMarkdownTabContent() + .findMarkdownSelect(mockRun2.run_id); + markDown.findIframeContent().should('have.text', 'helloWorld'); + markDown.findExpandButton().click(); + compareRunsMetricsContent.findMarkdownTabContent().findExpandedMarkdown().should('exist'); + }); }); }); describe('No metrics', () => { beforeEach(() => { - initIntercepts(true); + initIntercepts({ noMetrics: true }); compareRunsGlobal.visit(projectName, mockExperiment.experiment_id, [ mockRun.run_id, mockRun2.run_id, @@ -374,10 +449,19 @@ describe('Compare runs', () => { }); }); -const initIntercepts = (noMetrics?: boolean) => { +type InterceptsType = { + noMetrics?: boolean; + disableArtifactAPI?: boolean; +}; + +const initIntercepts = ({ noMetrics, disableArtifactAPI = true }: InterceptsType) => { cy.interceptOdh( 'GET /api/config', - mockDashboardConfig({ disablePipelineExperiments: false, disableS3Endpoint: false }), + mockDashboardConfig({ + disablePipelineExperiments: false, + disableS3Endpoint: false, + disableArtifactsAPI: disableArtifactAPI, + }), ); cy.interceptK8sList( DataSciencePipelineApplicationModel, diff --git a/frontend/src/concepts/pipelines/apiHooks/__tests__/useArtifactStorage.spec.ts b/frontend/src/concepts/pipelines/apiHooks/__tests__/useArtifactStorage.spec.ts index 0110716012..164f425b05 100644 --- a/frontend/src/concepts/pipelines/apiHooks/__tests__/useArtifactStorage.spec.ts +++ b/frontend/src/concepts/pipelines/apiHooks/__tests__/useArtifactStorage.spec.ts @@ -88,7 +88,7 @@ describe('useArtifactStorage', () => { const storageObject = await result.current.getStorageObject(artifact); const storageObjectSize = await result.current.getStorageObjectSize(artifact); const storageObjectUrl = await result.current.getStorageObjectUrl(artifact); - expect(storageObject).toBe('hello world'); + expect(storageObject).toBe('http://rhoai.v1/namespace/45456'); expect(storageObjectSize).toBe(60); expect(storageObjectUrl).toBe('http://rhoai.v1/namespace/45456'); } diff --git a/frontend/src/concepts/pipelines/apiHooks/useArtifactStorage.ts b/frontend/src/concepts/pipelines/apiHooks/useArtifactStorage.ts index f220f5ffa1..3c2a2b5583 100644 --- a/frontend/src/concepts/pipelines/apiHooks/useArtifactStorage.ts +++ b/frontend/src/concepts/pipelines/apiHooks/useArtifactStorage.ts @@ -1,3 +1,4 @@ +import React from 'react'; import { Artifact } from '~/third_party/mlmd'; import { usePipelinesAPI } from '~/concepts/pipelines/context'; import { fetchStorageObject, fetchStorageObjectSize } from '~/services/storageService'; @@ -21,65 +22,77 @@ export type ArtifactType = export const useArtifactStorage = (): ArtifactType => { const s3EndpointAvailable = useIsAreaAvailable(SupportedArea.S3_ENDPOINT).status; const artifactApiAvailable = useIsAreaAvailable(SupportedArea.ARTIFACT_API).status; + const { api, namespace } = usePipelinesAPI(); - if (!s3EndpointAvailable && !artifactApiAvailable) { - return { enabled: false }; - } + const enabled = s3EndpointAvailable || artifactApiAvailable; - const getStorageObject = async (artifact: Artifact): Promise => { - if (artifactApiAvailable) { - return api - .getArtifact({}, artifact.getId(), 'DOWNLOAD') - .then((artifactStorage) => { - if (artifactStorage.download_url) { - return fetch(artifactStorage.download_url).then((downloadObject) => - downloadObject.text(), - ); - } - return Promise.reject(); - }) - .catch((e) => { - throw new Error(`Error fetching Storage object ${e}`); - }); - } + const getStorageObject = React.useCallback( + async (artifact: Artifact): Promise => { + if (artifactApiAvailable) { + return api + .getArtifact({}, artifact.getId(), 'DOWNLOAD') + .then((artifactStorage) => { + if (artifactStorage.download_url) { + return artifactStorage.download_url; + } + return Promise.reject(); + }) + .catch((e) => { + throw new Error(`Error fetching Storage object ${e}`); + }); + } - const path = artifact.getUri(); - const uriComponents = extractS3UriComponents(path); - if (uriComponents) { - return fetchStorageObject(namespace, uriComponents.path); - } - return Promise.reject(); - }; + const path = artifact.getUri(); + const uriComponents = extractS3UriComponents(path); + if (uriComponents) { + return fetchStorageObject(namespace, uriComponents.path); + } + return Promise.reject(); + }, + [api, artifactApiAvailable, namespace], + ); - const getStorageObjectSize = async (artifact: Artifact): Promise => { - if (artifactApiAvailable) { - return api - .getArtifact({}, artifact.getId()) - .then((artifactStorage) => Number(artifactStorage.artifact_size)) - .catch((e) => { - throw new Error(`Error fetching Storage size ${e}`); - }); - } - const path = artifact.getUri(); - const uriComponents = extractS3UriComponents(path); - if (uriComponents) { - return fetchStorageObjectSize(namespace, uriComponents.path); - } - return Promise.reject(); - }; + const getStorageObjectSize = React.useCallback( + async (artifact: Artifact): Promise => { + if (artifactApiAvailable) { + return api + .getArtifact({}, artifact.getId()) + .then((artifactStorage) => Number(artifactStorage.artifact_size)) + .catch((e) => { + throw new Error(`Error fetching Storage size ${e}`); + }); + } + const path = artifact.getUri(); + const uriComponents = extractS3UriComponents(path); + if (uriComponents) { + return fetchStorageObjectSize(namespace, uriComponents.path); + } + return Promise.reject(); + }, + [api, artifactApiAvailable, namespace], + ); - const getStorageObjectUrl = async (artifact: Artifact): Promise => { - if (artifactApiAvailable) { - return api - .getArtifact({}, artifact.getId(), 'DOWNLOAD') - .then((artifactStorage) => artifactStorage.download_url) - .catch((e) => { - throw new Error(`Error fetching Storage url ${e}`); - }); - } - return getArtifactUrlFromUri(artifact.getUri(), namespace); - }; + const getStorageObjectUrl = React.useCallback( + async (artifact: Artifact): Promise => { + if (artifactApiAvailable) { + return api + .getArtifact({}, artifact.getId(), 'DOWNLOAD') + .then((artifactStorage) => artifactStorage.download_url) + .catch((e) => { + throw new Error(`Error fetching Storage url ${e}`); + }); + } + return getArtifactUrlFromUri(artifact.getUri(), namespace); + }, + [api, artifactApiAvailable, namespace], + ); - return { enabled: true, getStorageObject, getStorageObjectSize, getStorageObjectUrl }; + return React.useMemo( + () => + enabled + ? { enabled: true, getStorageObject, getStorageObjectSize, getStorageObjectUrl } + : { enabled: false }, + [enabled, getStorageObject, getStorageObjectSize, getStorageObjectUrl], + ); }; diff --git a/frontend/src/concepts/pipelines/content/compareRuns/metricsSection/markdown/MarkdownCompare.tsx b/frontend/src/concepts/pipelines/content/compareRuns/metricsSection/markdown/MarkdownCompare.tsx index f365962ee5..c1f9ead5f8 100644 --- a/frontend/src/concepts/pipelines/content/compareRuns/metricsSection/markdown/MarkdownCompare.tsx +++ b/frontend/src/concepts/pipelines/content/compareRuns/metricsSection/markdown/MarkdownCompare.tsx @@ -16,6 +16,7 @@ import { MAX_STORAGE_OBJECT_SIZE } from '~/services/storageService'; import { bytesAsRoundedGiB } from '~/utilities/number'; import { PipelineRunKFv2 } from '~/concepts/pipelines/kfTypes'; import { CompareRunsNoMetrics } from '~/concepts/pipelines/content/compareRuns/CompareRunsNoMetrics'; +import { SupportedArea, useIsAreaAvailable } from '~/concepts/areas'; type MarkdownCompareProps = { configMap: Record; @@ -37,7 +38,7 @@ const MarkdownCompare: React.FC = ({ isEmpty, }) => { const [expandedGraph, setExpandedGraph] = React.useState(undefined); - + const isArtifactApiAvailable = useIsAreaAvailable(SupportedArea.ARTIFACT_API).status; if (!isLoaded) { return ( @@ -66,7 +67,11 @@ const MarkdownCompare: React.FC = ({ )} - + {isArtifactApiAvailable ? ( +