Skip to content

Commit

Permalink
Merge pull request #722 from StochSS/python-sbml-feature-validation
Browse files Browse the repository at this point in the history
Add unit test for SBML feature validation
  • Loading branch information
seanebum authored Feb 18, 2022
2 parents 75a4d07 + ace83f1 commit 7d3648d
Show file tree
Hide file tree
Showing 13 changed files with 127 additions and 39 deletions.
18 changes: 16 additions & 2 deletions gillespy2/core/gillespySolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

from .gillespyError import SimulationError
from .gillespyError import SimulationError, ModelError
from typing import Set, Type


class GillesPySolver:
Expand Down Expand Up @@ -79,4 +80,17 @@ def get_increment(self, increment):
`increment` argument from this `solver.run()` call.
"""
)
return increment
return increment

@classmethod
def get_supported_features(cls) -> "Set[Type]":
return set()

@classmethod
def validate_sbml_features(cls, model):
unsupported_features = model.get_model_features() - cls.get_supported_features()
if unsupported_features:
unsupported_features = [feature.__name__ for feature in unsupported_features]
raise ModelError(f"Could not run Model, "
f"SBML Features not supported by {cls.name}: " +
", ".join(unsupported_features))
21 changes: 20 additions & 1 deletion gillespy2/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

import gillespy2
from gillespy2.core.jsonify import TranslationTable
from gillespy2.core.reaction import *
from gillespy2.core.raterule import RateRule
Expand All @@ -27,6 +27,7 @@
from collections import OrderedDict
from gillespy2.core.gillespyError import *
from .gillespyError import SimulationError
from typing import Set, Type

try:
import lxml.etree as eTree
Expand Down Expand Up @@ -928,6 +929,24 @@ def get_best_solver_algo(self, algorithm):
raise ModelError("Invalid value for the argument 'algorithm' entered. "
"Please enter 'SSA', 'ODE', 'Tau-leaping', or 'Tau-Hybrid'.")

def get_model_features(self) -> "Set[Type]":
"""
Determine what solver-specific model features are present on the model.
Used to validate that the model is compatible with the given solver.
:returns: Set containing the classes of every solver-specific feature present on the model.
"""
features = set()
if len(self.listOfEvents):
features.add(gillespy2.Event)
if len(self.listOfRateRules):
features.add(gillespy2.RateRule)
if len(self.listOfAssignmentRules):
features.add(gillespy2.AssignmentRule)
if len(self.listOfFunctionDefinitions):
features.add(gillespy2.FunctionDefinition)
return features

def run(self, solver=None, timeout=0, t=None, increment=None, show_labels=True, cpp_support=False, algorithm=None,
**solver_args):
"""
Expand Down
11 changes: 0 additions & 11 deletions gillespy2/solvers/cpp/c_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,17 +365,6 @@ def _validate_kwargs(self, **kwargs):
for key, val in kwargs.items():
log.warning(f"Unsupported keyword argument for solver {self.name}: {key}")

def _validate_sbml_features(self, unsupported_features: "dict[str, str]"):
detected = [ ]
for feature_name, count in unsupported_features.items():
if count:
detected.append(feature_name)

if len(detected):
raise gillespyError.ModelError(f"Could not run Model, "
f"SBML Features not supported by {self.name}: "
+ ", ".join(detected))

def _validate_seed(self, seed: int):
if seed is None:
return None
Expand Down
7 changes: 1 addition & 6 deletions gillespy2/solvers/cpp/ode_c_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def run(self=None, model: Model = None, t: int = 20, number_of_trajectories: int
if model is not None and model.get_json_hash() != self.model.get_json_hash():
raise SimulationError("Model must equal ODECSolver.model.")
self.model.resolve_parameters()
self.validate_sbml_features(model=model)

increment = self.get_increment(increment=increment)

Expand All @@ -56,12 +57,6 @@ def run(self=None, model: Model = None, t: int = 20, number_of_trajectories: int

self._validate_resume(t, resume)
self._validate_kwargs(**kwargs)
self._validate_sbml_features({
"Rate Rules": len(self.model.listOfRateRules),
"Assignment Rules": len(self.model.listOfAssignmentRules),
"Events": len(self.model.listOfEvents),
"Function Definitions": len(self.model.listOfFunctionDefinitions)
})

if resume is not None:
t = abs(t - int(resume["time"][-1]))
Expand Down
7 changes: 1 addition & 6 deletions gillespy2/solvers/cpp/ssa_c_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def run(self=None, model: Model = None, t: int = 20, number_of_trajectories: int
if model is not None and model.get_json_hash() != self.model.get_json_hash():
raise SimulationError("Model must equal SSACSolver.model.")
self.model.resolve_parameters()
self.validate_sbml_features(model=model)

increment = self.get_increment(increment=increment)

Expand All @@ -57,12 +58,6 @@ def run(self=None, model: Model = None, t: int = 20, number_of_trajectories: int

self._validate_resume(t, resume)
self._validate_kwargs(**kwargs)
self._validate_sbml_features({
"Rate Rules": len(self.model.listOfRateRules),
"Assignment Rules": len(self.model.listOfAssignmentRules),
"Events": len(self.model.listOfEvents),
"Function Definitions": len(self.model.listOfFunctionDefinitions)
})

if resume is not None:
t = abs(t - int(resume["time"][-1]))
Expand Down
16 changes: 10 additions & 6 deletions gillespy2/solvers/cpp/tau_hybrid_c_solver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gillespy2
from gillespy2.solvers.cpp.c_decoder import IterativeSimDecoder
from gillespy2.solvers.utilities import solverutils as cutils
from gillespy2.core import GillesPySolver, Model
from gillespy2.core import GillesPySolver, Model, Event, RateRule
from gillespy2.core.gillespyError import *
from typing import Union
from gillespy2.core import Results
Expand Down Expand Up @@ -141,6 +141,13 @@ def __create_options(cls, sanitized_model: "SanitizedModel") -> "SanitizedModel"
sanitized_model.options["GPY_HYBRID_NUM_EVENT_ASSIGNMENTS"] = str(len(event_assignment_list))
return sanitized_model

@classmethod
def get_supported_features(cls):
return {
Event,
RateRule,
}

def _build(self, model: "Union[Model, SanitizedModel]", simulation_name: str, variable: bool, debug: bool = False,
custom_definitions=None) -> str:
variable = variable or len(model.listOfEvents) > 0
Expand All @@ -156,7 +163,7 @@ def get_solver_settings(self):
return ('model', 't', 'number_of_trajectories', 'timeout', 'increment', 'seed', 'debug', 'profile')

def run(self=None, model: Model = None, t: int = 20, number_of_trajectories: int = 1, timeout: int = 0,
increment: int = None, seed: int = None, debug: bool = False, profile: bool = False, variables={},
increment: int = None, seed: int = None, debug: bool = False, profile: bool = False, variables={},
resume=None, live_output: str = None, live_output_options: dict = {}, tau_step: int = .03, tau_tol=0.03, **kwargs):

if self is None:
Expand All @@ -168,6 +175,7 @@ def run(self=None, model: Model = None, t: int = 20, number_of_trajectories: int
if model is not None and model.get_json_hash() != self.model.get_json_hash():
raise SimulationError("Model must equal TauHybridCSolver.model.")
self.model.resolve_parameters()
self.validate_sbml_features(model=model)

increment = self.get_increment(increment=increment)

Expand All @@ -176,10 +184,6 @@ def run(self=None, model: Model = None, t: int = 20, number_of_trajectories: int
self._validate_variables_in_set(variables, self.species + self.parameters)
self._validate_resume(t, resume)
self._validate_kwargs(**kwargs)
self._validate_sbml_features({
"Assignment Rules": len(self.model.listOfAssignmentRules),
"Function Definitions": len(self.model.listOfFunctionDefinitions)
})

if resume is not None:
t = abs(t - int(resume["time"][-1]))
Expand Down
7 changes: 1 addition & 6 deletions gillespy2/solvers/cpp/tau_leaping_c_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def run(self=None, model: Model = None, t: int = 20, number_of_trajectories: int
if model is not None and model.get_json_hash() != self.model.get_json_hash():
raise SimulationError("Model must equal TauLeapingCSolver.model.")
self.model.resolve_parameters()
self.validate_sbml_features(model=model)

increment = self.get_increment(increment=increment)

Expand All @@ -56,12 +57,6 @@ def run(self=None, model: Model = None, t: int = 20, number_of_trajectories: int
self._validate_variables_in_set(variables, self.species + self.parameters)
self._validate_resume(t, resume)
self._validate_kwargs(**kwargs)
self._validate_sbml_features({
"Rate Rules": len(self.model.listOfRateRules),
"Assignment Rules": len(self.model.listOfAssignmentRules),
"Events": len(self.model.listOfEvents),
"Function Definitions": len(self.model.listOfFunctionDefinitions)
})

if resume is not None:
t = abs(t - int(resume["time"][-1]))
Expand Down
1 change: 1 addition & 0 deletions gillespy2/solvers/numpy/CLE_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def run(self, model=None, t=20, number_of_trajectories=1, increment=None, seed=N
if model is not None and model.get_json_hash() != self.model.get_json_hash():
raise SimulationError("Model must equal CLESolver.model.")
self.model.resolve_parameters()
self.validate_sbml_features(model=model)

increment = self.get_increment(increment=increment)

Expand Down
1 change: 1 addition & 0 deletions gillespy2/solvers/numpy/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def run(self, model=None, t=20, number_of_trajectories=1, increment=None, show_l
if model is not None and model.get_json_hash() != self.model.get_json_hash():
raise SimulationError("Model must equal OSESolver.model.")
self.model.resolve_parameters()
self.validate_sbml_features(model=model)

increment = self.get_increment(increment=increment)

Expand Down
1 change: 1 addition & 0 deletions gillespy2/solvers/numpy/ssa_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def run(self, model=None, t=20, number_of_trajectories=1, increment=None, seed=N
if model is not None and model.get_json_hash() != self.model.get_json_hash():
raise SimulationError("Model must equal NumPySSASolver.model.")
self.model.resolve_parameters()
self.validate_sbml_features(model=model)

increment = self.get_increment(increment=increment)

Expand Down
12 changes: 11 additions & 1 deletion gillespy2/solvers/numpy/tau_hybrid_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import threading
import gillespy2
from gillespy2.solvers.utilities import Tau
from gillespy2.core import GillesPySolver, log
from gillespy2.core import GillesPySolver, log, Event, RateRule, AssignmentRule, FunctionDefinition
from gillespy2.core.gillespyError import *
from gillespy2.core.results import Results

Expand Down Expand Up @@ -764,6 +764,15 @@ def get_solver_settings(self):
return ('model', 't', 'number_of_trajectories', 'increment', 'seed', 'debug', 'profile', 'tau_tol',
'event_sensitivity', 'integrator', 'integrator_options', 'timeout')

@classmethod
def get_supported_features(cls):
return {
Event,
RateRule,
AssignmentRule,
FunctionDefinition,
}

@classmethod
def run(self, model=None, t=20, number_of_trajectories=1, increment=None, seed=None,
debug=False, profile=False, tau_tol=0.03, event_sensitivity=100, integrator='LSODA',
Expand Down Expand Up @@ -827,6 +836,7 @@ def run(self, model=None, t=20, number_of_trajectories=1, increment=None, seed=N
if model is not None and model.get_json_hash() != self.model.get_json_hash():
raise SimulationError("Model must equal TauHybridSolver.model.")
self.model.resolve_parameters()
self.validate_sbml_features(model=model)

increment = self.get_increment(increment=increment)

Expand Down
1 change: 1 addition & 0 deletions gillespy2/solvers/numpy/tau_leaping_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def run(self, model=None, t=20, number_of_trajectories=1, increment=None, seed=N
if model is not None and model.get_json_hash() != self.model.get_json_hash():
raise SimulationError("Model must equal TauLeapingSolver.model.")
self.model.resolve_parameters()
self.validate_sbml_features(model=model)

increment = self.get_increment(increment=increment)

Expand Down
63 changes: 63 additions & 0 deletions test/test_all_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,43 @@ class TestAllSolvers(unittest.TestCase):
TauHybridCSolver,
]

sbml_features = {
"AssignmentRule": lambda model, variable:
model.add_assignment_rule(gillespy2.AssignmentRule(variable=variable, formula="1/(t+1)")),
"RateRule": lambda model, variable:
model.add_rate_rule(gillespy2.RateRule(variable=variable, formula="2*t")),
"Event": lambda model, variable:
model.add_event(gillespy2.Event(
trigger=gillespy2.EventTrigger(expression="t>1"),
assignments=[gillespy2.EventAssignment(variable=variable, expression="100")]
)),
"FunctionDefinition": lambda model, variable:
model.add_function_definition(
gillespy2.FunctionDefinition(name="fn", function="variable", args=["variable"])),
}

# List of supported SBML features for each solver.
# When a feature is implemented for a particular solver, add the feature to its list.
solver_supported_sbml_features = {
NumPySSASolver: [],
TauLeapingSolver: [],
ODESolver: [],
TauHybridSolver: [
"AssignmentRule",
"RateRule",
"Event",
"FunctionDefinition",
],

SSACSolver: [],
ODECSolver: [],
TauLeapingCSolver: [],
TauHybridCSolver: [
"RateRule",
"Event",
],
}

model = Example()
for sp in model.listOfSpecies.values():
sp.mode = 'discrete'
Expand Down Expand Up @@ -129,6 +166,32 @@ def test_basic_solver_import(self):
results3 = model.run(solver=BasicTauHybridSolver)
self.assertTrue(results3[0].solver_name == 'TauHybridSolver')

def test_sbml_feature_validation(self):
class TestModel(gillespy2.Model):
def __init__(self):
gillespy2.Model.__init__(self, name="TestModel")
self.add_species(gillespy2.Species(name="S", initial_value=0))
self.timespan(np.linspace(0, 10, 11))

all_features = set(self.sbml_features.keys())
for solver in self.solvers:
unsupported_features = all_features.difference(self.solver_supported_sbml_features.get(solver))
with self.subTest(solver=solver.name):
for sbml_feature_name in unsupported_features:
model = TestModel()
with self.subTest("Unsupported model features raise an error", sbml_feature=sbml_feature_name):
add_sbml_feature = self.sbml_features.get(sbml_feature_name)
add_sbml_feature(model, "S")
with self.assertRaises(gillespy2.ModelError):
solver.validate_sbml_features(model=model)

for sbml_feature_name in self.solver_supported_sbml_features.get(solver):
model = TestModel()
with self.subTest("Supported model features validate successfully", sbml_feature=sbml_feature_name):
add_sbml_feature = self.sbml_features.get(sbml_feature_name)
add_sbml_feature(model, "S")
solver.validate_sbml_features(model=model)


if __name__ == '__main__':
unittest.main()

0 comments on commit 7d3648d

Please sign in to comment.