From 5b44c9b57b49daeeb52ac784494bc557f1ec9ec5 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 6 Nov 2024 16:52:12 +0100 Subject: [PATCH] Same for pipeline run links --- src/zenml/client.py | 18 ----- .../models/v2/core/model_version_artifact.py | 6 -- .../v2/core/model_version_pipeline_run.py | 45 ++----------- src/zenml/orchestrators/step_run_utils.py | 3 - .../routers/model_versions_endpoints.py | 30 +++++++++ .../routers/workspaces_endpoints.py | 66 ------------------- ...51f51_simplify_model_version_artifacts2.py | 39 +++++++++++ ...307720f92_simplify_model_version_links.py} | 10 ++- src/zenml/zen_stores/rest_zen_store.py | 4 +- src/zenml/zen_stores/schemas/model_schemas.py | 41 ------------ src/zenml/zen_stores/schemas/user_schemas.py | 3 - .../zen_stores/schemas/workspace_schemas.py | 6 -- 12 files changed, 86 insertions(+), 185 deletions(-) create mode 100644 src/zenml/zen_stores/migrations/versions/5591a5051f51_simplify_model_version_artifacts2.py rename src/zenml/zen_stores/migrations/versions/{ec6307720f92_simplify_model_version_artifacts.py => ec6307720f92_simplify_model_version_links.py} (84%) diff --git a/src/zenml/client.py b/src/zenml/client.py index 2a16cab63e6..0dfa8efba56 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -6480,9 +6480,6 @@ def list_model_version_artifact_links( logical_operator: LogicalOperators = LogicalOperators.AND, created: Optional[Union[datetime, str]] = None, updated: Optional[Union[datetime, str]] = None, - workspace_id: Optional[Union[UUID, str]] = None, - user_id: Optional[Union[UUID, str]] = None, - model_id: Optional[Union[UUID, str]] = None, model_version_id: Optional[Union[UUID, str]] = None, artifact_version_id: Optional[Union[UUID, str]] = None, artifact_name: Optional[str] = None, @@ -6502,9 +6499,6 @@ def list_model_version_artifact_links( logical_operator: Which logical operator to use [and, or] created: Use to filter by time of creation updated: Use the last updated date for filtering - workspace_id: Use the workspace id for filtering - user_id: Use the user id for filtering - model_id: Use the model id for filtering model_version_id: Use the model version id for filtering artifact_version_id: Use the artifact id for filtering artifact_name: Use the artifact name for filtering @@ -6527,9 +6521,6 @@ def list_model_version_artifact_links( size=size, created=created, updated=updated, - workspace_id=workspace_id, - user_id=user_id, - model_id=model_id, model_version_id=model_version_id, artifact_version_id=artifact_version_id, artifact_name=artifact_name, @@ -6601,9 +6592,6 @@ def list_model_version_pipeline_run_links( logical_operator: LogicalOperators = LogicalOperators.AND, created: Optional[Union[datetime, str]] = None, updated: Optional[Union[datetime, str]] = None, - workspace_id: Optional[Union[UUID, str]] = None, - user_id: Optional[Union[UUID, str]] = None, - model_id: Optional[Union[UUID, str]] = None, model_version_id: Optional[Union[UUID, str]] = None, pipeline_run_id: Optional[Union[UUID, str]] = None, pipeline_run_name: Optional[str] = None, @@ -6619,9 +6607,6 @@ def list_model_version_pipeline_run_links( logical_operator: Which logical operator to use [and, or] created: Use to filter by time of creation updated: Use the last updated date for filtering - workspace_id: Use the workspace id for filtering - user_id: Use the user id for filtering - model_id: Use the model id for filtering model_version_id: Use the model version id for filtering pipeline_run_id: Use the pipeline run id for filtering pipeline_run_name: Use the pipeline run name for filtering @@ -6640,9 +6625,6 @@ def list_model_version_pipeline_run_links( size=size, created=created, updated=updated, - workspace_id=workspace_id, - user_id=user_id, - model_id=model_id, model_version_id=model_version_id, pipeline_run_id=pipeline_run_id, pipeline_run_name=pipeline_run_name, diff --git a/src/zenml/models/v2/core/model_version_artifact.py b/src/zenml/models/v2/core/model_version_artifact.py index 3949e6e1f27..075a4eaad40 100644 --- a/src/zenml/models/v2/core/model_version_artifact.py +++ b/src/zenml/models/v2/core/model_version_artifact.py @@ -128,17 +128,11 @@ class ModelVersionArtifactFilter(BaseFilter): "only_model_artifacts", "only_deployment_artifacts", "has_custom_name", - "model_id", "model_version_id", "updated", "id", ] - model_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Filter by model ID", - union_mode="left_to_right", - ) model_version_id: Optional[Union[UUID, str]] = Field( default=None, description="Filter by model version ID", diff --git a/src/zenml/models/v2/core/model_version_pipeline_run.py b/src/zenml/models/v2/core/model_version_pipeline_run.py index eb294d5f92f..6181c2ffbb1 100644 --- a/src/zenml/models/v2/core/model_version_pipeline_run.py +++ b/src/zenml/models/v2/core/model_version_pipeline_run.py @@ -23,23 +23,19 @@ from zenml.models.v2.base.base import ( BaseDatedResponseBody, BaseIdentifiedResponse, + BaseRequest, BaseResponseMetadata, BaseResponseResources, ) -from zenml.models.v2.base.filter import StrFilter -from zenml.models.v2.base.scoped import ( - WorkspaceScopedFilter, - WorkspaceScopedRequest, -) +from zenml.models.v2.base.filter import BaseFilter, StrFilter from zenml.models.v2.core.pipeline_run import PipelineRunResponse # ------------------ Request Model ------------------ -class ModelVersionPipelineRunRequest(WorkspaceScopedRequest): +class ModelVersionPipelineRunRequest(BaseRequest): """Request model for links between model versions and pipeline runs.""" - model: UUID model_version: UUID pipeline_run: UUID @@ -62,7 +58,6 @@ class ModelVersionPipelineRunRequest(WorkspaceScopedRequest): class ModelVersionPipelineRunResponseBody(BaseDatedResponseBody): """Response body for links between model versions and pipeline runs.""" - model: UUID model_version: UUID pipeline_run: PipelineRunResponse @@ -88,16 +83,6 @@ class ModelVersionPipelineRunResponse( ): """Response model for links between model versions and pipeline runs.""" - # Body and metadata properties - @property - def model(self) -> UUID: - """The `model` property. - - Returns: - the value of the property. - """ - return self.get_body().model - @property def model_version(self) -> UUID: """The `model_version` property. @@ -120,39 +105,21 @@ def pipeline_run(self) -> "PipelineRunResponse": # ------------------ Filter Model ------------------ -class ModelVersionPipelineRunFilter(WorkspaceScopedFilter): +class ModelVersionPipelineRunFilter(BaseFilter): """Model version pipeline run links filter model.""" FILTER_EXCLUDE_FIELDS = [ - *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, + *BaseFilter.FILTER_EXCLUDE_FIELDS, "pipeline_run_name", "user", ] CLI_EXCLUDE_FIELDS = [ - *WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS, - "model_id", + *BaseFilter.CLI_EXCLUDE_FIELDS, "model_version_id", - "user_id", - "workspace_id", "updated", "id", ] - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="The workspace of the Model Version", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="The user of the Model Version", - union_mode="left_to_right", - ) - model_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Filter by model ID", - union_mode="left_to_right", - ) model_version_id: Optional[Union[UUID, str]] = Field( default=None, description="Filter by model version ID", diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index f28a5123ae5..f2ba33d71ad 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -540,10 +540,7 @@ def link_pipeline_run_to_model_version( client = Client() client.zen_store.create_model_version_pipeline_run_link( ModelVersionPipelineRunRequest( - user=client.active_user.id, - workspace=client.active_workspace.id, pipeline_run=pipeline_run.id, - model=model_version.model.id, model_version=model_version.id, ) ) diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py index eb587e51d8c..6965308d58e 100644 --- a/src/zenml/zen_server/routers/model_versions_endpoints.py +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -33,6 +33,7 @@ ModelVersionArtifactResponse, ModelVersionFilter, ModelVersionPipelineRunFilter, + ModelVersionPipelineRunRequest, ModelVersionPipelineRunResponse, ModelVersionResponse, ModelVersionUpdate, @@ -320,6 +321,35 @@ def delete_all_model_version_artifact_links( ) +@router.post( + "", + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def create_model_version_pipeline_run_link( + model_version_pipeline_run_link: ModelVersionPipelineRunRequest, + _: AuthContext = Security(authorize), +) -> ModelVersionPipelineRunResponse: + """Create a new model version to pipeline run link. + + Args: + model_version_pipeline_run_link: The model version to pipeline run link to create. + + Returns: + - If Model Version to Pipeline Run Link already exists - returns the existing link. + - Otherwise, returns the newly created model version to pipeline run link. + """ + model_version = zen_store().get_model_version( + model_version_pipeline_run_link.model_version, hydrate=False + ) + verify_permission_for_model(model_version, action=Action.UPDATE) + + mv = zen_store().create_model_version_pipeline_run_link( + model_version_pipeline_run_link + ) + return mv + + @model_version_pipeline_runs_router.get( "", response_model=Page[ModelVersionPipelineRunResponse], diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 0af5a3f991e..db2e1b796ad 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -53,8 +53,6 @@ ComponentResponse, ModelRequest, ModelResponse, - ModelVersionPipelineRunRequest, - ModelVersionPipelineRunResponse, ModelVersionRequest, ModelVersionResponse, Page, @@ -1440,70 +1438,6 @@ def create_model_version( ) -@router.post( - WORKSPACES - + "/{workspace_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_id}" - + RUNS, - response_model=ModelVersionPipelineRunResponse, - responses={401: error_response, 409: error_response, 422: error_response}, -) -@handle_exceptions -def create_model_version_pipeline_run_link( - workspace_name_or_id: Union[str, UUID], - model_version_id: UUID, - model_version_pipeline_run_link: ModelVersionPipelineRunRequest, - auth_context: AuthContext = Security(authorize), -) -> ModelVersionPipelineRunResponse: - """Create a new model version to pipeline run link. - - Args: - workspace_name_or_id: Name or ID of the workspace. - model_version_id: ID of the model version. - model_version_pipeline_run_link: The model version to pipeline run link to create. - auth_context: Authentication context. - - Returns: - - If Model Version to Pipeline Run Link already exists - returns the existing link. - - Otherwise, returns the newly created model version to pipeline run link. - - Raises: - IllegalOperationError: If the workspace or user specified in the - model version does not match the current workspace or authenticated - user. - """ - workspace = zen_store().get_workspace(workspace_name_or_id) - if str(model_version_id) != str( - model_version_pipeline_run_link.model_version - ): - raise IllegalOperationError( - f"The model version id in your path `{model_version_id}` does not " - f"match the model version specified in the request model " - f"`{model_version_pipeline_run_link.model_version}`" - ) - - if model_version_pipeline_run_link.workspace != workspace.id: - raise IllegalOperationError( - "Creating model versions outside of the workspace scope " - f"of this endpoint `{workspace_name_or_id}` is " - f"not supported." - ) - if model_version_pipeline_run_link.user != auth_context.user.id: - raise IllegalOperationError( - "Creating models for a user other than yourself " - "is not supported." - ) - - model_version = zen_store().get_model_version(model_version_id) - verify_permission_for_model(model_version, action=Action.UPDATE) - - mv = zen_store().create_model_version_pipeline_run_link( - model_version_pipeline_run_link - ) - return mv - - @router.post( WORKSPACES + "/{workspace_name_or_id}" + SERVICES, response_model=ServiceResponse, diff --git a/src/zenml/zen_stores/migrations/versions/5591a5051f51_simplify_model_version_artifacts2.py b/src/zenml/zen_stores/migrations/versions/5591a5051f51_simplify_model_version_artifacts2.py new file mode 100644 index 00000000000..d05a3d6aae0 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/5591a5051f51_simplify_model_version_artifacts2.py @@ -0,0 +1,39 @@ +"""Simplify model version artifacts2 [5591a5051f51]. + +Revision ID: 5591a5051f51 +Revises: ec6307720f92 +Create Date: 2024-11-06 16:51:01.346080 + +""" +import sqlmodel +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '5591a5051f51' +down_revision = 'ec6307720f92' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('model_versions_runs', schema=None) as batch_op: + batch_op.add_column(sa.Column('user_id', sa.CHAR(length=32), nullable=True)) + batch_op.add_column(sa.Column('workspace_id', sa.CHAR(length=32), nullable=False)) + batch_op.add_column(sa.Column('model_id', sa.CHAR(length=32), nullable=False)) + batch_op.create_foreign_key('fk_model_versions_runs_user_id_user', 'user', ['user_id'], ['id'], ondelete='SET NULL') + batch_op.create_foreign_key('fk_model_versions_runs_workspace_id_workspace', 'workspace', ['workspace_id'], ['id'], ondelete='CASCADE') + batch_op.create_foreign_key('fk_model_versions_runs_model_id_model', 'model', ['model_id'], ['id'], ondelete='CASCADE') + + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/migrations/versions/ec6307720f92_simplify_model_version_artifacts.py b/src/zenml/zen_stores/migrations/versions/ec6307720f92_simplify_model_version_links.py similarity index 84% rename from src/zenml/zen_stores/migrations/versions/ec6307720f92_simplify_model_version_artifacts.py rename to src/zenml/zen_stores/migrations/versions/ec6307720f92_simplify_model_version_links.py index 1ff6b561b76..247d8a52191 100644 --- a/src/zenml/zen_stores/migrations/versions/ec6307720f92_simplify_model_version_artifacts.py +++ b/src/zenml/zen_stores/migrations/versions/ec6307720f92_simplify_model_version_links.py @@ -1,4 +1,4 @@ -"""Simplify model version artifacts [ec6307720f92]. +"""Simplify model version links [ec6307720f92]. Revision ID: ec6307720f92 Revises: c22561cbb3a9 @@ -86,6 +86,14 @@ def upgrade() -> None: batch_op.drop_column("workspace_id") batch_op.drop_column("is_model_artifact") + with op.batch_alter_table('model_versions_runs', schema=None) as batch_op: + batch_op.drop_constraint('fk_model_versions_runs_model_id_model', type_='foreignkey') + batch_op.drop_constraint('fk_model_versions_runs_workspace_id_workspace', type_='foreignkey') + batch_op.drop_constraint('fk_model_versions_runs_user_id_user', type_='foreignkey') + batch_op.drop_column('model_id') + batch_op.drop_column('workspace_id') + batch_op.drop_column('user_id') + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 8bf8eec4964..8d1f7903d99 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -3802,10 +3802,10 @@ def create_model_version_pipeline_run_link( - Otherwise, returns the newly created model version to pipeline run link. """ - return self._create_workspace_scoped_resource( + return self._create_resource( resource=model_version_pipeline_run_link, response_model=ModelVersionPipelineRunResponse, - route=f"{MODEL_VERSIONS}/{model_version_pipeline_run_link.model_version}{RUNS}", + route=MODEL_VERSION_PIPELINE_RUNS, ) def list_model_version_pipeline_run_links( diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 8c1f72ec0e9..24886efdc5b 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -108,10 +108,6 @@ class ModelSchema(NamedSchema, table=True): back_populates="model", sa_relationship_kwargs={"cascade": "delete"}, ) - pipeline_run_links: List["ModelVersionPipelineRunSchema"] = Relationship( - back_populates="model", - sa_relationship_kwargs={"cascade": "delete"}, - ) @classmethod def from_request(cls, model_request: ModelRequest) -> "ModelSchema": @@ -538,39 +534,6 @@ class ModelVersionPipelineRunSchema(BaseSchema, table=True): __tablename__ = "model_versions_runs" - workspace_id: UUID = build_foreign_key_field( - source=__tablename__, - target=WorkspaceSchema.__tablename__, - source_column="workspace_id", - target_column="id", - ondelete="CASCADE", - nullable=False, - ) - workspace: "WorkspaceSchema" = Relationship( - back_populates="model_versions_pipeline_runs_links" - ) - - user_id: Optional[UUID] = build_foreign_key_field( - source=__tablename__, - target=UserSchema.__tablename__, - source_column="user_id", - target_column="id", - ondelete="SET NULL", - nullable=True, - ) - user: Optional["UserSchema"] = Relationship( - back_populates="model_versions_pipeline_runs_links" - ) - - model_id: UUID = build_foreign_key_field( - source=__tablename__, - target=ModelSchema.__tablename__, - source_column="model_id", - target_column="id", - ondelete="CASCADE", - nullable=False, - ) - model: "ModelSchema" = Relationship(back_populates="pipeline_run_links") model_version_id: UUID = build_foreign_key_field( source=__tablename__, target=ModelVersionSchema.__tablename__, @@ -616,9 +579,6 @@ def from_request( The converted schema. """ return cls( - workspace_id=model_version_pipeline_run_request.workspace, - user_id=model_version_pipeline_run_request.user, - model_id=model_version_pipeline_run_request.model, model_version_id=model_version_pipeline_run_request.model_version, pipeline_run_id=model_version_pipeline_run_request.pipeline_run, ) @@ -645,7 +605,6 @@ def to_model( body=ModelVersionPipelineRunResponseBody( created=self.created, updated=self.updated, - model=self.model_id, model_version=self.model_version_id, pipeline_run=self.pipeline_run.to_model(), ), diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index 0d139780c79..0aaa8bf7ac0 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -143,9 +143,6 @@ class UserSchema(NamedSchema, table=True): model_versions: List["ModelVersionSchema"] = Relationship( back_populates="user", ) - model_versions_pipeline_runs_links: List[ - "ModelVersionPipelineRunSchema" - ] = Relationship(back_populates="user") auth_devices: List["OAuthDeviceSchema"] = Relationship( back_populates="user", sa_relationship_kwargs={"cascade": "delete"}, diff --git a/src/zenml/zen_stores/schemas/workspace_schemas.py b/src/zenml/zen_stores/schemas/workspace_schemas.py index 0c87813de4b..4443a3da397 100644 --- a/src/zenml/zen_stores/schemas/workspace_schemas.py +++ b/src/zenml/zen_stores/schemas/workspace_schemas.py @@ -141,12 +141,6 @@ class WorkspaceSchema(NamedSchema, table=True): back_populates="workspace", sa_relationship_kwargs={"cascade": "delete"}, ) - model_versions_pipeline_runs_links: List[ - "ModelVersionPipelineRunSchema" - ] = Relationship( - back_populates="workspace", - sa_relationship_kwargs={"cascade": "delete"}, - ) @classmethod def from_request(cls, workspace: WorkspaceRequest) -> "WorkspaceSchema":