Skip to content

Commit

Permalink
Merge pull request optuna#5702 from porink0424/fix/reduce-select-in-s…
Browse files Browse the repository at this point in the history
…et_trial_param-by-passing-study_id

Reduce `SELECT` statements by passing `study_id` to `check_and_add` in `TrialParamModel`
  • Loading branch information
nabenabe0928 authored Oct 28, 2024
2 parents a89dcc3 + 5e6263d commit a50cee4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
10 changes: 4 additions & 6 deletions optuna/storages/_rdb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,19 +349,17 @@ class TrialParamModel(BaseModel):
TrialModel, backref=orm.backref("params", cascade="all, delete-orphan")
)

def check_and_add(self, session: orm.Session) -> None:
self._check_compatibility_with_previous_trial_param_distributions(session)
def check_and_add(self, session: orm.Session, study_id: int) -> None:
self._check_compatibility_with_previous_trial_param_distributions(session, study_id)
session.add(self)

def _check_compatibility_with_previous_trial_param_distributions(
self, session: orm.Session
self, session: orm.Session, study_id: int
) -> None:
trial = TrialModel.find_or_raise_by_id(self.trial_id, session)

previous_record = (
session.query(TrialParamModel)
.join(TrialModel)
.filter(TrialModel.study_id == trial.study_id)
.filter(TrialModel.study_id == study_id)
.filter(TrialParamModel.param_name == self.param_name)
.first()
)
Expand Down
4 changes: 2 additions & 2 deletions optuna/storages/_rdb/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def _set_trial_param_without_commit(
distribution_json=distributions.distribution_to_json(distribution),
)

trial_param.check_and_add(session)
trial_param.check_and_add(session, trial.study_id)

def _check_and_set_param_distribution(
self,
Expand All @@ -616,7 +616,7 @@ def _check_and_set_param_distribution(
param_name=param_name,
param_value=param_value_internal,
distribution_json=distributions.distribution_to_json(distribution),
).check_and_add(session)
).check_and_add(session, study_id)

def get_trial_param(self, trial_id: int, param_name: str) -> float:
with _create_scoped_session(self.scoped_session) as session:
Expand Down

0 comments on commit a50cee4

Please sign in to comment.