diff --git a/aiozmq/rpc/base.py b/aiozmq/rpc/base.py index 3fbe325..93607ea 100644 --- a/aiozmq/rpc/base.py +++ b/aiozmq/rpc/base.py @@ -105,7 +105,6 @@ class Service(asyncio.AbstractServer): def __init__(self, loop, proto): self._loop = loop self._proto = proto - self._closing = False @property def transport(self): @@ -120,9 +119,9 @@ def transport(self): return transport def close(self): - if self._closing: + if self._proto.closing: return - self._closing = True + self._proto.closing = True if self._proto.transport is None: return self._proto.transport.close() @@ -143,6 +142,8 @@ def __init__(self, loop, *, translation_table=None): self.transport = None self.done_waiters = [] self.packer = _Packer(translation_table=translation_table) + self.pending_waiters = set() + self.closing = False def connection_made(self, transport): self.transport = transport @@ -152,6 +153,12 @@ def connection_lost(self, exc): for waiter in self.done_waiters: waiter.set_result(None) + def add_pending(self, fut): + self.pending_waiters.add(fut) + + def discard_pending(self, fut): + self.pending_waiters.discard(fut) + class _BaseServerProtocol(_BaseProtocol): @@ -164,7 +171,6 @@ def __init__(self, loop, handler, *, self.handler = handler self.log_exceptions = log_exceptions self.exclude_log_exceptions = exclude_log_exceptions - self.pending_waiters = set() def connection_lost(self, exc): super().connection_lost(exc) diff --git a/aiozmq/rpc/pipeline.py b/aiozmq/rpc/pipeline.py index 850586e..4c664c7 100644 --- a/aiozmq/rpc/pipeline.py +++ b/aiozmq/rpc/pipeline.py @@ -129,7 +129,7 @@ def msg_received(self, data): else: if asyncio.iscoroutinefunction(func): fut = asyncio.async(func(*args, **kwargs), loop=self.loop) - self.pending_waiters.add(fut) + self.add_pending(fut) else: fut = asyncio.Future(loop=self.loop) try: @@ -140,7 +140,7 @@ def msg_received(self, data): name=name, args=args, kwargs=kwargs)) def process_call_result(self, fut, *, name, args, kwargs): - self.pending_waiters.discard(fut) + self.discard_pending(fut) try: if fut.result() is not None: logger.warning("Pipeline handler %r returned not None", name) diff --git a/aiozmq/rpc/pubsub.py b/aiozmq/rpc/pubsub.py index 4f40303..6f05c1f 100644 --- a/aiozmq/rpc/pubsub.py +++ b/aiozmq/rpc/pubsub.py @@ -204,7 +204,7 @@ def msg_received(self, data): else: if asyncio.iscoroutinefunction(func): fut = asyncio.async(func(*args, **kwargs), loop=self.loop) - self.pending_waiters.add(fut) + self.add_pending(fut) else: fut = asyncio.Future(loop=self.loop) try: @@ -215,7 +215,7 @@ def msg_received(self, data): name=name, args=args, kwargs=kwargs)) def process_call_result(self, fut, *, name, args, kwargs): - self.pending_waiters.discard(fut) + self.discard_pending(fut) try: if fut.result() is not None: logger.warning("PubSub handler %r returned not None", name) diff --git a/aiozmq/rpc/rpc.py b/aiozmq/rpc/rpc.py index c6d00c7..d969448 100644 --- a/aiozmq/rpc/rpc.py +++ b/aiozmq/rpc/rpc.py @@ -251,7 +251,7 @@ def msg_received(self, data): else: if asyncio.iscoroutinefunction(func): fut = asyncio.async(func(*args, **kwargs), loop=self.loop) - self.pending_waiters.add(fut) + self.add_pending(fut) else: fut = asyncio.Future(loop=self.loop) try: @@ -268,7 +268,7 @@ def msg_received(self, data): def process_call_result(self, fut, *, req_id, pre, name, args, kwargs, return_annotation=None): - self.pending_waiters.discard(fut) + self.discard_pending(fut) self.try_log(fut, name, args, kwargs) if self.transport is None: return diff --git a/tests/rpc_test.py b/tests/rpc_test.py index 0ab1623..2682c47 100644 --- a/tests/rpc_test.py +++ b/tests/rpc_test.py @@ -73,6 +73,12 @@ def cancelled_fut(self): def exc2(self, arg): raise ValueError("bad arg", arg) + @aiozmq.rpc.method + @asyncio.coroutine + def not_so_fast(self): + yield from asyncio.sleep(0.001, loop=self.loop) + return 'ok' + class Protocol(aiozmq.ZmqProtocol): @@ -629,6 +635,19 @@ def communicate(): self.loop.run_until_complete(communicate()) + def xtest_wait_closed(self): + client, server = self.make_rpc_pair() + + @asyncio.coroutine + def go(): + f1 = client.call.not_so_fast() + client.close() + client.wait_closed() + r = yield from f1 + self.assertEqual('ok', r) + + self.loop.run_until_complete(go()) + class LoopRpcTests(unittest.TestCase, RpcTestsMixin):