Skip to content

Commit

Permalink
Merge pull request #437 from Olllom/other_barostats
Browse files Browse the repository at this point in the history
ThermodynamicState support for more barostats
  • Loading branch information
andrrizzi authored Oct 30, 2019
2 parents 21dee06 + 46308b0 commit 6e07232
Show file tree
Hide file tree
Showing 5 changed files with 536 additions and 71 deletions.
2 changes: 2 additions & 0 deletions docs/releasehistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Release History
New features
------------
- Added support in ``AbsoluteAlchemicalFactory`` for handling multiple independent alchemical regions (`#438 <https://github.com/choderalab/openmmtools/pull/438>`_).
- Added support for anisotropic and membrane barostats in `ThermodynamicState` (`#437 <https://github.com/choderalab/openmmtools/pull/437>`_)
- Added support for platform properties in ContextCache (e.g. for mixed and double precision CUDA in multistate sampler) (`#437 <https://github.com/choderalab/openmmtools/pull/437>`_)

Bugfixes
--------
Expand Down
56 changes: 53 additions & 3 deletions openmmtools/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ class ContextCache(object):
platform : simtk.openmm.Platform, optional
The OpenMM platform to use to create Contexts. If None, OpenMM
tries to select the fastest one available (default is None).
platform_properties : dict, optional
A dictionary of platform properties for the OpenMM platform.
Only valid if the platform is not None (default is None).
**kwargs
Parameters to pass to the underlying LRUCache constructor such
as capacity and time_to_live.
Expand Down Expand Up @@ -303,8 +306,10 @@ class ContextCache(object):
"""

def __init__(self, platform=None, **kwargs):
def __init__(self, platform=None, platform_properties=None, **kwargs):
self._validate_platform_properties(platform, platform_properties)
self._platform = platform
self._platform_properties = platform_properties
self._lru = LRUCache(**kwargs)

def __len__(self):
Expand All @@ -324,8 +329,18 @@ def platform(self):
def platform(self, new_platform):
if len(self._lru) > 0:
raise RuntimeError('Cannot change platform of a non-empty ContextCache')
if new_platform is None:
self._platform_properties = None
self._validate_platform_properties(new_platform, self._platform_properties)
self._platform = new_platform

def set_platform(self, new_platform, platform_properties=None):
if len(self._lru) > 0:
raise RuntimeError('Cannot change platform of a non-empty ContextCache')
self._validate_platform_properties(new_platform, platform_properties)
self._platform = new_platform
self._platform_properties = platform_properties

@property
def capacity(self):
"""The maximum number of Context cached.
Expand Down Expand Up @@ -429,7 +444,7 @@ def get_context(self, thermodynamic_state, integrator=None):
try:
context = self._lru[context_id]
except KeyError:
context = thermodynamic_state.create_context(integrator, self._platform)
context = thermodynamic_state.create_context(integrator, self._platform, self._platform_properties)
self._lru[context_id] = context
context_integrator = context.getIntegrator()

Expand All @@ -442,18 +457,24 @@ def get_context(self, thermodynamic_state, integrator=None):
return context, context_integrator

def __getstate__(self):
# this serialization format was introduced in openmmtools > 0.18.3 (pull request #437)
if self.platform is not None:
platform_serialization = self.platform.getName()
else:
platform_serialization = None
return dict(platform=platform_serialization, capacity=self.capacity,
time_to_live=self.time_to_live)
time_to_live=self.time_to_live, platform_properties=self._platform_properties)

def __setstate__(self, serialization):
# this serialization format was introduced in openmmtools > 0.18.3 (pull request #437)
if serialization['platform'] is None:
self._platform = None
else:
self._platform = openmm.Platform.getPlatformByName(serialization['platform'])
if not 'platform_properties' in serialization:
self._platform_properties = None
else:
self._platform_properties = serialization["platform_properties"]
self._lru = LRUCache(serialization['capacity'], serialization['time_to_live'])

# -------------------------------------------------------------------------
Expand Down Expand Up @@ -630,6 +651,35 @@ def _default_integrator_id(cls):
return cls._cached_default_integrator_id
_cached_default_integrator_id = None

@staticmethod
def _validate_platform_properties(platform=None, platform_properties=None):
"""Check if platform properties are valid for the platform; else raise ValueError."""
if platform_properties is None:
return True
if platform_properties is not None and platform is None:
raise ValueError("To set platform_properties, you need to also specify the platform.")
if not isinstance(platform_properties, dict):
raise ValueError("platform_properties must be a dictionary")
for key, value in platform_properties.items():
if not isinstance(value, str):
raise ValueError(
"All platform properties must be strings. You supplied {}: {} of type {}".format(
key, value, type(value)
)
)
# create a context to check if all properties are
dummy_system = openmm.System()
dummy_system.addParticle(1)
dummy_integrator = openmm.VerletIntegrator(1.0*unit.femtoseconds)
try:
openmm.Context(dummy_system, dummy_integrator, platform, platform_properties)
return True
except Exception as e:
if "Illegal property name" in str(e):
raise ValueError("Invalid platform property for this platform. {}".format(e))
else:
raise e


# =============================================================================
# DUMMY CONTEXT CACHE
Expand Down
Loading

0 comments on commit 6e07232

Please sign in to comment.