Skip to content

Commit

Permalink
replaced get --> scheduler in leaflet and use 'multiprocessing' as st…
Browse files Browse the repository at this point in the history
…ring

- modified tests so that they use default scheduler
- supplying n_jobs
- NOTE: test_leaflets() failes for n_jobs=2; this NEEDS TO BE FIXED in a
        separate PR; right now this is marked as XFAIL
  • Loading branch information
orbeckst committed Oct 29, 2018
1 parent cc11410 commit 5103271
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
18 changes: 9 additions & 9 deletions pmda/leaflet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#
# Released under the GNU Public Licence, v2 or any higher version
"""
LeafletFInder Analysis tool --- :mod:`pmda.leaflet`
==========================================================
LeafletFinder Analysis tool --- :mod:`pmda.leaflet`
===================================================
This module contains parallel versions of analysis tasks in
:mod:`MDAnalysis.analysis.leaflet`.
Expand All @@ -27,7 +27,7 @@
from scipy.spatial import cKDTree

import MDAnalysis as mda
from dask import distributed, multiprocessing
from dask import distributed
from joblib import cpu_count

from .parallel import ParallelAnalysisBase, Timing
Expand Down Expand Up @@ -59,8 +59,8 @@ class LeafletFinder(ParallelAnalysisBase):
At the moment, this class has far fewer features than the serial
version :class:`MDAnalysis.analysis.leaflet.LeafletFinder`.
This version offers Leaflet Finder algorithm 4 ("Tree-based Nearest
Neighbor and Parallel-Connected Com- ponents (Tree-Search)") in
This version offers LeafletFinder algorithm 4 ("Tree-based Nearest
Neighbor and Parallel-Connected Components (Tree-Search)") in
[Paraskevakos2018]_.
Currently, periodic boundaries are not taken into account.
Expand Down Expand Up @@ -254,10 +254,10 @@ def run(self,
"""
if scheduler is None:
scheduler = multiprocessing
scheduler = 'multiprocessing'

if n_jobs == -1:
if scheduler == multiprocessing:
if scheduler == 'multiprocessing':
n_jobs = cpu_count()
elif isinstance(scheduler, distributed.Client):
n_jobs = len(scheduler.ncores())
Expand All @@ -269,8 +269,8 @@ def run(self,
with timeit() as b_universe:
universe = mda.Universe(self._top, self._traj)

scheduler_kwargs = {'get': scheduler.get}
if scheduler == multiprocessing:
scheduler_kwargs = {'scheduler': scheduler}
if scheduler == 'multiprocessing':
scheduler_kwargs['num_workers'] = n_jobs

start, stop, step = self._trajectory.check_slice_indices(
Expand Down
14 changes: 9 additions & 5 deletions pmda/test/test_leaflet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import MDAnalysis
from MDAnalysisTests.datafiles import Martini_membrane_gro
from MDAnalysisTests.datafiles import GRO_MEMPROT, XTC_MEMPROT
from dask import multiprocessing
from pmda import leaflet
import numpy as np

Expand Down Expand Up @@ -39,24 +38,29 @@ def correct_values(self):
def correct_values_single_frame(self):
return [np.arange(1, 2150, 12), np.arange(2521, 4670, 12)]

def test_leaflet(self, universe, correct_values):
# XFAIL for 2 jobs needs to be fixed!
@pytest.mark.parametrize('n_jobs', (1, pytest.mark.xfail(2)))
def test_leaflet(self, universe, correct_values, n_jobs):
lipid_heads = universe.select_atoms("name P and resname POPG")
universe.trajectory.rewind()
leaflets = leaflet.LeafletFinder(universe, lipid_heads)
leaflets.run(scheduler=multiprocessing, n_jobs=1)
leaflets.run(n_jobs=n_jobs)
results = [atoms.indices for atomgroup in leaflets.results
for atoms in atomgroup]
[assert_almost_equal(x, y, err_msg="error: leaflets should match " +
"test values") for x, y in
zip(results, correct_values)]

@pytest.mark.parametrize('n_jobs', (1, 2))
def test_leaflet_single_frame(self,
u_one_frame,
correct_values_single_frame):
correct_values_single_frame,
n_jobs):
lipid_heads = u_one_frame.select_atoms("name PO4")
u_one_frame.trajectory.rewind()
leaflets = leaflet.LeafletFinder(u_one_frame,
lipid_heads).run(start=0, stop=1)
lipid_heads).run(start=0, stop=1,
n_jobs=n_jobs)

assert_almost_equal([atoms.indices for atomgroup in leaflets.results
for atoms in atomgroup],
Expand Down

0 comments on commit 5103271

Please sign in to comment.