Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThermodynamicState support for more barostats #437

Merged
merged 13 commits into from
Oct 30, 2019
Merged
44 changes: 42 additions & 2 deletions openmmtools/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,9 @@ class ContextCache(object):
"""

def __init__(self, platform=None, platform_properties=None, **kwargs):
self._validate_platform_properties(platform, platform_properties)
self._platform = platform
self._platform_properties = platform_properties
if platform_properties is not None and platform is None:
raise ValueError("To set platform_properties, you need to also specify the platform.")
self._lru = LRUCache(**kwargs)

def __len__(self):
Expand All @@ -330,7 +329,17 @@ def platform(self):
def platform(self, new_platform):
if len(self._lru) > 0:
raise RuntimeError('Cannot change platform of a non-empty ContextCache')
if new_platform is None:
self._platform_properties = None
self._validate_platform_properties(new_platform, self._platform_properties)
self._platform = new_platform

def set_platform(self, new_platform, platform_properties=None):
if len(self._lru) > 0:
raise RuntimeError('Cannot change platform of a non-empty ContextCache')
self._validate_platform_properties(new_platform, platform_properties)
self._platform = new_platform
self._platform_properties = platform_properties

@property
def capacity(self):
Expand Down Expand Up @@ -448,6 +457,7 @@ def get_context(self, thermodynamic_state, integrator=None):
return context, context_integrator

def __getstate__(self):
# this serialization format was introduced in openmmtools > 0.18.3 (pull request #437)
if self.platform is not None:
platform_serialization = self.platform.getName()
else:
Expand All @@ -456,6 +466,7 @@ def __getstate__(self):
time_to_live=self.time_to_live, platform_properties=self._platform_properties)

def __setstate__(self, serialization):
# this serialization format was introduced in openmmtools > 0.18.3 (pull request #437)
if serialization['platform'] is None:
self._platform = None
else:
Expand Down Expand Up @@ -640,6 +651,35 @@ def _default_integrator_id(cls):
return cls._cached_default_integrator_id
_cached_default_integrator_id = None

@staticmethod
def _validate_platform_properties(platform=None, platform_properties=None):
"""Check if platform properties are valid for the platform; else raise ValueError."""
if platform_properties is None:
return True
if platform_properties is not None and platform is None:
raise ValueError("To set platform_properties, you need to also specify the platform.")
if not isinstance(platform_properties, dict):
raise ValueError("platform_properties must be a dictionary")
for key, value in platform_properties.items():
if not isinstance(value, str):
raise ValueError(
"All platform properties must be strings. You supplied {}: {} of type {}".format(
key, value, type(value)
)
)
# create a context to check if all properties are
dummy_system = openmm.System()
dummy_system.addParticle(1)
dummy_integrator = openmm.VerletIntegrator(1.0*unit.femtoseconds)
try:
openmm.Context(dummy_system, dummy_integrator, platform, platform_properties)
return True
except Exception as e:
if "Illegal property name" in str(e):
raise ValueError("Invalid platform property for this platform. {}".format(e))
else:
raise e


# =============================================================================
# DUMMY CONTEXT CACHE
Expand Down
Loading