-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from panahiparham/acceptance
Beginning of acceptance test framework
- Loading branch information
Showing
7 changed files
with
307 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ authors = [ | |
{name = "Andy Patterson", email = "[email protected]"}, | ||
] | ||
dependencies = [] | ||
requires-python = ">=3.10,<3.13" | ||
requires-python = ">=3.11,<3.13" | ||
readme = "README.md" | ||
license = {text = "MIT"} | ||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import argparse | ||
import os | ||
import random | ||
|
||
from ml_experiment.ExperimentDefinition import ExperimentDefinition | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--part", type=str, required=True) | ||
parser.add_argument("--config-id", type=int, required=True) | ||
parser.add_argument("--seed", type=int, required=True) | ||
parser.add_argument("--version", type=int, required=True) | ||
parser.add_argument("--results_path", type=str, required=True) | ||
|
||
class SoftmaxAC: | ||
def __init__( | ||
self, | ||
alpha: float, | ||
tau: float, | ||
nstep: float, | ||
tiles: int, | ||
tilings: int, | ||
) -> None: | ||
self.alpha = alpha | ||
self.tau = tau | ||
self.nstep = nstep | ||
self.tiles = tiles | ||
self.tilings = tilings | ||
self.name = "SoftmaxAC" | ||
|
||
def run(self) -> str: | ||
return f"{self.name}({self.alpha}, {self.tau}, {self.nstep}, {self.tiles}, {self.tilings})" | ||
|
||
|
||
def main(): | ||
cmdline = parser.parse_args() | ||
|
||
# make sure we are using softmaxAC | ||
if cmdline.part != "softmaxAC": | ||
raise ValueError(f"Unknown part: {cmdline.part}") | ||
|
||
# do some rng control | ||
random.seed(cmdline.seed) | ||
|
||
# extract configs from the database | ||
exp = ExperimentDefinition("softmaxAC", cmdline.version) | ||
# TODO: don't overwrite this | ||
exp.get_results_path = lambda *args, **kwargs: cmdline.results_path # overwrite results path | ||
config = exp.get_config(cmdline.config_id) | ||
|
||
# make our dummy agent | ||
alpha = config["alpha"] | ||
tau = config["tau"] | ||
n_step = config["n_step"] | ||
tiles = config["tiles"] | ||
tilings = config["tilings"] | ||
agent = SoftmaxAC(alpha, tau, n_step, tiles, tilings) | ||
|
||
# run the agent | ||
output = agent.run() | ||
|
||
# write the output to a file | ||
output_path = os.path.join(cmdline.results_path, f"output_{cmdline.config_id}.txt") | ||
with open(output_path, "w") as f: | ||
f.write(output) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import os | ||
import pytest | ||
import subprocess | ||
|
||
from ml_experiment.ExperimentDefinition import ExperimentDefinition | ||
from ml_experiment.DefinitionPart import DefinitionPart | ||
from ml_experiment.Scheduler import LocalRunConfig, RunSpec, Scheduler | ||
|
||
|
||
@pytest.fixture | ||
def base_path(request): | ||
"""Overwrite the __main__.__file__ to be the path to the current file. This allows _utils.get_experiment_name to look at ./tests/acceptance/this_file.py rather than ./.venv/bin/pytest.""" | ||
import __main__ | ||
__main__.__file__ = request.path.__fspath__() | ||
|
||
def write_database(tmp_path, alphas: list[float], taus: list[float]): | ||
# make table writer | ||
softmaxAC = DefinitionPart("softmaxAC", base=str(tmp_path)) | ||
|
||
# add properties to sweep | ||
softmaxAC.add_sweepable_property("alpha", alphas) | ||
softmaxAC.add_sweepable_property("tau", taus) | ||
|
||
# add properties that are static | ||
softmaxAC.add_property("n_step", 1) | ||
softmaxAC.add_property("tiles", 4) | ||
softmaxAC.add_property("tilings", 16) | ||
softmaxAC.add_property("total_steps", 100000) | ||
softmaxAC.add_property("episode_cutoff", 5000) | ||
softmaxAC.add_property("seed", 10) | ||
|
||
# write the properties to the database | ||
softmaxAC.commit() | ||
|
||
return softmaxAC | ||
|
||
|
||
# overwrite the run_single function | ||
class StubScheduler(Scheduler): | ||
|
||
# allows us to force the results path to be in a specific spot | ||
def __init__(self, results_path: str, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.results_path = results_path | ||
|
||
# adding the results path to the command | ||
def _run_single(self: Scheduler, r: RunSpec) -> None: | ||
subprocess.run(['python', self.entry, '--part', r.part_name, '--config-id', str(r.config_id), '--seed', str(r.seed), '--version', str(r.version), '--results_path', self.results_path]) | ||
|
||
|
||
def test_read_database(tmp_path, base_path): | ||
""" | ||
Test that we can retrieve the configurations from the experiment definition. | ||
""" | ||
|
||
# expected outputs | ||
alphas = [0.05, 0.01] | ||
taus = [10.0, 20.0, 5.0] | ||
partial_configs = ( | ||
{ | ||
"alpha": a, | ||
"tau": t, | ||
"n_step": 1, | ||
"tiles": 4, | ||
"tilings": 16, | ||
"total_steps": 100000, | ||
"episode_cutoff": 5000, | ||
"seed": 10, | ||
} | ||
for a in alphas | ||
for t in taus | ||
) | ||
expected_configs = ( | ||
{ | ||
**config, | ||
"id": i, | ||
} | ||
for i, config in enumerate(partial_configs) | ||
) | ||
|
||
# write experiment definition to table | ||
write_database(tmp_path, alphas, taus) | ||
|
||
# make Experiment object (versions start at 0) | ||
softmaxAC_mc = ExperimentDefinition( | ||
part_name="softmaxAC", version=0, base=str(tmp_path) | ||
) | ||
|
||
# get the configuration ids | ||
num_configs = len(alphas) * len(taus) | ||
config_ids = list(range(num_configs)) | ||
|
||
for cid, expected_config in zip(config_ids, expected_configs, strict=True): | ||
config = softmaxAC_mc.get_config(cid) | ||
assert config == expected_config | ||
|
||
|
||
def test_run_tasks(tmp_path): | ||
"""Make sure that the scheduler runs all the tasks, and that they return the correct results.""" | ||
# setup | ||
alphas = [0.05, 0.01] | ||
taus = [10.0, 20.0, 5.0] | ||
seed_num = 10 | ||
version_num = 0 | ||
exp_name = "acceptance" | ||
|
||
# expected outputs | ||
partial_configs = ( | ||
{ | ||
"alpha": a, | ||
"tau": t, | ||
"n_step": 1, | ||
"tiles": 4, | ||
"tilings": 16, | ||
"total_steps": 100000, | ||
"episode_cutoff": 5000, | ||
"seed": 10, | ||
} | ||
for a in alphas | ||
for t in taus | ||
) | ||
expected_configs = {i : config for i, config in enumerate(partial_configs)} | ||
|
||
# set experiment file name | ||
experiment_file_name = f"tests/{exp_name}/my_experiment.py" | ||
|
||
# set results path | ||
results_path = os.path.join(tmp_path, "results", f"{exp_name}") | ||
|
||
# write experiment definition to table | ||
db = write_database(tmp_path, alphas, taus) | ||
|
||
assert db.name == "softmaxAC" | ||
assert os.path.exists(os.path.join(results_path, "metadata.db")) | ||
|
||
# get number of tasks to run in parallel | ||
try: | ||
import multiprocessing | ||
|
||
ntasks = multiprocessing.cpu_count() - 1 | ||
except (ImportError, NotImplementedError): | ||
ntasks = 1 | ||
|
||
# initialize run config | ||
run_conf = LocalRunConfig(tasks_in_parallel=ntasks, log_path=".logs/") | ||
|
||
# set up scheduler | ||
sched = StubScheduler( | ||
exp_name=exp_name, | ||
entry=experiment_file_name, | ||
seeds=[seed_num], | ||
version=version_num, | ||
results_path=results_path, | ||
base = str(tmp_path), | ||
) | ||
|
||
# run all the tasks | ||
sched = sched.get_all_runs() | ||
sched.run(run_conf) | ||
|
||
# make sure there are the correct amount of runs | ||
assert len(sched.all_runs) == len(expected_configs.keys()) | ||
|
||
# check that the output files were created | ||
for runspec in sched.all_runs: | ||
|
||
# sanity check: make sure the runspec uses the hardcoded part, version, and seed | ||
assert runspec.part_name == "softmaxAC" | ||
assert runspec.version == version_num | ||
assert runspec.seed == seed_num | ||
|
||
# get the expected output | ||
expected_config = expected_configs[runspec.config_id] | ||
expected_output = f"SoftmaxAC({expected_config['alpha']}, {expected_config['tau']}, {expected_config['n_step']}, {expected_config['tiles']}, {expected_config['tilings']})" | ||
|
||
# check that the output file was created | ||
output_path = os.path.join(results_path, f"output_{runspec.config_id}.txt") | ||
with open(output_path, "r") as f: | ||
output = f.read() | ||
assert output.strip() == expected_output | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from ml_experiment import DefinitionPart as dp | ||
|
||
|
||
def init_esarsa_mc(tmp_path, alphas: list[float], epsilons: list[float], n_steps: list[int]): | ||
|
||
esarsa = dp.DefinitionPart("esarsa-mc", base=str(tmp_path)) | ||
esarsa.add_sweepable_property("alpha", alphas) | ||
esarsa.add_sweepable_property("epsilon", epsilons) | ||
esarsa.add_sweepable_property("n_step", n_steps) | ||
esarsa.add_property("tiles", 4) | ||
esarsa.add_property("tilings", 16) | ||
esarsa.add_property("total_steps", 100000) | ||
esarsa.add_property("episode_cutoff", 5000) | ||
esarsa.commit() | ||
|
||
return esarsa | ||
|
||
|
||
def test_generate_configurations(tmp_path): | ||
""" | ||
Tests that the dp.generate_configurations function returns the same configurations as the ones written by the dp.DefinitionPart.commit function. | ||
Note: configs do not have ID numbers | ||
""" | ||
|
||
# expected outputs | ||
alphas = [0.5, 0.25, 0.125] | ||
epsilons = [0.1, 0.05, 0.15] | ||
n_steps = [2, 3] | ||
expected_configs = ( | ||
{ | ||
"alpha": a, | ||
"epsilon": e, | ||
"n_step": n, | ||
"tiles": 4, | ||
"tilings": 16, | ||
"total_steps": 100000, | ||
"episode_cutoff": 5000, | ||
} | ||
for a in alphas | ||
for e in epsilons | ||
for n in n_steps | ||
) | ||
|
||
# write experiment definition to table | ||
esarsa_mc = init_esarsa_mc(tmp_path, alphas, epsilons, n_steps) | ||
|
||
# get all the hyperparameter configurations | ||
configs = dp.generate_configurations(esarsa_mc._properties) | ||
|
||
for config, expected_config in zip(configs, expected_configs, strict=True): | ||
assert config == expected_config |