Skip to content

Commit

Permalink
Merge pull request optuna#5029 from Alnusjaponica/implement-study-nam…
Browse files Browse the repository at this point in the history
…e-cli

Add `optuna study-names` cli
  • Loading branch information
not522 authored Nov 6, 2023
2 parents 16e07c7 + e152d69 commit e5835b8
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 3 deletions.
42 changes: 41 additions & 1 deletion optuna/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,16 @@ def get_string(self, value_type: ValueType, width: int) -> str:
return f"{value:<{width}}"


def _dump_value(records: List[Dict[str, Any]], header: List[str]) -> str:
values = []
for record in records:
row = []
for column_name in header:
row.append(str(record.get(column_name, "")))
values.append(" ".join(row))
return "\n".join(values)


def _dump_table(records: List[Dict[str, Any]], header: List[str]) -> str:
rows = []
for record in records:
Expand Down Expand Up @@ -233,7 +243,12 @@ def _format_output(
else:
values, header = _convert_to_dict([records], columns, flatten)

if output_format == "table":
if output_format == "value":
if isinstance(records, list):
return _dump_value(values, header).strip()
else:
return str(values[0]).strip()
elif output_format == "table":
return _dump_table(values, header).strip()
elif output_format == "json":
if isinstance(records, list):
Expand Down Expand Up @@ -382,6 +397,30 @@ def take_action(self, parsed_args: Namespace) -> int:
return 0


class _StudyNames(_BaseCommand):
"""Get all study names stored in a specified storage"""

def add_arguments(self, parser: ArgumentParser) -> None:
parser.add_argument(
"-f",
"--format",
type=str,
choices=("value", "json", "table", "yaml"),
default="value",
help="Output format.",
)

def take_action(self, parsed_args: Namespace) -> int:
storage = _get_storage(parsed_args.storage, parsed_args.storage_class)
all_study_names = optuna.get_all_study_names(storage)
records = []
record_key = ("name", "")
for study_name in all_study_names:
records.append({record_key: study_name})
print(_format_output(records, [record_key], parsed_args.format, flatten=False))
return 0


class _Studies(_BaseCommand):
"""Show a list of studies."""

Expand Down Expand Up @@ -917,6 +956,7 @@ def take_action(self, parsed_args: Namespace) -> int:
"create-study": _CreateStudy,
"delete-study": _DeleteStudy,
"study set-user-attr": _StudySetUserAttribute,
"study-names": _StudyNames,
"studies": _Studies,
"trials": _Trials,
"best-trial": _BestTrial,
Expand Down
72 changes: 70 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ def _parse_output(output: str, output_format: str) -> Any:
For table format, a list of dict formatted rows.
For JSON or YAML format, a list or a dict corresponding to ``output``.
"""

if output_format == "table":
if output_format == "value":
# Currently, _parse_output with output_format="value" is used only for
# `study-names` command.
return [{"name": values} for values in output.split(os.linesep)]
elif output_format == "table":
rows = output.split(os.linesep)
assert all(len(rows[0]) == len(row) for row in rows)
# Check ruled lines.
Expand Down Expand Up @@ -297,6 +300,71 @@ def test_study_set_user_attr_command() -> None:
assert all(study_user_attrs[k] == v for k, v in example_attrs.items())


@pytest.mark.skip_coverage
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
def test_study_names_command(output_format: Optional[str]) -> None:
with StorageSupplier("sqlite") as storage:
assert isinstance(storage, RDBStorage)
storage_url = str(storage.engine.url)

expected_study_names = ["study-names-test1", "study-names-test2"]
expected_column_name = "name"

# Create a study.
command = [
"optuna",
"create-study",
"--storage",
storage_url,
"--study-name",
expected_study_names[0],
]
subprocess.check_output(command)

# Get study names.
command = ["optuna", "study-names", "--storage", storage_url]
if output_format is not None:
command += ["--format", output_format]
output = str(subprocess.check_output(command).decode().strip())
study_names = _parse_output(output, output_format or "value")

# Check user_attrs are not printed.
assert len(study_names) == 1
assert study_names[0]["name"] == expected_study_names[0]

# Create another study.
command = [
"optuna",
"create-study",
"--storage",
storage_url,
"--study-name",
expected_study_names[1],
]
subprocess.check_output(command)

# Get study names.
command = ["optuna", "study-names", "--storage", storage_url]
if output_format is not None:
command += ["--format", output_format]
output = str(subprocess.check_output(command).decode().strip())
study_names = _parse_output(output, output_format or "value")

assert len(study_names) == 2
for i, study_name in enumerate(study_names):
assert list(study_name.keys()) == [expected_column_name]
assert study_name["name"] == expected_study_names[i]


@pytest.mark.skip_coverage
def test_study_names_command_without_storage_url() -> None:
with pytest.raises(subprocess.CalledProcessError):
subprocess.check_output(
["optuna", "study-names", "--study-name", "dummy_study"],
env={k: v for k, v in os.environ.items() if k != "OPTUNA_STORAGE"},
)


@pytest.mark.skip_coverage
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
def test_studies_command(output_format: Optional[str]) -> None:
Expand Down

0 comments on commit e5835b8

Please sign in to comment.