From be788243c10ab54a5853b6d6ab6b3adaa920cdb7 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 21 Aug 2014 12:29:25 +0300 Subject: [PATCH 1/3] Refactor pending waiters --- aiozmq/rpc/base.py | 8 +++++++- aiozmq/rpc/pipeline.py | 4 ++-- aiozmq/rpc/pubsub.py | 4 ++-- aiozmq/rpc/rpc.py | 4 ++-- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/aiozmq/rpc/base.py b/aiozmq/rpc/base.py index 3fbe325..328ec45 100644 --- a/aiozmq/rpc/base.py +++ b/aiozmq/rpc/base.py @@ -143,6 +143,7 @@ def __init__(self, loop, *, translation_table=None): self.transport = None self.done_waiters = [] self.packer = _Packer(translation_table=translation_table) + self.pending_waiters = set() def connection_made(self, transport): self.transport = transport @@ -152,6 +153,12 @@ def connection_lost(self, exc): for waiter in self.done_waiters: waiter.set_result(None) + def add_pending(self, fut): + self.pending_waiters.add(fut) + + def discard_pending(self, fut): + self.pending_waiters.discard(fut) + class _BaseServerProtocol(_BaseProtocol): @@ -164,7 +171,6 @@ def __init__(self, loop, handler, *, self.handler = handler self.log_exceptions = log_exceptions self.exclude_log_exceptions = exclude_log_exceptions - self.pending_waiters = set() def connection_lost(self, exc): super().connection_lost(exc) diff --git a/aiozmq/rpc/pipeline.py b/aiozmq/rpc/pipeline.py index 850586e..4c664c7 100644 --- a/aiozmq/rpc/pipeline.py +++ b/aiozmq/rpc/pipeline.py @@ -129,7 +129,7 @@ def msg_received(self, data): else: if asyncio.iscoroutinefunction(func): fut = asyncio.async(func(*args, **kwargs), loop=self.loop) - self.pending_waiters.add(fut) + self.add_pending(fut) else: fut = asyncio.Future(loop=self.loop) try: @@ -140,7 +140,7 @@ def msg_received(self, data): name=name, args=args, kwargs=kwargs)) def process_call_result(self, fut, *, name, args, kwargs): - self.pending_waiters.discard(fut) + self.discard_pending(fut) try: if fut.result() is not None: logger.warning("Pipeline handler %r returned not None", name) diff --git a/aiozmq/rpc/pubsub.py b/aiozmq/rpc/pubsub.py index 4f40303..6f05c1f 100644 --- a/aiozmq/rpc/pubsub.py +++ b/aiozmq/rpc/pubsub.py @@ -204,7 +204,7 @@ def msg_received(self, data): else: if asyncio.iscoroutinefunction(func): fut = asyncio.async(func(*args, **kwargs), loop=self.loop) - self.pending_waiters.add(fut) + self.add_pending(fut) else: fut = asyncio.Future(loop=self.loop) try: @@ -215,7 +215,7 @@ def msg_received(self, data): name=name, args=args, kwargs=kwargs)) def process_call_result(self, fut, *, name, args, kwargs): - self.pending_waiters.discard(fut) + self.discard_pending(fut) try: if fut.result() is not None: logger.warning("PubSub handler %r returned not None", name) diff --git a/aiozmq/rpc/rpc.py b/aiozmq/rpc/rpc.py index c6d00c7..d969448 100644 --- a/aiozmq/rpc/rpc.py +++ b/aiozmq/rpc/rpc.py @@ -251,7 +251,7 @@ def msg_received(self, data): else: if asyncio.iscoroutinefunction(func): fut = asyncio.async(func(*args, **kwargs), loop=self.loop) - self.pending_waiters.add(fut) + self.add_pending(fut) else: fut = asyncio.Future(loop=self.loop) try: @@ -268,7 +268,7 @@ def msg_received(self, data): def process_call_result(self, fut, *, req_id, pre, name, args, kwargs, return_annotation=None): - self.pending_waiters.discard(fut) + self.discard_pending(fut) self.try_log(fut, name, args, kwargs) if self.transport is None: return From 092dad4e3c40ceef7861a39bed8f29bd6388b5f5 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 21 Aug 2014 12:38:11 +0300 Subject: [PATCH 2/3] Continue rafactoring --- aiozmq/rpc/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aiozmq/rpc/base.py b/aiozmq/rpc/base.py index 328ec45..93607ea 100644 --- a/aiozmq/rpc/base.py +++ b/aiozmq/rpc/base.py @@ -105,7 +105,6 @@ class Service(asyncio.AbstractServer): def __init__(self, loop, proto): self._loop = loop self._proto = proto - self._closing = False @property def transport(self): @@ -120,9 +119,9 @@ def transport(self): return transport def close(self): - if self._closing: + if self._proto.closing: return - self._closing = True + self._proto.closing = True if self._proto.transport is None: return self._proto.transport.close() @@ -144,6 +143,7 @@ def __init__(self, loop, *, translation_table=None): self.done_waiters = [] self.packer = _Packer(translation_table=translation_table) self.pending_waiters = set() + self.closing = False def connection_made(self, transport): self.transport = transport From ca5cdef314a8023ed0ac235891f747a8dff05d2e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 21 Aug 2014 14:30:28 +0300 Subject: [PATCH 3/3] Add disabled test --- tests/rpc_test.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/rpc_test.py b/tests/rpc_test.py index 0ab1623..2682c47 100644 --- a/tests/rpc_test.py +++ b/tests/rpc_test.py @@ -73,6 +73,12 @@ def cancelled_fut(self): def exc2(self, arg): raise ValueError("bad arg", arg) + @aiozmq.rpc.method + @asyncio.coroutine + def not_so_fast(self): + yield from asyncio.sleep(0.001, loop=self.loop) + return 'ok' + class Protocol(aiozmq.ZmqProtocol): @@ -629,6 +635,19 @@ def communicate(): self.loop.run_until_complete(communicate()) + def xtest_wait_closed(self): + client, server = self.make_rpc_pair() + + @asyncio.coroutine + def go(): + f1 = client.call.not_so_fast() + client.close() + client.wait_closed() + r = yield from f1 + self.assertEqual('ok', r) + + self.loop.run_until_complete(go()) + class LoopRpcTests(unittest.TestCase, RpcTestsMixin):