Skip to content

Commit

Permalink
Fix MySQL warning when filtering pipelines by latest run (#3051)
Browse files Browse the repository at this point in the history
* Remove MySQL warning when filtering pipelines by latest run

* Linting
  • Loading branch information
schustmi authored Oct 8, 2024
1 parent 8ea6053 commit 47a07bd
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,6 @@ def filter_and_paginate(
ValueError: if the filtered page number is out of bounds.
RuntimeError: if the schema does not have a `to_model` method.
"""
query = query.distinct()
query = filter_model.apply_filter(query=query, table=table)
query = query.distinct()

Expand Down Expand Up @@ -4108,7 +4107,8 @@ def list_pipelines(
Returns:
A list of all pipelines matching the filter criteria.
"""
query = select(PipelineSchema)
query: Union[Select[Any], SelectOfScalar[Any]] = select(PipelineSchema)
_custom_conversion: Optional[Callable[[Any], PipelineResponse]] = None

column, operand = pipeline_filter_model.sorting_params
if column == SORT_PIPELINES_BY_LATEST_RUN_KEY:
Expand Down Expand Up @@ -4141,21 +4141,35 @@ def list_pipelines(
sort_clause = asc

query = (
query.where(PipelineSchema.id == max_date_subquery.c.id)
# We need to include the subquery in the select here to
# make this query work with the distinct statement. This
# result will be removed in the custom conversion function
# applied later
select(PipelineSchema, max_date_subquery.c.run_or_created)
.where(PipelineSchema.id == max_date_subquery.c.id)
.order_by(sort_clause(max_date_subquery.c.run_or_created))
# We always add the `id` column as a tiebreaker to ensure a
# stable, repeatable order of items, otherwise subsequent
# pages might contain the same items.
.order_by(col(PipelineSchema.id))
)

def _custom_conversion(row: Any) -> PipelineResponse:
return cast(
PipelineResponse,
row[0].to_model(
include_metadata=hydrate, include_resources=True
),
)

with Session(self.engine) as session:
return self.filter_and_paginate(
session=session,
query=query,
table=PipelineSchema,
filter_model=pipeline_filter_model,
hydrate=hydrate,
custom_schema_to_model_conversion=_custom_conversion,
)

def count_pipelines(self, filter_model: Optional[PipelineFilter]) -> int:
Expand Down

0 comments on commit 47a07bd

Please sign in to comment.