Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filter to the acceptance tests #26

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions tests/acceptance/my_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,11 @@ def run(self) -> str:
def main():
cmdline = parser.parse_args()

# make sure we are using softmaxAC
if cmdline.part != "softmaxAC":
raise ValueError(f"Unknown part: {cmdline.part}")

# do some rng control
random.seed(cmdline.seed)

# extract configs from the database
exp = ExperimentDefinition("softmaxAC", cmdline.version)
exp = ExperimentDefinition(cmdline.part, cmdline.version)
# TODO: don't overwrite this
exp.get_results_path = lambda *args, **kwargs: cmdline.results_path # overwrite results path
config = exp.get_config(cmdline.config_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,34 @@
from ml_experiment.experiment_definition import ExperimentDefinition
from ml_experiment.Scheduler import LocalRunConfig, Scheduler

DATABASE_NAME = "MyHyperparameterDatabase"

@pytest.fixture
def base_path(request):
"""Overwrite the __main__.__file__ to be the path to the current file. This allows _utils.get_experiment_name to look at ./tests/acceptance/this_file.py rather than ./.venv/bin/pytest."""
"""Overwrite the __main__.__file__ to be the path to the current file. This allows _utils.get_experiment_name to look at ./tests/acceptance/this_file.py rather than the default of ./.venv/bin/pytest."""
import __main__
__main__.__file__ = request.path.__fspath__()

def write_database(tmp_path, alphas: list[float], taus: list[float]):
# make table writer
softmaxAC = DefinitionPart("softmaxAC", base=str(tmp_path))
# make database writer
db_writer = DefinitionPart(DATABASE_NAME, base=str(tmp_path))

# add properties to sweep
softmaxAC.add_sweepable_property("alpha", alphas)
softmaxAC.add_sweepable_property("tau", taus)
db_writer.add_sweepable_property("alpha", alphas)
db_writer.add_sweepable_property("tau", taus)

# add properties that are static
softmaxAC.add_property("n_step", 1)
softmaxAC.add_property("tiles", 4)
softmaxAC.add_property("tilings", 16)
softmaxAC.add_property("total_steps", 100000)
softmaxAC.add_property("episode_cutoff", 5000)
softmaxAC.add_property("seed", 10)
db_writer.add_property("n_step", 1)
db_writer.add_property("tiles", 4)
db_writer.add_property("tilings", 16)
db_writer.add_property("total_steps", 100000)
db_writer.add_property("episode_cutoff", 5000)
db_writer.add_property("seed", 10)

# write the properties to the database
softmaxAC.commit()
db_writer.commit()

return softmaxAC
return db_writer

def test_read_database(tmp_path, base_path):
"""
Expand Down Expand Up @@ -68,7 +69,7 @@ def test_read_database(tmp_path, base_path):

# make Experiment object (versions start at 0)
softmaxAC_mc = ExperimentDefinition(
part_name="softmaxAC", version=0, base=str(tmp_path)
part_name=DATABASE_NAME, version=0, base=str(tmp_path)
)

# get the configuration ids
Expand Down Expand Up @@ -104,7 +105,9 @@ def test_run_tasks(tmp_path):
for a in alphas
for t in taus
)
expected_configs = {i : config for i, config in enumerate(partial_configs)}
# filter out the tau == 5 configs
configs_with_tau5 = {i : config for i, config in enumerate(partial_configs)}
expected_configs = {k:v for k,v in configs_with_tau5.items() if v["tau"] != 5}

# set experiment file name
experiment_file_name = f"tests/{exp_name}/my_experiment.py"
Expand All @@ -113,9 +116,9 @@ def test_run_tasks(tmp_path):
results_path = os.path.join(tmp_path, "results", f"{exp_name}")

# write experiment definition to table
db = write_database(tmp_path, alphas, taus)
db_writer = write_database(tmp_path, alphas, taus)

assert db.name == "softmaxAC"
assert db_writer.name == DATABASE_NAME
assert os.path.exists(os.path.join(results_path, "metadata.db"))

# get number of tasks to run in parallel
Expand All @@ -138,30 +141,62 @@ def test_run_tasks(tmp_path):
base = str(tmp_path),
)

# run all the tasks
sched = sched.get_all_runs()
def pred(part_name, version, config_id, seed) -> bool:
"""Fn to filter out tasks where the output is True.
In this case we check whether tau = 5.

Note: The ExperimentDefinition base path is coming from the closure.
Note: We have to make an ExperimentDefinition here to get the config values, or pass one in via closure.

Args:
part_name (str): part name
version (int): version number
config_id (int): configuration id
seed (int): seed

Returns:
bool: True if tau == 5, False otherwise
"""
# make database reader to confirm config values
db_reader = ExperimentDefinition(part_name, version, base=str(tmp_path))
config = db_reader.get_config(config_id)
return config["tau"] == 5

# get all the runs, including tau == 5
all_sched = sched.get_all_runs()

# run the filtered tasks
sched = all_sched.filter(pred)
sched.run(run_conf)

# make sure there are the correct amount of runs
assert len(sched.all_runs) == len(expected_configs.keys())

# check that the output files were created
for runspec in sched.all_runs:
for runspec in all_sched.all_runs:

# sanity check: make sure the runspec uses the hardcoded part, version, and seed
assert runspec.part_name == "softmaxAC"
assert runspec.part_name == DATABASE_NAME
assert runspec.version == version_num
assert runspec.seed == seed_num

# get the expected output
expected_config = expected_configs[runspec.config_id]
expected_output = f"SoftmaxAC({expected_config['alpha']}, {expected_config['tau']}, {expected_config['n_step']}, {expected_config['tiles']}, {expected_config['tilings']})"

# check that the output file was created
output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt")
with open(output_path, "r") as f:
output = f.read()
assert output.strip() == expected_output
# in this case we filtered out the runs
if pred(*runspec):
# check that the output file was not created
output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt")
assert not os.path.exists(output_path)

# in this case we did not filter out the runs, and we should have an output file
else:
# get the expected output
expected_config = expected_configs[runspec.config_id]
expected_output = f"SoftmaxAC({expected_config['alpha']}, {expected_config['tau']}, {expected_config['n_step']}, {expected_config['tiles']}, {expected_config['tilings']})"

# check that the output file was created
output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt")
with open(output_path, "r") as f:
output = f.read()
assert output.strip() == expected_output



Expand Down
Loading