Skip to content

Commit

Permalink
enable new dask 0.18 scheduler selection idoms
Browse files Browse the repository at this point in the history
- someone should check with dask. It seems a bit brittle
- fix tests maybe
- update documentation

fixes #17
  • Loading branch information
kain88-de committed Oct 30, 2018
1 parent 8686767 commit d71eb98
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 45 deletions.
9 changes: 6 additions & 3 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Released under the GNU Public Licence, v2 or any higher version

from dask import distributed
import dask
import pytest


Expand All @@ -24,9 +25,11 @@ def client(tmpdir_factory, request):
lc.close()


@pytest.fixture(scope='session', params=('distributed', 'multiprocessing'))
@pytest.fixture(scope='session', params=('distributed', 'multiprocessing', 'single-threaded'))
def scheduler(request, client):
if request.param == 'distributed':
return client
arg = client
else:
return request.param
arg = request.param
with dask.config.set(scheduler=arg):
yield
6 changes: 3 additions & 3 deletions docs/userguide/parallelization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Internally, this uses the multiprocessing `scheduler`_ of dask. If you
want to make use of more advanced scheduler features or scale your
analysis to multiple nodes, e.g., in an HPC (high performance
computing) environment, then use the :mod:`distributed` scheduler, as
described next.
described next. If ``n_jobs==1`` use a single threaded scheduler.

.. _`scheduler`:
https://dask.pydata.org/en/latest/scheduler-overview.html
Expand Down Expand Up @@ -58,7 +58,7 @@ use the :ref:`RMSD example<example-parallel-rmsd>`):

.. code:: python
rmsd_ana = rms.RMSD(u.atoms, ref.atoms).run(scheduler=client)
rmsd_ana = rms.RMSD(u.atoms, ref.atoms).run()
Because the local cluster contains 8 workers, the RMSD trajectory
analysis will be parallelized over 8 trajectory segments.
Expand All @@ -78,7 +78,7 @@ analysis :meth:`~pmda.parallel.ParallelAnalysisBase.run` method:
import distributed
client = distributed.Client('192.168.0.1:8786')
rmsd_ana = rms.RMSD(u.atoms, ref.atoms).run(scheduler=client)
rmsd_ana = rms.RMSD(u.atoms, ref.atoms).run()
In this way one can spread an analysis task over many different nodes.

Expand Down
2 changes: 1 addition & 1 deletion docs/userguide/pmda_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ are provided as keyword arguments:

set up the parallel analysis

.. method:: run(n_jobs=-1, scheduler=None)
.. method:: run(n_jobs=-1)

perform parallel analysis; see :ref:`parallelization`
for explanation of the arguments
Expand Down
43 changes: 30 additions & 13 deletions pmda/leaflet.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def run(self,
start=None,
stop=None,
step=None,
scheduler=None,
n_jobs=-1,
cutoff=15.0):
"""Perform the calculation
Expand All @@ -244,35 +243,53 @@ def run(self,
stop frame of analysis
step : int, optional
number of frames to skip between each analysed frame
scheduler : dask scheduler, optional
Use dask scheduler, defaults to multiprocessing. This can be used
to spread work to a distributed scheduler
n_jobs : int, optional
number of tasks to start, if `-1` use number of logical cpu cores.
This argument will be ignored when the distributed scheduler is
used
"""
if scheduler is None:
# are we using a distributed scheduler or should we use multiprocessing?
scheduler = dask.config.get('scheduler', None)
if scheduler is None and client is None:
scheduler = 'multiprocessing'
elif scheduler is None:
# maybe we can grab a global worker
try:
from dask import distributed
scheduler = distributed.worker.get_client()
except ValueError:
pass
except ImportError:
pass

if n_jobs == -1:
n_jobs = cpu_count()

# we could not find a global scheduler to use and we ask for a single
# job. Therefore we run this on the single threaded scheduler for
# debugging.
if scheduler is None and n_jobs == 1:
scheduler = 'single-threaded'

if n_blocks is None:
if scheduler == 'multiprocessing':
n_jobs = cpu_count()
n_blocks = n_jobs
elif isinstance(scheduler, distributed.Client):
n_jobs = len(scheduler.ncores())
n_blocks = len(scheduler.ncores())
else:
raise ValueError(
"Couldn't guess ideal number of jobs from scheduler."
"Please provide `n_jobs` in call to method.")

with timeit() as b_universe:
universe = mda.Universe(self._top, self._traj)
n_blocks = 1
warnings.warn(
"Couldn't guess ideal number of blocks from scheduler. Set n_blocks=1"
"Please provide `n_blocks` in call to method.")

scheduler_kwargs = {'scheduler': scheduler}
if scheduler == 'multiprocessing':
scheduler_kwargs['num_workers'] = n_jobs

with timeit() as b_universe:
universe = mda.Universe(self._top, self._traj)

start, stop, step = self._trajectory.check_slice_indices(
start, stop, step)
with timeit() as total:
Expand Down
29 changes: 21 additions & 8 deletions pmda/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from six.moves import range

import MDAnalysis as mda
from dask import distributed
from dask.delayed import delayed
from joblib import cpu_count
import numpy as np
Expand Down Expand Up @@ -267,7 +266,6 @@ def run(self,
start=None,
stop=None,
step=None,
scheduler=None,
n_jobs=1,
n_blocks=None):
"""Perform the calculation
Expand All @@ -280,9 +278,6 @@ def run(self,
stop frame of analysis
step : int, optional
number of frames to skip between each analysed frame
scheduler : dask scheduler, optional
Use dask scheduler, defaults to multiprocessing. This can be used
to spread work to a distributed scheduler
n_jobs : int, optional
number of jobs to start, if `-1` use number of logical cpu cores.
This argument will be ignored when the distributed scheduler is
Expand All @@ -292,20 +287,38 @@ def run(self,
to n_jobs or number of available workers in scheduler.
"""
if scheduler is None:
# are we using a distributed scheduler or should we use multiprocessing?
scheduler = dask.config.get('scheduler', None)
if scheduler is None and client is None:
scheduler = 'multiprocessing'
elif scheduler is None:
# maybe we can grab a global worker
try:
from dask import distributed
scheduler = distributed.worker.get_client()
except ValueError:
pass
except ImportError:
pass

if n_jobs == -1:
n_jobs = cpu_count()

# we could not find a global scheduler to use and we ask for a single
# job. Therefore we run this on the single threaded scheduler for
# debugging.
if scheduler is None and n_jobs == 1:
scheduler = 'single-threaded'

if n_blocks is None:
if scheduler == 'multiprocessing':
n_blocks = n_jobs
elif isinstance(scheduler, distributed.Client):
n_blocks = len(scheduler.ncores())
else:
raise ValueError(
"Couldn't guess ideal number of blocks from scheduler."
n_blocks = 1
warnings.warn(
"Couldn't guess ideal number of blocks from scheduler. Set n_blocks=1"
"Please provide `n_blocks` in call to method.")

scheduler_kwargs = {'scheduler': scheduler}
Expand Down
6 changes: 3 additions & 3 deletions pmda/test/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def test_AnalysisFromFunction(scheduler):
u = mda.Universe(PSF, DCD)
step = 2
ana1 = custom.AnalysisFromFunction(custom_function, u, u.atoms).run(
step=step, scheduler=scheduler
step=step
)
ana2 = custom.AnalysisFromFunction(custom_function, u, u.atoms).run(
step=step, scheduler=scheduler
step=step
)
ana3 = custom.AnalysisFromFunction(custom_function, u, u.atoms).run(
step=step, scheduler=scheduler
step=step
)

results = []
Expand Down
15 changes: 1 addition & 14 deletions pmda/test/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ def analysis():
return ana


def test_wrong_scheduler(analysis):
with pytest.raises(ValueError):
analysis.run(scheduler=2)


@pytest.mark.parametrize('n_jobs', (1, 2))
def test_all_frames(analysis, n_jobs):
analysis.run(n_jobs=n_jobs)
Expand All @@ -91,16 +86,8 @@ def test_no_frames(analysis, n_jobs):
assert analysis.timing.universe == 0


@pytest.fixture(scope='session', params=('distributed', 'multiprocessing'))
def scheduler(request, client):
if request.param == 'distributed':
return client
else:
return request.param


def test_scheduler(analysis, scheduler):
analysis.run(scheduler=scheduler)
analysis.run()


def test_nframes_less_nblocks_warning(analysis):
Expand Down

0 comments on commit d71eb98

Please sign in to comment.