Skip to content

Commit

Permalink
test: add file io to confirm the correct dummy experiments are being run
Browse files Browse the repository at this point in the history
  • Loading branch information
yasuiniko committed Oct 29, 2024
1 parent 977e1ba commit 45f9124
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 17 deletions.
11 changes: 5 additions & 6 deletions tests/acceptance/my_experiment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os
import random

from ml_experiment.ExperimentDefinition import ExperimentDefinition
Expand All @@ -8,10 +9,6 @@
parser.add_argument("--config-id", type=int, required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--version", type=int, required=True)

# this is an extra argument for testing
# we need to be able to find the metadata database
# TODO: remove this
parser.add_argument("--results_path", type=str, required=True)

class SoftmaxAC:
Expand Down Expand Up @@ -61,8 +58,10 @@ def main():
# run the agent
output = agent.run()

# test output
assert output == f"SoftmaxAC({alpha}, {tau}, {n_step}, {tiles}, {tilings})"
# write the output to a file
output_path = os.path.join(cmdline.results_path, f"output_{cmdline.config_id}.txt")
with open(output_path, "w") as f:
f.write(output)

if __name__ == "__main__":
main()
66 changes: 55 additions & 11 deletions tests/acceptance/test_softmaxAC_mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
def write_database(results_path, alphas: list[float], taus: list[float]):
# make table writer
softmaxAC = DefinitionPart("softmaxAC")
# TODO: don't overwrite this

# overwrite the results path to be in the temp directory
softmaxAC.get_results_path = lambda *args, **kwargs: results_path

# add properties to sweep
Expand All @@ -30,7 +31,6 @@ def write_database(results_path, alphas: list[float], taus: list[float]):
return softmaxAC


## TODO: remove this and use the actual scheduler
# overwrite the run_single function
class StubScheduler(Scheduler):

Expand Down Expand Up @@ -82,10 +82,11 @@ def test_read_database(tmp_path):
softmaxAC_mc = ExperimentDefinition(
part_name="softmaxAC", version=0
)
# TODO: don't overwrite this

# overwrite the results path to be in the temp directory
softmaxAC_mc.get_results_path = lambda *args, **kwargs: results_path

# TODO: This can't be the intended way to get the configurations?
# get the configuration ids
num_configs = len(alphas) * len(taus)
config_ids = list(range(num_configs))

Expand All @@ -99,6 +100,27 @@ def test_run_tasks(tmp_path):
# setup
alphas = [0.05, 0.01]
taus = [10.0, 20.0, 5.0]
seed_num = 10
version_num = 0

# expected outputs
partial_configs = (
{
"alpha": a,
"tau": t,
"n_step": 1,
"tiles": 4,
"tilings": 16,
"total_steps": 100000,
"episode_cutoff": 5000,
"seed": 10,
}
for a in alphas
for t in taus
)
expected_configs = {i : config for i, config in enumerate(partial_configs)}

# write experiment definition to table
results_path = os.path.join(tmp_path, "results", "temp")
db = write_database(results_path, alphas, taus)

Expand All @@ -123,15 +145,37 @@ def test_run_tasks(tmp_path):
sched = StubScheduler(
exp_name="temp",
entry=experiment_file_name,
seeds=[10],
version=0,
seeds=[seed_num],
version=version_num,
results_path=results_path,
base = str(tmp_path),
)

# run all the tasks
(
sched
.get_all_runs()
.run(run_conf)
)
sched = sched.get_all_runs()
sched.run(run_conf)

# make sure there are the correct amount of runs
assert len(sched.all_runs) == len(expected_configs.keys())

# check that the output files were created
for runspec in sched.all_runs:

# sanity check: make sure the runspec uses the hardcoded part, version, and seed
assert runspec.part_name == "softmaxAC"
assert runspec.version == version_num
assert runspec.seed == seed_num

# get the expected output
expected_config = expected_configs[runspec.config_id]
expected_output = f"SoftmaxAC({expected_config['alpha']}, {expected_config['tau']}, {expected_config['n_step']}, {expected_config['tiles']}, {expected_config['tilings']})"

# check that the output file was created
output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt")
with open(output_path, "r") as f:
output = f.read()
assert output.strip() == expected_output




0 comments on commit 45f9124

Please sign in to comment.