Skip to content

Commit

Permalink
Merge pull request optuna#5703 from porink0424/fix/introduce-upsert-i…
Browse files Browse the repository at this point in the history
…n-set_trial_user_attr

Introduce `UPSERT` in `set_trial_user_attr`
  • Loading branch information
not522 authored Oct 28, 2024
2 parents a50cee4 + 98c6e06 commit c014b9d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests-with-minimum-versions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
run: |
# Install dependencies with minimum versions.
pip uninstall -y alembic cmaes packaging sqlalchemy plotly scikit-learn pillow
pip install alembic==1.5.0 cmaes==0.10.0 packaging==20.0 sqlalchemy==1.3.0 tqdm==4.27.0 colorlog==0.3 PyYAML==5.1 'pillow<10.4.0'
pip install alembic==1.5.0 cmaes==0.10.0 packaging==20.0 sqlalchemy==1.4.2 tqdm==4.27.0 colorlog==0.3 PyYAML==5.1 'pillow<10.4.0'
pip uninstall -y matplotlib pandas scipy
# Python v3.6 was dropped at NumPy v1.20.3.
if [ "${{ matrix.python-version }}" = "3.7" ]; then
Expand Down
39 changes: 33 additions & 6 deletions optuna/storages/_rdb/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import os
import random
import sqlite3
import time
from typing import Any
from typing import Callable
Expand Down Expand Up @@ -43,6 +44,8 @@
import alembic.migration as alembic_migration
import alembic.script as alembic_script
import sqlalchemy
import sqlalchemy.dialects.mysql as sqlalchemy_dialects_mysql
import sqlalchemy.dialects.sqlite as sqlalchemy_dialects_sqlite
import sqlalchemy.exc as sqlalchemy_exc
import sqlalchemy.orm as sqlalchemy_orm
import sqlalchemy.sql.functions as sqlalchemy_sql_functions
Expand All @@ -55,6 +58,8 @@
alembic_script = _LazyImport("alembic.script")

sqlalchemy = _LazyImport("sqlalchemy")
sqlalchemy_dialects_mysql = _LazyImport("sqlalchemy.dialects.mysql")
sqlalchemy_dialects_sqlite = _LazyImport("sqlalchemy.dialects.sqlite")
sqlalchemy_exc = _LazyImport("sqlalchemy.exc")
sqlalchemy_orm = _LazyImport("sqlalchemy.orm")
sqlalchemy_sql_functions = _LazyImport("sqlalchemy.sql.functions")
Expand Down Expand Up @@ -720,14 +725,36 @@ def _set_trial_user_attr_without_commit(
trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
self.check_trial_is_updatable(trial_id, trial.state)

attribute = models.TrialUserAttributeModel.find_by_trial_and_key(trial, key, session)
if attribute is None:
attribute = models.TrialUserAttributeModel(
trial_id=trial_id, key=key, value_json=json.dumps(value)
if self.engine.name == "mysql":
mysql_insert_stmt = sqlalchemy_dialects_mysql.insert(
models.TrialUserAttributeModel
).values(trial_id=trial_id, key=key, value_json=json.dumps(value))
mysql_upsert_stmt = mysql_insert_stmt.on_duplicate_key_update(
value_json=mysql_insert_stmt.inserted.value_json
)
session.add(attribute)
session.execute(mysql_upsert_stmt)
elif self.engine.name == "sqlite" and sqlite3.sqlite_version_info >= (3, 24, 0):
sqlite_insert_stmt = sqlalchemy_dialects_sqlite.insert(
models.TrialUserAttributeModel
).values(trial_id=trial_id, key=key, value_json=json.dumps(value))
sqlite_upsert_stmt = sqlite_insert_stmt.on_conflict_do_update(
index_elements=[
models.TrialUserAttributeModel.trial_id,
models.TrialUserAttributeModel.key,
],
set_=dict(value_json=sqlite_insert_stmt.excluded.value_json),
)
session.execute(sqlite_upsert_stmt)
else:
attribute.value_json = json.dumps(value)
# TODO(porink0424): Add support for other databases, e.g., PostgreSQL.
attribute = models.TrialUserAttributeModel.find_by_trial_and_key(trial, key, session)
if attribute is None:
attribute = models.TrialUserAttributeModel(
trial_id=trial_id, key=key, value_json=json.dumps(value)
)
session.add(attribute)
else:
attribute.value_json = json.dumps(value)

def set_trial_system_attr(self, trial_id: int, key: str, value: JSONSerializable) -> None:
with _create_scoped_session(self.scoped_session, True) as session:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"colorlog",
"numpy",
"packaging>=20.0",
"sqlalchemy>=1.3.0",
"sqlalchemy>=1.4.2",
"tqdm",
"PyYAML", # Only used in `optuna/cli.py`.
]
Expand Down

0 comments on commit c014b9d

Please sign in to comment.