Skip to content

Commit

Permalink
Typed requests
Browse files Browse the repository at this point in the history
  • Loading branch information
flcl42 committed Dec 6, 2024
1 parent 903a777 commit 9e6fac5
Show file tree
Hide file tree
Showing 15 changed files with 215 additions and 44 deletions.
5 changes: 3 additions & 2 deletions src/libp2p/Libp2p.Core.Tests/ContextTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: MIT

using Multiformats.Address;
using Nethermind.Libp2p.Core.Discovery;
using Nethermind.Libp2p.Stack;

namespace Nethermind.Libp2p.Core.Tests;
Expand All @@ -28,8 +29,8 @@ public async Task E2e()
TopProtocols = [tProto]
};

LocalPeer peer1 = new(new Identity(), protocolStackSettings);
LocalPeer peer2 = new(new Identity(), protocolStackSettings);
LocalPeer peer1 = new(new Identity(), new PeerStore(), protocolStackSettings);
LocalPeer peer2 = new(new Identity(), new PeerStore(), protocolStackSettings);

await peer1.StartListenAsync([new Multiaddress()]);
await peer2.StartListenAsync([new Multiaddress()]);
Expand Down
7 changes: 4 additions & 3 deletions src/libp2p/Libp2p.Core.TestsBase/E2e/TestBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: MIT

using Microsoft.Extensions.Logging;
using Nethermind.Libp2p.Core.Discovery;
using Nethermind.Libp2p.Protocols;
using Nethermind.Libp2p.Stack;
using System.Collections.Concurrent;
Expand All @@ -25,18 +26,18 @@ .. additionalProtocols
}
}

public class TestPeerFactory(IProtocolStackSettings protocolStackSettings, ILoggerFactory? loggerFactory = null) : PeerFactory(protocolStackSettings)
public class TestPeerFactory(IProtocolStackSettings protocolStackSettings, PeerStore peerStore, ILoggerFactory? loggerFactory = null) : PeerFactory(protocolStackSettings, peerStore)
{
ConcurrentDictionary<PeerId, IPeer> peers = new();

public override IPeer Create(Identity? identity = default)
{
ArgumentNullException.ThrowIfNull(identity);
return peers.GetOrAdd(identity.PeerId, (p) => new TestLocalPeer(identity, protocolStackSettings, loggerFactory));
return peers.GetOrAdd(identity.PeerId, (p) => new TestLocalPeer(identity, protocolStackSettings, peerStore, loggerFactory));
}
}

internal class TestLocalPeer(Identity id, IProtocolStackSettings protocolStackSettings, ILoggerFactory? loggerFactory = null) : LocalPeer(id, protocolStackSettings, loggerFactory)
internal class TestLocalPeer(Identity id, IProtocolStackSettings protocolStackSettings, PeerStore peerStore, ILoggerFactory? loggerFactory = null) : LocalPeer(id, peerStore, protocolStackSettings, loggerFactory)
{
protected override async Task ConnectedTo(ISession session, bool isDialer)
{
Expand Down
11 changes: 6 additions & 5 deletions src/libp2p/Libp2p.Core.TestsBase/E2e/TestMuxerProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private async Task HandleRemote(IChannel downChannel, INewConnectionContext conn
uint chanId = Interlocked.Add(ref counter, 2);
logger?.LogDebug($"{context.Peer.Identity.PeerId}({chanId}): Sub-request {item.SelectedProtocol} {item.CompletionSource is not null} to call {connection.State.RemoteAddress.GetPeerId()}");

chans[chanId] = new MuxerChannel { Tcs = item.CompletionSource };
chans[chanId] = new MuxerChannel { Tcs = item.CompletionSource, Argument = item.Argument };
MuxerPacket response = new()
{
ChannelId = chanId,
Expand Down Expand Up @@ -170,7 +170,7 @@ private async Task HandleRemote(IChannel downChannel, INewConnectionContext conn
case MuxerPacketType.NewStreamResponse:
if (packet.Protocols.Any())
{
UpgradeOptions req = new() { SelectedProtocol = session.SubProtocols.FirstOrDefault(x => x.Id == packet.Protocols.First()), ModeOverride = UpgradeModeOverride.Dial };
UpgradeOptions req = new() { SelectedProtocol = session.SubProtocols.FirstOrDefault(x => x.Id == packet.Protocols.First()), CompletionSource = chans[packet.ChannelId].Tcs, Argument = chans[packet.ChannelId].Argument, ModeOverride = UpgradeModeOverride.Dial };
IChannel upChannel = session.Upgrade(req);
chans[packet.ChannelId].UpChannel = upChannel;
logger?.LogDebug($"{logPrefix}({packet.ChannelId}): Start upchanel with {req.SelectedProtocol}");
Expand Down Expand Up @@ -201,7 +201,7 @@ private async Task HandleRemote(IChannel downChannel, INewConnectionContext conn

if (chans[packet.ChannelId].LocalClosedWrites)
{
chans[packet.ChannelId].Tcs?.SetResult();
//chans[packet.ChannelId].Tcs?.SetResult(null);
_ = chans[packet.ChannelId].UpChannel?.CloseAsync();
chans.Remove(packet.ChannelId);
}
Expand Down Expand Up @@ -248,7 +248,7 @@ private Task HandleUpchannelData(IChannel downChannel, Dictionary<uint, MuxerCha
if (chans[channelId].RemoteClosedWrites)
{
logger?.LogDebug($"{logPrefix}({channelId}): Upchannel dial/listen complete");
chans[channelId].Tcs?.SetResult();
//chans[channelId].Tcs?.SetResult(null);
_ = upChannel.CloseAsync();
chans.Remove(channelId);
}
Expand Down Expand Up @@ -278,8 +278,9 @@ private Task HandleUpchannelData(IChannel downChannel, Dictionary<uint, MuxerCha
class MuxerChannel
{
public IChannel? UpChannel { get; set; }
public TaskCompletionSource? Tcs { get; set; }
public TaskCompletionSource<object?>? Tcs { get; set; }
public bool RemoteClosedWrites { get; set; }
public bool LocalClosedWrites { get; set; }
public object? Argument { get; internal set; }
}
}
10 changes: 10 additions & 0 deletions src/libp2p/Libp2p.Core.TestsBase/LocalPeerStub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public Task<ISession> DialAsync(Multiaddress[] samePeerAddrs, CancellationToken
return Task.FromResult<ISession>(new TestRemotePeer(samePeerAddrs.First()));
}

public Task<ISession> DialAsync(PeerId peerId, CancellationToken token = default)
{
throw new NotImplementedException();
}

public Task DisconnectAsync()
{
return Task.CompletedTask;
Expand Down Expand Up @@ -60,6 +65,11 @@ public Task DialAsync<TProtocol>(CancellationToken token = default) where TProto
return Task.CompletedTask;
}

public Task<TResponse> DialAsync<TProtocol, TRequest, TResponse>(TRequest request, CancellationToken token = default) where TProtocol : ISessionProtocol<TRequest, TResponse>
{
throw new NotImplementedException();
}

public Task DisconnectAsync()
{
return Task.CompletedTask;
Expand Down
1 change: 1 addition & 0 deletions src/libp2p/Libp2p.Core/Discovery/PeerStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public PeerInfo GetPeerInfo(PeerId peerId)
public class PeerInfo
{
public ByteString? SignedPeerRecord { get; set; }
public string[]? SupportedProtocols { get; set; }
public HashSet<Multiaddress>? Addrs { get; set; }
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/libp2p/Libp2p.Core/IChannelFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ public record UpgradeOptions
{
public IProtocol? SelectedProtocol { get; init; }
public UpgradeModeOverride ModeOverride { get; init; }
public TaskCompletionSource? CompletionSource { get; init; }
public TaskCompletionSource<object?>? CompletionSource { get; init; }
public object? Argument { get; set; }
}

public enum UpgradeModeOverride
Expand Down
6 changes: 6 additions & 0 deletions src/libp2p/Libp2p.Core/IPeer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ public interface IPeer
Task<ISession> DialAsync(Multiaddress addr, CancellationToken token = default);
Task<ISession> DialAsync(Multiaddress[] samePeerAddrs, CancellationToken token = default);

/// <summary>
/// Find existing session or dial a peer if found in peer store
/// </summary>
Task<ISession> DialAsync(PeerId peerId, CancellationToken token = default);

Task StartListenAsync(Multiaddress[] addrs, CancellationToken token = default);

Task DisconnectAsync();
Expand All @@ -28,5 +33,6 @@ public interface ISession
{
Multiaddress RemoteAddress { get; }
Task DialAsync<TProtocol>(CancellationToken token = default) where TProtocol : ISessionProtocol;
Task<TResponse> DialAsync<TProtocol, TRequest, TResponse>(TRequest request, CancellationToken token = default) where TProtocol : ISessionProtocol<TRequest, TResponse>;
Task DisconnectAsync();
}
14 changes: 13 additions & 1 deletion src/libp2p/Libp2p.Core/IProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,20 @@ public interface IConnectionProtocol : IProtocol
Task DialAsync(IChannel downChannel, IConnectionContext context);
}

public interface ISessionProtocol : IProtocol

public interface ISessionListenerProtocol : IProtocol
{
Task ListenAsync(IChannel downChannel, ISessionContext context);
}

public interface ISessionProtocol : ISessionListenerProtocol
{
Task DialAsync(IChannel downChannel, ISessionContext context);
}

public class Void;

public interface ISessionProtocol<TRequest, TResponse> : ISessionListenerProtocol
{
Task<TResponse> DialAsync(IChannel downChannel, ISessionContext context, TRequest request);
}
122 changes: 117 additions & 5 deletions src/libp2p/Libp2p.Core/Peer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Microsoft.Extensions.Logging;
using Multiformats.Address;
using Multiformats.Address.Protocols;
using Nethermind.Libp2p.Core.Discovery;
using Nethermind.Libp2p.Core.Exceptions;
using Nethermind.Libp2p.Core.Extensions;
using Nethermind.Libp2p.Stack;
Expand All @@ -15,15 +16,17 @@ namespace Nethermind.Libp2p.Core;
public class LocalPeer : IPeer
{
protected readonly ILogger? _logger;
protected readonly PeerStore _peerStore;
protected readonly IProtocolStackSettings _protocolStackSettings;

Dictionary<object, TaskCompletionSource<Multiaddress>> listenerReadyTcs = new();
private ObservableCollection<Session> sessions { get; } = [];


public LocalPeer(Identity identity, IProtocolStackSettings protocolStackSettings, ILoggerFactory? loggerFactory = null)
public LocalPeer(Identity identity, PeerStore peerStore, IProtocolStackSettings protocolStackSettings, ILoggerFactory? loggerFactory = null)
{
Identity = identity;
_peerStore = peerStore;
_protocolStackSettings = protocolStackSettings;
_logger = loggerFactory?.CreateLogger($"peer-{identity.PeerId}");
}
Expand All @@ -50,19 +53,30 @@ public class Session(LocalPeer peer) : ISession

public async Task DialAsync<TProtocol>(CancellationToken token = default) where TProtocol : ISessionProtocol
{
TaskCompletionSource tcs = new();
TaskCompletionSource<object> tcs = new();
SubDialRequests.Add(new UpgradeOptions() { CompletionSource = tcs, SelectedProtocol = peer.GetProtocolInstance<TProtocol>() }, token);
await tcs.Task;
MarkAsConnected();
}

public async Task DialAsync(ISessionProtocol protocol, CancellationToken token = default)
{
TaskCompletionSource tcs = new();
TaskCompletionSource<object> tcs = new();
SubDialRequests.Add(new UpgradeOptions() { CompletionSource = tcs, SelectedProtocol = protocol }, token);
await tcs.Task;
MarkAsConnected();
}

public async Task<TResponse> DialAsync<TProtocol, TRequest, TResponse>(TRequest request, CancellationToken token = default) where TProtocol : ISessionProtocol<TRequest, TResponse>
{
TaskCompletionSource<object> tcs = new();
SubDialRequests.Add(new UpgradeOptions() { CompletionSource = tcs, SelectedProtocol = peer.GetProtocolInstance<TProtocol>(), Argument = request }, token);
await tcs.Task;
MarkAsConnected();
return (TResponse)tcs.Task.Result;
}


private CancellationTokenSource connectionTokenSource = new();

public Task DisconnectAsync()
Expand All @@ -77,7 +91,7 @@ public Task DisconnectAsync()

public TaskCompletionSource ConnectedTcs = new();
public Task Connected => ConnectedTcs.Task;
internal void MarkAsConnected() => ConnectedTcs.SetResult();
internal void MarkAsConnected() => ConnectedTcs?.TrySetResult();


internal IEnumerable<UpgradeOptions> GetRequestQueue() => SubDialRequests.GetConsumingEnumerable(ConnectionToken);
Expand Down Expand Up @@ -232,7 +246,7 @@ public async Task<ISession> DialAsync(Multiaddress[] addrs, CancellationToken to
cancellations[addr] = CancellationTokenSource.CreateLinkedTokenSource(token);
}

Task timeoutTask = Task.Delay(15_000, token);
Task timeoutTask = Task.Delay(1511111_000, token);
Task wait = await TaskHelper.FirstSuccess([timeoutTask, .. addrs.Select(addr => DialAsync(addr, cancellations[addr].Token))]);

if (wait == timeoutTask)
Expand Down Expand Up @@ -281,6 +295,25 @@ public async Task<ISession> DialAsync(Multiaddress addr, CancellationToken token
return session;
}

public Task<ISession> DialAsync(PeerId peerId, CancellationToken token = default)
{
ISession? existingSession = sessions.FirstOrDefault(s => s.State.RemotePeerId == peerId);

if (existingSession is not null)
{
return Task.FromResult(existingSession);
}

PeerStore.PeerInfo existingPeerInfo = _peerStore.GetPeerInfo(peerId);

if (existingPeerInfo?.Addrs is null)
{
throw new Libp2pException("Peer not found");
}

return DialAsync([.. existingPeerInfo.Addrs], token);
}

internal IChannel Upgrade(Session session, ProtocolRef parentProtocol, IProtocol? upgradeProtocol, UpgradeOptions? options, bool isListener)
{
if (_protocolStackSettings.Protocols is null)
Expand Down Expand Up @@ -321,6 +354,38 @@ internal IChannel Upgrade(Session session, ProtocolRef parentProtocol, IProtocol
break;
}
default:
if (isListener && top.Protocol is ISessionListenerProtocol listenerProtocol)
{
SessionContext ctx = new(this, session, top, isListener, options);
upgradeTask = listenerProtocol.ListenAsync(downChannel.Reverse, ctx);
break;
}

var genericInterface = top.Protocol.GetType().GetInterfaces()
.FirstOrDefault(i =>
i.IsGenericType &&
i.GetGenericTypeDefinition() == typeof(ISessionProtocol<,>));

if (genericInterface != null)
{
var genericArguments = genericInterface.GetGenericArguments();
var requestType = genericArguments[0];
var responseType = genericArguments[1];

if (options?.Argument is not null && !options.Argument.GetType().IsAssignableTo(requestType))
{
throw new ArgumentException($"Invalid request. Argument is of {options.Argument.GetType()} type which is not assignable to {requestType.FullName}");
}

// Dynamically invoke DialAsync
var dialAsyncMethod = genericInterface.GetMethod("DialAsync");
if (dialAsyncMethod != null)
{
SessionContext ctx = new(this, session, top, isListener, options);
upgradeTask = (Task)dialAsyncMethod.Invoke(top.Protocol, [downChannel.Reverse, ctx, options?.Argument])!;
break;
}
}
throw new Libp2pSetupException($"Protocol {top.Protocol} does not implement proper protocol interface");
}

Expand All @@ -345,6 +410,21 @@ internal IChannel Upgrade(Session session, ProtocolRef parentProtocol, IProtocol
return downChannel;
}

private static void MapToTaskCompletionSource(Task t, TaskCompletionSource<object?> tcs)
{
if (t.IsCompletedSuccessfully)
{
tcs.SetResult(t.GetType().GenericTypeArguments.Any() ? t.GetType().GetProperty("Result")!.GetValue(t) : null);
return;
}
if (t.IsCanceled)
{
tcs.SetCanceled();
return;
}
tcs.SetException(t.Exception!);
}

private static void MapToTaskCompletionSource(Task t, TaskCompletionSource tcs)
{
if (t.IsCompletedSuccessfully)
Expand Down Expand Up @@ -396,6 +476,38 @@ internal async Task Upgrade(Session session, IChannel parentChannel, ProtocolRef
break;
}
default:
if (isListener && top.Protocol is ISessionListenerProtocol listenerProtocol)
{
SessionContext ctx = new(this, session, top, isListener, options);
upgradeTask = listenerProtocol.ListenAsync(parentChannel, ctx);
break;
}

var genericInterface = top.Protocol.GetType().GetInterfaces()
.FirstOrDefault(i =>
i.IsGenericType &&
i.GetGenericTypeDefinition() == typeof(ISessionProtocol<,>));

if (genericInterface != null)
{
var genericArguments = genericInterface.GetGenericArguments();
var requestType = genericArguments[0];
var responseType = genericArguments[1];

if (options?.Argument is not null && !options.Argument.GetType().IsInstanceOfType(requestType))
{
throw new ArgumentException($"Invalid request. Argument is of {options.Argument.GetType()} type which is not assignable to {requestType.FullName}");
}

// Dynamically invoke DialAsync
var dialAsyncMethod = genericInterface.GetMethod("DialAsync");
if (dialAsyncMethod != null)
{
SessionContext ctx = new(this, session, top, isListener, options);
upgradeTask = (Task)dialAsyncMethod.Invoke(top.Protocol, [parentChannel, ctx, options?.Argument])!;
break;
}
}
throw new Libp2pSetupException($"Protocol {top.Protocol} does not implement proper protocol interface");
}

Expand Down
Loading

0 comments on commit 9e6fac5

Please sign in to comment.