Skip to content

Commit

Permalink
Prevent some race conditions (#3167)
Browse files Browse the repository at this point in the history
* Add pipeline and run unique constraints

* Add model unique constraint

* Cleanup unused methods

* Increment MV version server side

* Fix log message

* Use name instead of number for model version unique constraint

* Improve model version unique constraints

* Separate exceptions

* Add more migrations

* Convert new exception to 500 status code

* Format

* Fix mypy

* Docstrings

* Improve logic for model version creation

* Add missing break

* Add jitter to exponential backoff

* Dont check for DB columns in integrity errors

* Add workspace to model name unique constraint

* Docstring

* Cast to str

* Fix alembic order

* Implement migration of duplicate values before DB migration

* Auto-update of Starter template

* Auto-update of E2E template

* Auto-update of NLP template

* Add logs during migration

* Type annotations in migration

* Formatting

* More mypy

* Update alembic order

---------

Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
schustmi and actions-user authored Nov 11, 2024
1 parent a5d4531 commit c6dc2be
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 186 deletions.
4 changes: 4 additions & 0 deletions src/zenml/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ class EntityExistsError(ZenMLBaseException):
"""Raised when trying to register an entity that already exists."""


class EntityCreationError(ZenMLBaseException, RuntimeError):
"""Raised when failing to create an entity."""


class ActionExistsError(EntityExistsError):
"""Raised when registering an action with a name that already exists."""

Expand Down
78 changes: 16 additions & 62 deletions src/zenml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Model user facing interface to pass into pipeline or step."""

import datetime
import time
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -28,7 +27,6 @@

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator

from zenml.constants import MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION
from zenml.enums import MetadataResourceTypes, ModelStages
from zenml.exceptions import EntityExistsError
from zenml.logger import get_logger
Expand Down Expand Up @@ -527,14 +525,6 @@ def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]:
data["suppress_class_validation_warnings"] = True
return data

def _validate_config_in_runtime(self) -> "ModelVersionResponse":
"""Validate that config doesn't conflict with runtime environment.
Returns:
The model version based on configuration.
"""
return self._get_or_create_model_version()

def _get_or_create_model(self) -> "ModelResponse":
"""This method should get or create a model from Model Control Plane.
Expand Down Expand Up @@ -678,16 +668,6 @@ def _get_or_create_model_version(
if isinstance(self.version, str):
self.version = format_name_template(self.version)

zenml_client = Client()
model_version_request = ModelVersionRequest(
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
name=str(self.version) if self.version else None,
description=self.description,
model=model.id,
tags=self.tags,
)
mv_request = ModelVersionRequest.model_validate(model_version_request)
try:
if self.version or self.model_version_id:
model_version = self._get_model_version()
Expand Down Expand Up @@ -717,60 +697,34 @@ def _get_or_create_model_version(
" as an example. You can explore model versions using "
f"`zenml model version list -n {self.name}` CLI command."
)
retries_made = 0
for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
try:
model_version = (
zenml_client.zen_store.create_model_version(
model_version=mv_request
)
)
break
except EntityExistsError as e:
if i == MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - 1:
raise RuntimeError(
f"Failed to create model version "
f"`{self.version if self.version else 'new'}` "
f"in model `{self.name}`. Retried {retries_made} times. "
"This could be driven by exceptionally high concurrency of "
"pipeline runs. Please, reach out to us on ZenML Slack for support."
) from e
# smoothed exponential back-off, it will go as 0.2, 0.3,
# 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ...
sleep = 0.2 * 1.5**i
logger.debug(
f"Failed to create new model version for "
f"model `{self.name}`. Retrying in {sleep}..."
)
time.sleep(sleep)
retries_made += 1
self.version = model_version.name

client = Client()
model_version_request = ModelVersionRequest(
user=client.active_user.id,
workspace=client.active_workspace.id,
name=str(self.version) if self.version else None,
description=self.description,
model=model.id,
tags=self.tags,
)
model_version = client.zen_store.create_model_version(
model_version=model_version_request
)

self._created_model_version = True

logger.info(
"Created new model version `%s` for model `%s`.",
self.version,
model_version.name,
self.name,
)

self.version = model_version.name
self.model_version_id = model_version.id
self._model_id = model_version.model.id
self._number = model_version.number
return model_version

def _merge(self, model: "Model") -> None:
self.license = self.license or model.license
self.description = self.description or model.description
self.audience = self.audience or model.audience
self.use_cases = self.use_cases or model.use_cases
self.limitations = self.limitations or model.limitations
self.trade_offs = self.trade_offs or model.trade_offs
self.ethics = self.ethics or model.ethics
if model.tags is not None:
self.tags = list(
{t for t in self.tags or []}.union(set(model.tags))
)

def __hash__(self) -> int:
"""Get hash of the `Model`.
Expand Down
2 changes: 2 additions & 0 deletions src/zenml/zen_server/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
CredentialsNotValid,
DoesNotExistException,
DuplicateRunNameError,
EntityCreationError,
EntityExistsError,
IllegalOperationError,
MethodNotAllowedError,
Expand Down Expand Up @@ -93,6 +94,7 @@ class ErrorModel(BaseModel):
# 422 Unprocessable Entity
(ValueError, 422),
# 500 Internal Server Error
(EntityCreationError, 500),
(RuntimeError, 500),
# 501 Not Implemented,
(NotImplementedError, 501),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
"""Add pipeline, model and run unique constraints [904464ea4041].
Revision ID: 904464ea4041
Revises: b557b2871693
Create Date: 2024-11-04 10:27:05.450092
"""

from collections import defaultdict
from typing import Any, Dict, Set

import sqlalchemy as sa
from alembic import op

from zenml.logger import get_logger

logger = get_logger(__name__)

# revision identifiers, used by Alembic.
revision = "904464ea4041"
down_revision = "b557b2871693"
branch_labels = None
depends_on = None


def resolve_duplicate_entities() -> None:
"""Resolve duplicate entities."""
connection = op.get_bind()
meta = sa.MetaData()
meta.reflect(
bind=connection,
only=("pipeline_run", "pipeline", "model", "model_version"),
)

# Remove duplicate names for runs, pipelines and models
for table_name in ["pipeline_run", "pipeline", "model"]:
table = sa.Table(table_name, meta)
result = connection.execute(
sa.select(table.c.id, table.c.name, table.c.workspace_id)
).all()
existing: Dict[str, Set[str]] = defaultdict(set)

for id_, name, workspace_id in result:
names_in_workspace = existing[workspace_id]

if name in names_in_workspace:
new_name = f"{name}_{id_[:6]}"
logger.warning(
"Migrating %s name from %s to %s to resolve duplicate name.",
table_name,
name,
new_name,
)
connection.execute(
sa.update(table)
.where(table.c.id == id_)
.values(name=new_name)
)
names_in_workspace.add(new_name)
else:
names_in_workspace.add(name)

# Remove duplicate names and version numbers for model versions
model_version_table = sa.Table("model_version", meta)
result = connection.execute(
sa.select(
model_version_table.c.id,
model_version_table.c.name,
model_version_table.c.number,
model_version_table.c.model_id,
)
).all()

existing_names: Dict[str, Set[str]] = defaultdict(set)
existing_numbers: Dict[str, Set[int]] = defaultdict(set)

needs_update = []

for id_, name, number, model_id in result:
names_for_model = existing_names[model_id]
numbers_for_model = existing_numbers[model_id]

needs_new_name = name in names_for_model
needs_new_number = number in numbers_for_model

if needs_new_name or needs_new_number:
needs_update.append(
(id_, name, number, model_id, needs_new_name, needs_new_number)
)

names_for_model.add(name)
numbers_for_model.add(number)

for (
id_,
name,
number,
model_id,
needs_new_name,
needs_new_number,
) in needs_update:
values: Dict[str, Any] = {}

is_numeric_version = str(number) == name
next_numeric_version = max(existing_numbers[model_id]) + 1

if is_numeric_version:
# No matter if the name or number clashes, we need to update both
values["number"] = next_numeric_version
values["name"] = str(next_numeric_version)
existing_numbers[model_id].add(next_numeric_version)
logger.warning(
"Migrating model version %s to %s to resolve duplicate name.",
name,
values["name"],
)
else:
if needs_new_name:
values["name"] = f"{name}_{id_[:6]}"
logger.warning(
"Migrating model version %s to %s to resolve duplicate name.",
name,
values["name"],
)

if needs_new_number:
values["number"] = next_numeric_version
existing_numbers[model_id].add(next_numeric_version)

connection.execute(
sa.update(model_version_table)
.where(model_version_table.c.id == id_)
.values(**values)
)


def upgrade() -> None:
"""Upgrade database schema and/or data, creating a new revision."""
# ### commands auto generated by Alembic - please adjust! ###

resolve_duplicate_entities()

with op.batch_alter_table("pipeline", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique_pipeline_name_in_workspace", ["name", "workspace_id"]
)

with op.batch_alter_table("pipeline_run", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique_run_name_in_workspace", ["name", "workspace_id"]
)

with op.batch_alter_table("model", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique_model_name_in_workspace", ["name", "workspace_id"]
)

with op.batch_alter_table("model_version", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique_version_for_model_id", ["name", "model_id"]
)
batch_op.create_unique_constraint(
"unique_version_number_for_model_id", ["number", "model_id"]
)
# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade database schema and/or data back to the previous revision."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("model_version", schema=None) as batch_op:
batch_op.drop_constraint(
"unique_version_number_for_model_id", type_="unique"
)
batch_op.drop_constraint("unique_version_for_model_id", type_="unique")

with op.batch_alter_table("model", schema=None) as batch_op:
batch_op.drop_constraint(
"unique_model_name_in_workspace", type_="unique"
)

with op.batch_alter_table("pipeline_run", schema=None) as batch_op:
batch_op.drop_constraint(
"unique_run_name_in_workspace", type_="unique"
)

with op.batch_alter_table("pipeline", schema=None) as batch_op:
batch_op.drop_constraint(
"unique_pipeline_name_in_workspace", type_="unique"
)

# ### end Alembic commands ###
26 changes: 25 additions & 1 deletion src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from uuid import UUID

from pydantic import ConfigDict
from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column
from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column, UniqueConstraint
from sqlmodel import Field, Relationship

from zenml.enums import MetadataResourceTypes, TaggableResourceTypes
Expand Down Expand Up @@ -62,6 +62,13 @@ class ModelSchema(NamedSchema, table=True):
"""SQL Model for model."""

__tablename__ = "model"
__table_args__ = (
UniqueConstraint(
"name",
"workspace_id",
name="unique_model_name_in_workspace",
),
)

workspace_id: UUID = build_foreign_key_field(
source=__tablename__,
Expand Down Expand Up @@ -220,6 +227,23 @@ class ModelVersionSchema(NamedSchema, table=True):
"""SQL Model for model version."""

__tablename__ = MODEL_VERSION_TABLENAME
__table_args__ = (
# We need two unique constraints here:
# - The first to ensure that each model version for a
# model has a unique version number
# - The second one to ensure that explicit names given by
# users are unique
UniqueConstraint(
"number",
"model_id",
name="unique_version_number_for_model_id",
),
UniqueConstraint(
"name",
"model_id",
name="unique_version_for_model_id",
),
)

workspace_id: UUID = build_foreign_key_field(
source=__tablename__,
Expand Down
Loading

0 comments on commit c6dc2be

Please sign in to comment.