Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New log_metadata function, new oneof filtering, additional run_metadata filtering #3182

Merged
merged 72 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 71 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
697a93f
Initial commit, nuking all metadata responses and seeing what breaks
AlexejPenner Oct 17, 2024
733a6c8
Removed last remnant of LazyLoader
AlexejPenner Oct 17, 2024
3ce54ef
Merge branch 'develop' into feature/better-metadata
AlexejPenner Oct 18, 2024
ae71757
Reintroducing the lazy loaders.
AlexejPenner Oct 18, 2024
01b5179
Merge branch 'feature/better-metadata' of github.com:zenml-io/zenml i…
AlexejPenner Oct 18, 2024
7d0ff82
Add LazyRunMetadataResponse to EntrypointFunctionDefinition
avishniakov Oct 18, 2024
d7a9f83
Test for lazy loaders works now
AlexejPenner Oct 18, 2024
bccb4d2
Merge branch 'develop' into feature/better-metadata
AlexejPenner Oct 18, 2024
9a0e0b2
Fixed tests, reformatted
AlexejPenner Oct 21, 2024
145b90b
Use updated template
AlexejPenner Oct 21, 2024
1e1991a
Auto-update of Starter template
actions-user Oct 21, 2024
adab934
Merge branch 'develop' into feature/better-metadata
AlexejPenner Oct 21, 2024
d83628a
Updated more templates
AlexejPenner Oct 21, 2024
6d13071
Merge branch 'feature/better-metadata' of github.com:zenml-io/zenml i…
AlexejPenner Oct 21, 2024
c4febf3
Fixed failing test
AlexejPenner Oct 21, 2024
5aef8ab
Fixed step run schemas
AlexejPenner Oct 21, 2024
0b66f07
Auto-update of E2E template
actions-user Oct 21, 2024
4b2434a
Auto-update of NLP template
actions-user Oct 21, 2024
8f4af6e
Fixed tests, removed additional .value access
AlexejPenner Oct 21, 2024
cc6902b
Merge branch 'feature/better-metadata' of github.com:zenml-io/zenml i…
AlexejPenner Oct 21, 2024
edba625
Further fixing
AlexejPenner Oct 21, 2024
7d5cfb7
Merge branch 'develop' into feature/better-metadata
AlexejPenner Oct 21, 2024
c2b6955
Fixed linting issues
AlexejPenner Oct 21, 2024
e2bd53a
Merge branch 'develop' into feature/better-metadata
AlexejPenner Oct 21, 2024
a582836
Merge branch 'feature/better-metadata' of github.com:zenml-io/zenml i…
AlexejPenner Oct 22, 2024
4f82ade
Merge branch 'develop' into feature/better-metadata
AlexejPenner Oct 22, 2024
58293bb
Merge branch 'feature/better-metadata' of github.com:zenml-io/zenml i…
AlexejPenner Oct 22, 2024
8f6d305
Reformatted
AlexejPenner Oct 22, 2024
6b18322
Linted, formatted and tested again
AlexejPenner Oct 22, 2024
8b3a1bd
Typing
AlexejPenner Oct 22, 2024
b34f18b
Merge branch 'develop' into feature/better-metadata
AlexejPenner Oct 28, 2024
5cc7b44
Maybe fix everything
schustmi Oct 28, 2024
c368dec
Apply some feedback
schustmi Oct 28, 2024
62e8d6e
merged develop
bcdurak Oct 29, 2024
050f5b5
resolved conflicts
bcdurak Nov 4, 2024
74c1a42
new operation
bcdurak Nov 5, 2024
53dc8e8
new log_metadata function
bcdurak Nov 5, 2024
68a455c
changes to the base filters
bcdurak Nov 5, 2024
4af4165
new filters
bcdurak Nov 6, 2024
fdf8945
adding log_metadata to __all__
bcdurak Nov 6, 2024
39f5bf8
checkpoint with float casting
bcdurak Nov 6, 2024
1c051ec
adding tests
bcdurak Nov 6, 2024
e284808
final touches and formatting
bcdurak Nov 6, 2024
d5bbf72
formatting
bcdurak Nov 6, 2024
3a0d4c8
moved the utils
bcdurak Nov 6, 2024
5b3b217
modified log metadata function
bcdurak Nov 6, 2024
3d5a9f0
checkpoint
bcdurak Nov 6, 2024
e3079a3
deprecating the old functions
bcdurak Nov 6, 2024
c3e69c2
linting and final fixes
bcdurak Nov 6, 2024
2d4c723
better error message
bcdurak Nov 6, 2024
206340c
merged develop
bcdurak Nov 7, 2024
2debd9e
merged develop
bcdurak Nov 8, 2024
7e20409
merged develop
bcdurak Nov 8, 2024
fbd0200
fixing the client method
bcdurak Nov 8, 2024
ec7dc02
better error message
bcdurak Nov 8, 2024
1fafb7e
consistent creation\
bcdurak Nov 8, 2024
ad4a4f7
merged develop
bcdurak Nov 8, 2024
d90f55d
adjusting tests
bcdurak Nov 8, 2024
e0db418
linting
bcdurak Nov 8, 2024
14dfdea
changes for step metadata
bcdurak Nov 8, 2024
d89358d
more test adjustments
bcdurak Nov 8, 2024
7d90305
testing unit tests
bcdurak Nov 8, 2024
b060987
linting
bcdurak Nov 8, 2024
43a7034
fixing more tests
bcdurak Nov 8, 2024
28ecdc1
fixing more tests
bcdurak Nov 8, 2024
e0c5e4f
more test fixes
bcdurak Nov 8, 2024
6edc16e
fixing the test
bcdurak Nov 11, 2024
030d530
fixing per comments
bcdurak Nov 11, 2024
929fba4
added validation, constant error message
bcdurak Nov 11, 2024
be79553
merged develop
bcdurak Nov 11, 2024
e07b777
Merge branch 'develop' into feature/best-metadata
bcdurak Nov 11, 2024
c1bcb00
linting
bcdurak Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/zenml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from zenml.pipelines import get_pipeline_context, pipeline
from zenml.steps import step, get_step_context
from zenml.steps.utils import log_step_metadata
from zenml.utils.metadata_utils import log_metadata
from zenml.entrypoints import entrypoint

__all__ = [
Expand All @@ -56,6 +57,7 @@
"get_pipeline_context",
"get_step_context",
"load_artifact",
"log_metadata",
"log_artifact_metadata",
"log_model_metadata",
"log_step_metadata",
Expand Down
6 changes: 5 additions & 1 deletion src/zenml/artifacts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def log_artifact_metadata(
not provided, when being called inside a step that produces an
artifact named `artifact_name`, the metadata will be associated to
the corresponding newly created artifact. Or, if not provided when
being called outside of a step, or in a step that does not produce
being called outside a step, or in a step that does not produce
any artifact named `artifact_name`, the metadata will be associated
to the latest version of that artifact.

Expand All @@ -417,6 +417,10 @@ def log_artifact_metadata(
called inside a step with a single output, or, if neither an
artifact nor an output with the given name exists.
"""
logger.warning(
"The `log_artifact_metadata` function is deprecated and will soon be "
"removed. Please use `log_metadata` instead."
)
try:
step_context = get_step_context()
in_step_outputs = (artifact_name in step_context._outputs) or (
bcdurak marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
7 changes: 6 additions & 1 deletion src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3796,6 +3796,7 @@ def list_pipeline_runs(
templatable: Optional[bool] = None,
tag: Optional[str] = None,
user: Optional[Union[UUID, str]] = None,
run_metadata: Optional[Dict[str, str]] = None,
pipeline: Optional[Union[UUID, str]] = None,
code_repository: Optional[Union[UUID, str]] = None,
model: Optional[Union[UUID, str]] = None,
Expand Down Expand Up @@ -3835,6 +3836,7 @@ def list_pipeline_runs(
templatable: If the runs should be templatable or not.
tag: Tag to filter by.
user: The name/ID of the user to filter by.
run_metadata: The run_metadata of the run to filter by.
pipeline: The name/ID of the pipeline to filter by.
code_repository: Filter by code repository name/ID.
model: Filter by model name/ID.
Expand Down Expand Up @@ -3874,6 +3876,7 @@ def list_pipeline_runs(
tag=tag,
unlisted=unlisted,
user=user,
run_metadata=run_metadata,
pipeline=pipeline,
code_repository=code_repository,
stack=stack,
Expand Down Expand Up @@ -4194,7 +4197,7 @@ def get_artifact_version(
),
)
except RuntimeError:
pass # Cannot link to step run if called outside of a step
pass # Cannot link to step run if called outside a step
return artifact

def list_artifact_versions(
Expand Down Expand Up @@ -4222,6 +4225,7 @@ def list_artifact_versions(
user: Optional[Union[UUID, str]] = None,
model: Optional[Union[UUID, str]] = None,
pipeline_run: Optional[Union[UUID, str]] = None,
run_metadata: Optional[Dict[str, str]] = None,
tag: Optional[str] = None,
hydrate: bool = False,
) -> Page[ArtifactVersionResponse]:
Expand Down Expand Up @@ -4253,6 +4257,7 @@ def list_artifact_versions(
user: Filter by user name or ID.
model: Filter by model name or ID.
pipeline_run: Filter by pipeline run name or ID.
run_metadata: Filter by run metadata.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.

Expand Down
1 change: 1 addition & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class GenericFilterOps(StrEnum):
CONTAINS = "contains"
STARTSWITH = "startswith"
ENDSWITH = "endswith"
ONEOF = "oneof"
GTE = "gte"
GT = "gt"
LTE = "lte"
Expand Down
5 changes: 5 additions & 0 deletions src/zenml/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def log_model_metadata(
ValueError: If no model name/version is provided and the function is not
called inside a step with configured `model` in decorator.
"""
logger.warning(
"The `log_model_metadata` function is deprecated and will soon be "
"removed. Please use `log_metadata` instead."
)

if model_name and model_version:
from zenml import Model

bcdurak marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
121 changes: 113 additions & 8 deletions src/zenml/models/v2/base/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""Base filter model definitions."""

import json
from abc import ABC, abstractmethod
from datetime import datetime
from typing import (
Expand All @@ -36,7 +37,7 @@
field_validator,
model_validator,
)
from sqlalchemy import asc, desc
from sqlalchemy import Float, and_, asc, cast, desc
from sqlmodel import SQLModel

from zenml.constants import (
Expand All @@ -63,6 +64,11 @@

AnyQuery = TypeVar("AnyQuery", bound=Any)

ONEOF_ERROR = (
"When you are using the 'oneof:' filtering make sure that the "
"provided value is a json formatted list."
)


class Filter(BaseModel, ABC):
"""Filter for all fields.
Expand Down Expand Up @@ -171,8 +177,20 @@ class StrFilter(Filter):
GenericFilterOps.STARTSWITH,
GenericFilterOps.CONTAINS,
GenericFilterOps.ENDSWITH,
GenericFilterOps.ONEOF,
GenericFilterOps.GT,
GenericFilterOps.GTE,
GenericFilterOps.LT,
GenericFilterOps.LTE,
]

@model_validator(mode="after")
def check_value_if_operation_oneof(self) -> "StrFilter":
if self.operation == GenericFilterOps.ONEOF:
if not isinstance(self.value, list):
raise ValueError(ONEOF_ERROR)
return self

def generate_query_conditions_from_column(self, column: Any) -> Any:
"""Generate query conditions for a string column.

Expand All @@ -181,6 +199,9 @@ def generate_query_conditions_from_column(self, column: Any) -> Any:

Returns:
A list of query conditions.

Raises:
ValueError: the comparison of the column to a numeric value fails.
"""
if self.operation == GenericFilterOps.CONTAINS:
return column.like(f"%{self.value}%")
Expand All @@ -190,6 +211,40 @@ def generate_query_conditions_from_column(self, column: Any) -> Any:
return column.endswith(f"{self.value}")
if self.operation == GenericFilterOps.NOT_EQUALS:
return column != self.value
if self.operation == GenericFilterOps.ONEOF:
bcdurak marked this conversation as resolved.
Show resolved Hide resolved
return column.in_(self.value)
if self.operation in {
GenericFilterOps.GT,
GenericFilterOps.LT,
GenericFilterOps.GTE,
GenericFilterOps.LTE,
}:
try:
numeric_column = cast(column, Float)

assert self.value is not None

if self.operation == GenericFilterOps.GT:
return and_(
numeric_column, numeric_column > float(self.value)
)
if self.operation == GenericFilterOps.LT:
return and_(
numeric_column, numeric_column < float(self.value)
)
if self.operation == GenericFilterOps.GTE:
return and_(
numeric_column, numeric_column >= float(self.value)
)
if self.operation == GenericFilterOps.LTE:
return and_(
numeric_column, numeric_column <= float(self.value)
)
except Exception as e:
raise ValueError(
f"Failed to compare the column '{column}' to the "
f"value '{self.value}' (must be numeric): {e}"
)

return column == self.value

Expand All @@ -211,6 +266,9 @@ def _remove_hyphens_from_value(cls, value: Any) -> Any:
if isinstance(value, str):
return value.replace("-", "")

if isinstance(value, list):
return [str(v).replace("-", "") for v in value]
bcdurak marked this conversation as resolved.
Show resolved Hide resolved

return value

def generate_query_conditions_from_column(self, column: Any) -> Any:
Expand Down Expand Up @@ -588,6 +646,10 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]:

Returns:
A tuple of the filter value and the operator.

Raises:
ValueError: when we try to use the `oneof` operator with the wrong
value.
"""
operator = GenericFilterOps.EQUALS # Default operator
if isinstance(value, str):
Expand All @@ -598,6 +660,15 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]:
):
value = split_value[1]
operator = GenericFilterOps(split_value[0])

if operator == operator.ONEOF:
try:
value = json.loads(value)
bcdurak marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(value, list):
raise ValueError
except ValueError:
raise ValueError(ONEOF_ERROR)

return value, operator

def generate_name_or_id_query_conditions(
Expand Down Expand Up @@ -648,8 +719,8 @@ def generate_name_or_id_query_conditions(

return or_(*conditions)

@staticmethod
def generate_custom_query_conditions_for_column(
self,
value: Any,
table: Type[SQLModel],
column: str,
Expand Down Expand Up @@ -833,16 +904,17 @@ def define_filter(

# Create str filters
if self.is_str_field(column):
return StrFilter(
operation=GenericFilterOps(operator),
return self._define_str_filter(
operator=GenericFilterOps(operator),
column=column,
value=value,
)

# Handle unsupported datatypes
logger.warning(
f"The Datatype {self._model_class.model_fields[column].annotation} might "
"not be supported for filtering. Defaulting to a string filter."
f"The Datatype {self._model_class.model_fields[column].annotation} "
"might not be supported for filtering. Defaulting to a string "
"filter."
)
return StrFilter(
operation=GenericFilterOps(operator),
Expand Down Expand Up @@ -1032,8 +1104,9 @@ def _define_uuid_filter(
"Invalid value passed as UUID query parameter."
) from e

# Cast the value to string for further comparisons.
value = str(value)
# For equality checks, ensure that the value is a valid UUID.
if operator == GenericFilterOps.ONEOF and not isinstance(value, list):
raise ValueError(ONEOF_ERROR)

# Generate the filter.
uuid_filter = UUIDFilter(
Expand All @@ -1043,6 +1116,38 @@ def _define_uuid_filter(
)
return uuid_filter

@staticmethod
def _define_str_filter(
column: str, value: Any, operator: GenericFilterOps
) -> StrFilter:
"""Define a str filter for a given column.

Args:
column: The column to filter on.
value: The UUID value by which to filter.
operator: The operator to use for filtering.

Returns:
A Filter object.

Raises:
ValueError: If the value is not a proper value.
"""
# For equality checks, ensure that the value is a valid UUID.
if operator == GenericFilterOps.ONEOF and not isinstance(value, list):
raise ValueError(
"If you are using `oneof:` as a filtering op, the value needs "
"to be a json formatted list string."
bcdurak marked this conversation as resolved.
Show resolved Hide resolved
)

# Generate the filter.
str_filter = StrFilter(
operation=GenericFilterOps(operator),
column=column,
value=value,
)
return str_filter

@staticmethod
def _define_bool_filter(
column: str, value: Any, operator: GenericFilterOps
Expand Down
23 changes: 23 additions & 0 deletions src/zenml/models/v2/core/artifact_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
"user",
"model",
"pipeline_run",
"run_metadata",
]
artifact_id: Optional[Union[UUID, str]] = Field(
default=None,
Expand Down Expand Up @@ -545,6 +546,10 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
description="Name/ID of a pipeline run that is associated with this "
"artifact version.",
)
run_metadata: Optional[Dict[str, str]] = Field(
default=None,
description="The run_metadata to filter the artifact versions by.",
)

model_config = ConfigDict(protected_namespaces=())

Expand All @@ -564,6 +569,7 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]:
ModelSchema,
ModelVersionArtifactSchema,
PipelineRunSchema,
RunMetadataSchema,
StepRunInputArtifactSchema,
StepRunOutputArtifactSchema,
StepRunSchema,
Expand Down Expand Up @@ -645,6 +651,23 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]:
)
custom_filters.append(pipeline_run_filter)

if self.run_metadata is not None:
from zenml.enums import MetadataResourceTypes

for key, value in self.run_metadata.items():
additional_filter = and_(
RunMetadataSchema.resource_id == ArtifactVersionSchema.id,
RunMetadataSchema.resource_type
== MetadataResourceTypes.ARTIFACT_VERSION,
RunMetadataSchema.key == key,
self.generate_custom_query_conditions_for_column(
value=value,
table=RunMetadataSchema,
column="value",
),
)
custom_filters.append(additional_filter)

return custom_filters


Expand Down
Loading
Loading