Skip to content

Commit dd3208d

Browse files
author
targarg
committed
ODSC-64654 register model artifact reference
1 parent b15138c commit dd3208d

File tree

5 files changed

+189
-3
lines changed

5 files changed

+189
-3
lines changed

ads/model/datascience_model.py

+20
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,26 @@ def restore_model(
14051405
restore_model_for_hours_specified=restore_model_for_hours_specified,
14061406
)
14071407

1408+
def register_model_artifact_reference(self,bucket_uri_list: List[str]) -> None:
1409+
"""
1410+
Registers model artifact references against a model.
1411+
Can be used for any model for which model-artifact doesn't exist yet. Requires to provide List of Object
1412+
Storage buckets_uri(s) which contain the artifacts.
1413+
1414+
Parameters
1415+
----------
1416+
bucket_uri_list: List[str]
1417+
The list of OCI Object Storage URIs where model artifacts are present.
1418+
Example: [`oci://<bucket_name>@<namespace>/prefix/`, `oci://<bucket_name>@<namespace>/prefix/`].
1419+
1420+
Returns
1421+
-------
1422+
None
1423+
"""
1424+
self.dsc_model.register_model_artifact_reference(
1425+
bucket_uri_list=bucket_uri_list
1426+
)
1427+
14081428
def download_artifact(
14091429
self,
14101430
target_dir: str,

ads/model/service/oci_datascience_model.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
ExportModelArtifactDetails,
2727
ImportModelArtifactDetails,
2828
UpdateModelDetails,
29-
WorkRequest,
29+
WorkRequest, RegisterModelArtifactReferenceDetails, OSSModelArtifactReferenceDetails, ModelArtifactReferenceDetails,
3030
)
3131
from oci.exceptions import ServiceError
3232

@@ -449,6 +449,48 @@ def export_model_artifact(self, bucket_uri: str, region: str = None):
449449
progress_bar_description="Exporting model artifacts."
450450
)
451451

452+
@check_for_model_id(
453+
msg="Model needs to be saved to the Model Catalog before the artifact can be registered against it."
454+
)
455+
def register_model_artifact_reference(self, bucket_uri_list: List[str]) -> None:
456+
"""
457+
Registers model artifact references against a model.
458+
Can be used for any model for which model-artifact doesn't exist yet. Requires to provide List of Object
459+
Storage buckets_uri(s) which contain the artifacts.
460+
461+
Parameters
462+
----------
463+
bucket_uri_list: List[str]
464+
The list of OCI Object Storage URIs where model artifacts are present.
465+
Example: [`oci://<bucket_name>@<namespace>/prefix/`, `oci://<bucket_name>@<namespace>/prefix/`].
466+
467+
Returns
468+
-------
469+
None
470+
"""
471+
model_artifact_reference_details_list = []
472+
for bucket_uri in bucket_uri_list:
473+
bucket_details = ObjectStorageDetails.from_path(bucket_uri)
474+
model_artifact_reference_details = OSSModelArtifactReferenceDetails()
475+
model_artifact_reference_details.namespace = bucket_details.namespace
476+
model_artifact_reference_details.bucket_name = bucket_details.bucket
477+
if bucket_details.filepath is not None and bucket_details.filepath != "":
478+
model_artifact_reference_details.prefix = bucket_details.filepath.strip('/')
479+
model_artifact_reference_details_list.append(model_artifact_reference_details)
480+
481+
register_model_artifact_reference_details = RegisterModelArtifactReferenceDetails()
482+
register_model_artifact_reference_details.model_artifact_references = model_artifact_reference_details_list
483+
484+
work_request_id = self.client.register_model_artifact_reference(
485+
model_id=self.id,
486+
register_model_artifact_reference_details=register_model_artifact_reference_details
487+
).headers["opc-work-request-id"]
488+
489+
# Show progress of model artifact references being registered
490+
DataScienceWorkRequest(work_request_id).wait_work_request(
491+
progress_bar_description="Registering model artifact references."
492+
)
493+
452494
@check_for_model_id(
453495
msg="Model needs to be saved to the Model Catalog before it can be updated."
454496
)

docs/source/user_guide/model_catalog/model_catalog.rst

+42-1
Original file line numberDiff line numberDiff line change
@@ -1553,4 +1553,45 @@ In the next example, the model that was stored in the model catalog as part of t
15531553
Restore Archived Model
15541554
**********************
15551555

1556-
The ``.restore_model()`` method of Model catalog restores the model for a specified number of hours. Restored models can be downloaded for 1-240 hours, defaulting to 24 hours.
1556+
The ``.restore_model()`` method of Model catalog restores the model for a specified number of hours. Restored models can be downloaded for 1-240 hours, defaulting to 24 hours.
1557+
1558+
Register Model Artifact Reference
1559+
**********************
1560+
1561+
The ``.register_model_artifact_reference()`` method of Model catalog registers the references of your OCI Object Storage buckets where the artifact files are present against the model.
1562+
1563+
By using this API, you can avoid the need to upload or export large model artifacts, and can simply give the references of the OCI Object Storage locations where your artifacts are present. The OCI Data Science will directly read artifact files from those locations when you create a deployment of the model.
1564+
1565+
The input to this method is a List of bucket_uri(s). The URI syntax for the bucket_uri is:
1566+
1567+
oci://<bucket_name>@<namespace>/<path>/
1568+
1569+
Example -
1570+
1571+
.. code-block:: python3
1572+
1573+
model.register_model_artifact_reference(
1574+
bucket_uri_list = ["oci://<bucket_name>@<namespace>/<path>/"]
1575+
)
1576+
1577+
Important Points:
1578+
1579+
1. The buckets provided should be of same region and have versioning enabled on them.
1580+
1581+
2. The <path> is optional. If your files that you want to use for this model are within a path in the bucket, then path can be specified in the bucket_uri, else it can be skipped like below:
1582+
1583+
oci://<bucket_name>@<namespace>/
1584+
1585+
3. The location specified by bucket_uri should have at-least one object within it.
1586+
1587+
4. Make sure that the buckets provided has following IAM policy configured to allow the Data Science service to read artifact files from those Object Storage buckets in your tenancy. An administrator must configure these policies in `IAM <https://docs.oracle.com/iaas/Content/Identity/home1.htm>`_ in the Console.
1588+
1589+
.. parsed-literal::
1590+
1591+
allow any-user to read object-family in compartment <compartment> where ALL {target.bucket.name= '<bucket_name>', request.principal.type = /\*datasciencemodel\*/}
1592+
1593+
If you want, you can have a more granular policy by having an additional filter on project_id like below, which will then give access to the bucket only to models present in the data science project specified in the filter.
1594+
1595+
.. parsed-literal::
1596+
1597+
allow any-user to read object-family in compartment <compartment> where ALL {target.bucket.name= '<bucket_name>', request.principal.type = /\*datasciencemodel\*/, request.principal.project_id = '<project_ocid>'}

tests/unitary/default_setup/model/test_datascience_model.py

+18
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,24 @@ def test_upload_artifact(self):
817817
)
818818
mock_upload.assert_called()
819819

820+
@patch.object(OCIDataScienceModel, 'register_model_artifact_reference')
821+
def test_register_model_artifact_reference(self, mock_register_model_artifact_reference):
822+
823+
# Sample input for the test
824+
bucket_uri_list = [
825+
"oci://bucket1@namespace1/prefix1/",
826+
"oci://bucket2@namespace2/prefix2/"
827+
]
828+
829+
# Call the function with the test data
830+
self.mock_dsc_model.register_model_artifact_reference(bucket_uri_list=bucket_uri_list)
831+
832+
# Assert that the mocked `register_model_artifact_reference` method was called once
833+
# and with the expected arguments
834+
mock_register_model_artifact_reference.assert_called_once_with(
835+
bucket_uri_list=bucket_uri_list
836+
)
837+
820838
def test_download_artifact(self):
821839
"""Tests downloading artifacts from the model catalog."""
822840
# Artifact size greater than 2GB

tests/unitary/default_setup/model/test_oci_datascience_model.py

+66-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ExportModelArtifactDetails,
1313
ImportModelArtifactDetails,
1414
Model,
15-
ModelProvenance,
15+
ModelProvenance, RegisterModelArtifactReferenceDetails, OSSModelArtifactReferenceDetails,
1616
)
1717
from oci.exceptions import ServiceError
1818
from oci.response import Response
@@ -136,6 +136,12 @@ def setup_class(cls):
136136
headers={"opc-work-request-id": "work_request_id"},
137137
request=None,
138138
)
139+
cls.mock_register_model_artifact_reference_response = Response(
140+
data=None,
141+
status=None,
142+
headers={"opc-work-request-id": "work_request_id"},
143+
request=None,
144+
)
139145

140146
def setup_method(self):
141147
self.mock_model = OCIDataScienceModel(**OCI_MODEL_PAYLOAD)
@@ -173,6 +179,9 @@ def mock_client(self):
173179
mock_client.export_model_artifact = MagicMock(
174180
return_value=self.mock_export_artifact_response
175181
)
182+
mock_client.register_model_artifact_reference = MagicMock(
183+
return_value=self.mock_register_model_artifact_reference_response
184+
)
176185
return mock_client
177186

178187
def test_create_fail(self):
@@ -463,6 +472,62 @@ def test_export_model_artifact(
463472
progress_bar_description="Exporting model artifacts."
464473
)
465474

475+
@patch(
476+
"ads.model.service.oci_datascience_model.DataScienceWorkRequest.wait_work_request"
477+
)
478+
@patch("ads.model.service.oci_datascience_model.DataScienceWorkRequest.__init__")
479+
def test_register_model_artifact_reference(
480+
self,
481+
mock_data_science_work_request,
482+
mock_wait_work_request,
483+
mock_client,
484+
):
485+
"""Tests register model artifact reference for a model in model catalog."""
486+
test_bucket_uri_1 = "oci://bucket1@namespace1/prefix1/"
487+
test_bucket_uri_2 = "oci://bucket2@namespace2/prefix2/subPrefix2"
488+
test_bucket_uri_3 = "oci://bucket3@namespace3/"
489+
test_bucket_uri_4 = "oci://bucket4@namespace4"
490+
491+
model_artifact_reference_details_1 = OSSModelArtifactReferenceDetails()
492+
model_artifact_reference_details_1.namespace = 'namespace1'
493+
model_artifact_reference_details_1.bucket_name = 'bucket1'
494+
model_artifact_reference_details_1.prefix = 'prefix1'
495+
496+
model_artifact_reference_details_2 = OSSModelArtifactReferenceDetails()
497+
model_artifact_reference_details_2.namespace = 'namespace2'
498+
model_artifact_reference_details_2.bucket_name = 'bucket2'
499+
model_artifact_reference_details_2.prefix = 'prefix2/subPrefix2'
500+
501+
model_artifact_reference_details_3 = OSSModelArtifactReferenceDetails()
502+
model_artifact_reference_details_3.namespace = 'namespace3'
503+
model_artifact_reference_details_3.bucket_name = 'bucket3'
504+
model_artifact_reference_details_3.prefix = None
505+
506+
model_artifact_reference_details_4 = OSSModelArtifactReferenceDetails()
507+
model_artifact_reference_details_4.namespace = 'namespace4'
508+
model_artifact_reference_details_4.bucket_name = 'bucket4'
509+
model_artifact_reference_details_4.prefix = None
510+
511+
model_artifact_reference_details_list = [model_artifact_reference_details_1, model_artifact_reference_details_2,
512+
model_artifact_reference_details_3, model_artifact_reference_details_4]
513+
514+
register_model_artifact_reference_details = RegisterModelArtifactReferenceDetails()
515+
register_model_artifact_reference_details.model_artifact_references = model_artifact_reference_details_list
516+
517+
mock_data_science_work_request.return_value = None
518+
with patch.object(OCIDataScienceModel, "client", mock_client):
519+
self.mock_model.register_model_artifact_reference(
520+
bucket_uri_list=[test_bucket_uri_1, test_bucket_uri_2, test_bucket_uri_3, test_bucket_uri_4]
521+
)
522+
mock_client.register_model_artifact_reference.assert_called_with(
523+
model_id=self.mock_model.id,
524+
register_model_artifact_reference_details=register_model_artifact_reference_details
525+
)
526+
mock_data_science_work_request.assert_called_with("work_request_id")
527+
mock_wait_work_request.assert_called_with(
528+
progress_bar_description="Registering model artifact references."
529+
)
530+
466531
def test_is_model_by_reference(self):
467532
"""Test to check if model is created by reference using custom metadata information"""
468533

0 commit comments

Comments
 (0)