Skip to content

Commit

Permalink
test: adding tests for experiment definition class
Browse files Browse the repository at this point in the history
  • Loading branch information
panahiparham committed Sep 24, 2024
1 parent b95c5b4 commit e3f82a8
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions tests/test_ExperimentDefinition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os

from ml_experiment.DefinitionPart import DefinitionPart
from ml_experiment.ExperimentDefinition import ExperimentDefinition


def test_ExperimentDefinition():

# build dummy experiment
exp_name = 'dummy_experiment'
part_name = 'qrc'

part = stubbed_DefinitionPart(exp_name, part_name)
part.add_sweepable_property('alpha', (2**-i for i in range(3, 8)))
part.add_sweepable_property('beta', [0.5, 1.0, 2.0])
part.commit()


# load experiment definition
version = 0
config_ids = [1, 2, 3]
seeds = [1, 2]

exp = stubbed_ExperimentDefinition(exp_name, part_name, version)

config = exp.get_config(0)
assert config == {'alpha': 0.125, 'beta': 0.5, 'id': 0}

configs = exp.get_configs(config_ids)
assert configs == [
{'alpha': 0.125, 'beta': 1.0, 'id': 1},
{'alpha': 0.125, 'beta': 2.0, 'id': 2},
{'alpha': 0.0625, 'beta': 0.5, 'id': 3},
]


configs_and_seeds = exp.get_configs(config_ids, product_seeds=seeds)
assert configs_and_seeds == [
{'alpha': 0.125, 'beta': 1.0, 'id': 1, 'seed': 1},
{'alpha': 0.125, 'beta': 1.0, 'id': 1, 'seed': 2},
{'alpha': 0.125, 'beta': 2.0, 'id': 2, 'seed': 1},
{'alpha': 0.125, 'beta': 2.0, 'id': 2, 'seed': 2},
{'alpha': 0.0625, 'beta': 0.5, 'id': 3, 'seed': 1},
{'alpha': 0.0625, 'beta': 0.5, 'id': 3, 'seed': 2},
]


class stubbed_DefinitionPart(DefinitionPart):
def __init__(self, exp_name: str, name: str, base: str | None = None):
self.exp_name = exp_name
super().__init__(name, base)

def get_results_path(self) -> str:
return os.path.join(self.base_path, 'results', self.exp_name)

class stubbed_ExperimentDefinition(ExperimentDefinition):
def __init__(self, exp_name: str, part_name: str, version: int, base: str | None = None):
self.exp_name = exp_name
super().__init__(part_name, version, base)

def get_results_path(self) -> str:
return os.path.join(self.base_path, 'results', self.exp_name)

0 comments on commit e3f82a8

Please sign in to comment.