Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent some race conditions #3167

Merged
merged 36 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
250e713
Add pipeline and run unique constraints
schustmi Nov 4, 2024
a8cec79
Add model unique constraint
schustmi Nov 4, 2024
3965539
Cleanup unused methods
schustmi Nov 4, 2024
f120877
Increment MV version server side
schustmi Nov 4, 2024
f5ec6b9
Fix log message
schustmi Nov 4, 2024
40661e8
Use name instead of number for model version unique constraint
schustmi Nov 4, 2024
c3b8363
Improve model version unique constraints
schustmi Nov 4, 2024
a4c6875
Separate exceptions
schustmi Nov 4, 2024
6e35f2d
Add more migrations
schustmi Nov 5, 2024
e9ef50e
Convert new exception to 500 status code
schustmi Nov 5, 2024
31bbd11
Format
schustmi Nov 5, 2024
524034f
Fix mypy
schustmi Nov 6, 2024
ccdc499
Docstrings
schustmi Nov 6, 2024
56f6bfd
Merge branch 'develop' into bugfix/PRD-704-concurrent-runs-race-condi…
schustmi Nov 6, 2024
d7b4821
Improve logic for model version creation
schustmi Nov 6, 2024
3446f04
Merge branch 'develop' into bugfix/PRD-704-concurrent-runs-race-condi…
schustmi Nov 6, 2024
d6e42ab
Add missing break
schustmi Nov 6, 2024
f1732a9
Add jitter to exponential backoff
schustmi Nov 6, 2024
aeadcf1
Dont check for DB columns in integrity errors
schustmi Nov 6, 2024
666c836
Add workspace to model name unique constraint
schustmi Nov 6, 2024
6c214e0
Docstring
schustmi Nov 6, 2024
b4670a5
Cast to str
schustmi Nov 6, 2024
24cf449
Merge branch 'develop' into bugfix/PRD-704-concurrent-runs-race-condi…
schustmi Nov 6, 2024
2700828
Merge branch 'develop' into bugfix/PRD-704-concurrent-runs-race-condi…
schustmi Nov 7, 2024
8b0de42
Fix alembic order
schustmi Nov 7, 2024
b50779d
Implement migration of duplicate values before DB migration
schustmi Nov 7, 2024
7be917e
Auto-update of Starter template
actions-user Nov 7, 2024
f729aab
Auto-update of E2E template
actions-user Nov 7, 2024
a87d05c
Auto-update of NLP template
actions-user Nov 7, 2024
38f624a
Add logs during migration
schustmi Nov 7, 2024
0a70a83
Merge branch 'bugfix/PRD-704-concurrent-runs-race-conditions' of gith…
schustmi Nov 7, 2024
f5dd244
Type annotations in migration
schustmi Nov 7, 2024
92522d3
Formatting
schustmi Nov 7, 2024
1df5ecc
More mypy
schustmi Nov 7, 2024
d3b68c3
Merge branch 'develop' into bugfix/PRD-704-concurrent-runs-race-condi…
schustmi Nov 11, 2024
a744ce5
Update alembic order
schustmi Nov 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/zenml/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ class EntityExistsError(ZenMLBaseException):
"""Raised when trying to register an entity that already exists."""


class EntityCreationError(ZenMLBaseException, RuntimeError):
"""Raised when failing to create an entity."""


class ActionExistsError(EntityExistsError):
"""Raised when registering an action with a name that already exists."""

Expand Down
78 changes: 16 additions & 62 deletions src/zenml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Model user facing interface to pass into pipeline or step."""

import datetime
import time
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -28,7 +27,6 @@

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator

from zenml.constants import MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION
from zenml.enums import MetadataResourceTypes, ModelStages
from zenml.exceptions import EntityExistsError
from zenml.logger import get_logger
Expand Down Expand Up @@ -527,14 +525,6 @@ def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]:
data["suppress_class_validation_warnings"] = True
return data

def _validate_config_in_runtime(self) -> "ModelVersionResponse":
"""Validate that config doesn't conflict with runtime environment.

Returns:
The model version based on configuration.
"""
return self._get_or_create_model_version()

def _get_or_create_model(self) -> "ModelResponse":
"""This method should get or create a model from Model Control Plane.

Expand Down Expand Up @@ -678,16 +668,6 @@ def _get_or_create_model_version(
if isinstance(self.version, str):
self.version = format_name_template(self.version)

zenml_client = Client()
model_version_request = ModelVersionRequest(
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
name=str(self.version) if self.version else None,
description=self.description,
model=model.id,
tags=self.tags,
)
mv_request = ModelVersionRequest.model_validate(model_version_request)
try:
if self.version or self.model_version_id:
model_version = self._get_model_version()
Expand Down Expand Up @@ -717,60 +697,34 @@ def _get_or_create_model_version(
" as an example. You can explore model versions using "
f"`zenml model version list -n {self.name}` CLI command."
)
retries_made = 0
for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
try:
model_version = (
zenml_client.zen_store.create_model_version(
model_version=mv_request
)
)
break
except EntityExistsError as e:
if i == MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - 1:
raise RuntimeError(
f"Failed to create model version "
f"`{self.version if self.version else 'new'}` "
f"in model `{self.name}`. Retried {retries_made} times. "
"This could be driven by exceptionally high concurrency of "
"pipeline runs. Please, reach out to us on ZenML Slack for support."
) from e
# smoothed exponential back-off, it will go as 0.2, 0.3,
# 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ...
sleep = 0.2 * 1.5**i
logger.debug(
f"Failed to create new model version for "
f"model `{self.name}`. Retrying in {sleep}..."
)
time.sleep(sleep)
retries_made += 1
self.version = model_version.name

client = Client()
model_version_request = ModelVersionRequest(
user=client.active_user.id,
workspace=client.active_workspace.id,
name=str(self.version) if self.version else None,
description=self.description,
model=model.id,
tags=self.tags,
)
model_version = client.zen_store.create_model_version(
model_version=model_version_request
)

self._created_model_version = True

logger.info(
"Created new model version `%s` for model `%s`.",
self.version,
model_version.name,
self.name,
)

self.version = model_version.name
self.model_version_id = model_version.id
self._model_id = model_version.model.id
self._number = model_version.number
return model_version

def _merge(self, model: "Model") -> None:
self.license = self.license or model.license
self.description = self.description or model.description
self.audience = self.audience or model.audience
self.use_cases = self.use_cases or model.use_cases
self.limitations = self.limitations or model.limitations
self.trade_offs = self.trade_offs or model.trade_offs
self.ethics = self.ethics or model.ethics
if model.tags is not None:
self.tags = list(
{t for t in self.tags or []}.union(set(model.tags))
)

def __hash__(self) -> int:
"""Get hash of the `Model`.

Expand Down
2 changes: 2 additions & 0 deletions src/zenml/zen_server/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
CredentialsNotValid,
DoesNotExistException,
DuplicateRunNameError,
EntityCreationError,
EntityExistsError,
IllegalOperationError,
MethodNotAllowedError,
Expand Down Expand Up @@ -93,6 +94,7 @@ class ErrorModel(BaseModel):
# 422 Unprocessable Entity
(ValueError, 422),
# 500 Internal Server Error
(EntityCreationError, 500),
(RuntimeError, 500),
# 501 Not Implemented,
(NotImplementedError, 501),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
"""Add pipeline, model and run unique constraints [904464ea4041].

Revision ID: 904464ea4041
Revises: b557b2871693
Create Date: 2024-11-04 10:27:05.450092

"""

from collections import defaultdict
from typing import Any, Dict, Set

import sqlalchemy as sa
from alembic import op

from zenml.logger import get_logger

logger = get_logger(__name__)

# revision identifiers, used by Alembic.
revision = "904464ea4041"
down_revision = "b557b2871693"
branch_labels = None
depends_on = None


def resolve_duplicate_entities() -> None:
"""Resolve duplicate entities."""
connection = op.get_bind()
meta = sa.MetaData()
meta.reflect(
bind=connection,
only=("pipeline_run", "pipeline", "model", "model_version"),
)

# Remove duplicate names for runs, pipelines and models
for table_name in ["pipeline_run", "pipeline", "model"]:
table = sa.Table(table_name, meta)
result = connection.execute(
sa.select(table.c.id, table.c.name, table.c.workspace_id)
).all()
existing: Dict[str, Set[str]] = defaultdict(set)

for id_, name, workspace_id in result:
names_in_workspace = existing[workspace_id]

if name in names_in_workspace:
new_name = f"{name}_{id_[:6]}"
logger.warning(
"Migrating %s name from %s to %s to resolve duplicate name.",
table_name,
name,
new_name,
)
connection.execute(
sa.update(table)
.where(table.c.id == id_)
.values(name=new_name)
)
names_in_workspace.add(new_name)
else:
names_in_workspace.add(name)

# Remove duplicate names and version numbers for model versions
model_version_table = sa.Table("model_version", meta)
result = connection.execute(
sa.select(
model_version_table.c.id,
model_version_table.c.name,
model_version_table.c.number,
model_version_table.c.model_id,
)
).all()

existing_names: Dict[str, Set[str]] = defaultdict(set)
existing_numbers: Dict[str, Set[int]] = defaultdict(set)

needs_update = []

for id_, name, number, model_id in result:
names_for_model = existing_names[model_id]
numbers_for_model = existing_numbers[model_id]

needs_new_name = name in names_for_model
needs_new_number = number in numbers_for_model

if needs_new_name or needs_new_number:
needs_update.append(
(id_, name, number, model_id, needs_new_name, needs_new_number)
)

names_for_model.add(name)
numbers_for_model.add(number)

for (
id_,
name,
number,
model_id,
needs_new_name,
needs_new_number,
) in needs_update:
values: Dict[str, Any] = {}

is_numeric_version = str(number) == name
next_numeric_version = max(existing_numbers[model_id]) + 1

if is_numeric_version:
# No matter if the name or number clashes, we need to update both
values["number"] = next_numeric_version
values["name"] = str(next_numeric_version)
existing_numbers[model_id].add(next_numeric_version)
logger.warning(
"Migrating model version %s to %s to resolve duplicate name.",
name,
values["name"],
)
else:
if needs_new_name:
values["name"] = f"{name}_{id_[:6]}"
logger.warning(
"Migrating model version %s to %s to resolve duplicate name.",
name,
values["name"],
)

if needs_new_number:
values["number"] = next_numeric_version
existing_numbers[model_id].add(next_numeric_version)

connection.execute(
sa.update(model_version_table)
.where(model_version_table.c.id == id_)
.values(**values)
)


def upgrade() -> None:
"""Upgrade database schema and/or data, creating a new revision."""
# ### commands auto generated by Alembic - please adjust! ###

resolve_duplicate_entities()

with op.batch_alter_table("pipeline", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique_pipeline_name_in_workspace", ["name", "workspace_id"]
)

with op.batch_alter_table("pipeline_run", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique_run_name_in_workspace", ["name", "workspace_id"]
)

with op.batch_alter_table("model", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique_model_name_in_workspace", ["name", "workspace_id"]
)

with op.batch_alter_table("model_version", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique_version_for_model_id", ["name", "model_id"]
)
batch_op.create_unique_constraint(
"unique_version_number_for_model_id", ["number", "model_id"]
)
# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade database schema and/or data back to the previous revision."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("model_version", schema=None) as batch_op:
batch_op.drop_constraint(
"unique_version_number_for_model_id", type_="unique"
)
batch_op.drop_constraint("unique_version_for_model_id", type_="unique")

with op.batch_alter_table("model", schema=None) as batch_op:
batch_op.drop_constraint(
"unique_model_name_in_workspace", type_="unique"
)

with op.batch_alter_table("pipeline_run", schema=None) as batch_op:
batch_op.drop_constraint(
"unique_run_name_in_workspace", type_="unique"
)

with op.batch_alter_table("pipeline", schema=None) as batch_op:
batch_op.drop_constraint(
"unique_pipeline_name_in_workspace", type_="unique"
)

# ### end Alembic commands ###
26 changes: 25 additions & 1 deletion src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from uuid import UUID

from pydantic import ConfigDict
from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column
from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column, UniqueConstraint
from sqlmodel import Field, Relationship

from zenml.enums import MetadataResourceTypes, TaggableResourceTypes
Expand Down Expand Up @@ -62,6 +62,13 @@ class ModelSchema(NamedSchema, table=True):
"""SQL Model for model."""

__tablename__ = "model"
__table_args__ = (
UniqueConstraint(
"name",
"workspace_id",
name="unique_model_name_in_workspace",
),
)

workspace_id: UUID = build_foreign_key_field(
source=__tablename__,
Expand Down Expand Up @@ -220,6 +227,23 @@ class ModelVersionSchema(NamedSchema, table=True):
"""SQL Model for model version."""

__tablename__ = MODEL_VERSION_TABLENAME
__table_args__ = (
# We need two unique constraints here:
# - The first to ensure that each model version for a
# model has a unique version number
# - The second one to ensure that explicit names given by
# users are unique
UniqueConstraint(
"number",
"model_id",
name="unique_version_number_for_model_id",
),
UniqueConstraint(
"name",
"model_id",
name="unique_version_for_model_id",
),
)

workspace_id: UUID = build_foreign_key_field(
source=__tablename__,
Expand Down
Loading
Loading