diff --git a/.github/workflows/tests-with-minimum-versions.yml b/.github/workflows/tests-with-minimum-versions.yml index c692acf4d6..f40ec0ed14 100644 --- a/.github/workflows/tests-with-minimum-versions.yml +++ b/.github/workflows/tests-with-minimum-versions.yml @@ -70,7 +70,7 @@ jobs: run: | # Install dependencies with minimum versions. pip uninstall -y alembic cmaes packaging sqlalchemy plotly scikit-learn pillow - pip install alembic==1.5.0 cmaes==0.10.0 packaging==20.0 sqlalchemy==1.3.0 tqdm==4.27.0 colorlog==0.3 PyYAML==5.1 'pillow<10.4.0' + pip install alembic==1.5.0 cmaes==0.10.0 packaging==20.0 sqlalchemy==1.4.2 tqdm==4.27.0 colorlog==0.3 PyYAML==5.1 'pillow<10.4.0' pip uninstall -y matplotlib pandas scipy # Python v3.6 was dropped at NumPy v1.20.3. if [ "${{ matrix.python-version }}" = "3.7" ]; then diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index 6558fb52cf..27dd840ef5 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -9,6 +9,7 @@ import logging import os import random +import sqlite3 import time from typing import Any from typing import Callable @@ -43,6 +44,8 @@ import alembic.migration as alembic_migration import alembic.script as alembic_script import sqlalchemy + import sqlalchemy.dialects.mysql as sqlalchemy_dialects_mysql + import sqlalchemy.dialects.sqlite as sqlalchemy_dialects_sqlite import sqlalchemy.exc as sqlalchemy_exc import sqlalchemy.orm as sqlalchemy_orm import sqlalchemy.sql.functions as sqlalchemy_sql_functions @@ -55,6 +58,8 @@ alembic_script = _LazyImport("alembic.script") sqlalchemy = _LazyImport("sqlalchemy") + sqlalchemy_dialects_mysql = _LazyImport("sqlalchemy.dialects.mysql") + sqlalchemy_dialects_sqlite = _LazyImport("sqlalchemy.dialects.sqlite") sqlalchemy_exc = _LazyImport("sqlalchemy.exc") sqlalchemy_orm = _LazyImport("sqlalchemy.orm") sqlalchemy_sql_functions = _LazyImport("sqlalchemy.sql.functions") @@ -720,14 +725,36 @@ def _set_trial_user_attr_without_commit( trial = models.TrialModel.find_or_raise_by_id(trial_id, session) self.check_trial_is_updatable(trial_id, trial.state) - attribute = models.TrialUserAttributeModel.find_by_trial_and_key(trial, key, session) - if attribute is None: - attribute = models.TrialUserAttributeModel( - trial_id=trial_id, key=key, value_json=json.dumps(value) + if self.engine.name == "mysql": + mysql_insert_stmt = sqlalchemy_dialects_mysql.insert( + models.TrialUserAttributeModel + ).values(trial_id=trial_id, key=key, value_json=json.dumps(value)) + mysql_upsert_stmt = mysql_insert_stmt.on_duplicate_key_update( + value_json=mysql_insert_stmt.inserted.value_json ) - session.add(attribute) + session.execute(mysql_upsert_stmt) + elif self.engine.name == "sqlite" and sqlite3.sqlite_version_info >= (3, 24, 0): + sqlite_insert_stmt = sqlalchemy_dialects_sqlite.insert( + models.TrialUserAttributeModel + ).values(trial_id=trial_id, key=key, value_json=json.dumps(value)) + sqlite_upsert_stmt = sqlite_insert_stmt.on_conflict_do_update( + index_elements=[ + models.TrialUserAttributeModel.trial_id, + models.TrialUserAttributeModel.key, + ], + set_=dict(value_json=sqlite_insert_stmt.excluded.value_json), + ) + session.execute(sqlite_upsert_stmt) else: - attribute.value_json = json.dumps(value) + # TODO(porink0424): Add support for other databases, e.g., PostgreSQL. + attribute = models.TrialUserAttributeModel.find_by_trial_and_key(trial, key, session) + if attribute is None: + attribute = models.TrialUserAttributeModel( + trial_id=trial_id, key=key, value_json=json.dumps(value) + ) + session.add(attribute) + else: + attribute.value_json = json.dumps(value) def set_trial_system_attr(self, trial_id: int, key: str, value: JSONSerializable) -> None: with _create_scoped_session(self.scoped_session, True) as session: diff --git a/pyproject.toml b/pyproject.toml index dd1a17777c..4f7d027323 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "colorlog", "numpy", "packaging>=20.0", - "sqlalchemy>=1.3.0", + "sqlalchemy>=1.4.2", "tqdm", "PyYAML", # Only used in `optuna/cli.py`. ]