Skip to content

Commit

Permalink
consistent creation\
Browse files Browse the repository at this point in the history
  • Loading branch information
bcdurak committed Nov 8, 2024
1 parent ec7dc02 commit 1fafb7e
Showing 1 changed file with 46 additions and 22 deletions.
68 changes: 46 additions & 22 deletions src/zenml/utils/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,6 @@ def log_metadata(
# Initialize the client
client = Client()

# Initialize a batch of request to avoid duplications
metadata_batch: Dict[MetadataResourceTypes, Set[UUID]] = {
MetadataResourceTypes.PIPELINE_RUN: set(),
MetadataResourceTypes.STEP_RUN: set(),
MetadataResourceTypes.ARTIFACT_VERSION: set(),
MetadataResourceTypes.MODEL_VERSION: set(),
}

# If a step name is provided, we need a run_id_name_or_prefix and will log
# metadata for the steps pipeline and model accordingly.
if step_name is not None and run_id_name_or_prefix is not None:
Expand All @@ -140,11 +132,21 @@ def log_metadata(
)
step_model = run_model.steps[step_name]

metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add(run_model.id)
metadata_batch[MetadataResourceTypes.STEP_RUN].add(step_model.id)
client.create_run_metadata(
metadata=metadata,
resource_id=run_model.id,
resource_type=MetadataResourceTypes.PIPELINE_RUN,
)
client.create_run_metadata(
metadata=metadata,
resource_id=step_model.id,
resource_type=MetadataResourceTypes.STEP_RUN,
)
if step_model.model_version:
metadata_batch[MetadataResourceTypes.MODEL_VERSION].add(
step_model.model_version.id
client.create_run_metadata(
metadata=metadata,
resource_id=step_model.model_version.id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)

# If a step is identified by id, fetch it directly through the client,
Expand All @@ -155,11 +157,21 @@ def log_metadata(
run_model = client.get_pipeline_run(
name_id_or_prefix=step_model.pipeline_run_id
)
metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add(run_model.id)
metadata_batch[MetadataResourceTypes.STEP_RUN].add(step_model.id)
client.create_run_metadata(
metadata=metadata,
resource_id=run_model.id,
resource_type=MetadataResourceTypes.PIPELINE_RUN,
)
client.create_run_metadata(
metadata=metadata,
resource_id=step_model.id,
resource_type=MetadataResourceTypes.STEP_RUN,
)
if step_model.model_version:
metadata_batch[MetadataResourceTypes.MODEL_VERSION].add(
step_model.model_version.id
client.create_run_metadata(
metadata=metadata,
resource_id=step_model.model_version.id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)

# If a pipeline run id is identified, we need to log metadata to it and its
Expand All @@ -168,28 +180,40 @@ def log_metadata(
run_model = client.get_pipeline_run(
name_id_or_prefix=run_id_name_or_prefix
)
client.create_run_metadata(
metadata=metadata,
resource_id=run_model.id,
resource_type=MetadataResourceTypes.PIPELINE_RUN,
)
if run_model.model_version:
metadata_batch[MetadataResourceTypes.MODEL_VERSION].add(
run_model.model_version.id
client.create_run_metadata(
metadata=metadata,
resource_id=run_model.model_version.id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)
metadata_batch[MetadataResourceTypes.PIPELINE_RUN].add(run_model.id)

# If the user provides a model name and version, we use to model abstraction
# to fetch the model version and attach the corresponding metadata to it.
elif model_name is not None and model_version is not None:
from zenml import Model

mv = Model(name=model_name, version=model_version)
metadata_batch[MetadataResourceTypes.MODEL_VERSION].add(mv.id)
client.create_run_metadata(
metadata=metadata,
resource_id=mv.id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)

# If the user provides a model version id, we use the client to fetch it and
# attach the metadata to it.
elif model_version_id is not None:
model_version_id = client.get_model_version(
model_version_name_or_number_or_id=model_version_id
).id
metadata_batch[MetadataResourceTypes.MODEL_VERSION].add(
model_version_id
client.create_run_metadata(
metadata=metadata,
resource_id=model_version_id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)

# If the user provides an artifact name, there are two possibilities. If
Expand Down

0 comments on commit 1fafb7e

Please sign in to comment.