Skip to content

Commit

Permalink
Protect StatefulPool from class methods
Browse files Browse the repository at this point in the history
StatefulPool doesn't support batching over with class methods because it
always passes `G` as the first argument to the worker function. If one of the
`run_` methods in StatefulPool is called with a class method it can lead to
a silent lock-up of the pool, which is very difficult to debug.

Note: this bug does not appear unless n_parallel > 1
  • Loading branch information
ryanjulian committed Mar 27, 2018
1 parent b3a2899 commit 4b41d10
Showing 1 changed file with 32 additions and 5 deletions.
37 changes: 32 additions & 5 deletions rllab/sampler/stateful_pool.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@


from joblib.pool import MemmapingPool
import inspect
import multiprocessing as mp
from rllab.misc import logger
import pyprind
import time
import traceback
import sys

from joblib.pool import MemmapingPool
import pyprind

from rllab.misc import logger


class ProgBarCounter(object):
def __init__(self, total_count):
Expand Down Expand Up @@ -68,6 +69,12 @@ def run_each(self, runner, args_list=None):
in the args_list, if any
:return:
"""
assert not inspect.ismethod(runner), (
"run_each() cannot run a class method. Please ensure that runner is"
" a function with the prototype def foo(G, ...), where G is an "
"object of type rllab.sampler.stateful_pool.SharedGlobal"
)

if args_list is None:
args_list = [tuple()] * self.n_parallel
assert len(args_list) == self.n_parallel
Expand All @@ -83,6 +90,12 @@ def run_each(self, runner, args_list=None):
return [runner(self.G, *args_list[0])]

def run_map(self, runner, args_list):
assert not inspect.ismethod(runner), (
"run_map() cannot run a class method. Please ensure that runner is "
"a function with the prototype 'def foo(G, ...)', where G is an "
"object of type rllab.sampler.stateful_pool.SharedGlobal"
)

if self.n_parallel > 1:
return self.pool.map(_worker_run_map, [(runner, args) for args in args_list])
else:
Expand All @@ -92,6 +105,13 @@ def run_map(self, runner, args_list):
return ret

def run_imap_unordered(self, runner, args_list):
assert not inspect.ismethod(runner), (
"run_imap_unordered() cannot run a class method. Please ensure that"
"runner is a function with the prototype 'def foo(G, ...)', where "
"G is an object of type rllab.sampler.stateful_pool.SharedGlobal"
)


if self.n_parallel > 1:
for x in self.pool.imap_unordered(_worker_run_map, [(runner, args) for args in args_list]):
yield x
Expand All @@ -117,6 +137,13 @@ def collect_once(G):
:param threshold:
:return:
"""
assert not inspect.ismethod(collect_once), (
"run_collect() cannot run a class method. Please ensure that "
"collect_once is a function with the prototype 'def foo(G, ...)', "
"where G is an object of type "
"rllab.sampler.stateful_pool.SharedGlobal"
)

if args is None:
args = tuple()
if self.pool:
Expand Down

0 comments on commit 4b41d10

Please sign in to comment.