Skip to content

Commit

Permalink
Same for pipeline run links
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Nov 6, 2024
1 parent b161354 commit 5b44c9b
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 185 deletions.
18 changes: 0 additions & 18 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6480,9 +6480,6 @@ def list_model_version_artifact_links(
logical_operator: LogicalOperators = LogicalOperators.AND,
created: Optional[Union[datetime, str]] = None,
updated: Optional[Union[datetime, str]] = None,
workspace_id: Optional[Union[UUID, str]] = None,
user_id: Optional[Union[UUID, str]] = None,
model_id: Optional[Union[UUID, str]] = None,
model_version_id: Optional[Union[UUID, str]] = None,
artifact_version_id: Optional[Union[UUID, str]] = None,
artifact_name: Optional[str] = None,
Expand All @@ -6502,9 +6499,6 @@ def list_model_version_artifact_links(
logical_operator: Which logical operator to use [and, or]
created: Use to filter by time of creation
updated: Use the last updated date for filtering
workspace_id: Use the workspace id for filtering
user_id: Use the user id for filtering
model_id: Use the model id for filtering
model_version_id: Use the model version id for filtering
artifact_version_id: Use the artifact id for filtering
artifact_name: Use the artifact name for filtering
Expand All @@ -6527,9 +6521,6 @@ def list_model_version_artifact_links(
size=size,
created=created,
updated=updated,
workspace_id=workspace_id,
user_id=user_id,
model_id=model_id,
model_version_id=model_version_id,
artifact_version_id=artifact_version_id,
artifact_name=artifact_name,
Expand Down Expand Up @@ -6601,9 +6592,6 @@ def list_model_version_pipeline_run_links(
logical_operator: LogicalOperators = LogicalOperators.AND,
created: Optional[Union[datetime, str]] = None,
updated: Optional[Union[datetime, str]] = None,
workspace_id: Optional[Union[UUID, str]] = None,
user_id: Optional[Union[UUID, str]] = None,
model_id: Optional[Union[UUID, str]] = None,
model_version_id: Optional[Union[UUID, str]] = None,
pipeline_run_id: Optional[Union[UUID, str]] = None,
pipeline_run_name: Optional[str] = None,
Expand All @@ -6619,9 +6607,6 @@ def list_model_version_pipeline_run_links(
logical_operator: Which logical operator to use [and, or]
created: Use to filter by time of creation
updated: Use the last updated date for filtering
workspace_id: Use the workspace id for filtering
user_id: Use the user id for filtering
model_id: Use the model id for filtering
model_version_id: Use the model version id for filtering
pipeline_run_id: Use the pipeline run id for filtering
pipeline_run_name: Use the pipeline run name for filtering
Expand All @@ -6640,9 +6625,6 @@ def list_model_version_pipeline_run_links(
size=size,
created=created,
updated=updated,
workspace_id=workspace_id,
user_id=user_id,
model_id=model_id,
model_version_id=model_version_id,
pipeline_run_id=pipeline_run_id,
pipeline_run_name=pipeline_run_name,
Expand Down
6 changes: 0 additions & 6 deletions src/zenml/models/v2/core/model_version_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,11 @@ class ModelVersionArtifactFilter(BaseFilter):
"only_model_artifacts",
"only_deployment_artifacts",
"has_custom_name",
"model_id",
"model_version_id",
"updated",
"id",
]

model_id: Optional[Union[UUID, str]] = Field(
default=None,
description="Filter by model ID",
union_mode="left_to_right",
)
model_version_id: Optional[Union[UUID, str]] = Field(
default=None,
description="Filter by model version ID",
Expand Down
45 changes: 6 additions & 39 deletions src/zenml/models/v2/core/model_version_pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,19 @@
from zenml.models.v2.base.base import (
BaseDatedResponseBody,
BaseIdentifiedResponse,
BaseRequest,
BaseResponseMetadata,
BaseResponseResources,
)
from zenml.models.v2.base.filter import StrFilter
from zenml.models.v2.base.scoped import (
WorkspaceScopedFilter,
WorkspaceScopedRequest,
)
from zenml.models.v2.base.filter import BaseFilter, StrFilter
from zenml.models.v2.core.pipeline_run import PipelineRunResponse

# ------------------ Request Model ------------------


class ModelVersionPipelineRunRequest(WorkspaceScopedRequest):
class ModelVersionPipelineRunRequest(BaseRequest):
"""Request model for links between model versions and pipeline runs."""

model: UUID
model_version: UUID
pipeline_run: UUID

Expand All @@ -62,7 +58,6 @@ class ModelVersionPipelineRunRequest(WorkspaceScopedRequest):
class ModelVersionPipelineRunResponseBody(BaseDatedResponseBody):
"""Response body for links between model versions and pipeline runs."""

model: UUID
model_version: UUID
pipeline_run: PipelineRunResponse

Expand All @@ -88,16 +83,6 @@ class ModelVersionPipelineRunResponse(
):
"""Response model for links between model versions and pipeline runs."""

# Body and metadata properties
@property
def model(self) -> UUID:
"""The `model` property.
Returns:
the value of the property.
"""
return self.get_body().model

@property
def model_version(self) -> UUID:
"""The `model_version` property.
Expand All @@ -120,39 +105,21 @@ def pipeline_run(self) -> "PipelineRunResponse":
# ------------------ Filter Model ------------------


class ModelVersionPipelineRunFilter(WorkspaceScopedFilter):
class ModelVersionPipelineRunFilter(BaseFilter):
"""Model version pipeline run links filter model."""

FILTER_EXCLUDE_FIELDS = [
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
*BaseFilter.FILTER_EXCLUDE_FIELDS,
"pipeline_run_name",
"user",
]
CLI_EXCLUDE_FIELDS = [
*WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
"model_id",
*BaseFilter.CLI_EXCLUDE_FIELDS,
"model_version_id",
"user_id",
"workspace_id",
"updated",
"id",
]

workspace_id: Optional[Union[UUID, str]] = Field(
default=None,
description="The workspace of the Model Version",
union_mode="left_to_right",
)
user_id: Optional[Union[UUID, str]] = Field(
default=None,
description="The user of the Model Version",
union_mode="left_to_right",
)
model_id: Optional[Union[UUID, str]] = Field(
default=None,
description="Filter by model ID",
union_mode="left_to_right",
)
model_version_id: Optional[Union[UUID, str]] = Field(
default=None,
description="Filter by model version ID",
Expand Down
3 changes: 0 additions & 3 deletions src/zenml/orchestrators/step_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,7 @@ def link_pipeline_run_to_model_version(
client = Client()
client.zen_store.create_model_version_pipeline_run_link(
ModelVersionPipelineRunRequest(
user=client.active_user.id,
workspace=client.active_workspace.id,
pipeline_run=pipeline_run.id,
model=model_version.model.id,
model_version=model_version.id,
)
)
Expand Down
30 changes: 30 additions & 0 deletions src/zenml/zen_server/routers/model_versions_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ModelVersionArtifactResponse,
ModelVersionFilter,
ModelVersionPipelineRunFilter,
ModelVersionPipelineRunRequest,
ModelVersionPipelineRunResponse,
ModelVersionResponse,
ModelVersionUpdate,
Expand Down Expand Up @@ -320,6 +321,35 @@ def delete_all_model_version_artifact_links(
)


@router.post(
"",
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_model_version_pipeline_run_link(
model_version_pipeline_run_link: ModelVersionPipelineRunRequest,
_: AuthContext = Security(authorize),
) -> ModelVersionPipelineRunResponse:
"""Create a new model version to pipeline run link.
Args:
model_version_pipeline_run_link: The model version to pipeline run link to create.
Returns:
- If Model Version to Pipeline Run Link already exists - returns the existing link.
- Otherwise, returns the newly created model version to pipeline run link.
"""
model_version = zen_store().get_model_version(
model_version_pipeline_run_link.model_version, hydrate=False
)
verify_permission_for_model(model_version, action=Action.UPDATE)

mv = zen_store().create_model_version_pipeline_run_link(
model_version_pipeline_run_link
)
return mv


@model_version_pipeline_runs_router.get(
"",
response_model=Page[ModelVersionPipelineRunResponse],
Expand Down
66 changes: 0 additions & 66 deletions src/zenml/zen_server/routers/workspaces_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@
ComponentResponse,
ModelRequest,
ModelResponse,
ModelVersionPipelineRunRequest,
ModelVersionPipelineRunResponse,
ModelVersionRequest,
ModelVersionResponse,
Page,
Expand Down Expand Up @@ -1440,70 +1438,6 @@ def create_model_version(
)


@router.post(
WORKSPACES
+ "/{workspace_name_or_id}"
+ MODEL_VERSIONS
+ "/{model_version_id}"
+ RUNS,
response_model=ModelVersionPipelineRunResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_model_version_pipeline_run_link(
workspace_name_or_id: Union[str, UUID],
model_version_id: UUID,
model_version_pipeline_run_link: ModelVersionPipelineRunRequest,
auth_context: AuthContext = Security(authorize),
) -> ModelVersionPipelineRunResponse:
"""Create a new model version to pipeline run link.
Args:
workspace_name_or_id: Name or ID of the workspace.
model_version_id: ID of the model version.
model_version_pipeline_run_link: The model version to pipeline run link to create.
auth_context: Authentication context.
Returns:
- If Model Version to Pipeline Run Link already exists - returns the existing link.
- Otherwise, returns the newly created model version to pipeline run link.
Raises:
IllegalOperationError: If the workspace or user specified in the
model version does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if str(model_version_id) != str(
model_version_pipeline_run_link.model_version
):
raise IllegalOperationError(
f"The model version id in your path `{model_version_id}` does not "
f"match the model version specified in the request model "
f"`{model_version_pipeline_run_link.model_version}`"
)

if model_version_pipeline_run_link.workspace != workspace.id:
raise IllegalOperationError(
"Creating model versions outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if model_version_pipeline_run_link.user != auth_context.user.id:
raise IllegalOperationError(
"Creating models for a user other than yourself "
"is not supported."
)

model_version = zen_store().get_model_version(model_version_id)
verify_permission_for_model(model_version, action=Action.UPDATE)

mv = zen_store().create_model_version_pipeline_run_link(
model_version_pipeline_run_link
)
return mv


@router.post(
WORKSPACES + "/{workspace_name_or_id}" + SERVICES,
response_model=ServiceResponse,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Simplify model version artifacts2 [5591a5051f51].
Revision ID: 5591a5051f51
Revises: ec6307720f92
Create Date: 2024-11-06 16:51:01.346080
"""
import sqlmodel
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '5591a5051f51'
down_revision = 'ec6307720f92'
branch_labels = None
depends_on = None


def upgrade() -> None:
"""Upgrade database schema and/or data, creating a new revision."""
# ### commands auto generated by Alembic - please adjust! ###


# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade database schema and/or data back to the previous revision."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('model_versions_runs', schema=None) as batch_op:
batch_op.add_column(sa.Column('user_id', sa.CHAR(length=32), nullable=True))
batch_op.add_column(sa.Column('workspace_id', sa.CHAR(length=32), nullable=False))
batch_op.add_column(sa.Column('model_id', sa.CHAR(length=32), nullable=False))
batch_op.create_foreign_key('fk_model_versions_runs_user_id_user', 'user', ['user_id'], ['id'], ondelete='SET NULL')
batch_op.create_foreign_key('fk_model_versions_runs_workspace_id_workspace', 'workspace', ['workspace_id'], ['id'], ondelete='CASCADE')
batch_op.create_foreign_key('fk_model_versions_runs_model_id_model', 'model', ['model_id'], ['id'], ondelete='CASCADE')

# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Simplify model version artifacts [ec6307720f92].
"""Simplify model version links [ec6307720f92].
Revision ID: ec6307720f92
Revises: c22561cbb3a9
Expand Down Expand Up @@ -86,6 +86,14 @@ def upgrade() -> None:
batch_op.drop_column("workspace_id")
batch_op.drop_column("is_model_artifact")

with op.batch_alter_table('model_versions_runs', schema=None) as batch_op:
batch_op.drop_constraint('fk_model_versions_runs_model_id_model', type_='foreignkey')
batch_op.drop_constraint('fk_model_versions_runs_workspace_id_workspace', type_='foreignkey')
batch_op.drop_constraint('fk_model_versions_runs_user_id_user', type_='foreignkey')
batch_op.drop_column('model_id')
batch_op.drop_column('workspace_id')
batch_op.drop_column('user_id')

# ### end Alembic commands ###


Expand Down
4 changes: 2 additions & 2 deletions src/zenml/zen_stores/rest_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3802,10 +3802,10 @@ def create_model_version_pipeline_run_link(
- Otherwise, returns the newly created model version to pipeline
run link.
"""
return self._create_workspace_scoped_resource(
return self._create_resource(
resource=model_version_pipeline_run_link,
response_model=ModelVersionPipelineRunResponse,
route=f"{MODEL_VERSIONS}/{model_version_pipeline_run_link.model_version}{RUNS}",
route=MODEL_VERSION_PIPELINE_RUNS,
)

def list_model_version_pipeline_run_links(
Expand Down
Loading

0 comments on commit 5b44c9b

Please sign in to comment.