From b90846533759ea82a34e1b885d5b3b05b4d49b9d Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 6 Nov 2024 14:43:41 +0100 Subject: [PATCH 1/3] Add artifact version batch request (#3164) * Add batch request endpoint for artifact version creation (cherry picked from commit 78fd508ccb6288e618324810441ac62ad110146e) * Fix some errors (cherry picked from commit c392d93ab2789d3780a5551e0c1c54f7a5935ac1) * Fix method call after merge --- src/zenml/artifacts/utils.py | 178 ++++++++++++------ src/zenml/constants.py | 1 + src/zenml/orchestrators/step_runner.py | 23 ++- src/zenml/zen_server/rbac/endpoint_utils.py | 44 ++++- .../routers/artifact_version_endpoints.py | 28 ++- src/zenml/zen_stores/rest_zen_store.py | 52 +++++ src/zenml/zen_stores/sql_zen_store.py | 16 ++ src/zenml/zen_stores/zen_store_interface.py | 13 ++ 8 files changed, 285 insertions(+), 70 deletions(-) diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index 5175ab456cb..6a111752b56 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -82,6 +82,114 @@ # ---------- +def _save_artifact_visualizations( + data: Any, materializer: "BaseMaterializer" +) -> List[ArtifactVisualizationRequest]: + """Save artifact visualizations. + + Args: + data: The data for which to save the visualizations. + materializer: The materializer that should be used to generate and + save the visualizations. + + Returns: + List of requests for the saved visualizations. + """ + try: + visualizations = materializer.save_visualizations(data) + except Exception as e: + logger.warning("Failed to save artifact visualizations: %s", e) + return [] + + return [ + ArtifactVisualizationRequest( + type=type, + uri=uri, + ) + for uri, type in visualizations.items() + ] + + +def _store_artifact_data_and_prepare_request( + data: Any, + name: str, + uri: str, + materializer_class: Type["BaseMaterializer"], + version: Optional[Union[int, str]] = None, + tags: Optional[List[str]] = None, + store_metadata: bool = True, + store_visualizations: bool = True, + has_custom_name: bool = True, + metadata: Optional[Dict[str, "MetadataType"]] = None, +) -> ArtifactVersionRequest: + """Store artifact data and prepare a request to the server. + + Args: + data: The artifact data. + name: The artifact name. + uri: The artifact URI. + materializer_class: The materializer class to use for storing the + artifact data. + version: The artifact version. + tags: Tags for the artifact version. + store_metadata: Whether to store metadata for the artifact version. + store_visualizations: Whether to store visualizations for the artifact + version. + has_custom_name: Whether the artifact has a custom name. + metadata: Metadata to store for the artifact version. This will be + ignored if `store_metadata` is set to `False`. + + Returns: + Artifact version request for the artifact data that was stored. + """ + artifact_store = Client().active_stack.artifact_store + artifact_store.makedirs(uri) + + materializer = materializer_class(uri=uri, artifact_store=artifact_store) + materializer.uri = materializer.uri.replace("\\", "/") + + data_type = type(data) + materializer.validate_save_type_compatibility(data_type) + materializer.save(data) + + visualizations = ( + _save_artifact_visualizations(data=data, materializer=materializer) + if store_visualizations + else None + ) + + combined_metadata: Dict[str, "MetadataType"] = {} + if store_metadata: + try: + combined_metadata = materializer.extract_full_metadata(data) + except Exception as e: + logger.warning("Failed to extract materializer metadata: %s", e) + + # Update with user metadata to potentially overwrite values coming from + # the materializer + combined_metadata.update(metadata or {}) + + artifact_version_request = ArtifactVersionRequest( + artifact_name=name, + version=version, + tags=tags, + type=materializer.ASSOCIATED_ARTIFACT_TYPE, + uri=materializer.uri, + materializer=source_utils.resolve(materializer.__class__), + data_type=source_utils.resolve(data_type), + user=Client().active_user.id, + workspace=Client().active_workspace.id, + artifact_store_id=artifact_store.id, + visualizations=visualizations, + has_custom_name=has_custom_name, + metadata=validate_metadata(combined_metadata) + if combined_metadata + else None, + ) + + return artifact_version_request + + def save_artifact( data: Any, name: str, @@ -89,13 +197,14 @@ def save_artifact( tags: Optional[List[str]] = None, extract_metadata: bool = True, include_visualizations: bool = True, - has_custom_name: bool = True, user_metadata: Optional[Dict[str, "MetadataType"]] = None, materializer: Optional["MaterializerClassOrSource"] = None, uri: Optional[str] = None, is_model_artifact: bool = False, is_deployment_artifact: bool = False, + # TODO: remove these once external artifact does not use this function anymore manual_save: bool = True, + has_custom_name: bool = True, ) -> "ArtifactVersionResponse": """Upload and publish an artifact. @@ -107,8 +216,6 @@ def save_artifact( tags: Tags to associate with the artifact. extract_metadata: If artifact metadata should be extracted and returned. include_visualizations: If artifact visualizations should be generated. - has_custom_name: If the artifact name is custom and should be listed in - the dashboard "Artifacts" tab. user_metadata: User-provided metadata to store with the artifact. materializer: The materializer to use for saving the artifact to the artifact store. @@ -119,6 +226,8 @@ def save_artifact( is_deployment_artifact: If the artifact is a deployment artifact. manual_save: If this function is called manually and should therefore link the artifact to the current step run. + has_custom_name: If the artifact name is custom and should be listed in + the dashboard "Artifacts" tab. Returns: The saved artifact response. @@ -129,11 +238,8 @@ def save_artifact( from zenml.utils import source_utils client = Client() - - # Get the current artifact store artifact_store = client.active_stack.artifact_store - # Build and check the artifact URI if not uri: uri = os.path.join("custom_artifacts", name, str(uuid4())) if not uri.startswith(artifact_store.path): @@ -147,9 +253,7 @@ def save_artifact( uri=uri, name=name, ) - artifact_store.makedirs(uri) - # Find and initialize the right materializer class if isinstance(materializer, type): materializer_class = materializer elif materializer: @@ -158,60 +262,18 @@ def save_artifact( ) else: materializer_class = materializer_registry[type(data)] - materializer_object = materializer_class(uri) - - # Force URIs to have forward slashes - materializer_object.uri = materializer_object.uri.replace("\\", "/") - - # Save the artifact to the artifact store - data_type = type(data) - materializer_object.validate_save_type_compatibility(data_type) - materializer_object.save(data) - # Save visualizations of the artifact - visualizations: List[ArtifactVisualizationRequest] = [] - if include_visualizations: - try: - vis_data = materializer_object.save_visualizations(data) - for vis_uri, vis_type in vis_data.items(): - vis_model = ArtifactVisualizationRequest( - type=vis_type, - uri=vis_uri, - ) - visualizations.append(vis_model) - except Exception as e: - logger.warning( - f"Failed to save visualization for output artifact '{name}': " - f"{e}" - ) - - # Save metadata of the artifact - artifact_metadata: Dict[str, "MetadataType"] = {} - if extract_metadata: - try: - artifact_metadata = materializer_object.extract_full_metadata(data) - artifact_metadata.update(user_metadata or {}) - except Exception as e: - logger.warning( - f"Failed to extract metadata for output artifact '{name}': {e}" - ) - - artifact_version_request = ArtifactVersionRequest( - artifact_name=name, + artifact_version_request = _store_artifact_data_and_prepare_request( + data=data, + name=name, + uri=uri, + materializer_class=materializer_class, version=version, tags=tags, - type=materializer_object.ASSOCIATED_ARTIFACT_TYPE, - uri=materializer_object.uri, - materializer=source_utils.resolve(materializer_object.__class__), - data_type=source_utils.resolve(data_type), - user=Client().active_user.id, - workspace=Client().active_workspace.id, - artifact_store_id=artifact_store.id, - visualizations=visualizations, + store_metadata=extract_metadata, + store_visualizations=include_visualizations, has_custom_name=has_custom_name, - metadata=validate_metadata(artifact_metadata) - if artifact_metadata - else None, + metadata=user_metadata, ) artifact_version = client.zen_store.create_artifact_version( artifact_version=artifact_version_request diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 1474574cbe7..e8fb57a078a 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -338,6 +338,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ARTIFACT_VERSIONS = "/artifact_versions" ARTIFACT_VISUALIZATIONS = "/artifact_visualizations" AUTH = "/auth" +BATCH = "/batch" CODE_REFERENCES = "/code_references" CODE_REPOSITORIES = "/code_repositories" COMPONENT_TYPES = "/component-types" diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 84f9525b40e..b5cccc43f88 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -28,7 +28,8 @@ ) from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact -from zenml.artifacts.utils import save_artifact +from zenml.artifacts.utils import _store_artifact_data_and_prepare_request +from zenml.client import Client from zenml.config.step_configurations import StepConfiguration from zenml.config.step_run_info import StepRunInfo from zenml.constants import ( @@ -534,7 +535,7 @@ def _store_output_artifacts( The IDs of the published output artifacts. """ step_context = get_step_context() - output_artifacts: Dict[str, "ArtifactVersionResponse"] = {} + artifact_requests = [] for output_name, return_value in output_data.items(): data_type = type(return_value) @@ -595,22 +596,24 @@ def _store_output_artifacts( # Get full set of tags tags = step_context.get_output_tags(output_name) - artifact = save_artifact( + artifact_request = _store_artifact_data_and_prepare_request( name=artifact_name, data=return_value, - materializer=materializer_class, + materializer_class=materializer_class, uri=uri, - extract_metadata=artifact_metadata_enabled, - include_visualizations=artifact_visualization_enabled, + store_metadata=artifact_metadata_enabled, + store_visualizations=artifact_visualization_enabled, has_custom_name=has_custom_name, version=version, tags=tags, - user_metadata=user_metadata, - manual_save=False, + metadata=user_metadata, ) - output_artifacts[output_name] = artifact + artifact_requests.append(artifact_request) - return output_artifacts + responses = Client().zen_store.batch_create_artifact_versions( + artifact_requests + ) + return dict(zip(output_data.keys(), responses)) def load_and_run_hook( self, diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py index 1f8abe8d6ea..81c956d3543 100644 --- a/src/zenml/zen_server/rbac/endpoint_utils.py +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """High-level helper functions to write endpoints with RBAC.""" -from typing import Any, Callable, TypeVar, Union +from typing import Any, Callable, List, TypeVar, Union from uuid import UUID from pydantic import BaseModel @@ -96,6 +96,48 @@ def verify_permissions_and_create_entity( return created +def verify_permissions_and_batch_create_entity( + batch: List[AnyRequest], + resource_type: ResourceType, + create_method: Callable[[List[AnyRequest]], List[AnyResponse]], +) -> List[AnyResponse]: + """Verify permissions and create a batch of entities if authorized. + + Args: + batch: The batch to create. + resource_type: The resource type of the entities to create. + create_method: The method to create the entities. + + Raises: + IllegalOperationError: If the request model has a different owner then + the currently authenticated user. + RuntimeError: If the resource type is usage-tracked. + + Returns: + The created entities. + """ + auth_context = get_auth_context() + assert auth_context + + for request_model in batch: + if isinstance(request_model, UserScopedRequest): + if request_model.user != auth_context.user.id: + raise IllegalOperationError( + f"Not allowed to create resource '{resource_type}' for a " + "different user." + ) + + verify_permission(resource_type=resource_type, action=Action.CREATE) + + if resource_type in REPORTABLE_RESOURCES: + raise RuntimeError( + "Batch requests are currently not possible with usage-tracked features." + ) + + created = create_method(batch) + return created + + def verify_permissions_and_get_entity( id: UUIDOrStr, get_method: Callable[[UUIDOrStr], AnyResponse], diff --git a/src/zenml/zen_server/routers/artifact_version_endpoints.py b/src/zenml/zen_server/routers/artifact_version_endpoints.py index 1a65a5ada8e..eda381c233d 100644 --- a/src/zenml/zen_server/routers/artifact_version_endpoints.py +++ b/src/zenml/zen_server/routers/artifact_version_endpoints.py @@ -13,12 +13,13 @@ # permissions and limitations under the License. """Endpoint definitions for artifact versions.""" +from typing import List from uuid import UUID from fastapi import APIRouter, Depends, Security from zenml.artifacts.utils import load_artifact_visualization -from zenml.constants import API, ARTIFACT_VERSIONS, VERSION_1, VISUALIZE +from zenml.constants import API, ARTIFACT_VERSIONS, BATCH, VERSION_1, VISUALIZE from zenml.models import ( ArtifactVersionFilter, ArtifactVersionRequest, @@ -30,6 +31,7 @@ from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_batch_create_entity, verify_permissions_and_create_entity, verify_permissions_and_delete_entity, verify_permissions_and_get_entity, @@ -118,6 +120,30 @@ def create_artifact_version( ) +@artifact_version_router.post( + BATCH, + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def batch_create_artifact_version( + artifact_versions: List[ArtifactVersionRequest], + _: AuthContext = Security(authorize), +) -> List[ArtifactVersionResponse]: + """Create a batch of artifact versions. + + Args: + artifact_versions: The artifact versions to create. + + Returns: + The created artifact versions. + """ + return verify_permissions_and_batch_create_entity( + batch=artifact_versions, + resource_type=ResourceType.ARTIFACT_VERSION, + create_method=zen_store().batch_create_artifact_versions, + ) + + @artifact_version_router.get( "/{artifact_version_id}", response_model=ArtifactVersionResponse, diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index dcead29fe06..0b4836c3f7c 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -57,6 +57,7 @@ ARTIFACT_VERSIONS, ARTIFACT_VISUALIZATIONS, ARTIFACTS, + BATCH, CODE_REFERENCES, CODE_REPOSITORIES, CONFIG, @@ -991,6 +992,23 @@ def create_artifact_version( route=ARTIFACT_VERSIONS, ) + def batch_create_artifact_versions( + self, artifact_versions: List[ArtifactVersionRequest] + ) -> List[ArtifactVersionResponse]: + """Creates a batch of artifact versions. + + Args: + artifact_versions: The artifact versions to create. + + Returns: + The created artifact versions. + """ + return self._batch_create_resources( + resources=artifact_versions, + response_model=ArtifactVersionResponse, + route=ARTIFACT_VERSIONS, + ) + def get_artifact_version( self, artifact_version_id: UUID, hydrate: bool = True ) -> ArtifactVersionResponse: @@ -4518,6 +4536,40 @@ def _create_resource( return response_model.model_validate(response_body) + def _batch_create_resources( + self, + resources: List[AnyRequest], + response_model: Type[AnyResponse], + route: str, + params: Optional[Dict[str, Any]] = None, + ) -> List[AnyResponse]: + """Create a new batch of resources. + + Args: + resources: The resources to create. + response_model: The response model of an individual resource. + route: The resource REST route to use. + params: Optional query parameters to pass to the endpoint. + + Returns: + List of response models. + """ + json_data = [ + resource.model_dump(mode="json") for resource in resources + ] + response = self._request( + "POST", + self.url + API + VERSION_1 + route + BATCH, + json=json_data, + params=params, + ) + assert isinstance(response, list) + + return [ + response_model.model_validate(model_data) + for model_data in response + ] + def _create_workspace_scoped_resource( self, resource: AnyWorkspaceScopedRequest, diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 7001f0b1529..4d6b735f7fd 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -2915,6 +2915,22 @@ def create_artifact_version( include_metadata=True, include_resources=True ) + def batch_create_artifact_versions( + self, artifact_versions: List[ArtifactVersionRequest] + ) -> List[ArtifactVersionResponse]: + """Creates a batch of artifact versions. + + Args: + artifact_versions: The artifact versions to create. + + Returns: + The created artifact versions. + """ + return [ + self.create_artifact_version(artifact_version) + for artifact_version in artifact_versions + ] + def get_artifact_version( self, artifact_version_id: UUID, hydrate: bool = True ) -> ArtifactVersionResponse: diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 6f0cb7b496d..ea2e53a06ca 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -663,6 +663,19 @@ def create_artifact_version( The created artifact version. """ + @abstractmethod + def batch_create_artifact_versions( + self, artifact_versions: List[ArtifactVersionRequest] + ) -> List[ArtifactVersionResponse]: + """Creates a batch of artifact versions. + + Args: + artifact_versions: The artifact versions to create. + + Returns: + The created artifact versions. + """ + @abstractmethod def get_artifact_version( self, artifact_version_id: UUID, hydrate: bool = True From d9f06a923537efc398567545c0ee3d9b944b42e1 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Wed, 6 Nov 2024 15:56:30 +0100 Subject: [PATCH 2/3] add missing section (#3172) --- docs/book/user-guide/llmops-guide/README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/book/user-guide/llmops-guide/README.md b/docs/book/user-guide/llmops-guide/README.md index 106c8b266bd..4197eebddc7 100644 --- a/docs/book/user-guide/llmops-guide/README.md +++ b/docs/book/user-guide/llmops-guide/README.md @@ -32,6 +32,13 @@ In this guide, we'll explore various aspects of working with LLMs in ZenML, incl * [Finetuning embeddings with Sentence Transformers](finetuning-embeddings/finetuning-embeddings-with-sentence-transformers.md) * [Evaluating finetuned embeddings](finetuning-embeddings/evaluating-finetuned-embeddings.md) * [Finetuning LLMs with ZenML](finetuning-llms/finetuning-llms.md) + * [Finetuning in 100 lines of code](finetuning-llms/finetuning-100-loc.md) + * [Why and when to finetune LLMs](finetuning-llms/why-and-when-to-finetune-llms.md) + * [Starter choices with finetuning](finetuning-llms/starter-choices-for-finetuning-llms.md) + * [Finetuning with 🤗 Accelerate](finetuning-llms/finetuning-with-accelerate.md) + * [Evaluation for finetuning](finetuning-llms/evaluation-for-finetuning.md) + * [Deploying finetuned models](finetuning-llms/deploying-finetuned-models.md) + * [Next steps](finetuning-llms/next-steps.md) To follow along with the examples and tutorials in this guide, ensure you have a Python environment set up with ZenML installed. Familiarity with the concepts covered in the [Starter Guide](../starter-guide/README.md) and [Production Guide](../production-guide/README.md) is recommended. From aeeb315af429879b85cca330e8d59625e478e2f2 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 6 Nov 2024 17:00:04 +0100 Subject: [PATCH 3/3] fix uvloop mypy (#3174) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 07948240e91..994407ddf7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -478,5 +478,6 @@ module = [ "langchain_community.*", "vllm.*", "numba.*", + "uvloop.*", ] ignore_missing_imports = true