Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scheduler #10

Merged
merged 3 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions ml_experiment/Scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from __future__ import annotations

import sys
import os
sys.path.append(os.getcwd())

import sqlite3
from typing import Self, Callable, NamedTuple
from dataclasses import dataclass
from itertools import product
from multiprocessing.pool import Pool
import subprocess
from ml_experiment.metadata.MetadataTableRegistry import MetadataTableRegistry

class RunSpec(NamedTuple):
part_name: str
version: int
config_id: int
seed: int


@dataclass
class RunConfig:
...
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit does not contain the slurm functionality but the parent RunConfig is still defined here. I guess this is fine?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I think that's just fine!

I typically try to avoid defining generic interfaces before they are needed, but in this case we know that this will be needed in a follow-up PR.



@dataclass
class LocalRunConfig(RunConfig):
tasks_in_parallel: int

log_path: str = ".logs/"



Pred = Callable[[str, int, int, int], bool]
VersionSpec = int | dict[str, int | None] | None

class Scheduler:
def __init__(self, exp_name: str, seeds: list[int], entry: str, version: VersionSpec = None, base: str | None = None):
self.exp_name = exp_name
self.seeds = seeds
self.entry = entry
self.base_path = base or os.getcwd()
self.version = version if version is not None else -1

self.all_runs = set[RunSpec]() # TODO: polars dataframe!

self._sanity_check()

def __repr__(self):
return f'Scheduler({self.exp_name}, {self.seeds}, {self.version}, {self.all_runs})'

def get_all_runs(self) -> Self:
res_path = os.path.join(self.base_path, 'results', self.exp_name, 'metadata.db')

meta = MetadataTableRegistry()

with sqlite3.connect(res_path) as con:
cur = con.cursor()
parts = meta.get_parts(cur)
resloved_ver = self._resolve_version(parts, cur, meta)

for k, v in resloved_ver.items():
t = meta.get_table(cur, k, v)
assert t is not None
config_ids = t.get_configuration_ids(cur)
self.all_runs |= {RunSpec(k, v, c, s) for c, s in product(config_ids, self.seeds)}

return self


def filter(self, already_exists: Pred) -> Scheduler:
filtered = Scheduler(self.exp_name, self.seeds, self.entry, self.version, self.base_path)

for r in self.all_runs:
if not already_exists(*r):
filtered.all_runs.add(r)

return filtered


def run(self, c: RunConfig) -> None:
type(c)
if isinstance(c, LocalRunConfig):
self._run_local(c)
else:
raise ValueError('Unknown RunConfig type')


# ----------------------
# -- Internal Methods --
# ----------------------

def _run_local(self, c: LocalRunConfig) -> None:
pool = Pool(c.tasks_in_parallel)
pool.map(self._run_single, self.all_runs)


def _run_single(self, r: RunSpec) -> None:
subprocess.run(['python', self.entry, '--part', r.part_name, '--config-id', str(r.config_id), '--seed', str(r.seed), '--version', str(r.version)])
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the upcoming steps we should work on a sample main file that uses these parameters to properly setup the ml_instrumentation library for data saving and loading.



def _resolve_version(
self,
parts: set[str],
cur: sqlite3.Cursor,
meta: MetadataTableRegistry,
) -> dict[str, int]:
if isinstance(self.version, int):
_r = {p: self.version for p in parts}
elif isinstance(self.version, dict):
_r = {}
for p in parts:
v = self.version.get(p, -1)
_r[p] = v if v is not None else -1
else:
_r = {p: -1 for p in parts}

for k, v in _r.items():
if v == -1:
table = meta.get_latest_version(cur, k)
assert table is not None
_r[k] = table.version

return _r

def _sanity_check(self):
res_path = os.path.join(self.base_path, 'results', self.exp_name, 'metadata.db')
assert os.path.exists(res_path), f'{self.exp_name} does not exist'


12 changes: 12 additions & 0 deletions ml_experiment/_utils/path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os

def get_experiment_name():
import __main__
return __main__.__file__.split('/')[-2]

def get_results_path(base_path: str) -> str:
return os.path.join(
base_path,
'results',
get_experiment_name(),
)
5 changes: 5 additions & 0 deletions ml_experiment/metadata/MetadataTableRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ def __init__(self):
self._tables: Dict[str, MetadataTable] = {}


def get_parts(self, cur: sqlite3.Cursor) -> set[str]:
n = sqlu.get_tables(cur)
return set(p.split('-')[0] for p in n)


def get_table(self, cur: sqlite3.Cursor, part_name: str, version: int) -> MetadataTable | None:
table_name = f'{part_name}-v{version}'

Expand Down
Loading