From abd026b73e35dffc52ce1cab74346bfc4d9ee6bb Mon Sep 17 00:00:00 2001 From: Parham Panahi Date: Thu, 3 Oct 2024 14:06:14 -0600 Subject: [PATCH] refactor: use utility function for getting results path --- ml_experiment/DefinitionPart.py | 9 +++------ ml_experiment/ExperimentDefinition.py | 12 ++++-------- tests/test_ExperimentDefinition.py | 8 ++++---- 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/ml_experiment/DefinitionPart.py b/ml_experiment/DefinitionPart.py index cf310fd..d39f34b 100644 --- a/ml_experiment/DefinitionPart.py +++ b/ml_experiment/DefinitionPart.py @@ -5,6 +5,7 @@ from collections import defaultdict import ml_experiment._utils.sqlite as sqlu +from ml_experiment._utils.path import get_results_path from ml_experiment.metadata.MetadataTableRegistry import MetadataTableRegistry ValueType = int | float | str | bool @@ -13,6 +14,7 @@ class DefinitionPart: def __init__(self, name: str, base: str | None = None): self.name = name self.base_path = base or os.getcwd() + self.get_results_path = get_results_path self._properties: Dict[str, Set[ValueType]] = defaultdict(set) self._prior_values: Dict[str, ValueType] = {} @@ -39,15 +41,10 @@ def add_sweepable_property( if assume_prior_value is not None: self._prior_values[key] = assume_prior_value - def get_results_path(self) -> str: - import __main__ - experiment_name = __main__.__file__.split('/')[-2] - return os.path.join(self.base_path, 'results', experiment_name) - def commit(self): configurations = list(generate_configurations(self._properties)) - save_path = self.get_results_path() + save_path = self.get_results_path(self.base_path) db_path = os.path.join(save_path, 'metadata.db') con = sqlu.init_db(db_path) cur = con.cursor() diff --git a/ml_experiment/ExperimentDefinition.py b/ml_experiment/ExperimentDefinition.py index d08aa76..cb13225 100644 --- a/ml_experiment/ExperimentDefinition.py +++ b/ml_experiment/ExperimentDefinition.py @@ -3,17 +3,19 @@ import sqlite3 from ml_experiment.metadata.MetadataTable import MetadataTable +from ml_experiment._utils.path import get_results_path class ExperimentDefinition: def __init__(self, part_name: str, version: int, base: str | None = None): self.part_name = part_name self.version = version self.base_path = base or os.getcwd() + self.get_results_path = get_results_path self.table = MetadataTable(self.part_name, self.version) def get_config(self, config_id: int) -> dict[str, Any]: - save_path = self.get_results_path() + save_path = self.get_results_path(self.base_path) db_path = os.path.join(save_path, 'metadata.db') with sqlite3.connect(db_path) as con: @@ -23,7 +25,7 @@ def get_config(self, config_id: int) -> dict[str, Any]: def get_configs(self, config_ids: list[int], product_seeds: list[int] | None = None) -> list[dict[str, Any]]: - save_path = self.get_results_path() + save_path = self.get_results_path(self.base_path) db_path = os.path.join(save_path, 'metadata.db') with sqlite3.connect(db_path) as con: @@ -41,9 +43,3 @@ def get_configs(self, config_ids: list[int], product_seeds: list[int] | None = N ] else: return _c - - - def get_results_path(self) -> str: - import __main__ - experiment_name = __main__.__file__.split('/')[-2] - return os.path.join(self.base_path, 'results', experiment_name) diff --git a/tests/test_ExperimentDefinition.py b/tests/test_ExperimentDefinition.py index 2d18ce3..787fa14 100644 --- a/tests/test_ExperimentDefinition.py +++ b/tests/test_ExperimentDefinition.py @@ -50,13 +50,13 @@ def __init__(self, exp_name: str, name: str, base: str | None = None): self.exp_name = exp_name super().__init__(name, base) - def get_results_path(self) -> str: - return os.path.join(self.base_path, 'results', self.exp_name) + def get_results_path(self, base_path) -> str: + return os.path.join(base_path, 'results', self.exp_name) class stubbed_ExperimentDefinition(ExperimentDefinition): def __init__(self, exp_name: str, part_name: str, version: int, base: str | None = None): self.exp_name = exp_name super().__init__(part_name, version, base) - def get_results_path(self) -> str: - return os.path.join(self.base_path, 'results', self.exp_name) + def get_results_path(self, base_path) -> str: + return os.path.join(base_path, 'results', self.exp_name)