diff --git a/ml_experiment/Scheduler.py b/ml_experiment/Scheduler.py index 8d8d5a4..de40ec1 100644 --- a/ml_experiment/Scheduler.py +++ b/ml_experiment/Scheduler.py @@ -41,6 +41,7 @@ def __init__(self, exp_name: str, seeds: list[int], entry: str, version: Version self.seeds = seeds self.entry = entry self.base_path = base or os.getcwd() + self.results_path = os.path.join(self.base_path, 'results', self.exp_name) self.version = version if version is not None else -1 self.all_runs = set[RunSpec]() # TODO: polars dataframe! @@ -51,11 +52,11 @@ def __repr__(self): return f'Scheduler({self.exp_name}, {self.seeds}, {self.version}, {self.all_runs})' def get_all_runs(self) -> Self: - res_path = os.path.join(self.base_path, 'results', self.exp_name, 'metadata.db') - meta = MetadataTableRegistry() - with sqlite3.connect(res_path) as con: + table_path = os.path.join(self.results_path, 'metadata.db') + + with sqlite3.connect(table_path) as con: cur = con.cursor() parts = meta.get_parts(cur) resloved_ver = self._resolve_version(parts, cur, meta) @@ -97,7 +98,7 @@ def _run_local(self, c: LocalRunConfig) -> None: def _run_single(self, r: RunSpec) -> None: - subprocess.run(['python', self.entry, '--part', r.part_name, '--config-id', str(r.config_id), '--seed', str(r.seed), '--version', str(r.version)]) + subprocess.run(['python', self.entry, '--part', r.part_name, '--config-id', str(r.config_id), '--seed', str(r.seed), '--version', str(r.version), '--results-path', self.results_path]) def _resolve_version( diff --git a/ml_experiment/metadata/__init__.py b/ml_experiment/metadata/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index d2c6248..de6775f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,7 @@ [tool] -[tool.setuptools] -packages = ['ml_experiment'] +[tool.setuptools.packages.find] +where = ["."] # list of folders that contain the packages (["."] by default) +include = ["ml_experiment*"] # package names should match these glob patterns (["*"] by default) [tool.commitizen] name = "cz_conventional_commits" diff --git a/tests/acceptance/my_experiment.py b/tests/acceptance/my_experiment.py index d956dc1..1d863f0 100644 --- a/tests/acceptance/my_experiment.py +++ b/tests/acceptance/my_experiment.py @@ -9,7 +9,7 @@ parser.add_argument("--config-id", type=int, required=True) parser.add_argument("--seed", type=int, required=True) parser.add_argument("--version", type=int, required=True) -parser.add_argument("--results_path", type=str, required=True) +parser.add_argument("--results-path", type=str, required=True) class SoftmaxAC: def __init__( diff --git a/tests/acceptance/test_softmaxAC_mc.py b/tests/acceptance/test_softmaxAC_mc.py index 37eebc7..e1f6184 100644 --- a/tests/acceptance/test_softmaxAC_mc.py +++ b/tests/acceptance/test_softmaxAC_mc.py @@ -1,10 +1,9 @@ import os import pytest -import subprocess from ml_experiment.ExperimentDefinition import ExperimentDefinition from ml_experiment.DefinitionPart import DefinitionPart -from ml_experiment.Scheduler import LocalRunConfig, RunSpec, Scheduler +from ml_experiment.Scheduler import LocalRunConfig, Scheduler @pytest.fixture @@ -34,20 +33,6 @@ def write_database(tmp_path, alphas: list[float], taus: list[float]): return softmaxAC - -# overwrite the run_single function -class StubScheduler(Scheduler): - - # allows us to force the results path to be in a specific spot - def __init__(self, results_path: str, *args, **kwargs): - super().__init__(*args, **kwargs) - self.results_path = results_path - - # adding the results path to the command - def _run_single(self: Scheduler, r: RunSpec) -> None: - subprocess.run(['python', self.entry, '--part', r.part_name, '--config-id', str(r.config_id), '--seed', str(r.seed), '--version', str(r.version), '--results_path', self.results_path]) - - def test_read_database(tmp_path, base_path): """ Test that we can retrieve the configurations from the experiment definition. @@ -145,12 +130,11 @@ def test_run_tasks(tmp_path): run_conf = LocalRunConfig(tasks_in_parallel=ntasks, log_path=".logs/") # set up scheduler - sched = StubScheduler( + sched = Scheduler( exp_name=exp_name, entry=experiment_file_name, seeds=[seed_num], version=version_num, - results_path=results_path, base = str(tmp_path), )