From 5895dd809f8cb1650c1d9e96b770077a4bff4c5d Mon Sep 17 00:00:00 2001 From: Niko Yasui Date: Tue, 5 Nov 2024 15:27:39 -0700 Subject: [PATCH] test: add filtering to the acceptance test --- tests/acceptance/test_acceptance.py | 64 ++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/tests/acceptance/test_acceptance.py b/tests/acceptance/test_acceptance.py index 16309c4..7d0deb0 100644 --- a/tests/acceptance/test_acceptance.py +++ b/tests/acceptance/test_acceptance.py @@ -105,7 +105,9 @@ def test_run_tasks(tmp_path): for a in alphas for t in taus ) - expected_configs = {i : config for i, config in enumerate(partial_configs)} + # filter out the tau == 5 configs + configs_with_tau5 = {i : config for i, config in enumerate(partial_configs)} + expected_configs = {k:v for k,v in configs_with_tau5.items() if v["tau"] != 5} # set experiment file name experiment_file_name = f"tests/{exp_name}/my_experiment.py" @@ -114,9 +116,9 @@ def test_run_tasks(tmp_path): results_path = os.path.join(tmp_path, "results", f"{exp_name}") # write experiment definition to table - db = write_database(tmp_path, alphas, taus) + db_writer = write_database(tmp_path, alphas, taus) - assert db.name == DATABASE_NAME + assert db_writer.name == DATABASE_NAME assert os.path.exists(os.path.join(results_path, "metadata.db")) # get number of tasks to run in parallel @@ -139,30 +141,62 @@ def test_run_tasks(tmp_path): base = str(tmp_path), ) - # run all the tasks - sched = sched.get_all_runs() + def pred(part_name, version, config_id, seed) -> bool: + """Fn to filter out tasks where the output is True. + In this case we check whether tau = 5. + + Note: The ExperimentDefinition base path is coming from the closure. + Note: We have to make an ExperimentDefinition here to get the config values, or pass one in via closure. + + Args: + part_name (str): part name + version (int): version number + config_id (int): configuration id + seed (int): seed + + Returns: + bool: True if tau == 5, False otherwise + """ + # make database reader to confirm config values + db_reader = ExperimentDefinition(part_name, version, base=str(tmp_path)) + config = db_reader.get_config(config_id) + return config["tau"] == 5 + + # get all the runs, including tau == 5 + all_sched = sched.get_all_runs() + + # run the filtered tasks + sched = all_sched.filter(pred) sched.run(run_conf) # make sure there are the correct amount of runs assert len(sched.all_runs) == len(expected_configs.keys()) # check that the output files were created - for runspec in sched.all_runs: + for runspec in all_sched.all_runs: # sanity check: make sure the runspec uses the hardcoded part, version, and seed assert runspec.part_name == DATABASE_NAME assert runspec.version == version_num assert runspec.seed == seed_num - # get the expected output - expected_config = expected_configs[runspec.config_id] - expected_output = f"SoftmaxAC({expected_config['alpha']}, {expected_config['tau']}, {expected_config['n_step']}, {expected_config['tiles']}, {expected_config['tilings']})" - - # check that the output file was created - output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt") - with open(output_path, "r") as f: - output = f.read() - assert output.strip() == expected_output + # in this case we filtered out the runs + if pred(*runspec): + # check that the output file was not created + output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt") + assert not os.path.exists(output_path) + + # in this case we did not filter out the runs, and we should have an output file + else: + # get the expected output + expected_config = expected_configs[runspec.config_id] + expected_output = f"SoftmaxAC({expected_config['alpha']}, {expected_config['tau']}, {expected_config['n_step']}, {expected_config['tiles']}, {expected_config['tilings']})" + + # check that the output file was created + output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt") + with open(output_path, "r") as f: + output = f.read() + assert output.strip() == expected_output