Skip to content

Commit

Permalink
Merge branch '0.5' of github.com:aio-libs/aiozmq into 0.5
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Aug 22, 2014
2 parents f111b72 + ca5cdef commit df3b0f2
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 10 deletions.
14 changes: 10 additions & 4 deletions aiozmq/rpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ class Service(asyncio.AbstractServer):
def __init__(self, loop, proto):
self._loop = loop
self._proto = proto
self._closing = False

@property
def transport(self):
Expand All @@ -120,9 +119,9 @@ def transport(self):
return transport

def close(self):
if self._closing:
if self._proto.closing:
return
self._closing = True
self._proto.closing = True
if self._proto.transport is None:
return
self._proto.transport.close()
Expand All @@ -143,6 +142,8 @@ def __init__(self, loop, *, translation_table=None):
self.transport = None
self.done_waiters = []
self.packer = _Packer(translation_table=translation_table)
self.pending_waiters = set()
self.closing = False

def connection_made(self, transport):
self.transport = transport
Expand All @@ -152,6 +153,12 @@ def connection_lost(self, exc):
for waiter in self.done_waiters:
waiter.set_result(None)

def add_pending(self, fut):
self.pending_waiters.add(fut)

def discard_pending(self, fut):
self.pending_waiters.discard(fut)


class _BaseServerProtocol(_BaseProtocol):

Expand All @@ -164,7 +171,6 @@ def __init__(self, loop, handler, *,
self.handler = handler
self.log_exceptions = log_exceptions
self.exclude_log_exceptions = exclude_log_exceptions
self.pending_waiters = set()

def connection_lost(self, exc):
super().connection_lost(exc)
Expand Down
4 changes: 2 additions & 2 deletions aiozmq/rpc/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def msg_received(self, data):
else:
if asyncio.iscoroutinefunction(func):
fut = asyncio.async(func(*args, **kwargs), loop=self.loop)
self.pending_waiters.add(fut)
self.add_pending(fut)
else:
fut = asyncio.Future(loop=self.loop)
try:
Expand All @@ -140,7 +140,7 @@ def msg_received(self, data):
name=name, args=args, kwargs=kwargs))

def process_call_result(self, fut, *, name, args, kwargs):
self.pending_waiters.discard(fut)
self.discard_pending(fut)
try:
if fut.result() is not None:
logger.warning("Pipeline handler %r returned not None", name)
Expand Down
4 changes: 2 additions & 2 deletions aiozmq/rpc/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def msg_received(self, data):
else:
if asyncio.iscoroutinefunction(func):
fut = asyncio.async(func(*args, **kwargs), loop=self.loop)
self.pending_waiters.add(fut)
self.add_pending(fut)
else:
fut = asyncio.Future(loop=self.loop)
try:
Expand All @@ -215,7 +215,7 @@ def msg_received(self, data):
name=name, args=args, kwargs=kwargs))

def process_call_result(self, fut, *, name, args, kwargs):
self.pending_waiters.discard(fut)
self.discard_pending(fut)
try:
if fut.result() is not None:
logger.warning("PubSub handler %r returned not None", name)
Expand Down
4 changes: 2 additions & 2 deletions aiozmq/rpc/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def msg_received(self, data):
else:
if asyncio.iscoroutinefunction(func):
fut = asyncio.async(func(*args, **kwargs), loop=self.loop)
self.pending_waiters.add(fut)
self.add_pending(fut)
else:
fut = asyncio.Future(loop=self.loop)
try:
Expand All @@ -268,7 +268,7 @@ def msg_received(self, data):
def process_call_result(self, fut, *, req_id, pre, name,
args, kwargs,
return_annotation=None):
self.pending_waiters.discard(fut)
self.discard_pending(fut)
self.try_log(fut, name, args, kwargs)
if self.transport is None:
return
Expand Down
19 changes: 19 additions & 0 deletions tests/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def cancelled_fut(self):
def exc2(self, arg):
raise ValueError("bad arg", arg)

@aiozmq.rpc.method
@asyncio.coroutine
def not_so_fast(self):
yield from asyncio.sleep(0.001, loop=self.loop)
return 'ok'


class Protocol(aiozmq.ZmqProtocol):

Expand Down Expand Up @@ -629,6 +635,19 @@ def communicate():

self.loop.run_until_complete(communicate())

def xtest_wait_closed(self):
client, server = self.make_rpc_pair()

@asyncio.coroutine
def go():
f1 = client.call.not_so_fast()
client.close()
client.wait_closed()
r = yield from f1
self.assertEqual('ok', r)

self.loop.run_until_complete(go())


class LoopRpcTests(unittest.TestCase, RpcTestsMixin):

Expand Down

0 comments on commit df3b0f2

Please sign in to comment.