Skip to content

Commit

Permalink
refactor: use utility function for getting results path
Browse files Browse the repository at this point in the history
  • Loading branch information
panahiparham committed Oct 23, 2024
1 parent ed08c84 commit abd026b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 18 deletions.
9 changes: 3 additions & 6 deletions ml_experiment/DefinitionPart.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict

import ml_experiment._utils.sqlite as sqlu
from ml_experiment._utils.path import get_results_path
from ml_experiment.metadata.MetadataTableRegistry import MetadataTableRegistry

ValueType = int | float | str | bool
Expand All @@ -13,6 +14,7 @@ class DefinitionPart:
def __init__(self, name: str, base: str | None = None):
self.name = name
self.base_path = base or os.getcwd()
self.get_results_path = get_results_path

self._properties: Dict[str, Set[ValueType]] = defaultdict(set)
self._prior_values: Dict[str, ValueType] = {}
Expand All @@ -39,15 +41,10 @@ def add_sweepable_property(
if assume_prior_value is not None:
self._prior_values[key] = assume_prior_value

def get_results_path(self) -> str:
import __main__
experiment_name = __main__.__file__.split('/')[-2]
return os.path.join(self.base_path, 'results', experiment_name)

def commit(self):
configurations = list(generate_configurations(self._properties))

save_path = self.get_results_path()
save_path = self.get_results_path(self.base_path)
db_path = os.path.join(save_path, 'metadata.db')
con = sqlu.init_db(db_path)
cur = con.cursor()
Expand Down
12 changes: 4 additions & 8 deletions ml_experiment/ExperimentDefinition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import sqlite3

from ml_experiment.metadata.MetadataTable import MetadataTable
from ml_experiment._utils.path import get_results_path

class ExperimentDefinition:
def __init__(self, part_name: str, version: int, base: str | None = None):
self.part_name = part_name
self.version = version
self.base_path = base or os.getcwd()
self.get_results_path = get_results_path

self.table = MetadataTable(self.part_name, self.version)

def get_config(self, config_id: int) -> dict[str, Any]:
save_path = self.get_results_path()
save_path = self.get_results_path(self.base_path)
db_path = os.path.join(save_path, 'metadata.db')

with sqlite3.connect(db_path) as con:
Expand All @@ -23,7 +25,7 @@ def get_config(self, config_id: int) -> dict[str, Any]:


def get_configs(self, config_ids: list[int], product_seeds: list[int] | None = None) -> list[dict[str, Any]]:
save_path = self.get_results_path()
save_path = self.get_results_path(self.base_path)
db_path = os.path.join(save_path, 'metadata.db')

with sqlite3.connect(db_path) as con:
Expand All @@ -41,9 +43,3 @@ def get_configs(self, config_ids: list[int], product_seeds: list[int] | None = N
]
else:
return _c


def get_results_path(self) -> str:
import __main__
experiment_name = __main__.__file__.split('/')[-2]
return os.path.join(self.base_path, 'results', experiment_name)
8 changes: 4 additions & 4 deletions tests/test_ExperimentDefinition.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def __init__(self, exp_name: str, name: str, base: str | None = None):
self.exp_name = exp_name
super().__init__(name, base)

def get_results_path(self) -> str:
return os.path.join(self.base_path, 'results', self.exp_name)
def get_results_path(self, base_path) -> str:
return os.path.join(base_path, 'results', self.exp_name)

class stubbed_ExperimentDefinition(ExperimentDefinition):
def __init__(self, exp_name: str, part_name: str, version: int, base: str | None = None):
self.exp_name = exp_name
super().__init__(part_name, version, base)

def get_results_path(self) -> str:
return os.path.join(self.base_path, 'results', self.exp_name)
def get_results_path(self, base_path) -> str:
return os.path.join(base_path, 'results', self.exp_name)

0 comments on commit abd026b

Please sign in to comment.