Skip to content

Commit

Permalink
adding a parameter to control the related entity behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
bcdurak committed Nov 13, 2024
1 parent dca5913 commit b767269
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 18 deletions.
57 changes: 40 additions & 17 deletions src/zenml/utils/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,19 @@


@overload
def log_metadata(metadata: Dict[str, MetadataType]) -> None: ...
def log_metadata(
*,
metadata: Dict[str, MetadataType],
log_related_entities: Optional[bool] = True,
) -> None: ...


@overload
def log_metadata(
*,
metadata: Dict[str, MetadataType],
artifact_version_id: UUID,
log_related_entities: Optional[bool] = True,
) -> None: ...


Expand All @@ -44,6 +49,7 @@ def log_metadata(
metadata: Dict[str, MetadataType],
artifact_name: str,
artifact_version: Optional[str] = None,
log_related_entities: Optional[bool] = True,
) -> None: ...


Expand All @@ -52,6 +58,7 @@ def log_metadata(
*,
metadata: Dict[str, MetadataType],
model_version_id: UUID,
log_related_entities: Optional[bool] = True,
) -> None: ...


Expand All @@ -61,6 +68,7 @@ def log_metadata(
metadata: Dict[str, MetadataType],
model_name: str,
model_version: str,
log_related_entities: Optional[bool] = True,
) -> None: ...


Expand All @@ -69,6 +77,7 @@ def log_metadata(
*,
metadata: Dict[str, MetadataType],
step_id: UUID,
log_related_entities: Optional[bool] = True,
) -> None: ...


Expand All @@ -77,6 +86,7 @@ def log_metadata(
*,
metadata: Dict[str, MetadataType],
run_id_name_or_prefix: Union[UUID, str],
log_related_entities: Optional[bool] = True,
) -> None: ...


Expand All @@ -86,6 +96,7 @@ def log_metadata(
metadata: Dict[str, MetadataType],
step_name: str,
run_id_name_or_prefix: Union[UUID, str],
log_related_entities: Optional[bool] = True,
) -> None: ...


Expand All @@ -103,6 +114,8 @@ def log_metadata(
model_version_id: Optional[UUID] = None,
model_name: Optional[str] = None,
model_version: Optional[str] = None,
# Parameter to adjust whether we log to all related entities
log_related_entities: Optional[bool] = True,
) -> None:
"""Logs metadata for various resource types in a generalized way.
Expand All @@ -117,6 +130,8 @@ def log_metadata(
model_version_id: The ID of the model version.
model_name: The name of the model.
model_version: The version of the model
log_related_entities: Flag to decide whether we should log the same
metadata for related entities.
Raises:
ValueError: If no identifiers are provided and the function is not
Expand All @@ -130,29 +145,37 @@ def log_metadata(
run = client.get_pipeline_run(run_id_name_or_prefix)
step = run.steps[step_name]

resources = [
(run.id, MetadataResourceTypes.PIPELINE_RUN),
(step.id, MetadataResourceTypes.STEP_RUN),
]
if step.model_version:
resources.append(
(step.model_version.id, MetadataResourceTypes.MODEL_VERSION)
)
resources = [(step.id, MetadataResourceTypes.STEP_RUN)]

if log_related_entities:
resources.append((run.id, MetadataResourceTypes.PIPELINE_RUN))
if step.model_version:
resources.append(
(
step.model_version.id,
MetadataResourceTypes.MODEL_VERSION,
)
)
client.create_run_metadata(metadata=metadata, resources=resources)
# If a step is identified by id, fetch it directly through the client,
# follow a similar procedure and log metadata for its pipeline and model
# as well.
elif step_id is not None:
step = client.get_run_step(step_id)
resources = [(step_id, MetadataResourceTypes.STEP_RUN)]

resources = [
(step.pipeline_run_id, MetadataResourceTypes.PIPELINE_RUN),
(step.id, MetadataResourceTypes.STEP_RUN),
]
if step.model_version:
if log_related_entities:
step = client.get_run_step(step_id)
resources.append(
(step.model_version.id, MetadataResourceTypes.MODEL_VERSION)
(step.pipeline_run_id, MetadataResourceTypes.PIPELINE_RUN)
)

if step.model_version:
resources.append(
(
step.model_version.id,
MetadataResourceTypes.MODEL_VERSION,
)
)
client.create_run_metadata(metadata=metadata, resources=resources)

# If a pipeline run id is identified, we need to log metadata to it and its
Expand All @@ -162,7 +185,7 @@ def log_metadata(

resources = [(run.id, MetadataResourceTypes.PIPELINE_RUN)]

if run.model_version:
if log_related_entities and run.model_version is not None:
resources.append(
(run.model_version.id, MetadataResourceTypes.MODEL_VERSION)
)
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/zen_stores/schemas/run_metadata_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class RunMetadataResourceSchema(SQLModel, table=True):
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)",
overlaps="run_metadata,step_run,artifact_version,model_version",
)
),
)
step_run: List["StepRunSchema"] = Relationship(
back_populates="run_metadata",
Expand Down

0 comments on commit b767269

Please sign in to comment.