Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Nov 6, 2024
1 parent 3929feb commit 58917fb
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 99 deletions.
1 change: 0 additions & 1 deletion src/zenml/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def link_artifact_version_to_model_version(
client = Client()
client.zen_store.create_model_version_artifact_link(
ModelVersionArtifactRequest(
user=client.active_user.id,
artifact_version=artifact_version.id,
model_version=model_version.id,
)
Expand Down
18 changes: 6 additions & 12 deletions src/zenml/models/v2/core/model_version_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from zenml.models.v2.base.base import (
BaseDatedResponseBody,
BaseIdentifiedResponse,
BaseRequest,
BaseResponseMetadata,
BaseResponseResources,
)
from zenml.models.v2.base.filter import StrFilter
from zenml.models.v2.base.scoped import UserScopedFilter, UserScopedRequest
from zenml.models.v2.base.filter import BaseFilter, StrFilter

if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
Expand All @@ -37,7 +37,7 @@
# ------------------ Request Model ------------------


class ModelVersionArtifactRequest(UserScopedRequest):
class ModelVersionArtifactRequest(BaseRequest):
"""Request model for links between model versions and artifacts."""

model_version: UUID
Expand Down Expand Up @@ -109,12 +109,12 @@ def artifact_version(self) -> "ArtifactVersionResponse":
# ------------------ Filter Model ------------------


class ModelVersionArtifactFilter(UserScopedFilter):
class ModelVersionArtifactFilter(BaseFilter):
"""Model version pipeline run links filter model."""

# Artifact name and type are not DB fields and need to be handled separately
FILTER_EXCLUDE_FIELDS = [
*UserScopedFilter.FILTER_EXCLUDE_FIELDS,
*BaseFilter.FILTER_EXCLUDE_FIELDS,
"artifact_name",
"only_data_artifacts",
"only_model_artifacts",
Expand All @@ -123,23 +123,17 @@ class ModelVersionArtifactFilter(UserScopedFilter):
"user",
]
CLI_EXCLUDE_FIELDS = [
*UserScopedFilter.CLI_EXCLUDE_FIELDS,
*BaseFilter.CLI_EXCLUDE_FIELDS,
"only_data_artifacts",
"only_model_artifacts",
"only_deployment_artifacts",
"has_custom_name",
"model_id",
"model_version_id",
"user_id",
"updated",
"id",
]

user_id: Optional[Union[UUID, str]] = Field(
default=None,
description="The user of the Model Version",
union_mode="left_to_right",
)
model_id: Optional[Union[UUID, str]] = Field(
default=None,
description="Filter by model ID",
Expand Down
29 changes: 29 additions & 0 deletions src/zenml/zen_server/routers/model_versions_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from zenml.models import (
ModelVersionArtifactFilter,
ModelVersionArtifactRequest,
ModelVersionArtifactResponse,
ModelVersionFilter,
ModelVersionPipelineRunFilter,
Expand Down Expand Up @@ -198,6 +199,34 @@ def delete_model_version(
)


@model_version_artifacts_router.post(
"",
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_model_version_artifact_link(
model_version_artifact_link: ModelVersionArtifactRequest,
_: AuthContext = Security(authorize),
) -> ModelVersionArtifactResponse:
"""Create a new model version to artifact link.
Args:
model_version_artifact_link: The model version to artifact link to create.
Returns:
The created model version to artifact link.
"""
model_version = zen_store().get_model_version(
model_version_artifact_link.model_version
)
verify_permission_for_model(model_version, action=Action.UPDATE)

mv = zen_store().create_model_version_artifact_link(
model_version_artifact_link
)
return mv


@model_version_artifacts_router.get(
"",
response_model=Page[ModelVersionArtifactResponse],
Expand Down
64 changes: 0 additions & 64 deletions src/zenml/zen_server/routers/workspaces_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from zenml.constants import (
API,
ARTIFACTS,
CODE_REPOSITORIES,
GET_OR_CREATE,
MODEL_VERSIONS,
Expand Down Expand Up @@ -54,8 +53,6 @@
ComponentResponse,
ModelRequest,
ModelResponse,
ModelVersionArtifactRequest,
ModelVersionArtifactResponse,
ModelVersionPipelineRunRequest,
ModelVersionPipelineRunResponse,
ModelVersionRequest,
Expand Down Expand Up @@ -1443,67 +1440,6 @@ def create_model_version(
)


@router.post(
WORKSPACES
+ "/{workspace_name_or_id}"
+ MODEL_VERSIONS
+ "/{model_version_id}"
+ ARTIFACTS,
response_model=ModelVersionArtifactResponse,
responses={401: error_response, 409: error_response, 422: error_response},
)
@handle_exceptions
def create_model_version_artifact_link(
workspace_name_or_id: Union[str, UUID],
model_version_id: UUID,
model_version_artifact_link: ModelVersionArtifactRequest,
auth_context: AuthContext = Security(authorize),
) -> ModelVersionArtifactResponse:
"""Create a new model version to artifact link.
Args:
workspace_name_or_id: Name or ID of the workspace.
model_version_id: ID of the model version.
model_version_artifact_link: The model version to artifact link to create.
auth_context: Authentication context.
Returns:
The created model version to artifact link.
Raises:
IllegalOperationError: If the workspace or user specified in the
model version does not match the current workspace or authenticated
user.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
if str(model_version_id) != str(model_version_artifact_link.model_version):
raise IllegalOperationError(
f"The model version id in your path `{model_version_id}` does not "
f"match the model version specified in the request model "
f"`{model_version_artifact_link.model_version}`"
)

if model_version_artifact_link.workspace != workspace.id:
raise IllegalOperationError(
"Creating model version to artifact links outside of the workspace scope "
f"of this endpoint `{workspace_name_or_id}` is "
f"not supported."
)
if model_version_artifact_link.user != auth_context.user.id:
raise IllegalOperationError(
"Creating model to artifact links for a user other than yourself "
"is not supported."
)

model_version = zen_store().get_model_version(model_version_id)
verify_permission_for_model(model_version, action=Action.UPDATE)

mv = zen_store().create_model_version_artifact_link(
model_version_artifact_link
)
return mv


@router.post(
WORKSPACES
+ "/{workspace_name_or_id}"
Expand Down
4 changes: 2 additions & 2 deletions src/zenml/zen_stores/rest_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3722,10 +3722,10 @@ def create_model_version_artifact_link(
Returns:
The newly created model version to artifact link.
"""
return self._create_workspace_scoped_resource(
return self._create_resource(
resource=model_version_artifact_link,
response_model=ModelVersionArtifactResponse,
route=f"{MODEL_VERSIONS}/{model_version_artifact_link.model_version}{ARTIFACTS}",
route=MODEL_VERSION_ARTIFACTS,
)

def list_model_version_artifact_links(
Expand Down
26 changes: 10 additions & 16 deletions src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column
from sqlmodel import Field, Relationship

from zenml.enums import MetadataResourceTypes, TaggableResourceTypes
from zenml.enums import (
ArtifactType,
MetadataResourceTypes,
TaggableResourceTypes,
)
from zenml.models import (
BaseResponseMetadata,
ModelRequest,
Expand Down Expand Up @@ -348,11 +352,14 @@ def to_model(
artifact_name = artifact_link.artifact_version.artifact.name
artifact_version = str(artifact_link.artifact_version.version)
artifact_version_id = artifact_link.artifact_version.id
if artifact_link.is_model_artifact:
if artifact_link.artifact_version.type == ArtifactType.MODEL.value:
model_artifact_ids.setdefault(artifact_name, {}).update(
{str(artifact_version): artifact_version_id}
)
elif artifact_link.is_deployment_artifact:
if (
artifact_link.artifact_version.type
== ArtifactType.SERVICE.value
):
deployment_artifact_ids.setdefault(artifact_name, {}).update(
{str(artifact_version): artifact_version_id}
)
Expand Down Expand Up @@ -448,18 +455,6 @@ class ModelVersionArtifactSchema(BaseSchema, table=True):

__tablename__ = "model_versions_artifacts"

user_id: Optional[UUID] = build_foreign_key_field(
source=__tablename__,
target=UserSchema.__tablename__,
source_column="user_id",
target_column="id",
ondelete="SET NULL",
nullable=True,
)
user: Optional["UserSchema"] = Relationship(
back_populates="model_versions_artifacts_links"
)

model_version_id: UUID = build_foreign_key_field(
source=__tablename__,
target=ModelVersionSchema.__tablename__,
Expand Down Expand Up @@ -505,7 +500,6 @@ def from_request(
The converted schema.
"""
return cls(
user_id=model_version_artifact_request.user,
model_version_id=model_version_artifact_request.model_version,
artifact_version_id=model_version_artifact_request.artifact_version,
)
Expand Down
4 changes: 0 additions & 4 deletions src/zenml/zen_stores/schemas/user_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
EventSourceSchema,
FlavorSchema,
ModelSchema,
ModelVersionArtifactSchema,
ModelVersionPipelineRunSchema,
ModelVersionSchema,
OAuthDeviceSchema,
Expand Down Expand Up @@ -144,9 +143,6 @@ class UserSchema(NamedSchema, table=True):
model_versions: List["ModelVersionSchema"] = Relationship(
back_populates="user",
)
model_versions_artifacts_links: List["ModelVersionArtifactSchema"] = (
Relationship(back_populates="user")
)
model_versions_pipeline_runs_links: List[
"ModelVersionPipelineRunSchema"
] = Relationship(back_populates="user")
Expand Down

0 comments on commit 58917fb

Please sign in to comment.