Skip to content

Commit

Permalink
Use ExpectedResponseChannels for pongs
Browse files Browse the repository at this point in the history
  • Loading branch information
gsalgado committed Jan 6, 2020
1 parent 3c56ee8 commit c28c9e5
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 72 deletions.
100 changes: 46 additions & 54 deletions p2p/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ def __init__(self,
self.this_node = Node(self.pubkey, address)
self.routing = RoutingTable(self.this_node)
self.enr_response_channels = ExpectedResponseChannels[Tuple[ENR, Hash32]]()
self.pong_channels = ExpectedResponseChannels[Tuple[Hash32, int]]()

This comment has been minimized.

Copy link
@pipermerriam

pipermerriam Jan 6, 2020

I think there is a bug in ExpectedResponseChannels. The send side needs to be inserted into the dictionary of tracked channels before the async with statement for the receive channel. Otherwise, two competing coroutines that are both racing for the same node id will overwrite each other's channels.

I think the solution for this is simply to insert the send side of the channel before entering the async with block.

This comment has been minimized.

Copy link
@gsalgado

gsalgado Jan 7, 2020

Author Owner

Yeah, good catch!

# FIXME: Use a persistent EnrDb implementation.
self._enr_db = MemoryEnrDb(default_identity_scheme_registry)
# FIXME: Use a concurrency-safe EnrDb implementation.
self._enr_db_lock = trio.Lock()
self.pong_callbacks = CallbackManager()
self.ping_callbacks = CallbackManager()
self.neighbours_callbacks = CallbackManager()
self.parity_pong_tokens: Dict[Hash32, Hash32] = {}
Expand Down Expand Up @@ -214,6 +214,14 @@ def update_routing_table(self, node: NodeAPI) -> None:
# replacement cache.
self.manager.run_task(self.bond, eviction_candidate)

async def wait_pong_from(self, remote: NodeAPI) -> Tuple[Hash32, int]:
send_chan: trio.MemorySendChannel[Tuple[Hash32, int]]

This comment has been minimized.

Copy link
@pipermerriam

pipermerriam Jan 6, 2020

Curious about reducing this boilerplate with a local helper:

TChannelElement = TypeVar('TChannelElement')

def open_channel(sample_element: TChannelElement, buffer_size: int = 0) -> Tuple[trio.MemorySendChannel[TChannelElement], trio.MemoryReceiveChannel[TChannelElement]]:
    ...
    return send_chan, recv_chan

The idea is that we can do something like this in our code:

element: Tuple[Hash32, int] = (DUMMY_HASH, 0)
recv_chan, send_chan = open_channel(element, 0)

I'm not sold on this being a great idea but thought I'd drop it and see what anyone thinks.

This comment has been minimized.

Copy link
@gsalgado

gsalgado Jan 7, 2020

Author Owner

I kinda liked your idea but there was just no way I could make it work -- either mypy complained with some cryptic error msg or I got a runtime error because MemorySendChannel is not a generic type. However, because of that I decided to look around for an alternative and ended up finding python-trio/trio#908, so we now have a much simpler/cleaner alternative:

send_chan, recv_chan = trio.open_memory_channel[Tuple[Hash32, int]](1)
recv_chan: trio.MemoryReceiveChannel[Tuple[Hash32, int]]
send_chan, recv_chan = trio.open_memory_channel(1)
with trio.move_on_after(constants.KADEMLIA_REQUEST_TIMEOUT):
return await self.pong_channels.wait_receive(remote, send_chan, recv_chan)
return None, None

This comment has been minimized.

Copy link
@pipermerriam

pipermerriam Jan 6, 2020

Just pointing out that this does return value checking which is an anti-pattern we'd like to try to avoid. I think this has previously been noted but it always jumps out at me so I like to comment on it to be sure. If there's already an issue for this maybe we can link to it from here with a comment.

This comment has been minimized.

Copy link
@gsalgado

gsalgado Jan 7, 2020

Author Owner

I've addressed this


async def bond(self, node: NodeAPI) -> bool:
"""Bond with the given node.
Expand All @@ -226,16 +234,20 @@ async def bond(self, node: NodeAPI) -> bool:
return False

token = await self.send_ping_v4(node)
log_version = "v4"

try:
got_pong = await self.wait_pong_v4(node, token)
received_token, enr_seq = await self.wait_pong_from(node)
except AlreadyWaitingDiscoveryResponse:
self.logger.debug("bonding failed, awaiting %s pong from %s", log_version, node)
self.logger.debug(f"Bonding failed, already waiting pong from {node}")
return False

if not got_pong:
self.logger.debug("bonding failed, didn't receive %s pong from %s", log_version, node)
if received_token is None:

This comment has been minimized.

Copy link
@pipermerriam

pipermerriam Jan 6, 2020

I'm 👍 on doing this in steps but it would be ideal if we could get the whole API updated so that you just did all of this in one call.

await self.do_ping_pong(node)
# raise various exceptions to signal the different failure modes.
# - Timeout
# - TokenMismatch

This comment has been minimized.

Copy link
@gsalgado

gsalgado Jan 7, 2020

Author Owner

So, these are all internal APIs, only used here (in bond(), which is also internal), and wait_pong_from() is really just opening the channels and calling self.pong_channels.wait_receive() with a timeout, so I'm actually wondering if it wouldn't make more sense to actually unroll it here instead... I see no advantage in having separate (trivial, nearly one-liners) methods that raise exceptions for all possible failure modes when this is the only place they'd be handled. How do you feel about that?

self.logger.debug(f"Bonding failed, didn't receive pong from {node} with token {token}")
self.routing.remove_node(node)
return False
elif received_token != token:
self.logger.debug(
f"Bonding with {node} failed, expected pong with token {token!r}, "
f"but got {received_token!r}")
self.routing.remove_node(node)
return False

Expand Down Expand Up @@ -328,38 +340,6 @@ async def wait_ping(self, remote: NodeAPI) -> bool:

return got_ping

async def wait_pong_v4(self, remote: NodeAPI, token: Hash32) -> bool:
"""Wait for a pong from the given remote containing the given token.
This coroutine adds a callback to pong_callbacks and yields control until the given event
is set or a timeout (k_request_timeout) occurs. At that point it returns whether or not
a pong was received with the given token.
"""

event = trio.Event()

def callback(received_token: Hash32) -> None:
if received_token == token:
event.set()
else:
self.logger.warning("Pong from %s with wrong token: %s", received_token)

with self.pong_callbacks.acquire(remote, callback):
with trio.move_on_after(constants.KADEMLIA_REQUEST_TIMEOUT) as cancel_scope:
await event.wait()
if cancel_scope.cancelled_caught:
got_pong = False
self.logger.debug2(
'timed out waiting for pong from %s (token == %s)',
remote,
encode_hex(token),
)
else:
got_pong = True
self.logger.debug2('got expected pong with token %s', encode_hex(token))

return got_pong

async def wait_neighbours(self, remote: NodeAPI) -> Tuple[NodeAPI, ...]:
"""Wait for a neihgbours packet from the given node.
Expand Down Expand Up @@ -413,7 +393,8 @@ async def _find_node(node_id: int, remote: NodeAPI) -> Tuple[NodeAPI, ...]:
all_candidates = tuple(c for c in candidates if c not in nodes_seen)
candidates = tuple(
c for c in all_candidates
if (not self.ping_callbacks.locked(c) and not self.pong_callbacks.locked(c))
if (not self.ping_callbacks.locked(c) and
not self.pong_channels.already_waiting_for(c))
)
self.logger.debug2("got %s new candidates", len(candidates))
# Add new candidates to nodes_seen so that we don't attempt to bond with failing ones
Expand Down Expand Up @@ -504,7 +485,8 @@ async def bootstrap(self) -> None:
(self.bond, n)
for n
in self.bootstrap_nodes
if (not self.ping_callbacks.locked(n) and not self.pong_callbacks.locked(n))
if (not self.ping_callbacks.locked(n) and
not self.pong_channels.already_waiting_for(n))
)
bonded = await trio_utils.gather(*bonding_queries)
if not any(bonded):
Expand Down Expand Up @@ -562,7 +544,7 @@ async def recv_pong_v4(self, node: NodeAPI, payload: Sequence[Any], _: Hash32) -
if self._is_msg_expired(expiration):
return
self.logger.debug2('<<< pong (v4) from %s (token == %s)', node, encode_hex(token))
self.process_pong_v4(node, token)
await self.process_pong_v4(node, token, enr_seq)

async def recv_neighbours_v4(self, node: NodeAPI, payload: Sequence[Any], _: Hash32) -> None:
# The neighbours payload should have 2 elements: nodes, expiration
Expand Down Expand Up @@ -653,7 +635,11 @@ async def recv_enr_response(
enr.validate_signature()
self.logger.debug(
"Received ENR %s with expected response token: %s", enr, encode_hex(token))
await channel.send((enr, token))
try:
await channel.send((enr, token))
except trio.BrokenResourceError:
# This means the receiver has already closed, probably because it timed out.
pass

def send_enr_request(self, node: NodeAPI) -> Hash32:
message = self.send(node, CMD_ENR_REQUEST, [_get_msg_expiration()])
Expand Down Expand Up @@ -724,13 +710,7 @@ def process_neighbours(self, remote: NodeAPI, neighbours: List[NodeAPI]) -> None
else:
callback(neighbours)

def process_pong_v4(self, remote: NodeAPI, token: Hash32) -> None:
"""Process a pong packet.
Pong packets should only be received as a response to a ping, so the actual processing is
left to the callback from pong_callbacks, which is added (and removed after it's done
or timed out) in wait_pong().
"""
async def process_pong_v4(self, remote: NodeAPI, token: Hash32, enr_seq: int) -> None:
# XXX: This hack is needed because there are lots of parity 1.10 nodes out there that send
# the wrong token on pong msgs (https://github.com/paritytech/parity/issues/8038). We
# should get rid of this once there are no longer too many parity 1.10 nodes out there.
Expand All @@ -744,11 +724,20 @@ def process_pong_v4(self, remote: NodeAPI, token: Hash32) -> None:
lambda val: val != token, self.parity_pong_tokens)

try:
callback = self.pong_callbacks.get_callback(remote)
channel = self.pong_channels.get_channel(remote)
except KeyError:
self.logger.debug('unexpected v4 pong from %s (token == %s)', remote, encode_hex(token))
else:
callback(token)
# This is probably a Node which changed its identity since it was added to the DHT,
# causing us to expect a pong signed with a certain key when in fact it's using
# a different one. Another possibility is that the pong came after we've given up
# waiting.
self.logger.debug(f'Unexpected pong from {remote} with token {encode_hex(token)}')
return

try:
await channel.send((token, enr_seq))
except trio.BrokenResourceError:
# This means the receiver has already closed, probably because it timed out.
pass

def process_ping(self, remote: NodeAPI, hash_: Hash32) -> None:
"""Process a received ping packet.
Expand Down Expand Up @@ -1123,6 +1112,9 @@ class ExpectedResponseChannels(Generic[TMsg]):
def __init__(self) -> None:
self._channels: Dict[NodeAPI, trio.abc.SendChannel[TMsg]] = {}

def already_waiting_for(self, remote: NodeAPI) -> bool:
return remote in self._channels

def get_channel(self, remote: NodeAPI) -> trio.abc.SendChannel[TMsg]:
return self._channels[remote]

Expand Down
34 changes: 16 additions & 18 deletions tests-trio/p2p-trio/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ async def test_wait_ping(nursery, echo):


@pytest.mark.trio
async def test_wait_pong(nursery):
async def test_wait_pong_from(nursery):
service = MockDiscoveryService([])
us = service.this_node
node = NodeFactory()
Expand All @@ -224,23 +224,18 @@ async def test_wait_pong(nursery):
pong_msg_payload = [us.address.to_endpoint(), token, expiration]
nursery.start_soon(service.recv_pong_v4, node, pong_msg_payload, b'')

got_pong = await service.wait_pong_v4(node, token)
received_token, _ = await service.wait_pong_from(node)

assert got_pong
# Ensure wait_pong() cleaned up after itself.
pingid = service._mkpingid(token, node)
assert pingid not in service.pong_callbacks
assert received_token == token
# Ensure wait_pong_from() cleaned up after itself.
assert not service.pong_channels.already_waiting_for(node)

# If the remote node echoed something different than what we expected, wait_pong() would
# timeout.
wrong_token = b"foo"
pong_msg_payload = [us.address.to_endpoint(), wrong_token, expiration]
nursery.start_soon(service.recv_pong_v4, node, pong_msg_payload, b'')

got_pong = await service.wait_pong_v4(node, token)
# If the pong doesn't arrive, wait_pong_from() should hit its internal timeout and return
# None.
received_token, _ = await service.wait_pong_from(node)

assert not got_pong
assert pingid not in service.pong_callbacks
assert received_token is None
assert not service.pong_channels.already_waiting_for(node)


@pytest.mark.trio
Expand Down Expand Up @@ -284,10 +279,13 @@ async def send_ping(node):
monkeypatch.setattr(discovery, 'send_ping_v4', send_ping)

# Pretend we get a pong from the node we are bonding with.
async def wait_pong_v4(remote, t) -> bool:
return t == token and remote == node
async def wait_pong_from(remote) -> bool:
if remote == node:
return token, 0
else:
return None, None

monkeypatch.setattr(discovery, 'wait_pong_v4', wait_pong_v4)
monkeypatch.setattr(discovery, 'wait_pong_from', wait_pong_from)

bonded = await discovery.bond(node)

Expand Down

0 comments on commit c28c9e5

Please sign in to comment.