diff --git a/execnet/gateway_base.py b/execnet/gateway_base.py index 99cc3f60..bdf57260 100644 --- a/execnet/gateway_base.py +++ b/execnet/gateway_base.py @@ -64,8 +64,29 @@ def reraise(cls, val, tb): # def log_extra(*msg): # f.write(" ".join([str(x) for x in msg]) + "\n") +if sys.version_info >= (3, 7): + from contextlib import nullcontext +else: + class nullcontext(object): + """Context manager that does no additional processing. + Used as a stand-in for a normal context manager, when a particular + block of code is only sometimes used with a normal context manager: + cm = optional_cm if condition else nullcontext() + with cm: + # Perform operation, using optional_cm if condition is True + """ + + def __init__(self, enter_result=None): + self.enter_result = enter_result -class EmptySemaphore: + def __enter__(self): + return self.enter_result + + def __exit__(self, *excinfo): + pass + + +class EmptySemaphore(nullcontext): acquire = release = lambda self: None @@ -238,13 +259,16 @@ class WorkerPool(object): when the pool received a trigger_shutdown(). """ - def __init__(self, execmodel, hasprimary=False): - """by default allow unlimited number of spawns.""" +< + def __init__(self, execmodel, hasprimary=False, size=None): + """ by default allow unlimited number of spawns. """ + self.execmodel = execmodel self._running_lock = self.execmodel.Lock() self._running = set() self._shuttingdown = False self._waitall_events = [] + self._semaphore = self.execmodel.Semaphore(size) if hasprimary: if self.execmodel.backend != "thread": raise ValueError("hasprimary=True requires thread model") @@ -307,7 +331,7 @@ def spawn(self, func, *args, **kwargs): of the given func(*args, **kwargs). """ reply = Reply((func, args, kwargs), self.execmodel) - with self._running_lock: + with self._semaphore, self._running_lock: if self._shuttingdown: raise ValueError("pool is shutting down") self._running.add(reply) diff --git a/testing/test_threadpool.py b/testing/test_threadpool.py index d4694368..96c399b6 100644 --- a/testing/test_threadpool.py +++ b/testing/test_threadpool.py @@ -74,7 +74,6 @@ def test_waitfinish_on_reply(pool): pytest.raises(ZeroDivisionError, reply.get) -@pytest.mark.xfail(reason="WorkerPool does not implement limited size") def test_limited_size(execmodel): pool = WorkerPool(execmodel, size=1) q = execmodel.queue.Queue()