Skip to content

Commit

Permalink
Merge branch 'feature/PRD-663-multiple-output-versions-for-a-step' in…
Browse files Browse the repository at this point in the history
…to feature/PRD-668-better-input-artifacts-typing
  • Loading branch information
avishniakov authored Nov 6, 2024
2 parents ad7b3da + d0e2633 commit aeafc20
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 70 deletions.
7 changes: 7 additions & 0 deletions docs/book/user-guide/llmops-guide/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ In this guide, we'll explore various aspects of working with LLMs in ZenML, incl
* [Finetuning embeddings with Sentence Transformers](finetuning-embeddings/finetuning-embeddings-with-sentence-transformers.md)
* [Evaluating finetuned embeddings](finetuning-embeddings/evaluating-finetuned-embeddings.md)
* [Finetuning LLMs with ZenML](finetuning-llms/finetuning-llms.md)
* [Finetuning in 100 lines of code](finetuning-llms/finetuning-100-loc.md)
* [Why and when to finetune LLMs](finetuning-llms/why-and-when-to-finetune-llms.md)
* [Starter choices with finetuning](finetuning-llms/starter-choices-for-finetuning-llms.md)
* [Finetuning with 🤗 Accelerate](finetuning-llms/finetuning-with-accelerate.md)
* [Evaluation for finetuning](finetuning-llms/evaluation-for-finetuning.md)
* [Deploying finetuned models](finetuning-llms/deploying-finetuned-models.md)
* [Next steps](finetuning-llms/next-steps.md)

To follow along with the examples and tutorials in this guide, ensure you have a Python environment set up with ZenML installed. Familiarity with the concepts covered in the [Starter Guide](../starter-guide/README.md) and [Production Guide](../production-guide/README.md) is recommended.

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -478,5 +478,6 @@ module = [
"langchain_community.*",
"vllm.*",
"numba.*",
"uvloop.*",
]
ignore_missing_imports = true
183 changes: 124 additions & 59 deletions src/zenml/artifacts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,132 @@
# ----------


def _save_artifact_visualizations(
data: Any, materializer: "BaseMaterializer"
) -> List[ArtifactVisualizationRequest]:
"""Save artifact visualizations.
Args:
data: The data for which to save the visualizations.
materializer: The materializer that should be used to generate and
save the visualizations.
Returns:
List of requests for the saved visualizations.
"""
try:
visualizations = materializer.save_visualizations(data)
except Exception as e:
logger.warning("Failed to save artifact visualizations: %s", e)
return []

return [
ArtifactVisualizationRequest(
type=type,
uri=uri,
)
for uri, type in visualizations.items()
]


def _store_artifact_data_and_prepare_request(
data: Any,
name: str,
uri: str,
materializer_class: Type["BaseMaterializer"],
save_type: ArtifactSaveType,
version: Optional[Union[int, str]] = None,
tags: Optional[List[str]] = None,
store_metadata: bool = True,
store_visualizations: bool = True,
has_custom_name: bool = True,
metadata: Optional[Dict[str, "MetadataType"]] = None,
) -> ArtifactVersionRequest:
"""Store artifact data and prepare a request to the server.
Args:
data: The artifact data.
name: The artifact name.
uri: The artifact URI.
materializer_class: The materializer class to use for storing the
artifact data.
save_type: Save type of the artifact version.
version: The artifact version.
tags: Tags for the artifact version.
store_metadata: Whether to store metadata for the artifact version.
store_visualizations: Whether to store visualizations for the artifact
version.
has_custom_name: Whether the artifact has a custom name.
metadata: Metadata to store for the artifact version. This will be
ignored if `store_metadata` is set to `False`.
Returns:
Artifact version request for the artifact data that was stored.
"""
artifact_store = Client().active_stack.artifact_store
artifact_store.makedirs(uri)

materializer = materializer_class(uri=uri, artifact_store=artifact_store)
materializer.uri = materializer.uri.replace("\\", "/")

data_type = type(data)
materializer.validate_save_type_compatibility(data_type)
materializer.save(data)

visualizations = (
_save_artifact_visualizations(data=data, materializer=materializer)
if store_visualizations
else None
)

combined_metadata: Dict[str, "MetadataType"] = {}
if store_metadata:
try:
combined_metadata = materializer.extract_full_metadata(data)
except Exception as e:
logger.warning("Failed to extract materializer metadata: %s", e)

# Update with user metadata to potentially overwrite values coming from
# the materializer
combined_metadata.update(metadata or {})

artifact_version_request = ArtifactVersionRequest(
artifact_name=name,
version=version,
tags=tags,
type=materializer.ASSOCIATED_ARTIFACT_TYPE,
uri=materializer.uri,
materializer=source_utils.resolve(materializer.__class__),
data_type=source_utils.resolve(data_type),
user=Client().active_user.id,
workspace=Client().active_workspace.id,
artifact_store_id=artifact_store.id,
visualizations=visualizations,
has_custom_name=has_custom_name,
save_type=save_type,
metadata=validate_metadata(combined_metadata)
if combined_metadata
else None,
)

return artifact_version_request


def save_artifact(
data: Any,
name: str,
version: Optional[Union[int, str]] = None,
tags: Optional[List[str]] = None,
extract_metadata: bool = True,
include_visualizations: bool = True,
has_custom_name: bool = True,
user_metadata: Optional[Dict[str, "MetadataType"]] = None,
materializer: Optional["MaterializerClassOrSource"] = None,
uri: Optional[str] = None,
is_model_artifact: bool = False,
is_deployment_artifact: bool = False,
# TODO: remove these once external artifact does not use this function anymore
save_type: ArtifactSaveType = ArtifactSaveType.MANUAL,
has_custom_name: bool = True,
) -> "ArtifactVersionResponse":
"""Upload and publish an artifact.
Expand All @@ -108,8 +220,6 @@ def save_artifact(
tags: Tags to associate with the artifact.
extract_metadata: If artifact metadata should be extracted and returned.
include_visualizations: If artifact visualizations should be generated.
has_custom_name: If the artifact name is custom and should be listed in
the dashboard "Artifacts" tab.
user_metadata: User-provided metadata to store with the artifact.
materializer: The materializer to use for saving the artifact to the
artifact store.
Expand All @@ -119,6 +229,8 @@ def save_artifact(
is_model_artifact: If the artifact is a model artifact.
is_deployment_artifact: If the artifact is a deployment artifact.
save_type: The type of save operation that created the artifact version.
has_custom_name: If the artifact name is custom and should be listed in
the dashboard "Artifacts" tab.
Returns:
The saved artifact response.
Expand All @@ -129,11 +241,8 @@ def save_artifact(
from zenml.utils import source_utils

client = Client()

# Get the current artifact store
artifact_store = client.active_stack.artifact_store

# Build and check the artifact URI
if not uri:
uri = os.path.join("custom_artifacts", name, str(uuid4()))
if not uri.startswith(artifact_store.path):
Expand All @@ -147,9 +256,7 @@ def save_artifact(
uri=uri,
name=name,
)
artifact_store.makedirs(uri)

# Find and initialize the right materializer class
if isinstance(materializer, type):
materializer_class = materializer
elif materializer:
Expand All @@ -158,61 +265,19 @@ def save_artifact(
)
else:
materializer_class = materializer_registry[type(data)]
materializer_object = materializer_class(uri)

# Force URIs to have forward slashes
materializer_object.uri = materializer_object.uri.replace("\\", "/")

# Save the artifact to the artifact store
data_type = type(data)
materializer_object.validate_save_type_compatibility(data_type)
materializer_object.save(data)

# Save visualizations of the artifact
visualizations: List[ArtifactVisualizationRequest] = []
if include_visualizations:
try:
vis_data = materializer_object.save_visualizations(data)
for vis_uri, vis_type in vis_data.items():
vis_model = ArtifactVisualizationRequest(
type=vis_type,
uri=vis_uri,
)
visualizations.append(vis_model)
except Exception as e:
logger.warning(
f"Failed to save visualization for output artifact '{name}': "
f"{e}"
)

# Save metadata of the artifact
artifact_metadata: Dict[str, "MetadataType"] = {}
if extract_metadata:
try:
artifact_metadata = materializer_object.extract_full_metadata(data)
artifact_metadata.update(user_metadata or {})
except Exception as e:
logger.warning(
f"Failed to extract metadata for output artifact '{name}': {e}"
)

artifact_version_request = ArtifactVersionRequest(
artifact_name=name,
artifact_version_request = _store_artifact_data_and_prepare_request(
data=data,
name=name,
uri=uri,
materializer_class=materializer_class,
save_type=save_type,
version=version,
tags=tags,
type=materializer_object.ASSOCIATED_ARTIFACT_TYPE,
save_type=save_type,
uri=materializer_object.uri,
materializer=source_utils.resolve(materializer_object.__class__),
data_type=source_utils.resolve(data_type),
user=Client().active_user.id,
workspace=Client().active_workspace.id,
artifact_store_id=artifact_store.id,
visualizations=visualizations,
store_metadata=extract_metadata,
store_visualizations=include_visualizations,
has_custom_name=has_custom_name,
metadata=validate_metadata(artifact_metadata)
if artifact_metadata
else None,
metadata=user_metadata,
)
artifact_version = client.zen_store.create_artifact_version(
artifact_version=artifact_version_request
Expand Down
1 change: 1 addition & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
ARTIFACT_VERSIONS = "/artifact_versions"
ARTIFACT_VISUALIZATIONS = "/artifact_visualizations"
AUTH = "/auth"
BATCH = "/batch"
CODE_REFERENCES = "/code_references"
CODE_REPOSITORIES = "/code_repositories"
COMPONENT_TYPES = "/component-types"
Expand Down
22 changes: 13 additions & 9 deletions src/zenml/orchestrators/step_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
)

from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact
from zenml.artifacts.utils import save_artifact
from zenml.artifacts.utils import _store_artifact_data_and_prepare_request
from zenml.client import Client
from zenml.config.step_configurations import StepConfiguration
from zenml.config.step_run_info import StepRunInfo
from zenml.constants import (
Expand Down Expand Up @@ -537,7 +538,7 @@ def _store_output_artifacts(
The IDs of the published output artifacts.
"""
step_context = get_step_context()
output_artifacts: Dict[str, "ArtifactVersionResponse"] = {}
artifact_requests = []

for output_name, return_value in output_data.items():
data_type = type(return_value)
Expand Down Expand Up @@ -598,22 +599,25 @@ def _store_output_artifacts(
# Get full set of tags
tags = step_context.get_output_tags(output_name)

artifact = save_artifact(
artifact_request = _store_artifact_data_and_prepare_request(
name=artifact_name,
data=return_value,
materializer=materializer_class,
materializer_class=materializer_class,
uri=uri,
extract_metadata=artifact_metadata_enabled,
include_visualizations=artifact_visualization_enabled,
store_metadata=artifact_metadata_enabled,
store_visualizations=artifact_visualization_enabled,
has_custom_name=has_custom_name,
version=version,
tags=tags,
user_metadata=user_metadata,
save_type=ArtifactSaveType.STEP_OUTPUT,
metadata=user_metadata,
)
output_artifacts[output_name] = artifact
artifact_requests.append(artifact_request)

return output_artifacts
responses = Client().zen_store.batch_create_artifact_versions(
artifact_requests
)
return dict(zip(output_data.keys(), responses))

def load_and_run_hook(
self,
Expand Down
44 changes: 43 additions & 1 deletion src/zenml/zen_server/rbac/endpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.
"""High-level helper functions to write endpoints with RBAC."""

from typing import Any, Callable, TypeVar, Union
from typing import Any, Callable, List, TypeVar, Union
from uuid import UUID

from pydantic import BaseModel
Expand Down Expand Up @@ -96,6 +96,48 @@ def verify_permissions_and_create_entity(
return created


def verify_permissions_and_batch_create_entity(
batch: List[AnyRequest],
resource_type: ResourceType,
create_method: Callable[[List[AnyRequest]], List[AnyResponse]],
) -> List[AnyResponse]:
"""Verify permissions and create a batch of entities if authorized.
Args:
batch: The batch to create.
resource_type: The resource type of the entities to create.
create_method: The method to create the entities.
Raises:
IllegalOperationError: If the request model has a different owner then
the currently authenticated user.
RuntimeError: If the resource type is usage-tracked.
Returns:
The created entities.
"""
auth_context = get_auth_context()
assert auth_context

for request_model in batch:
if isinstance(request_model, UserScopedRequest):
if request_model.user != auth_context.user.id:
raise IllegalOperationError(
f"Not allowed to create resource '{resource_type}' for a "
"different user."
)

verify_permission(resource_type=resource_type, action=Action.CREATE)

if resource_type in REPORTABLE_RESOURCES:
raise RuntimeError(
"Batch requests are currently not possible with usage-tracked features."
)

created = create_method(batch)
return created


def verify_permissions_and_get_entity(
id: UUIDOrStr,
get_method: Callable[[UUIDOrStr], AnyResponse],
Expand Down
Loading

0 comments on commit aeafc20

Please sign in to comment.