From e3c4832363a5661993ffbaedea6de8f3835c3236 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 12 Oct 2023 16:41:57 +0900 Subject: [PATCH 1/4] Add `optuna study-name` cli --- optuna/cli.py | 25 ++++++++++++++++++ tests/test_cli.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/optuna/cli.py b/optuna/cli.py index ffba0e3724..43b4210317 100644 --- a/optuna/cli.py +++ b/optuna/cli.py @@ -382,6 +382,30 @@ def take_action(self, parsed_args: Namespace) -> int: return 0 +class _StudyNames(_BaseCommand): + """Get all study names stored in a specified storage""" + + def add_arguments(self, parser: ArgumentParser) -> None: + parser.add_argument( + "-f", + "--format", + type=str, + choices=("json", "table", "yaml"), + default="table", + help="Output format.", + ) + + def take_action(self, parsed_args: Namespace) -> int: + storage = _get_storage(parsed_args.storage, parsed_args.storage_class) + all_study_names = optuna.get_all_study_names(storage) + records = [] + record_key = ("name", "") + for study_name in all_study_names: + records.append({record_key: study_name}) + print(_format_output(records, [record_key], parsed_args.format, flatten=False)) + return 0 + + class _Studies(_BaseCommand): """Show a list of studies.""" @@ -917,6 +941,7 @@ def take_action(self, parsed_args: Namespace) -> int: "create-study": _CreateStudy, "delete-study": _DeleteStudy, "study set-user-attr": _StudySetUserAttribute, + "study-names": _StudyNames, "studies": _Studies, "trials": _Trials, "best-trial": _BestTrial, diff --git a/tests/test_cli.py b/tests/test_cli.py index 79c64ecd64..0cbb67bd09 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -297,6 +297,71 @@ def test_study_set_user_attr_command() -> None: assert all(study_user_attrs[k] == v for k, v in example_attrs.items()) +@pytest.mark.skip_coverage +@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml")) +def test_study_names_command(output_format: Optional[str]) -> None: + with StorageSupplier("sqlite") as storage: + assert isinstance(storage, RDBStorage) + storage_url = str(storage.engine.url) + + expected_study_names = ["study-names-test1", "study-names-test2"] + expected_column_name = "name" + + # Create a study. + command = [ + "optuna", + "create-study", + "--storage", + storage_url, + "--study-name", + expected_study_names[0], + ] + subprocess.check_output(command) + + # Get study names. + command = ["optuna", "study-names", "--storage", storage_url] + if output_format is not None: + command += ["--format", output_format] + output = str(subprocess.check_output(command).decode().strip()) + study_names = _parse_output(output, output_format or "table") + + # Check user_attrs are not printed. + assert len(study_names) == 1 + assert study_names[0]["name"] == expected_study_names[0] + + # Create another study. + command = [ + "optuna", + "create-study", + "--storage", + storage_url, + "--study-name", + expected_study_names[1], + ] + subprocess.check_output(command) + + # Get study names. + command = ["optuna", "study-names", "--storage", storage_url] + if output_format is not None: + command += ["--format", output_format] + output = str(subprocess.check_output(command).decode().strip()) + study_names = _parse_output(output, output_format or "table") + + assert len(study_names) == 2 + for i, study_name in enumerate(study_names): + assert list(study_name.keys()) == [expected_column_name] + assert study_name["name"] == expected_study_names[i] + + +@pytest.mark.skip_coverage +def test_study_names_command_without_storage_url() -> None: + with pytest.raises(subprocess.CalledProcessError): + subprocess.check_output( + ["optuna", "study-names", "--study-name", "dummy_study"], + env={k: v for k, v in os.environ.items() if k != "OPTUNA_STORAGE"}, + ) + + @pytest.mark.skip_coverage @pytest.mark.parametrize("output_format", (None, "table", "json", "yaml")) def test_studies_command(output_format: Optional[str]) -> None: From e95c13a15e98675dc1545340b994227922876d7d Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 26 Oct 2023 11:03:35 +0900 Subject: [PATCH 2/4] implement value format and make it default to study-names --- optuna/cli.py | 21 ++++++++++++++++++--- tests/test_cli.py | 11 +++++++---- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/optuna/cli.py b/optuna/cli.py index 43b4210317..90f60c9fe3 100644 --- a/optuna/cli.py +++ b/optuna/cli.py @@ -186,6 +186,16 @@ def get_string(self, value_type: ValueType, width: int) -> str: return f"{value:<{width}}" +def _dump_value(records: List[Dict[str, Any]], header: List[str]) -> str: + values = [] + for record in records: + row = [] + for column_name in header: + row.append(str(record.get(column_name, ""))) + values.append(" ".join(row)) + return os.linesep.join(values) + + def _dump_table(records: List[Dict[str, Any]], header: List[str]) -> str: rows = [] for record in records: @@ -233,7 +243,12 @@ def _format_output( else: values, header = _convert_to_dict([records], columns, flatten) - if output_format == "table": + if output_format == "value": + if isinstance(records, list): + return _dump_value(values, header).strip() + else: + return str(values[0]).strip() + elif output_format == "table": return _dump_table(values, header).strip() elif output_format == "json": if isinstance(records, list): @@ -390,8 +405,8 @@ def add_arguments(self, parser: ArgumentParser) -> None: "-f", "--format", type=str, - choices=("json", "table", "yaml"), - default="table", + choices=("value", "json", "table", "yaml"), + default="value", help="Output format.", ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 0cbb67bd09..dec66e166d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -71,8 +71,11 @@ def _parse_output(output: str, output_format: str) -> Any: For table format, a list of dict formatted rows. For JSON or YAML format, a list or a dict corresponding to ``output``. """ - - if output_format == "table": + if output_format == "value": + # Currently, _parse_output with output_format="value" is used only for + # `study-names` command. + return [{"name": values} for values in output.split(os.linesep)] + elif output_format == "table": rows = output.split(os.linesep) assert all(len(rows[0]) == len(row) for row in rows) # Check ruled lines. @@ -323,7 +326,7 @@ def test_study_names_command(output_format: Optional[str]) -> None: if output_format is not None: command += ["--format", output_format] output = str(subprocess.check_output(command).decode().strip()) - study_names = _parse_output(output, output_format or "table") + study_names = _parse_output(output, output_format or "value") # Check user_attrs are not printed. assert len(study_names) == 1 @@ -345,7 +348,7 @@ def test_study_names_command(output_format: Optional[str]) -> None: if output_format is not None: command += ["--format", output_format] output = str(subprocess.check_output(command).decode().strip()) - study_names = _parse_output(output, output_format or "table") + study_names = _parse_output(output, output_format or "value") assert len(study_names) == 2 for i, study_name in enumerate(study_names): From 08db792a465ba3f477476aec912ce6e027246fc2 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 26 Oct 2023 16:39:36 +0900 Subject: [PATCH 3/4] Replace os.linesep with \n --- optuna/cli.py | 2 +- tests/test_cli.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/optuna/cli.py b/optuna/cli.py index 90f60c9fe3..d9bd5b1d24 100644 --- a/optuna/cli.py +++ b/optuna/cli.py @@ -193,7 +193,7 @@ def _dump_value(records: List[Dict[str, Any]], header: List[str]) -> str: for column_name in header: row.append(str(record.get(column_name, ""))) values.append(" ".join(row)) - return os.linesep.join(values) + return "\n".join(values) def _dump_table(records: List[Dict[str, Any]], header: List[str]) -> str: diff --git a/tests/test_cli.py b/tests/test_cli.py index dec66e166d..4764ce53e8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -74,9 +74,9 @@ def _parse_output(output: str, output_format: str) -> Any: if output_format == "value": # Currently, _parse_output with output_format="value" is used only for # `study-names` command. - return [{"name": values} for values in output.split(os.linesep)] + return [{"name": values} for values in output.split("\n")] elif output_format == "table": - rows = output.split(os.linesep) + rows = output.split("\n") assert all(len(rows[0]) == len(row) for row in rows) # Check ruled lines. assert rows[0] == rows[2] == rows[-1] From e152d69bd61aa5354647372fc5e50c8c91fee7a0 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:39:36 +0900 Subject: [PATCH 4/4] Replace "\n" with os.linesep --- tests/test_cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 4764ce53e8..dec66e166d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -74,9 +74,9 @@ def _parse_output(output: str, output_format: str) -> Any: if output_format == "value": # Currently, _parse_output with output_format="value" is used only for # `study-names` command. - return [{"name": values} for values in output.split("\n")] + return [{"name": values} for values in output.split(os.linesep)] elif output_format == "table": - rows = output.split("\n") + rows = output.split(os.linesep) assert all(len(rows[0]) == len(row) for row in rows) # Check ruled lines. assert rows[0] == rows[2] == rows[-1]