Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent some race conditions #3167

Merged
merged 36 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
250e713
Add pipeline and run unique constraints
schustmi Nov 4, 2024
a8cec79
Add model unique constraint
schustmi Nov 4, 2024
3965539
Cleanup unused methods
schustmi Nov 4, 2024
f120877
Increment MV version server side
schustmi Nov 4, 2024
f5ec6b9
Fix log message
schustmi Nov 4, 2024
40661e8
Use name instead of number for model version unique constraint
schustmi Nov 4, 2024
c3b8363
Improve model version unique constraints
schustmi Nov 4, 2024
a4c6875
Separate exceptions
schustmi Nov 4, 2024
6e35f2d
Add more migrations
schustmi Nov 5, 2024
e9ef50e
Convert new exception to 500 status code
schustmi Nov 5, 2024
31bbd11
Format
schustmi Nov 5, 2024
524034f
Fix mypy
schustmi Nov 6, 2024
ccdc499
Docstrings
schustmi Nov 6, 2024
56f6bfd
Merge branch 'develop' into bugfix/PRD-704-concurrent-runs-race-condi…
schustmi Nov 6, 2024
d7b4821
Improve logic for model version creation
schustmi Nov 6, 2024
3446f04
Merge branch 'develop' into bugfix/PRD-704-concurrent-runs-race-condi…
schustmi Nov 6, 2024
d6e42ab
Add missing break
schustmi Nov 6, 2024
f1732a9
Add jitter to exponential backoff
schustmi Nov 6, 2024
aeadcf1
Dont check for DB columns in integrity errors
schustmi Nov 6, 2024
666c836
Add workspace to model name unique constraint
schustmi Nov 6, 2024
6c214e0
Docstring
schustmi Nov 6, 2024
b4670a5
Cast to str
schustmi Nov 6, 2024
24cf449
Merge branch 'develop' into bugfix/PRD-704-concurrent-runs-race-condi…
schustmi Nov 6, 2024
2700828
Merge branch 'develop' into bugfix/PRD-704-concurrent-runs-race-condi…
schustmi Nov 7, 2024
8b0de42
Fix alembic order
schustmi Nov 7, 2024
b50779d
Implement migration of duplicate values before DB migration
schustmi Nov 7, 2024
7be917e
Auto-update of Starter template
actions-user Nov 7, 2024
f729aab
Auto-update of E2E template
actions-user Nov 7, 2024
a87d05c
Auto-update of NLP template
actions-user Nov 7, 2024
38f624a
Add logs during migration
schustmi Nov 7, 2024
0a70a83
Merge branch 'bugfix/PRD-704-concurrent-runs-race-conditions' of gith…
schustmi Nov 7, 2024
f5dd244
Type annotations in migration
schustmi Nov 7, 2024
92522d3
Formatting
schustmi Nov 7, 2024
1df5ecc
More mypy
schustmi Nov 7, 2024
d3b68c3
Merge branch 'develop' into bugfix/PRD-704-concurrent-runs-race-condi…
schustmi Nov 11, 2024
a744ce5
Update alembic order
schustmi Nov 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/zenml/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ class EntityExistsError(ZenMLBaseException):
"""Raised when trying to register an entity that already exists."""


class EntityCreationError(ZenMLBaseException):
"""Raised when failing to create an entity."""


avishniakov marked this conversation as resolved.
Show resolved Hide resolved
class ActionExistsError(EntityExistsError):
"""Raised when registering an action with a name that already exists."""

Expand Down
78 changes: 16 additions & 62 deletions src/zenml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Model user facing interface to pass into pipeline or step."""

import datetime
import time
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -28,7 +27,6 @@

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator

from zenml.constants import MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION
from zenml.enums import MetadataResourceTypes, ModelStages
from zenml.exceptions import EntityExistsError
from zenml.logger import get_logger
Expand Down Expand Up @@ -528,14 +526,6 @@ def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]:
data["suppress_class_validation_warnings"] = True
return data

def _validate_config_in_runtime(self) -> "ModelVersionResponse":
"""Validate that config doesn't conflict with runtime environment.

Returns:
The model version based on configuration.
"""
return self._get_or_create_model_version()

def _get_or_create_model(self) -> "ModelResponse":
"""This method should get or create a model from Model Control Plane.

Expand Down Expand Up @@ -679,16 +669,6 @@ def _get_or_create_model_version(
if isinstance(self.version, str):
self.version = format_name_template(self.version)

zenml_client = Client()
model_version_request = ModelVersionRequest(
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
name=str(self.version) if self.version else None,
description=self.description,
model=model.id,
tags=self.tags,
)
mv_request = ModelVersionRequest.model_validate(model_version_request)
try:
if self.version or self.model_version_id:
model_version = self._get_model_version()
Expand Down Expand Up @@ -718,60 +698,34 @@ def _get_or_create_model_version(
" as an example. You can explore model versions using "
f"`zenml model version list -n {self.name}` CLI command."
)
retries_made = 0
for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
try:
model_version = (
zenml_client.zen_store.create_model_version(
model_version=mv_request
)
)
break
except EntityExistsError as e:
if i == MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - 1:
raise RuntimeError(
f"Failed to create model version "
f"`{self.version if self.version else 'new'}` "
f"in model `{self.name}`. Retried {retries_made} times. "
"This could be driven by exceptionally high concurrency of "
"pipeline runs. Please, reach out to us on ZenML Slack for support."
) from e
# smoothed exponential back-off, it will go as 0.2, 0.3,
# 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ...
sleep = 0.2 * 1.5**i
logger.debug(
f"Failed to create new model version for "
f"model `{self.name}`. Retrying in {sleep}..."
)
time.sleep(sleep)
retries_made += 1
self.version = model_version.name

client = Client()
model_version_request = ModelVersionRequest(
user=client.active_user.id,
workspace=client.active_workspace.id,
name=str(self.version) if self.version else None,
description=self.description,
model=model.id,
tags=self.tags,
)
model_version = client.zen_store.create_model_version(
model_version=model_version_request
)

self._created_model_version = True

logger.info(
"Created new model version `%s` for model `%s`.",
self.version,
model_version.name,
self.name,
)

self.version = model_version.name
self.model_version_id = model_version.id
self._model_id = model_version.model.id
self._number = model_version.number
return model_version

def _merge(self, model: "Model") -> None:
self.license = self.license or model.license
self.description = self.description or model.description
self.audience = self.audience or model.audience
self.use_cases = self.use_cases or model.use_cases
self.limitations = self.limitations or model.limitations
self.trade_offs = self.trade_offs or model.trade_offs
self.ethics = self.ethics or model.ethics
if model.tags is not None:
self.tags = list(
{t for t in self.tags or []}.union(set(model.tags))
)

def __hash__(self) -> int:
"""Get hash of the `Model`.

Expand Down
2 changes: 2 additions & 0 deletions src/zenml/zen_server/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
SubscriptionUpgradeRequiredError,
ValidationError,
ZenKeyError,
EntityCreationError
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -93,6 +94,7 @@ class ErrorModel(BaseModel):
# 422 Unprocessable Entity
(ValueError, 422),
# 500 Internal Server Error
(EntityCreationError, 500),
(RuntimeError, 500),
# 501 Not Implemented,
(NotImplementedError, 501),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Add model version unique constraints [7f562e0c04a4].

Revision ID: 7f562e0c04a4
Revises: 904464ea4041
Create Date: 2024-11-04 11:10:43.454981

"""

from alembic import op

# revision identifiers, used by Alembic.
revision = "7f562e0c04a4"
down_revision = "904464ea4041"
branch_labels = None
depends_on = None


def upgrade() -> None:
"""Upgrade database schema and/or data, creating a new revision."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("model_version", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique_version_for_model_id", ["name", "model_id"]
)
batch_op.create_unique_constraint(
"unique_version_number_for_model_id", ["number", "model_id"]
)

# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade database schema and/or data back to the previous revision."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("model_version", schema=None) as batch_op:
batch_op.drop_constraint(
"unique_version_number_for_model_id", type_="unique"
)
batch_op.drop_constraint("unique_version_for_model_id", type_="unique")

# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Add pipeline, model and run unique constraints [904464ea4041].

Revision ID: 904464ea4041
Revises: c22561cbb3a9
Create Date: 2024-11-04 10:27:05.450092

"""

from alembic import op

# revision identifiers, used by Alembic.
revision = "904464ea4041"
down_revision = "c22561cbb3a9"
branch_labels = None
depends_on = None


def upgrade() -> None:
"""Upgrade database schema and/or data, creating a new revision."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("pipeline", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique_pipeline_name_in_workspace", ["name", "workspace_id"]
)

with op.batch_alter_table("pipeline_run", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique_run_name_in_workspace", ["name", "workspace_id"]
)

with op.batch_alter_table("model", schema=None) as batch_op:
batch_op.create_unique_constraint("unique_model_name", ["name"])

# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade database schema and/or data back to the previous revision."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("model", schema=None) as batch_op:
batch_op.drop_constraint("unique_model_name", type_="unique")

with op.batch_alter_table("pipeline_run", schema=None) as batch_op:
batch_op.drop_constraint(
"unique_run_name_in_workspace", type_="unique"
)

with op.batch_alter_table("pipeline", schema=None) as batch_op:
batch_op.drop_constraint(
"unique_pipeline_name_in_workspace", type_="unique"
)

# ### end Alembic commands ###
25 changes: 24 additions & 1 deletion src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from uuid import UUID

from pydantic import ConfigDict
from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column
from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column, UniqueConstraint
from sqlmodel import Field, Relationship

from zenml.enums import MetadataResourceTypes, TaggableResourceTypes
Expand Down Expand Up @@ -61,6 +61,12 @@ class ModelSchema(NamedSchema, table=True):
"""SQL Model for model."""

__tablename__ = "model"
__table_args__ = (
UniqueConstraint(
"name",
name="unique_model_name",
),
)

workspace_id: UUID = build_foreign_key_field(
source=__tablename__,
Expand Down Expand Up @@ -219,6 +225,23 @@ class ModelVersionSchema(NamedSchema, table=True):
"""SQL Model for model version."""

__tablename__ = MODEL_VERSION_TABLENAME
__table_args__ = (
# We need two unique constraints here:
# - The first to ensure that each model version for a
# model has a unique version number
# - The second one to ensure that explicit names given by
# users are unique
UniqueConstraint(
"number",
"model_id",
name="unique_version_number_for_model_id",
),
UniqueConstraint(
"name",
"model_id",
name="unique_version_for_model_id",
),
)

workspace_id: UUID = build_foreign_key_field(
source=__tablename__,
Expand Down
5 changes: 5 additions & 0 deletions src/zenml/zen_stores/schemas/pipeline_run_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class PipelineRunSchema(NamedSchema, table=True):
"orchestrator_run_id",
name="unique_orchestrator_run_id_for_deployment_id",
),
UniqueConstraint(
"name",
"workspace_id",
name="unique_run_name_in_workspace",
),
)

# Fields
Expand Down
10 changes: 8 additions & 2 deletions src/zenml/zen_stores/schemas/pipeline_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, List, Optional
from uuid import UUID

from sqlalchemy import TEXT, Column
from sqlalchemy import TEXT, Column, UniqueConstraint
from sqlmodel import Field, Relationship

from zenml.enums import TaggableResourceTypes
Expand Down Expand Up @@ -50,7 +50,13 @@ class PipelineSchema(NamedSchema, table=True):
"""SQL Model for pipelines."""

__tablename__ = "pipeline"

__table_args__ = (
UniqueConstraint(
"name",
"workspace_id",
name="unique_pipeline_name_in_workspace",
),
)
# Fields
description: Optional[str] = Field(sa_column=Column(TEXT, nullable=True))

Expand Down
Loading
Loading