Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple output versions for a step outputs #3072

Merged
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
9bdb5f1
multi-versioned outputs
avishniakov Oct 10, 2024
4dd7e90
more artifact save types
avishniakov Oct 10, 2024
eb7f691
fix for the llm template
avishniakov Oct 10, 2024
49f0f24
Auto-update of LLM Finetuning template
actions-user Oct 10, 2024
5d81b5d
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Oct 10, 2024
ac9a7bd
Auto-update of Starter template
actions-user Oct 10, 2024
9607df5
Auto-update of E2E template
actions-user Oct 10, 2024
fc19dc3
fix tests notation
avishniakov Oct 10, 2024
9f04da1
Merge branch 'feature/PRD-663-multiple-output-versions-for-a-step' of…
avishniakov Oct 10, 2024
feb3f53
Refactor artifact saving logic to use save types
avishniakov Oct 17, 2024
5557837
Refactor artifact saving logic to use save types
avishniakov Oct 17, 2024
eac4aea
Remove unneeded TODO
avishniakov Oct 17, 2024
e62d8eb
Refactor artifact saving logic to use list instead of set for output …
avishniakov Oct 17, 2024
f4bf134
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Oct 17, 2024
9a8cf20
Refactor artifact saving logic to use outputs instead of saved_artifa…
avishniakov Oct 17, 2024
ecb9ff0
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Oct 17, 2024
5683081
Auto-update of LLM Finetuning template
actions-user Oct 17, 2024
f02fee1
Refactor artifact saving logic to use 'external' as the default save …
avishniakov Oct 18, 2024
0a2dfa1
Refactor artifact saving logic to use outputs instead of saved_artifa…
avishniakov Oct 18, 2024
77b6a79
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Oct 18, 2024
a704f95
Refactor StepRunRequestFactory to correctly assign output artifacts
avishniakov Oct 18, 2024
e991fd5
Merge branch 'feature/PRD-663-multiple-output-versions-for-a-step' of…
avishniakov Oct 18, 2024
a2d735c
mypy
avishniakov Oct 18, 2024
56a37a8
Remove unused arg
schustmi Oct 30, 2024
b5c8ad0
Improve input resolution
schustmi Oct 30, 2024
3d2443b
Rename save type
schustmi Oct 30, 2024
aef08fc
Only apply artifact config to step outputs
schustmi Oct 30, 2024
239122e
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
schustmi Oct 30, 2024
eaf2958
Fix alembic order
schustmi Oct 30, 2024
fc0c141
Fix DB migration
schustmi Oct 30, 2024
1ecefd8
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
schustmi Oct 31, 2024
9a39cd4
Fix alembic order
schustmi Oct 31, 2024
3193c20
Refactor artifact saving in cacheable_multiple_versioned_producer
avishniakov Nov 5, 2024
3797a07
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Nov 5, 2024
e3e9320
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Nov 6, 2024
5cf1089
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Nov 6, 2024
58e75f2
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
schustmi Nov 6, 2024
d0e2633
Merge branch 'develop' into feature/PRD-663-multiple-output-versions-…
avishniakov Nov 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/update-templates-to-examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
ref-zenml: ${{ github.ref }}
ref-template: 2024.09.24 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py
ref-template: 2024.10.10 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py
- name: Clean-up
run: |
rm -rf ./local_checkout
Expand Down
2 changes: 1 addition & 1 deletion examples/llm_finetuning/.copier-answers.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Changes here will be overwritten by Copier
_commit: 2024.09.24
_commit: 2024.08.29-1-g7af7693
_src_path: gh:zenml-io/template-llm-finetuning
bf16: true
cuda_version: cuda11.8
Expand Down
2 changes: 1 addition & 1 deletion examples/llm_finetuning/steps/log_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def log_metadata_from_step_artifact(

context = get_step_context()
metadata_dict: Dict[str, Any] = (
context.pipeline_run.steps[step_name].outputs[artifact_name].load()
context.pipeline_run.steps[step_name].outputs[artifact_name][0].load()
)

metadata = {artifact_name: metadata_dict}
Expand Down
3 changes: 2 additions & 1 deletion src/zenml/artifacts/external_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ExternalArtifactConfiguration,
)
from zenml.config.source import Source
from zenml.enums import ArtifactSaveType
from zenml.logger import get_logger
from zenml.materializers.base_materializer import BaseMaterializer

Expand Down Expand Up @@ -114,7 +115,7 @@ def upload_by_value(self) -> UUID:
materializer=self.materializer,
uri=uri,
has_custom_name=False,
manual_save=False,
save_type=ArtifactSaveType.EXTERNAL,
)

# To avoid duplicate uploads, switch to referencing the uploaded
Expand Down
19 changes: 10 additions & 9 deletions src/zenml/artifacts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
MODEL_METADATA_YAML_FILE_NAME,
)
from zenml.enums import (
ArtifactSaveType,
ArtifactType,
ExecutionStatus,
MetadataResourceTypes,
Expand Down Expand Up @@ -100,7 +101,7 @@ def save_artifact(
uri: Optional[str] = None,
is_model_artifact: bool = False,
is_deployment_artifact: bool = False,
manual_save: bool = True,
save_type: ArtifactSaveType = ArtifactSaveType.MANUAL,
) -> "ArtifactVersionResponse":
"""Upload and publish an artifact.

Expand All @@ -122,8 +123,7 @@ def save_artifact(
`custom_artifacts/{name}/{version}`.
is_model_artifact: If the artifact is a model artifact.
is_deployment_artifact: If the artifact is a deployment artifact.
manual_save: If this function is called manually and should therefore
link the artifact to the current step run.
save_type: The type of save operation that created the artifact version.

Returns:
The saved artifact response.
Expand All @@ -150,7 +150,7 @@ def save_artifact(
if not uri.startswith(artifact_store.path):
uri = os.path.join(artifact_store.path, uri)

if manual_save:
if save_type == ArtifactSaveType.MANUAL:
# This check is only necessary for manual saves as we already check
# it when creating the directory for step output artifacts
_check_if_artifact_with_given_uri_already_registered(
Expand Down Expand Up @@ -224,6 +224,7 @@ def _create_version(
artifact_store_id=artifact_store.id,
visualizations=visualizations,
has_custom_name=has_custom_name,
save_type=save_type,
)
try:
return client.zen_store.create_artifact_version(
Expand All @@ -245,7 +246,7 @@ def _create_version(
resource_type=MetadataResourceTypes.ARTIFACT_VERSION,
)

if manual_save:
if save_type == ArtifactSaveType.MANUAL:
_link_artifact_version_to_the_step_and_model(
artifact_version=response,
is_model_artifact=is_model_artifact,
Expand Down Expand Up @@ -326,6 +327,7 @@ def _create_version(
workspace=Client().active_workspace.id,
artifact_store_id=artifact_store.id,
has_custom_name=has_custom_name,
save_type=ArtifactSaveType.PREEXISTING,
)
try:
return client.zen_store.create_artifact_version(
Expand Down Expand Up @@ -625,7 +627,8 @@ def get_artifacts_versions_of_pipeline_run(
artifact_versions: List["ArtifactVersionResponse"] = []
for step in pipeline_run.steps.values():
if not only_produced or step.status == ExecutionStatus.COMPLETED:
artifact_versions.extend(step.outputs.values())
for output in step.outputs.values():
artifact_versions.extend(output)
return artifact_versions


Expand Down Expand Up @@ -787,9 +790,7 @@ def _link_artifact_version_to_the_step_and_model(
client.zen_store.update_run_step(
step_run_id=step_run.id,
step_run_update=StepRunUpdate(
saved_artifact_versions={
artifact_version.artifact.name: artifact_version.id
}
outputs={artifact_version.artifact.name: artifact_version.id}
),
)
error_message = "model"
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def copier_github_url(self) -> str:
),
llm_finetuning=ZenMLProjectTemplateLocation(
github_url="zenml-io/template-llm-finetuning",
github_tag="2024.09.24", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
github_tag="2024.10.10", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
),
)

Expand Down
10 changes: 7 additions & 3 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ class StepRunInputArtifactType(StrEnum):
MANUAL = "manual" # manually loaded via `zenml.load_artifact()`


class StepRunOutputArtifactType(StrEnum):
"""All possible types of a step run output artifact."""
class ArtifactSaveType(StrEnum):
"""All possible method types of how artifact versions can be saved."""

DEFAULT = "default" # output of the current step
STEP_OUTPUT = "default" # output of the current step
MANUAL = "manual" # manually saved via `zenml.save_artifact()`
PREEXISTING = "preexisting" # register via `zenml.register_artifact()`
EXTERNAL = (
"external" # saved via `zenml.ExternalArtifact.upload_by_value()`
)


class VisualizationType(StrEnum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,41 +87,44 @@ def visualize(
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ArtifactType.MODEL:
logdir = os.path.dirname(artifact_view.uri)

# first check if a TensorBoard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if running_server:
self.visualize_tensorboard(running_server.port, height)
return

if sys.platform == "win32":
# Daemon service functionality is currently not supported
# on Windows
print(
"You can run:\n"
f"[italic green] tensorboard --logdir {logdir}"
"[/italic green]\n"
"...to visualize the TensorBoard logs for your trained model."
for output in object.outputs.values():
for artifact_view in output:
# filter out anything but model artifacts
if artifact_view.type == ArtifactType.MODEL:
logdir = os.path.dirname(artifact_view.uri)

# first check if a TensorBoard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(
logdir
)
else:
# start a new TensorBoard server
service = TensorboardService(
TensorboardServiceConfig(
logdir=logdir,
name=f"zenml-tensorboard-{logdir}",
if running_server:
self.visualize_tensorboard(running_server.port, height)
return

if sys.platform == "win32":
# Daemon service functionality is currently not supported
# on Windows
print(
"You can run:\n"
f"[italic green] tensorboard --logdir {logdir}"
"[/italic green]\n"
"...to visualize the TensorBoard logs for your trained model."
)
)
service.start(timeout=60)
if service.endpoint.status.port:
self.visualize_tensorboard(
service.endpoint.status.port, height
else:
# start a new TensorBoard server
service = TensorboardService(
TensorboardServiceConfig(
logdir=logdir,
name=f"zenml-tensorboard-{logdir}",
)
)
return
service.start(timeout=60)
if service.endpoint.status.port:
self.visualize_tensorboard(
service.endpoint.status.port, height
)
return

def visualize_tensorboard(
self,
Expand Down Expand Up @@ -154,31 +157,34 @@ def stop(
Args:
object: StepRunResponseModel fetched from get_step().
"""
for _, artifact_view in object.outputs.items():
# filter out anything but model artifacts
if artifact_view.type == ArtifactType.MODEL:
logdir = os.path.dirname(artifact_view.uri)

# first check if a TensorBoard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(logdir)
if not running_server:
return
for output in object.outputs.values():
for artifact_view in output:
# filter out anything but model artifacts
if artifact_view.type == ArtifactType.MODEL:
logdir = os.path.dirname(artifact_view.uri)

# first check if a TensorBoard server is already running for
# the same logdir location and use that one
running_server = self.find_running_tensorboard_server(
logdir
)
if not running_server:
return

logger.debug(
"Stopping tensorboard server with PID '%d' ...",
running_server.pid,
)
try:
p = psutil.Process(running_server.pid)
except psutil.Error:
logger.error(
"Could not find process for PID '%d' ...",
logger.debug(
"Stopping tensorboard server with PID '%d' ...",
running_server.pid,
)
continue
p.kill()
return
try:
p = psutil.Process(running_server.pid)
except psutil.Error:
logger.error(
"Could not find process for PID '%d' ...",
running_server.pid,
)
continue
p.kill()
return


def get_step(pipeline_name: str, step_name: str) -> "StepRunResponse":
Expand Down
40 changes: 23 additions & 17 deletions src/zenml/lineage_graph/lineage_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,25 @@ def generate_step_nodes_and_edges(self, step: "StepRunResponse") -> None:
self.add_step_node(step, step_id)

# Add nodes and edges for all output artifacts
for artifact_name, artifact_version in step.outputs.items():
artifact_version_id = ARTIFACT_PREFIX + str(artifact_version.id)
if step.status == ExecutionStatus.CACHED:
artifact_status = ArtifactNodeStatus.CACHED
elif step.status == ExecutionStatus.COMPLETED:
artifact_status = ArtifactNodeStatus.CREATED
else:
artifact_status = ArtifactNodeStatus.UNKNOWN
self.add_artifact_node(
artifact=artifact_version,
id=artifact_version_id,
name=artifact_name,
step_id=str(step_id),
status=artifact_status,
)
self.add_edge(step_id, artifact_version_id)
for artifact_name, output in step.outputs.items():
for artifact_version in output:
artifact_version_id = ARTIFACT_PREFIX + str(
artifact_version.id
)
if step.status == ExecutionStatus.CACHED:
artifact_status = ArtifactNodeStatus.CACHED
elif step.status == ExecutionStatus.COMPLETED:
artifact_status = ArtifactNodeStatus.CREATED
else:
artifact_status = ArtifactNodeStatus.UNKNOWN
self.add_artifact_node(
artifact=artifact_version,
id=artifact_version_id,
name=artifact_name,
step_id=str(step_id),
status=artifact_status,
)
self.add_edge(step_id, artifact_version_id)

# Add nodes and edges for all input artifacts
for artifact_name, artifact_version in step.inputs.items():
Expand Down Expand Up @@ -186,7 +189,10 @@ def add_step_node(
parameters=step.config.parameters,
configuration=step_config,
inputs={k: v.uri for k, v in step.inputs.items()},
outputs={k: v.uri for k, v in step.outputs.items()},
outputs={
k: [av.uri for av in v]
for k, v in step.outputs.items()
},
metadata=[
(m.key, str(m.value), str(m.type))
for m in step.run_metadata.values()
Expand Down
4 changes: 2 additions & 2 deletions src/zenml/lineage_graph/node/step_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class StepNodeDetails(BaseNodeDetails):
entrypoint_name: str
parameters: Dict[str, Any]
configuration: Dict[str, Any]
inputs: Dict[str, Any]
outputs: Dict[str, Any]
inputs: Dict[str, str] # (key, uri)
outputs: Dict[str, List[str]] # (key, [uris,...])
metadata: List[Tuple[str, str, str]] # (key, value, type)


Expand Down
17 changes: 16 additions & 1 deletion src/zenml/models/v2/core/artifact_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from zenml.config.source import Source, SourceWithValidator
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
from zenml.enums import ArtifactType, GenericFilterOps
from zenml.enums import ArtifactSaveType, ArtifactType, GenericFilterOps
from zenml.logger import get_logger
from zenml.models.v2.base.filter import StrFilter
from zenml.models.v2.base.scoped import (
Expand Down Expand Up @@ -95,6 +95,9 @@ class ArtifactVersionRequest(WorkspaceScopedRequest):
visualizations: Optional[List["ArtifactVisualizationRequest"]] = Field(
default=None, title="Visualizations of the artifact."
)
save_type: ArtifactSaveType = Field(
title="The save type of the artifact version.",
)

@field_validator("version")
@classmethod
Expand Down Expand Up @@ -156,6 +159,9 @@ class ArtifactVersionResponseBody(WorkspaceScopedResponseBody):
title="The ID of the pipeline run that generated this artifact version.",
default=None,
)
save_type: ArtifactSaveType = Field(
title="The save type of the artifact version.",
)

@field_validator("version")
@classmethod
Expand Down Expand Up @@ -276,6 +282,15 @@ def producer_pipeline_run_id(self) -> Optional[UUID]:
"""
return self.get_body().producer_pipeline_run_id

@property
def save_type(self) -> ArtifactSaveType:
"""The `save_type` property.

Returns:
the value of the property.
"""
return self.get_body().save_type

@property
def artifact_store_id(self) -> Optional[UUID]:
"""The `artifact_store_id` property.
Expand Down
Loading
Loading