Skip to content

Commit

Permalink
Merge pull request #24 from panahiparham/scheduler-path
Browse files Browse the repository at this point in the history
Have scheduler pass results path to experiment file
  • Loading branch information
yasuiniko authored Nov 1, 2024
2 parents 2971444 + cffc428 commit f84ae8d
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 25 deletions.
9 changes: 5 additions & 4 deletions ml_experiment/Scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, exp_name: str, seeds: list[int], entry: str, version: Version
self.seeds = seeds
self.entry = entry
self.base_path = base or os.getcwd()
self.results_path = os.path.join(self.base_path, 'results', self.exp_name)
self.version = version if version is not None else -1

self.all_runs = set[RunSpec]() # TODO: polars dataframe!
Expand All @@ -51,11 +52,11 @@ def __repr__(self):
return f'Scheduler({self.exp_name}, {self.seeds}, {self.version}, {self.all_runs})'

def get_all_runs(self) -> Self:
res_path = os.path.join(self.base_path, 'results', self.exp_name, 'metadata.db')

meta = MetadataTableRegistry()

with sqlite3.connect(res_path) as con:
table_path = os.path.join(self.results_path, 'metadata.db')

with sqlite3.connect(table_path) as con:
cur = con.cursor()
parts = meta.get_parts(cur)
resloved_ver = self._resolve_version(parts, cur, meta)
Expand Down Expand Up @@ -97,7 +98,7 @@ def _run_local(self, c: LocalRunConfig) -> None:


def _run_single(self, r: RunSpec) -> None:
subprocess.run(['python', self.entry, '--part', r.part_name, '--config-id', str(r.config_id), '--seed', str(r.seed), '--version', str(r.version)])
subprocess.run(['python', self.entry, '--part', r.part_name, '--config-id', str(r.config_id), '--seed', str(r.seed), '--version', str(r.version), '--results-path', self.results_path])


def _resolve_version(
Expand Down
Empty file.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[tool]
[tool.setuptools]
packages = ['ml_experiment']
[tool.setuptools.packages.find]
where = ["."] # list of folders that contain the packages (["."] by default)
include = ["ml_experiment*"] # package names should match these glob patterns (["*"] by default)

[tool.commitizen]
name = "cz_conventional_commits"
Expand Down
2 changes: 1 addition & 1 deletion tests/acceptance/my_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
parser.add_argument("--config-id", type=int, required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--version", type=int, required=True)
parser.add_argument("--results_path", type=str, required=True)
parser.add_argument("--results-path", type=str, required=True)

class SoftmaxAC:
def __init__(
Expand Down
20 changes: 2 additions & 18 deletions tests/acceptance/test_softmaxAC_mc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
import pytest
import subprocess

from ml_experiment.ExperimentDefinition import ExperimentDefinition
from ml_experiment.DefinitionPart import DefinitionPart
from ml_experiment.Scheduler import LocalRunConfig, RunSpec, Scheduler
from ml_experiment.Scheduler import LocalRunConfig, Scheduler


@pytest.fixture
Expand Down Expand Up @@ -34,20 +33,6 @@ def write_database(tmp_path, alphas: list[float], taus: list[float]):

return softmaxAC


# overwrite the run_single function
class StubScheduler(Scheduler):

# allows us to force the results path to be in a specific spot
def __init__(self, results_path: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.results_path = results_path

# adding the results path to the command
def _run_single(self: Scheduler, r: RunSpec) -> None:
subprocess.run(['python', self.entry, '--part', r.part_name, '--config-id', str(r.config_id), '--seed', str(r.seed), '--version', str(r.version), '--results_path', self.results_path])


def test_read_database(tmp_path, base_path):
"""
Test that we can retrieve the configurations from the experiment definition.
Expand Down Expand Up @@ -145,12 +130,11 @@ def test_run_tasks(tmp_path):
run_conf = LocalRunConfig(tasks_in_parallel=ntasks, log_path=".logs/")

# set up scheduler
sched = StubScheduler(
sched = Scheduler(
exp_name=exp_name,
entry=experiment_file_name,
seeds=[seed_num],
version=version_num,
results_path=results_path,
base = str(tmp_path),
)

Expand Down

0 comments on commit f84ae8d

Please sign in to comment.