From 2227ed0069a8ca9b8d36d9b38e7d5093d94db70a Mon Sep 17 00:00:00 2001 From: porink0424 Date: Wed, 16 Oct 2024 16:45:53 +0900 Subject: [PATCH 1/2] Remove unnecessary distribution compatibility check --- optuna/storages/_rdb/storage.py | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index 25d66e6ede..3cb02189cb 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -588,27 +588,14 @@ def _set_trial_param_without_commit( trial = models.TrialModel.find_or_raise_by_id(trial_id, session) self.check_trial_is_updatable(trial_id, trial.state) - trial_param = models.TrialParamModel.find_by_trial_and_param_name( - trial, param_name, session + trial_param = models.TrialParamModel( + trial_id=trial_id, + param_name=param_name, + param_value=param_value_internal, + distribution_json=distributions.distribution_to_json(distribution), ) - if trial_param is not None: - # Raise error in case distribution is incompatible. - distributions.check_distribution_compatibility( - distributions.json_to_distribution(trial_param.distribution_json), distribution - ) - - trial_param.param_value = param_value_internal - trial_param.distribution_json = distributions.distribution_to_json(distribution) - else: - trial_param = models.TrialParamModel( - trial_id=trial_id, - param_name=param_name, - param_value=param_value_internal, - distribution_json=distributions.distribution_to_json(distribution), - ) - - trial_param.check_and_add(session) + trial_param.check_and_add(session) def _check_and_set_param_distribution( self, From 9a04adac981bd7292c63e54212a736f8ba7e6ec3 Mon Sep 17 00:00:00 2001 From: porink0424 Date: Wed, 16 Oct 2024 17:02:41 +0900 Subject: [PATCH 2/2] Fix test cases --- tests/storages_tests/test_storages.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/storages_tests/test_storages.py b/tests/storages_tests/test_storages.py index bf9e3e5da7..5d7cbebd19 100644 --- a/tests/storages_tests/test_storages.py +++ b/tests/storages_tests/test_storages.py @@ -486,11 +486,6 @@ def test_set_trial_param(storage_mode: str) -> None: # Check set_param breaks neither get_trial nor get_trial_params. assert storage.get_trial(trial_id_1).params == {"x": 0.5, "y": "Meguro"} assert storage.get_trial_params(trial_id_1) == {"x": 0.5, "y": "Meguro"} - # Duplicated registration should overwrite. - storage.set_trial_param(trial_id_1, "x", 0.6, distribution_x) - assert storage.get_trial_param(trial_id_1, "x") == 0.6 - assert storage.get_trial(trial_id_1).params == {"x": 0.6, "y": "Meguro"} - assert storage.get_trial_params(trial_id_1) == {"x": 0.6, "y": "Meguro"} # Set params to another trial. storage.set_trial_param(trial_id_2, "x", 0.3, distribution_x)