diff --git a/src/zenml/model/utils.py b/src/zenml/model/utils.py index 52bf969b2fa..c8150d179ed 100644 --- a/src/zenml/model/utils.py +++ b/src/zenml/model/utils.py @@ -85,7 +85,6 @@ def link_artifact_version_to_model_version( client = Client() client.zen_store.create_model_version_artifact_link( ModelVersionArtifactRequest( - user=client.active_user.id, artifact_version=artifact_version.id, model_version=model_version.id, ) diff --git a/src/zenml/models/v2/core/model_version_artifact.py b/src/zenml/models/v2/core/model_version_artifact.py index 875d10597fa..3949e6e1f27 100644 --- a/src/zenml/models/v2/core/model_version_artifact.py +++ b/src/zenml/models/v2/core/model_version_artifact.py @@ -22,11 +22,11 @@ from zenml.models.v2.base.base import ( BaseDatedResponseBody, BaseIdentifiedResponse, + BaseRequest, BaseResponseMetadata, BaseResponseResources, ) -from zenml.models.v2.base.filter import StrFilter -from zenml.models.v2.base.scoped import UserScopedFilter, UserScopedRequest +from zenml.models.v2.base.filter import BaseFilter, StrFilter if TYPE_CHECKING: from sqlalchemy.sql.elements import ColumnElement @@ -37,7 +37,7 @@ # ------------------ Request Model ------------------ -class ModelVersionArtifactRequest(UserScopedRequest): +class ModelVersionArtifactRequest(BaseRequest): """Request model for links between model versions and artifacts.""" model_version: UUID @@ -109,12 +109,12 @@ def artifact_version(self) -> "ArtifactVersionResponse": # ------------------ Filter Model ------------------ -class ModelVersionArtifactFilter(UserScopedFilter): +class ModelVersionArtifactFilter(BaseFilter): """Model version pipeline run links filter model.""" # Artifact name and type are not DB fields and need to be handled separately FILTER_EXCLUDE_FIELDS = [ - *UserScopedFilter.FILTER_EXCLUDE_FIELDS, + *BaseFilter.FILTER_EXCLUDE_FIELDS, "artifact_name", "only_data_artifacts", "only_model_artifacts", @@ -123,23 +123,17 @@ class ModelVersionArtifactFilter(UserScopedFilter): "user", ] CLI_EXCLUDE_FIELDS = [ - *UserScopedFilter.CLI_EXCLUDE_FIELDS, + *BaseFilter.CLI_EXCLUDE_FIELDS, "only_data_artifacts", "only_model_artifacts", "only_deployment_artifacts", "has_custom_name", "model_id", "model_version_id", - "user_id", "updated", "id", ] - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="The user of the Model Version", - union_mode="left_to_right", - ) model_id: Optional[Union[UUID, str]] = Field( default=None, description="Filter by model ID", diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py index ddde60e35f9..eb587e51d8c 100644 --- a/src/zenml/zen_server/routers/model_versions_endpoints.py +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -29,6 +29,7 @@ ) from zenml.models import ( ModelVersionArtifactFilter, + ModelVersionArtifactRequest, ModelVersionArtifactResponse, ModelVersionFilter, ModelVersionPipelineRunFilter, @@ -198,6 +199,34 @@ def delete_model_version( ) +@model_version_artifacts_router.post( + "", + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def create_model_version_artifact_link( + model_version_artifact_link: ModelVersionArtifactRequest, + _: AuthContext = Security(authorize), +) -> ModelVersionArtifactResponse: + """Create a new model version to artifact link. + + Args: + model_version_artifact_link: The model version to artifact link to create. + + Returns: + The created model version to artifact link. + """ + model_version = zen_store().get_model_version( + model_version_artifact_link.model_version + ) + verify_permission_for_model(model_version, action=Action.UPDATE) + + mv = zen_store().create_model_version_artifact_link( + model_version_artifact_link + ) + return mv + + @model_version_artifacts_router.get( "", response_model=Page[ModelVersionArtifactResponse], diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 7b9782e87c3..0af5a3f991e 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -20,7 +20,6 @@ from zenml.constants import ( API, - ARTIFACTS, CODE_REPOSITORIES, GET_OR_CREATE, MODEL_VERSIONS, @@ -54,8 +53,6 @@ ComponentResponse, ModelRequest, ModelResponse, - ModelVersionArtifactRequest, - ModelVersionArtifactResponse, ModelVersionPipelineRunRequest, ModelVersionPipelineRunResponse, ModelVersionRequest, @@ -1443,67 +1440,6 @@ def create_model_version( ) -@router.post( - WORKSPACES - + "/{workspace_name_or_id}" - + MODEL_VERSIONS - + "/{model_version_id}" - + ARTIFACTS, - response_model=ModelVersionArtifactResponse, - responses={401: error_response, 409: error_response, 422: error_response}, -) -@handle_exceptions -def create_model_version_artifact_link( - workspace_name_or_id: Union[str, UUID], - model_version_id: UUID, - model_version_artifact_link: ModelVersionArtifactRequest, - auth_context: AuthContext = Security(authorize), -) -> ModelVersionArtifactResponse: - """Create a new model version to artifact link. - - Args: - workspace_name_or_id: Name or ID of the workspace. - model_version_id: ID of the model version. - model_version_artifact_link: The model version to artifact link to create. - auth_context: Authentication context. - - Returns: - The created model version to artifact link. - - Raises: - IllegalOperationError: If the workspace or user specified in the - model version does not match the current workspace or authenticated - user. - """ - workspace = zen_store().get_workspace(workspace_name_or_id) - if str(model_version_id) != str(model_version_artifact_link.model_version): - raise IllegalOperationError( - f"The model version id in your path `{model_version_id}` does not " - f"match the model version specified in the request model " - f"`{model_version_artifact_link.model_version}`" - ) - - if model_version_artifact_link.workspace != workspace.id: - raise IllegalOperationError( - "Creating model version to artifact links outside of the workspace scope " - f"of this endpoint `{workspace_name_or_id}` is " - f"not supported." - ) - if model_version_artifact_link.user != auth_context.user.id: - raise IllegalOperationError( - "Creating model to artifact links for a user other than yourself " - "is not supported." - ) - - model_version = zen_store().get_model_version(model_version_id) - verify_permission_for_model(model_version, action=Action.UPDATE) - - mv = zen_store().create_model_version_artifact_link( - model_version_artifact_link - ) - return mv - - @router.post( WORKSPACES + "/{workspace_name_or_id}" diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 0b4836c3f7c..8bf8eec4964 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -3722,10 +3722,10 @@ def create_model_version_artifact_link( Returns: The newly created model version to artifact link. """ - return self._create_workspace_scoped_resource( + return self._create_resource( resource=model_version_artifact_link, response_model=ModelVersionArtifactResponse, - route=f"{MODEL_VERSIONS}/{model_version_artifact_link.model_version}{ARTIFACTS}", + route=MODEL_VERSION_ARTIFACTS, ) def list_model_version_artifact_links( diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 94928d20090..8c1f72ec0e9 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -21,7 +21,11 @@ from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column from sqlmodel import Field, Relationship -from zenml.enums import MetadataResourceTypes, TaggableResourceTypes +from zenml.enums import ( + ArtifactType, + MetadataResourceTypes, + TaggableResourceTypes, +) from zenml.models import ( BaseResponseMetadata, ModelRequest, @@ -348,11 +352,14 @@ def to_model( artifact_name = artifact_link.artifact_version.artifact.name artifact_version = str(artifact_link.artifact_version.version) artifact_version_id = artifact_link.artifact_version.id - if artifact_link.is_model_artifact: + if artifact_link.artifact_version.type == ArtifactType.MODEL.value: model_artifact_ids.setdefault(artifact_name, {}).update( {str(artifact_version): artifact_version_id} ) - elif artifact_link.is_deployment_artifact: + if ( + artifact_link.artifact_version.type + == ArtifactType.SERVICE.value + ): deployment_artifact_ids.setdefault(artifact_name, {}).update( {str(artifact_version): artifact_version_id} ) @@ -448,18 +455,6 @@ class ModelVersionArtifactSchema(BaseSchema, table=True): __tablename__ = "model_versions_artifacts" - user_id: Optional[UUID] = build_foreign_key_field( - source=__tablename__, - target=UserSchema.__tablename__, - source_column="user_id", - target_column="id", - ondelete="SET NULL", - nullable=True, - ) - user: Optional["UserSchema"] = Relationship( - back_populates="model_versions_artifacts_links" - ) - model_version_id: UUID = build_foreign_key_field( source=__tablename__, target=ModelVersionSchema.__tablename__, @@ -505,7 +500,6 @@ def from_request( The converted schema. """ return cls( - user_id=model_version_artifact_request.user, model_version_id=model_version_artifact_request.model_version, artifact_version_id=model_version_artifact_request.artifact_version, ) diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index 9f0c0f0cdd8..0d139780c79 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -44,7 +44,6 @@ EventSourceSchema, FlavorSchema, ModelSchema, - ModelVersionArtifactSchema, ModelVersionPipelineRunSchema, ModelVersionSchema, OAuthDeviceSchema, @@ -144,9 +143,6 @@ class UserSchema(NamedSchema, table=True): model_versions: List["ModelVersionSchema"] = Relationship( back_populates="user", ) - model_versions_artifacts_links: List["ModelVersionArtifactSchema"] = ( - Relationship(back_populates="user") - ) model_versions_pipeline_runs_links: List[ "ModelVersionPipelineRunSchema" ] = Relationship(back_populates="user")