Skip to content

Commit

Permalink
Merge pull request optuna#5415 from eukaryo/v4pr-drop-create-study-in…
Browse files Browse the repository at this point in the history
…-ask

Drop implicit create-study in `ask` command
  • Loading branch information
HideakiImamura authored May 21, 2024
2 parents c38896c + 0640786 commit fe400d1
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
5 changes: 4 additions & 1 deletion optuna/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,10 @@ def take_action(self, parsed_args: Namespace) -> int:
)

except KeyError:
study = optuna.create_study(**create_study_kwargs)
raise KeyError(
"Implicit study creation within the 'ask' command was dropped in Optuna v4.0.0. "
"Please use the 'create-study' command beforehand."
)
trial = study.ask(fixed_distributions=search_space)

self.logger.info(f"Asked trial {trial.number} with parameters {trial.params}.")
Expand Down
51 changes: 51 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,9 @@ def test_ask(
with NamedTemporaryFilePool() as tf:
db_url = "sqlite:///{}".format(tf.name)

args = ["optuna", "create-study", "--storage", db_url, "--study-name", study_name]
subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

args = [
"optuna",
"ask",
Expand Down Expand Up @@ -1218,6 +1221,9 @@ def test_ask_flatten(
with NamedTemporaryFilePool() as tf:
db_url = "sqlite:///{}".format(tf.name)

args = ["optuna", "create-study", "--storage", db_url, "--study-name", study_name]
subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

args = [
"optuna",
"ask",
Expand Down Expand Up @@ -1261,6 +1267,9 @@ def test_ask_empty_search_space(output_format: str) -> None:
with NamedTemporaryFilePool() as tf:
db_url = "sqlite:///{}".format(tf.name)

args = ["optuna", "create-study", "--storage", db_url, "--study-name", study_name]
subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

args = [
"optuna",
"ask",
Expand Down Expand Up @@ -1294,6 +1303,9 @@ def test_ask_empty_search_space_flatten(output_format: str) -> None:
with NamedTemporaryFilePool() as tf:
db_url = "sqlite:///{}".format(tf.name)

args = ["optuna", "create-study", "--storage", db_url, "--study-name", study_name]
subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

args = [
"optuna",
"ask",
Expand Down Expand Up @@ -1331,6 +1343,9 @@ def test_ask_sampler_kwargs_without_sampler() -> None:
with NamedTemporaryFilePool() as tf:
db_url = "sqlite:///{}".format(tf.name)

args = ["optuna", "create-study", "--storage", db_url, "--study-name", study_name]
subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

args = [
"optuna",
"ask",
Expand All @@ -1349,6 +1364,36 @@ def test_ask_sampler_kwargs_without_sampler() -> None:
assert "`--sampler_kwargs` is set without `--sampler`." in error_message


@pytest.mark.skip_coverage
def test_ask_without_create_study_beforehand() -> None:
study_name = "test_study"
search_space = (
'{"x": {"name": "FloatDistribution", "attributes": {"low": 0.0, "high": 1.0}}, '
'"y": {"name": "CategoricalDistribution", "attributes": {"choices": ["foo"]}}}'
)

with NamedTemporaryFilePool() as tf:
db_url = "sqlite:///{}".format(tf.name)

args = [
"optuna",
"ask",
"--storage",
db_url,
"--study-name",
study_name,
"--search-space",
search_space,
]

result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
error_message = result.stderr.decode()
assert (
"Implicit study creation within the 'ask' command was dropped in Optuna v4.0.0."
in error_message
)


@pytest.mark.skip_coverage
@pytest.mark.parametrize(
"direction,directions,sampler,sampler_kwargs",
Expand Down Expand Up @@ -1421,6 +1466,9 @@ def test_tell() -> None:
with NamedTemporaryFilePool() as tf:
db_url = "sqlite:///{}".format(tf.name)

args = ["optuna", "create-study", "--storage", db_url, "--study-name", study_name]
subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

output: Any = subprocess.check_output(
[
"optuna",
Expand Down Expand Up @@ -1498,6 +1546,9 @@ def test_tell_with_nan() -> None:
with NamedTemporaryFilePool() as tf:
db_url = "sqlite:///{}".format(tf.name)

args = ["optuna", "create-study", "--storage", db_url, "--study-name", study_name]
subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

output: Any = subprocess.check_output(
[
"optuna",
Expand Down

0 comments on commit fe400d1

Please sign in to comment.