-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
94 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,78 +1,125 @@ | ||
from typing import List, NamedTuple | ||
from pathlib import Path | ||
import inspect | ||
from typing import List, NamedTuple, Optional, Callable | ||
|
||
from napalib.system.universe import NapAUniverse | ||
import MDAnalysis as mda | ||
from MDAnalysis.analysis.base import Results | ||
|
||
from . import Database | ||
from .database import Database, Table | ||
|
||
|
||
def get_NapA_universe_by_simID(db: Database, simID: int) -> NapAUniverse: | ||
row = db.get_table("Simulations").get_row(simID) | ||
topology = row.topology | ||
trajectory = row.trajectory | ||
u = NapAUniverse(topology) | ||
u.load_new(trajectory) | ||
class DBAnalysisManager: | ||
"""Class that connects database IO with analysis running. | ||
return u | ||
""" | ||
|
||
def __init__(self, Analysis, dbfile, hooks=None): | ||
""" | ||
def get_universe_by_simID(db: Database, simID: int) -> mda.Universe: | ||
row = db.get_table("Simulations").get_row(simID) | ||
topology = row.topology | ||
trajectory = row.trajectory | ||
Parameters | ||
---------- | ||
Analysis : | ||
dbfile : | ||
hooks : | ||
return mda.Universe(topology, trajectory) | ||
""" | ||
|
||
|
||
class DBAnalysisRunner: | ||
|
||
def __init__(self, db: Database, Analysis): | ||
|
||
self.db = db | ||
self.Analysis = Analysis | ||
self.name = self.Analysis.name | ||
self._analysis = None | ||
self.db = Database(dbfile) | ||
|
||
try: | ||
self.analysis_name = self.Analysis.name | ||
except AttributeError: | ||
self.analysis_name = self.Analysis.__name__ | ||
try: | ||
self.analysis_notes = self.Analysis.notes | ||
except AttributeError: | ||
self.analysis_notes = None | ||
self.analysis_path = inspect.getfile(self.Analysis) | ||
|
||
try: | ||
self.observables = self.db.get_table("Observables") | ||
self.obsv = self.db.get_table("Observables") | ||
except ValueError: | ||
self.observables = self.db.create_table( | ||
"Observables (name TEXT, progenitor TEXT)" | ||
self.obsv = self.db.create_table( | ||
""" | ||
Observables ( | ||
obsName TEXT, | ||
notes TEXT, | ||
creator TEXT, | ||
timestamp DATETIME DEFAULT (strftime('%m-%d-%Y %H:%M', 'now', 'localtime')) | ||
) | ||
""", | ||
STRICT=False | ||
) | ||
finally: | ||
self.observables.insert_array([ | ||
(self.name, self.Analysis._path), | ||
]) | ||
|
||
def __enter__(self): | ||
self.db.open() | ||
return self | ||
if self.analysis_name not in self.obsv.get_column("obsName").data: | ||
self.obsv.insert_row( | ||
(self.analysis_name, self.analysis_notes, self.analysis_path), | ||
columns=["obsName, notes, creator"], | ||
) | ||
|
||
def __exit__(self, *args): | ||
self.db.close() | ||
self._analysis = None | ||
|
||
@property | ||
def results(self): | ||
if self._analysis is not None: | ||
return self._analysis.results | ||
|
||
def run_for_simID(self, simID: int, **kwargs) -> None: | ||
"""""" | ||
universe = get_NapA_universe_by_simID(self.db, simID) | ||
self._analysis = self.Analysis(universe, **kwargs) | ||
def results(self) -> Results: | ||
"""Analysis results.""" | ||
|
||
if self._analysis is None: | ||
raise ValueError("Must call run() for results to exist.") | ||
return self._analysis.results | ||
|
||
def _get_universe(self, simID: int, get_universe: Optional[Callable]): | ||
|
||
if get_universe is not None: | ||
#if self.hooks["get_universe"]: | ||
return get_universe(self.db, simID) | ||
|
||
row = self.db.get_table("Simulations").get_row(simID) | ||
return mda.Universe(row.topology, row.trajecory) | ||
|
||
def run( | ||
self, | ||
simID: int, | ||
get_universe: Optional[Callable] = None, | ||
**kwargs: dict, | ||
) -> None: | ||
""" | ||
Parameters | ||
---------- | ||
simID : int | ||
get_universe : Callable[Database, int] | ||
**kwargs : dict | ||
additional keyword arguments to be passed to the Analysis class | ||
""" | ||
u = self._get_universe(simID, get_universe) | ||
|
||
self._analysis = self.Analysis(u, **kwargs) | ||
self._analysis._simID = simID | ||
self._analysis.run() | ||
|
||
def save(self) -> None: | ||
"""Save the results of the analysis to the database.""" | ||
|
||
if not self.results: | ||
raise ValueError("no results") | ||
|
||
if self.name not in self.db._get_table_names(): | ||
analysis_table = Table(self.analysis_name, self.db) | ||
|
||
if analysis_table not in self.db: | ||
self.db.create_table(self.Analysis.schema) | ||
else: | ||
simID = self._analysis._simID | ||
if simID in analysis_table._get_rowids(): | ||
raise ValueError( | ||
f"{self.analysis_name} table already has data for simID {simID}" | ||
) | ||
|
||
rows = self.results[self.Analysis.results_key] | ||
|
||
table = self.db.get_table(self.name) | ||
table.insert_array(rows) | ||
self.db.get_table(self.analysis_name).insert_rows(rows) | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, *args): | ||
self.db.close() |