diff --git a/tests/acceptance/my_experiment.py b/tests/acceptance/my_experiment.py index 27637b7..d956dc1 100644 --- a/tests/acceptance/my_experiment.py +++ b/tests/acceptance/my_experiment.py @@ -1,4 +1,5 @@ import argparse +import os import random from ml_experiment.ExperimentDefinition import ExperimentDefinition @@ -8,10 +9,6 @@ parser.add_argument("--config-id", type=int, required=True) parser.add_argument("--seed", type=int, required=True) parser.add_argument("--version", type=int, required=True) - -# this is an extra argument for testing -# we need to be able to find the metadata database -# TODO: remove this parser.add_argument("--results_path", type=str, required=True) class SoftmaxAC: @@ -61,8 +58,10 @@ def main(): # run the agent output = agent.run() - # test output - assert output == f"SoftmaxAC({alpha}, {tau}, {n_step}, {tiles}, {tilings})" + # write the output to a file + output_path = os.path.join(cmdline.results_path, f"output_{cmdline.config_id}.txt") + with open(output_path, "w") as f: + f.write(output) if __name__ == "__main__": main() diff --git a/tests/acceptance/test_softmaxAC_mc.py b/tests/acceptance/test_softmaxAC_mc.py index 7566f77..1bf30c4 100644 --- a/tests/acceptance/test_softmaxAC_mc.py +++ b/tests/acceptance/test_softmaxAC_mc.py @@ -9,7 +9,8 @@ def write_database(results_path, alphas: list[float], taus: list[float]): # make table writer softmaxAC = DefinitionPart("softmaxAC") - # TODO: don't overwrite this + + # overwrite the results path to be in the temp directory softmaxAC.get_results_path = lambda *args, **kwargs: results_path # add properties to sweep @@ -30,7 +31,6 @@ def write_database(results_path, alphas: list[float], taus: list[float]): return softmaxAC -## TODO: remove this and use the actual scheduler # overwrite the run_single function class StubScheduler(Scheduler): @@ -82,10 +82,11 @@ def test_read_database(tmp_path): softmaxAC_mc = ExperimentDefinition( part_name="softmaxAC", version=0 ) - # TODO: don't overwrite this + + # overwrite the results path to be in the temp directory softmaxAC_mc.get_results_path = lambda *args, **kwargs: results_path - # TODO: This can't be the intended way to get the configurations? + # get the configuration ids num_configs = len(alphas) * len(taus) config_ids = list(range(num_configs)) @@ -99,6 +100,27 @@ def test_run_tasks(tmp_path): # setup alphas = [0.05, 0.01] taus = [10.0, 20.0, 5.0] + seed_num = 10 + version_num = 0 + + # expected outputs + partial_configs = ( + { + "alpha": a, + "tau": t, + "n_step": 1, + "tiles": 4, + "tilings": 16, + "total_steps": 100000, + "episode_cutoff": 5000, + "seed": 10, + } + for a in alphas + for t in taus + ) + expected_configs = {i : config for i, config in enumerate(partial_configs)} + + # write experiment definition to table results_path = os.path.join(tmp_path, "results", "temp") db = write_database(results_path, alphas, taus) @@ -123,15 +145,37 @@ def test_run_tasks(tmp_path): sched = StubScheduler( exp_name="temp", entry=experiment_file_name, - seeds=[10], - version=0, + seeds=[seed_num], + version=version_num, results_path=results_path, base = str(tmp_path), ) # run all the tasks - ( - sched - .get_all_runs() - .run(run_conf) - ) + sched = sched.get_all_runs() + sched.run(run_conf) + + # make sure there are the correct amount of runs + assert len(sched.all_runs) == len(expected_configs.keys()) + + # check that the output files were created + for runspec in sched.all_runs: + + # sanity check: make sure the runspec uses the hardcoded part, version, and seed + assert runspec.part_name == "softmaxAC" + assert runspec.version == version_num + assert runspec.seed == seed_num + + # get the expected output + expected_config = expected_configs[runspec.config_id] + expected_output = f"SoftmaxAC({expected_config['alpha']}, {expected_config['tau']}, {expected_config['n_step']}, {expected_config['tiles']}, {expected_config['tilings']})" + + # check that the output file was created + output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt") + with open(output_path, "r") as f: + output = f.read() + assert output.strip() == expected_output + + + +