diff --git a/tests/acceptance/my_experiment.py b/tests/acceptance/my_experiment.py new file mode 100644 index 0000000..27637b7 --- /dev/null +++ b/tests/acceptance/my_experiment.py @@ -0,0 +1,68 @@ +import argparse +import random + +from ml_experiment.ExperimentDefinition import ExperimentDefinition + +parser = argparse.ArgumentParser() +parser.add_argument("--part", type=str, required=True) +parser.add_argument("--config-id", type=int, required=True) +parser.add_argument("--seed", type=int, required=True) +parser.add_argument("--version", type=int, required=True) + +# this is an extra argument for testing +# we need to be able to find the metadata database +# TODO: remove this +parser.add_argument("--results_path", type=str, required=True) + +class SoftmaxAC: + def __init__( + self, + alpha: float, + tau: float, + nstep: float, + tiles: int, + tilings: int, + ) -> None: + self.alpha = alpha + self.tau = tau + self.nstep = nstep + self.tiles = tiles + self.tilings = tilings + self.name = "SoftmaxAC" + + def run(self) -> str: + return f"{self.name}({self.alpha}, {self.tau}, {self.nstep}, {self.tiles}, {self.tilings})" + + +def main(): + cmdline = parser.parse_args() + + # make sure we are using softmaxAC + if cmdline.part != "softmaxAC": + raise ValueError(f"Unknown part: {cmdline.part}") + + # do some rng control + random.seed(cmdline.seed) + + # extract configs from the database + exp = ExperimentDefinition("softmaxAC", cmdline.version) + # TODO: don't overwrite this + exp.get_results_path = lambda *args, **kwargs: cmdline.results_path # overwrite results path + config = exp.get_config(cmdline.config_id) + + # make our dummy agent + alpha = config["alpha"] + tau = config["tau"] + n_step = config["n_step"] + tiles = config["tiles"] + tilings = config["tilings"] + agent = SoftmaxAC(alpha, tau, n_step, tiles, tilings) + + # run the agent + output = agent.run() + + # test output + assert output == f"SoftmaxAC({alpha}, {tau}, {n_step}, {tiles}, {tilings})" + +if __name__ == "__main__": + main() diff --git a/tests/acceptance/test_softmaxAC_mc.py b/tests/acceptance/test_softmaxAC_mc.py index 96be810..7566f77 100644 --- a/tests/acceptance/test_softmaxAC_mc.py +++ b/tests/acceptance/test_softmaxAC_mc.py @@ -1,25 +1,56 @@ +import os +import subprocess + from ml_experiment.ExperimentDefinition import ExperimentDefinition from ml_experiment.DefinitionPart import DefinitionPart +from ml_experiment.Scheduler import LocalRunConfig, RunSpec, Scheduler + +def write_database(results_path, alphas: list[float], taus: list[float]): + # make table writer + softmaxAC = DefinitionPart("softmaxAC") + # TODO: don't overwrite this + softmaxAC.get_results_path = lambda *args, **kwargs: results_path -def init_softmaxAC_mc(tmp_path, alphas: list[float], taus: list[float]): - softmaxAC = DefinitionPart("softmaxAC-mc", base=str(tmp_path)) + # add properties to sweep softmaxAC.add_sweepable_property("alpha", alphas) softmaxAC.add_sweepable_property("tau", taus) + + # add properties that are static softmaxAC.add_property("n_step", 1) softmaxAC.add_property("tiles", 4) softmaxAC.add_property("tilings", 16) softmaxAC.add_property("total_steps", 100000) softmaxAC.add_property("episode_cutoff", 5000) + softmaxAC.add_property("seed", 10) + + # write the properties to the database softmaxAC.commit() + return softmaxAC + + +## TODO: remove this and use the actual scheduler +# overwrite the run_single function +class StubScheduler(Scheduler): -def test_read_configs(tmp_path): + # allows us to force the results path to be in a specific spot + def __init__(self, results_path: str, *args, **kwargs): + super().__init__(*args, **kwargs) + self.results_path = results_path + + # adding the results path to the command + def _run_single(self: Scheduler, r: RunSpec) -> None: + subprocess.run(['python', self.entry, '--part', r.part_name, '--config-id', str(r.config_id), '--seed', str(r.seed), '--version', str(r.version), '--results_path', self.results_path]) + + +def test_read_database(tmp_path): """ Test that we can retrieve the configurations from the experiment definition. """ # expected outputs + results_path = os.path.join(tmp_path, "results", "temp") alphas = [0.05, 0.01] taus = [10.0, 20.0, 5.0] partial_configs = ( @@ -31,6 +62,7 @@ def test_read_configs(tmp_path): "tilings": 16, "total_steps": 100000, "episode_cutoff": 5000, + "seed": 10, } for a in alphas for t in taus @@ -44,10 +76,14 @@ def test_read_configs(tmp_path): ) # write experiment definition to table - init_softmaxAC_mc(tmp_path, alphas, taus) + write_database(results_path, alphas, taus) # make Experiment object (versions start at 0) - softmaxAC_mc = ExperimentDefinition(part_name="softmaxAC-mc", version=0, base=str(tmp_path)) + softmaxAC_mc = ExperimentDefinition( + part_name="softmaxAC", version=0 + ) + # TODO: don't overwrite this + softmaxAC_mc.get_results_path = lambda *args, **kwargs: results_path # TODO: This can't be the intended way to get the configurations? num_configs = len(alphas) * len(taus) @@ -56,3 +92,46 @@ def test_read_configs(tmp_path): for cid, expected_config in zip(config_ids, expected_configs, strict=True): config = softmaxAC_mc.get_config(cid) assert config == expected_config + + +def test_run_tasks(tmp_path): + """Make sure that the scheduler runs all the tasks, and that they return the correct results.""" + # setup + alphas = [0.05, 0.01] + taus = [10.0, 20.0, 5.0] + results_path = os.path.join(tmp_path, "results", "temp") + db = write_database(results_path, alphas, taus) + + assert db.name == "softmaxAC" + assert os.path.exists(os.path.join(results_path, "metadata.db")) + + # get number of tasks to run in parallel + try: + import multiprocessing + + ntasks = multiprocessing.cpu_count() - 1 + except (ImportError, NotImplementedError): + ntasks = 1 + + # initialize run config + run_conf = LocalRunConfig(tasks_in_parallel=ntasks, log_path=".logs/") + + # set up experiment file + experiment_file_name = "tests/acceptance/my_experiment.py" + + # set up scheduler + sched = StubScheduler( + exp_name="temp", + entry=experiment_file_name, + seeds=[10], + version=0, + results_path=results_path, + base = str(tmp_path), + ) + + # run all the tasks + ( + sched + .get_all_runs() + .run(run_conf) + )