|
12 | 12 | ExportModelArtifactDetails,
|
13 | 13 | ImportModelArtifactDetails,
|
14 | 14 | Model,
|
15 |
| - ModelProvenance, |
| 15 | + ModelProvenance, RegisterModelArtifactReferenceDetails, OSSModelArtifactReferenceDetails, |
16 | 16 | )
|
17 | 17 | from oci.exceptions import ServiceError
|
18 | 18 | from oci.response import Response
|
@@ -136,6 +136,12 @@ def setup_class(cls):
|
136 | 136 | headers={"opc-work-request-id": "work_request_id"},
|
137 | 137 | request=None,
|
138 | 138 | )
|
| 139 | + cls.mock_register_model_artifact_reference_response = Response( |
| 140 | + data=None, |
| 141 | + status=None, |
| 142 | + headers={"opc-work-request-id": "work_request_id"}, |
| 143 | + request=None, |
| 144 | + ) |
139 | 145 |
|
140 | 146 | def setup_method(self):
|
141 | 147 | self.mock_model = OCIDataScienceModel(**OCI_MODEL_PAYLOAD)
|
@@ -173,6 +179,9 @@ def mock_client(self):
|
173 | 179 | mock_client.export_model_artifact = MagicMock(
|
174 | 180 | return_value=self.mock_export_artifact_response
|
175 | 181 | )
|
| 182 | + mock_client.register_model_artifact_reference = MagicMock( |
| 183 | + return_value=self.mock_register_model_artifact_reference_response |
| 184 | + ) |
176 | 185 | return mock_client
|
177 | 186 |
|
178 | 187 | def test_create_fail(self):
|
@@ -463,6 +472,62 @@ def test_export_model_artifact(
|
463 | 472 | progress_bar_description="Exporting model artifacts."
|
464 | 473 | )
|
465 | 474 |
|
| 475 | + @patch( |
| 476 | + "ads.model.service.oci_datascience_model.DataScienceWorkRequest.wait_work_request" |
| 477 | + ) |
| 478 | + @patch("ads.model.service.oci_datascience_model.DataScienceWorkRequest.__init__") |
| 479 | + def test_register_model_artifact_reference( |
| 480 | + self, |
| 481 | + mock_data_science_work_request, |
| 482 | + mock_wait_work_request, |
| 483 | + mock_client, |
| 484 | + ): |
| 485 | + """Tests register model artifact reference for a model in model catalog.""" |
| 486 | + test_bucket_uri_1 = "oci://bucket1@namespace1/prefix1/" |
| 487 | + test_bucket_uri_2 = "oci://bucket2@namespace2/prefix2/subPrefix2" |
| 488 | + test_bucket_uri_3 = "oci://bucket3@namespace3/" |
| 489 | + test_bucket_uri_4 = "oci://bucket4@namespace4" |
| 490 | + |
| 491 | + model_artifact_reference_details_1 = OSSModelArtifactReferenceDetails() |
| 492 | + model_artifact_reference_details_1.namespace = 'namespace1' |
| 493 | + model_artifact_reference_details_1.bucket_name = 'bucket1' |
| 494 | + model_artifact_reference_details_1.prefix = 'prefix1' |
| 495 | + |
| 496 | + model_artifact_reference_details_2 = OSSModelArtifactReferenceDetails() |
| 497 | + model_artifact_reference_details_2.namespace = 'namespace2' |
| 498 | + model_artifact_reference_details_2.bucket_name = 'bucket2' |
| 499 | + model_artifact_reference_details_2.prefix = 'prefix2/subPrefix2' |
| 500 | + |
| 501 | + model_artifact_reference_details_3 = OSSModelArtifactReferenceDetails() |
| 502 | + model_artifact_reference_details_3.namespace = 'namespace3' |
| 503 | + model_artifact_reference_details_3.bucket_name = 'bucket3' |
| 504 | + model_artifact_reference_details_3.prefix = None |
| 505 | + |
| 506 | + model_artifact_reference_details_4 = OSSModelArtifactReferenceDetails() |
| 507 | + model_artifact_reference_details_4.namespace = 'namespace4' |
| 508 | + model_artifact_reference_details_4.bucket_name = 'bucket4' |
| 509 | + model_artifact_reference_details_4.prefix = None |
| 510 | + |
| 511 | + model_artifact_reference_details_list = [model_artifact_reference_details_1, model_artifact_reference_details_2, |
| 512 | + model_artifact_reference_details_3, model_artifact_reference_details_4] |
| 513 | + |
| 514 | + register_model_artifact_reference_details = RegisterModelArtifactReferenceDetails() |
| 515 | + register_model_artifact_reference_details.model_artifact_references = model_artifact_reference_details_list |
| 516 | + |
| 517 | + mock_data_science_work_request.return_value = None |
| 518 | + with patch.object(OCIDataScienceModel, "client", mock_client): |
| 519 | + self.mock_model.register_model_artifact_reference( |
| 520 | + bucket_uri_list=[test_bucket_uri_1, test_bucket_uri_2, test_bucket_uri_3, test_bucket_uri_4] |
| 521 | + ) |
| 522 | + mock_client.register_model_artifact_reference.assert_called_with( |
| 523 | + model_id=self.mock_model.id, |
| 524 | + register_model_artifact_reference_details=register_model_artifact_reference_details |
| 525 | + ) |
| 526 | + mock_data_science_work_request.assert_called_with("work_request_id") |
| 527 | + mock_wait_work_request.assert_called_with( |
| 528 | + progress_bar_description="Registering model artifact references." |
| 529 | + ) |
| 530 | + |
466 | 531 | def test_is_model_by_reference(self):
|
467 | 532 | """Test to check if model is created by reference using custom metadata information"""
|
468 | 533 |
|
|
0 commit comments