Skip to content

Commit

Permalink
Merge pull request #2 from andnp/main
Browse files Browse the repository at this point in the history
Merge WIP into main
  • Loading branch information
panahiparham authored Sep 9, 2024
2 parents 2d8f6be + 6978237 commit a4dd108
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 83 deletions.
41 changes: 0 additions & 41 deletions .github/workflows/benchmark.yml

This file was deleted.

40 changes: 0 additions & 40 deletions .github/workflows/pr_benchmark.yml

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- run: echo "$PWD/.venv/bin" >> $GITHUB_PATH

- run: pytest tests/ --junitxml=junit/test-results-${{ matrix.os }}-${{ matrix.python-version }}.xml --cov=recommence --cov-report=html:coverage/cov-${{ matrix.os }}-${{ matrix.python-version }}.html
- run: pytest tests/ --junitxml=junit/test-results-${{ matrix.os }}-${{ matrix.python-version }}.xml --cov=ml_experiment --cov-report=html:coverage/cov-${{ matrix.os }}-${{ matrix.python-version }}.html

- name: Upload pytest test results
uses: actions/upload-artifact@v4
Expand Down
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
requirements.txt
.python-version
test.py
**/__pycache__/
.vscode/
results/
experiments/
.pytest_cache/
.venv/
104 changes: 104 additions & 0 deletions ml_experiment/DefinitionPart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import sqlite3

from itertools import product

from typing import Dict, Iterable, Set
from collections import defaultdict

import ml_experiment._utils.sqlite as sqlu

ValueType = int | float | str | bool

class DefinitionPart:
def __init__(self, name: str, base: str | None = None):
self.name = name
self.base_path = base or os.getcwd()

self._properties: Dict[str, Set[ValueType]] = defaultdict(set)

def add_property(self, key: str, value: ValueType):
self._properties[key].add(value)

def add_sweepable_property(self, key: str, values: Iterable[ValueType]):
self._properties[key] |= set(values)

def get_results_path(self) -> str:
import __main__
experiment_name = __main__.__file__.split('/')[-2]
return os.path.join(self.base_path, 'results', experiment_name)

def commit(self):
configurations = list(generate_configurations(self._properties))

save_path = self.get_results_path()
db_path = os.path.join(save_path, 'metadata.db')
con = _init_db(db_path)
cur = con.cursor()
tables = sqlu.get_tables(cur)

# get table version
latest_version = -1
for t in {t for t in tables if t.startswith(self.name)}:
version = int(t.replace(self.name + '-v', ''))
if version > latest_version:
latest_version = version

if latest_version == -1:
for i, configuration in enumerate(configurations):
configuration['id'] = i
else:

# find next id for new configurations
all_ids = []
for i in range(latest_version + 1):
_table_name = self.name + f'-v{i}'
cur.execute(f"SELECT DISTINCT id FROM '{_table_name}'")
all_ids.extend([x[0] for x in cur.fetchall()])
next_id = max(all_ids) + 1

# assign ids to new configurations / find ids for existing configurations
for configuration in configurations:
_id = None

for i in range(latest_version, -1, -1):
table_name = self.name + f'-v{i}'

# check if properties match the table schema
cur.execute(f"PRAGMA table_info('{table_name}')")
columns = set([x[1] for x in cur.fetchall()])

if not set(configuration.keys()) == columns - {'id'} :
continue

# check if configuration exists
cur.execute(f"SELECT id FROM '{table_name}' WHERE {' AND '.join([f'{k}=?' for k in configuration.keys()])}", list(configuration.values()))
_id = cur.fetchone()
if _id:
break

if _id:
configuration['id'] = _id[0]
else:
configuration['id'] = next_id
next_id += 1

table_name = self.name + f'-v{latest_version + 1}'

sqlu.create_table(cur, table_name, list(self._properties.keys()) + ['id INTEGER PRIMARY KEY'])
conf_string = ', '.join(['?'] * (len(self._properties) + 1))
column_names = ', '.join(list(self._properties.keys()) + ['id'])
cur.executemany(f"INSERT INTO '{table_name}' ({column_names}) VALUES ({conf_string})", [list(c.values()) for c in configurations])

con.commit()
con.close()


def _init_db(db_path: str):
os.makedirs(os.path.dirname(db_path), exist_ok=True)
con = sqlite3.connect(db_path)
return con

def generate_configurations(properties: Dict[str, Set[ValueType]]):
for configuration in product(*properties.values()):
yield dict(zip(properties.keys(), configuration, strict=True))
10 changes: 10 additions & 0 deletions ml_experiment/_utils/sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from sqlite3 import Cursor
from typing import Set, List

def get_tables(cur: Cursor) -> Set[str]:
res = cur.execute("SELECT name FROM sqlite_master")
return set(r[0] for r in res.fetchall())

def create_table(cur: Cursor, table_name: str, columns: List[str]):
columns_str = ', '.join(columns)
cur.execute(f"CREATE TABLE '{table_name}' ({columns_str})")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ select = ['F', 'E', 'W', 'B']
ignore = ['E501', 'E701']

[tool.pyright]
include = ['ml-experiment-definition']
include = ['ml_experiment']
venvPath = '.'
venv = '.venv'
typeCheckingMode = 'standard'
Expand Down
Empty file added tests/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions tests/test_DefinitionPart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ml_experiment.DefinitionPart import DefinitionPart


def test_add_sweepable_property():
builder = DefinitionPart('qrc')
builder.add_sweepable_property('key_1', [1, 2, 3])

for i in range(1, 4):
builder.add_property('key_2', i)

0 comments on commit a4dd108

Please sign in to comment.