-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: adding tests for experiment definition class
- Loading branch information
1 parent
b95c5b4
commit e3f82a8
Showing
1 changed file
with
62 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import os | ||
|
||
from ml_experiment.DefinitionPart import DefinitionPart | ||
from ml_experiment.ExperimentDefinition import ExperimentDefinition | ||
|
||
|
||
def test_ExperimentDefinition(): | ||
|
||
# build dummy experiment | ||
exp_name = 'dummy_experiment' | ||
part_name = 'qrc' | ||
|
||
part = stubbed_DefinitionPart(exp_name, part_name) | ||
part.add_sweepable_property('alpha', (2**-i for i in range(3, 8))) | ||
part.add_sweepable_property('beta', [0.5, 1.0, 2.0]) | ||
part.commit() | ||
|
||
|
||
# load experiment definition | ||
version = 0 | ||
config_ids = [1, 2, 3] | ||
seeds = [1, 2] | ||
|
||
exp = stubbed_ExperimentDefinition(exp_name, part_name, version) | ||
|
||
config = exp.get_config(0) | ||
assert config == {'alpha': 0.125, 'beta': 0.5, 'id': 0} | ||
|
||
configs = exp.get_configs(config_ids) | ||
assert configs == [ | ||
{'alpha': 0.125, 'beta': 1.0, 'id': 1}, | ||
{'alpha': 0.125, 'beta': 2.0, 'id': 2}, | ||
{'alpha': 0.0625, 'beta': 0.5, 'id': 3}, | ||
] | ||
|
||
|
||
configs_and_seeds = exp.get_configs(config_ids, product_seeds=seeds) | ||
assert configs_and_seeds == [ | ||
{'alpha': 0.125, 'beta': 1.0, 'id': 1, 'seed': 1}, | ||
{'alpha': 0.125, 'beta': 1.0, 'id': 1, 'seed': 2}, | ||
{'alpha': 0.125, 'beta': 2.0, 'id': 2, 'seed': 1}, | ||
{'alpha': 0.125, 'beta': 2.0, 'id': 2, 'seed': 2}, | ||
{'alpha': 0.0625, 'beta': 0.5, 'id': 3, 'seed': 1}, | ||
{'alpha': 0.0625, 'beta': 0.5, 'id': 3, 'seed': 2}, | ||
] | ||
|
||
|
||
class stubbed_DefinitionPart(DefinitionPart): | ||
def __init__(self, exp_name: str, name: str, base: str | None = None): | ||
self.exp_name = exp_name | ||
super().__init__(name, base) | ||
|
||
def get_results_path(self) -> str: | ||
return os.path.join(self.base_path, 'results', self.exp_name) | ||
|
||
class stubbed_ExperimentDefinition(ExperimentDefinition): | ||
def __init__(self, exp_name: str, part_name: str, version: int, base: str | None = None): | ||
self.exp_name = exp_name | ||
super().__init__(part_name, version, base) | ||
|
||
def get_results_path(self) -> str: | ||
return os.path.join(self.base_path, 'results', self.exp_name) |