diff --git a/tests/acceptance/my_experiment.py b/tests/acceptance/my_experiment.py index 5588eb7..1dd0a5f 100644 --- a/tests/acceptance/my_experiment.py +++ b/tests/acceptance/my_experiment.py @@ -34,15 +34,11 @@ def run(self) -> str: def main(): cmdline = parser.parse_args() - # make sure we are using softmaxAC - if cmdline.part != "softmaxAC": - raise ValueError(f"Unknown part: {cmdline.part}") - # do some rng control random.seed(cmdline.seed) # extract configs from the database - exp = ExperimentDefinition("softmaxAC", cmdline.version) + exp = ExperimentDefinition(cmdline.part, cmdline.version) # TODO: don't overwrite this exp.get_results_path = lambda *args, **kwargs: cmdline.results_path # overwrite results path config = exp.get_config(cmdline.config_id) diff --git a/tests/acceptance/test_softmaxAC_mc.py b/tests/acceptance/test_acceptance.py similarity index 54% rename from tests/acceptance/test_softmaxAC_mc.py rename to tests/acceptance/test_acceptance.py index 9869bd3..7d0deb0 100644 --- a/tests/acceptance/test_softmaxAC_mc.py +++ b/tests/acceptance/test_acceptance.py @@ -5,33 +5,34 @@ from ml_experiment.experiment_definition import ExperimentDefinition from ml_experiment.Scheduler import LocalRunConfig, Scheduler +DATABASE_NAME = "MyHyperparameterDatabase" @pytest.fixture def base_path(request): - """Overwrite the __main__.__file__ to be the path to the current file. This allows _utils.get_experiment_name to look at ./tests/acceptance/this_file.py rather than ./.venv/bin/pytest.""" + """Overwrite the __main__.__file__ to be the path to the current file. This allows _utils.get_experiment_name to look at ./tests/acceptance/this_file.py rather than the default of ./.venv/bin/pytest.""" import __main__ __main__.__file__ = request.path.__fspath__() def write_database(tmp_path, alphas: list[float], taus: list[float]): - # make table writer - softmaxAC = DefinitionPart("softmaxAC", base=str(tmp_path)) + # make database writer + db_writer = DefinitionPart(DATABASE_NAME, base=str(tmp_path)) # add properties to sweep - softmaxAC.add_sweepable_property("alpha", alphas) - softmaxAC.add_sweepable_property("tau", taus) + db_writer.add_sweepable_property("alpha", alphas) + db_writer.add_sweepable_property("tau", taus) # add properties that are static - softmaxAC.add_property("n_step", 1) - softmaxAC.add_property("tiles", 4) - softmaxAC.add_property("tilings", 16) - softmaxAC.add_property("total_steps", 100000) - softmaxAC.add_property("episode_cutoff", 5000) - softmaxAC.add_property("seed", 10) + db_writer.add_property("n_step", 1) + db_writer.add_property("tiles", 4) + db_writer.add_property("tilings", 16) + db_writer.add_property("total_steps", 100000) + db_writer.add_property("episode_cutoff", 5000) + db_writer.add_property("seed", 10) # write the properties to the database - softmaxAC.commit() + db_writer.commit() - return softmaxAC + return db_writer def test_read_database(tmp_path, base_path): """ @@ -68,7 +69,7 @@ def test_read_database(tmp_path, base_path): # make Experiment object (versions start at 0) softmaxAC_mc = ExperimentDefinition( - part_name="softmaxAC", version=0, base=str(tmp_path) + part_name=DATABASE_NAME, version=0, base=str(tmp_path) ) # get the configuration ids @@ -104,7 +105,9 @@ def test_run_tasks(tmp_path): for a in alphas for t in taus ) - expected_configs = {i : config for i, config in enumerate(partial_configs)} + # filter out the tau == 5 configs + configs_with_tau5 = {i : config for i, config in enumerate(partial_configs)} + expected_configs = {k:v for k,v in configs_with_tau5.items() if v["tau"] != 5} # set experiment file name experiment_file_name = f"tests/{exp_name}/my_experiment.py" @@ -113,9 +116,9 @@ def test_run_tasks(tmp_path): results_path = os.path.join(tmp_path, "results", f"{exp_name}") # write experiment definition to table - db = write_database(tmp_path, alphas, taus) + db_writer = write_database(tmp_path, alphas, taus) - assert db.name == "softmaxAC" + assert db_writer.name == DATABASE_NAME assert os.path.exists(os.path.join(results_path, "metadata.db")) # get number of tasks to run in parallel @@ -138,30 +141,62 @@ def test_run_tasks(tmp_path): base = str(tmp_path), ) - # run all the tasks - sched = sched.get_all_runs() + def pred(part_name, version, config_id, seed) -> bool: + """Fn to filter out tasks where the output is True. + In this case we check whether tau = 5. + + Note: The ExperimentDefinition base path is coming from the closure. + Note: We have to make an ExperimentDefinition here to get the config values, or pass one in via closure. + + Args: + part_name (str): part name + version (int): version number + config_id (int): configuration id + seed (int): seed + + Returns: + bool: True if tau == 5, False otherwise + """ + # make database reader to confirm config values + db_reader = ExperimentDefinition(part_name, version, base=str(tmp_path)) + config = db_reader.get_config(config_id) + return config["tau"] == 5 + + # get all the runs, including tau == 5 + all_sched = sched.get_all_runs() + + # run the filtered tasks + sched = all_sched.filter(pred) sched.run(run_conf) # make sure there are the correct amount of runs assert len(sched.all_runs) == len(expected_configs.keys()) # check that the output files were created - for runspec in sched.all_runs: + for runspec in all_sched.all_runs: # sanity check: make sure the runspec uses the hardcoded part, version, and seed - assert runspec.part_name == "softmaxAC" + assert runspec.part_name == DATABASE_NAME assert runspec.version == version_num assert runspec.seed == seed_num - # get the expected output - expected_config = expected_configs[runspec.config_id] - expected_output = f"SoftmaxAC({expected_config['alpha']}, {expected_config['tau']}, {expected_config['n_step']}, {expected_config['tiles']}, {expected_config['tilings']})" - - # check that the output file was created - output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt") - with open(output_path, "r") as f: - output = f.read() - assert output.strip() == expected_output + # in this case we filtered out the runs + if pred(*runspec): + # check that the output file was not created + output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt") + assert not os.path.exists(output_path) + + # in this case we did not filter out the runs, and we should have an output file + else: + # get the expected output + expected_config = expected_configs[runspec.config_id] + expected_output = f"SoftmaxAC({expected_config['alpha']}, {expected_config['tau']}, {expected_config['n_step']}, {expected_config['tiles']}, {expected_config['tilings']})" + + # check that the output file was created + output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt") + with open(output_path, "r") as f: + output = f.read() + assert output.strip() == expected_output