Skip to content

Commit

Permalink
updating analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
edisj committed Jul 15, 2024
1 parent 6529e80 commit 68ff3e8
Showing 1 changed file with 94 additions and 47 deletions.
141 changes: 94 additions & 47 deletions mdaadb/analysis.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,125 @@
from typing import List, NamedTuple
from pathlib import Path
import inspect
from typing import List, NamedTuple, Optional, Callable

from napalib.system.universe import NapAUniverse
import MDAnalysis as mda
from MDAnalysis.analysis.base import Results

from . import Database
from .database import Database, Table


def get_NapA_universe_by_simID(db: Database, simID: int) -> NapAUniverse:
row = db.get_table("Simulations").get_row(simID)
topology = row.topology
trajectory = row.trajectory
u = NapAUniverse(topology)
u.load_new(trajectory)
class DBAnalysisManager:
"""Class that connects database IO with analysis running.
return u
"""

def __init__(self, Analysis, dbfile, hooks=None):
"""
def get_universe_by_simID(db: Database, simID: int) -> mda.Universe:
row = db.get_table("Simulations").get_row(simID)
topology = row.topology
trajectory = row.trajectory
Parameters
----------
Analysis :
dbfile :
hooks :
return mda.Universe(topology, trajectory)
"""


class DBAnalysisRunner:

def __init__(self, db: Database, Analysis):

self.db = db
self.Analysis = Analysis
self.name = self.Analysis.name
self._analysis = None
self.db = Database(dbfile)

try:
self.analysis_name = self.Analysis.name
except AttributeError:
self.analysis_name = self.Analysis.__name__
try:
self.analysis_notes = self.Analysis.notes
except AttributeError:
self.analysis_notes = None
self.analysis_path = inspect.getfile(self.Analysis)

try:
self.observables = self.db.get_table("Observables")
self.obsv = self.db.get_table("Observables")
except ValueError:
self.observables = self.db.create_table(
"Observables (name TEXT, progenitor TEXT)"
self.obsv = self.db.create_table(
"""
Observables (
obsName TEXT,
notes TEXT,
creator TEXT,
timestamp DATETIME DEFAULT (strftime('%m-%d-%Y %H:%M', 'now', 'localtime'))
)
""",
STRICT=False
)
finally:
self.observables.insert_array([
(self.name, self.Analysis._path),
])

def __enter__(self):
self.db.open()
return self
if self.analysis_name not in self.obsv.get_column("obsName").data:
self.obsv.insert_row(
(self.analysis_name, self.analysis_notes, self.analysis_path),
columns=["obsName, notes, creator"],
)

def __exit__(self, *args):
self.db.close()
self._analysis = None

@property
def results(self):
if self._analysis is not None:
return self._analysis.results

def run_for_simID(self, simID: int, **kwargs) -> None:
""""""
universe = get_NapA_universe_by_simID(self.db, simID)
self._analysis = self.Analysis(universe, **kwargs)
def results(self) -> Results:
"""Analysis results."""

if self._analysis is None:
raise ValueError("Must call run() for results to exist.")
return self._analysis.results

def _get_universe(self, simID: int, get_universe: Optional[Callable]):

if get_universe is not None:
#if self.hooks["get_universe"]:
return get_universe(self.db, simID)

row = self.db.get_table("Simulations").get_row(simID)
return mda.Universe(row.topology, row.trajecory)

def run(
self,
simID: int,
get_universe: Optional[Callable] = None,
**kwargs: dict,
) -> None:
"""
Parameters
----------
simID : int
get_universe : Callable[Database, int]
**kwargs : dict
additional keyword arguments to be passed to the Analysis class
"""
u = self._get_universe(simID, get_universe)

self._analysis = self.Analysis(u, **kwargs)
self._analysis._simID = simID
self._analysis.run()

def save(self) -> None:
"""Save the results of the analysis to the database."""

if not self.results:
raise ValueError("no results")

if self.name not in self.db._get_table_names():
analysis_table = Table(self.analysis_name, self.db)

if analysis_table not in self.db:
self.db.create_table(self.Analysis.schema)
else:
simID = self._analysis._simID
if simID in analysis_table._get_rowids():
raise ValueError(
f"{self.analysis_name} table already has data for simID {simID}"
)

rows = self.results[self.Analysis.results_key]

table = self.db.get_table(self.name)
table.insert_array(rows)
self.db.get_table(self.analysis_name).insert_rows(rows)

def __enter__(self):
return self

def __exit__(self, *args):
self.db.close()

0 comments on commit 68ff3e8

Please sign in to comment.