Skip to content

Commit

Permalink
Merge pull request optuna#5709 from porink0424/fix/remove-unnecessary…
Browse files Browse the repository at this point in the history
…-distribution-compatibility-check

Reduce `SELECT` statements by removing unnecessary distribution compatibility check in `set_trial_param()`
  • Loading branch information
eukaryo authored Oct 18, 2024
2 parents 74e3618 + 9a04ada commit c2e0b3b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 24 deletions.
25 changes: 6 additions & 19 deletions optuna/storages/_rdb/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,27 +588,14 @@ def _set_trial_param_without_commit(
trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
self.check_trial_is_updatable(trial_id, trial.state)

trial_param = models.TrialParamModel.find_by_trial_and_param_name(
trial, param_name, session
trial_param = models.TrialParamModel(
trial_id=trial_id,
param_name=param_name,
param_value=param_value_internal,
distribution_json=distributions.distribution_to_json(distribution),
)

if trial_param is not None:
# Raise error in case distribution is incompatible.
distributions.check_distribution_compatibility(
distributions.json_to_distribution(trial_param.distribution_json), distribution
)

trial_param.param_value = param_value_internal
trial_param.distribution_json = distributions.distribution_to_json(distribution)
else:
trial_param = models.TrialParamModel(
trial_id=trial_id,
param_name=param_name,
param_value=param_value_internal,
distribution_json=distributions.distribution_to_json(distribution),
)

trial_param.check_and_add(session)
trial_param.check_and_add(session)

def _check_and_set_param_distribution(
self,
Expand Down
5 changes: 0 additions & 5 deletions tests/storages_tests/test_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,6 @@ def test_set_trial_param(storage_mode: str) -> None:
# Check set_param breaks neither get_trial nor get_trial_params.
assert storage.get_trial(trial_id_1).params == {"x": 0.5, "y": "Meguro"}
assert storage.get_trial_params(trial_id_1) == {"x": 0.5, "y": "Meguro"}
# Duplicated registration should overwrite.
storage.set_trial_param(trial_id_1, "x", 0.6, distribution_x)
assert storage.get_trial_param(trial_id_1, "x") == 0.6
assert storage.get_trial(trial_id_1).params == {"x": 0.6, "y": "Meguro"}
assert storage.get_trial_params(trial_id_1) == {"x": 0.6, "y": "Meguro"}

# Set params to another trial.
storage.set_trial_param(trial_id_2, "x", 0.3, distribution_x)
Expand Down

0 comments on commit c2e0b3b

Please sign in to comment.