From c6dc2be9d28d6e5d6601d89bbf33e0bdc5cf64b1 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 11 Nov 2024 10:16:23 +0100 Subject: [PATCH] Prevent some race conditions (#3167) * Add pipeline and run unique constraints * Add model unique constraint * Cleanup unused methods * Increment MV version server side * Fix log message * Use name instead of number for model version unique constraint * Improve model version unique constraints * Separate exceptions * Add more migrations * Convert new exception to 500 status code * Format * Fix mypy * Docstrings * Improve logic for model version creation * Add missing break * Add jitter to exponential backoff * Dont check for DB columns in integrity errors * Add workspace to model name unique constraint * Docstring * Cast to str * Fix alembic order * Implement migration of duplicate values before DB migration * Auto-update of Starter template * Auto-update of E2E template * Auto-update of NLP template * Add logs during migration * Type annotations in migration * Formatting * More mypy * Update alembic order --------- Co-authored-by: GitHub Actions --- src/zenml/exceptions.py | 4 + src/zenml/model/model.py | 78 +--- src/zenml/zen_server/exceptions.py | 2 + ...d_pipeline_model_run_unique_constraints.py | 192 ++++++++++ src/zenml/zen_stores/schemas/model_schemas.py | 26 +- .../schemas/pipeline_run_schemas.py | 5 + .../zen_stores/schemas/pipeline_schemas.py | 10 +- src/zenml/zen_stores/sql_zen_store.py | 336 +++++++++++------- 8 files changed, 467 insertions(+), 186 deletions(-) create mode 100644 src/zenml/zen_stores/migrations/versions/904464ea4041_add_pipeline_model_run_unique_constraints.py diff --git a/src/zenml/exceptions.py b/src/zenml/exceptions.py index 9cc6cc44d57..f255ca1fae5 100644 --- a/src/zenml/exceptions.py +++ b/src/zenml/exceptions.py @@ -168,6 +168,10 @@ class EntityExistsError(ZenMLBaseException): """Raised when trying to register an entity that already exists.""" +class EntityCreationError(ZenMLBaseException, RuntimeError): + """Raised when failing to create an entity.""" + + class ActionExistsError(EntityExistsError): """Raised when registering an action with a name that already exists.""" diff --git a/src/zenml/model/model.py b/src/zenml/model/model.py index 04183b89e3a..05c0045ca66 100644 --- a/src/zenml/model/model.py +++ b/src/zenml/model/model.py @@ -14,7 +14,6 @@ """Model user facing interface to pass into pipeline or step.""" import datetime -import time from typing import ( TYPE_CHECKING, Any, @@ -28,7 +27,6 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator -from zenml.constants import MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION from zenml.enums import MetadataResourceTypes, ModelStages from zenml.exceptions import EntityExistsError from zenml.logger import get_logger @@ -527,14 +525,6 @@ def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]: data["suppress_class_validation_warnings"] = True return data - def _validate_config_in_runtime(self) -> "ModelVersionResponse": - """Validate that config doesn't conflict with runtime environment. - - Returns: - The model version based on configuration. - """ - return self._get_or_create_model_version() - def _get_or_create_model(self) -> "ModelResponse": """This method should get or create a model from Model Control Plane. @@ -678,16 +668,6 @@ def _get_or_create_model_version( if isinstance(self.version, str): self.version = format_name_template(self.version) - zenml_client = Client() - model_version_request = ModelVersionRequest( - user=zenml_client.active_user.id, - workspace=zenml_client.active_workspace.id, - name=str(self.version) if self.version else None, - description=self.description, - model=model.id, - tags=self.tags, - ) - mv_request = ModelVersionRequest.model_validate(model_version_request) try: if self.version or self.model_version_id: model_version = self._get_model_version() @@ -717,60 +697,34 @@ def _get_or_create_model_version( " as an example. You can explore model versions using " f"`zenml model version list -n {self.name}` CLI command." ) - retries_made = 0 - for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION): - try: - model_version = ( - zenml_client.zen_store.create_model_version( - model_version=mv_request - ) - ) - break - except EntityExistsError as e: - if i == MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - 1: - raise RuntimeError( - f"Failed to create model version " - f"`{self.version if self.version else 'new'}` " - f"in model `{self.name}`. Retried {retries_made} times. " - "This could be driven by exceptionally high concurrency of " - "pipeline runs. Please, reach out to us on ZenML Slack for support." - ) from e - # smoothed exponential back-off, it will go as 0.2, 0.3, - # 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ... - sleep = 0.2 * 1.5**i - logger.debug( - f"Failed to create new model version for " - f"model `{self.name}`. Retrying in {sleep}..." - ) - time.sleep(sleep) - retries_made += 1 - self.version = model_version.name + + client = Client() + model_version_request = ModelVersionRequest( + user=client.active_user.id, + workspace=client.active_workspace.id, + name=str(self.version) if self.version else None, + description=self.description, + model=model.id, + tags=self.tags, + ) + model_version = client.zen_store.create_model_version( + model_version=model_version_request + ) + self._created_model_version = True logger.info( "Created new model version `%s` for model `%s`.", - self.version, + model_version.name, self.name, ) + self.version = model_version.name self.model_version_id = model_version.id self._model_id = model_version.model.id self._number = model_version.number return model_version - def _merge(self, model: "Model") -> None: - self.license = self.license or model.license - self.description = self.description or model.description - self.audience = self.audience or model.audience - self.use_cases = self.use_cases or model.use_cases - self.limitations = self.limitations or model.limitations - self.trade_offs = self.trade_offs or model.trade_offs - self.ethics = self.ethics or model.ethics - if model.tags is not None: - self.tags = list( - {t for t in self.tags or []}.union(set(model.tags)) - ) - def __hash__(self) -> int: """Get hash of the `Model`. diff --git a/src/zenml/zen_server/exceptions.py b/src/zenml/zen_server/exceptions.py index 91c8116c243..30b631934bb 100644 --- a/src/zenml/zen_server/exceptions.py +++ b/src/zenml/zen_server/exceptions.py @@ -23,6 +23,7 @@ CredentialsNotValid, DoesNotExistException, DuplicateRunNameError, + EntityCreationError, EntityExistsError, IllegalOperationError, MethodNotAllowedError, @@ -93,6 +94,7 @@ class ErrorModel(BaseModel): # 422 Unprocessable Entity (ValueError, 422), # 500 Internal Server Error + (EntityCreationError, 500), (RuntimeError, 500), # 501 Not Implemented, (NotImplementedError, 501), diff --git a/src/zenml/zen_stores/migrations/versions/904464ea4041_add_pipeline_model_run_unique_constraints.py b/src/zenml/zen_stores/migrations/versions/904464ea4041_add_pipeline_model_run_unique_constraints.py new file mode 100644 index 00000000000..8a9441a157e --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/904464ea4041_add_pipeline_model_run_unique_constraints.py @@ -0,0 +1,192 @@ +"""Add pipeline, model and run unique constraints [904464ea4041]. + +Revision ID: 904464ea4041 +Revises: b557b2871693 +Create Date: 2024-11-04 10:27:05.450092 + +""" + +from collections import defaultdict +from typing import Any, Dict, Set + +import sqlalchemy as sa +from alembic import op + +from zenml.logger import get_logger + +logger = get_logger(__name__) + +# revision identifiers, used by Alembic. +revision = "904464ea4041" +down_revision = "b557b2871693" +branch_labels = None +depends_on = None + + +def resolve_duplicate_entities() -> None: + """Resolve duplicate entities.""" + connection = op.get_bind() + meta = sa.MetaData() + meta.reflect( + bind=connection, + only=("pipeline_run", "pipeline", "model", "model_version"), + ) + + # Remove duplicate names for runs, pipelines and models + for table_name in ["pipeline_run", "pipeline", "model"]: + table = sa.Table(table_name, meta) + result = connection.execute( + sa.select(table.c.id, table.c.name, table.c.workspace_id) + ).all() + existing: Dict[str, Set[str]] = defaultdict(set) + + for id_, name, workspace_id in result: + names_in_workspace = existing[workspace_id] + + if name in names_in_workspace: + new_name = f"{name}_{id_[:6]}" + logger.warning( + "Migrating %s name from %s to %s to resolve duplicate name.", + table_name, + name, + new_name, + ) + connection.execute( + sa.update(table) + .where(table.c.id == id_) + .values(name=new_name) + ) + names_in_workspace.add(new_name) + else: + names_in_workspace.add(name) + + # Remove duplicate names and version numbers for model versions + model_version_table = sa.Table("model_version", meta) + result = connection.execute( + sa.select( + model_version_table.c.id, + model_version_table.c.name, + model_version_table.c.number, + model_version_table.c.model_id, + ) + ).all() + + existing_names: Dict[str, Set[str]] = defaultdict(set) + existing_numbers: Dict[str, Set[int]] = defaultdict(set) + + needs_update = [] + + for id_, name, number, model_id in result: + names_for_model = existing_names[model_id] + numbers_for_model = existing_numbers[model_id] + + needs_new_name = name in names_for_model + needs_new_number = number in numbers_for_model + + if needs_new_name or needs_new_number: + needs_update.append( + (id_, name, number, model_id, needs_new_name, needs_new_number) + ) + + names_for_model.add(name) + numbers_for_model.add(number) + + for ( + id_, + name, + number, + model_id, + needs_new_name, + needs_new_number, + ) in needs_update: + values: Dict[str, Any] = {} + + is_numeric_version = str(number) == name + next_numeric_version = max(existing_numbers[model_id]) + 1 + + if is_numeric_version: + # No matter if the name or number clashes, we need to update both + values["number"] = next_numeric_version + values["name"] = str(next_numeric_version) + existing_numbers[model_id].add(next_numeric_version) + logger.warning( + "Migrating model version %s to %s to resolve duplicate name.", + name, + values["name"], + ) + else: + if needs_new_name: + values["name"] = f"{name}_{id_[:6]}" + logger.warning( + "Migrating model version %s to %s to resolve duplicate name.", + name, + values["name"], + ) + + if needs_new_number: + values["number"] = next_numeric_version + existing_numbers[model_id].add(next_numeric_version) + + connection.execute( + sa.update(model_version_table) + .where(model_version_table.c.id == id_) + .values(**values) + ) + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + + resolve_duplicate_entities() + + with op.batch_alter_table("pipeline", schema=None) as batch_op: + batch_op.create_unique_constraint( + "unique_pipeline_name_in_workspace", ["name", "workspace_id"] + ) + + with op.batch_alter_table("pipeline_run", schema=None) as batch_op: + batch_op.create_unique_constraint( + "unique_run_name_in_workspace", ["name", "workspace_id"] + ) + + with op.batch_alter_table("model", schema=None) as batch_op: + batch_op.create_unique_constraint( + "unique_model_name_in_workspace", ["name", "workspace_id"] + ) + + with op.batch_alter_table("model_version", schema=None) as batch_op: + batch_op.create_unique_constraint( + "unique_version_for_model_id", ["name", "model_id"] + ) + batch_op.create_unique_constraint( + "unique_version_number_for_model_id", ["number", "model_id"] + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("model_version", schema=None) as batch_op: + batch_op.drop_constraint( + "unique_version_number_for_model_id", type_="unique" + ) + batch_op.drop_constraint("unique_version_for_model_id", type_="unique") + + with op.batch_alter_table("model", schema=None) as batch_op: + batch_op.drop_constraint( + "unique_model_name_in_workspace", type_="unique" + ) + + with op.batch_alter_table("pipeline_run", schema=None) as batch_op: + batch_op.drop_constraint( + "unique_run_name_in_workspace", type_="unique" + ) + + with op.batch_alter_table("pipeline", schema=None) as batch_op: + batch_op.drop_constraint( + "unique_pipeline_name_in_workspace", type_="unique" + ) + + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 9072025b46a..37cec2c5513 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -19,7 +19,7 @@ from uuid import UUID from pydantic import ConfigDict -from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column +from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column, UniqueConstraint from sqlmodel import Field, Relationship from zenml.enums import MetadataResourceTypes, TaggableResourceTypes @@ -62,6 +62,13 @@ class ModelSchema(NamedSchema, table=True): """SQL Model for model.""" __tablename__ = "model" + __table_args__ = ( + UniqueConstraint( + "name", + "workspace_id", + name="unique_model_name_in_workspace", + ), + ) workspace_id: UUID = build_foreign_key_field( source=__tablename__, @@ -220,6 +227,23 @@ class ModelVersionSchema(NamedSchema, table=True): """SQL Model for model version.""" __tablename__ = MODEL_VERSION_TABLENAME + __table_args__ = ( + # We need two unique constraints here: + # - The first to ensure that each model version for a + # model has a unique version number + # - The second one to ensure that explicit names given by + # users are unique + UniqueConstraint( + "number", + "model_id", + name="unique_version_number_for_model_id", + ), + UniqueConstraint( + "name", + "model_id", + name="unique_version_for_model_id", + ), + ) workspace_id: UUID = build_foreign_key_field( source=__tablename__, diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 66312f4a351..6028451acf2 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -72,6 +72,11 @@ class PipelineRunSchema(NamedSchema, table=True): "orchestrator_run_id", name="unique_orchestrator_run_id_for_deployment_id", ), + UniqueConstraint( + "name", + "workspace_id", + name="unique_run_name_in_workspace", + ), ) # Fields diff --git a/src/zenml/zen_stores/schemas/pipeline_schemas.py b/src/zenml/zen_stores/schemas/pipeline_schemas.py index d3604483980..1f287720ee6 100644 --- a/src/zenml/zen_stores/schemas/pipeline_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_schemas.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, List, Optional from uuid import UUID -from sqlalchemy import TEXT, Column +from sqlalchemy import TEXT, Column, UniqueConstraint from sqlmodel import Field, Relationship from zenml.enums import TaggableResourceTypes @@ -50,7 +50,13 @@ class PipelineSchema(NamedSchema, table=True): """SQL Model for pipelines.""" __tablename__ = "pipeline" - + __table_args__ = ( + UniqueConstraint( + "name", + "workspace_id", + name="unique_pipeline_name_in_workspace", + ), + ) # Fields description: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 0e03d8a0a0a..f9c314f774b 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -18,6 +18,7 @@ import logging import math import os +import random import re import sys import time @@ -126,6 +127,7 @@ ActionExistsError, AuthorizationException, BackupSecretsStoreNotConfiguredError, + EntityCreationError, EntityExistsError, EventSourceExistsError, IllegalOperationError, @@ -372,6 +374,25 @@ ZENML_SQLITE_DB_FILENAME = "zenml.db" +def exponential_backoff_with_jitter( + attempt: int, base_duration: float = 0.05 +) -> float: + """Exponential backoff with jitter. + + Implemented the `Full jitter` algorithm described in + https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ + + Args: + attempt: The backoff attempt. + base_duration: The backoff base duration. + + Returns: + The backoff duration. + """ + exponential_backoff = base_duration * 1.5**attempt + return random.uniform(0, exponential_backoff) + + class SQLDatabaseDriver(StrEnum): """SQL database drivers supported by the SQL ZenML store.""" @@ -2782,7 +2803,9 @@ def create_artifact_version( artifact_version: The artifact version to create. Raises: - EntityExistsError: If the artifact version already exists. + EntityExistsError: If an artifact version with the same name + already exists. + EntityCreationError: If the artifact version creation failed. Returns: The created artifact version. @@ -2823,7 +2846,7 @@ def create_artifact_version( artifact_version_id = artifact_version_schema.id except IntegrityError: if remaining_tries == 0: - raise EntityExistsError( + raise EntityCreationError( f"Failed to create version for artifact " f"{artifact_schema.name}. This is most likely " "caused by multiple parallel requests that try " @@ -2831,12 +2854,14 @@ def create_artifact_version( "database." ) else: - # Exponential backoff to account for heavy - # parallelization - sleep_duration = 0.05 * 1.5 ** ( + attempt = ( MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - remaining_tries ) + sleep_duration = exponential_backoff_with_jitter( + attempt=attempt + ) + logger.debug( "Failed to create artifact version %s " "(version %s) due to an integrity error. " @@ -4229,20 +4254,6 @@ def create_pipeline( EntityExistsError: If an identical pipeline already exists. """ with Session(self.engine) as session: - # Check if pipeline with the given name already exists - existing_pipeline = session.exec( - select(PipelineSchema) - .where(PipelineSchema.name == pipeline.name) - .where(PipelineSchema.workspace_id == pipeline.workspace) - ).first() - if existing_pipeline is not None: - raise EntityExistsError( - f"Unable to create pipeline in workspace " - f"'{pipeline.workspace}': A pipeline with this name " - "already exists." - ) - - # Create the pipeline new_pipeline = PipelineSchema.from_request(pipeline) if pipeline.tags: @@ -4253,7 +4264,14 @@ def create_pipeline( ) session.add(new_pipeline) - session.commit() + try: + session.commit() + except IntegrityError: + raise EntityExistsError( + f"Unable to create pipeline in workspace " + f"'{pipeline.workspace}': A pipeline with the name " + f"{pipeline.name} already exists." + ) session.refresh(new_pipeline) return new_pipeline.to_model( @@ -5097,6 +5115,26 @@ def delete_event_source(self, event_source_id: UUID) -> None: # ----------------------------- Pipeline runs ----------------------------- + def _pipeline_run_exists(self, workspace_id: UUID, name: str) -> bool: + """Check if a pipeline name with a certain name exists. + + Args: + workspace_id: The workspace to check. + name: The run name. + + Returns: + If a pipeline run with the given name exists. + """ + with Session(self.engine) as session: + return ( + session.exec( + select(PipelineRunSchema.id) + .where(PipelineRunSchema.workspace_id == workspace_id) + .where(PipelineRunSchema.name == name) + ).first() + is not None + ) + def create_run( self, pipeline_run: PipelineRunRequest ) -> PipelineRunResponse: @@ -5112,18 +5150,6 @@ def create_run( EntityExistsError: If a run with the same name already exists. """ with Session(self.engine) as session: - # Check if pipeline run with same name already exists. - existing_domain_run = session.exec( - select(PipelineRunSchema).where( - PipelineRunSchema.name == pipeline_run.name - ) - ).first() - if existing_domain_run is not None: - raise EntityExistsError( - f"Unable to create pipeline run: A pipeline run with name " - f"'{pipeline_run.name}' already exists." - ) - # Create the pipeline run new_run = PipelineRunSchema.from_request(pipeline_run) @@ -5135,7 +5161,22 @@ def create_run( ) session.add(new_run) - session.commit() + try: + session.commit() + except IntegrityError: + if self._pipeline_run_exists( + workspace_id=pipeline_run.workspace, name=pipeline_run.name + ): + raise EntityExistsError( + f"Unable to create pipeline run: A pipeline run with " + f"name '{pipeline_run.name}' already exists." + ) + else: + raise EntityExistsError( + "Unable to create pipeline run: A pipeline run with " + "the same deployment_id and orchestrator_run_id " + "already exists." + ) return new_run.to_model( include_metadata=True, include_resources=True @@ -5327,21 +5368,14 @@ def get_or_create_run( if pre_creation_hook: pre_creation_hook() return self.create_run(pipeline_run), True - except (EntityExistsError, IntegrityError) as create_error: - # Creating the run failed with an - # - IntegrityError: This happens when we violated a unique - # constraint, which in turn means a run with the same - # deployment_id and orchestrator_run_id exists. We now fetch and - # return that run. - # - EntityExistsError: This happens when a run with the same name - # already exists. This could be either a different run (in which - # case we want to fail) or a run created by a step of the same - # pipeline run (in which case we want to return it). - # Note: The IntegrityError might also be raised when other unique - # constraints get violated. The only other such constraint is the - # primary key constraint on the run ID, which means we randomly - # generated an existing UUID. In this case the call below will fail, - # but the chance of that happening is so low we don't handle it. + except EntityExistsError as create_error: + # Creating the run failed because + # - a run with the same deployment_id and orchestrator_run_id + # exists. We now fetch and return that run. + # - a run with the same name already exists. This could be either a + # different run (in which case we want to fail) or a run created + # by a step of the same pipeline run (in which case we want to + # return it). try: return ( self._get_run_by_orchestrator_run_id( @@ -5351,18 +5385,11 @@ def get_or_create_run( False, ) except KeyError: - if isinstance(create_error, EntityExistsError): - # There was a run with the same name which does not share - # the deployment_id and orchestrator_run_id -> We fail with - # the error that run names must be unique. - raise create_error from None - - # This should never happen as the run creation failed with an - # IntegrityError which means a run with the deployment_id and - # orchestrator_run_id exists. - raise RuntimeError( - f"Failed to get or create run: {create_error}" - ) + # We should only get here if the run creation failed because + # of a name conflict. We raise the error that happened during + # creation in any case to forward the error message to the + # user. + raise create_error def list_runs( self, @@ -10000,19 +10027,10 @@ def create_model(self, model: ModelRequest) -> ModelResponse: The newly created model. Raises: - EntityExistsError: If a workspace with the given name already exists. + EntityExistsError: If a model with the given name already exists. """ validate_name(model) with Session(self.engine) as session: - existing_model = session.exec( - select(ModelSchema).where(ModelSchema.name == model.name) - ).first() - if existing_model is not None: - raise EntityExistsError( - f"Unable to create model {model.name}: " - "A model with this name already exists." - ) - model_schema = ModelSchema.from_request(model) session.add(model_schema) @@ -10022,7 +10040,14 @@ def create_model(self, model: ModelRequest) -> ModelResponse: resource_id=model_schema.id, resource_type=TaggableResourceTypes.MODEL, ) - session.commit() + try: + session.commit() + except IntegrityError: + raise EntityExistsError( + f"Unable to create model {model.name}: " + "A model with this name already exists." + ) + return model_schema.to_model( include_metadata=True, include_resources=True ) @@ -10158,6 +10183,50 @@ def update_model( # ----------------------------- Model Versions ----------------------------- + def _get_next_numeric_version_for_model( + self, session: Session, model_id: UUID + ) -> int: + """Get the next numeric version for a model. + + Args: + session: DB session. + model_id: ID of the model for which to get the next numeric + version. + + Returns: + The next numeric version. + """ + current_max_version = session.exec( + select(func.max(ModelVersionSchema.number)).where( + ModelVersionSchema.model_id == model_id + ) + ).first() + + if current_max_version is None: + return 1 + else: + return int(current_max_version) + 1 + + def _model_version_exists(self, model_id: UUID, version: str) -> bool: + """Check if a model version with a certain version exists. + + Args: + model_id: The model ID of the version. + version: The version name. + + Returns: + If a model version with the given version name exists. + """ + with Session(self.engine) as session: + return ( + session.exec( + select(ModelVersionSchema.id) + .where(ModelVersionSchema.model_id == model_id) + .where(ModelVersionSchema.name == version) + ).first() + is not None + ) + @track_decorator(AnalyticsEvent.CREATED_MODEL_VERSION) def create_model_version( self, model_version: ModelVersionRequest @@ -10172,70 +10241,95 @@ def create_model_version( Raises: ValueError: If `number` is not None during model version creation. - EntityExistsError: If a workspace with the given name already exists. + EntityExistsError: If a model version with the given name already + exists. + EntityCreationError: If the model version creation failed. """ if model_version.number is not None: raise ValueError( "`number` field must be None during model version creation." ) - with Session(self.engine) as session: - model_version_ = model_version.model_copy() - model = self.get_model(model_version_.model) - def _check(tolerance: int = 0) -> None: - query = session.exec( - select(ModelVersionSchema) - .where(ModelVersionSchema.model_id == model.id) - .where(ModelVersionSchema.name == model_version_.name) - ) - existing_model_version = query.fetchmany(tolerance + 1) - if ( - existing_model_version is not None - and len(existing_model_version) > tolerance - ): - raise EntityExistsError( - f"Unable to create model version {model_version_.name}: " - f"A model version with this name already exists in {model.name} model." - ) + model = self.get_model(model_version.model) - _check() - all_versions = session.exec( - select(ModelVersionSchema) - .where(ModelVersionSchema.model_id == model.id) - .order_by(ModelVersionSchema.number.desc()) # type: ignore[attr-defined] - ).first() + has_custom_name = model_version.name is not None + if has_custom_name: + validate_name(model_version) - model_version_.number = ( - all_versions.number + 1 if all_versions else 1 - ) + model_version_id = None - if model_version_.name is None: - model_version_.name = str(model_version_.number) - else: - validate_name(model_version_) + remaining_tries = MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION + while remaining_tries > 0: + remaining_tries -= 1 + try: + with Session(self.engine) as session: + model_version.number = ( + self._get_next_numeric_version_for_model( + session=session, + model_id=model.id, + ) + ) + if not has_custom_name: + model_version.name = str(model_version.number) - model_version_schema = ModelVersionSchema.from_request( - model_version_ - ) - session.add(model_version_schema) + model_version_schema = ModelVersionSchema.from_request( + model_version + ) + session.add(model_version_schema) + session.commit() - if model_version_.tags: - self._attach_tags_to_resource( - tag_names=model_version_.tags, - resource_id=model_version_schema.id, - resource_type=TaggableResourceTypes.MODEL_VERSION, - ) - try: - _check(1) - session.commit() - except EntityExistsError as e: - session.rollback() - raise e + model_version_id = model_version_schema.id + break + except IntegrityError: + if has_custom_name and self._model_version_exists( + model_id=model.id, version=cast(str, model_version.name) + ): + # We failed not because of a version number conflict, + # but because the user requested a version name that + # is already taken -> We don't retry anymore but fail + # immediately. + raise EntityExistsError( + f"Unable to create model version " + f"{model.name} (version " + f"{model_version.name}): A model with the " + "same name and version already exists." + ) + elif remaining_tries == 0: + raise EntityCreationError( + f"Failed to create version for model " + f"{model.name}. This is most likely " + "caused by multiple parallel requests that try " + "to create versions for this model in the " + "database." + ) + else: + attempt = ( + MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION + - remaining_tries + ) + sleep_duration = exponential_backoff_with_jitter( + attempt=attempt + ) + logger.debug( + "Failed to create model version %s " + "(version %s) due to an integrity error. " + "Retrying in %f seconds.", + model.name, + model_version.number, + sleep_duration, + ) + time.sleep(sleep_duration) - return model_version_schema.to_model( - include_metadata=True, include_resources=True + assert model_version_id + if model_version.tags: + self._attach_tags_to_resource( + tag_names=model_version.tags, + resource_id=model_version_id, + resource_type=TaggableResourceTypes.MODEL_VERSION, ) + return self.get_model_version(model_version_id) + def get_model_version( self, model_version_id: UUID, hydrate: bool = True ) -> ModelVersionResponse: