diff --git a/rllab/sampler/stateful_pool.py b/rllab/sampler/stateful_pool.py index beba84cf4..512dd2dc0 100644 --- a/rllab/sampler/stateful_pool.py +++ b/rllab/sampler/stateful_pool.py @@ -1,13 +1,14 @@ - - -from joblib.pool import MemmapingPool +import inspect import multiprocessing as mp -from rllab.misc import logger -import pyprind import time import traceback import sys +from joblib.pool import MemmapingPool +import pyprind + +from rllab.misc import logger + class ProgBarCounter(object): def __init__(self, total_count): @@ -68,6 +69,12 @@ def run_each(self, runner, args_list=None): in the args_list, if any :return: """ + assert not inspect.ismethod(runner), ( + "run_each() cannot run a class method. Please ensure that runner is" + " a function with the prototype def foo(G, ...), where G is an " + "object of type rllab.sampler.stateful_pool.SharedGlobal" + ) + if args_list is None: args_list = [tuple()] * self.n_parallel assert len(args_list) == self.n_parallel @@ -83,6 +90,12 @@ def run_each(self, runner, args_list=None): return [runner(self.G, *args_list[0])] def run_map(self, runner, args_list): + assert not inspect.ismethod(runner), ( + "run_map() cannot run a class method. Please ensure that runner is " + "a function with the prototype 'def foo(G, ...)', where G is an " + "object of type rllab.sampler.stateful_pool.SharedGlobal" + ) + if self.n_parallel > 1: return self.pool.map(_worker_run_map, [(runner, args) for args in args_list]) else: @@ -92,6 +105,13 @@ def run_map(self, runner, args_list): return ret def run_imap_unordered(self, runner, args_list): + assert not inspect.ismethod(runner), ( + "run_imap_unordered() cannot run a class method. Please ensure that" + "runner is a function with the prototype 'def foo(G, ...)', where " + "G is an object of type rllab.sampler.stateful_pool.SharedGlobal" + ) + + if self.n_parallel > 1: for x in self.pool.imap_unordered(_worker_run_map, [(runner, args) for args in args_list]): yield x @@ -117,6 +137,13 @@ def collect_once(G): :param threshold: :return: """ + assert not inspect.ismethod(collect_once), ( + "run_collect() cannot run a class method. Please ensure that " + "collect_once is a function with the prototype 'def foo(G, ...)', " + "where G is an object of type " + "rllab.sampler.stateful_pool.SharedGlobal" + ) + if args is None: args = tuple() if self.pool: