Skip to content

Commit

Permalink
fix (metadata): Fix broken CLI arg - trails-folder
Browse files Browse the repository at this point in the history
Fixes: #45

add a test for it too
  • Loading branch information
Jacob-Stevens-Haas committed May 1, 2024
1 parent d71e6e0 commit 15fdafc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
4 changes: 3 additions & 1 deletion mitosis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def create_step(
for step_name in all_steps.keys()
]

folder = _disk.locate_trial_folder(trials_folder=args.folder, proj_file=args.config)
folder = _disk.locate_trial_folder(
trials_folder=args.trials_folder, proj_file=args.config
)
return {
"steps": exp_steps,
"debug": args.debug,
Expand Down
24 changes: 22 additions & 2 deletions mitosis/tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import os
import sys
from argparse import Namespace
from io import StringIO
Expand Down Expand Up @@ -28,7 +29,7 @@ def test_legacy_module(params, eval_params):
module="mitosis.tests.mock_legacy",
debug=True,
config="pyproject.toml",
folder=None,
trials_folder=None,
eval_param=eval_params,
param=params,
)
Expand All @@ -45,7 +46,7 @@ def test_experiment_arg():
module=None,
debug=True,
config="mitosis/tests/test_pyproject.toml",
folder=None,
trials_folder=None,
eval_param=["data.extra=True"],
param=["data.length=test", "fit_eval.metric=test"],
)
Expand All @@ -56,6 +57,25 @@ def test_experiment_arg():
assert id(result["steps"][1].lookup) == id(meth_config)


def test_folder_arg(tmp_path):
parser = _create_parser()
args = parser.parse_args(["-m", "mitosis.tests.mock_legacy", "-F", "foo"])

@contextlib.contextmanager
def change_cwd(new_pth: os.PathLike) -> Generator[None, None, None]:
temp = os.getcwd()
try:
os.chdir(new_pth)
yield
finally:
os.chdir(temp)

with change_cwd(tmp_path):
result = _process_cl_args(args)["trials_folder"]
expected = tmp_path / "foo"
assert result == expected


def test_argparse_options():
parser = _create_parser()
args = parser.parse_args(
Expand Down

0 comments on commit 15fdafc

Please sign in to comment.