Skip to content

Commit

Permalink
test: add filtering to the acceptance test
Browse files Browse the repository at this point in the history
  • Loading branch information
yasuiniko committed Nov 5, 2024
1 parent 0c7b93b commit 5895dd8
Showing 1 changed file with 49 additions and 15 deletions.
64 changes: 49 additions & 15 deletions tests/acceptance/test_acceptance.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def test_run_tasks(tmp_path):
for a in alphas
for t in taus
)
expected_configs = {i : config for i, config in enumerate(partial_configs)}
# filter out the tau == 5 configs
configs_with_tau5 = {i : config for i, config in enumerate(partial_configs)}
expected_configs = {k:v for k,v in configs_with_tau5.items() if v["tau"] != 5}

# set experiment file name
experiment_file_name = f"tests/{exp_name}/my_experiment.py"
Expand All @@ -114,9 +116,9 @@ def test_run_tasks(tmp_path):
results_path = os.path.join(tmp_path, "results", f"{exp_name}")

# write experiment definition to table
db = write_database(tmp_path, alphas, taus)
db_writer = write_database(tmp_path, alphas, taus)

assert db.name == DATABASE_NAME
assert db_writer.name == DATABASE_NAME
assert os.path.exists(os.path.join(results_path, "metadata.db"))

# get number of tasks to run in parallel
Expand All @@ -139,30 +141,62 @@ def test_run_tasks(tmp_path):
base = str(tmp_path),
)

# run all the tasks
sched = sched.get_all_runs()
def pred(part_name, version, config_id, seed) -> bool:
"""Fn to filter out tasks where the output is True.
In this case we check whether tau = 5.
Note: The ExperimentDefinition base path is coming from the closure.
Note: We have to make an ExperimentDefinition here to get the config values, or pass one in via closure.
Args:
part_name (str): part name
version (int): version number
config_id (int): configuration id
seed (int): seed
Returns:
bool: True if tau == 5, False otherwise
"""
# make database reader to confirm config values
db_reader = ExperimentDefinition(part_name, version, base=str(tmp_path))
config = db_reader.get_config(config_id)
return config["tau"] == 5

# get all the runs, including tau == 5
all_sched = sched.get_all_runs()

# run the filtered tasks
sched = all_sched.filter(pred)
sched.run(run_conf)

# make sure there are the correct amount of runs
assert len(sched.all_runs) == len(expected_configs.keys())

# check that the output files were created
for runspec in sched.all_runs:
for runspec in all_sched.all_runs:

# sanity check: make sure the runspec uses the hardcoded part, version, and seed
assert runspec.part_name == DATABASE_NAME
assert runspec.version == version_num
assert runspec.seed == seed_num

# get the expected output
expected_config = expected_configs[runspec.config_id]
expected_output = f"SoftmaxAC({expected_config['alpha']}, {expected_config['tau']}, {expected_config['n_step']}, {expected_config['tiles']}, {expected_config['tilings']})"

# check that the output file was created
output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt")
with open(output_path, "r") as f:
output = f.read()
assert output.strip() == expected_output
# in this case we filtered out the runs
if pred(*runspec):
# check that the output file was not created
output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt")
assert not os.path.exists(output_path)

# in this case we did not filter out the runs, and we should have an output file
else:
# get the expected output
expected_config = expected_configs[runspec.config_id]
expected_output = f"SoftmaxAC({expected_config['alpha']}, {expected_config['tau']}, {expected_config['n_step']}, {expected_config['tiles']}, {expected_config['tilings']})"

# check that the output file was created
output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt")
with open(output_path, "r") as f:
output = f.read()
assert output.strip() == expected_output



Expand Down

0 comments on commit 5895dd8

Please sign in to comment.