From 5103271c4f4cc9604de905f11fee27b696c7a579 Mon Sep 17 00:00:00 2001 From: Oliver Beckstein Date: Mon, 29 Oct 2018 01:04:49 -0700 Subject: [PATCH] replaced get --> scheduler in leaflet and use 'multiprocessing' as string - modified tests so that they use default scheduler - supplying n_jobs - NOTE: test_leaflets() failes for n_jobs=2; this NEEDS TO BE FIXED in a separate PR; right now this is marked as XFAIL --- pmda/leaflet.py | 18 +++++++++--------- pmda/test/test_leaflet.py | 14 +++++++++----- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/pmda/leaflet.py b/pmda/leaflet.py index 3f879ae5..bc9e94b9 100644 --- a/pmda/leaflet.py +++ b/pmda/leaflet.py @@ -7,8 +7,8 @@ # # Released under the GNU Public Licence, v2 or any higher version """ -LeafletFInder Analysis tool --- :mod:`pmda.leaflet` -========================================================== +LeafletFinder Analysis tool --- :mod:`pmda.leaflet` +=================================================== This module contains parallel versions of analysis tasks in :mod:`MDAnalysis.analysis.leaflet`. @@ -27,7 +27,7 @@ from scipy.spatial import cKDTree import MDAnalysis as mda -from dask import distributed, multiprocessing +from dask import distributed from joblib import cpu_count from .parallel import ParallelAnalysisBase, Timing @@ -59,8 +59,8 @@ class LeafletFinder(ParallelAnalysisBase): At the moment, this class has far fewer features than the serial version :class:`MDAnalysis.analysis.leaflet.LeafletFinder`. - This version offers Leaflet Finder algorithm 4 ("Tree-based Nearest - Neighbor and Parallel-Connected Com- ponents (Tree-Search)") in + This version offers LeafletFinder algorithm 4 ("Tree-based Nearest + Neighbor and Parallel-Connected Components (Tree-Search)") in [Paraskevakos2018]_. Currently, periodic boundaries are not taken into account. @@ -254,10 +254,10 @@ def run(self, """ if scheduler is None: - scheduler = multiprocessing + scheduler = 'multiprocessing' if n_jobs == -1: - if scheduler == multiprocessing: + if scheduler == 'multiprocessing': n_jobs = cpu_count() elif isinstance(scheduler, distributed.Client): n_jobs = len(scheduler.ncores()) @@ -269,8 +269,8 @@ def run(self, with timeit() as b_universe: universe = mda.Universe(self._top, self._traj) - scheduler_kwargs = {'get': scheduler.get} - if scheduler == multiprocessing: + scheduler_kwargs = {'scheduler': scheduler} + if scheduler == 'multiprocessing': scheduler_kwargs['num_workers'] = n_jobs start, stop, step = self._trajectory.check_slice_indices( diff --git a/pmda/test/test_leaflet.py b/pmda/test/test_leaflet.py index cf16b57b..cf8aff29 100644 --- a/pmda/test/test_leaflet.py +++ b/pmda/test/test_leaflet.py @@ -7,7 +7,6 @@ import MDAnalysis from MDAnalysisTests.datafiles import Martini_membrane_gro from MDAnalysisTests.datafiles import GRO_MEMPROT, XTC_MEMPROT -from dask import multiprocessing from pmda import leaflet import numpy as np @@ -39,24 +38,29 @@ def correct_values(self): def correct_values_single_frame(self): return [np.arange(1, 2150, 12), np.arange(2521, 4670, 12)] - def test_leaflet(self, universe, correct_values): + # XFAIL for 2 jobs needs to be fixed! + @pytest.mark.parametrize('n_jobs', (1, pytest.mark.xfail(2))) + def test_leaflet(self, universe, correct_values, n_jobs): lipid_heads = universe.select_atoms("name P and resname POPG") universe.trajectory.rewind() leaflets = leaflet.LeafletFinder(universe, lipid_heads) - leaflets.run(scheduler=multiprocessing, n_jobs=1) + leaflets.run(n_jobs=n_jobs) results = [atoms.indices for atomgroup in leaflets.results for atoms in atomgroup] [assert_almost_equal(x, y, err_msg="error: leaflets should match " + "test values") for x, y in zip(results, correct_values)] + @pytest.mark.parametrize('n_jobs', (1, 2)) def test_leaflet_single_frame(self, u_one_frame, - correct_values_single_frame): + correct_values_single_frame, + n_jobs): lipid_heads = u_one_frame.select_atoms("name PO4") u_one_frame.trajectory.rewind() leaflets = leaflet.LeafletFinder(u_one_frame, - lipid_heads).run(start=0, stop=1) + lipid_heads).run(start=0, stop=1, + n_jobs=n_jobs) assert_almost_equal([atoms.indices for atomgroup in leaflets.results for atoms in atomgroup],