Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThermodynamicState support for more barostats #437

Merged
merged 13 commits into from
Oct 30, 2019
Merged
2 changes: 2 additions & 0 deletions docs/releasehistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Release History
New features
------------
- Added support in ``AbsoluteAlchemicalFactory`` for handling multiple independent alchemical regions (`#438 <https://github.com/choderalab/openmmtools/pull/438>`_).
- Added support for anisotropic and membrane barostats in `ThermodynamicState` (`#437 <https://github.com/choderalab/openmmtools/pull/437>`_)
- Added support for platform properties in ContextCache (e.g. for mixed and double precision CUDA in multistate sampler) (`#437 <https://github.com/choderalab/openmmtools/pull/437>`_)

0.18.3 - Storage enhancements and bugfixes
==========================================
Expand Down
56 changes: 53 additions & 3 deletions openmmtools/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ class ContextCache(object):
platform : simtk.openmm.Platform, optional
The OpenMM platform to use to create Contexts. If None, OpenMM
tries to select the fastest one available (default is None).
platform_properties : dict, optional
A dictionary of platform properties for the OpenMM platform.
Only valid if the platform is not None (default is None).
**kwargs
Parameters to pass to the underlying LRUCache constructor such
as capacity and time_to_live.
Expand Down Expand Up @@ -303,8 +306,10 @@ class ContextCache(object):

"""

def __init__(self, platform=None, **kwargs):
def __init__(self, platform=None, platform_properties=None, **kwargs):
self._validate_platform_properties(platform, platform_properties)
self._platform = platform
self._platform_properties = platform_properties
self._lru = LRUCache(**kwargs)

def __len__(self):
Expand All @@ -324,8 +329,18 @@ def platform(self):
def platform(self, new_platform):
if len(self._lru) > 0:
raise RuntimeError('Cannot change platform of a non-empty ContextCache')
if new_platform is None:
self._platform_properties = None
self._validate_platform_properties(new_platform, self._platform_properties)
self._platform = new_platform

def set_platform(self, new_platform, platform_properties=None):
if len(self._lru) > 0:
raise RuntimeError('Cannot change platform of a non-empty ContextCache')
self._validate_platform_properties(new_platform, platform_properties)
self._platform = new_platform
self._platform_properties = platform_properties

@property
def capacity(self):
"""The maximum number of Context cached.
Expand Down Expand Up @@ -429,7 +444,7 @@ def get_context(self, thermodynamic_state, integrator=None):
try:
context = self._lru[context_id]
except KeyError:
context = thermodynamic_state.create_context(integrator, self._platform)
context = thermodynamic_state.create_context(integrator, self._platform, self._platform_properties)
self._lru[context_id] = context
context_integrator = context.getIntegrator()

Expand All @@ -442,18 +457,24 @@ def get_context(self, thermodynamic_state, integrator=None):
return context, context_integrator

def __getstate__(self):
# this serialization format was introduced in openmmtools > 0.18.3 (pull request #437)
if self.platform is not None:
platform_serialization = self.platform.getName()
else:
platform_serialization = None
return dict(platform=platform_serialization, capacity=self.capacity,
time_to_live=self.time_to_live)
time_to_live=self.time_to_live, platform_properties=self._platform_properties)

def __setstate__(self, serialization):
# this serialization format was introduced in openmmtools > 0.18.3 (pull request #437)
if serialization['platform'] is None:
self._platform = None
else:
self._platform = openmm.Platform.getPlatformByName(serialization['platform'])
if not 'platform_properties' in serialization:
self._platform_properties = None
else:
self._platform_properties = serialization["platform_properties"]
self._lru = LRUCache(serialization['capacity'], serialization['time_to_live'])

# -------------------------------------------------------------------------
Expand Down Expand Up @@ -630,6 +651,35 @@ def _default_integrator_id(cls):
return cls._cached_default_integrator_id
_cached_default_integrator_id = None

@staticmethod
def _validate_platform_properties(platform=None, platform_properties=None):
"""Check if platform properties are valid for the platform; else raise ValueError."""
if platform_properties is None:
return True
if platform_properties is not None and platform is None:
raise ValueError("To set platform_properties, you need to also specify the platform.")
if not isinstance(platform_properties, dict):
raise ValueError("platform_properties must be a dictionary")
for key, value in platform_properties.items():
if not isinstance(value, str):
raise ValueError(
"All platform properties must be strings. You supplied {}: {} of type {}".format(
key, value, type(value)
)
)
# create a context to check if all properties are
dummy_system = openmm.System()
dummy_system.addParticle(1)
dummy_integrator = openmm.VerletIntegrator(1.0*unit.femtoseconds)
try:
openmm.Context(dummy_system, dummy_integrator, platform, platform_properties)
return True
except Exception as e:
if "Illegal property name" in str(e):
raise ValueError("Invalid platform property for this platform. {}".format(e))
else:
raise e


# =============================================================================
# DUMMY CONTEXT CACHE
Expand Down
Loading