Skip to content

Commit

Permalink
test: implement acceptance test using scheduler
Browse files Browse the repository at this point in the history
This test includes some janky passing of results path. Needs to be fixed outside of test.

Towards: #8
  • Loading branch information
yasuiniko committed Oct 30, 2024
1 parent fd1882e commit ebc0acf
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 5 deletions.
68 changes: 68 additions & 0 deletions tests/acceptance/my_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import argparse
import random

from ml_experiment.ExperimentDefinition import ExperimentDefinition

parser = argparse.ArgumentParser()
parser.add_argument("--part", type=str, required=True)
parser.add_argument("--config-id", type=int, required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--version", type=int, required=True)

# this is an extra argument for testing
# we need to be able to find the metadata database
# TODO: remove this
parser.add_argument("--results_path", type=str, required=True)

class SoftmaxAC:
def __init__(
self,
alpha: float,
tau: float,
nstep: float,
tiles: int,
tilings: int,
) -> None:
self.alpha = alpha
self.tau = tau
self.nstep = nstep
self.tiles = tiles
self.tilings = tilings
self.name = "SoftmaxAC"

def run(self) -> str:
return f"{self.name}({self.alpha}, {self.tau}, {self.nstep}, {self.tiles}, {self.tilings})"


def main():
cmdline = parser.parse_args()

# make sure we are using softmaxAC
if cmdline.part != "softmaxAC":
raise ValueError(f"Unknown part: {cmdline.part}")

# do some rng control
random.seed(cmdline.seed)

# extract configs from the database
exp = ExperimentDefinition("softmaxAC", cmdline.version)
# TODO: don't overwrite this
exp.get_results_path = lambda *args, **kwargs: cmdline.results_path # overwrite results path
config = exp.get_config(cmdline.config_id)

# make our dummy agent
alpha = config["alpha"]
tau = config["tau"]
n_step = config["n_step"]
tiles = config["tiles"]
tilings = config["tilings"]
agent = SoftmaxAC(alpha, tau, n_step, tiles, tilings)

# run the agent
output = agent.run()

# test output
assert output == f"SoftmaxAC({alpha}, {tau}, {n_step}, {tiles}, {tilings})"

if __name__ == "__main__":
main()
89 changes: 84 additions & 5 deletions tests/acceptance/test_softmaxAC_mc.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,56 @@
import os
import subprocess

from ml_experiment.ExperimentDefinition import ExperimentDefinition
from ml_experiment.DefinitionPart import DefinitionPart
from ml_experiment.Scheduler import LocalRunConfig, RunSpec, Scheduler


def write_database(results_path, alphas: list[float], taus: list[float]):
# make table writer
softmaxAC = DefinitionPart("softmaxAC")
# TODO: don't overwrite this
softmaxAC.get_results_path = lambda *args, **kwargs: results_path

def init_softmaxAC_mc(tmp_path, alphas: list[float], taus: list[float]):
softmaxAC = DefinitionPart("softmaxAC-mc", base=str(tmp_path))
# add properties to sweep
softmaxAC.add_sweepable_property("alpha", alphas)
softmaxAC.add_sweepable_property("tau", taus)

# add properties that are static
softmaxAC.add_property("n_step", 1)
softmaxAC.add_property("tiles", 4)
softmaxAC.add_property("tilings", 16)
softmaxAC.add_property("total_steps", 100000)
softmaxAC.add_property("episode_cutoff", 5000)
softmaxAC.add_property("seed", 10)

# write the properties to the database
softmaxAC.commit()

return softmaxAC


## TODO: remove this and use the actual scheduler
# overwrite the run_single function
class StubScheduler(Scheduler):

def test_read_configs(tmp_path):
# allows us to force the results path to be in a specific spot
def __init__(self, results_path: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.results_path = results_path

# adding the results path to the command
def _run_single(self: Scheduler, r: RunSpec) -> None:
subprocess.run(['python', self.entry, '--part', r.part_name, '--config-id', str(r.config_id), '--seed', str(r.seed), '--version', str(r.version), '--results_path', self.results_path])


def test_read_database(tmp_path):
"""
Test that we can retrieve the configurations from the experiment definition.
"""

# expected outputs
results_path = os.path.join(tmp_path, "results", "temp")
alphas = [0.05, 0.01]
taus = [10.0, 20.0, 5.0]
partial_configs = (
Expand All @@ -31,6 +62,7 @@ def test_read_configs(tmp_path):
"tilings": 16,
"total_steps": 100000,
"episode_cutoff": 5000,
"seed": 10,
}
for a in alphas
for t in taus
Expand All @@ -44,10 +76,14 @@ def test_read_configs(tmp_path):
)

# write experiment definition to table
init_softmaxAC_mc(tmp_path, alphas, taus)
write_database(results_path, alphas, taus)

# make Experiment object (versions start at 0)
softmaxAC_mc = ExperimentDefinition(part_name="softmaxAC-mc", version=0, base=str(tmp_path))
softmaxAC_mc = ExperimentDefinition(
part_name="softmaxAC", version=0
)
# TODO: don't overwrite this
softmaxAC_mc.get_results_path = lambda *args, **kwargs: results_path

# TODO: This can't be the intended way to get the configurations?
num_configs = len(alphas) * len(taus)
Expand All @@ -56,3 +92,46 @@ def test_read_configs(tmp_path):
for cid, expected_config in zip(config_ids, expected_configs, strict=True):
config = softmaxAC_mc.get_config(cid)
assert config == expected_config


def test_run_tasks(tmp_path):
"""Make sure that the scheduler runs all the tasks, and that they return the correct results."""
# setup
alphas = [0.05, 0.01]
taus = [10.0, 20.0, 5.0]
results_path = os.path.join(tmp_path, "results", "temp")
db = write_database(results_path, alphas, taus)

assert db.name == "softmaxAC"
assert os.path.exists(os.path.join(results_path, "metadata.db"))

# get number of tasks to run in parallel
try:
import multiprocessing

ntasks = multiprocessing.cpu_count() - 1
except (ImportError, NotImplementedError):
ntasks = 1

# initialize run config
run_conf = LocalRunConfig(tasks_in_parallel=ntasks, log_path=".logs/")

# set up experiment file
experiment_file_name = "tests/acceptance/my_experiment.py"

# set up scheduler
sched = StubScheduler(
exp_name="temp",
entry=experiment_file_name,
seeds=[10],
version=0,
results_path=results_path,
base = str(tmp_path),
)

# run all the tasks
(
sched
.get_all_runs()
.run(run_conf)
)

0 comments on commit ebc0acf

Please sign in to comment.