diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5352e82e..413ad491 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,7 +36,11 @@ jobs: TEST_OPTS: -c ${{ env.BUILD_CONFIG }} --no-restore run: | dotnet test Libp2p.Core.Tests ${{ env.PACK_OPTS }} - #dotnet test Libp2p.Protocols.Multistream.Tests ${{ env.PACK_OPTS }} + dotnet test Libp2p.Protocols.Multistream.Tests ${{ env.PACK_OPTS }} dotnet test Libp2p.Protocols.Noise.Tests ${{ env.PACK_OPTS }} dotnet test Libp2p.Protocols.Pubsub.Tests ${{ env.PACK_OPTS }} dotnet test Libp2p.Protocols.Quic.Tests ${{ env.PACK_OPTS }} + dotnet test Libp2p.Protocols.Yamux.Tests ${{ env.PACK_OPTS }} + dotnet test Libp2p.E2eTests ${{ env.PACK_OPTS }} + dotnet test Libp2p.Protocols.Pubsub.E2eTests ${{ env.PACK_OPTS }} + dotnet test Libp2p.Protocols.PubsubPeerDiscovery.E2eTests ${{ env.PACK_OPTS }} diff --git a/README.md b/README.md index cac0d299..e8d2b180 100644 --- a/README.md +++ b/README.md @@ -40,14 +40,14 @@ The target is to provide a performant well-tested implementation of a wide range | Protocol | Version | Status | |--------------------|--------------------|-----------------| | TCP | tcp | ✅ | -| QUIC | quic-v1 | ✅ | +| QUIC | quic-v1 | 🚧 | | multistream-select | /multistream/1.0.0 | ✅ | | plaintext | /plaintext/2.0.0 | ✅ | | noise | /noise | ✅ | | tls | /tls/1.0.0 | 🚧 | | WebTransport | | ⬜ help wanted | | yamux | /yamux/1.0.0 | ✅ | -| Circuit Relay | /libp2p/circuit/relay/0.2.0/* | ⬜ help wanted | +| Circuit Relay | /libp2p/circuit/relay/0.2.0/* | 🚧 | | hole punching | | ⬜ help wanted | | **Application layer** | Identify | /ipfs/id/1.0.0 | ✅ | @@ -55,11 +55,11 @@ The target is to provide a performant well-tested implementation of a wide range | pubsub | /floodsub/1.0.0 | ✅ | | | /meshsub/1.0.0 | ✅ | | | /meshsub/1.1.0 | 🚧 | -| | /meshsub/1.2.0 | ⬜ | +| | /meshsub/1.2.0 | 🚧 | | **Discovery** | mDns | basic | ✅ | | | DNS-SD | 🚧 | -| [discv5](https://github.com/Pier-Two/Lantern.Discv5) | 5.1 | 🚧 help wanted | +| [discv5](https://github.com/Pier-Two/Lantern.Discv5) (wrapper) | 5.1 | 🚧 help wanted | ⬜ - not yet implemented
🚧 - work in progress
diff --git a/src/libp2p/Directory.Packages.props b/src/libp2p/Directory.Packages.props index 9e92e867..40d639fe 100644 --- a/src/libp2p/Directory.Packages.props +++ b/src/libp2p/Directory.Packages.props @@ -23,7 +23,7 @@ - + @@ -34,4 +34,4 @@ - \ No newline at end of file + diff --git a/src/libp2p/Libp2p.Core.Tests/TaskHelperTests.cs b/src/libp2p/Libp2p.Core.Tests/TaskHelperTests.cs new file mode 100644 index 00000000..772740ff --- /dev/null +++ b/src/libp2p/Libp2p.Core.Tests/TaskHelperTests.cs @@ -0,0 +1,52 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Nethermind.Libp2p.Core.Extensions; + +namespace Nethermind.Libp2p.Core.Tests; +internal class TaskHelperTests +{ + + [Test] + public async Task Test_AllExceptions_RaiseAggregateException() + { + TaskCompletionSource tcs1 = new(); + TaskCompletionSource tcs2 = new(); + TaskCompletionSource tcs3 = new(); + + Task t = TaskHelper.FirstSuccess(tcs1.Task, tcs2.Task, tcs3.Task); + + tcs1.SetException(new Exception()); + tcs2.SetException(new Exception()); + tcs3.SetException(new Exception()); + + await t.ContinueWith((t) => + { + Assert.Multiple(() => + { + Assert.That(t.IsFaulted, Is.True); + Assert.That(t.Exception?.InnerException, Is.TypeOf()); + Assert.That((t.Exception?.InnerException as AggregateException)?.InnerExceptions, Has.Count.EqualTo(3)); + }); + }); + } + + [Test] + public async Task Test_SingleSuccess_ReturnsCompletedTask() + { + TaskCompletionSource tcs1 = new(); + TaskCompletionSource tcs2 = new(); + TaskCompletionSource tcs3 = new(); + + Task t = TaskHelper.FirstSuccess(tcs1.Task, tcs2.Task, tcs3.Task); + + tcs1.SetException(new Exception()); + tcs2.SetException(new Exception()); + _ = Task.Delay(100).ContinueWith(t => tcs3.SetResult(true)); + + Task result = await t; + + Assert.That(result, Is.EqualTo(tcs3.Task)); + Assert.That((result as Task)!.Result, Is.EqualTo(true)); + } +} diff --git a/src/libp2p/Libp2p.Core.TestsBase/E2e/ChannelBus.cs b/src/libp2p/Libp2p.Core.TestsBase/E2e/ChannelBus.cs index e7db3c7d..ef083406 100644 --- a/src/libp2p/Libp2p.Core.TestsBase/E2e/ChannelBus.cs +++ b/src/libp2p/Libp2p.Core.TestsBase/E2e/ChannelBus.cs @@ -2,7 +2,6 @@ // SPDX-License-Identifier: MIT using Microsoft.Extensions.Logging; -using Newtonsoft.Json; using System.Threading.Channels; namespace Nethermind.Libp2p.Core.TestsBase.E2e; @@ -30,7 +29,7 @@ public async IAsyncEnumerable GetIncomingRequests(PeerId serverId) logger?.LogDebug($"Listen {serverId}"); - await foreach (var item in col.Reader.ReadAllAsync()) + await foreach (ClientChannel item in col.Reader.ReadAllAsync()) { logger?.LogDebug($"New request from {item.Client} to {serverId}"); yield return item.Channel; diff --git a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestBuilder.cs b/src/libp2p/Libp2p.Core.TestsBase/E2e/TestBuilder.cs index 91f51f44..7912768d 100644 --- a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestBuilder.cs +++ b/src/libp2p/Libp2p.Core.TestsBase/E2e/TestBuilder.cs @@ -1,13 +1,45 @@ // SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited // SPDX-License-Identifier: MIT +using Microsoft.Extensions.Logging; +using Nethermind.Libp2p.Core.Discovery; +using Nethermind.Libp2p.Protocols; +using System.Collections.Concurrent; + namespace Nethermind.Libp2p.Core.TestsBase.E2e; -public class TestBuilder(ChannelBus? commmonBus = null, IServiceProvider? serviceProvider = null) : PeerFactoryBuilderBase(serviceProvider) +public class TestBuilder(IServiceProvider? serviceProvider = null) : PeerFactoryBuilderBase(serviceProvider) +{ + protected override ProtocolRef[] BuildStack(IEnumerable additionalProtocols) + { + ProtocolRef root = Get(); + + Connect([root], + [ + Get(), + Get(), + .. additionalProtocols + ]); + + return [root]; + } +} + +public class TestPeerFactory(IProtocolStackSettings protocolStackSettings, PeerStore peerStore, ILoggerFactory? loggerFactory = null) : PeerFactory(protocolStackSettings, peerStore) +{ + ConcurrentDictionary peers = new(); + + public override ILocalPeer Create(Identity? identity = default) + { + ArgumentNullException.ThrowIfNull(identity); + return peers.GetOrAdd(identity.PeerId, (p) => new TestLocalPeer(identity, protocolStackSettings, peerStore, loggerFactory)); + } +} + +internal class TestLocalPeer(Identity id, IProtocolStackSettings protocolStackSettings, PeerStore peerStore, ILoggerFactory? loggerFactory = null) : LocalPeer(id, peerStore, protocolStackSettings, loggerFactory) { - protected override ProtocolStack BuildStack() + protected override async Task ConnectedTo(ISession session, bool isDialer) { - return Over(new TestMuxerProtocol(commmonBus ?? new ChannelBus(), new TestContextLoggerFactory())) - .AddAppLayerProtocol(); + await session.DialAsync(); } } diff --git a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestLocalPeer.cs b/src/libp2p/Libp2p.Core.TestsBase/E2e/TestLocalPeer.cs deleted file mode 100644 index d09db191..00000000 --- a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestLocalPeer.cs +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -using Multiformats.Address; - -namespace Nethermind.Libp2p.Core.TestsBase.E2e; - -internal class TestLocalPeer(Identity id) : ILocalPeer -{ - public Identity Identity { get => id; set => throw new NotImplementedException(); } - public Multiaddress Address { get => $"/p2p/{id.PeerId}"; set => throw new NotImplementedException(); } - - public Task DialAsync(Multiaddress addr, CancellationToken token = default) - { - throw new NotImplementedException(); - } - - public Task ListenAsync(Multiaddress addr, CancellationToken token = default) - { - throw new NotImplementedException(); - } -} diff --git a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestMuxerProtocol.cs b/src/libp2p/Libp2p.Core.TestsBase/E2e/TestMuxerProtocol.cs index 4ead3772..10fd051d 100644 --- a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestMuxerProtocol.cs +++ b/src/libp2p/Libp2p.Core.TestsBase/E2e/TestMuxerProtocol.cs @@ -1,84 +1,115 @@ using Google.Protobuf; using Microsoft.Extensions.Logging; +using Multiformats.Address; using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Core.Dto; +using Nethermind.Libp2p.Core.Exceptions; using Nethermind.Libp2p.Core.TestsBase.Dto; using Nethermind.Libp2p.Core.TestsBase.E2e; using Org.BouncyCastle.Utilities.Encoders; using System.Buffers; -class TestMuxerProtocol(ChannelBus bus, ILoggerFactory? loggerFactory = null) : IProtocol +class TestMuxerProtocol(ChannelBus bus, ILoggerFactory? loggerFactory = null) : ITransportProtocol { private const string id = "test-muxer"; private readonly ILogger? logger = loggerFactory?.CreateLogger(id); public string Id => id; + public static Multiaddress[] GetDefaultAddresses(PeerId peerId) => [$"/p2p/{peerId}"]; + public static bool IsAddressMatch(Multiaddress addr) => true; - public async Task DialAsync(IChannel downChannel, IChannelFactory? upChannelFactory, IPeerContext context) + public async Task DialAsync(ITransportContext context, Multiaddress remoteAddr, CancellationToken token) { - logger?.LogDebug($"{context.LocalPeer.Identity.PeerId}: Dial async"); - context.Connected(context.RemotePeer); - await Task.Run(() => HandleRemote(bus.Dial(context.LocalPeer.Identity.PeerId, context.RemotePeer.Address.GetPeerId()!), upChannelFactory!, context)); + logger?.LogDebug($"{context.Peer.Identity.PeerId}: Dial async"); + + //await Task.Run(async () => + //{ + IChannel chan = bus.Dial(context.Peer.Identity.PeerId, remoteAddr.GetPeerId()!); + using INewConnectionContext connection = context.CreateConnection(); + connection.State.RemoteAddress = remoteAddr; + + await HandleRemote(chan, connection, context); + //}); } - public async Task ListenAsync(IChannel downChannel, IChannelFactory? upChannelFactory, IPeerContext context) + public async Task ListenAsync(ITransportContext context, Multiaddress listenAddr, CancellationToken token) { - context.ListenerReady(); - logger?.LogDebug($"{context.LocalPeer.Identity.PeerId}: Listen async"); - await foreach (var item in bus.GetIncomingRequests(context.LocalPeer.Identity.PeerId)) + context.ListenerReady(listenAddr); + logger?.LogDebug($"{context.Peer.Identity.PeerId}: Listen async"); + await foreach (IChannel item in bus.GetIncomingRequests(context.Peer.Identity.PeerId)) { - logger?.LogDebug($"{context.LocalPeer.Identity.PeerId}: Listener handles new con"); - _ = HandleRemote(item, upChannelFactory!, context, true); + _ = Task.Run(async () => + { + INewConnectionContext connection = context.CreateConnection(); + logger?.LogDebug($"{context.Peer.Identity.PeerId}: Listener handles new con"); + try + { + await HandleRemote(item, connection, context, true); + } + catch (SessionExistsException) + { + logger?.LogDebug($"{context.Peer.Identity.PeerId}: Listener rejected inititation of a redundant session"); + } + catch (Exception e) + { + logger?.LogError(e, $"{context.Peer.Identity.PeerId}: Listener exception"); + } + }, token); } } - private async Task HandleRemote(IChannel downChannel, IChannelFactory upChannelFactory, IPeerContext context, bool isListen = false) + private async Task HandleRemote(IChannel downChannel, INewConnectionContext connection, ITransportContext context, bool isListen = false) { uint counter = isListen ? 1u : 0u; Dictionary chans = []; - string peer = ""; - context = context.Fork(); + PublicKey? remotePublicKey; + PeerId? remotePeerId; if (isListen) { - peer = await downChannel.ReadLineAsync(); - await downChannel.WriteLineAsync(context.LocalPeer.Identity.PeerId!.ToString()); - logger?.LogDebug($"{context.LocalPeer.Identity.PeerId}: Listener handles remote {peer}"); + remotePublicKey = await downChannel.ReadPrefixedProtobufAsync(PublicKey.Parser); + remotePeerId = new PeerId(remotePublicKey); + await downChannel.WriteSizeAndProtobufAsync(context.Peer.Identity.PublicKey); + logger?.LogDebug($"{context.Peer.Identity.PeerId}: Listener handles remote {remotePeerId}"); } else { - await downChannel.WriteLineAsync(context.LocalPeer.Identity.PeerId!.ToString()); - peer = await downChannel.ReadLineAsync(); - logger?.LogDebug($"{context.LocalPeer.Identity.PeerId}: Dialer handles remote {peer}"); + await downChannel.WriteSizeAndProtobufAsync(context.Peer.Identity.PublicKey); + remotePublicKey = await downChannel.ReadPrefixedProtobufAsync(PublicKey.Parser); + remotePeerId = new PeerId(remotePublicKey); + logger?.LogDebug($"{context.Peer.Identity.PeerId}: Dialer handles remote {remotePeerId}"); } - context.RemotePeer.Address = $"/p2p/{peer}"; + connection.State.RemotePublicKey = remotePublicKey; + connection.State.RemoteAddress = $"/p2p/{remotePeerId}"; + using INewSessionContext? session = connection.UpgradeToSession(); - string logPrefix = $"{context.LocalPeer.Identity.PeerId}<>{peer}"; + string logPrefix = $"{context.Peer.Identity.PeerId}<>{remotePeerId}"; - _ = Task.Run(async () => + _ = Task.Run(() => { - foreach (var item in context.SubDialRequests.GetConsumingEnumerable()) + foreach (UpgradeOptions item in session.DialRequests) { uint chanId = Interlocked.Add(ref counter, 2); - logger?.LogDebug($"{context.LocalPeer.Identity.PeerId}({chanId}): Sub-request {item.SubProtocol} {item.CompletionSource is not null} from {context.RemotePeer.Address.GetPeerId()}"); + logger?.LogDebug($"{context.Peer.Identity.PeerId}({chanId}): Sub-request {item.SelectedProtocol} {item.CompletionSource is not null} to call {connection.State.RemoteAddress.GetPeerId()}"); - chans[chanId] = new MuxerChannel { Tcs = item.CompletionSource }; - var response = new MuxerPacket() + chans[chanId] = new MuxerChannel { Tcs = item.CompletionSource, Argument = item.Argument }; + MuxerPacket response = new() { ChannelId = chanId, Type = MuxerPacketType.NewStreamRequest, - Protocols = { item.SubProtocol!.Id } + Protocols = { item.SelectedProtocol!.Id } }; logger?.LogDebug($"{logPrefix}({response.ChannelId}): > Packet {response.Type} {string.Join(",", response.Protocols)} {response.Data?.Length ?? 0}"); _ = downChannel.WriteSizeAndProtobufAsync(response); } - logger?.LogDebug($"{context.LocalPeer.Identity.PeerId}: SubDialRequests End"); - + logger?.LogDebug($"{context.Peer.Identity.PeerId}: SubDialRequests End"); + return Task.CompletedTask; }); while (true) @@ -87,7 +118,7 @@ private async Task HandleRemote(IChannel downChannel, IChannelFactory upChannelF { logger?.LogDebug($"{logPrefix}: < READY({(isListen ? "list" : "dial")})"); - var packet = await downChannel.ReadPrefixedProtobufAsync(MuxerPacket.Parser); + MuxerPacket packet = await downChannel.ReadPrefixedProtobufAsync(MuxerPacket.Parser); logger?.LogDebug($"{logPrefix}({packet.ChannelId}): < Packet {packet.Type} {string.Join(",", packet.Protocols)} {packet.Data?.Length ?? 0}"); @@ -95,15 +126,15 @@ private async Task HandleRemote(IChannel downChannel, IChannelFactory upChannelF { case MuxerPacketType.NewStreamRequest: IProtocol? selected = null; - foreach (var proto in packet.Protocols) + foreach (string? proto in packet.Protocols) { - selected = upChannelFactory.SubProtocols.FirstOrDefault(x => x.Id == proto); + selected = session.SubProtocols.FirstOrDefault(x => x.Id == proto); if (selected is not null) break; } if (selected is not null) { logger?.LogDebug($"{logPrefix}({packet.ChannelId}): Matched {selected}"); - var response = new MuxerPacket() + MuxerPacket response = new() { ChannelId = packet.ChannelId, Type = MuxerPacketType.NewStreamResponse, @@ -113,9 +144,9 @@ private async Task HandleRemote(IChannel downChannel, IChannelFactory upChannelF } }; - var req = new ChannelRequest { SubProtocol = selected }; + UpgradeOptions req = new() { SelectedProtocol = selected, ModeOverride = UpgradeModeOverride.Listen }; - IChannel upChannel = upChannelFactory.SubListen(context, req); + IChannel upChannel = session.Upgrade(selected, req); chans[packet.ChannelId] = new MuxerChannel { UpChannel = upChannel }; _ = HandleUpchannelData(downChannel, chans, packet.ChannelId, upChannel, logPrefix); @@ -127,7 +158,7 @@ private async Task HandleRemote(IChannel downChannel, IChannelFactory upChannelF { logger?.LogDebug($"{logPrefix}({packet.ChannelId}): No match {packet.Type} {string.Join(",", packet.Protocols)} {packet.Data?.Length ?? 0}"); - var response = new MuxerPacket() + MuxerPacket response = new() { ChannelId = packet.ChannelId, Type = MuxerPacketType.NewStreamResponse, @@ -141,10 +172,10 @@ private async Task HandleRemote(IChannel downChannel, IChannelFactory upChannelF case MuxerPacketType.NewStreamResponse: if (packet.Protocols.Any()) { - var req = new ChannelRequest { SubProtocol = upChannelFactory.SubProtocols.FirstOrDefault(x => x.Id == packet.Protocols.First()) }; - IChannel upChannel = upChannelFactory.SubDial(context, req); + UpgradeOptions req = new() { SelectedProtocol = session.SubProtocols.FirstOrDefault(x => x.Id == packet.Protocols.First()), CompletionSource = chans[packet.ChannelId].Tcs, Argument = chans[packet.ChannelId].Argument, ModeOverride = UpgradeModeOverride.Dial }; + IChannel upChannel = session.Upgrade(session.SubProtocols.FirstOrDefault(x => x.Id == packet.Protocols.First())!, req); chans[packet.ChannelId].UpChannel = upChannel; - logger?.LogDebug($"{logPrefix}({packet.ChannelId}): Start upchanel with {req.SubProtocol}"); + logger?.LogDebug($"{logPrefix}({packet.ChannelId}): Start upchanel with {req.SelectedProtocol}"); _ = HandleUpchannelData(downChannel, chans, packet.ChannelId, upChannel, logPrefix); } else @@ -153,23 +184,38 @@ private async Task HandleRemote(IChannel downChannel, IChannelFactory upChannelF } break; case MuxerPacketType.Data: - logger?.LogDebug($"{logPrefix}({packet.ChannelId}): Data to upchanel {packet.Data?.Length ?? 0} {Hex.ToHexString(packet.Data?.ToByteArray() ?? [])}"); - _ = chans[packet.ChannelId].UpChannel!.WriteAsync(new ReadOnlySequence(packet.Data.ToByteArray())); + if (packet.Data is null or []) + { + logger?.LogWarning($"{logPrefix}({packet.ChannelId}): Empty data received"); + break; + } + logger?.LogDebug($"{logPrefix}({packet.ChannelId}): Data to upchanel {packet.Data.Length} {Hex.ToHexString(packet.Data.ToByteArray())}"); + _ = chans.GetValueOrDefault(packet.ChannelId)?.UpChannel?.WriteAsync(new ReadOnlySequence(packet.Data.ToByteArray())); break; case MuxerPacketType.CloseWrite: logger?.LogDebug($"{logPrefix}({packet.ChannelId}): Remote EOF"); - chans[packet.ChannelId].RemoteClosedWrites = true; - _ = chans[packet.ChannelId].UpChannel!.WriteEofAsync(); + lock (chans[packet.ChannelId]) + { + chans[packet.ChannelId].RemoteClosedWrites = true; + + _ = chans[packet.ChannelId].UpChannel?.WriteEofAsync(); + + if (chans[packet.ChannelId].LocalClosedWrites) + { + //chans[packet.ChannelId].Tcs?.SetResult(null); + _ = chans[packet.ChannelId].UpChannel?.CloseAsync(); + chans.Remove(packet.ChannelId); + } + } break; default: break; } } - catch + catch (Exception e) { - - + logger?.LogError(e, $"{logPrefix}: Muxer listener exception"); } } @@ -181,12 +227,12 @@ private Task HandleUpchannelData(IChannel downChannel, Dictionary item in upChannel.ReadAllAsync()) { - var data = item.ToArray(); + byte[] data = item.ToArray(); logger?.LogDebug($"{logPrefix}({channelId}): Upchannel data {data.Length} {Hex.ToHexString(data, false)}"); - var packet = new MuxerPacket() + MuxerPacket packet = new() { ChannelId = channelId, Type = MuxerPacketType.Data, @@ -197,23 +243,31 @@ private Task HandleUpchannelData(IChannel downChannel, Dictionary Packet {packet.Type} {string.Join(",", packet.Protocols)} {packet.Data?.Length ?? 0}"); + logger?.LogDebug($"{logPrefix}({channelId}): Upchannel write close"); - _ = downChannel.WriteSizeAndProtobufAsync(packet); + { + MuxerPacket packet = new() + { + ChannelId = channelId, + Type = MuxerPacketType.CloseWrite, + }; + + logger?.LogDebug($"{logPrefix}({packet.ChannelId}): > Packet {packet.Type} {string.Join(",", packet.Protocols)} {packet.Data?.Length ?? 0}"); + + _ = downChannel.WriteSizeAndProtobufAsync(packet); + } } } catch @@ -226,7 +280,9 @@ private Task HandleUpchannelData(IChannel downChannel, Dictionary? Tcs { get; set; } public bool RemoteClosedWrites { get; set; } + public bool LocalClosedWrites { get; set; } + public object? Argument { get; internal set; } } } diff --git a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestMuxerTests.cs b/src/libp2p/Libp2p.Core.TestsBase/E2e/TestMuxerTests.cs index b5593514..496fe013 100644 --- a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestMuxerTests.cs +++ b/src/libp2p/Libp2p.Core.TestsBase/E2e/TestMuxerTests.cs @@ -11,20 +11,21 @@ internal class TestMuxerTests [Test] public async Task Test_ConnectionEstablished_AfterHandshake() { - ServiceProvider sp = new ServiceCollection() - .AddSingleton(sp => new TestBuilder(null, sp)) + ChannelBus channelBus = new(); + ServiceProvider MakeServiceProvider() => new ServiceCollection() + .AddSingleton(sp => new TestBuilder(sp)) + .AddSingleton() .AddSingleton() + .AddSingleton(channelBus) .AddSingleton(sp => sp.GetService()!.Build()) .BuildServiceProvider(); - IPeerFactory peerFactory = sp.GetService()!; + ILocalPeer peerA = MakeServiceProvider().GetRequiredService().Create(TestPeers.Identity(1)); + await peerA.StartListenAsync(); + ILocalPeer peerB = MakeServiceProvider().GetRequiredService().Create(TestPeers.Identity(2)); + await peerB.StartListenAsync(); - ILocalPeer peerA = peerFactory.Create(TestPeers.Identity(1)); - await peerA.ListenAsync(TestPeers.Multiaddr(1)); - ILocalPeer peerB = peerFactory.Create(TestPeers.Identity(2)); - await peerB.ListenAsync(TestPeers.Multiaddr(2)); - - IRemotePeer remotePeerB = await peerA.DialAsync(peerB.Address); + ISession remotePeerB = await peerA.DialAsync(TestPeers.Multiaddr(2)); await remotePeerB.DialAsync(); } } diff --git a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestPeerFactory.cs b/src/libp2p/Libp2p.Core.TestsBase/E2e/TestPeerFactory.cs deleted file mode 100644 index dcfcaebb..00000000 --- a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestPeerFactory.cs +++ /dev/null @@ -1,18 +0,0 @@ -// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -using Multiformats.Address; -using System.Collections.Concurrent; - -namespace Nethermind.Libp2p.Core.TestsBase.E2e; - -internal class TestPeerFactory(IServiceProvider serviceProvider) : PeerFactory(serviceProvider) -{ - ConcurrentDictionary peers = new(); - - public override ILocalPeer Create(Identity? identity = null, Multiaddress? localAddr = null) - { - ArgumentNullException.ThrowIfNull(identity); - return peers.GetOrAdd(identity.PeerId, (p) => new TestLocalPeer(identity)); - } -} diff --git a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestPingProtocol.cs b/src/libp2p/Libp2p.Core.TestsBase/E2e/TestPingProtocol.cs index 0a735853..6d4438d6 100644 --- a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestPingProtocol.cs +++ b/src/libp2p/Libp2p.Core.TestsBase/E2e/TestPingProtocol.cs @@ -4,11 +4,11 @@ using NUnit.Framework; namespace Nethermind.Libp2p.Core.TestsBase.E2e; -class TestPingProtocol : IProtocol +class TestPingProtocol : ISessionProtocol { public string Id => "test-ping"; - public async Task DialAsync(IChannel downChannel, IChannelFactory? upChannelFactory, IPeerContext context) + public async Task DialAsync(IChannel downChannel, ISessionContext context) { string str = "hello"; await downChannel.WriteLineAsync(str); @@ -16,7 +16,7 @@ public async Task DialAsync(IChannel downChannel, IChannelFactory? upChannelFact Assert.That(res, Is.EqualTo(str + " there")); } - public async Task ListenAsync(IChannel downChannel, IChannelFactory? upChannelFactory, IPeerContext context) + public async Task ListenAsync(IChannel downChannel, ISessionContext context) { string str = await downChannel.ReadLineAsync(); await downChannel.WriteLineAsync(str + " there"); diff --git a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestSuite.cs b/src/libp2p/Libp2p.Core.TestsBase/E2e/TestSuite.cs deleted file mode 100644 index a355e05b..00000000 --- a/src/libp2p/Libp2p.Core.TestsBase/E2e/TestSuite.cs +++ /dev/null @@ -1,12 +0,0 @@ -// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -namespace Nethermind.Libp2p.Core.TestsBase.E2e; - -public class TestSuite -{ - public static IPeerFactory CreateLibp2p(params Type[] appProcols) - { - return new TestBuilder().Build(); - } -} diff --git a/src/libp2p/Libp2p.Core.TestsBase/Libp2p.Core.TestsBase.csproj b/src/libp2p/Libp2p.Core.TestsBase/Libp2p.Core.TestsBase.csproj index 01c7ec59..bec831a2 100644 --- a/src/libp2p/Libp2p.Core.TestsBase/Libp2p.Core.TestsBase.csproj +++ b/src/libp2p/Libp2p.Core.TestsBase/Libp2p.Core.TestsBase.csproj @@ -33,6 +33,7 @@ + diff --git a/src/libp2p/Libp2p.Core.TestsBase/LocalPeerStub.cs b/src/libp2p/Libp2p.Core.TestsBase/LocalPeerStub.cs index 86045f3c..78fbf036 100644 --- a/src/libp2p/Libp2p.Core.TestsBase/LocalPeerStub.cs +++ b/src/libp2p/Libp2p.Core.TestsBase/LocalPeerStub.cs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT using Multiformats.Address; +using System.Collections.ObjectModel; namespace Nethermind.Libp2p.Core.TestsBase; @@ -16,18 +17,37 @@ public LocalPeerStub() public Identity Identity { get; set; } public Multiaddress Address { get; set; } - public Task DialAsync(Multiaddress addr, CancellationToken token = default) + public ObservableCollection ListenAddresses => throw new NotImplementedException(); + + public event Connected? OnConnected; + + public Task DialAsync(Multiaddress addr, CancellationToken token = default) + { + return Task.FromResult(new TestRemotePeer(addr)); + } + + public Task DialAsync(Multiaddress[] samePeerAddrs, CancellationToken token = default) { - return Task.FromResult(new TestRemotePeer(addr)); + return Task.FromResult(new TestRemotePeer(samePeerAddrs.First())); } - public Task ListenAsync(Multiaddress addr, CancellationToken token = default) + public Task DialAsync(PeerId peerId, CancellationToken token = default) { - return Task.FromResult(null); + throw new NotImplementedException(); + } + + public Task DisconnectAsync() + { + return Task.CompletedTask; + } + + public Task StartListenAsync(Multiaddress[] addrs, CancellationToken token = default) + { + throw new NotImplementedException(); } } -public class TestRemotePeer : IRemotePeer +public class TestRemotePeer : ISession { public TestRemotePeer(Multiaddress addr) { @@ -38,11 +58,18 @@ public TestRemotePeer(Multiaddress addr) public Identity Identity { get; set; } public Multiaddress Address { get; set; } - public Task DialAsync(CancellationToken token = default) where TProtocol : IProtocol + public Multiaddress RemoteAddress => $"/p2p/{Identity.PeerId}"; + + public Task DialAsync(CancellationToken token = default) where TProtocol : ISessionProtocol { return Task.CompletedTask; } + public Task DialAsync(TRequest request, CancellationToken token = default) where TProtocol : ISessionProtocol + { + throw new NotImplementedException(); + } + public Task DisconnectAsync() { return Task.CompletedTask; diff --git a/src/libp2p/Libp2p.Core.TestsBase/TestContextLoggerFactory.cs b/src/libp2p/Libp2p.Core.TestsBase/TestContextLoggerFactory.cs index cd337265..9c1520bd 100644 --- a/src/libp2p/Libp2p.Core.TestsBase/TestContextLoggerFactory.cs +++ b/src/libp2p/Libp2p.Core.TestsBase/TestContextLoggerFactory.cs @@ -13,39 +13,32 @@ class TestContextLogger(string categoryName) : ILogger, IDisposable { private readonly string _categoryName = categoryName; - public IDisposable? BeginScope(TState state) where TState : notnull - { - return this; - } + public IDisposable? BeginScope(TState state) where TState : notnull => this; - public void Dispose() - { - } + public void Dispose() { } + public bool IsEnabled(LogLevel logLevel) => true; - public bool IsEnabled(LogLevel logLevel) + private static string ToString(LogLevel level) => level switch { - return true; - } + LogLevel.Trace => "TRAC", + LogLevel.Debug => "DEBG", + LogLevel.Information => "INFO", + LogLevel.Warning => "WARN", + LogLevel.Error => "EROR", + LogLevel.Critical => "CRIT", + LogLevel.None => "NONE", + _ => throw new NotImplementedException() + }; public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) { - TestContext.Out.WriteLine($"{logLevel} {_categoryName}:{eventId}: {(exception is null ? state?.ToString() : formatter(state, exception))}"); - Debug.WriteLine($"{logLevel} {_categoryName}:{eventId}: {(exception is null ? state?.ToString() : formatter(state, exception))}"); + string log = $"{ToString(logLevel)} {_categoryName}: {(exception is null ? state?.ToString() : formatter(state, exception))}"; + TestContext.Out.WriteLine(log); + Debug.WriteLine(log); } } - public void AddProvider(ILoggerProvider provider) - { - - } - - public ILogger CreateLogger(string categoryName) - { - return new TestContextLogger(categoryName); - } - - public void Dispose() - { - - } + public void AddProvider(ILoggerProvider provider) { } + public ILogger CreateLogger(string categoryName) => new TestContextLogger(categoryName); + public void Dispose() { } } diff --git a/src/libp2p/Libp2p.Core.TestsBase/TestDiscoveryProtocol.cs b/src/libp2p/Libp2p.Core.TestsBase/TestDiscoveryProtocol.cs index 3aae5d9c..476e496d 100644 --- a/src/libp2p/Libp2p.Core.TestsBase/TestDiscoveryProtocol.cs +++ b/src/libp2p/Libp2p.Core.TestsBase/TestDiscoveryProtocol.cs @@ -8,7 +8,7 @@ namespace Nethermind.Libp2p.Core.TestsBase; // public Func? OnAddPeer { get; set; } // public Func? OnRemovePeer { get; set; } -// public Task DiscoverAsync(Multiaddress localPeerAddr, CancellationToken token = default) +// public Task DiscoverAsync(IPeer peer, CancellationToken token = default) // { // TaskCompletionSource task = new(); // token.Register(task.SetResult); diff --git a/src/libp2p/Libp2p.Core/Channel.cs b/src/libp2p/Libp2p.Core/Channel.cs index 1e3f5707..e19596ef 100644 --- a/src/libp2p/Libp2p.Core/Channel.cs +++ b/src/libp2p/Libp2p.Core/Channel.cs @@ -11,7 +11,7 @@ namespace Nethermind.Libp2p.Core; -internal class Channel : IChannel +public class Channel : IChannel { private IChannel? _reversedChannel; private ReaderWriter _reader; diff --git a/src/libp2p/Libp2p.Core/ChannelFactory.cs b/src/libp2p/Libp2p.Core/ChannelFactory.cs deleted file mode 100644 index f63821c8..00000000 --- a/src/libp2p/Libp2p.Core/ChannelFactory.cs +++ /dev/null @@ -1,108 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Nethermind.Libp2p.Core.Extensions; - -namespace Nethermind.Libp2p.Core; - -public class ChannelFactory : IChannelFactory -{ - private readonly IServiceProvider _serviceProvider; - private readonly ILoggerFactory? _loggerFactory; - private IDictionary _factories; - private readonly ILogger? _logger; - - public ChannelFactory(IServiceProvider serviceProvider) - { - _serviceProvider = serviceProvider; - _loggerFactory = _serviceProvider.GetService(); - _logger = _loggerFactory?.CreateLogger(); - } - - public IEnumerable SubProtocols => _factories.Keys; - - public IChannel SubDial(IPeerContext context, IChannelRequest? req = null) - { - IProtocol? subProtocol = req?.SubProtocol ?? SubProtocols.FirstOrDefault(); - Channel channel = new(); - ChannelFactory? channelFactory = _factories[subProtocol] as ChannelFactory; - - - _ = subProtocol.DialAsync(channel.Reverse, channelFactory, context) - .ContinueWith(async task => - { - if (!task.IsCompletedSuccessfully) - { - _logger?.DialFailed(subProtocol.Id, task.Exception, task.Exception.GetErrorMessage()); - } - await channel.CloseAsync(); - - req?.CompletionSource?.SetResult(); - }); - - return channel; - } - - public IChannel SubListen(IPeerContext context, IChannelRequest? req = null) - { - IProtocol? subProtocol = req?.SubProtocol ?? SubProtocols.FirstOrDefault(); - Channel channel = new(); - ChannelFactory? channelFactory = _factories[subProtocol] as ChannelFactory; - - - _ = subProtocol.ListenAsync(channel.Reverse, channelFactory, context) - .ContinueWith(async task => - { - if (!task.IsCompletedSuccessfully) - { - _logger?.ListenFailed(subProtocol.Id, task.Exception, task.Exception.GetErrorMessage()); - } - await channel.CloseAsync(); - - req?.CompletionSource?.SetResult(); - }); - - return channel; - } - - public Task SubDialAndBind(IChannel parent, IPeerContext context, - IChannelRequest? req = null) - { - IProtocol? subProtocol = req?.SubProtocol ?? SubProtocols.FirstOrDefault(); - ChannelFactory? channelFactory = _factories[subProtocol] as ChannelFactory; - - return subProtocol.DialAsync(((Channel)parent), channelFactory, context) - .ContinueWith(async task => - { - if (!task.IsCompletedSuccessfully) - { - _logger?.DialAndBindFailed(subProtocol.Id, task.Exception, task.Exception.GetErrorMessage()); - } - await parent.CloseAsync(); - - req?.CompletionSource?.SetResult(); - }); - } - - public Task SubListenAndBind(IChannel parent, IPeerContext context, - IChannelRequest? req = null) - { - IProtocol? subProtocol = req?.SubProtocol ?? SubProtocols.FirstOrDefault(); - ChannelFactory? channelFactory = _factories[subProtocol] as ChannelFactory; - - return subProtocol.ListenAsync(((Channel)parent), channelFactory, context) - .ContinueWith(async task => - { - await parent.CloseAsync(); - req?.CompletionSource?.SetResult(); - }); - } - - public ChannelFactory Setup(IDictionary factories) - { - _factories = factories; - return this; - } -} diff --git a/src/libp2p/Libp2p.Core/ChannelRequest.cs b/src/libp2p/Libp2p.Core/ChannelRequest.cs deleted file mode 100644 index 3f909665..00000000 --- a/src/libp2p/Libp2p.Core/ChannelRequest.cs +++ /dev/null @@ -1,15 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -namespace Nethermind.Libp2p.Core; - -public class ChannelRequest : IChannelRequest -{ - public IProtocol? SubProtocol { get; init; } - public TaskCompletionSource? CompletionSource { get; init; } - - public override string ToString() - { - return $"Request for {SubProtocol?.Id ?? "unknown protocol"}"; - } -} diff --git a/src/libp2p/Libp2p.Core/Context/ConnectionContext.cs b/src/libp2p/Libp2p.Core/Context/ConnectionContext.cs new file mode 100644 index 00000000..ae912a68 --- /dev/null +++ b/src/libp2p/Libp2p.Core/Context/ConnectionContext.cs @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +namespace Nethermind.Libp2p.Core.Context; + +public class ConnectionContext(LocalPeer localPeer, LocalPeer.Session session, ProtocolRef protocol, bool isListener, UpgradeOptions? upgradeOptions) : ContextBase(localPeer, session, protocol, isListener, upgradeOptions), IConnectionContext +{ + public UpgradeOptions? UpgradeOptions => upgradeOptions; + + public Task DisconnectAsync() + { + return session.DisconnectAsync(); + } +} diff --git a/src/libp2p/Libp2p.Core/Context/ContextBase.cs b/src/libp2p/Libp2p.Core/Context/ContextBase.cs new file mode 100644 index 00000000..3c70d6ab --- /dev/null +++ b/src/libp2p/Libp2p.Core/Context/ContextBase.cs @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Multiformats.Address; + +namespace Nethermind.Libp2p.Core.Context; + +public class ContextBase(LocalPeer localPeer, LocalPeer.Session session, ProtocolRef protocol, bool isListener, UpgradeOptions? upgradeOptions) : IChannelFactory +{ + protected bool isListener = isListener; + public ILocalPeer Peer => localPeer; + public State State => session.State; + + public IEnumerable SubProtocols => localPeer.GetProtocolsFor(protocol); + + public string Id { get; } = session.Id; + + protected LocalPeer localPeer = localPeer; + protected LocalPeer.Session session = session; + protected ProtocolRef protocol = protocol; + protected UpgradeOptions? upgradeOptions = upgradeOptions; + + public IChannel Upgrade(UpgradeOptions? upgradeOptions = null) + { + return localPeer.Upgrade(session, protocol, null, upgradeOptions ?? this.upgradeOptions, isListener); + } + + public IChannel Upgrade(IProtocol specificProtocol, UpgradeOptions? upgradeOptions = null) + { + return localPeer.Upgrade(session, protocol, specificProtocol, upgradeOptions ?? this.upgradeOptions, isListener); + } + + public Task Upgrade(IChannel parentChannel, UpgradeOptions? upgradeOptions = null) + { + return localPeer.Upgrade(session, parentChannel, protocol, null, upgradeOptions ?? this.upgradeOptions, isListener); + } + + public Task Upgrade(IChannel parentChannel, IProtocol specificProtocol, UpgradeOptions? upgradeOptions = null) + { + return localPeer.Upgrade(session, parentChannel, protocol, specificProtocol, upgradeOptions ?? this.upgradeOptions, isListener); + } + + public INewConnectionContext CreateConnection() + { + return localPeer.CreateConnection(protocol, null, isListener); + } + + public INewSessionContext UpgradeToSession() + { + return localPeer.UpgradeToSession(session, protocol, isListener); + } + + public void ListenerReady(Multiaddress addr) + { + localPeer.ListenerReady(this, addr); + } +} diff --git a/src/libp2p/Libp2p.Core/Context/NewConnectionContext.cs b/src/libp2p/Libp2p.Core/Context/NewConnectionContext.cs new file mode 100644 index 00000000..e6744d48 --- /dev/null +++ b/src/libp2p/Libp2p.Core/Context/NewConnectionContext.cs @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +namespace Nethermind.Libp2p.Core.Context; + +public class NewConnectionContext(LocalPeer localPeer, LocalPeer.Session session, ProtocolRef protocol, bool isListener, UpgradeOptions? upgradeOptions) : ContextBase(localPeer, session, protocol, isListener, upgradeOptions), INewConnectionContext +{ + public CancellationToken Token => session.ConnectionToken; + + public void Dispose() + { + + } +} diff --git a/src/libp2p/Libp2p.Core/Context/NewSessionContext.cs b/src/libp2p/Libp2p.Core/Context/NewSessionContext.cs new file mode 100644 index 00000000..21a6bbcf --- /dev/null +++ b/src/libp2p/Libp2p.Core/Context/NewSessionContext.cs @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +namespace Nethermind.Libp2p.Core.Context; + +public class NewSessionContext(LocalPeer localPeer, LocalPeer.Session session, ProtocolRef protocol, bool isListener, UpgradeOptions? upgradeOptions) : ContextBase(localPeer, session, protocol, isListener, upgradeOptions), INewSessionContext +{ + public IEnumerable DialRequests => session.GetRequestQueue(); + + public CancellationToken Token => session.ConnectionToken; + + public void Dispose() + { + + } +} diff --git a/src/libp2p/Libp2p.Core/Context/SessionContext.cs b/src/libp2p/Libp2p.Core/Context/SessionContext.cs new file mode 100644 index 00000000..73cdff53 --- /dev/null +++ b/src/libp2p/Libp2p.Core/Context/SessionContext.cs @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +namespace Nethermind.Libp2p.Core.Context; + +public class SessionContext(LocalPeer localPeer, LocalPeer.Session session, ProtocolRef protocol, bool isListener, UpgradeOptions? upgradeOptions) : ContextBase(localPeer, session, protocol, isListener, upgradeOptions), ISessionContext +{ + public UpgradeOptions? UpgradeOptions => upgradeOptions; + + public async Task DialAsync() where TProtocol : ISessionProtocol + { + await session.DialAsync(); + } + + public async Task DialAsync(ISessionProtocol protocol) + { + await session.DialAsync(protocol); + } + + public Task DisconnectAsync() + { + return session.DisconnectAsync(); + } +} diff --git a/src/libp2p/Libp2p.Core/Discovery/IDiscoveryProtocol.cs b/src/libp2p/Libp2p.Core/Discovery/IDiscoveryProtocol.cs index abd933ec..fd4029a5 100644 --- a/src/libp2p/Libp2p.Core/Discovery/IDiscoveryProtocol.cs +++ b/src/libp2p/Libp2p.Core/Discovery/IDiscoveryProtocol.cs @@ -7,5 +7,5 @@ namespace Nethermind.Libp2p.Core.Discovery; public interface IDiscoveryProtocol { - Task DiscoverAsync(Multiaddress localPeerAddr, CancellationToken token = default); + Task StartDiscoveryAsync(IReadOnlyList localPeerAddr, CancellationToken token = default); } diff --git a/src/libp2p/Libp2p.Core/Discovery/PeerStore.cs b/src/libp2p/Libp2p.Core/Discovery/PeerStore.cs index 7008217c..c026dc5c 100644 --- a/src/libp2p/Libp2p.Core/Discovery/PeerStore.cs +++ b/src/libp2p/Libp2p.Core/Discovery/PeerStore.cs @@ -4,13 +4,38 @@ using Google.Protobuf; using Multiformats.Address; using Nethermind.Libp2p.Core.Dto; +using Nethermind.Libp2p.Core.Extensions; using System.Collections.Concurrent; namespace Nethermind.Libp2p.Core.Discovery; public class PeerStore { - ConcurrentDictionary store = []; + private readonly ConcurrentDictionary _store = []; + + public void Discover(ByteString signedPeerRecord) + { + SignedEnvelope signedEnvelope = SignedEnvelope.Parser.ParseFrom(signedPeerRecord); + PublicKey publicKey = PublicKey.Parser.ParseFrom(signedEnvelope.PublicKey); + PeerId peerId = new Identity(publicKey).PeerId; + + if (!SigningHelper.VerifyPeerRecord(signedEnvelope, publicKey)) + { + return; + } + + Multiaddress[] addresses = PeerRecord.Parser.ParseFrom(signedEnvelope.Payload).Addresses + .Select(ai => Multiaddress.Decode(ai.Multiaddr.ToByteArray())) + .Where(a => a.GetPeerId() == peerId) + .ToArray(); + + if (addresses.Length == 0) + { + return; + } + + Discover(addresses); + } public void Discover(Multiaddress[] addrs) { @@ -24,8 +49,8 @@ public void Discover(Multiaddress[] addrs) if (peerId is not null) { PeerInfo? newOne = null; - PeerInfo peerInfo = store.GetOrAdd(peerId, (id) => newOne = new PeerInfo { Addrs = [.. addrs] }); - if (peerInfo != newOne && peerInfo.Addrs is not null && peerInfo.Addrs.Count == addrs.Length && addrs.All(peerInfo.Addrs.Contains)) + PeerInfo peerInfo = _store.GetOrAdd(peerId, (id) => newOne = new PeerInfo { Addrs = [.. addrs] }); + if (peerInfo != newOne && peerInfo.Addrs is not null && addrs.UnorderedSequenceEqual(peerInfo.Addrs)) { return; } @@ -45,7 +70,7 @@ public event Action? OnNewPeer } onNewPeer += value; - foreach (var item in store.Select(x => x.Value).ToArray()) + foreach (PeerInfo? item in _store.Select(x => x.Value).ToArray()) { if (item.Addrs is not null) value.Invoke(item.Addrs.ToArray()); } @@ -56,19 +81,17 @@ public event Action? OnNewPeer } } - public override string ToString() - { - return $"peerStore({store.Count}):{string.Join(",", store.Select(x => x.Key.ToString() ?? "null"))}"; - } + public override string ToString() => $"peerStore({_store.Count}):{string.Join(",", _store.Select(x => x.Key.ToString() ?? "null"))}"; public PeerInfo GetPeerInfo(PeerId peerId) { - return store.GetOrAdd(peerId, id => new PeerInfo()); + return _store.GetOrAdd(peerId, id => new PeerInfo()); } public class PeerInfo { public ByteString? SignedPeerRecord { get; set; } + public string[]? SupportedProtocols { get; set; } public HashSet? Addrs { get; set; } } } diff --git a/src/libp2p/Libp2p.Protocols.Identify/Dto/PeerRecord.cs b/src/libp2p/Libp2p.Core/Dto/PeerRecord.cs similarity index 94% rename from src/libp2p/Libp2p.Protocols.Identify/Dto/PeerRecord.cs rename to src/libp2p/Libp2p.Core/Dto/PeerRecord.cs index 2e2eef5f..b7099a92 100644 --- a/src/libp2p/Libp2p.Protocols.Identify/Dto/PeerRecord.cs +++ b/src/libp2p/Libp2p.Core/Dto/PeerRecord.cs @@ -9,7 +9,7 @@ using pbc = global::Google.Protobuf.Collections; using pbr = global::Google.Protobuf.Reflection; using scg = global::System.Collections.Generic; -namespace Nethermind.Libp2p.Protocols.Identify.Dto { +namespace Nethermind.Libp2p.Core.Dto { /// Holder for reflection information generated from PeerRecord.proto public static partial class PeerRecordReflection { @@ -26,13 +26,13 @@ static PeerRecordReflection() { string.Concat( "ChBQZWVyUmVjb3JkLnByb3RvIiAKC0FkZHJlc3NJbmZvEhEKCW11bHRpYWRk", "chgBIAIoDCJLCgpQZWVyUmVjb3JkEg8KB3BlZXJfaWQYASACKAwSCwoDc2Vx", - "GAIgAigEEh8KCWFkZHJlc3NlcxgDIAMoCzIMLkFkZHJlc3NJbmZvQiuqAihO", - "ZXRoZXJtaW5kLkxpYnAycC5Qcm90b2NvbHMuSWRlbnRpZnkuRHRv")); + "GAIgAigEEh8KCWFkZHJlc3NlcxgDIAMoCzIMLkFkZHJlc3NJbmZvQh2qAhpO", + "ZXRoZXJtaW5kLkxpYnAycC5Db3JlLkR0bw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { }, new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Nethermind.Libp2p.Protocols.Identify.Dto.AddressInfo), global::Nethermind.Libp2p.Protocols.Identify.Dto.AddressInfo.Parser, new[]{ "Multiaddr" }, null, null, null, null), - new pbr::GeneratedClrTypeInfo(typeof(global::Nethermind.Libp2p.Protocols.Identify.Dto.PeerRecord), global::Nethermind.Libp2p.Protocols.Identify.Dto.PeerRecord.Parser, new[]{ "PeerId", "Seq", "Addresses" }, null, null, null, null) + new pbr::GeneratedClrTypeInfo(typeof(global::Nethermind.Libp2p.Core.Dto.AddressInfo), global::Nethermind.Libp2p.Core.Dto.AddressInfo.Parser, new[]{ "Multiaddr" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Nethermind.Libp2p.Core.Dto.PeerRecord), global::Nethermind.Libp2p.Core.Dto.PeerRecord.Parser, new[]{ "PeerId", "Seq", "Addresses" }, null, null, null, null) })); } #endregion @@ -54,7 +54,7 @@ public sealed partial class AddressInfo : pb::IMessage [global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { - get { return global::Nethermind.Libp2p.Protocols.Identify.Dto.PeerRecordReflection.Descriptor.MessageTypes[0]; } + get { return global::Nethermind.Libp2p.Core.Dto.PeerRecordReflection.Descriptor.MessageTypes[0]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] @@ -267,7 +267,7 @@ public sealed partial class PeerRecord : pb::IMessage [global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { - get { return global::Nethermind.Libp2p.Protocols.Identify.Dto.PeerRecordReflection.Descriptor.MessageTypes[1]; } + get { return global::Nethermind.Libp2p.Core.Dto.PeerRecordReflection.Descriptor.MessageTypes[1]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] @@ -361,15 +361,15 @@ public void ClearSeq() { /// Field number for the "addresses" field. public const int AddressesFieldNumber = 3; - private static readonly pb::FieldCodec _repeated_addresses_codec - = pb::FieldCodec.ForMessage(26, global::Nethermind.Libp2p.Protocols.Identify.Dto.AddressInfo.Parser); - private readonly pbc::RepeatedField addresses_ = new pbc::RepeatedField(); + private static readonly pb::FieldCodec _repeated_addresses_codec + = pb::FieldCodec.ForMessage(26, global::Nethermind.Libp2p.Core.Dto.AddressInfo.Parser); + private readonly pbc::RepeatedField addresses_ = new pbc::RepeatedField(); /// /// addresses is a list of public listen addresses for the peer. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] - public pbc::RepeatedField Addresses { + public pbc::RepeatedField Addresses { get { return addresses_; } } diff --git a/src/libp2p/Libp2p.Protocols.Identify/Dto/PeerRecord.proto b/src/libp2p/Libp2p.Core/Dto/PeerRecord.proto similarity index 89% rename from src/libp2p/Libp2p.Protocols.Identify/Dto/PeerRecord.proto rename to src/libp2p/Libp2p.Core/Dto/PeerRecord.proto index 186cd3b2..db1628e0 100644 --- a/src/libp2p/Libp2p.Protocols.Identify/Dto/PeerRecord.proto +++ b/src/libp2p/Libp2p.Core/Dto/PeerRecord.proto @@ -1,6 +1,6 @@ syntax = "proto2"; -option csharp_namespace = "Nethermind.Libp2p.Protocols.Identify.Dto"; +option csharp_namespace = "Nethermind.Libp2p.Core.Dto"; message AddressInfo { required bytes multiaddr = 1; diff --git a/src/libp2p/Libp2p.Core/Dto/SigningHelper.cs b/src/libp2p/Libp2p.Core/Dto/SigningHelper.cs new file mode 100644 index 00000000..e11c6b0b --- /dev/null +++ b/src/libp2p/Libp2p.Core/Dto/SigningHelper.cs @@ -0,0 +1,103 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Google.Protobuf; +using Multiformats.Address; + +namespace Nethermind.Libp2p.Core.Dto; + +public static class SigningHelper +{ + private static readonly byte[] PayloadType = [((ushort)Enums.Libp2p.Libp2pPeerRecord >> 8) & 0xFF, (ushort)Enums.Libp2p.Libp2pPeerRecord & 0xFF]; + private static readonly byte[] Domain = "libp2p-peer-record"u8.ToArray().ToArray(); + public static bool VerifyPeerRecord(ByteString signedEnvelopeBytes, PublicKey publicKey) + { + SignedEnvelope signedEnvelope = SignedEnvelope.Parser.ParseFrom(signedEnvelopeBytes); + return VerifyPeerRecord(signedEnvelope, publicKey); + } + + public static bool VerifyPeerRecord(SignedEnvelope signedEnvelope, PublicKey publicKey) + { + Identity identity = new(publicKey); + + if (signedEnvelope.PayloadType?.Take(2).SequenceEqual(PayloadType) is not true) + { + return false; + } + + PeerRecord pr = PeerRecord.Parser.ParseFrom(signedEnvelope.Payload); + + if (identity.PeerId != new PeerId(pr.PeerId.ToByteArray())) + { + return false; + } + + byte[] signedData = new byte[ + VarInt.GetSizeInBytes(Domain.Length) + Domain.Length + + VarInt.GetSizeInBytes(PayloadType.Length) + PayloadType.Length + + VarInt.GetSizeInBytes(signedEnvelope.Payload.Length) + signedEnvelope.Payload.Length]; + + int offset = 0; + + VarInt.Encode(Domain.Length, signedData.AsSpan(), ref offset); + Array.Copy(Domain, 0, signedData, offset, Domain.Length); + offset += Domain.Length; + + VarInt.Encode(PayloadType.Length, signedData.AsSpan(), ref offset); + Array.Copy(PayloadType, 0, signedData, offset, PayloadType.Length); + offset += PayloadType.Length; + + VarInt.Encode(signedEnvelope.Payload.Length, signedData.AsSpan(), ref offset); + Array.Copy(signedEnvelope.Payload.ToByteArray(), 0, signedData, offset, signedEnvelope.Payload.Length); + + return identity.VerifySignature(signedData, signedEnvelope.Signature.ToByteArray()); + } + + public static ByteString CreateSignedEnvelope(Identity identity, Multiaddress[] addresses, ulong seq) + { + PeerRecord payload = new() + { + PeerId = ByteString.CopyFrom(identity.PeerId.Bytes), + Seq = seq + }; + + foreach (Multiaddress address in addresses) + { + payload.Addresses.Add(new AddressInfo + { + Multiaddr = ByteString.CopyFrom(address.ToBytes()) + }); + } + + SignedEnvelope envelope = new() + { + PayloadType = ByteString.CopyFrom(PayloadType), + Payload = payload.ToByteString(), + PublicKey = identity.PublicKey.ToByteString(), + }; + + int payloadLength = payload.CalculateSize(); + + byte[] signingData = new byte[ + VarInt.GetSizeInBytes(Domain.Length) + Domain.Length + + VarInt.GetSizeInBytes(PayloadType.Length) + PayloadType.Length + + VarInt.GetSizeInBytes(payloadLength) + payloadLength]; + + int offset = 0; + + VarInt.Encode(Domain.Length, signingData.AsSpan(), ref offset); + Array.Copy(Domain, 0, signingData, offset, Domain.Length); + offset += Domain.Length; + + VarInt.Encode(PayloadType.Length, signingData.AsSpan(), ref offset); + Array.Copy(PayloadType, 0, signingData, offset, PayloadType.Length); + offset += PayloadType.Length; + + VarInt.Encode(payloadLength, signingData.AsSpan(), ref offset); + Array.Copy(payload.ToByteArray(), 0, signingData, offset, payloadLength); + + envelope.Signature = ByteString.CopyFrom(identity.Sign(signingData).ToArray()); + + return envelope.ToByteString(); + } +} diff --git a/src/libp2p/Libp2p.Core/Exceptions/Libp2pException.cs b/src/libp2p/Libp2p.Core/Exceptions/Libp2pException.cs index d36b4528..afe39951 100644 --- a/src/libp2p/Libp2p.Core/Exceptions/Libp2pException.cs +++ b/src/libp2p/Libp2p.Core/Exceptions/Libp2pException.cs @@ -5,17 +5,33 @@ namespace Nethermind.Libp2p.Core.Exceptions; public class Libp2pException : Exception { - public Libp2pException(string? message) : base(message) - { + public Libp2pException(string? message) : base(message) { } + public Libp2pException() : base() { } +} - } - public Libp2pException() : base() - { +/// +/// Exception instead of IOResult to signal a channel cannot send or receive data anymore +/// +public class ChannelClosedException() : Libp2pException("Channel closed"); - } -} +/// +/// Appears when libp2p is not set up properly in part of protocol tack, IoC, etc. +/// +/// +public class Libp2pSetupException(string? message = null) : Libp2pException(message); + +/// +/// Appears when there is already active session for the given peer +/// +public class SessionExistsException(PeerId remotePeerId) : Libp2pException($"Session is already established with {remotePeerId}"); + + +/// +/// Appears if connection to peer failed or declined +/// +public class PeerConnectionException(string? message = null) : Libp2pException(message); -public class ChannelClosedException : Libp2pException +public class DLibp2pException : Libp2pException { } diff --git a/src/libp2p/Libp2p.Core/Extensions/ChannelFactoryExtensions.cs b/src/libp2p/Libp2p.Core/Extensions/ChannelFactoryExtensions.cs deleted file mode 100644 index 3c5c70ed..00000000 --- a/src/libp2p/Libp2p.Core/Extensions/ChannelFactoryExtensions.cs +++ /dev/null @@ -1,10 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -namespace Nethermind.Libp2p.Core.Extensions; - -internal static class ChannelFactoryExtensions -{ - public static IEnumerable GetSubProtocols(this ChannelFactory? channelFactory) - => channelFactory?.SubProtocols.Select(protocol => protocol.Id) ?? Enumerable.Empty(); -} diff --git a/src/libp2p/Libp2p.Core/Extensions/EnumerableExtensions.cs b/src/libp2p/Libp2p.Core/Extensions/EnumerableExtensions.cs new file mode 100644 index 00000000..a34ab6ba --- /dev/null +++ b/src/libp2p/Libp2p.Core/Extensions/EnumerableExtensions.cs @@ -0,0 +1,9 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +namespace Nethermind.Libp2p.Core.Extensions; + +public static class EnumerableExtensions +{ + public static bool UnorderedSequenceEqual(this IEnumerable left, IEnumerable right) => left.OrderBy(x => x).SequenceEqual(right.OrderBy(x => x)); +} diff --git a/src/libp2p/Libp2p.Core/Extensions/TaskHelper.cs b/src/libp2p/Libp2p.Core/Extensions/TaskHelper.cs new file mode 100644 index 00000000..249a68f9 --- /dev/null +++ b/src/libp2p/Libp2p.Core/Extensions/TaskHelper.cs @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Nethermind.Libp2p.Core.Exceptions; + +namespace Nethermind.Libp2p.Core.Extensions; + +internal static class TaskHelper +{ + public static async Task FirstSuccess(params Task[] tasks) + { + TaskCompletionSource tcs = new(); + + Task all = Task.WhenAll(tasks.Select(t => t.ContinueWith(t => + { + if (t.IsCompletedSuccessfully) + { + tcs.TrySetResult(t); + } + if (t.IsFaulted && t.Exception.InnerException is SessionExistsException) + { + tcs.TrySetResult(t); + } + }))); + + Task result = await Task.WhenAny(tcs.Task, all); + if (result == all) + { + throw new AggregateException(tasks.Select(t => t.Exception?.InnerException).Where(ex => ex is not null)!); + } + return tcs.Task.Result; + } +} diff --git a/src/libp2p/Libp2p.Core/IChannel.cs b/src/libp2p/Libp2p.Core/IChannel.cs index ce24c06d..c933bb2d 100644 --- a/src/libp2p/Libp2p.Core/IChannel.cs +++ b/src/libp2p/Libp2p.Core/IChannel.cs @@ -14,9 +14,9 @@ CancellationToken CancellationToken { get { - var token = new CancellationTokenSource(); - GetAwaiter().OnCompleted(token.Cancel); - return token.Token; + CancellationTokenSource cts = new(); + GetAwaiter().OnCompleted(cts.Cancel); + return cts.Token; } } } diff --git a/src/libp2p/Libp2p.Core/IChannelFactory.cs b/src/libp2p/Libp2p.Core/IChannelFactory.cs index d8ff230e..698cab80 100644 --- a/src/libp2p/Libp2p.Core/IChannelFactory.cs +++ b/src/libp2p/Libp2p.Core/IChannelFactory.cs @@ -6,33 +6,25 @@ namespace Nethermind.Libp2p.Core; public interface IChannelFactory { IEnumerable SubProtocols { get; } - IChannel SubDial(IPeerContext context, IChannelRequest? request = null); - IChannel SubListen(IPeerContext context, IChannelRequest? request = null); + IChannel Upgrade(UpgradeOptions? options = null); + IChannel Upgrade(IProtocol specificProtocol, UpgradeOptions? options = null); - Task SubDialAndBind(IChannel parentChannel, IPeerContext context, IChannelRequest? request = null); - - Task SubListenAndBind(IChannel parentChannel, IPeerContext context, IChannelRequest? request = null); - - - - IChannel SubDial(IPeerContext context, IProtocol protocol) - { - return SubDial(context, new ChannelRequest { SubProtocol = protocol }); - } - - IChannel SubListen(IPeerContext context, IProtocol protocol) - { - return SubListen(context, new ChannelRequest { SubProtocol = protocol }); - } + Task Upgrade(IChannel parentChannel, UpgradeOptions? options = null); + Task Upgrade(IChannel parentChannel, IProtocol specificProtocol, UpgradeOptions? options = null); +} - Task SubDialAndBind(IChannel parentChannel, IPeerContext context, IProtocol protocol) - { - return SubDialAndBind(parentChannel, context, new ChannelRequest { SubProtocol = protocol }); - } +public record UpgradeOptions +{ + public IProtocol? SelectedProtocol { get; init; } + public UpgradeModeOverride ModeOverride { get; init; } + public TaskCompletionSource? CompletionSource { get; init; } + public object? Argument { get; set; } +} - Task SubListenAndBind(IChannel parentChannel, IPeerContext context, IProtocol protocol) - { - return SubListenAndBind(parentChannel, context, new ChannelRequest { SubProtocol = protocol }); - } +public enum UpgradeModeOverride +{ + None, + Dial, + Listen, } diff --git a/src/libp2p/Libp2p.Core/IChannelRequest.cs b/src/libp2p/Libp2p.Core/IChannelRequest.cs deleted file mode 100644 index 999b486e..00000000 --- a/src/libp2p/Libp2p.Core/IChannelRequest.cs +++ /dev/null @@ -1,10 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -namespace Nethermind.Libp2p.Core; - -public interface IChannelRequest -{ - IProtocol? SubProtocol { get; } - public TaskCompletionSource? CompletionSource { get; } -} diff --git a/src/libp2p/Libp2p.Core/ILibp2pBuilderContext.cs b/src/libp2p/Libp2p.Core/ILibp2pBuilderContext.cs new file mode 100644 index 00000000..41799cca --- /dev/null +++ b/src/libp2p/Libp2p.Core/ILibp2pBuilderContext.cs @@ -0,0 +1,10 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +namespace Nethermind.Libp2p.Core; + +public interface IProtocolStackSettings +{ + Dictionary? Protocols { get; set; } + ProtocolRef[]? TopProtocols { get; set; } +} diff --git a/src/libp2p/Libp2p.Core/ILibp2pPeerFactoryBuilder.cs b/src/libp2p/Libp2p.Core/ILibp2pPeerFactoryBuilder.cs index 353b8ad0..8c57cf37 100644 --- a/src/libp2p/Libp2p.Core/ILibp2pPeerFactoryBuilder.cs +++ b/src/libp2p/Libp2p.Core/ILibp2pPeerFactoryBuilder.cs @@ -6,4 +6,6 @@ namespace Nethermind.Libp2p.Core; public interface ILibp2pPeerFactoryBuilder : IPeerFactoryBuilder { public ILibp2pPeerFactoryBuilder WithPlaintextEnforced(); + public ILibp2pPeerFactoryBuilder WithPubsub(); + public ILibp2pPeerFactoryBuilder WithRelay(); } diff --git a/src/libp2p/Libp2p.Core/IListener.cs b/src/libp2p/Libp2p.Core/IListener.cs deleted file mode 100644 index 1e19c470..00000000 --- a/src/libp2p/Libp2p.Core/IListener.cs +++ /dev/null @@ -1,18 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -using Multiformats.Address; -using System.Runtime.CompilerServices; - -namespace Nethermind.Libp2p.Core; - -public interface IListener -{ - Multiaddress Address { get; } - - event OnConnection OnConnection; - Task DisconnectAsync(); - TaskAwaiter GetAwaiter(); -} - -public delegate Task OnConnection(IRemotePeer peer); diff --git a/src/libp2p/Libp2p.Core/ILocalPeer.cs b/src/libp2p/Libp2p.Core/ILocalPeer.cs index 64c7e1eb..0364a98b 100644 --- a/src/libp2p/Libp2p.Core/ILocalPeer.cs +++ b/src/libp2p/Libp2p.Core/ILocalPeer.cs @@ -2,11 +2,29 @@ // SPDX-License-Identifier: MIT using Multiformats.Address; +using System.Collections.ObjectModel; namespace Nethermind.Libp2p.Core; -public interface ILocalPeer : IPeer +public interface ILocalPeer { - Task DialAsync(Multiaddress addr, CancellationToken token = default); - Task ListenAsync(Multiaddress addr, CancellationToken token = default); + Identity Identity { get; } + + Task DialAsync(Multiaddress addr, CancellationToken token = default); + Task DialAsync(Multiaddress[] samePeerAddrs, CancellationToken token = default); + + /// + /// Find existing session or dial a peer if found in peer store + /// + Task DialAsync(PeerId peerId, CancellationToken token = default); + + Task StartListenAsync(Multiaddress[]? addrs = default, CancellationToken token = default); + + Task DisconnectAsync(); + + ObservableCollection ListenAddresses { get; } + + event Connected? OnConnected; } + +public delegate Task Connected(ISession newSession); diff --git a/src/libp2p/Libp2p.Core/IPeer.cs b/src/libp2p/Libp2p.Core/IPeer.cs deleted file mode 100644 index 148e9885..00000000 --- a/src/libp2p/Libp2p.Core/IPeer.cs +++ /dev/null @@ -1,12 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -using Multiformats.Address; - -namespace Nethermind.Libp2p.Core; - -public interface IPeer -{ - Identity Identity { get; set; } - Multiaddress Address { get; set; } -} diff --git a/src/libp2p/Libp2p.Core/IPeerContext.cs b/src/libp2p/Libp2p.Core/IPeerContext.cs index e9fe0425..be01190e 100644 --- a/src/libp2p/Libp2p.Core/IPeerContext.cs +++ b/src/libp2p/Libp2p.Core/IPeerContext.cs @@ -2,34 +2,53 @@ // SPDX-License-Identifier: MIT using Multiformats.Address; -using System.Collections.Concurrent; +using Nethermind.Libp2p.Core.Dto; namespace Nethermind.Libp2p.Core; -public interface IPeerContext +public interface ITransportContext { - string Id { get; } - IPeer LocalPeer { get; } - IPeer RemotePeer { get; } + ILocalPeer Peer { get; } + void ListenerReady(Multiaddress addr); + INewConnectionContext CreateConnection(); +} - Multiaddress RemoteEndpoint { get; set; } - Multiaddress LocalEndpoint { get; set; } +public interface IContextState +{ + string Id { get; } + State State { get; } +} - // TODO: Get rid of this: - IPeerContext Fork(); +public interface IConnectionContext : ITransportContext, IChannelFactory, IContextState +{ + UpgradeOptions? UpgradeOptions { get; } + Task DisconnectAsync(); + INewSessionContext UpgradeToSession(); +} - #region Allows muxer to manage session and channels for the app protocols - BlockingCollection SubDialRequests { get; } +public interface ISessionContext : IConnectionContext +{ + Task DialAsync() where TProtocol : ISessionProtocol; + Task DialAsync(ISessionProtocol protocol); +} - IChannelRequest? SpecificProtocolRequest { get; set; } - event RemotePeerConnected OnRemotePeerConnection; - event ListenerReady OnListenerReady; +public interface INewConnectionContext : IDisposable, IChannelFactory, IContextState +{ + ILocalPeer Peer { get; } + CancellationToken Token { get; } + INewSessionContext UpgradeToSession(); +} - void Connected(IPeer peer); - void ListenerReady(); - #endregion +public interface INewSessionContext : IDisposable, INewConnectionContext +{ + IEnumerable DialRequests { get; } } -public delegate void RemotePeerConnected(IRemotePeer peer); -public delegate void ListenerReady(); +public class State +{ + public Multiaddress? LocalAddress { get; set; } + public Multiaddress? RemoteAddress { get; set; } + public PublicKey? RemotePublicKey { get; set; } + public PeerId? RemotePeerId => RemoteAddress?.GetPeerId(); +} diff --git a/src/libp2p/Libp2p.Core/IPeerFactory.cs b/src/libp2p/Libp2p.Core/IPeerFactory.cs index c48bbba1..7531410f 100644 --- a/src/libp2p/Libp2p.Core/IPeerFactory.cs +++ b/src/libp2p/Libp2p.Core/IPeerFactory.cs @@ -1,11 +1,9 @@ // SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited // SPDX-License-Identifier: MIT -using Multiformats.Address; - namespace Nethermind.Libp2p.Core; public interface IPeerFactory { - ILocalPeer Create(Identity? identity = default, Multiaddress? localAddr = default); + ILocalPeer Create(Identity? identity = default); } diff --git a/src/libp2p/Libp2p.Core/IPeerFactoryBuilder.cs b/src/libp2p/Libp2p.Core/IPeerFactoryBuilder.cs index e04af45b..2be44d06 100644 --- a/src/libp2p/Libp2p.Core/IPeerFactoryBuilder.cs +++ b/src/libp2p/Libp2p.Core/IPeerFactoryBuilder.cs @@ -5,7 +5,7 @@ namespace Nethermind.Libp2p.Core; public interface IPeerFactoryBuilder { - IPeerFactoryBuilder AddAppLayerProtocol(TProtocol? instance = default) where TProtocol : IProtocol; + IPeerFactoryBuilder AddAppLayerProtocol(TProtocol? instance = default, bool isExposed = true) where TProtocol : IProtocol; IPeerFactory Build(); IEnumerable AppLayerProtocols { get; } } diff --git a/src/libp2p/Libp2p.Core/IProtocol.cs b/src/libp2p/Libp2p.Core/IProtocol.cs index a8b207c9..4a4890a8 100644 --- a/src/libp2p/Libp2p.Core/IProtocol.cs +++ b/src/libp2p/Libp2p.Core/IProtocol.cs @@ -1,31 +1,48 @@ // SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited // SPDX-License-Identifier: MIT +using Multiformats.Address; + namespace Nethermind.Libp2p.Core; -// TODO: Try the synchronous approach public interface IProtocol { - /// - /// Id used to during connection establishedment, exchanging information about protocol versions and so on - /// string Id { get; } +} + +public interface ITransportProtocol : IProtocol +{ + static bool IsAddressMatch(IProtocol proto, Multiaddress addr) => (bool)proto.GetType() + .GetMethod(nameof(IsAddressMatch))!.Invoke(null, [addr])!; + static Multiaddress[] GetDefaultAddresses(IProtocol proto, PeerId peerId) => (Multiaddress[])proto.GetType() + .GetMethod(nameof(GetDefaultAddresses))!.Invoke(null, [peerId])!; + - /// - /// Actively dials a peer - /// - /// A channel to communicate with a bottom layer protocol - /// Factory that spawns new channels used to interact with top layer protocols - /// Holds information about local and remote peers - /// - Task DialAsync(IChannel downChannel, IChannelFactory? upChannelFactory, IPeerContext context); - - /// - /// Opens a channel to listen to a remote peer - /// - /// A channel to communicate with a bottom layer protocol - /// Factory that spawns new channels used to interact with top layer protocols - /// Holds information about local and remote peers - /// - Task ListenAsync(IChannel downChannel, IChannelFactory? upChannelFactory, IPeerContext context); + static abstract Multiaddress[] GetDefaultAddresses(PeerId peerId); + static abstract bool IsAddressMatch(Multiaddress addr); + + Task ListenAsync(ITransportContext context, Multiaddress listenAddr, CancellationToken token); + Task DialAsync(ITransportContext context, Multiaddress remoteAddr, CancellationToken token); +} + +public interface IConnectionProtocol : IProtocol +{ + Task ListenAsync(IChannel downChannel, IConnectionContext context); + Task DialAsync(IChannel downChannel, IConnectionContext context); +} + + +public interface ISessionListenerProtocol : IProtocol +{ + Task ListenAsync(IChannel downChannel, ISessionContext context); +} + +public interface ISessionProtocol : ISessionListenerProtocol +{ + Task DialAsync(IChannel downChannel, ISessionContext context); +} + +public interface ISessionProtocol : ISessionListenerProtocol +{ + Task DialAsync(IChannel downChannel, ISessionContext context, TRequest request); } diff --git a/src/libp2p/Libp2p.Core/IReader.cs b/src/libp2p/Libp2p.Core/IReader.cs index fc167e7d..d8105178 100644 --- a/src/libp2p/Libp2p.Core/IReader.cs +++ b/src/libp2p/Libp2p.Core/IReader.cs @@ -12,7 +12,6 @@ public interface IReader { ValueTask ReadAsync(int length, ReadBlockingMode blockingMode = ReadBlockingMode.WaitAll, CancellationToken token = default); - #region Read helpers async IAsyncEnumerable> ReadAllAsync( [EnumeratorCancellation] CancellationToken token = default) diff --git a/src/libp2p/Libp2p.Core/IRemotePeer.cs b/src/libp2p/Libp2p.Core/IRemotePeer.cs deleted file mode 100644 index be9cf928..00000000 --- a/src/libp2p/Libp2p.Core/IRemotePeer.cs +++ /dev/null @@ -1,10 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -namespace Nethermind.Libp2p.Core; - -public interface IRemotePeer : IPeer -{ - Task DialAsync(CancellationToken token = default) where TProtocol : IProtocol; - Task DisconnectAsync(); -} diff --git a/src/libp2p/Libp2p.Core/ISession.cs b/src/libp2p/Libp2p.Core/ISession.cs new file mode 100644 index 00000000..0bf75d29 --- /dev/null +++ b/src/libp2p/Libp2p.Core/ISession.cs @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Multiformats.Address; + +namespace Nethermind.Libp2p.Core; + +public interface ISession +{ + Multiaddress RemoteAddress { get; } + Task DialAsync(CancellationToken token = default) where TProtocol : ISessionProtocol; + Task DialAsync(TRequest request, CancellationToken token = default) where TProtocol : ISessionProtocol; + Task DisconnectAsync(); +} diff --git a/src/libp2p/Libp2p.Core/IWriter.cs b/src/libp2p/Libp2p.Core/IWriter.cs index 3948e798..7e652811 100644 --- a/src/libp2p/Libp2p.Core/IWriter.cs +++ b/src/libp2p/Libp2p.Core/IWriter.cs @@ -47,8 +47,12 @@ ValueTask WriteSizeAndDataAsync(byte[] data) async ValueTask WriteSizeAndProtobufAsync(T grpcMessage) where T : IMessage { - byte[] serializedMessage = grpcMessage.ToByteArray(); - await WriteSizeAndDataAsync(serializedMessage); + int length = grpcMessage.CalculateSize(); + byte[] buf = new byte[VarInt.GetSizeInBytes(length) + length]; + int offset = 0; + VarInt.Encode(length, buf, ref offset); + grpcMessage.WriteTo(buf.AsSpan(offset)); + await WriteAsync(new ReadOnlySequence(buf)); } ValueTask WriteAsync(ReadOnlySequence bytes, CancellationToken token = default); diff --git a/src/libp2p/Libp2p.Core/Identity.cs b/src/libp2p/Libp2p.Core/Identity.cs index 212b7a5b..44d5a83b 100644 --- a/src/libp2p/Libp2p.Core/Identity.cs +++ b/src/libp2p/Libp2p.Core/Identity.cs @@ -149,10 +149,7 @@ private static PublicKey GetPublicKey(PrivateKey privateKey) public bool VerifySignature(byte[] message, byte[] signature) { - if (PublicKey is null) - { - throw new ArgumentNullException(nameof(PublicKey)); - } + ArgumentNullException.ThrowIfNull(PublicKey); switch (PublicKey.Type) { @@ -164,6 +161,7 @@ public bool VerifySignature(byte[] message, byte[] signature) { using RSA rsa = RSA.Create(); rsa.ImportSubjectPublicKeyInfo(PublicKey.Data.Span, out _); + return rsa.VerifyData(message, signature, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); } case KeyType.Secp256K1: @@ -242,54 +240,4 @@ public byte[] Sign(byte[] message) } public PeerId PeerId => new(PublicKey); - - - //public byte[] CreateSignedEnvelope(byte[] message) - //{ - // if (PrivateKey is null) - // { - // throw new ArgumentException(nameof(PrivateKey)); - // } - - // switch (PublicKey.Type) - // { - // case KeyType.Ed25519: - // { - // byte[] sig = new byte[Ed25519.SignatureSize]; - // Ed25519.Sign(PrivateKey.Data.ToByteArray(), 0, PublicKey.Data.ToByteArray(), 0, - // message, 0, message.Length, sig, 0); - // return sig; - // } - // case KeyType.Ecdsa: - // { - // ECDsa e = ECDsa.Create(); - // e.ImportECPrivateKey(PrivateKey.Data.Span, out _); - // return e.SignData(message, HashAlgorithmName.SHA256, - // DSASignatureFormat.Rfc3279DerSequence); - // } - // case KeyType.Rsa: - // { - // using RSA rsa = RSA.Create(); - // rsa.ImportRSAPrivateKey(PrivateKey.Data.Span, out _); - // return rsa.SignData(message, 0, message.Length, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); - // } - // case KeyType.Secp256K1: - // { - // X9ECParameters curve = CustomNamedCurves.GetByName("secp256k1"); - // ISigner signer = SignerUtilities.GetSigner("SHA-256withECDSA"); - - // ECPrivateKeyParameters privateKeyParams = new( - // "ECDSA", - // new BigInteger(1, PrivateKey.Data.ToArray()), - // new ECDomainParameters(curve) - // ); - - // signer.Init(true, privateKeyParams); - // signer.BlockUpdate(message, 0, message.Length); - // return signer.GenerateSignature(); - // } - // default: - // throw new NotImplementedException($"{PublicKey.Type} is not supported"); - // } - //} } diff --git a/src/libp2p/Libp2p.Core/Libp2p.Core.csproj b/src/libp2p/Libp2p.Core/Libp2p.Core.csproj index 8868b8b7..ddb785e4 100644 --- a/src/libp2p/Libp2p.Core/Libp2p.Core.csproj +++ b/src/libp2p/Libp2p.Core/Libp2p.Core.csproj @@ -19,6 +19,9 @@ + + Never + diff --git a/src/libp2p/Libp2p.Core/LocalPeer.Session.cs b/src/libp2p/Libp2p.Core/LocalPeer.Session.cs new file mode 100644 index 00000000..b9779fd6 --- /dev/null +++ b/src/libp2p/Libp2p.Core/LocalPeer.Session.cs @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Multiformats.Address; +using Nethermind.Libp2p.Core.Exceptions; +using System.Collections.Concurrent; + +namespace Nethermind.Libp2p.Core; + +public partial class LocalPeer +{ + public class Session(LocalPeer peer) : ISession + { + private static int SessionIdCounter; + + public string Id { get; } = Interlocked.Increment(ref SessionIdCounter).ToString(); + public State State { get; } = new(); + public Multiaddress RemoteAddress => State.RemoteAddress ?? throw new Libp2pException("Session contains uninitialized remote address."); + + private readonly BlockingCollection SubDialRequests = []; + + public async Task DialAsync(CancellationToken token = default) where TProtocol : ISessionProtocol + { + TaskCompletionSource tcs = new(); + SubDialRequests.Add(new UpgradeOptions() { CompletionSource = tcs!, SelectedProtocol = peer.GetProtocolInstance() }, token); + await tcs.Task; + MarkAsConnected(); + } + + public async Task DialAsync(ISessionProtocol protocol, CancellationToken token = default) + { + TaskCompletionSource tcs = new(); + SubDialRequests.Add(new UpgradeOptions() { CompletionSource = tcs, SelectedProtocol = protocol }, token); + await tcs.Task; + MarkAsConnected(); + } + + public async Task DialAsync(TRequest request, CancellationToken token = default) where TProtocol : ISessionProtocol + { + TaskCompletionSource tcs = new(); + SubDialRequests.Add(new UpgradeOptions() { CompletionSource = tcs, SelectedProtocol = peer.GetProtocolInstance(), Argument = request }, token); + await tcs.Task; + MarkAsConnected(); + return (TResponse)tcs.Task.Result; + } + + + private CancellationTokenSource connectionTokenSource = new(); + + public Task DisconnectAsync() + { + connectionTokenSource.Cancel(); + peer.sessions.Remove(this); + return Task.CompletedTask; + } + + public CancellationToken ConnectionToken => connectionTokenSource.Token; + + + public TaskCompletionSource ConnectedTcs = new(); + public Task Connected => ConnectedTcs.Task; + internal void MarkAsConnected() => ConnectedTcs?.TrySetResult(); + + + internal IEnumerable GetRequestQueue() => SubDialRequests.GetConsumingEnumerable(ConnectionToken); + + } +} diff --git a/src/libp2p/Libp2p.Core/Peer.cs b/src/libp2p/Libp2p.Core/Peer.cs new file mode 100644 index 00000000..0cc4acf9 --- /dev/null +++ b/src/libp2p/Libp2p.Core/Peer.cs @@ -0,0 +1,417 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Microsoft.Extensions.Logging; +using Multiformats.Address; +using Multiformats.Address.Protocols; +using Nethermind.Libp2p.Core.Context; +using Nethermind.Libp2p.Core.Discovery; +using Nethermind.Libp2p.Core.Exceptions; +using Nethermind.Libp2p.Core.Extensions; +using System.Collections.ObjectModel; + +namespace Nethermind.Libp2p.Core; + +public partial class LocalPeer : ILocalPeer +{ + protected readonly ILogger? _logger; + protected readonly PeerStore _peerStore; + protected readonly IProtocolStackSettings _protocolStackSettings; + + Dictionary> listenerReadyTcs = []; + private ObservableCollection sessions { get; } = []; + + + public LocalPeer(Identity identity, PeerStore peerStore, IProtocolStackSettings protocolStackSettings, ILoggerFactory? loggerFactory = null) + { + Identity = identity; + _peerStore = peerStore; + _protocolStackSettings = protocolStackSettings; + _logger = loggerFactory?.CreateLogger($"peer-{identity.PeerId}"); + } + + public override string ToString() + { + return $"peer({Identity.PeerId}): addresses {string.Join(",", ListenAddresses)} sessions {string.Join("|", sessions.Select(x => $"{x.State.RemotePeerId}"))}"; + } + + public Identity Identity { get; } + + public ObservableCollection ListenAddresses { get; } = []; + + + protected virtual Task ConnectedTo(ISession peer, bool isDialer) => Task.CompletedTask; + + protected virtual ProtocolRef SelectProtocol(Multiaddress addr) + { + if (_protocolStackSettings.TopProtocols is null or []) + { + throw new Libp2pSetupException($"Protocols are not set in {nameof(_protocolStackSettings)}"); + } + + return _protocolStackSettings.TopProtocols.First(p => ITransportProtocol.IsAddressMatch(p.Protocol, addr)); + } + + protected virtual Multiaddress[] GetDefaultAddresses() + { + if (_protocolStackSettings.TopProtocols is null or []) + { + throw new Libp2pSetupException($"Protocols are not set in {nameof(_protocolStackSettings)}"); + } + + return _protocolStackSettings.TopProtocols.SelectMany(p => ITransportProtocol.GetDefaultAddresses(p.Protocol, Identity.PeerId)).ToArray(); + } + + protected virtual IEnumerable PrepareAddresses(Multiaddress[] addrs) + { + foreach (Multiaddress addr in addrs) + { + if (!addr.Has()) + { + yield return addr.Add(Identity.PeerId.ToString()); + } + else + { + yield return addr; + } + } + } + + + public event Connected? OnConnected; + + public virtual async Task StartListenAsync(Multiaddress[]? addrs = default, CancellationToken token = default) + { + addrs ??= GetDefaultAddresses(); + + List listenTasks = new(addrs.Length); + + foreach (Multiaddress addr in PrepareAddresses(addrs)) + { + ProtocolRef listenerProtocol = SelectProtocol(addr); + + if (listenerProtocol.Protocol is not ITransportProtocol transportProtocol) + { + throw new Libp2pSetupException($"{nameof(ITransportProtocol)} should be implemented by {listenerProtocol.GetType()}"); + } + + ITransportContext ctx = new TransportContext(this, listenerProtocol, true); + TaskCompletionSource tcs = new(); + listenerReadyTcs[ctx] = tcs; + + _ = transportProtocol.ListenAsync(ctx, addr, token).ContinueWith(t => + { + if (t.IsFaulted) + { + tcs.SetException(t.Exception); + } + ListenAddresses.Remove(tcs.Task.Result); + }); + + listenTasks.Add(tcs.Task.WaitAsync(TimeSpan.FromMilliseconds(5000)).ContinueWith(t => + { + if (t.IsFaulted) + { + _logger?.LogDebug($"Failed to start listener for an address"); + return null; + } + + return t.Result; + })); + } + + await Task.WhenAll(listenTasks); + + foreach (Task startTask in listenTasks) + { + Multiaddress? addr = (startTask as Task)?.Result; + + if (addr is not null) + { + ListenAddresses.Add(addr); + } + } + } + + public void ListenerReady(object sender, Multiaddress addr) + { + if (listenerReadyTcs.Remove(sender, out TaskCompletionSource? tcs)) + { + tcs.SetResult(addr); + } + } + + public INewConnectionContext CreateConnection(ProtocolRef proto, Session? session, bool isListener) + { + session ??= new(this); + return new NewConnectionContext(this, session, proto, isListener, null); + } + + public INewSessionContext UpgradeToSession(Session session, ProtocolRef proto, bool isListener) + { + PeerId? remotePeerId = session.State.RemotePeerId ?? + throw new Libp2pSetupException($"{nameof(session.State.RemoteAddress)} should be initialiazed before session creation"); + + lock (sessions) + { + if (sessions.Any(s => !ReferenceEquals(session, s) && s.State.RemoteAddress.GetPeerId() == remotePeerId)) + { + _ = session.DisconnectAsync(); + throw new SessionExistsException(remotePeerId); + } + _logger?.LogDebug($"New session with {remotePeerId}"); + sessions.Add(session); + } + + Task initializeSession = ConnectedTo(session, !isListener); + initializeSession.ContinueWith(t => + { + if (t.IsFaulted) + { + _ = session.DisconnectAsync(); + _logger?.LogError(t.Exception.InnerException, $"Disconnecting due to exception"); + return; + } + session.ConnectedTcs.TrySetResult(); + OnConnected?.Invoke(session); + }); + return new NewSessionContext(this, session, proto, isListener, null); + } + + internal IEnumerable GetProtocolsFor(ProtocolRef protocol) + { + if (_protocolStackSettings.Protocols is null) + { + throw new Libp2pSetupException($"Protocols are not set in {nameof(_protocolStackSettings)}"); + } + + if (!_protocolStackSettings.Protocols.ContainsKey(protocol)) + { + throw new Libp2pSetupException($"{protocol} is not added"); + } + + return _protocolStackSettings.Protocols[protocol].Select(p => p.Protocol); + } + + + // TODO: Remove locking in the entire stack, look only on level for the given parent protocol + internal IProtocol? GetProtocolInstance() + { + return _protocolStackSettings.Protocols?.Keys.FirstOrDefault(p => p.Protocol.GetType() == typeof(TProtocol))?.Protocol; + } + + public async Task DialAsync(Multiaddress[] addrs, CancellationToken token) + { + PeerId? remotePeerId = addrs.FirstOrDefault()?.GetPeerId(); + ISession? existingSession = sessions.FirstOrDefault(s => s.State.RemotePeerId == remotePeerId); + + if (existingSession is not null) + { + return existingSession; + } + + Dictionary cancellations = []; + foreach (Multiaddress addr in addrs) + { + cancellations[addr] = CancellationTokenSource.CreateLinkedTokenSource(token); + } + + Task timeoutTask = Task.Delay(15_000, token); + Task wait = await TaskHelper.FirstSuccess([timeoutTask, .. addrs.Select(addr => DialAsync(addr, cancellations[addr].Token))]); + + if (wait == timeoutTask) + { + throw new TimeoutException(); + } + + ISession firstConnected = (wait as Task)!.Result; + + foreach (KeyValuePair c in cancellations) + { + if (c.Key != firstConnected.RemoteAddress) + { + c.Value.Cancel(false); + } + } + + return firstConnected; + } + + public async Task DialAsync(Multiaddress addr, CancellationToken token = default) + { + ProtocolRef dialerProtocol = SelectProtocol(addr); + + if (dialerProtocol.Protocol is not ITransportProtocol transportProtocol) + { + throw new Libp2pSetupException($"{nameof(ITransportProtocol)} should be implemented by {dialerProtocol.GetType()}"); + } + + Session session = new(this); + ITransportContext ctx = new DialerTransportContext(this, session, dialerProtocol); + + Task dialingTask = transportProtocol.DialAsync(ctx, addr, token); + + Task dialingResult = await Task.WhenAny(dialingTask, session.Connected); + + if (dialingResult == dialingTask) + { + if (dialingResult.IsFaulted) + { + throw dialingResult.Exception; + } + throw new Libp2pException("Not able to dial the peer"); + } + await session.Connected; + return session; + } + + public Task DialAsync(PeerId peerId, CancellationToken token = default) + { + ISession? existingSession = sessions.FirstOrDefault(s => s.State.RemotePeerId == peerId); + + if (existingSession is not null) + { + return Task.FromResult(existingSession); + } + + PeerStore.PeerInfo existingPeerInfo = _peerStore.GetPeerInfo(peerId); + + if (existingPeerInfo?.Addrs is null) + { + throw new Libp2pException("Peer not found"); + } + + return DialAsync([.. existingPeerInfo.Addrs], token); + } + + private static void MapToTaskCompletionSource(Task t, TaskCompletionSource tcs) + { + if (t.IsCompletedSuccessfully) + { + tcs.SetResult(t.GetType().GenericTypeArguments.Any() ? t.GetType().GetProperty("Result")!.GetValue(t) : null); + return; + } + if (t.IsCanceled) + { + tcs.SetCanceled(); + return; + } + tcs.SetException(t.Exception!); + } + + private static void MapToTaskCompletionSource(Task t, TaskCompletionSource tcs) + { + if (t.IsCompletedSuccessfully) + { + tcs.SetResult(); + return; + } + if (t.IsCanceled) + { + tcs.SetCanceled(); + return; + } + tcs.SetException(t.Exception!); + } + + internal IChannel Upgrade(Session session, ProtocolRef parentProtocol, IProtocol? upgradeProtocol, UpgradeOptions? options, bool isListener) + { + Channel downChannel = new(); + + _ = Upgrade(session, downChannel.Reverse, parentProtocol, upgradeProtocol, options, isListener); + + return downChannel; + } + + internal Task Upgrade(Session session, IChannel downChannel, ProtocolRef parentProtocol, IProtocol? upgradeProtocol, UpgradeOptions? options, bool isListener) + { + if (_protocolStackSettings.Protocols is null) + { + throw new Libp2pSetupException($"Protocols are not set in {nameof(_protocolStackSettings)}"); + } + + if (upgradeProtocol is not null && !_protocolStackSettings.Protocols[parentProtocol].Any(p => p.Protocol == upgradeProtocol)) + { + _protocolStackSettings.Protocols.Add(new ProtocolRef(upgradeProtocol, false), []); + } + + ProtocolRef top = upgradeProtocol is not null ? + _protocolStackSettings.Protocols[parentProtocol].FirstOrDefault(p => p.Protocol == upgradeProtocol, _protocolStackSettings.Protocols.Keys.First(k => k.Protocol == upgradeProtocol)) : + _protocolStackSettings.Protocols[parentProtocol].Single(); + + isListener = options?.ModeOverride switch { UpgradeModeOverride.Dial => false, UpgradeModeOverride.Listen => true, _ => isListener }; + + _logger?.LogInformation($"Upgrade and bind {parentProtocol} to {top}, listen={isListener}"); + + Task upgradeTask; + switch (top.Protocol) + { + case IConnectionProtocol tProto: + { + ConnectionContext ctx = new(this, session, top, isListener, options); + upgradeTask = isListener ? tProto.ListenAsync(downChannel, ctx) : tProto.DialAsync(downChannel, ctx); + break; + } + case ISessionProtocol sProto: + { + SessionContext ctx = new(this, session, top, isListener, options); + upgradeTask = isListener ? sProto.ListenAsync(downChannel, ctx) : sProto.DialAsync(downChannel, ctx); + break; + } + default: + if (isListener && top.Protocol is ISessionListenerProtocol listenerProtocol) + { + SessionContext ctx = new(this, session, top, isListener, options); + upgradeTask = listenerProtocol.ListenAsync(downChannel, ctx); + break; + } + + Type? genericInterface = top.Protocol.GetType().GetInterfaces() + .FirstOrDefault(i => + i.IsGenericType && + i.GetGenericTypeDefinition() == typeof(ISessionProtocol<,>)); + + if (genericInterface != null) + { + Type[] genericArguments = genericInterface.GetGenericArguments(); + Type requestType = genericArguments[0]; + + if (options?.Argument is not null && !options.Argument.GetType().IsAssignableTo(requestType)) + { + throw new ArgumentException($"Invalid request. Argument is of {options.Argument.GetType()} type which is not assignable to {requestType.FullName}"); + } + + // Dynamically invoke DialAsync + System.Reflection.MethodInfo? dialAsyncMethod = genericInterface.GetMethod("DialAsync"); + if (dialAsyncMethod != null) + { + SessionContext ctx = new(this, session, top, isListener, options); + upgradeTask = (Task)dialAsyncMethod.Invoke(top.Protocol, [downChannel, ctx, options?.Argument])!; + break; + } + } + throw new Libp2pSetupException($"Protocol {top.Protocol} does not implement proper protocol interface"); + } + + if (options?.SelectedProtocol == top.Protocol && options?.CompletionSource is not null) + { + _ = upgradeTask.ContinueWith(async t => + { + MapToTaskCompletionSource(t, options.CompletionSource); + await downChannel.CloseAsync(); + }); + } + + return upgradeTask.ContinueWith(t => + { + if (t.IsFaulted) + { + _logger?.LogError($"Upgrade task failed with {t.Exception}"); + } + _ = downChannel.CloseAsync(); + _logger?.LogInformation($"Finished#2 {parentProtocol} to {top}, listen={isListener}"); + }); + } + + public Task DisconnectAsync() => Task.WhenAll(sessions.ToArray().Select(s => s.DisconnectAsync())); +} diff --git a/src/libp2p/Libp2p.Core/PeerConnectionException.cs b/src/libp2p/Libp2p.Core/PeerConnectionException.cs deleted file mode 100644 index f16439ba..00000000 --- a/src/libp2p/Libp2p.Core/PeerConnectionException.cs +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-FileCopyrightText:2023 Demerzel Solutions Limited -// SPDX-License-Identifier:MIT - -namespace Nethermind.Libp2p.Core; - -public class PeerConnectionException : Exception -{ - -} diff --git a/src/libp2p/Libp2p.Core/PeerContext.cs b/src/libp2p/Libp2p.Core/PeerContext.cs deleted file mode 100644 index 19abb63c..00000000 --- a/src/libp2p/Libp2p.Core/PeerContext.cs +++ /dev/null @@ -1,39 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -using Multiformats.Address; -using System.Collections.Concurrent; - -namespace Nethermind.Libp2p.Core; - -public class PeerContext : IPeerContext -{ - public string Id { get; set; } - public IPeer LocalPeer { get; set; } - public IPeer RemotePeer { get; set; } - public Multiaddress RemoteEndpoint { get; set; } - public Multiaddress LocalEndpoint { get; set; } - public BlockingCollection SubDialRequests { get; set; } = new(); - public IChannelRequest? SpecificProtocolRequest { get; set; } - - public IPeerContext Fork() - { - PeerContext result = (PeerContext)MemberwiseClone(); - result.RemotePeer = ((PeerFactory.RemotePeer)RemotePeer).Fork(); - return result; - } - - - - public event RemotePeerConnected? OnRemotePeerConnection; - public void Connected(IPeer peer) - { - OnRemotePeerConnection?.Invoke((IRemotePeer)peer); - } - - public event ListenerReady? OnListenerReady; - public void ListenerReady() - { - OnListenerReady?.Invoke(); - } -} diff --git a/src/libp2p/Libp2p.Core/PeerFactory.cs b/src/libp2p/Libp2p.Core/PeerFactory.cs index d39a614e..f41610ee 100644 --- a/src/libp2p/Libp2p.Core/PeerFactory.cs +++ b/src/libp2p/Libp2p.Core/PeerFactory.cs @@ -1,237 +1,20 @@ // SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited // SPDX-License-Identifier: MIT -using Microsoft.Extensions.DependencyInjection; -using Multiformats.Address; -using Multiformats.Address.Protocols; -using System.Runtime.CompilerServices; +using Microsoft.Extensions.Logging; +using Nethermind.Libp2p.Core.Discovery; namespace Nethermind.Libp2p.Core; -public class PeerFactory : IPeerFactory +public class PeerFactory(IProtocolStackSettings protocolStackSettings, PeerStore peerStore, ILoggerFactory? loggerFactory = null) : IPeerFactory { - private readonly IServiceProvider _serviceProvider; + protected IProtocolStackSettings protocolStackSettings = protocolStackSettings; - private IProtocol _protocol; - private IChannelFactory _upChannelFactory; - private static int CtxId = 0; + protected PeerStore PeerStore { get; } = peerStore; + protected ILoggerFactory? LoggerFactory { get; } = loggerFactory; - public PeerFactory(IServiceProvider serviceProvider) + public virtual ILocalPeer Create(Identity? identity = default) { - _serviceProvider = serviceProvider; - } - - public virtual ILocalPeer Create(Identity? identity = default, Multiaddress? localAddr = default) - { - identity ??= new Identity(); - return new LocalPeer(this) { Identity = identity ?? new Identity(), Address = localAddr ?? $"/ip4/0.0.0.0/tcp/0/p2p/{identity.PeerId}" }; - } - - /// - /// PeerFactory interface ctor - /// - /// - /// - /// - /// - public void Setup(IProtocol protocol, IChannelFactory upChannelFactory) - { - _protocol = protocol; - _upChannelFactory = upChannelFactory; - } - - private async Task ListenAsync(LocalPeer peer, Multiaddress addr, CancellationToken token) - { - peer.Address = addr; - if (!peer.Address.Has()) - { - peer.Address = peer.Address.Add(peer.Identity.PeerId.ToString()); - } - - Channel chan = new(); - if (token != default) - { - token.Register(() => chan.CloseAsync()); - } - - TaskCompletionSource ts = new(); - - - PeerContext peerContext = new() - { - Id = $"ctx-{++CtxId}", - LocalPeer = peer, - }; - - peerContext.OnListenerReady += OnListenerReady; - - void OnListenerReady() - { - ts.SetResult(); - peerContext.OnListenerReady -= OnListenerReady; - } - - RemotePeer remotePeer = new(this, peer, peerContext); - peerContext.RemotePeer = remotePeer; - - PeerListener result = new(chan, peer); - peerContext.OnRemotePeerConnection += remotePeer => - { - if (((RemotePeer)remotePeer).LocalPeer != peer) - { - return; - } - - ConnectedTo(remotePeer, false) - .ContinueWith(t => { result.RaiseOnConnection(remotePeer); }, token); - }; - _ = _protocol.ListenAsync(chan, _upChannelFactory, peerContext); - - await ts.Task; - return result; - } - - protected virtual Task ConnectedTo(IRemotePeer peer, bool isDialer) - { - return Task.CompletedTask; - } - - private Task DialAsync(IPeerContext peerContext, CancellationToken token) where TProtocol : IProtocol - { - TaskCompletionSource cts = new(token); - peerContext.SubDialRequests.Add(new ChannelRequest - { - SubProtocol = (_serviceProvider.GetService() as ICreateProtocolInstance)!.CreateProtocolInstance(_serviceProvider), - CompletionSource = cts - }); - return cts.Task; - } - - protected virtual async Task DialAsync(LocalPeer peer, Multiaddress addr, CancellationToken token) - { - try - { - Channel chan = new(); - token.Register(() => _ = chan.CloseAsync()); - - PeerContext context = new() - { - Id = $"ctx-{++CtxId}", - LocalPeer = peer, - }; - RemotePeer result = new(this, peer, context) { Address = addr, Channel = chan }; - context.RemotePeer = result; - - TaskCompletionSource tcs = new(); - RemotePeerConnected remotePeerConnected = null!; - - remotePeerConnected = remotePeer => - { - if (((RemotePeer)remotePeer).LocalPeer != peer) - { - return; - } - - ConnectedTo(remotePeer, true).ContinueWith((t) => { tcs.TrySetResult(true); }); - context.OnRemotePeerConnection -= remotePeerConnected; - }; - context.OnRemotePeerConnection += remotePeerConnected; - - _ = _protocol.DialAsync(chan, _upChannelFactory, context); - - await tcs.Task; - return result; - } - catch - { - throw; - } - } - - private class PeerListener : IListener - { - private readonly Channel _chan; - private readonly LocalPeer _localPeer; - - public PeerListener(Channel chan, LocalPeer localPeer) - { - _chan = chan; - _localPeer = localPeer; - } - - public event OnConnection? OnConnection; - public Multiaddress Address => _localPeer.Address; - - public Task DisconnectAsync() - { - return _chan.CloseAsync().AsTask(); - } - - public TaskAwaiter GetAwaiter() - { - return _chan.GetAwaiter(); - } - - internal void RaiseOnConnection(IRemotePeer peer) - { - OnConnection?.Invoke(peer); - } - } - - protected class LocalPeer : ILocalPeer - { - private readonly PeerFactory _factory; - - public LocalPeer(PeerFactory factory) - { - _factory = factory; - } - - public Identity? Identity { get; set; } - public Multiaddress Address { get; set; } - - public Task DialAsync(Multiaddress addr, CancellationToken token = default) - { - return _factory.DialAsync(this, addr, token); - } - - public Task ListenAsync(Multiaddress addr, CancellationToken token = default) - { - return _factory.ListenAsync(this, addr, token); - } - } - - internal class RemotePeer : IRemotePeer - { - private readonly PeerFactory _factory; - private readonly IPeerContext peerContext; - - public RemotePeer(PeerFactory factory, ILocalPeer localPeer, IPeerContext peerContext) - { - _factory = factory; - LocalPeer = localPeer; - this.peerContext = peerContext; - } - - public Channel Channel { get; set; } - - public Identity Identity { get; set; } - public Multiaddress Address { get; set; } - internal ILocalPeer LocalPeer { get; } - - public Task DialAsync(CancellationToken token = default) where TProtocol : IProtocol - { - return _factory.DialAsync(peerContext, token); - } - - public Task DisconnectAsync() - { - return Channel.CloseAsync().AsTask(); - } - - public IPeer Fork() - { - return (IPeer)MemberwiseClone(); - } + return new LocalPeer(identity ?? new Identity(), PeerStore, protocolStackSettings, LoggerFactory); } } diff --git a/src/libp2p/Libp2p.Core/PeerFactoryBuilderBase.cs b/src/libp2p/Libp2p.Core/PeerFactoryBuilderBase.cs index 15f66979..790953aa 100644 --- a/src/libp2p/Libp2p.Core/PeerFactoryBuilderBase.cs +++ b/src/libp2p/Libp2p.Core/PeerFactoryBuilderBase.cs @@ -5,183 +5,99 @@ namespace Nethermind.Libp2p.Core; -public interface ICreateProtocolInstance +public class ProtocolRef(IProtocol protocol, bool isExposed = true) { - IProtocol CreateProtocolInstance(IServiceProvider serviceProvider, TProtocol? instance = default) where TProtocol : IProtocol; + static int RefIdCounter = 0; + + public string RefId { get; } = Interlocked.Increment(ref RefIdCounter).ToString(); + public IProtocol Protocol => protocol; + public bool IsExposed => isExposed; + + public string Id => Protocol.Id; + + public override string ToString() + { + return $"ref#{RefId}({Protocol.Id})"; + } } -public abstract class PeerFactoryBuilderBase : IPeerFactoryBuilder, ICreateProtocolInstance + +public abstract class PeerFactoryBuilderBase : IPeerFactoryBuilder where TBuilder : PeerFactoryBuilderBase, IPeerFactoryBuilder - where TPeerFactory : PeerFactory + where TPeerFactory : IPeerFactory { - private HashSet protocols = new(); + private readonly HashSet protocolInstances = []; - public IProtocol CreateProtocolInstance(IServiceProvider serviceProvider, TProtocol? instance = default) where TProtocol : IProtocol + private TProtocol CreateProtocolInstance(IServiceProvider serviceProvider, TProtocol? instance = default) where TProtocol : IProtocol { if (instance is not null) { - protocols.Add(instance); + protocolInstances.Add(instance); } - IProtocol? existing = instance ?? protocols.OfType().FirstOrDefault(); + IProtocol? existing = instance ?? protocolInstances.OfType().FirstOrDefault(); if (existing is null) { existing = ActivatorUtilities.GetServiceOrCreateInstance(serviceProvider); - protocols.Add(existing); + protocolInstances.Add(existing); } - return existing; + return (TProtocol)existing; } - private readonly List _appLayerProtocols = new(); - public IEnumerable AppLayerProtocols { get => _appLayerProtocols; } - internal readonly IServiceProvider ServiceProvider; + private readonly List _appLayerProtocols = []; + public IEnumerable AppLayerProtocols => _appLayerProtocols.Select(x => x.Protocol); - protected readonly ProtocolStack? _stack; + internal readonly IServiceProvider ServiceProvider; protected PeerFactoryBuilderBase(IServiceProvider? serviceProvider = default) { ServiceProvider = serviceProvider ?? new ServiceCollection().BuildServiceProvider(); } - protected ProtocolStack Over(TProtocol? instance = default) where TProtocol : IProtocol + public IPeerFactoryBuilder AddAppLayerProtocol(TProtocol? instance = default, bool isExposed = true) where TProtocol : IProtocol { - ProtocolStack result = new ProtocolStack(this, ServiceProvider, CreateProtocolInstance(ServiceProvider, instance), this); - result.Root = result; - return result; - } - - public IPeerFactoryBuilder AddAppLayerProtocol(TProtocol? instance = default) where TProtocol : IProtocol - { - _appLayerProtocols.Add(CreateProtocolInstance(ServiceProvider!, instance)); + _appLayerProtocols.Add(new ProtocolRef(CreateProtocolInstance(ServiceProvider!, instance), isExposed)); return (TBuilder)this; } - protected class ProtocolStack - { - private readonly IPeerFactoryBuilder builder; - private readonly IServiceProvider serviceProvider; - private readonly ICreateProtocolInstance createProtocolInstance; - - public ProtocolStack? Root { get; set; } - public ProtocolStack? Parent { get; private set; } - public ProtocolStack? PrevSwitch { get; private set; } - public IProtocol Protocol { get; } - public HashSet TopProtocols { get; } = new(); - public ChannelFactory UpChannelsFactory { get; } - - public ProtocolStack(IPeerFactoryBuilder builder, IServiceProvider serviceProvider, IProtocol protocol, ICreateProtocolInstance createProtocolInstance) - { - this.builder = builder; - this.serviceProvider = serviceProvider; - Protocol = protocol; - this.createProtocolInstance = createProtocolInstance; - UpChannelsFactory = ActivatorUtilities.GetServiceOrCreateInstance(serviceProvider); - } + protected abstract ProtocolRef[] BuildStack(IEnumerable additionalProtocols); - public ProtocolStack AddAppLayerProtocol(TProtocol? instance = default) where TProtocol : IProtocol - { - builder.AddAppLayerProtocol(instance); - return this; - } - - public ProtocolStack Over(TProtocol? instance = default) where TProtocol : IProtocol - { - ProtocolStack nextNode = new(builder, serviceProvider, createProtocolInstance.CreateProtocolInstance(serviceProvider!, instance), createProtocolInstance); - return Over(nextNode); - } + private Dictionary protocols = []; - public ProtocolStack Or(TProtocol? instance = default) where TProtocol : IProtocol - { - if (Parent is null) - { - throw new NotImplementedException(); - } - IProtocol protocol = createProtocolInstance.CreateProtocolInstance(serviceProvider!, instance); - ProtocolStack stack = new(builder, serviceProvider, protocol, createProtocolInstance); - return Or(stack); - } - - public ProtocolStack Over(ProtocolStack stack) + protected ProtocolRef[] Connect(ProtocolRef[] protocols, params ProtocolRef[][] upgradeToStacks) + { + ProtocolRef[] previous = protocols; + foreach (ProtocolRef[] upgradeTo in upgradeToStacks) { - PeerFactoryBuilderBase.ProtocolStack rootProto = stack.Root ?? stack; - TopProtocols.Add(rootProto); - - if (PrevSwitch != null) + foreach (ProtocolRef protocolRef in previous) { - PrevSwitch.Over(stack); - } - - rootProto.Root = stack.Root = Root ?? this; - rootProto.Parent = this; - - return stack; - } + this.protocols[protocolRef] = upgradeTo; - public ProtocolStack Or(ProtocolStack stack) - { - if (Parent is null) - { - throw new NotImplementedException(); + foreach (ProtocolRef upgradeToRef in upgradeTo) + { + this.protocols.TryAdd(upgradeToRef, []); + } } - stack.PrevSwitch = this; - return Parent.Over(stack); + previous = upgradeTo; } - public override string ToString() - { - return $"{Protocol.Id}({TopProtocols.Count}): {string.Join(" or ", TopProtocols.Select(p => p.Protocol.Id))}"; - } + return previous; } - protected abstract ProtocolStack BuildStack(); + protected ProtocolRef Get() where TProtocol : IProtocol + { + return new ProtocolRef(CreateProtocolInstance(ServiceProvider)); + } public IPeerFactory Build() { - ProtocolStack transportLayer = BuildStack(); - ProtocolStack? appLayer = default; - - foreach (IProtocol appLayerProtocol in _appLayerProtocols) - { - appLayer = appLayer is null ? transportLayer.Over(appLayerProtocol) : appLayer.Or(appLayerProtocol); - } - - ProtocolStack? root = transportLayer.Root; - - if (root?.Protocol is null || root.UpChannelsFactory is null) - { - throw new ApplicationException("Root protocol is not properly defined"); - } - - static void SetupChannelFactories(ProtocolStack root) - { - root.UpChannelsFactory.Setup(new Dictionary(root.TopProtocols - .Select(p => new KeyValuePair(p.Protocol, p.UpChannelsFactory)))); - foreach (ProtocolStack topProto in root.TopProtocols) - { - if (!root.TopProtocols.Any()) - { - return; - } - SetupChannelFactories(topProto); - } - } - - SetupChannelFactories(root); + IProtocolStackSettings protocolStackSettings = ActivatorUtilities.GetServiceOrCreateInstance(ServiceProvider); + protocolStackSettings.TopProtocols = BuildStack(_appLayerProtocols.ToArray()); + protocolStackSettings.Protocols = protocols; TPeerFactory result = ActivatorUtilities.GetServiceOrCreateInstance(ServiceProvider); - result.Setup(root?.Protocol!, root!.UpChannelsFactory); - return result; - } - - private class Layer - { - public List Protocols { get; } = new(); - public bool IsSelector { get; set; } - public override string ToString() - { - return (IsSelector ? "(selector)" : "") + string.Join(",", Protocols.Select(p => p.Id)); - } + return result; } } diff --git a/src/libp2p/Libp2p.Core/PeerId.cs b/src/libp2p/Libp2p.Core/PeerId.cs index 29c0a303..5ca6920b 100644 --- a/src/libp2p/Libp2p.Core/PeerId.cs +++ b/src/libp2p/Libp2p.Core/PeerId.cs @@ -127,7 +127,7 @@ static int ComputeHash(params byte[] data) return hashCode ??= ComputeHash(Bytes); } - public static bool operator ==(PeerId left, PeerId right) + public static bool operator ==(PeerId? left, PeerId? right) { if (left is null) { @@ -141,7 +141,7 @@ static int ComputeHash(params byte[] data) return left.Equals(right); } - public static bool operator !=(PeerId left, PeerId right) => !(left == right); + public static bool operator !=(PeerId? left, PeerId? right) => !(left == right); #endregion } diff --git a/src/libp2p/Libp2p.Core/ProtocolStackSettings.cs b/src/libp2p/Libp2p.Core/ProtocolStackSettings.cs new file mode 100644 index 00000000..f3cc0da2 --- /dev/null +++ b/src/libp2p/Libp2p.Core/ProtocolStackSettings.cs @@ -0,0 +1,10 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +namespace Nethermind.Libp2p.Core; + +public class ProtocolStackSettings : IProtocolStackSettings +{ + public ProtocolRef[]? TopProtocols { get; set; } + public Dictionary? Protocols { get; set; } +} diff --git a/src/libp2p/Libp2p.Core/Stream.cs b/src/libp2p/Libp2p.Core/Stream.cs index 4b1d1720..7f3e36e5 100644 --- a/src/libp2p/Libp2p.Core/Stream.cs +++ b/src/libp2p/Libp2p.Core/Stream.cs @@ -38,7 +38,7 @@ public override int Read(Span buffer) { if (buffer is { Length: 0 } && _canRead) return 0; - var result = _chan.ReadAsync(buffer.Length, ReadBlockingMode.WaitAny).Result; + ReadResult result = _chan.ReadAsync(buffer.Length, ReadBlockingMode.WaitAny).Result; if (result.Result != IOResult.Ok) { _canRead = false; @@ -72,7 +72,7 @@ public override async Task ReadAsync(byte[] buffer, int offset, int count, { if (buffer is { Length: 0 } && _canRead) return 0; - var result = await _chan.ReadAsync(buffer.Length, ReadBlockingMode.WaitAny); + ReadResult result = await _chan.ReadAsync(buffer.Length, ReadBlockingMode.WaitAny); if (result.Result != IOResult.Ok) { _canRead = false; diff --git a/src/libp2p/Libp2p.Core/SymetricProtocol.cs b/src/libp2p/Libp2p.Core/SymetricProtocol.cs index 3ca90875..a5f63583 100644 --- a/src/libp2p/Libp2p.Core/SymetricProtocol.cs +++ b/src/libp2p/Libp2p.Core/SymetricProtocol.cs @@ -5,16 +5,29 @@ namespace Nethermind.Libp2p.Core; public abstract class SymmetricProtocol { - public Task DialAsync(IChannel channel, IChannelFactory? channelFactory, IPeerContext context) + public Task DialAsync(IChannel channel, IConnectionContext context) { - return ConnectAsync(channel, channelFactory, context, false); + return ConnectAsync(channel, context, false); } - public Task ListenAsync(IChannel channel, IChannelFactory? channelFactory, IPeerContext context) + public Task ListenAsync(IChannel channel, IConnectionContext context) { - return ConnectAsync(channel, channelFactory, context, true); + return ConnectAsync(channel, context, true); } - protected abstract Task ConnectAsync(IChannel channel, IChannelFactory? channelFactory, - IPeerContext context, bool isListener); + protected abstract Task ConnectAsync(IChannel channel, IConnectionContext context, bool isListener); +} +public abstract class SymmetricSessionProtocol +{ + public Task DialAsync(IChannel channel, ISessionContext context) + { + return ConnectAsync(channel, context, false); + } + + public Task ListenAsync(IChannel channel, ISessionContext context) + { + return ConnectAsync(channel, context, true); + } + + protected abstract Task ConnectAsync(IChannel channel, ISessionContext context, bool isListener); } diff --git a/src/libp2p/Libp2p.Core/TransportContext.cs b/src/libp2p/Libp2p.Core/TransportContext.cs new file mode 100644 index 00000000..2551cf7b --- /dev/null +++ b/src/libp2p/Libp2p.Core/TransportContext.cs @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Multiformats.Address; + +namespace Nethermind.Libp2p.Core; + +public class TransportContext(LocalPeer peer, ProtocolRef proto, bool isListener) : ITransportContext +{ + public Identity Identity => peer.Identity; + public ILocalPeer Peer => peer; + public bool IsListener => isListener; + + public void ListenerReady(Multiaddress addr) + { + peer.ListenerReady(this, addr); + } + + public virtual INewConnectionContext CreateConnection() + { + return peer.CreateConnection(proto, null, isListener); + } +} + +public class DialerTransportContext(LocalPeer peer, LocalPeer.Session session, ProtocolRef proto) : TransportContext(peer, proto, false) +{ + public override INewConnectionContext CreateConnection() + { + return peer.CreateConnection(proto, session, false); + } +} diff --git a/src/libp2p/Libp2p.Core/Utils/IpHelper.cs b/src/libp2p/Libp2p.Core/Utils/IpHelper.cs new file mode 100644 index 00000000..3921878b --- /dev/null +++ b/src/libp2p/Libp2p.Core/Utils/IpHelper.cs @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using System.Net; +using System.Net.NetworkInformation; + +namespace Nethermind.Libp2p.Core.Utils; +public class IpHelper +{ + public static IEnumerable GetListenerAddresses() => NetworkInterface.GetAllNetworkInterfaces().SelectMany(i => i.GetIPProperties().UnicastAddresses.Select(a => a.Address)); +} diff --git a/src/libp2p/Libp2p.E2eTests/E2eTestSetup.cs b/src/libp2p/Libp2p.E2eTests/E2eTestSetup.cs new file mode 100644 index 00000000..2015065b --- /dev/null +++ b/src/libp2p/Libp2p.E2eTests/E2eTestSetup.cs @@ -0,0 +1,100 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Nethermind.Libp2p; +using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Core.Discovery; +using Nethermind.Libp2p.Core.TestsBase; +using System.Text; + +namespace Libp2p.E2eTests; + +public class E2eTestSetup : IDisposable +{ + private readonly CancellationTokenSource _commonTokenSource = new(); + public void Dispose() + { + _commonTokenSource.Cancel(); + _commonTokenSource.Dispose(); + } + + protected CancellationToken Token => _commonTokenSource.Token; + + protected static TestContextLoggerFactory loggerFactory = new(); + private int _peerCounter = 0; + + protected ILogger TestLogger { get; set; } = loggerFactory.CreateLogger("test-setup"); + + public Dictionary Peers { get; } = []; + public Dictionary PeerStores { get; } = []; + public Dictionary ServiceProviders { get; } = []; + + protected virtual IPeerFactoryBuilder ConfigureLibp2p(ILibp2pPeerFactoryBuilder builder) + { + return builder.AddAppLayerProtocol(); + } + + protected virtual IServiceCollection ConfigureServices(IServiceCollection col) + { + return col; + } + + protected virtual void AddToPrintState(StringBuilder sb, int index) + { + } + + protected virtual void AddAt(int index) + { + + } + + public async Task AddPeersAsync(int count) + { + int totalCount = _peerCounter + count; + + for (; _peerCounter < totalCount; _peerCounter++) + { + // But we create a seprate setup for every peer + ServiceProvider sp = ServiceProviders[_peerCounter] = + ConfigureServices( + new ServiceCollection() + .AddLibp2p(ConfigureLibp2p) + .AddSingleton(sp => new TestContextLoggerFactory()) + ) + .BuildServiceProvider(); + + PeerStores[_peerCounter] = ServiceProviders[_peerCounter].GetService()!; + Peers[_peerCounter] = sp.GetService()!.Create(TestPeers.Identity(_peerCounter)); + + await Peers[_peerCounter].StartListenAsync(token: Token); + + AddAt(_peerCounter); + } + } + + + private int stateCounter = 1; + + public void PrintState(bool outputToConsole = false) + { + StringBuilder reportBuilder = new(); + reportBuilder.AppendLine($"Test state#{stateCounter++}"); + + foreach ((int index, ILocalPeer peer) in Peers) + { + AddToPrintState(reportBuilder, index); + reportBuilder.AppendLine(peer.ToString()); + reportBuilder.AppendLine(); + } + + string report = reportBuilder.ToString(); + + if (outputToConsole) + { + Console.WriteLine(report); + } + else + { + TestLogger.LogInformation(report.ToString()); + } + } +} diff --git a/src/libp2p/Libp2p.E2eTests/IncrementNumberTestProtocol.cs b/src/libp2p/Libp2p.E2eTests/IncrementNumberTestProtocol.cs new file mode 100644 index 00000000..af81e1e5 --- /dev/null +++ b/src/libp2p/Libp2p.E2eTests/IncrementNumberTestProtocol.cs @@ -0,0 +1,20 @@ +using Nethermind.Libp2p.Core; + +namespace Libp2p.E2eTests; + +public class IncrementNumberTestProtocol : ISessionProtocol +{ + public string Id => "1"; + + public async Task DialAsync(IChannel downChannel, ISessionContext context, int request) + { + await downChannel.WriteVarintAsync(request); + return await downChannel.ReadVarintAsync(); + } + + public async Task ListenAsync(IChannel downChannel, ISessionContext context) + { + int request = await downChannel.ReadVarintAsync(); + await downChannel.WriteVarintAsync(request + 1); + } +} diff --git a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.Tests/Libp2p.Protocols.PubsubPeerDiscovery.Tests.csproj b/src/libp2p/Libp2p.E2eTests/Libp2p.E2eTests.csproj similarity index 79% rename from src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.Tests/Libp2p.Protocols.PubsubPeerDiscovery.Tests.csproj rename to src/libp2p/Libp2p.E2eTests/Libp2p.E2eTests.csproj index 2132529c..9c85b1e2 100644 --- a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.Tests/Libp2p.Protocols.PubsubPeerDiscovery.Tests.csproj +++ b/src/libp2p/Libp2p.E2eTests/Libp2p.E2eTests.csproj @@ -3,13 +3,14 @@ enable enable - Nethermind.$(MSBuildProjectName.Replace(" ", "_")) - false - Nethermind.$(MSBuildProjectName) + + + + @@ -18,13 +19,13 @@ - + diff --git a/src/libp2p/Libp2p.E2eTests/RequestResponseTests.cs b/src/libp2p/Libp2p.E2eTests/RequestResponseTests.cs new file mode 100644 index 00000000..3d38f9b7 --- /dev/null +++ b/src/libp2p/Libp2p.E2eTests/RequestResponseTests.cs @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Nethermind.Libp2p.Core; +using NUnit.Framework; + +namespace Libp2p.E2eTests; + +public class RequestResponseTests +{ + [Test] + public async Task Test_RequestReponse() + { + E2eTestSetup test = new(); + int request = 1; + + await test.AddPeersAsync(2); + ISession session = await test.Peers[0].DialAsync(test.Peers[1].ListenAddresses.ToArray()); + int response = await session.DialAsync(1); + + Assert.That(response, Is.EqualTo(request + 1)); + } + +} diff --git a/src/libp2p/Libp2p.Protocols.Identify/IPAddressExtensions.cs b/src/libp2p/Libp2p.Protocols.Identify/IPAddressExtensions.cs new file mode 100644 index 00000000..c72802b5 --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.Identify/IPAddressExtensions.cs @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using System.Net; +using System.Net.Sockets; + +namespace Nethermind.Libp2p.Protocols.Identify; + + +// Picked from https://gist.github.com/angularsen/f77b53ee9966fcd914025e25a2b3a085 creds: Andreas Gullberg Larsen +public static class IPAddressExtensions +{ + /// + /// Returns true if the IP address is in a private range.
+ /// IPv4: Loopback, link local ("169.254.x.x"), class A ("10.x.x.x"), class B ("172.16.x.x" to "172.31.x.x") and class C ("192.168.x.x").
+ /// IPv6: Loopback, link local, site local, unique local and private IPv4 mapped to IPv6.
+ ///
+ /// The IP address. + /// True if the IP address was in a private range. + /// bool isPrivate = IPAddress.Parse("127.0.0.1").IsPrivate(); + public static bool IsPrivate(this IPAddress ip) + { + // Map back to IPv4 if mapped to IPv6, for example "::ffff:1.2.3.4" to "1.2.3.4". + if (ip.IsIPv4MappedToIPv6) + ip = ip.MapToIPv4(); + + // Checks loopback ranges for both IPv4 and IPv6. + if (IPAddress.IsLoopback(ip)) return true; + + // IPv4 + if (ip.AddressFamily == AddressFamily.InterNetwork) + return IsPrivateIPv4(ip.GetAddressBytes()); + + // IPv6 + if (ip.AddressFamily == AddressFamily.InterNetworkV6) + { + return ip.IsIPv6LinkLocal || + ip.IsIPv6UniqueLocal || + ip.IsIPv6SiteLocal; + } + + throw new NotSupportedException( + $"IP address family {ip.AddressFamily} is not supported, expected only IPv4 (InterNetwork) or IPv6 (InterNetworkV6)."); + } + + private static bool IsPrivateIPv4(byte[] ipv4Bytes) + { + // Link local (no IP assigned by DHCP): 169.254.0.0 to 169.254.255.255 (169.254.0.0/16) + bool IsLinkLocal() => ipv4Bytes[0] == 169 && ipv4Bytes[1] == 254; + + // Class A private range: 10.0.0.0 – 10.255.255.255 (10.0.0.0/8) + bool IsClassA() => ipv4Bytes[0] == 10; + + // Class B private range: 172.16.0.0 – 172.31.255.255 (172.16.0.0/12) + bool IsClassB() => ipv4Bytes[0] == 172 && ipv4Bytes[1] >= 16 && ipv4Bytes[1] <= 31; + + // Class C private range: 192.168.0.0 – 192.168.255.255 (192.168.0.0/16) + bool IsClassC() => ipv4Bytes[0] == 192 && ipv4Bytes[1] == 168; + + return IsLinkLocal() || IsClassA() || IsClassC() || IsClassB(); + } +} diff --git a/src/libp2p/Libp2p.Protocols.Identify/IdentifyProtocol.cs b/src/libp2p/Libp2p.Protocols.Identify/IdentifyProtocol.cs index eef62dfb..54e30ce3 100644 --- a/src/libp2p/Libp2p.Protocols.Identify/IdentifyProtocol.cs +++ b/src/libp2p/Libp2p.Protocols.Identify/IdentifyProtocol.cs @@ -2,145 +2,103 @@ // SPDX-License-Identifier: MIT using Google.Protobuf; -using Nethermind.Libp2p.Core; using Microsoft.Extensions.Logging; -using Nethermind.Libp2p.Core.Dto; -using Multiformats.Address; -using Multiformats.Address.Protocols; +using Multiformats.Address.Net; +using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Protocols.Identify; +using System.Net.Sockets; using Nethermind.Libp2p.Core.Discovery; -using Nethermind.Libp2p.Protocols.Identify.Dto; +using Nethermind.Libp2p.Core.Dto; +using Nethermind.Libp2p.Core.Exceptions; namespace Nethermind.Libp2p.Protocols; /// /// https://github.com/libp2p/specs/tree/master/identify /// -public class IdentifyProtocol : IProtocol +public class IdentifyProtocol : ISessionProtocol { - private readonly string _agentVersion; - private readonly string _protocolVersion; - private readonly ILogger? _logger; - private readonly IPeerFactoryBuilder _peerFactoryBuilder; private readonly PeerStore? _peerStore; - - private static readonly byte[] Libp2pPeerRecordAsArray = [((ushort)Core.Enums.Libp2p.Libp2pPeerRecord >> 8) & 0xFF, (ushort)Core.Enums.Libp2p.Libp2pPeerRecord & 0xFF]; + private readonly IProtocolStackSettings _protocolStackSettings; + private readonly IdentifyProtocolSettings _settings; public string Id => "/ipfs/id/1.0.0"; - public IdentifyProtocol(IPeerFactoryBuilder peerFactoryBuilder, IdentifyProtocolSettings? settings = null, PeerStore? peerStore = null, ILoggerFactory? loggerFactory = null) + public IdentifyProtocol(IProtocolStackSettings protocolStackSettings, IdentifyProtocolSettings? settings = null, PeerStore? peerStore = null, ILoggerFactory? loggerFactory = null) { _logger = loggerFactory?.CreateLogger(); - _peerFactoryBuilder = peerFactoryBuilder; _peerStore = peerStore; - _agentVersion = settings?.AgentVersion ?? IdentifyProtocolSettings.Default.AgentVersion!; - _protocolVersion = settings?.ProtocolVersion ?? IdentifyProtocolSettings.Default.ProtocolVersion!; + _protocolStackSettings = protocolStackSettings; + _settings = settings ?? new IdentifyProtocolSettings(); } - public async Task DialAsync(IChannel channel, IChannelFactory? channelFactory, - IPeerContext context) + public async Task DialAsync(IChannel channel, ISessionContext context) { + ArgumentNullException.ThrowIfNull(context.State.RemotePublicKey); + ArgumentNullException.ThrowIfNull(context.State.RemotePeerId); + _logger?.LogInformation("Dial"); Identify.Dto.Identify identify = await channel.ReadPrefixedProtobufAsync(Identify.Dto.Identify.Parser); _logger?.LogInformation("Received peer info: {identify}", identify); - context.RemotePeer.Identity = new Identity(PublicKey.Parser.ParseFrom(identify.PublicKey)); - if (_peerStore is not null && identify.SignedPeerRecord is not null) + if (_peerStore is not null) { - if (!VerifyPeerRecord(identify.SignedPeerRecord, context.RemotePeer.Identity)) + _peerStore.GetPeerInfo(context.State.RemotePeerId).SupportedProtocols = identify.Protocols.ToArray(); + + if (identify.SignedPeerRecord is not null) + { + if (!SigningHelper.VerifyPeerRecord(identify.SignedPeerRecord, context.State.RemotePublicKey)) + { + if (_settings?.PeerRecordsVerificationPolicy == PeerRecordsVerificationPolicy.RequireCorrect) + { + throw new PeerConnectionException("Malformed peer identity: peer record signature is not valid"); + } + else + { + _logger?.LogWarning("Malformed peer identity: peer record signature is not valid"); + } + } + else + { + _peerStore.GetPeerInfo(context.State.RemotePeerId).SignedPeerRecord = identify.SignedPeerRecord; + _logger?.LogDebug("Confirmed peer record: {peerId}", context.State.RemotePeerId); + } + } + else if (_settings.PeerRecordsVerificationPolicy != PeerRecordsVerificationPolicy.DoesNotRequire) { - throw new PeerConnectionException(); + throw new PeerConnectionException("Malformed peer identity: there is no peer record which is required"); } - _peerStore.GetPeerInfo(context.RemotePeer.Identity.PeerId).SignedPeerRecord = identify.SignedPeerRecord; } - if (context.RemotePeer.Identity.PublicKey.ToByteString() != identify.PublicKey) + if (context.State.RemotePublicKey.ToByteString() != identify.PublicKey) { - throw new PeerConnectionException(); + throw new PeerConnectionException("Malformed peer identity: the remote public key corresponds to a different peer id"); } } - public async Task ListenAsync(IChannel channel, IChannelFactory? channelFactory, - IPeerContext context) + public async Task ListenAsync(IChannel channel, ISessionContext context) { _logger?.LogInformation("Listen"); Identify.Dto.Identify identify = new() { - ProtocolVersion = _protocolVersion, - AgentVersion = _agentVersion, - PublicKey = context.LocalPeer.Identity.PublicKey.ToByteString(), - ListenAddrs = { ByteString.CopyFrom(ToEndpoint(context.LocalEndpoint).ToBytes()) }, - ObservedAddr = ByteString.CopyFrom(ToEndpoint(context.RemoteEndpoint).ToBytes()), - Protocols = { _peerFactoryBuilder.AppLayerProtocols.Select(p => p.Id) }, - SignedPeerRecord = CreateSignedEnvelope(context.LocalPeer.Identity, [context.LocalPeer.Address], 1), + ProtocolVersion = _settings.ProtocolVersion, + AgentVersion = _settings.AgentVersion, + PublicKey = context.Peer.Identity.PublicKey.ToByteString(), + ListenAddrs = { context.Peer.ListenAddresses.Select(x => ByteString.CopyFrom(x.ToBytes())) }, + ObservedAddr = ByteString.CopyFrom(context.State.RemoteAddress!.ToEndPoint(out ProtocolType proto).ToMultiaddress(proto).ToBytes()), + Protocols = { _protocolStackSettings.Protocols!.Select(r => r.Key.Protocol).OfType().Select(p => p.Id) }, + SignedPeerRecord = SigningHelper.CreateSignedEnvelope(context.Peer.Identity, [.. context.Peer.ListenAddresses], 1), }; - byte[] ar = new byte[identify.CalculateSize()]; - identify.WriteTo(ar); + ByteString[] endpoints = context.Peer.ListenAddresses.Where(a => !a.ToEndPoint().Address.IsPrivate()).Select(a => a.ToEndPoint(out ProtocolType proto).ToMultiaddress(proto)).Select(a => ByteString.CopyFrom(a.ToBytes())).ToArray(); + identify.ListenAddrs.AddRange(endpoints); - await channel.WriteSizeAndDataAsync(ar); + await channel.WriteSizeAndProtobufAsync(identify); _logger?.LogDebug("Sent peer info {identify}", identify); } - - private static bool VerifyPeerRecord(ByteString signedPeerRecordBytes, Identity identity) - { - SignedEnvelope envelope = SignedEnvelope.Parser.ParseFrom(signedPeerRecordBytes); - - if (envelope.PayloadType?.Take(2).SequenceEqual(Libp2pPeerRecordAsArray) is not true) - { - return false; - } - - PeerRecord pr = PeerRecord.Parser.ParseFrom(envelope.Payload); - - if (identity.PeerId != new PeerId(pr.PeerId.ToByteArray())) - { - return false; - } - - SignedEnvelope envelopeWithoutSignature = envelope.Clone(); - envelopeWithoutSignature.ClearSignature(); - - return identity.VerifySignature(envelopeWithoutSignature.ToByteArray(), envelope.Signature.ToByteArray()); - } - - private static ByteString CreateSignedEnvelope(Identity identity, Multiaddress[] addresses, ulong seq) - { - PeerRecord paylaod = new() - { - PeerId = ByteString.CopyFrom(identity.PeerId.Bytes), - Seq = seq - }; - - foreach (var address in addresses) - { - paylaod.Addresses.Add(new AddressInfo - { - Multiaddr = ByteString.CopyFrom(address.ToBytes()) - }); - } - - SignedEnvelope envelope = new() - { - PayloadType = ByteString.CopyFrom(Libp2pPeerRecordAsArray), - Payload = paylaod.ToByteString(), - PublicKey = identity.PublicKey.ToByteString(), - }; - - envelope.Signature = ByteString.CopyFrom(identity.Sign(envelope.ToByteArray())); - return envelope.ToByteString(); - } - - private static Multiaddress ToEndpoint(Multiaddress addr) => new() - { - Protocols = - { - addr.Has() ? addr.Get() : addr.Get(), - addr.Has() ? addr.Get() : addr.Get() - } - }; } diff --git a/src/libp2p/Libp2p.Protocols.Identify/IdentifyProtocolSettings.cs b/src/libp2p/Libp2p.Protocols.Identify/IdentifyProtocolSettings.cs index ab150f50..3078962e 100644 --- a/src/libp2p/Libp2p.Protocols.Identify/IdentifyProtocolSettings.cs +++ b/src/libp2p/Libp2p.Protocols.Identify/IdentifyProtocolSettings.cs @@ -5,12 +5,15 @@ namespace Nethermind.Libp2p.Protocols; public class IdentifyProtocolSettings { - public string? AgentVersion { get; set; } - public string? ProtocolVersion { get; set; } + public string AgentVersion { get; set; } = "ipfs/1.0.0"; + public string ProtocolVersion { get; set; } = "dotnet-libp2p/1.0.0"; + public PeerRecordsVerificationPolicy PeerRecordsVerificationPolicy { get; set; } = PeerRecordsVerificationPolicy.RequireWithWarning; +} + - public static IdentifyProtocolSettings Default { get; } = new() - { - ProtocolVersion = "ipfs/1.0.0", - AgentVersion = "dotnet-libp2p/1.0.0", - }; +public enum PeerRecordsVerificationPolicy +{ + RequireCorrect, + RequireWithWarning, + DoesNotRequire } diff --git a/src/libp2p/Libp2p.Protocols.Identify/Libp2p.Protocols.Identify.csproj b/src/libp2p/Libp2p.Protocols.Identify/Libp2p.Protocols.Identify.csproj index 5e2c68ac..4f264a38 100644 --- a/src/libp2p/Libp2p.Protocols.Identify/Libp2p.Protocols.Identify.csproj +++ b/src/libp2p/Libp2p.Protocols.Identify/Libp2p.Protocols.Identify.csproj @@ -13,13 +13,6 @@ - - - - - - Never - diff --git a/src/libp2p/Libp2p.Protocols.IpTcp/IpTcpProtocol.cs b/src/libp2p/Libp2p.Protocols.IpTcp/IpTcpProtocol.cs index bb6c5f71..f2471f7c 100644 --- a/src/libp2p/Libp2p.Protocols.IpTcp/IpTcpProtocol.cs +++ b/src/libp2p/Libp2p.Protocols.IpTcp/IpTcpProtocol.cs @@ -8,242 +8,190 @@ using Microsoft.Extensions.Logging; using Multiformats.Address; using Multiformats.Address.Protocols; -using System.Threading.Channels; +using Multiformats.Address.Net; +using Nethermind.Libp2p.Core.Exceptions; +using Nethermind.Libp2p.Core.Utils; namespace Nethermind.Libp2p.Protocols; -public class IpTcpProtocol(ILoggerFactory? loggerFactory = null) : IProtocol +public class IpTcpProtocol(ILoggerFactory? loggerFactory = null) : ITransportProtocol { private readonly ILogger? _logger = loggerFactory?.CreateLogger(); public string Id => "ip-tcp"; + public static Multiaddress[] GetDefaultAddresses(PeerId peerId) => IpHelper.GetListenerAddresses() + .Select(a => Multiaddress.Decode($"/{(a.AddressFamily is AddressFamily.InterNetwork ? "ip4" : "ip6")}/{a}/tcp/0/p2p/{peerId}")).ToArray(); + public static bool IsAddressMatch(Multiaddress addr) => addr.Has(); - public async Task ListenAsync(IChannel signalingChannel, IChannelFactory? channelFactory, IPeerContext context) + public async Task ListenAsync(ITransportContext context, Multiaddress listenAddr, CancellationToken token) { - try - { - if (channelFactory is null) - { - throw new Exception("Protocol is not properly instantiated"); - } + Socket listener = new(SocketType.Stream, ProtocolType.Tcp); - Multiaddress addr = context.LocalPeer.Address; - bool isIP4 = addr.Has(); - MultiaddressProtocol ipProtocol = isIP4 ? addr.Get() : addr.Get(); - IPAddress ipAddress = IPAddress.Parse(ipProtocol.ToString()); - int tcpPort = int.Parse(addr.Get().ToString()); + IPEndPoint endpoint = listenAddr.ToEndPoint(); - Socket srv = new(SocketType.Stream, ProtocolType.Tcp); - srv.Bind(new IPEndPoint(ipAddress, tcpPort)); - srv.Listen(tcpPort); - signalingChannel.GetAwaiter().OnCompleted(() => - { - srv.Close(); - }); + listener.Bind(endpoint); + listener.Listen(); - IPEndPoint localIpEndpoint = (IPEndPoint)srv.LocalEndPoint!; + if (endpoint.Port is 0) + { + IPEndPoint localIpEndpoint = (IPEndPoint)listener.LocalEndPoint!; + listenAddr.ReplaceOrAdd(localIpEndpoint.Port); + } - Multiaddress localMultiaddress = new(); - localMultiaddress = isIP4 ? localMultiaddress.Add(localIpEndpoint.Address.MapToIPv4()) : localMultiaddress.Add(localIpEndpoint.Address.MapToIPv6()); - localMultiaddress = localMultiaddress.Add(localIpEndpoint.Port); - context.LocalEndpoint = localMultiaddress; + token.Register(listener.Close); - if (tcpPort == 0) - { - context.LocalPeer.Address = context.LocalPeer.Address - .ReplaceOrAdd(localIpEndpoint.Port); - } - _logger?.LogDebug("Ready to handle connections"); - context.ListenerReady(); + _logger?.LogDebug("Ready to handle connections"); + context.ListenerReady(listenAddr); - await Task.Run(async () => + await Task.Run(async () => + { + for (; ; ) { - for (; ; ) - { - Socket client = await srv.AcceptAsync(); - IPeerContext clientContext = context.Fork(); - IPEndPoint remoteIpEndpoint = (IPEndPoint)client.RemoteEndPoint!; - - Multiaddress remoteMultiaddress = new(); - remoteMultiaddress = isIP4 ? remoteMultiaddress.Add(remoteIpEndpoint.Address.MapToIPv4()) : remoteMultiaddress.Add(remoteIpEndpoint.Address.MapToIPv6()); - remoteMultiaddress = remoteMultiaddress.Add(remoteIpEndpoint.Port); + Socket client = await listener.AcceptAsync(); - clientContext.RemoteEndpoint = clientContext.RemotePeer.Address = remoteMultiaddress; + INewConnectionContext connectionCtx = context.CreateConnection(); + connectionCtx.Token.Register(client.Close); + connectionCtx.State.RemoteAddress = client.RemoteEndPoint.ToMultiaddress(ProtocolType.Tcp); - IChannel upChannel = channelFactory.SubListen(clientContext); + IChannel upChannel = connectionCtx.Upgrade(); - _ = Task.Run(async () => + Task readTask = Task.Run(async () => + { + try { - try + for (; client.Connected;) { - for (; client.Connected;) + if (client.Available == 0) { - if (client.Available == 0) - { - await Task.Yield(); - } - - byte[] buf = new byte[client.ReceiveBufferSize]; - int length = await client.ReceiveAsync(buf, SocketFlags.None); - if (length != 0) - { - if ((await upChannel.WriteAsync(new ReadOnlySequence(buf.AsMemory()[..length]))) != IOResult.Ok) - { - break; - } - } - else - { - break; - } + await Task.Yield(); + } + + byte[] buf = new byte[client.ReceiveBufferSize]; + int length = await client.ReceiveAsync(buf, SocketFlags.None); + + if (length is 0 || await upChannel.WriteAsync(new ReadOnlySequence(buf.AsMemory()[..length])) != IOResult.Ok) + { + break; } } - catch (SocketException e) - { - await upChannel.CloseAsync(); - } - }); - _ = Task.Run(async () => + } + catch (SocketException e) + { + await upChannel.CloseAsync(); + } + }); + + Task writeTask = Task.Run(async () => + { + try { - try + await foreach (ReadOnlySequence data in upChannel.ReadAllAsync()) { - await foreach (ReadOnlySequence data in upChannel.ReadAllAsync()) + int sent = await client.SendAsync(data.ToArray(), SocketFlags.None); + if (sent is 0 || !client.Connected) { - int sent = await client.SendAsync(data.ToArray(), SocketFlags.None); - if (sent is 0 || !client.Connected) - { - break; - } + await upChannel.CloseAsync(); + break; } } - catch (SocketException) - { - _logger?.LogInformation($"Disconnected({context.Id}) due to a socket exception"); - await upChannel.CloseAsync(); - } - }); - } - }); - } - catch (Exception ex) - { - _logger?.LogError(ex, $"Listener error"); - throw; - } - } - - public async Task DialAsync(IChannel signalingChannel, IChannelFactory? channelFactory, IPeerContext context) - { - try - { - if (channelFactory is null) - { - throw new ProtocolViolationException(); - } - - Socket client = new(SocketType.Stream, ProtocolType.Tcp); - Multiaddress addr = context.RemotePeer.Address; - MultiaddressProtocol ipProtocol = addr.Has() ? addr.Get() : addr.Get(); - IPAddress ipAddress = IPAddress.Parse(ipProtocol.ToString()); - int tcpPort = addr.Get().Port; - - _logger?.LogDebug("Dialing {0}:{1}", ipAddress, tcpPort); + } + catch (SocketException e) + { + _logger?.LogInformation($"Disconnected due to a socket exception"); + await upChannel.CloseAsync(); + } + }); - try - { - await client.ConnectAsync(new IPEndPoint(ipAddress, tcpPort), signalingChannel.CancellationToken); - } - catch (SocketException e) - { - _logger?.LogDebug($"Failed({context.Id}) to connect {addr}"); - _logger?.LogTrace($"Failed with {e.GetType()}: {e.Message}"); - _ = signalingChannel.CloseAsync(); - return; + _ = Task.WhenAll(readTask, writeTask).ContinueWith(_ => connectionCtx.Dispose()); } + }); + } - signalingChannel.GetAwaiter().OnCompleted(() => - { - client.Close(); - }); - - IPEndPoint localEndpoint = (IPEndPoint)client.LocalEndPoint!; - IPEndPoint remoteEndpoint = (IPEndPoint)client.RemoteEndPoint!; + public async Task DialAsync(ITransportContext context, Multiaddress remoteAddr, CancellationToken token) + { + Socket client = new(SocketType.Stream, ProtocolType.Tcp); - var isIP4 = addr.Has(); + IPEndPoint remoteEndpoint = remoteAddr.ToEndPoint(); + _logger?.LogDebug("Dialing {0}:{1}", remoteEndpoint.Address, remoteEndpoint.Port); - var remoteMultiaddress = new Multiaddress(); - var remoteIpAddress = isIP4 ? remoteEndpoint.Address.MapToIPv4() : remoteEndpoint.Address.MapToIPv6(); - remoteMultiaddress = isIP4 ? remoteMultiaddress.Add(remoteIpAddress) : remoteMultiaddress.Add(remoteIpAddress); - context.RemoteEndpoint = remoteMultiaddress.Add(remoteEndpoint.Port); + try + { + await client.ConnectAsync(remoteEndpoint, token); + } + catch (SocketException e) + { + _logger?.LogDebug($"Failed to connect {remoteAddr}"); + _logger?.LogTrace($"Failed with {e.GetType()}: {e.Message}"); + throw; + } + if (client.LocalEndPoint is null) + { + throw new Libp2pException($"{nameof(client.LocalEndPoint)} is not set for client connection."); + } + if (client.RemoteEndPoint is null) + { + throw new Libp2pException($"{nameof(client.RemoteEndPoint)} is not set for client connection."); + } - var localMultiaddress = new Multiaddress(); - var localIpAddress = isIP4 ? localEndpoint.Address.MapToIPv4() : localEndpoint.Address.MapToIPv6(); - localMultiaddress = isIP4 ? localMultiaddress.Add(localIpAddress) : localMultiaddress.Add(localIpAddress); - context.LocalEndpoint = localMultiaddress.Add(localEndpoint.Port); + INewConnectionContext connectionCtx = context.CreateConnection(); + connectionCtx.State.RemoteAddress = client.RemoteEndPoint.ToMultiaddress(ProtocolType.Tcp); + connectionCtx.State.LocalAddress = client.LocalEndPoint.ToMultiaddress(ProtocolType.Tcp); - context.LocalPeer.Address = context.LocalEndpoint.Add(context.LocalPeer.Identity.PeerId.ToString()); + connectionCtx.Token.Register(client.Close); + token.Register(client.Close); - IChannel upChannel = channelFactory.SubDial(context); + IChannel upChannel = connectionCtx.Upgrade(); - Task receiveTask = Task.Run(async () => + Task receiveTask = Task.Run(async () => + { + try { - byte[] buf = new byte[client.ReceiveBufferSize]; - try + for (; client.Connected;) { - for (; client.Connected;) + byte[] buf = new byte[client.ReceiveBufferSize]; + int dataLength = await client.ReceiveAsync(buf, SocketFlags.None); + _logger?.LogDebug("Ctx{0}: receive, length={1}", connectionCtx.Id, dataLength); + + if (dataLength == 0 || (await upChannel.WriteAsync(new ReadOnlySequence(buf[..dataLength]))) != IOResult.Ok) { - int dataLength = await client.ReceiveAsync(buf, SocketFlags.None); - if (dataLength != 0) - { - _logger?.LogDebug("Receive {0} data, len={1}", context.Id, dataLength); - if ((await upChannel.WriteAsync(new ReadOnlySequence(buf[..dataLength]))) != IOResult.Ok) - { - break; - } - } - else - { - break; - } + break; } - } - catch (SocketException) - { - _ = upChannel.CloseAsync(); - } - }); + } + catch (SocketException) + { + _ = upChannel.CloseAsync(); + } + }); - Task sendTask = Task.Run(async () => + Task sendTask = Task.Run(async () => + { + try { - try + await foreach (ReadOnlySequence data in upChannel.ReadAllAsync()) { - await foreach (ReadOnlySequence data in upChannel.ReadAllAsync()) + _logger?.LogDebug("Ctx{0}: send, length={2}", connectionCtx.Id, data.Length); + int sent = await client.SendAsync(data.ToArray(), SocketFlags.None); + if (sent is 0 || !client.Connected) { - _logger?.LogDebug("Send {0} data, len={1}", context.Id, data.Length); - int sent = await client.SendAsync(data.ToArray(), SocketFlags.None); - if (sent is 0 || !client.Connected) - { - break; - } + break; } } - catch (SocketException) - { - _ = upChannel.CloseAsync(); - } - }); + } + catch (SocketException) + { + _ = upChannel.CloseAsync(); + return; + } - await Task.WhenAll(receiveTask, sendTask); - _ = upChannel.CloseAsync(); - } - catch (Exception ex) - { - _logger?.LogError(ex, $"Listener error"); - throw; - } + client.Close(); + }); + + await Task.WhenAll(receiveTask, sendTask).ContinueWith(t => connectionCtx.Dispose()); + + _ = upChannel.CloseAsync(); } } diff --git a/src/libp2p/Libp2p.Protocols.MDns/MDnsDiscoveryProtocol.cs b/src/libp2p/Libp2p.Protocols.MDns/MDnsDiscoveryProtocol.cs index 858cfa48..a4a891da 100644 --- a/src/libp2p/Libp2p.Protocols.MDns/MDnsDiscoveryProtocol.cs +++ b/src/libp2p/Libp2p.Protocols.MDns/MDnsDiscoveryProtocol.cs @@ -9,6 +9,8 @@ using Makaretu.Dns; using Multiformats.Address; using Multiformats.Address.Protocols; +using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Core.Exceptions; namespace Nethermind.Libp2p.Protocols; @@ -22,52 +24,60 @@ public class MDnsDiscoveryProtocol(PeerStore peerStore, ILoggerFactory? loggerFa private string PeerName = null!; - public async Task DiscoverAsync(Multiaddress localPeerAddr, CancellationToken token = default) + public Task StartDiscoveryAsync(IReadOnlyList localPeerAddrs, CancellationToken token = default) { ObservableCollection peers = []; ServiceDiscovery sd = new(); + string? localPeerId = localPeerAddrs.First().GetPeerId()?.ToString(); + + if (localPeerId is null) + { + throw new Libp2pException(); + } try { PeerName = RandomString(32); ServiceProfile service = new(PeerName, ServiceNameOverride ?? ServiceName, 0); - if (localPeerAddr.Get().ToString() == "0.0.0.0") + foreach (Multiaddress localPeerAddr in localPeerAddrs) { - service.Resources.Add(new TXTRecord() + if (localPeerAddr.Get().ToString() == "0.0.0.0") { - Name = service.FullyQualifiedName, - Strings = new List(MulticastService.GetLinkLocalAddresses() - .Where(x => x.AddressFamily == AddressFamily.InterNetwork) - .Select(item => $"dnsaddr={localPeerAddr.ReplaceOrAdd(item.ToString())}")) - }); - } - else - { - service.Resources.Add(new TXTRecord() + service.Resources.Add(new TXTRecord() + { + Name = service.FullyQualifiedName, + Strings = new List(MulticastService.GetLinkLocalAddresses() + .Where(x => x.AddressFamily == AddressFamily.InterNetwork) + .Select(item => $"dnsaddr={localPeerAddr.ReplaceOrAdd(item.ToString())}")) + }); + } + else { - Name = service.FullyQualifiedName, - Strings = [$"dnsaddr={localPeerAddr}"] - }); + service.Resources.Add(new TXTRecord() + { + Name = service.FullyQualifiedName, + Strings = [$"dnsaddr={localPeerAddr}"] + }); + } } _logger?.LogInformation("Started as {0} {1}", PeerName, ServiceNameOverride ?? ServiceName); - - sd.ServiceDiscovered += (s, serviceName) => { _logger?.LogTrace("Srv disc {0}", serviceName); }; + sd.ServiceInstanceDiscovered += (s, e) => { Multiaddress[] records = e.Message.AdditionalRecords.OfType() .Select(x => x.Strings.Where(x => x.StartsWith("dnsaddr"))) .SelectMany(x => x).Select(x => Multiaddress.Decode(x.Replace("dnsaddr=", ""))).ToArray(); _logger?.LogTrace("Inst disc {0}, nmsg: {1}", e.ServiceInstanceName, e.Message); - if (Enumerable.Any(records) && !peers.Contains(Enumerable.First(records)) && localPeerAddr.Get().ToString() != Enumerable.First(records).Get().ToString()) + if (records.Length != 0 && !peers.Contains(records[0]) && localPeerId != records[0].Get().ToString()) { - List peerAddresses = new(); + List peerAddresses = []; foreach (Multiaddress peer in records) { peers.Add(peer); @@ -83,6 +93,12 @@ public async Task DiscoverAsync(Multiaddress localPeerAddr, CancellationToken to _logger?.LogError(ex, "Error setting up mDNS"); } + _ = RunAsync(sd, token); + return Task.CompletedTask; + } + + private async Task RunAsync(ServiceDiscovery sd, CancellationToken token) + { while (!token.IsCancellationRequested) { try @@ -96,7 +112,6 @@ public async Task DiscoverAsync(Multiaddress localPeerAddr, CancellationToken to } await Task.Delay(MdnsQueryInterval, token); } - } private static string RandomString(int length) diff --git a/src/libp2p/Libp2p.Protocols.Multistream.Tests/MultistreamProtocolTests.cs b/src/libp2p/Libp2p.Protocols.Multistream.Tests/MultistreamProtocolTests.cs index 8b9ce61c..1916bb26 100644 --- a/src/libp2p/Libp2p.Protocols.Multistream.Tests/MultistreamProtocolTests.cs +++ b/src/libp2p/Libp2p.Protocols.Multistream.Tests/MultistreamProtocolTests.cs @@ -1,6 +1,10 @@ // SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited // SPDX-License-Identifier: MIT +using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Core.TestsBase; +using NSubstitute; + namespace Nethermind.Libp2p.Protocols.Multistream.Tests; [TestFixture] @@ -12,19 +16,16 @@ public async Task Test_ConnectionEstablished_AfterHandshake() { IChannel downChannel = new TestChannel(); IChannel downChannelFromProtocolPov = ((TestChannel)downChannel).Reverse(); - IChannelFactory channelFactory = Substitute.For(); - IPeerContext peerContext = Substitute.For(); - peerContext.SpecificProtocolRequest.Returns((IChannelRequest?)null); + IConnectionContext peerContext = Substitute.For(); + peerContext.UpgradeOptions.Returns(new UpgradeOptions()); IProtocol? proto1 = Substitute.For(); proto1.Id.Returns("proto1"); - channelFactory.SubProtocols.Returns(new[] { proto1 }); - IChannel upChannel = new TestChannel(); - channelFactory.SubDialAndBind(Arg.Any(), Arg.Any(), Arg.Any()) - .Returns(Task.CompletedTask); + peerContext.SubProtocols.Returns([proto1]); + peerContext.Upgrade(Arg.Any(), Arg.Any()).Returns(Task.CompletedTask); MultistreamProtocol proto = new(); - Task dialTask = proto.DialAsync(downChannelFromProtocolPov, channelFactory, peerContext); + Task dialTask = proto.DialAsync(downChannelFromProtocolPov, peerContext); _ = Task.Run(async () => { await downChannel.WriteLineAsync(proto.Id); @@ -36,7 +37,7 @@ public async Task Test_ConnectionEstablished_AfterHandshake() await dialTask; - _ = channelFactory.Received().SubDialAndBind(downChannelFromProtocolPov, peerContext, proto1); + _ = peerContext.Received().Upgrade(downChannelFromProtocolPov, proto1); await downChannel.CloseAsync(); } @@ -45,21 +46,16 @@ public async Task Test_ConnectionEstablished_AfterHandshake_With_SpecificRequest { IChannel downChannel = new TestChannel(); IChannel downChannelFromProtocolPov = ((TestChannel)downChannel).Reverse(); - IChannelFactory channelFactory = Substitute.For(); - IPeerContext peerContext = Substitute.For(); - IChannelRequest channelRequest = Substitute.For(); - peerContext.SpecificProtocolRequest.Returns(channelRequest); + IConnectionContext peerContext = Substitute.For(); IProtocol? proto1 = Substitute.For(); proto1.Id.Returns("proto1"); - channelRequest.SubProtocol.Returns(proto1); - IChannel upChannel = new TestChannel(); + peerContext.UpgradeOptions.Returns(new UpgradeOptions { SelectedProtocol = proto1 }); - channelFactory.SubDialAndBind(Arg.Any(), Arg.Any(), Arg.Any()) - .Returns(Task.CompletedTask); + peerContext.Upgrade(Arg.Any(), Arg.Any()).Returns(Task.CompletedTask); MultistreamProtocol proto = new(); - Task dialTask = proto.DialAsync(downChannelFromProtocolPov, channelFactory, peerContext); + Task dialTask = proto.DialAsync(downChannelFromProtocolPov, peerContext); _ = Task.Run(async () => { await downChannel.WriteLineAsync(proto.Id); @@ -71,7 +67,7 @@ public async Task Test_ConnectionEstablished_AfterHandshake_With_SpecificRequest await dialTask; - _ = channelFactory.Received().SubDialAndBind(downChannelFromProtocolPov, peerContext, proto1); + _ = peerContext.Received().Upgrade(downChannelFromProtocolPov, proto1); await downChannel.CloseAsync(); } @@ -80,13 +76,12 @@ public async Task Test_ConnectionClosed_ForUnknownProtocol() { IChannel downChannel = new TestChannel(); IChannel downChannelFromProtocolPov = ((TestChannel)downChannel).Reverse(); - IChannelFactory channelFactory = Substitute.For(); - IPeerContext peerContext = Substitute.For(); - peerContext.SpecificProtocolRequest.Returns((IChannelRequest?)null); + IConnectionContext peerContext = Substitute.For(); + peerContext.UpgradeOptions.Returns(new UpgradeOptions() { SelectedProtocol = null }); IProtocol? proto1 = Substitute.For(); proto1.Id.Returns("proto1"); - channelFactory.SubProtocols.Returns(new[] { proto1 }); + peerContext.SubProtocols.Returns([proto1]); MultistreamProtocol proto = new(); _ = Task.Run(async () => @@ -95,14 +90,14 @@ public async Task Test_ConnectionClosed_ForUnknownProtocol() await downChannel.WriteLineAsync("proto2"); }); - Task dialTask = proto.DialAsync(downChannelFromProtocolPov, channelFactory, peerContext); + Task dialTask = proto.DialAsync(downChannelFromProtocolPov, peerContext); Assert.That(await downChannel.ReadLineAsync(), Is.EqualTo(proto.Id)); Assert.That(await downChannel.ReadLineAsync(), Is.EqualTo("proto1")); await dialTask; - _ = channelFactory.DidNotReceive().SubDialAndBind(downChannelFromProtocolPov, peerContext, proto1); + _ = peerContext.DidNotReceive().Upgrade(downChannelFromProtocolPov, proto1); } [Test] @@ -110,21 +105,20 @@ public async Task Test_ConnectionEstablished_ForAnyOfProtocols() { IChannel downChannel = new TestChannel(); IChannel downChannelFromProtocolPov = ((TestChannel)downChannel).Reverse(); - IChannelFactory channelFactory = Substitute.For(); - IPeerContext peerContext = Substitute.For(); - peerContext.SpecificProtocolRequest.Returns((IChannelRequest?)null); + IConnectionContext peerContext = Substitute.For(); + peerContext.UpgradeOptions.Returns(new UpgradeOptions()); IProtocol? proto1 = Substitute.For(); proto1.Id.Returns("proto1"); IProtocol? proto2 = Substitute.For(); proto2.Id.Returns("proto2"); - channelFactory.SubProtocols.Returns(new[] { proto1, proto2 }); + peerContext.SubProtocols.Returns([proto1, proto2]); IChannel upChannel = new TestChannel(); - channelFactory.SubDialAndBind(Arg.Any(), Arg.Any(), Arg.Any()) + peerContext.Upgrade(Arg.Any(), Arg.Any()) .Returns(Task.CompletedTask); MultistreamProtocol proto = new(); - Task dialTask = proto.DialAsync(downChannelFromProtocolPov, channelFactory, peerContext); + Task dialTask = proto.DialAsync(downChannelFromProtocolPov, peerContext); _ = Task.Run(async () => { await downChannel.WriteLineAsync(proto.Id); @@ -138,7 +132,7 @@ public async Task Test_ConnectionEstablished_ForAnyOfProtocols() await dialTask; - _ = channelFactory.Received().SubDialAndBind(downChannelFromProtocolPov, peerContext, proto2); + _ = peerContext.Received().Upgrade(downChannelFromProtocolPov, proto2); await upChannel.CloseAsync(); } @@ -147,18 +141,17 @@ public async Task Test_ConnectionClosed_ForBadProtocol() { IChannel downChannel = new TestChannel(); IChannel downChannelFromProtocolPov = ((TestChannel)downChannel).Reverse(); - IChannelFactory channelFactory = Substitute.For(); - IPeerContext peerContext = Substitute.For(); - peerContext.SpecificProtocolRequest.Returns((IChannelRequest?)null); + IConnectionContext peerContext = Substitute.For(); + peerContext.UpgradeOptions.Returns(new UpgradeOptions()); IProtocol? proto1 = Substitute.For(); proto1.Id.Returns("proto1"); IProtocol? proto2 = Substitute.For(); proto1.Id.Returns("proto2"); - channelFactory.SubProtocols.Returns(new[] { proto1, proto2 }); + peerContext.SubProtocols.Returns([proto1, proto2]); MultistreamProtocol proto = new(); - Task dialTask = proto.DialAsync(downChannelFromProtocolPov, channelFactory, peerContext); + Task dialTask = proto.DialAsync(downChannelFromProtocolPov, peerContext); _ = Task.Run(async () => { await downChannel.WriteLineAsync(proto.Id); @@ -171,6 +164,6 @@ public async Task Test_ConnectionClosed_ForBadProtocol() await dialTask; - _ = channelFactory.DidNotReceiveWithAnyArgs().SubDialAndBind(null!, null!, (IProtocol)null!); + _ = peerContext.DidNotReceiveWithAnyArgs().Upgrade(Arg.Any(), Arg.Any()); } } diff --git a/src/libp2p/Libp2p.Protocols.Multistream.Tests/Usings.cs b/src/libp2p/Libp2p.Protocols.Multistream.Tests/Usings.cs index 852eb690..e160e2a1 100644 --- a/src/libp2p/Libp2p.Protocols.Multistream.Tests/Usings.cs +++ b/src/libp2p/Libp2p.Protocols.Multistream.Tests/Usings.cs @@ -1,7 +1,4 @@ // SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited // SPDX-License-Identifier: MIT -global using Nethermind.Libp2p.Core; -global using Nethermind.Libp2p.Core.TestsBase; -global using NSubstitute; global using NUnit.Framework; diff --git a/src/libp2p/Libp2p.Protocols.Multistream/MultistreamProtocol.cs b/src/libp2p/Libp2p.Protocols.Multistream/MultistreamProtocol.cs index 94696151..51a34c20 100644 --- a/src/libp2p/Libp2p.Protocols.Multistream/MultistreamProtocol.cs +++ b/src/libp2p/Libp2p.Protocols.Multistream/MultistreamProtocol.cs @@ -9,7 +9,7 @@ namespace Nethermind.Libp2p.Protocols; /// /// https://github.com/multiformats/multistream-select /// -public class MultistreamProtocol : IProtocol +public class MultistreamProtocol : IConnectionProtocol { private readonly ILogger? _logger; private const string ProtocolNotSupported = "na"; @@ -19,14 +19,18 @@ public MultistreamProtocol(ILoggerFactory? loggerFactory = null) { _logger = loggerFactory?.CreateLogger(); } - public async Task DialAsync(IChannel channel, IChannelFactory? channelFactory, - IPeerContext context) + + public async Task DialAsync(IChannel channel, IConnectionContext context) { + _logger?.LogTrace($"Hello started"); + if (!await SendHello(channel)) { await channel.CloseAsync(); + _logger?.LogTrace($"Hello failed"); return; } + _logger?.LogTrace($"Hello passed"); async Task DialProtocol(IProtocol selector) { @@ -48,19 +52,17 @@ public async Task DialAsync(IChannel channel, IChannelFactory? channelFactory, IProtocol? selected = null; - if (context.SpecificProtocolRequest?.SubProtocol is not null) + if (context.UpgradeOptions?.SelectedProtocol is not null) { - selected = context.SpecificProtocolRequest.SubProtocol; - - context.SpecificProtocolRequest = null; - if (await DialProtocol(selected) != true) + _logger?.LogDebug($"Proposing just {context.UpgradeOptions.SelectedProtocol}"); + if (await DialProtocol(context.UpgradeOptions.SelectedProtocol) == true) { - return; + selected = context.UpgradeOptions.SelectedProtocol; } } else { - foreach (IProtocol selector in channelFactory!.SubProtocols) + foreach (IProtocol selector in context!.SubProtocols) { bool? dialResult = await DialProtocol(selector); if (dialResult == true) @@ -80,12 +82,11 @@ public async Task DialAsync(IChannel channel, IChannelFactory? channelFactory, _logger?.LogDebug($"Negotiation failed"); return; } - _logger?.LogDebug($"Protocol selected during dialing: {selected}"); - await channelFactory.SubDialAndBind(channel, context, selected); + _logger?.LogDebug($"Protocol selected during dialing: {selected.Id}"); + await context.Upgrade(channel, selected); } - public async Task ListenAsync(IChannel channel, IChannelFactory? channelFactory, - IPeerContext context) + public async Task ListenAsync(IChannel channel, IConnectionContext context) { if (!await SendHello(channel)) { @@ -97,7 +98,7 @@ public async Task ListenAsync(IChannel channel, IChannelFactory? channelFactory, for (; ; ) { string proto = await channel.ReadLineAsync(); - selected = channelFactory!.SubProtocols.FirstOrDefault(x => x.Id == proto); + selected = context.SubProtocols.FirstOrDefault(x => x.Id == proto) as IProtocol; if (selected is not null) { await channel.WriteLineAsync(selected.Id); @@ -116,7 +117,7 @@ public async Task ListenAsync(IChannel channel, IChannelFactory? channelFactory, } _logger?.LogDebug($"Protocol selected during listening: {selected}"); - await channelFactory.SubListenAndBind(channel, context, selected); + await context.Upgrade(channel, selected); } private async Task SendHello(IChannel channel) diff --git a/src/libp2p/Libp2p.Protocols.Noise.Tests/Libp2p.Protocols.Noise.Tests.csproj b/src/libp2p/Libp2p.Protocols.Noise.Tests/Libp2p.Protocols.Noise.Tests.csproj index 860c1a07..e4deaead 100644 --- a/src/libp2p/Libp2p.Protocols.Noise.Tests/Libp2p.Protocols.Noise.Tests.csproj +++ b/src/libp2p/Libp2p.Protocols.Noise.Tests/Libp2p.Protocols.Noise.Tests.csproj @@ -9,7 +9,7 @@ - + diff --git a/src/libp2p/Libp2p.Protocols.Noise.Tests/NoiseProtocolTests.cs b/src/libp2p/Libp2p.Protocols.Noise.Tests/NoiseProtocolTests.cs index fa53f1f0..498334b6 100644 --- a/src/libp2p/Libp2p.Protocols.Noise.Tests/NoiseProtocolTests.cs +++ b/src/libp2p/Libp2p.Protocols.Noise.Tests/NoiseProtocolTests.cs @@ -1,6 +1,11 @@ // SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited // SPDX-License-Identifier: MIT +using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Core.TestsBase; +using NSubstitute; +using NUnit.Framework; + namespace Nethermind.Libp2p.Protocols.Noise.Tests; [TestFixture] @@ -13,9 +18,6 @@ public async Task Test_ConnectionEstablished_AfterHandshake() // Arrange IChannel downChannel = new TestChannel(); IChannel downChannelFromProtocolPov = ((TestChannel)downChannel).Reverse(); - IChannelFactory channelFactory = Substitute.For(); - IPeerContext peerContext = Substitute.For(); - IPeerContext listenerContext = Substitute.For(); IProtocol? proto1 = Substitute.For(); proto1.Id.Returns("proto1"); @@ -23,101 +25,52 @@ public async Task Test_ConnectionEstablished_AfterHandshake() IProtocol? proto2 = Substitute.For(); proto2.Id.Returns("proto2"); - channelFactory.SubProtocols.Returns([proto1, proto2]); + // Dialer + MultiplexerSettings dialerSettings = new(); + dialerSettings.Add(proto2); + dialerSettings.Add(proto1); - TestChannel upChannel = new TestChannel(); - channelFactory.SubDial(Arg.Any(), Arg.Any()) - .Returns(upChannel); + IConnectionContext dialerContext = Substitute.For(); + dialerContext.Peer.Identity.Returns(TestPeers.Identity(1)); + dialerContext.Peer.ListenAddresses.Returns([TestPeers.Multiaddr(1)]); + dialerContext.State.Returns(new State() { RemoteAddress = $"/ip4/0.0.0.0/tcp/0/p2p/{TestPeers.PeerId(2)}" }); - TestChannel listenerUpChannel = new TestChannel(); - channelFactory.SubListen(Arg.Any(), Arg.Any()) - .Returns(listenerUpChannel); + TestChannel dialerUpChannel = new(); + dialerContext.Upgrade(Arg.Any()).Returns(dialerUpChannel); - var i_multiplexerSettings = new MultiplexerSettings(); - var r_multiplexerSettings = new MultiplexerSettings(); - r_multiplexerSettings.Add(proto2); - r_multiplexerSettings.Add(proto1); - i_multiplexerSettings.Add(proto1); + NoiseProtocol dialer = new(dialerSettings); - NoiseProtocol proto_initiator = new(i_multiplexerSettings); - NoiseProtocol proto_responder = new(r_multiplexerSettings); + // Listener + MultiplexerSettings listenerSettings = new(); + listenerSettings.Add(proto1); - peerContext.LocalPeer.Identity.Returns(new Identity()); - listenerContext.LocalPeer.Identity.Returns(new Identity()); + IConnectionContext listenerContext = Substitute.For(); + listenerContext.Peer.Identity.Returns(TestPeers.Identity(2)); + listenerContext.Peer.ListenAddresses.Returns([TestPeers.Multiaddr(2)]); + listenerContext.State.Returns(new State() { RemoteAddress = $"/ip4/0.0.0.0/tcp/0/p2p/{TestPeers.PeerId(1)}" }); - string peerId = peerContext.LocalPeer.Identity.PeerId.ToString(); - Multiaddress localAddr = $"/ip4/0.0.0.0/tcp/0/p2p/{peerId}"; - peerContext.RemotePeer.Address.Returns(localAddr); + TestChannel listenerUpChannel = new(); + listenerContext.Upgrade(Arg.Any()).Returns(listenerUpChannel); - string listenerPeerId = listenerContext.LocalPeer.Identity.PeerId.ToString(); - Multiaddress listenerAddr = $"/ip4/0.0.0.0/tcp/0/p2p/{listenerPeerId}"; - listenerContext.RemotePeer.Address.Returns(listenerAddr); + NoiseProtocol listener = new(listenerSettings); // Act - Task listenTask = proto_responder.ListenAsync(downChannel, channelFactory, listenerContext); - Task dialTask = proto_initiator.DialAsync(downChannelFromProtocolPov, channelFactory, peerContext); + Task listenTask = listener.ListenAsync(downChannel, listenerContext); + Task dialTask = dialer.DialAsync(downChannelFromProtocolPov, dialerContext); int sent = 42; - ValueTask writeTask = upChannel.Reverse().WriteVarintAsync(sent); + ValueTask writeTask = dialerUpChannel.Reverse().WriteVarintAsync(sent); int received = await listenerUpChannel.Reverse().ReadVarintAsync(); await writeTask; - await upChannel.CloseAsync(); + await dialerUpChannel.CloseAsync(); await listenerUpChannel.CloseAsync(); await downChannel.CloseAsync(); - Assert.That(received, Is.EqualTo(sent)); - } - - [Test] - public async Task Test_ConnectionEstablished_With_PreSelectedMuxer() - { - // Arrange - IChannel downChannel = new TestChannel(); - IChannel downChannelFromProtocolPov = ((TestChannel)downChannel).Reverse(); - IChannelFactory channelFactory = Substitute.For(); - IPeerContext peerContext = Substitute.For(); - IPeerContext listenerContext = Substitute.For(); - - IProtocol? proto1 = Substitute.For(); - proto1.Id.Returns("proto1"); + await dialTask; + await listenTask; - IProtocol? proto2 = Substitute.For(); - proto2.Id.Returns("proto2"); - - channelFactory.SubProtocols.Returns(new[] { proto1, proto2 }); - - - var i_multiplexerSettings = new MultiplexerSettings(); - var r_multiplexerSettings = new MultiplexerSettings(); - r_multiplexerSettings.Add(proto2); - r_multiplexerSettings.Add(proto1); - i_multiplexerSettings.Add(proto1); - - NoiseProtocol proto_initiator = new(i_multiplexerSettings); - NoiseProtocol proto_responder = new(r_multiplexerSettings); - - peerContext.LocalPeer.Identity.Returns(new Identity()); - listenerContext.LocalPeer.Identity.Returns(new Identity()); - string peerId = peerContext.LocalPeer.Identity.PeerId.ToString(); - Multiaddress localAddr = $"/ip4/0.0.0.0/tcp/0/p2p/{peerId}"; - peerContext.RemotePeer.Address.Returns(localAddr); - - string listenerPeerId = listenerContext.LocalPeer.Identity.PeerId.ToString(); - Multiaddress listenerAddr = $"/ip4/0.0.0.0/tcp/0/p2p/{listenerPeerId}"; - listenerContext.RemotePeer.Address.Returns(listenerAddr); - - // Act - Task listenTask = proto_responder.ListenAsync(downChannel, channelFactory, listenerContext); - Task dialTask = proto_initiator.DialAsync(downChannelFromProtocolPov, channelFactory, peerContext); - - await Task.Delay(TimeSpan.FromSeconds(2)); - - // Assert - Assert.That(peerContext.SpecificProtocolRequest.SubProtocol, Is.EqualTo(proto1)); - - // Cleanup - await downChannel.CloseAsync(); + Assert.That(received, Is.EqualTo(sent)); } } diff --git a/src/libp2p/Libp2p.Protocols.Noise.Tests/Usings.cs b/src/libp2p/Libp2p.Protocols.Noise.Tests/Usings.cs index 07d11c14..29387723 100644 --- a/src/libp2p/Libp2p.Protocols.Noise.Tests/Usings.cs +++ b/src/libp2p/Libp2p.Protocols.Noise.Tests/Usings.cs @@ -1,9 +1,3 @@ // SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited // SPDX-License-Identifier: MIT -global using Nethermind.Libp2p.Core; -global using Nethermind.Libp2p.Core.TestsBase; -global using NSubstitute; -global using NUnit.Framework; -global using Multiformats.Address; -global using System.Threading.Tasks; diff --git a/src/libp2p/Libp2p.Protocols.Noise/NoiseProtocol.cs b/src/libp2p/Libp2p.Protocols.Noise/NoiseProtocol.cs index 5088b82e..34598b41 100644 --- a/src/libp2p/Libp2p.Protocols.Noise/NoiseProtocol.cs +++ b/src/libp2p/Libp2p.Protocols.Noise/NoiseProtocol.cs @@ -7,7 +7,6 @@ using Nethermind.Libp2p.Core; using Noise; using System.Text; -using Org.BouncyCastle.Math.EC.Rfc8032; using Microsoft.Extensions.Logging; using Multiformats.Address.Protocols; using Nethermind.Libp2p.Protocols.Noise.Dto; @@ -17,7 +16,7 @@ namespace Nethermind.Libp2p.Protocols; /// /// -public class NoiseProtocol(MultiplexerSettings? multiplexerSettings = null, ILoggerFactory? loggerFactory = null) : IProtocol +public class NoiseProtocol(MultiplexerSettings? multiplexerSettings = null, ILoggerFactory? loggerFactory = null) : IConnectionProtocol { private readonly Protocol _protocol = new( HandshakePattern.XX, @@ -26,20 +25,21 @@ public class NoiseProtocol(MultiplexerSettings? multiplexerSettings = null, ILog ); private readonly ILogger? _logger = loggerFactory?.CreateLogger(); - private readonly NoiseExtensions _extensions = new NoiseExtensions() + private NoiseExtensions _extensions => new() { - StreamMuxers = - { - multiplexerSettings is null ? ["na"] : !multiplexerSettings.Multiplexers.Any() ? ["na"] : [.. multiplexerSettings.Multiplexers.Select(proto => proto.Id)] - } + StreamMuxers = { } // TODO: return the following after go question resolution: + //{ + // multiplexerSettings is null || !multiplexerSettings.Multiplexers.Any() ? ["na"] : [.. multiplexerSettings.Multiplexers.Select(proto => proto.Id)] + //} }; public string Id => "/noise"; + private const string PayloadSigPrefix = "noise-libp2p-static-key:"; - public async Task DialAsync(IChannel downChannel, IChannelFactory? upChannelFactory, IPeerContext context) + public async Task DialAsync(IChannel downChannel, IConnectionContext context) { - ArgumentNullException.ThrowIfNull(upChannelFactory); + ArgumentNullException.ThrowIfNull(context.State.RemoteAddress); KeyPair? clientStatic = KeyPair.Generate(); using HandshakeState? handshakeState = _protocol.Create(true, s: clientStatic.PrivateKey); @@ -59,33 +59,37 @@ public async Task DialAsync(IChannel downChannel, IChannelFactory? upChannelFact (int BytesRead, byte[] HandshakeHash, Transport Transport) msg1 = handshakeState.ReadMessage(received.ToArray(), buffer); NoiseHandshakePayload? msg1Decoded = NoiseHandshakePayload.Parser.ParseFrom(buffer.AsSpan(0, msg1.BytesRead)); + PublicKey? msg1KeyDecoded = PublicKey.Parser.ParseFrom(msg1Decoded.IdentityKey); - //var key = new byte[] { 0x1 }.Concat(clientStatic.PublicKey).ToArray(); + context.State.RemotePublicKey = msg1KeyDecoded; + // TODO: verify signature + List responderMuxers = msg1Decoded.Extensions.StreamMuxers .Where(m => !string.IsNullOrEmpty(m)) .ToList(); - IProtocol? commonMuxer = multiplexerSettings?.Multiplexers.FirstOrDefault(m => responderMuxers.Contains(m.Id)); + IProtocol? commonMuxer = null;// multiplexerSettings?.Multiplexers.FirstOrDefault(m => responderMuxers.Contains(m.Id)); + + UpgradeOptions? upgradeOptions = null; + if (commonMuxer is not null) { - context.SpecificProtocolRequest = new ChannelRequest + upgradeOptions = new UpgradeOptions { - SubProtocol = commonMuxer, - CompletionSource = context.SpecificProtocolRequest?.CompletionSource + SelectedProtocol = commonMuxer, }; } PeerId remotePeerId = new(msg1KeyDecoded); - if (!context.RemotePeer.Address.Has()) + if (!context.State.RemoteAddress.Has()) { - context.RemotePeer.Address.Add(new P2P(remotePeerId.ToString())); + context.State.RemoteAddress.Add(new P2P(remotePeerId.ToString())); } byte[] msg = [.. Encoding.UTF8.GetBytes(PayloadSigPrefix), .. ByteString.CopyFrom(clientStatic.PublicKey)]; - byte[] sig = new byte[64]; - Ed25519.Sign([.. context.LocalPeer.Identity.PrivateKey!.Data], 0, msg, 0, msg.Length, sig, 0); + byte[] sig = context.Peer.Identity.Sign(msg); NoiseHandshakePayload payload = new() { - IdentityKey = context.LocalPeer.Identity.PublicKey.ToByteString(), + IdentityKey = context.Peer.Identity.PublicKey.ToByteString(), IdentitySig = ByteString.CopyFrom(sig), Extensions = _extensions }; @@ -104,9 +108,9 @@ public async Task DialAsync(IChannel downChannel, IChannelFactory? upChannelFact await downChannel.WriteAsync(new ReadOnlySequence(buffer, 0, msg2.BytesWritten)); Transport? transport = msg2.Transport; - _logger?.LogDebug("Established connection to {peer}", context.RemotePeer.Address); + _logger?.LogDebug("Established connection to {peer}", context.State.RemoteAddress); - IChannel upChannel = upChannelFactory.SubDial(context); + IChannel upChannel = context.Upgrade(upgradeOptions); await ExchangeData(transport, downChannel, upChannel); @@ -114,9 +118,9 @@ public async Task DialAsync(IChannel downChannel, IChannelFactory? upChannelFact _logger?.LogDebug("Closed"); } - public async Task ListenAsync(IChannel downChannel, IChannelFactory? upChannelFactory, IPeerContext context) + public async Task ListenAsync(IChannel downChannel, IConnectionContext context) { - ArgumentNullException.ThrowIfNull(upChannelFactory); + ArgumentNullException.ThrowIfNull(context.State.RemoteAddress); KeyPair? serverStatic = KeyPair.Generate(); using HandshakeState? handshakeState = @@ -132,11 +136,11 @@ public async Task ListenAsync(IChannel downChannel, IChannelFactory? upChannelFa byte[] msg = Encoding.UTF8.GetBytes(PayloadSigPrefix) .Concat(ByteString.CopyFrom(serverStatic.PublicKey)) .ToArray(); - byte[] sig = new byte[64]; - Ed25519.Sign(context.LocalPeer.Identity.PrivateKey!.Data.ToArray(), 0, msg, 0, msg.Length, sig, 0); + byte[] sig = context.Peer.Identity.Sign(msg); + NoiseHandshakePayload payload = new() { - IdentityKey = context.LocalPeer.Identity.PublicKey.ToByteString(), + IdentityKey = context.Peer.Identity.PublicKey.ToByteString(), IdentitySig = ByteString.CopyFrom(sig), Extensions = _extensions }; @@ -155,30 +159,32 @@ public async Task ListenAsync(IChannel downChannel, IChannelFactory? upChannelFa handshakeState.ReadMessage(hs2Bytes.ToArray(), buffer); NoiseHandshakePayload? msg2Decoded = NoiseHandshakePayload.Parser.ParseFrom(buffer.AsSpan(0, msg2.BytesRead)); PublicKey? msg2KeyDecoded = PublicKey.Parser.ParseFrom(msg2Decoded.IdentityKey); - Transport? transport = msg2.Transport; - - PeerId remotePeerId = new(msg2KeyDecoded); + context.State.RemotePublicKey = msg2KeyDecoded; + // TODO: verify signature + Transport? transport = msg2.Transport; List initiatorMuxers = msg2Decoded.Extensions.StreamMuxers.Where(m => !string.IsNullOrEmpty(m)).ToList(); - IProtocol? commonMuxer = multiplexerSettings?.Multiplexers.FirstOrDefault(m => initiatorMuxers.Contains(m.Id)); + IProtocol? commonMuxer = null; // multiplexerSettings?.Multiplexers.FirstOrDefault(m => initiatorMuxers.Contains(m.Id)); + + UpgradeOptions? upgradeOptions = null; if (commonMuxer is not null) { - context.SpecificProtocolRequest = new ChannelRequest + upgradeOptions = new UpgradeOptions { - SubProtocol = commonMuxer, - CompletionSource = context.SpecificProtocolRequest?.CompletionSource + SelectedProtocol = commonMuxer, }; } - if (!context.RemotePeer.Address.Has()) + if (!context.State.RemoteAddress.Has()) { - context.RemotePeer.Address.Add(new P2P(remotePeerId.ToString())); + PeerId remotePeerId = new(msg2KeyDecoded); + context.State.RemoteAddress.Add(new P2P(remotePeerId.ToString())); } - _logger?.LogDebug("Established connection to {peer}", context.RemotePeer.Address); + _logger?.LogDebug("Established connection to {peer}", context.State.RemoteAddress); - IChannel upChannel = upChannelFactory.SubListen(context); + IChannel upChannel = context.Upgrade(upgradeOptions); await ExchangeData(transport, downChannel, upChannel); @@ -240,9 +246,6 @@ private static Task ExchangeData(Transport transport, IChannel downChannel, ICha } }); - return Task.WhenAny(t, t2).ContinueWith((t) => - { - - }); + return Task.WhenAll(t, t2); } } diff --git a/src/libp2p/Libp2p.Protocols.Ping/PingProtocol.cs b/src/libp2p/Libp2p.Protocols.Ping/PingProtocol.cs index ce49759d..3fa2178a 100644 --- a/src/libp2p/Libp2p.Protocols.Ping/PingProtocol.cs +++ b/src/libp2p/Libp2p.Protocols.Ping/PingProtocol.cs @@ -11,7 +11,7 @@ namespace Nethermind.Libp2p.Protocols; /// /// https://github.com/libp2p/specs/blob/master/ping/ping.md /// -public class PingProtocol : IProtocol +public class PingProtocol : ISessionProtocol { private const int PayloadLength = 32; @@ -24,39 +24,41 @@ public PingProtocol(ILoggerFactory? loggerFactory = null) _logger = loggerFactory?.CreateLogger(); } - public async Task DialAsync(IChannel channel, IChannelFactory? channelFactory, - IPeerContext context) + public async Task DialAsync(IChannel channel, ISessionContext context) { + ArgumentNullException.ThrowIfNull(context.State.RemoteAddress); + byte[] ping = new byte[PayloadLength]; _random.NextBytes(ping.AsSpan(0, PayloadLength)); ReadOnlySequence bytes = new(ping); - _logger?.LogPing(context.RemotePeer.Address); + _logger?.LogPing(context.State.RemoteAddress); await channel.WriteAsync(bytes); _logger?.LogTrace("Sent ping: {ping}", Convert.ToHexString(ping)); - _logger?.ReadingPong(context.RemotePeer.Address); + _logger?.ReadingPong(context.State.RemoteAddress); ReadOnlySequence response = await channel.ReadAsync(PayloadLength, ReadBlockingMode.WaitAll).OrThrow(); _logger?.LogTrace("Received pong: {ping}", Convert.ToHexString(ping)); - _logger?.VerifyingPong(context.RemotePeer.Address); + _logger?.VerifyingPong(context.State.RemoteAddress); if (!ping[0..PayloadLength].SequenceEqual(response.ToArray())) { - _logger?.PingFailed(context.RemotePeer.Address); + _logger?.PingFailed(context.State.RemoteAddress); throw new ApplicationException(); } - _logger?.LogPinged(context.RemotePeer.Address); + _logger?.LogPinged(context.State.RemoteAddress); } - public async Task ListenAsync(IChannel channel, IChannelFactory? channelFactory, - IPeerContext context) + public async Task ListenAsync(IChannel channel, ISessionContext context) { - _logger?.PingListenStarted(context.RemotePeer.Address); + ArgumentNullException.ThrowIfNull(context.State.RemoteAddress); + + _logger?.PingListenStarted(context.State.RemoteAddress); while (true) { - _logger?.ReadingPing(context.RemotePeer.Address); + _logger?.ReadingPing(context.State.RemoteAddress); ReadResult read = await channel.ReadAsync(PayloadLength, ReadBlockingMode.WaitAny); if (read.Result != IOResult.Ok) { @@ -66,11 +68,11 @@ public async Task ListenAsync(IChannel channel, IChannelFactory? channelFactory, byte[] ping = read.Data.ToArray(); _logger?.LogTrace("Received ping: {ping}", Convert.ToHexString(ping)); - _logger?.ReturningPong(context.RemotePeer.Address); + _logger?.ReturningPong(context.State.RemoteAddress); await channel.WriteAsync(new ReadOnlySequence(ping)); _logger?.LogTrace("Sent pong: {ping}", Convert.ToHexString(ping)); } - _logger?.PingFinished(context.RemotePeer.Address); + _logger?.PingFinished(context.State.RemoteAddress); } } diff --git a/src/libp2p/Libp2p.Protocols.Plaintext/PlainTextProtocol.cs b/src/libp2p/Libp2p.Protocols.Plaintext/PlainTextProtocol.cs index 6dc12e34..d38d582d 100644 --- a/src/libp2p/Libp2p.Protocols.Plaintext/PlainTextProtocol.cs +++ b/src/libp2p/Libp2p.Protocols.Plaintext/PlainTextProtocol.cs @@ -10,17 +10,16 @@ namespace Nethermind.Libp2p.Protocols; /// /// -public class PlainTextProtocol : SymmetricProtocol, IProtocol +public class PlainTextProtocol : SymmetricProtocol, IConnectionProtocol { public string Id => "/plaintext/2.0.0"; - protected override async Task ConnectAsync(IChannel channel, IChannelFactory? channelFactory, - IPeerContext context, bool isListener) + protected override async Task ConnectAsync(IChannel channel, IConnectionContext context, bool isListener) { Exchange src = new() { - Id = ByteString.CopyFrom(context.LocalPeer.Identity.PeerId.Bytes), - Pubkey = context.LocalPeer.Identity.PublicKey.ToByteString() + Id = ByteString.CopyFrom(context.Peer.Identity.PeerId.Bytes), + Pubkey = context.Peer.Identity.PublicKey.ToByteString() }; int size = src.CalculateSize(); int sizeOfSize = VarInt.GetSizeInBytes(size); @@ -35,8 +34,6 @@ protected override async Task ConnectAsync(IChannel channel, IChannelFactory? ch buf = (await channel.ReadAsync(structSize).OrThrow()).ToArray(); Exchange? dest = Exchange.Parser.ParseFrom(buf); - await (isListener - ? channelFactory.SubListenAndBind(channel, context) - : channelFactory.SubDialAndBind(channel, context)); + await context.Upgrade(channel); } } diff --git a/src/libp2p/Libp2p.Protocols.Pubsub.E2eTests/Libp2p.Protocols.Pubsub.E2eTests.csproj b/src/libp2p/Libp2p.Protocols.Pubsub.E2eTests/Libp2p.Protocols.Pubsub.E2eTests.csproj new file mode 100644 index 00000000..990a0d3f --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.Pubsub.E2eTests/Libp2p.Protocols.Pubsub.E2eTests.csproj @@ -0,0 +1,32 @@ + + + + enable + enable + + + + + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + + + diff --git a/src/libp2p/Libp2p.Protocols.Pubsub.E2eTests/PubsubE2eTestSetup.cs b/src/libp2p/Libp2p.Protocols.Pubsub.E2eTests/PubsubE2eTestSetup.cs new file mode 100644 index 00000000..990f78c5 --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.Pubsub.E2eTests/PubsubE2eTestSetup.cs @@ -0,0 +1,91 @@ +using Libp2p.E2eTests; +using Microsoft.Extensions.DependencyInjection; +using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Protocols.Pubsub; +using System.Text; + +namespace Libp2p.Protocols.Pubsub.E2eTests; + +public class PubsubE2eTestSetup : E2eTestSetup +{ + public PubsubSettings DefaultSettings { get; set; } = new PubsubSettings { LowestDegree = 2, Degree = 3, LazyDegree = 3, HighestDegree = 4, HeartbeatInterval = 200 }; + public Dictionary Routers { get; } = []; + + + protected override IPeerFactoryBuilder ConfigureLibp2p(ILibp2pPeerFactoryBuilder builder) + { + return base.ConfigureLibp2p(builder.WithPubsub()); + } + + protected override IServiceCollection ConfigureServices(IServiceCollection col) + { + return base.ConfigureServices(col); + } + + protected override void AddToPrintState(StringBuilder sb, int index) + { + base.AddToPrintState(sb, index); + sb.AppendLine(Routers[index].ToString()); + } + + protected override void AddAt(int index) + { + base.AddAt(index); + Routers[index] = ServiceProviders[index].GetService()!; + _ = Routers[index].StartAsync(Peers[index]); + } + + /// + /// Manual heartbeat in case the period is set to infinite + /// + /// + public async Task Heartbeat() + { + foreach (PubsubRouter router in Routers.Values) + { + await router.Heartbeat(); + } + } + + public void Subscribe(string topic) + { + foreach (PubsubRouter router in Routers.Values) + { + router.GetTopic(topic); + } + } + + public async Task WaitForFullMeshAsync(string topic, int timeoutMs = 15_000) + { + int requiredCount = int.Min(Routers.Count - 1, DefaultSettings.LowestDegree); + + CancellationTokenSource cts = new(); + Task delayTask = Task.Delay(timeoutMs).ContinueWith((t) => cts.Cancel()); + + while (true) + { + PrintState(); + + if (cts.IsCancellationRequested) + { + throw new Exception("Timeout waiting for the network"); + } + + + cts.Token.ThrowIfCancellationRequested(); + await Task.Delay(1000); + + bool stillWaiting = false; + + foreach (IRoutingStateContainer router in Routers.Values) + { + if (router.Mesh[topic].Count < requiredCount) + { + stillWaiting = true; + } + } + + if (!stillWaiting) break; + } + } +} diff --git a/src/libp2p/Libp2p.Protocols.Pubsub.Profiler/Libp2p.Protocols.Pubsub.Profiler.csproj b/src/libp2p/Libp2p.Protocols.Pubsub.Profiler/Libp2p.Protocols.Pubsub.Profiler.csproj deleted file mode 100644 index b3a22fa1..00000000 --- a/src/libp2p/Libp2p.Protocols.Pubsub.Profiler/Libp2p.Protocols.Pubsub.Profiler.csproj +++ /dev/null @@ -1,16 +0,0 @@ - - - - Exe - net8.0 - enable - enable - - - - - - - - - diff --git a/src/libp2p/Libp2p.Protocols.Pubsub.Profiler/Program.cs b/src/libp2p/Libp2p.Protocols.Pubsub.Profiler/Program.cs deleted file mode 100644 index a4103fcf..00000000 --- a/src/libp2p/Libp2p.Protocols.Pubsub.Profiler/Program.cs +++ /dev/null @@ -1,95 +0,0 @@ -// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Nethermind.Libp2p.Core; -using Nethermind.Libp2p.Core.Discovery; -using Nethermind.Libp2p.Core.TestsBase; -using Nethermind.Libp2p.Core.TestsBase.E2e; -using Nethermind.Libp2p.Protocols; -using Nethermind.Libp2p.Protocols.Pubsub; -using System.Text; - -int totalCount = 7; -TestContextLoggerFactory fac = new(); -// There is common communication point -ChannelBus commonBus = new(fac); -ILocalPeer[] peers = new ILocalPeer[totalCount]; -PeerStore[] peerStores = new PeerStore[totalCount]; -PubsubRouter[] routers = new PubsubRouter[totalCount]; - - -for (int i = 0; i < totalCount; i++) -{ - // But we create a seprate setup for every peer - ServiceProvider sp = new ServiceCollection() - .AddSingleton(sp => new TestBuilder(commonBus, sp).AddAppLayerProtocol()) - .AddSingleton(sp => fac) - .AddSingleton() - .AddSingleton() - .AddSingleton(sp => new Settings { LowestDegree = 1, Degree = 2, LazyDegree = 2, HighestDegree = 3 }) - .AddSingleton(sp => sp.GetService()!.Build()) - .BuildServiceProvider(); - - IPeerFactory peerFactory = sp.GetService()!; - ILocalPeer peer = peers[i] = peerFactory.Create(TestPeers.Identity(i)); - PubsubRouter router = routers[i] = sp.GetService()!; - PubsubPeerDiscoveryProtocol disc = new(router, peerStores[i] = sp.GetService()!, new PubsubPeerDiscoverySettings() { Interval = 300 }, peer); - - await peer.ListenAsync(TestPeers.Multiaddr(i)); - _ = router.RunAsync(peer, sp.GetService()); - //_ = disc.DiscoverAsync(peer.Address); -} - -Console.WriteLine($"Emulate peer exchange with one bootstrap peer"); - -for (int i = 0; i < routers.Length; i++) -{ - routers[i].GetTopic("test"); -} - -Console.WriteLine($"Center: {peers[0].Address}"); - -for (int i = 1; i < peers.Length; i++) -{ - peerStores[i].Discover([peers[0].Address]); -} - -await Task.Delay(10000); - -Console.WriteLine("Routers"); - -for (int i = 0; i < routers.Length; i++) -{ - Console.WriteLine(routers[i].ToString()); -} - -Console.WriteLine("Stores"); - -for (int i = 0; i < peerStores.Length; i++) -{ - Console.WriteLine(peerStores[i].ToString()); -} - - -await Task.Delay(5000); - -var testTopic = routers[1].GetTopic("test"); -var testTopicEnd = routers[totalCount - 1].GetTopic("test"); -testTopicEnd.OnMessage += (s) => Console.WriteLine(Encoding.UTF8.GetString(s)); - -testTopic.Publish(Encoding.UTF8.GetBytes("test")); - -for (int i = 0; i < 20; i++) -{ - Console.WriteLine(i * 100); - await Task.Delay(100); -} - -Console.WriteLine("Routers"); - -for (int i = 0; i < routers.Length; i++) -{ - Console.WriteLine(routers[i].ToString()); -} diff --git a/src/libp2p/Libp2p.Protocols.Pubsub.Tests/FloodsubProtocolTests.cs b/src/libp2p/Libp2p.Protocols.Pubsub.Tests/FloodsubProtocolTests.cs index 74dff9ec..e2701232 100644 --- a/src/libp2p/Libp2p.Protocols.Pubsub.Tests/FloodsubProtocolTests.cs +++ b/src/libp2p/Libp2p.Protocols.Pubsub.Tests/FloodsubProtocolTests.cs @@ -10,6 +10,7 @@ namespace Nethermind.Libp2p.Protocols.Pubsub.Tests; [TestFixture] public class FloodsubProtocolTests { + [Ignore("TODO")] [Test] public async Task Test_Peer_is_in_fpeers() { @@ -24,26 +25,26 @@ public async Task Test_Peer_is_in_fpeers() const string commonTopic = "topic1"; ILocalPeer peer = Substitute.For(); - peer.Address.Returns(localPeerAddr); + peer.ListenAddresses.Returns([localPeerAddr]); peer.Identity.Returns(TestPeers.Identity(2)); peer.DialAsync(discoveredPeerAddress, Arg.Any()).Returns(new TestRemotePeer(discoveredPeerAddress)); CancellationToken token = default; - List sentRpcs = new(); + List sentRpcs = []; - _ = router.RunAsync(peer, token: token); + _ = router.StartAsync(peer, token: token); router.GetTopic(commonTopic); Assert.That(state.FloodsubPeers.Keys, Has.Member(commonTopic)); peerStore.Discover([discoveredPeerAddress]); await Task.Delay(100); - _ = peer.Received().DialAsync(discoveredPeerAddress, Arg.Any()); + _ = peer.Received().DialAsync([discoveredPeerAddress], Arg.Any()); TaskCompletionSource tcs = new(); router.OutboundConnection(discoveredPeerAddress, PubsubRouter.FloodsubProtocolVersion, tcs.Task, sentRpcs.Add); router.InboundConnection(discoveredPeerAddress, PubsubRouter.FloodsubProtocolVersion, tcs.Task, tcs.Task, () => Task.CompletedTask); - await router.OnRpc(discoveredPeer.PeerId, new Rpc().WithTopics(new[] { commonTopic }, [])); + router.OnRpc(discoveredPeer.PeerId, new Rpc().WithTopics(new[] { commonTopic }, [])); Assert.Multiple(() => { diff --git a/src/libp2p/Libp2p.Protocols.Pubsub.Tests/GossipsubProtocolTests.cs b/src/libp2p/Libp2p.Protocols.Pubsub.Tests/GossipsubProtocolTests.cs index 54e79060..e190360a 100644 --- a/src/libp2p/Libp2p.Protocols.Pubsub.Tests/GossipsubProtocolTests.cs +++ b/src/libp2p/Libp2p.Protocols.Pubsub.Tests/GossipsubProtocolTests.cs @@ -15,20 +15,18 @@ public async Task Test_New_messages_are_sent_to_mesh_only() { PeerStore peerStore = new(); PubsubRouter router = new(peerStore); - Settings settings = new() { HeartbeatInterval = int.MaxValue }; + PubsubSettings settings = new() { HeartbeatInterval = int.MaxValue }; IRoutingStateContainer state = router; - int peerCount = Settings.Default.HighestDegree + 1; + int peerCount = PubsubSettings.Default.HighestDegree + 1; const string commonTopic = "topic1"; ILocalPeer peer = new LocalPeerStub(); - CancellationToken token = default; List sentRpcs = []; - _ = router.RunAsync(peer, token: token); router.GetTopic(commonTopic); Assert.That(state.FloodsubPeers.Keys, Has.Member(commonTopic)); Assert.That(state.GossipsubPeers.Keys, Has.Member(commonTopic)); - + await router.StartAsync(peer); TaskCompletionSource tcs = new(); foreach (int index in Enumerable.Range(1, peerCount)) @@ -39,7 +37,7 @@ public async Task Test_New_messages_are_sent_to_mesh_only() peerStore.Discover([discoveredPeer]); router.OutboundConnection(discoveredPeer, PubsubRouter.GossipsubProtocolVersionV10, tcs.Task, sentRpcs.Add); router.InboundConnection(discoveredPeer, PubsubRouter.GossipsubProtocolVersionV10, tcs.Task, tcs.Task, () => Task.CompletedTask); - await router.OnRpc(peerId, new Rpc().WithTopics([commonTopic], [])); + router.OnRpc(peerId, new Rpc().WithTopics([commonTopic], [])); } await router.Heartbeat(); @@ -47,7 +45,7 @@ public async Task Test_New_messages_are_sent_to_mesh_only() Assert.Multiple(() => { Assert.That(state.GossipsubPeers[commonTopic], Has.Count.EqualTo(peerCount)); - Assert.That(state.Mesh[commonTopic], Has.Count.EqualTo(Settings.Default.Degree)); + Assert.That(state.Mesh[commonTopic], Has.Count.EqualTo(PubsubSettings.Default.Degree)); }); tcs.SetResult(); diff --git a/src/libp2p/Libp2p.Protocols.Pubsub.Tests/PubsubProtocolTests.cs b/src/libp2p/Libp2p.Protocols.Pubsub.Tests/PubsubProtocolTests.cs index a7ce5aa9..ec222a15 100644 --- a/src/libp2p/Libp2p.Protocols.Pubsub.Tests/PubsubProtocolTests.cs +++ b/src/libp2p/Libp2p.Protocols.Pubsub.Tests/PubsubProtocolTests.cs @@ -1,10 +1,11 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT +//// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited +//// SPDX-License-Identifier: MIT + +//namespace Nethermind.Libp2p.Protocols.Pubsub.Tests; using Multiformats.Address; using Nethermind.Libp2p.Core.Discovery; - -namespace Nethermind.Libp2p.Protocols.Pubsub.Tests; +using Nethermind.Libp2p.Protocols.Pubsub; [TestFixture] public class PubsubProtocolTests @@ -15,25 +16,25 @@ public async Task Test_Peer_is_dialed_when_added_by_discovery() PeerStore peerStore = new(); PubsubRouter router = new(peerStore); IRoutingStateContainer state = router; - Multiaddress discoveredPeerAddr = TestPeers.Multiaddr(1); - Multiaddress localPeer = TestPeers.Multiaddr(2); + Multiaddress localPeerAddr = TestPeers.Multiaddr(1); + Multiaddress[] discoveredPeerAddrs = [TestPeers.Multiaddr(2)]; ILocalPeer peer = Substitute.For(); - peer.Address.Returns(localPeer); - peer.Identity.Returns(TestPeers.Identity(2)); - peer.DialAsync(discoveredPeerAddr, Arg.Any()).Returns(new TestRemotePeer(discoveredPeerAddr)); + peer.ListenAddresses.Returns([localPeerAddr]); + peer.Identity.Returns(TestPeers.Identity(1)); + peer.DialAsync(discoveredPeerAddrs, Arg.Any()).Returns(new TestRemotePeer(discoveredPeerAddrs[0])); CancellationToken token = default; TaskCompletionSource taskCompletionSource = new(); - _ = router.RunAsync(peer, token: token); - peerStore.Discover([discoveredPeerAddr]); + await router.StartAsync(peer, token: token); + peerStore.Discover(discoveredPeerAddrs); await Task.Delay(100); - _ = peer.Received().DialAsync(discoveredPeerAddr, Arg.Any()); + _ = peer.Received().DialAsync(discoveredPeerAddrs, Arg.Any()); - router.OutboundConnection(discoveredPeerAddr, PubsubRouter.FloodsubProtocolVersion, taskCompletionSource.Task, (rpc) => { }); - Assert.That(state.ConnectedPeers, Has.Member(discoveredPeerAddr.GetPeerId())); + router.OutboundConnection(discoveredPeerAddrs[0], PubsubRouter.FloodsubProtocolVersion, taskCompletionSource.Task, (rpc) => { }); + Assert.That(state.ConnectedPeers, Has.Member(discoveredPeerAddrs[0].GetPeerId())); taskCompletionSource.SetResult(); } } diff --git a/src/libp2p/Libp2p.Protocols.Pubsub/IRoutingStateContainer.cs b/src/libp2p/Libp2p.Protocols.Pubsub/IRoutingStateContainer.cs new file mode 100644 index 00000000..d5bf0ee7 --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.Pubsub/IRoutingStateContainer.cs @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Nethermind.Libp2p.Core; +using System.Collections.Concurrent; + +namespace Nethermind.Libp2p.Protocols.Pubsub; + +public interface IRoutingStateContainer +{ + ConcurrentDictionary> FloodsubPeers { get; } + ConcurrentDictionary> GossipsubPeers { get; } + ConcurrentDictionary> Mesh { get; } + ConcurrentDictionary> Fanout { get; } + ConcurrentDictionary FanoutLastPublished { get; } + ICollection ConnectedPeers { get; } + bool Started { get; } + Task Heartbeat(); +} diff --git a/src/libp2p/Libp2p.Protocols.Pubsub/ManagedPeer.cs b/src/libp2p/Libp2p.Protocols.Pubsub/ManagedPeer.cs deleted file mode 100644 index 6caa80f1..00000000 --- a/src/libp2p/Libp2p.Protocols.Pubsub/ManagedPeer.cs +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -using Nethermind.Libp2p.Core; - -namespace Nethermind.Libp2p.Protocols.Pubsub; -internal class ManagedPeer(ILocalPeer peer) -{ - internal async Task DialAsync(Multiaddress[] addrs, CancellationToken token) - { - Dictionary cancellations = new(); - foreach (Multiaddress addr in addrs) - { - cancellations[addr] = CancellationTokenSource.CreateLinkedTokenSource(token); - } - - Task timoutTask = Task.Delay(15_000, token); - Task> firstConnectedTask = Task.WhenAny(addrs - .Select(addr => peer.DialAsync(addr, cancellations[addr].Token))); - - Task wait = await Task.WhenAny(firstConnectedTask, timoutTask); - - if (wait == timoutTask) - { - throw new TimeoutException(); - } - - IRemotePeer firstConnected = firstConnectedTask.Result.Result; - - foreach (KeyValuePair c in cancellations) - { - if (c.Key != firstConnected.Address) - { - c.Value.Cancel(false); - } - } - - return firstConnected; - } -} diff --git a/src/libp2p/Libp2p.Protocols.Pubsub/Settings.cs b/src/libp2p/Libp2p.Protocols.Pubsub/PubSubSettings.cs similarity index 92% rename from src/libp2p/Libp2p.Protocols.Pubsub/Settings.cs rename to src/libp2p/Libp2p.Protocols.Pubsub/PubSubSettings.cs index 4ba2e3ae..3fc31cae 100644 --- a/src/libp2p/Libp2p.Protocols.Pubsub/Settings.cs +++ b/src/libp2p/Libp2p.Protocols.Pubsub/PubSubSettings.cs @@ -6,9 +6,9 @@ namespace Nethermind.Libp2p.Protocols.Pubsub; -public class Settings +public class PubsubSettings { - public static Settings Default { get; } = new(); + public static PubsubSettings Default { get; } = new(); public int ReconnectionAttempts { get; set; } = 10; public int ReconnectionPeriod { get; set; } = 15_000; @@ -17,6 +17,9 @@ public class Settings public int LowestDegree { get; set; } = 4; //Lower bound for outbound degree 4 public int HighestDegree { get; set; } = 12;//Upper bound for outbound degree 12 public int LazyDegree { get; set; } = 6;//(Optional) the outbound degree for gossip emission D + + public int MaxConnections { get; set; } + public int HeartbeatInterval { get; set; } = 1_000;//Time between heartbeats 1 second public int FanoutTtl { get; set; } = 60 * 1000;//Time-to-live for each topic's fanout state 60 seconds public int mcache_len { get; set; } = 5;//Number of history windows in message cache 5 diff --git a/src/libp2p/Libp2p.Protocols.Pubsub/PubsubProtocol.cs b/src/libp2p/Libp2p.Protocols.Pubsub/PubsubProtocol.cs index 52ff139e..e2752e5f 100644 --- a/src/libp2p/Libp2p.Protocols.Pubsub/PubsubProtocol.cs +++ b/src/libp2p/Libp2p.Protocols.Pubsub/PubsubProtocol.cs @@ -2,16 +2,16 @@ // SPDX-License-Identifier: MIT using Microsoft.Extensions.Logging; -using Multiformats.Address.Protocols; using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Protocols.Pubsub; using Nethermind.Libp2p.Protocols.Pubsub.Dto; -namespace Nethermind.Libp2p.Protocols.Pubsub; +namespace Nethermind.Libp2p.Protocols; /// /// https://github.com/libp2p/specs/tree/master/pubsub /// -public abstract class PubsubProtocol : IProtocol +public abstract class PubsubProtocol : ISessionProtocol { private readonly ILogger? _logger; private readonly PubsubRouter router; @@ -25,43 +25,49 @@ public PubsubProtocol(string protocolId, PubsubRouter router, ILoggerFactory? lo this.router = router; } - public async Task DialAsync(IChannel channel, IChannelFactory? channelFactory, - IPeerContext context) + public async Task DialAsync(IChannel channel, ISessionContext context) { - string peerId = context.RemotePeer.Address.Get().ToString()!; - _logger?.LogDebug($"{context.LocalPeer.Address} dials {context.RemotePeer.Address}"); + ArgumentNullException.ThrowIfNull(context.State.RemoteAddress); + ArgumentNullException.ThrowIfNull(context.State.RemotePeerId); + + PeerId? remotePeerId = context.State.RemotePeerId; + + _logger?.LogDebug($"Dialed({context.Id}) {context.State.RemoteAddress}"); TaskCompletionSource dialTcs = new(); - CancellationToken token = router.OutboundConnection(context.RemotePeer.Address, Id, dialTcs.Task, (rpc) => + CancellationToken token = router.OutboundConnection(context.State.RemoteAddress, Id, dialTcs.Task, (rpc) => { - var t = channel.WriteSizeAndProtobufAsync(rpc); + ValueTask t = channel.WriteSizeAndProtobufAsync(rpc); t.AsTask().ContinueWith((t) => { if (!t.IsCompletedSuccessfully) { - _logger?.LogWarning($"Sending RPC failed message to {peerId}: {rpc}"); + _logger?.LogWarning($"Sending RPC failed message to {remotePeerId}: {rpc}"); } }); - _logger?.LogTrace($"Sent message to {peerId}: {rpc}"); + _logger?.LogTrace($"Sent message to {remotePeerId}: {rpc}"); }); await channel; dialTcs.SetResult(); - _logger?.LogDebug($"Finished dial({context.Id}) {context.RemotePeer.Address}"); + _logger?.LogDebug($"Finished dial({context.Id}) {context.State.RemoteAddress}"); } - public async Task ListenAsync(IChannel channel, IChannelFactory? channelFactory, - IPeerContext context) + public async Task ListenAsync(IChannel channel, ISessionContext context) { - string peerId = context.RemotePeer.Address.Get().ToString()!; - _logger?.LogDebug($"{context.LocalPeer.Address} listens to {context.RemotePeer.Address}"); + ArgumentNullException.ThrowIfNull(context.State.RemoteAddress); + ArgumentNullException.ThrowIfNull(context.State.RemotePeerId); + + PeerId? remotePeerId = context.State.RemotePeerId; + + _logger?.LogDebug($"Listen({context.Id}) to {context.State.RemoteAddress}"); TaskCompletionSource listTcs = new(); TaskCompletionSource dialTcs = new(); - CancellationToken token = router.InboundConnection(context.RemotePeer.Address, Id, listTcs.Task, dialTcs.Task, () => + CancellationToken token = router.InboundConnection(context.State.RemoteAddress, Id, listTcs.Task, dialTcs.Task, () => { - context.SubDialRequests.Add(new ChannelRequest { SubProtocol = this }); + _ = context.DialAsync(this); return dialTcs.Task; }); @@ -70,23 +76,21 @@ public async Task ListenAsync(IChannel channel, IChannelFactory? channelFactory, Rpc? rpc = await channel.ReadAnyPrefixedProtobufAsync(Rpc.Parser, token); if (rpc is null) { - _logger?.LogDebug($"Received a broken message or EOF from {peerId}"); + _logger?.LogDebug($"Received a broken message or EOF from {remotePeerId}"); break; } else { - _logger?.LogTrace($"Received message from {peerId}: {rpc}"); - _ = router.OnRpc(peerId, rpc); + _logger?.LogTrace($"Received message from {remotePeerId}: {rpc}"); + router.OnRpc(remotePeerId, rpc); } } + listTcs.SetResult(); - _logger?.LogDebug($"Finished({context.Id}) list {context.RemotePeer.Address}"); + _logger?.LogDebug($"Finished({context.Id}) list {context.State.RemoteAddress}"); } - public override string ToString() - { - return Id; - } + public override string ToString() => Id; } public class FloodsubProtocol(PubsubRouter router, ILoggerFactory? loggerFactory = null) : PubsubProtocol(PubsubRouter.FloodsubProtocolVersion, router, loggerFactory); diff --git a/src/libp2p/Libp2p.Protocols.Pubsub/PubsubRouter.Rpc.cs b/src/libp2p/Libp2p.Protocols.Pubsub/PubsubRouter.Rpc.cs new file mode 100644 index 00000000..5eb582d0 --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.Pubsub/PubsubRouter.Rpc.cs @@ -0,0 +1,290 @@ +// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Google.Protobuf; +using Microsoft.Extensions.Logging; +using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Protocols.Pubsub.Dto; +using System.Collections.Concurrent; + +namespace Nethermind.Libp2p.Protocols.Pubsub; + +public partial class PubsubRouter : IRoutingStateContainer, IDisposable +{ + internal void OnRpc(PeerId peerId, Rpc rpc) + { + try + { + ConcurrentDictionary peerMessages = new(); + lock (this) + { + if (rpc.Publish.Count != 0) + { + HandleNewMessages(peerId, rpc.Publish, peerMessages); + } + + if (rpc.Subscriptions.Count != 0) + { + HandleSubscriptions(peerId, rpc.Subscriptions); + } + + if (rpc.Control is not null) + { + if (rpc.Control.Graft.Count != 0) + { + HandleGraft(peerId, rpc.Control.Graft, peerMessages); + } + + if (rpc.Control.Prune.Count != 0) + { + HandlePrune(peerId, rpc.Control.Prune, peerMessages); + } + + if (rpc.Control.Ihave.Count != 0) + { + HandleIhave(peerId, rpc.Control.Ihave, peerMessages); + } + + if (rpc.Control.Iwant.Count != 0) + { + HandleIwant(peerId, rpc.Control.Iwant, peerMessages); + } + + if (rpc.Control.Idontwant.Count != 0) + { + HandleIdontwant(peerId, rpc.Control.Idontwant); + } + } + } + foreach (KeyValuePair peerMessage in peerMessages) + { + peerState.GetValueOrDefault(peerMessage.Key)?.Send(peerMessage.Value); + } + } + catch (Exception ex) + { + logger?.LogError(ex, "Exception while processing RPC"); + } + } + + private void HandleNewMessages(PeerId peerId, IEnumerable messages, ConcurrentDictionary peerMessages) + { + logger?.LogDebug($"Messages received: {messages.Select(_settings.GetMessageId).Count(messageId => _limboMessageCache.Contains(messageId) || _messageCache!.Contains(messageId))}/{messages.Count()}"); + + foreach (Message? message in messages) + { + MessageId messageId = _settings.GetMessageId(message); + + if (_limboMessageCache.Contains(messageId) || _messageCache!.Contains(messageId)) + { + continue; + } + + switch (VerifyMessage?.Invoke(message)) + { + case MessageValidity.Rejected: + case MessageValidity.Ignored: + _limboMessageCache.Add(messageId, message); + continue; + case MessageValidity.Trottled: + continue; + } + + if (!message.VerifySignature(_settings.DefaultSignaturePolicy)) + { + _limboMessageCache!.Add(messageId, message); + continue; + } + + _messageCache.Add(messageId, message); + + PeerId author = new(message.From.ToArray()); + OnMessage?.Invoke(message.Topic, message.Data.ToByteArray()); + + if (fPeers.TryGetValue(message.Topic, out HashSet? topicPeers)) + { + foreach (PeerId peer in topicPeers) + { + if (peer == author || peer == peerId) + { + continue; + } + peerMessages.GetOrAdd(peer, _ => new Rpc()).Publish.Add(message); + } + } + if (mesh.TryGetValue(message.Topic, out topicPeers)) + { + foreach (PeerId peer in topicPeers) + { + if (peer == author || peer == peerId) + { + continue; + } + peerMessages.GetOrAdd(peer, _ => new Rpc()).Publish.Add(message); + } + } + } + } + + private void HandleSubscriptions(PeerId peerId, IEnumerable subscriptions) + { + foreach (Rpc.Types.SubOpts? sub in subscriptions) + { + PubsubPeer? state = peerState.GetValueOrDefault(peerId); + if (state is null) + { + return; + } + if (sub.Subscribe) + { + if (state.IsGossipSub) + { + gPeers.GetOrAdd(sub.Topicid, _ => []).Add(peerId); + } + else if (state.IsFloodSub) + { + fPeers.GetOrAdd(sub.Topicid, _ => []).Add(peerId); + } + } + else + { + if (state.IsGossipSub) + { + gPeers.GetOrAdd(sub.Topicid, _ => []).Remove(peerId); + if (mesh.ContainsKey(sub.Topicid)) + { + mesh[sub.Topicid].Remove(peerId); + } + if (fanout.ContainsKey(sub.Topicid)) + { + fanout[sub.Topicid].Remove(peerId); + } + } + else if (state.IsFloodSub) + { + fPeers.GetOrAdd(sub.Topicid, _ => []).Remove(peerId); + } + } + } + } + + private void HandleGraft(PeerId peerId, IEnumerable grafts, ConcurrentDictionary peerMessages) + { + foreach (ControlGraft? graft in grafts) + { + if (!topicState.ContainsKey(graft.TopicID)) + { + peerMessages.GetOrAdd(peerId, _ => new Rpc()) + .Ensure(r => r.Control.Prune) + .Add(new ControlPrune { TopicID = graft.TopicID }); + } + else + { + HashSet topicMesh = mesh[graft.TopicID]; + + if (topicMesh.Count >= _settings.HighestDegree) + { + ControlPrune prune = new() { TopicID = graft.TopicID }; + + if (peerState.TryGetValue(peerId, out PubsubPeer? state) && state.IsGossipSub && state.Protocol >= PubsubPeer.PubsubProtocol.GossipsubV11) + { + state.Backoff[prune.TopicID] = DateTime.Now.AddSeconds(prune.Backoff == 0 ? 60 : prune.Backoff); + prune.Peers.AddRange(topicMesh.ToArray().Select(pid => (PeerId: pid, Record: _peerStore.GetPeerInfo(pid)?.SignedPeerRecord)).Where(pid => pid.Record is not null).Select(pid => new PeerInfo + { + PeerID = ByteString.CopyFrom(pid.PeerId.Bytes), + SignedPeerRecord = pid.Record, + })); + } + + peerMessages.GetOrAdd(peerId, _ => new Rpc()) + .Ensure(r => r.Control.Prune) + .Add(prune); + } + else + { + if (!topicMesh.Contains(peerId)) + { + topicMesh.Add(peerId); + gPeers[graft.TopicID].Add(peerId); + peerMessages.GetOrAdd(peerId, _ => new Rpc()) + .Ensure(r => r.Control.Graft) + .Add(new ControlGraft { TopicID = graft.TopicID }); + } + } + } + } + } + + private void HandlePrune(PeerId peerId, IEnumerable prunes, ConcurrentDictionary peerMessages) + { + foreach (ControlPrune? prune in prunes) + { + if (topicState.ContainsKey(prune.TopicID) && mesh[prune.TopicID].Contains(peerId)) + { + if (peerState.TryGetValue(peerId, out PubsubPeer? state)) + { + state.Backoff[prune.TopicID] = DateTime.Now.AddSeconds(prune.Backoff == 0 ? 60 : prune.Backoff); + } + mesh[prune.TopicID].Remove(peerId); + peerMessages.GetOrAdd(peerId, _ => new Rpc()) + .Ensure(r => r.Control.Prune) + .Add(new ControlPrune { TopicID = prune.TopicID }); + + foreach (PeerInfo? peer in prune.Peers) + { + _peerStore.Discover(peer.SignedPeerRecord); + } + } + } + } + + private void HandleIhave(PeerId peerId, IEnumerable ihaves, ConcurrentDictionary peerMessages) + { + List messageIds = []; + + foreach (ControlIHave? ihave in ihaves.Where(iw => topicState.ContainsKey(iw.TopicID))) + { + messageIds.AddRange(ihave.MessageIDs.Select(m => new MessageId(m.ToByteArray())) + .Where(mid => !_messageCache.Contains(mid))); + } + + if (messageIds.Any()) + { + ControlIWant ciw = new(); + foreach (MessageId mId in messageIds) + { + ciw.MessageIDs.Add(ByteString.CopyFrom(mId.Bytes)); + } + peerMessages.GetOrAdd(peerId, _ => new Rpc()) + .Ensure(r => r.Control.Iwant) + .Add(ciw); + } + } + + private void HandleIwant(PeerId peerId, IEnumerable iwants, ConcurrentDictionary peerMessages) + { + IEnumerable messageIds = iwants.SelectMany(iw => iw.MessageIDs).Select(m => new MessageId(m.ToByteArray())); + List messages = []; + foreach (MessageId? mId in messageIds) + { + Message message = _messageCache.Get(mId); + if (message != default) + { + messages.Add(message); + } + } + if (messages.Any()) + { + peerMessages.GetOrAdd(peerId, _ => new Rpc()) + .Publish.AddRange(messages); + } + } + + private void HandleIdontwant(PeerId peerId, IEnumerable idontwants) + { + foreach (MessageId messageId in idontwants.SelectMany(iw => iw.MessageIDs).Select(m => new MessageId(m.ToByteArray())).Take(_settings.MaxIdontwantMessages)) + { + _dontWantMessages.Add((peerId, messageId)); + } + } +} diff --git a/src/libp2p/Libp2p.Protocols.Pubsub/PubsubRouter.Topics.cs b/src/libp2p/Libp2p.Protocols.Pubsub/PubsubRouter.Topics.cs index 865b4d24..53840a6d 100644 --- a/src/libp2p/Libp2p.Protocols.Pubsub/PubsubRouter.Topics.cs +++ b/src/libp2p/Libp2p.Protocols.Pubsub/PubsubRouter.Topics.cs @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited // SPDX-License-Identifier: MIT +using Microsoft.Extensions.Logging; using Nethermind.Libp2p.Core; using Nethermind.Libp2p.Protocols.Pubsub.Dto; using System.Buffers.Binary; @@ -49,7 +50,7 @@ public void Subscribe(string topicId) } Rpc topicUpdate = new Rpc().WithTopics([topicId], []); - foreach (var peer in peerState) + foreach (KeyValuePair peer in peerState) { peer.Value.Send(topicUpdate); } @@ -82,34 +83,41 @@ public void Unsubscribe(string topicId) public void UnsubscribeAll() { - foreach (PeerId? peerId in fPeers.SelectMany(kv => kv.Value)) + try { - Rpc msg = new Rpc().WithTopics([], topicState.Keys); + foreach (PeerId? peerId in fPeers.SelectMany(kv => kv.Value)) + { + Rpc msg = new Rpc().WithTopics([], topicState.Keys); - peerState.GetValueOrDefault(peerId)?.Send(msg); - } + peerState.GetValueOrDefault(peerId)?.Send(msg); + } - Dictionary peerMessages = new(); + Dictionary peerMessages = []; - foreach (PeerId? peerId in gPeers.SelectMany(kv => kv.Value)) - { - (peerMessages[peerId] ??= new Rpc()) - .WithTopics([], topicState.Keys); - } - - foreach (KeyValuePair> topicMesh in mesh) - { - foreach (PeerId peerId in topicMesh.Value) + foreach (PeerId? peerId in gPeers.SelectMany(kv => kv.Value)) { (peerMessages[peerId] ??= new Rpc()) - .Ensure(r => r.Control.Prune) - .Add(new ControlPrune { TopicID = topicMesh.Key }); + .WithTopics([], topicState.Keys); } - } - foreach (KeyValuePair peerMessage in peerMessages) + foreach (KeyValuePair> topicMesh in mesh.ToDictionary()) + { + foreach (PeerId peerId in topicMesh.Value) + { + (peerMessages[peerId] ??= new Rpc()) + .Ensure(r => r.Control.Prune) + .Add(new ControlPrune { TopicID = topicMesh.Key }); + } + } + + foreach (KeyValuePair peerMessage in peerMessages) + { + peerState.GetValueOrDefault(peerMessage.Key)?.Send(peerMessage.Value); + } + } + catch (Exception e) { - peerState.GetValueOrDefault(peerMessage.Key)?.Send(peerMessage.Value); + logger?.LogError(e, $"Error during {nameof(UnsubscribeAll)}"); } } @@ -120,7 +128,7 @@ public void Publish(string topicId, byte[] message) ulong seqNo = this.seqNo++; Span seqNoBytes = stackalloc byte[8]; BinaryPrimitives.WriteUInt64BigEndian(seqNoBytes, seqNo); - Rpc rpc = new Rpc().WithMessages(topicId, seqNo, localPeer!.Address.GetPeerId()!.Bytes, message, localPeer.Identity); + Rpc rpc = new Rpc().WithMessages(topicId, seqNo, localPeer!.Identity.PeerId.Bytes, message, localPeer.Identity); foreach (PeerId peerId in fPeers[topicId]) { @@ -137,14 +145,14 @@ public void Publish(string topicId, byte[] message) else { fanoutLastPublished[topicId] = DateTime.Now; - HashSet topicFanout = fanout.GetOrAdd(topicId, _ => new HashSet()); + HashSet topicFanout = fanout.GetOrAdd(topicId, _ => []); if (topicFanout.Count == 0) { HashSet? topicPeers = gPeers.GetValueOrDefault(topicId); if (topicPeers is { Count: > 0 }) { - foreach (PeerId peer in topicPeers.Take(settings.Degree)) + foreach (PeerId peer in topicPeers.Take(_settings.Degree)) { topicFanout.Add(peer); } diff --git a/src/libp2p/Libp2p.Protocols.Pubsub/PubsubRouter.cs b/src/libp2p/Libp2p.Protocols.Pubsub/PubsubRouter.cs index 17e298ac..c4c3a3db 100644 --- a/src/libp2p/Libp2p.Protocols.Pubsub/PubsubRouter.cs +++ b/src/libp2p/Libp2p.Protocols.Pubsub/PubsubRouter.cs @@ -6,27 +6,12 @@ using Multiformats.Address.Protocols; using Nethermind.Libp2p.Core; using Nethermind.Libp2p.Core.Discovery; -using Nethermind.Libp2p.Core.Dto; -using Nethermind.Libp2p.Protocols.Identify.Dto; using Nethermind.Libp2p.Protocols.Pubsub.Dto; -using Org.BouncyCastle.Tls; using System.Collections.Concurrent; -using System.Diagnostics; namespace Nethermind.Libp2p.Protocols.Pubsub; -internal interface IRoutingStateContainer -{ - ConcurrentDictionary> FloodsubPeers { get; } - ConcurrentDictionary> GossipsubPeers { get; } - ConcurrentDictionary> Mesh { get; } - ConcurrentDictionary> Fanout { get; } - ConcurrentDictionary FanoutLastPublished { get; } - ICollection ConnectedPeers { get; } - Task Heartbeat(); -} - -public partial class PubsubRouter(PeerStore store, ILoggerFactory? loggerFactory = default) : IRoutingStateContainer +public partial class PubsubRouter : IRoutingStateContainer, IDisposable { static int routerCounter = 0; readonly int routerId = Interlocked.Increment(ref routerCounter); @@ -34,8 +19,8 @@ public partial class PubsubRouter(PeerStore store, ILoggerFactory? loggerFactory public override string ToString() { //{string.Join("|", peerState.Select(x => $"{x.Key}:{x.Value.SendRpc is not null}"))} - return $"Router#{routerId}: {localPeer?.Address.GetPeerId() ?? "null"}, " + - $"peers: {peerState.Count(x => x.Value.SendRpc is not null)}/{peerState.Count}, " + + return $"Router#{routerId}: {localPeer?.Identity.PeerId ?? "null"}, " + + $"peers: {peerState.Count(x => x.Value.SendRpc is not null)}/{peerState.Count} ({string.Join(",", peerState.Keys)}), " + $"mesh: {string.Join("|", mesh.Select(m => $"{m.Key}:{m.Value.Count}"))}, " + $"fanout: {string.Join("|", fanout.Select(m => $"{m.Key}:{m.Value.Count}"))}, " + $"fPeers: {string.Join("|", fPeers.Select(m => $"{m.Key}:{m.Value.Count}"))}, " + @@ -48,9 +33,10 @@ public override string ToString() class PubsubPeer { - public PubsubPeer(PeerId peerId, string protocolId) + public PubsubPeer(PeerId peerId, string protocolId, ILogger? logger) { PeerId = peerId; + _logger = logger; Protocol = protocolId switch { GossipsubProtocolVersionV10 => PubsubProtocol.GossipsubV10, @@ -90,11 +76,13 @@ public void Send(Rpc rpc) public Dictionary Backoff { get; internal set; } public ConcurrentQueue SendRpcQueue { get; } private Action? _sendRpc; + private readonly ILogger? _logger; + public Action? SendRpc { get => _sendRpc; set { - Debug.WriteLine($"Set SENDRPC for {this.PeerId}: {value}"); + _logger?.LogDebug($"Set SENDRPC for {PeerId}: {value}"); _sendRpc = value; if (_sendRpc is not null) lock (SendRpcQueue) @@ -125,6 +113,7 @@ public Action? SendRpc ConcurrentDictionary> IRoutingStateContainer.Mesh => mesh; ConcurrentDictionary> IRoutingStateContainer.Fanout => fanout; ConcurrentDictionary IRoutingStateContainer.FanoutLastPublished => fanoutLastPublished; + bool IRoutingStateContainer.Started => localPeer is not null; ICollection IRoutingStateContainer.ConnectedPeers => peerState.Keys; Task IRoutingStateContainer.Heartbeat() => Heartbeat(); #endregion @@ -132,14 +121,13 @@ public Action? SendRpc public event Action? OnMessage; public Func? VerifyMessage = null; - private Settings settings; - private TtlCache messageCache; - private TtlCache limboMessageCache; - private TtlCache<(PeerId, MessageId)> dontWantMessages; + private readonly PubsubSettings _settings; + private readonly TtlCache _messageCache; + private readonly TtlCache _limboMessageCache; + private readonly TtlCache<(PeerId, MessageId)> _dontWantMessages; private ILocalPeer? localPeer; - private ManagedPeer peer; - private readonly ILogger? logger = loggerFactory?.CreateLogger(); + private readonly ILogger? logger; // all floodsub peers in topics private readonly ConcurrentDictionary> fPeers = new(); @@ -157,7 +145,8 @@ public Action? SendRpc // all peers with their connection status private readonly ConcurrentDictionary peerState = new(); - private readonly ConcurrentBag reconnections = new(); + private readonly ConcurrentBag reconnections = []; + private readonly PeerStore _peerStore; private ulong seqNo = 1; private record Reconnection(Multiaddress[] Addresses, int Attempts); @@ -169,24 +158,29 @@ static PubsubRouter() Canceled = cts.Token; } - public async Task RunAsync(ILocalPeer localPeer, Settings? settings = null, CancellationToken token = default) + public PubsubRouter(PeerStore store, PubsubSettings? settings = null, ILoggerFactory? loggerFactory = default) + { + logger = loggerFactory?.CreateLogger("pubsub-router"); + + _peerStore = store; + _settings = settings ?? PubsubSettings.Default; + _messageCache = new(_settings.MessageCacheTtl); + _limboMessageCache = new(_settings.MessageCacheTtl); + _dontWantMessages = new(_settings.MessageCacheTtl); + } + + public Task StartAsync(ILocalPeer localPeer, CancellationToken token = default) { - logger?.LogDebug($"Running pubsub for {localPeer.Address}"); + logger?.LogDebug($"Running pubsub for {string.Join(",", localPeer.ListenAddresses)}"); if (this.localPeer is not null) { throw new InvalidOperationException("Router has been already started"); } this.localPeer = localPeer; - peer = new ManagedPeer(localPeer); - this.settings = settings ?? Settings.Default; - messageCache = new(this.settings.MessageCacheTtl); - limboMessageCache = new(this.settings.MessageCacheTtl); - dontWantMessages = new(this.settings.MessageCacheTtl); - logger?.LogInformation("Started"); - store.OnNewPeer += (addrs) => + _peerStore.OnNewPeer += (addrs) => { if (addrs.Any(a => a.GetPeerId()! == localPeer.Identity.PeerId)) { @@ -196,20 +190,20 @@ public async Task RunAsync(ILocalPeer localPeer, Settings? settings = null, Canc { try { - IRemotePeer remotePeer = await peer.DialAsync(addrs, token); + ISession session = await localPeer.DialAsync(addrs, token); - if (!peerState.ContainsKey(remotePeer.Address.Get().ToString())) + if (!peerState.ContainsKey(session.RemoteAddress.Get().ToString())) { - await remotePeer.DialAsync(token); - if (peerState.TryGetValue(remotePeer.Address.GetPeerId()!, out PubsubPeer? state) && state.InititatedBy == ConnectionInitiation.Remote) + await session.DialAsync(token); + if (peerState.TryGetValue(session.RemoteAddress.GetPeerId()!, out PubsubPeer? state) && state.InititatedBy == ConnectionInitiation.Remote) { - _ = remotePeer.DisconnectAsync(); + _ = session.DisconnectAsync(); } } } catch { - reconnections.Add(new Reconnection(addrs, this.settings.ReconnectionAttempts)); + reconnections.Add(new Reconnection(addrs, _settings.ReconnectionAttempts)); } }); }; @@ -218,7 +212,7 @@ public async Task RunAsync(ILocalPeer localPeer, Settings? settings = null, Canc { while (!token.IsCancellationRequested) { - await Task.Delay(this.settings.HeartbeatInterval); + await Task.Delay(_settings.HeartbeatInterval); await Heartbeat(); } }, token); @@ -228,16 +222,20 @@ public async Task RunAsync(ILocalPeer localPeer, Settings? settings = null, Canc { while (!token.IsCancellationRequested) { - await Task.Delay(this.settings.ReconnectionPeriod); + await Task.Delay(_settings.ReconnectionPeriod); await Reconnect(token); } }, token); - await Task.Delay(Timeout.Infinite, token); - messageCache.Dispose(); - limboMessageCache.Dispose(); + logger?.LogInformation("Started"); + return Task.CompletedTask; } + public void Dispose() + { + _messageCache.Dispose(); + _limboMessageCache.Dispose(); + } private async Task Reconnect(CancellationToken token) { @@ -245,7 +243,7 @@ private async Task Reconnect(CancellationToken token) { try { - IRemotePeer remotePeer = await peer.DialAsync(rec.Addresses, token); + ISession remotePeer = await localPeer.DialAsync(rec.Addresses, token); await remotePeer.DialAsync(token); } catch @@ -265,13 +263,11 @@ public Task Heartbeat() { foreach (KeyValuePair> mesh in mesh) { - logger?.LogDebug($"MESH({localPeer!.Address.GetPeerId()}) {mesh.Key}: {mesh.Value.Count} ({mesh.Value})"); - if (mesh.Value.Count < settings.LowestDegree) + if (mesh.Value.Count < _settings.LowestDegree) { PeerId[] peersToGraft = gPeers[mesh.Key] - .Where(p => !mesh.Value.Contains(p) - && (peerState.GetValueOrDefault(p)?.Backoff.TryGetValue(mesh.Key, out DateTime backoff) != true || - backoff < DateTime.Now)).Take(settings.Degree - mesh.Value.Count).ToArray(); + .Where(p => !mesh.Value.Contains(p) && (peerState.GetValueOrDefault(p)?.Backoff.TryGetValue(mesh.Key, out DateTime backoff) != true || backoff < DateTime.Now)) + .Take(_settings.Degree - mesh.Value.Count).ToArray(); foreach (PeerId peerId in peersToGraft) { mesh.Value.Add(peerId); @@ -280,18 +276,21 @@ public Task Heartbeat() .Add(new ControlGraft { TopicID = mesh.Key }); } } - else if (mesh.Value.Count > settings.HighestDegree) + else if (mesh.Value.Count > _settings.HighestDegree) { - PeerId[] peerstoPrune = mesh.Value.Take(mesh.Value.Count - settings.HighestDegree).ToArray(); + PeerId[] peerstoPrune = mesh.Value.Take(mesh.Value.Count - _settings.HighestDegree).ToArray(); foreach (PeerId? peerId in peerstoPrune) { mesh.Value.Remove(peerId); - var prune = new ControlPrune { TopicID = mesh.Key, Backoff = 60 }; - prune.Peers.AddRange(mesh.Value.ToArray().Select(pid => (PeerId: pid, Record: store.GetPeerInfo(pid)?.SignedPeerRecord)).Where(pid => pid.Record is not null).Select(pid => new PeerInfo - { - PeerID = ByteString.CopyFrom(pid.PeerId.Bytes), - SignedPeerRecord = pid.Record, - })); + ControlPrune prune = new() { TopicID = mesh.Key, Backoff = 60 }; + prune.Peers.AddRange(mesh.Value.ToArray() + .Select(pid => (PeerId: pid, Record: _peerStore.GetPeerInfo(pid)?.SignedPeerRecord)) + .Where(pid => pid.Record is not null) + .Select(pid => new PeerInfo + { + PeerID = ByteString.CopyFrom(pid.PeerId.Bytes), + SignedPeerRecord = pid.Record, + })); peerMessages.GetOrAdd(peerId, _ => new Rpc()) .Ensure(r => r.Control.Prune) .Add(prune); @@ -301,14 +300,14 @@ public Task Heartbeat() foreach (string? fanoutTopic in fanout.Keys.ToArray()) { - if (fanoutLastPublished.GetOrAdd(fanoutTopic, _ => DateTime.Now).AddMilliseconds(settings.FanoutTtl) < DateTime.Now) + if (fanoutLastPublished.GetOrAdd(fanoutTopic, _ => DateTime.Now).AddMilliseconds(_settings.FanoutTtl) < DateTime.Now) { fanout.Remove(fanoutTopic, out _); fanoutLastPublished.Remove(fanoutTopic, out _); } else { - int peerCountToAdd = settings.Degree - fanout[fanoutTopic].Count; + int peerCountToAdd = _settings.Degree - fanout[fanoutTopic].Count; if (peerCountToAdd > 0) { foreach (PeerId? peerId in gPeers[fanoutTopic].Where(p => !fanout[fanoutTopic].Contains(p)).Take(peerCountToAdd)) @@ -319,7 +318,7 @@ public Task Heartbeat() } } - IEnumerable> msgs = messageCache.ToList().GroupBy(m => m.Topic); + IEnumerable> msgs = _messageCache.ToList().GroupBy(m => m.Topic); foreach (string? topic in gPeers.Keys.Concat(fanout.Keys).Distinct().ToArray()) { @@ -327,9 +326,9 @@ public Task Heartbeat() if (msgsInTopic is not null) { ControlIHave ihave = new() { TopicID = topic }; - ihave.MessageIDs.AddRange(msgsInTopic.Select(m => ByteString.CopyFrom(settings.GetMessageId(m).Bytes))); + ihave.MessageIDs.AddRange(msgsInTopic.Select(m => ByteString.CopyFrom(_settings.GetMessageId(m).Bytes))); - foreach (PeerId? peer in gPeers[topic].Where(p => !mesh[topic].Contains(p) && !fanout[topic].Contains(p)).Take(settings.LazyDegree)) + foreach (PeerId? peer in gPeers[topic].Where(p => !mesh[topic].Contains(p) && !fanout[topic].Contains(p)).Take(_settings.LazyDegree)) { peerMessages.GetOrAdd(peer, _ => new Rpc()) .Ensure(r => r.Control.Ihave).Add(ihave); @@ -355,7 +354,7 @@ internal CancellationToken OutboundConnection(Multiaddress addr, string protocol return Canceled; } - PubsubPeer peer = peerState.GetOrAdd(peerId, (id) => new PubsubPeer(peerId, protocolId) { Address = addr, SendRpc = sendRpc, InititatedBy = ConnectionInitiation.Local }); + PubsubPeer peer = peerState.GetOrAdd(peerId, (id) => new PubsubPeer(peerId, protocolId, logger) { Address = addr, SendRpc = sendRpc, InititatedBy = ConnectionInitiation.Local }); lock (peer) { @@ -367,337 +366,100 @@ internal CancellationToken OutboundConnection(Multiaddress addr, string protocol } else { + logger?.LogDebug("Outbound, rpc set for {peerId}, cancelling", peerId); return Canceled; } } - } - - dialTask.ContinueWith(t => - { - peerState.GetValueOrDefault(peerId)?.TokenSource.Cancel(); - peerState.TryRemove(peerId, out _); - foreach (var topicPeers in fPeers) - { - topicPeers.Value.Remove(peerId); - } - foreach (var topicPeers in gPeers) - { - topicPeers.Value.Remove(peerId); - } - foreach (var topicPeers in fanout) - { - topicPeers.Value.Remove(peerId); - } - foreach (var topicPeers in mesh) - { - topicPeers.Value.Remove(peerId); - } - reconnections.Add(new Reconnection([addr], settings.ReconnectionAttempts)); - }); - - string[] topics = topicState.Keys.ToArray(); - - if (topics.Any()) - { - logger?.LogDebug("Topics sent to {peerId}: {topics}", peerId, string.Join(",", topics)); - Rpc helloMessage = new Rpc().WithTopics(topics, []); - peer.Send(helloMessage); - } - logger?.LogDebug("Outbound {peerId}", peerId); - return peer.TokenSource.Token; - } + logger?.LogDebug("Outbound, let's dial {peerId} via remotely initiated connection", peerId); - internal CancellationToken InboundConnection(Multiaddress addr, string protocolId, Task listTask, Task dialTask, Func subDial) - { - PeerId? peerId = addr.GetPeerId(); - - if (peerId is null || peerId == localPeer!.Identity.PeerId) - { - return Canceled; - } - - PubsubPeer? newPeer = null; - PubsubPeer existingPeer = peerState.GetOrAdd(peerId, (id) => newPeer = new PubsubPeer(peerId, protocolId) { Address = addr, InititatedBy = ConnectionInitiation.Remote }); - if (newPeer is not null) - { - logger?.LogDebug("Inbound, let's dial {peerId} via remotely initiated connection", peerId); - listTask.ContinueWith(t => + dialTask.ContinueWith(t => { peerState.GetValueOrDefault(peerId)?.TokenSource.Cancel(); peerState.TryRemove(peerId, out _); - foreach (var topicPeers in fPeers) + foreach (KeyValuePair> topicPeers in fPeers) { topicPeers.Value.Remove(peerId); } - foreach (var topicPeers in gPeers) + foreach (KeyValuePair> topicPeers in gPeers) { topicPeers.Value.Remove(peerId); } - foreach (var topicPeers in fanout) + foreach (KeyValuePair> topicPeers in fanout) { topicPeers.Value.Remove(peerId); } - foreach (var topicPeers in mesh) + foreach (KeyValuePair> topicPeers in mesh) { topicPeers.Value.Remove(peerId); } - reconnections.Add(new Reconnection([addr], settings.ReconnectionAttempts)); + reconnections.Add(new Reconnection([addr], _settings.ReconnectionAttempts)); }); - subDial(); - - return newPeer.TokenSource.Token; - } - else - { - return existingPeer.TokenSource.Token; - } - } + string[] topics = topicState.Keys.ToArray(); - internal async Task OnRpc(PeerId peerId, Rpc rpc) - { - try - { - ConcurrentDictionary peerMessages = new(); - lock (this) + if (topics.Any()) { - if (rpc.Subscriptions.Any()) - { - foreach (Rpc.Types.SubOpts? sub in rpc.Subscriptions) - { - var state = peerState.GetValueOrDefault(peerId); - if (state is null) - { - return; - } - if (sub.Subscribe) - { - if (state.IsGossipSub) - { - gPeers.GetOrAdd(sub.Topicid, _ => []).Add(peerId); - } - else if (state.IsFloodSub) - { - fPeers.GetOrAdd(sub.Topicid, _ => []).Add(peerId); - } - } - else - { - if (state.IsGossipSub) - { - gPeers.GetOrAdd(sub.Topicid, _ => []).Remove(peerId); - if (mesh.ContainsKey(sub.Topicid)) - { - mesh[sub.Topicid].Remove(peerId); - } - if (fanout.ContainsKey(sub.Topicid)) - { - fanout[sub.Topicid].Remove(peerId); - } - } - else if (state.IsFloodSub) - { - fPeers.GetOrAdd(sub.Topicid, _ => []).Remove(peerId); - } - } - } - } + logger?.LogDebug("Topics sent to {peerId}: {topics}", peerId, string.Join(",", topics)); - if (rpc.Publish.Any()) - { - if (rpc.Publish.Any()) - { - logger?.LogDebug($"Messages received: {rpc.Publish.Select(settings.GetMessageId).Count(messageId => limboMessageCache.Contains(messageId) || messageCache!.Contains(messageId))}/{rpc.Publish.Count}: {rpc.Publish.Count}"); - } - - foreach (Message? message in rpc.Publish) - { - MessageId messageId = settings.GetMessageId(message); - - if (limboMessageCache.Contains(messageId) || messageCache!.Contains(messageId)) - { - continue; - } - - switch (VerifyMessage?.Invoke(message)) - { - case MessageValidity.Rejected: - case MessageValidity.Ignored: - limboMessageCache.Add(messageId, message); - continue; - case MessageValidity.Trottled: - continue; - } + Rpc helloMessage = new Rpc().WithTopics(topics, []); + peer.Send(helloMessage); + } - if (!message.VerifySignature(settings.DefaultSignaturePolicy)) - { - limboMessageCache!.Add(messageId, message); - continue; - } + logger?.LogDebug("Outbound {peerId}", peerId); + return peer.TokenSource.Token; + } + } - messageCache.Add(messageId, message); + internal CancellationToken InboundConnection(Multiaddress addr, string protocolId, Task listTask, Task dialTask, Func subDial) + { + PeerId? peerId = addr.GetPeerId(); - PeerId author = new(message.From.ToArray()); - OnMessage?.Invoke(message.Topic, message.Data.ToByteArray()); + if (peerId is null || peerId == localPeer!.Identity.PeerId) + { + return Canceled; + } - if (fPeers.TryGetValue(message.Topic, out HashSet? topicPeers)) - { - foreach (PeerId peer in topicPeers) - { - if (peer == author || peer == peerId) - { - continue; - } - peerMessages.GetOrAdd(peer, _ => new Rpc()).Publish.Add(message); - } - } - if (fPeers.TryGetValue(message.Topic, out topicPeers)) - { - foreach (PeerId peer in mesh[message.Topic]) - { - if (peer == author || peer == peerId) - { - continue; - } - peerMessages.GetOrAdd(peer, _ => new Rpc()).Publish.Add(message); - } - } - } - } + PubsubPeer? newPeer = null; + PubsubPeer existingPeer = peerState.GetOrAdd(peerId, (id) => newPeer = new PubsubPeer(peerId, protocolId, logger) { Address = addr, InititatedBy = ConnectionInitiation.Remote }); + lock (existingPeer) + { - if (rpc.Control is not null) + if (newPeer is not null) + { + logger?.LogDebug("Inbound, let's dial {peerId} via remotely initiated connection", peerId); + listTask.ContinueWith(t => { - if (rpc.Control.Graft.Any()) + peerState.GetValueOrDefault(peerId)?.TokenSource.Cancel(); + peerState.TryRemove(peerId, out _); + foreach (KeyValuePair> topicPeers in fPeers) { - foreach (ControlGraft? graft in rpc.Control.Graft) - { - if (!topicState.ContainsKey(graft.TopicID)) - { - peerMessages.GetOrAdd(peerId, _ => new Rpc()) - .Ensure(r => r.Control.Prune) - .Add(new ControlPrune { TopicID = graft.TopicID }); - } - else - { - HashSet topicMesh = mesh[graft.TopicID]; - - if (topicMesh.Count >= settings.HighestDegree) - { - ControlPrune prune = new() { TopicID = graft.TopicID }; - - if (peerState.TryGetValue(peerId, out PubsubPeer? state) && state.IsGossipSub && state.Protocol >= PubsubPeer.PubsubProtocol.GossipsubV11) - { - state.Backoff[prune.TopicID] = DateTime.Now.AddSeconds(prune.Backoff == 0 ? 60 : prune.Backoff); - prune.Peers.AddRange(topicMesh.ToArray().Select(pid => (PeerId: pid, Record: store.GetPeerInfo(pid)?.SignedPeerRecord)).Where(pid => pid.Record is not null).Select(pid => new PeerInfo - { - PeerID = ByteString.CopyFrom(pid.PeerId.Bytes), - SignedPeerRecord = pid.Record, - })); - } - - peerMessages.GetOrAdd(peerId, _ => new Rpc()) - .Ensure(r => r.Control.Prune) - .Add(prune); - } - else - { - topicMesh.Add(peerId); - peerMessages.GetOrAdd(peerId, _ => new Rpc()) - .Ensure(r => r.Control.Graft) - .Add(new ControlGraft { TopicID = graft.TopicID }); - } - } - } + topicPeers.Value.Remove(peerId); } - - if (rpc.Control.Prune.Any()) + foreach (KeyValuePair> topicPeers in gPeers) { - foreach (ControlPrune? prune in rpc.Control.Prune) - { - if (topicState.ContainsKey(prune.TopicID) && mesh[prune.TopicID].Contains(peerId)) - { - if (peerState.TryGetValue(peerId, out PubsubPeer? state)) - { - state.Backoff[prune.TopicID] = DateTime.Now.AddSeconds(prune.Backoff == 0 ? 60 : prune.Backoff); - } - mesh[prune.TopicID].Remove(peerId); - peerMessages.GetOrAdd(peerId, _ => new Rpc()) - .Ensure(r => r.Control.Prune) - .Add(new ControlPrune { TopicID = prune.TopicID }); - - foreach (var peer in prune.Peers) - { - // TODO verify payload type, signature, etc - // TODO check if it's working - reconnections.Add(new Reconnection(PeerRecord.Parser.ParseFrom(SignedEnvelope.Parser.ParseFrom(peer.SignedPeerRecord).Payload).Addresses.Select(ai => Multiaddress.Decode(ai.Multiaddr.ToByteArray())).ToArray(), 5)); - } - } - } + topicPeers.Value.Remove(peerId); } - - if (rpc.Control.Ihave.Any()) + foreach (KeyValuePair> topicPeers in fanout) { - List messageIds = new(); - - foreach (ControlIHave? ihave in rpc.Control.Ihave - .Where(iw => topicState.ContainsKey(iw.TopicID))) - { - messageIds.AddRange(ihave.MessageIDs.Select(m => new MessageId(m.ToByteArray())) - .Where(mid => !messageCache.Contains(mid))); - } - - if (messageIds.Any()) - { - ControlIWant ciw = new(); - foreach (MessageId mId in messageIds) - { - ciw.MessageIDs.Add(ByteString.CopyFrom(mId.Bytes)); - } - peerMessages.GetOrAdd(peerId, _ => new Rpc()) - .Ensure(r => r.Control.Iwant) - .Add(ciw); - } + topicPeers.Value.Remove(peerId); } - - if (rpc.Control.Iwant.Any()) + foreach (KeyValuePair> topicPeers in mesh) { - IEnumerable messageIds = rpc.Control.Iwant.SelectMany(iw => iw.MessageIDs).Select(m => new MessageId(m.ToByteArray())); - List messages = new(); - foreach (MessageId? mId in messageIds) - { - Message message = messageCache.Get(mId); - if (message != default) - { - messages.Add(message); - } - } - if (messages.Any()) - { - peerMessages.GetOrAdd(peerId, _ => new Rpc()) - .Publish.AddRange(messages); - } + topicPeers.Value.Remove(peerId); } + reconnections.Add(new Reconnection([addr], _settings.ReconnectionAttempts)); + }); - if (rpc.Control.Idontwant.Any()) - { - foreach (MessageId messageId in rpc.Control.Iwant.SelectMany(iw => iw.MessageIDs).Select(m => new MessageId(m.ToByteArray())).Take(settings.MaxIdontwantMessages)) - { - dontWantMessages.Add((peerId, messageId)); - } - } - } + subDial(); + return newPeer.TokenSource.Token; } - foreach (KeyValuePair peerMessage in peerMessages) + else { - peerState.GetValueOrDefault(peerMessage.Key)?.Send(peerMessage.Value); + return existingPeer.TokenSource.Token; } } - catch (Exception ex) - { - logger?.LogError("Exception during rpc handling: {exception}", ex); - } } } diff --git a/src/libp2p/Libp2p.Protocols.Pubsub/RpcExtensions.cs b/src/libp2p/Libp2p.Protocols.Pubsub/RpcExtensions.cs index fbbd1003..c2a5a90d 100644 --- a/src/libp2p/Libp2p.Protocols.Pubsub/RpcExtensions.cs +++ b/src/libp2p/Libp2p.Protocols.Pubsub/RpcExtensions.cs @@ -45,9 +45,9 @@ public static Rpc WithTopics(this Rpc rpc, IEnumerable addTopics, IEnume return rpc; } - public static bool VerifySignature(this Message message, Settings.SignaturePolicy signaturePolicy) + public static bool VerifySignature(this Message message, PubsubSettings.SignaturePolicy signaturePolicy) { - if (signaturePolicy is Settings.SignaturePolicy.StrictNoSign) + if (signaturePolicy is PubsubSettings.SignaturePolicy.StrictNoSign) { return message.Signature.IsEmpty; } diff --git a/src/libp2p/Libp2p.Protocols.Pubsub/TtlCache.cs b/src/libp2p/Libp2p.Protocols.Pubsub/TtlCache.cs index e8d6edef..1c777b74 100644 --- a/src/libp2p/Libp2p.Protocols.Pubsub/TtlCache.cs +++ b/src/libp2p/Libp2p.Protocols.Pubsub/TtlCache.cs @@ -13,7 +13,7 @@ private struct CachedItem public DateTimeOffset ValidTill { get; set; } } - private readonly SortedDictionary items = new(); + private readonly SortedDictionary items = []; private bool isDisposed; public TtlCache(int ttl) diff --git a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.E2eTests/Libp2p.Protocols.PubsubPeerDiscovery.E2eTests.csproj b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.E2eTests/Libp2p.Protocols.PubsubPeerDiscovery.E2eTests.csproj new file mode 100644 index 00000000..a0a90143 --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.E2eTests/Libp2p.Protocols.PubsubPeerDiscovery.E2eTests.csproj @@ -0,0 +1,32 @@ + + + + enable + enable + + + + + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + + + diff --git a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.E2eTests/NetworkDiscoveryTests.cs b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.E2eTests/NetworkDiscoveryTests.cs new file mode 100644 index 00000000..19dd6c22 --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.E2eTests/NetworkDiscoveryTests.cs @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Nethermind.Libp2p.Core.Discovery; +using NUnit.Framework; + +namespace Libp2p.Protocols.PubsubPeerDiscovery.E2eTests; + +public class NetworkDiscoveryTests +{ + [Test] + public async Task Test_NetworkDiscoveredByEveryPeer() + { + string commonTopic = "test"; + + int totalCount = 2; + using PubsubDiscoveryE2eTestSetup test = new(); + + await test.AddPeersAsync(totalCount); + test.Subscribe(commonTopic); + foreach ((_, PeerStore peerStore) in test.PeerStores.Skip(1)) + { + peerStore.Discover(test.Peers[0].ListenAddresses.ToArray()); + } + + + await test.WaitForFullMeshAsync(commonTopic); + } +} diff --git a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.E2eTests/PubsubDiscoveryE2eTestSetup.cs b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.E2eTests/PubsubDiscoveryE2eTestSetup.cs new file mode 100644 index 00000000..560020f4 --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.E2eTests/PubsubDiscoveryE2eTestSetup.cs @@ -0,0 +1,36 @@ +using Libp2p.E2eTests; +using Libp2p.Protocols.Pubsub.E2eTests; +using Microsoft.Extensions.DependencyInjection; +using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Protocols; +using Nethermind.Libp2p.Protocols.PubsubPeerDiscovery; + +namespace Libp2p.Protocols.PubsubPeerDiscovery.E2eTests; + +public class PubsubDiscoveryE2eTestSetup : PubsubE2eTestSetup +{ + public PubsubPeerDiscoverySettings DefaultDiscoverySettings { get; set; } = new PubsubPeerDiscoverySettings { Interval = 300 }; + + public Dictionary Discovery { get; } = []; + + protected override IPeerFactoryBuilder ConfigureLibp2p(ILibp2pPeerFactoryBuilder builder) + { + return base.ConfigureLibp2p(builder) + .AddAppLayerProtocol(); + } + + protected override IServiceCollection ConfigureServices(IServiceCollection col) + { + return base.ConfigureServices(col) + .AddSingleton(new PubsubPeerDiscoverySettings()) + .AddSingleton(); + } + + protected override void AddAt(int index) + { + base.AddAt(index); + Discovery[index] = new PubsubPeerDiscoveryProtocol(Routers[index], PeerStores[index], DefaultDiscoverySettings, Peers[index], loggerFactory); + + _ = Discovery[index].StartDiscoveryAsync(Peers[index].ListenAddresses, Token); + } +} diff --git a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.Tests/MultistreamProtocolTests.cs b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.Tests/MultistreamProtocolTests.cs deleted file mode 100644 index e9f98a90..00000000 --- a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.Tests/MultistreamProtocolTests.cs +++ /dev/null @@ -1,136 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Nethermind.Libp2p.Core.Discovery; -using Nethermind.Libp2p.Core.TestsBase.E2e; -using Nethermind.Libp2p.Protocols.Pubsub; -using NUnit.Framework.Internal; - -namespace Nethermind.Libp2p.Protocols.PubsubPeerDiscovery.Tests; - -[TestFixture, Ignore("No support of time mock yet")] -[Parallelizable(scope: ParallelScope.All)] -public class MultistreamProtocolTests -{ - [Test, CancelAfter(5000)] - public async Task Test_PeersConnect() - { - IPeerFactory peerFactory = new TestBuilder().Build(); - ChannelBus commonBus = new(); - - ServiceProvider sp1 = new ServiceCollection() - .AddSingleton(sp => new TestBuilder(commonBus, sp).AddAppLayerProtocol()) - .AddSingleton(sp => new TestContextLoggerFactory()) - .AddSingleton() - .AddSingleton() - .AddSingleton(sp => sp.GetService()!.Build()) - .BuildServiceProvider(); - - - ServiceProvider sp2 = new ServiceCollection() - .AddSingleton(sp => new TestBuilder(commonBus, sp).AddAppLayerProtocol()) - .AddSingleton(sp => new TestContextLoggerFactory()) - .AddSingleton() - .AddSingleton() - .AddSingleton(sp => sp.GetService()!.Build()) - .BuildServiceProvider(); - - - ILocalPeer peerA = sp1.GetService()!.Create(TestPeers.Identity(1)); - await peerA.ListenAsync(TestPeers.Multiaddr(1)); - ILocalPeer peerB = sp2.GetService()!.Create(TestPeers.Identity(2)); - await peerB.ListenAsync(TestPeers.Multiaddr(2)); - - IRemotePeer remotePeerB = await peerA.DialAsync(peerB.Address); - await remotePeerB.DialAsync(); - } - - [Test] - public async Task Test_ConnectionEstablished_AfterHandshake() - { - int totalCount = 5; - TestContextLoggerFactory fac = new(); - // There is common communication point - ChannelBus commonBus = new(fac); - ILocalPeer[] peers = new ILocalPeer[totalCount]; - PeerStore[] peerStores = new PeerStore[totalCount]; - PubsubRouter[] routers = new PubsubRouter[totalCount]; - - for (int i = 0; i < totalCount; i++) - { - // But we create a seprate setup for every peer - ServiceProvider sp = new ServiceCollection() - .AddSingleton(sp => new TestBuilder(commonBus, sp).AddAppLayerProtocol()) - .AddSingleton(sp => fac) - .AddSingleton() - .AddSingleton() - .AddSingleton(sp => sp.GetService()!.Build()) - .BuildServiceProvider(); - - IPeerFactory peerFactory = sp.GetService()!; - ILocalPeer peer = peers[i] = peerFactory.Create(TestPeers.Identity(i)); - PubsubRouter router = routers[i] = sp.GetService()!; - PeerStore peerStore = sp.GetService()!; - PubsubPeerDiscoveryProtocol disc = new(router, peerStore, new PubsubPeerDiscoverySettings() { Interval = 300 }, peer); - await peer.ListenAsync(TestPeers.Multiaddr(i)); - _ = router.RunAsync(peer); - peerStores[i] = peerStore; - _ = disc.DiscoverAsync(peers[i].Address); - } - - await Task.Delay(1000); - - for (int i = 0; i < peers.Length; i++) - { - peerStores[i].Discover([peers[(i + 1) % totalCount].Address]); - } - - await Task.Delay(30000); - - foreach (var router in routers) - { - Assert.That(((IRoutingStateContainer)router).ConnectedPeers.Count, Is.EqualTo(totalCount - 1)); - } - } - - [Test, CancelAfter(5000)] - public async Task Test_ConnectionEstablished_AfterHandshak3e() - { - int totalCount = 5; - - PubsubTestSetup setup = new(); - Dictionary discoveries = []; - - await setup.AddAsync(totalCount); - - // discover in circle - for (int i = 0; i < setup.Peers.Count; i++) - { - setup.PeerStores[i].Discover([setup.Peers[(i + 1) % setup.Peers.Count].Address]); - } - - for (int i = 0; i < setup.Peers.Count; i++) - { - discoveries[i] = new(setup.Routers[i], setup.PeerStores[i], new PubsubPeerDiscoverySettings() { Interval = int.MaxValue }, setup.Peers[i]); - _ = discoveries[i].DiscoverAsync(setup.Peers[i].Address); - } - - await Task.Delay(100); - - await setup.Heartbeat(); - await setup.Heartbeat(); - await setup.Heartbeat(); - - for (int i = 0; i < setup.Peers.Count; i++) - { - discoveries[i].BroadcastPeerInfo(); - } - - foreach (var router in setup.Routers.Values) - { - Assert.That(((IRoutingStateContainer)router).ConnectedPeers.Count, Is.EqualTo(totalCount - 1)); - } - } -} diff --git a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.Tests/PubSubTestSetup.cs b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.Tests/PubSubTestSetup.cs deleted file mode 100644 index 00e24b7f..00000000 --- a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.Tests/PubSubTestSetup.cs +++ /dev/null @@ -1,57 +0,0 @@ -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Nethermind.Libp2p.Core.Discovery; -using Nethermind.Libp2p.Core.TestsBase.E2e; -using Nethermind.Libp2p.Protocols.Pubsub; - -namespace Nethermind.Libp2p.Protocols.PubsubPeerDiscovery.Tests; - -class PubsubTestSetup -{ - static TestContextLoggerFactory fac = new TestContextLoggerFactory(); - - public ChannelBus CommonBus { get; } = new(fac); - public Dictionary Peers { get; } = new(); - public Dictionary PeerStores { get; } = new(); - public Dictionary Routers { get; } = new(); - - public async Task AddAsync(int count) - { - int initialCount = Peers.Count; - // There is common communication point - - for (int i = initialCount; i < initialCount + count; i++) - { - // But we create a seprate setup for every peer - Settings settings = new Settings - { - HeartbeatInterval = int.MaxValue, - }; - - ServiceProvider sp = new ServiceCollection() - .AddSingleton(sp => new TestBuilder(CommonBus, sp).AddAppLayerProtocol()) - .AddSingleton((Func)(sp => fac)) - .AddSingleton() - .AddSingleton(settings) - .AddSingleton() - .AddSingleton(sp => sp.GetService()!.Build()) - .BuildServiceProvider(); - - IPeerFactory peerFactory = sp.GetService()!; - ILocalPeer peer = Peers[i] = peerFactory.Create(TestPeers.Identity(i)); - PubsubRouter router = Routers[i] = sp.GetService()!; - PeerStore peerStore = sp.GetService()!; - await peer.ListenAsync(TestPeers.Multiaddr(i)); - _ = router.RunAsync(peer); - PeerStores[i] = peerStore; - } - } - - internal async Task Heartbeat() - { - foreach (PubsubRouter router in Routers.Values) - { - await router.Heartbeat(); - } - } -} diff --git a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.Tests/Usings.cs b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.Tests/Usings.cs deleted file mode 100644 index 6c255752..00000000 --- a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery.Tests/Usings.cs +++ /dev/null @@ -1,6 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -global using Nethermind.Libp2p.Core; -global using Nethermind.Libp2p.Core.TestsBase; -global using NUnit.Framework; diff --git a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery/PubsubPeerDiscoveryProtocol.cs b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery/PubsubPeerDiscoveryProtocol.cs index 6e0cc306..591f05ea 100644 --- a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery/PubsubPeerDiscoveryProtocol.cs +++ b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery/PubsubPeerDiscoveryProtocol.cs @@ -1,28 +1,25 @@ // SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited // SPDX-License-Identifier: MIT +using Nethermind.Libp2p.Protocols.PubsubPeerDiscovery; using Nethermind.Libp2p.Protocols.PubsubPeerDiscovery.Dto; namespace Nethermind.Libp2p.Protocols; -public class PubsubPeerDiscoverySettings -{ - public string[] Topics { get; set; } = ["_peer-discovery._p2p._pubsub"]; - public int Interval { get; set; } = 10_000; - public bool ListenOnly { get; set; } -} - public class PubsubPeerDiscoveryProtocol(PubsubRouter pubSubRouter, PeerStore peerStore, PubsubPeerDiscoverySettings settings, ILocalPeer peer, ILoggerFactory? loggerFactory = null) : IDiscoveryProtocol { private readonly PubsubRouter _pubSubRouter = pubSubRouter; - private Multiaddress? _localPeerAddr; + private IReadOnlyList? _localPeerAddrs; + private PeerId? localPeerId; private ITopic[]? topics; private readonly PubsubPeerDiscoverySettings _settings = settings; - private ILogger? logger = loggerFactory?.CreateLogger(); + private readonly ILogger? logger = loggerFactory?.CreateLogger(); - public async Task DiscoverAsync(Multiaddress localPeerAddr, CancellationToken token = default) + public Task StartDiscoveryAsync(IReadOnlyList localPeerAddrs, CancellationToken token = default) { - _localPeerAddr = localPeerAddr; + _localPeerAddrs = localPeerAddrs; + localPeerId = localPeerAddrs.First().GetPeerId(); + topics = _settings.Topics.Select(topic => { ITopic subscription = _pubSubRouter.GetTopic(topic); @@ -32,7 +29,7 @@ public async Task DiscoverAsync(Multiaddress localPeerAddr, CancellationToken to token.Register(() => { - foreach (var topic in topics) + foreach (ITopic topic in topics) { topic.Unsubscribe(); } @@ -40,11 +37,18 @@ public async Task DiscoverAsync(Multiaddress localPeerAddr, CancellationToken to if (!_settings.ListenOnly) { - while (!token.IsCancellationRequested) - { - await Task.Delay(_settings.Interval, token); - BroadcastPeerInfo(); - } + _ = RunAsync(token); + } + + return Task.CompletedTask; + } + + private async Task RunAsync(CancellationToken token) + { + while (!token.IsCancellationRequested) + { + await Task.Delay(_settings.Interval, token); + BroadcastPeerInfo(); } } @@ -52,15 +56,15 @@ internal void BroadcastPeerInfo() { if (topics is null) { - throw new NullReferenceException($"{nameof(topics)} should be previously set in ${nameof(DiscoverAsync)}"); + throw new NullReferenceException($"{nameof(topics)} should be previously set in ${nameof(StartDiscoveryAsync)}"); } - foreach (var topic in topics) + foreach (ITopic topic in topics) { topic.Publish(new Peer { PublicKey = peer.Identity.PublicKey.ToByteString(), - Addrs = { ByteString.CopyFrom(peer.Address.ToBytes()) }, + Addrs = { peer.ListenAddresses.Select(a => ByteString.CopyFrom(a.ToBytes())) }, }); } } @@ -72,11 +76,11 @@ private void OnPeerMessage(byte[] msg) Peer peer = Peer.Parser.ParseFrom(msg); Multiaddress[] addrs = [.. peer.Addrs.Select(a => Multiaddress.Decode(a.ToByteArray()))]; PeerId? remotePeerId = addrs.FirstOrDefault()?.GetPeerId(); - if (remotePeerId is not null && remotePeerId != _localPeerAddr?.GetPeerId()!) + if (remotePeerId is not null && remotePeerId != localPeerId!) { peerStore.Discover(addrs); } - logger?.LogDebug($"{_localPeerAddr}: New peer discovered {peer}"); + logger?.LogDebug($"New peer discovered {peer}"); } catch (Exception ex) { diff --git a/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery/PubsubPeerDiscoverySettings.cs b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery/PubsubPeerDiscoverySettings.cs new file mode 100644 index 00000000..a948182c --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.PubsubPeerDiscovery/PubsubPeerDiscoverySettings.cs @@ -0,0 +1,12 @@ +// SPDX-FileCopyrightText: 2024 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +namespace Nethermind.Libp2p.Protocols.PubsubPeerDiscovery; + +public class PubsubPeerDiscoverySettings +{ + public string[] Topics { get; set; } = ["_peer-discovery._p2p._pubsub"]; + public int Interval { get; set; } = 10_000; + public bool ListenOnly { get; set; } +} + diff --git a/src/libp2p/Libp2p.Protocols.Quic.Tests/Libp2p.Protocols.Quic.Tests.csproj b/src/libp2p/Libp2p.Protocols.Quic.Tests/Libp2p.Protocols.Quic.Tests.csproj index bde9026d..c7fd46c0 100644 --- a/src/libp2p/Libp2p.Protocols.Quic.Tests/Libp2p.Protocols.Quic.Tests.csproj +++ b/src/libp2p/Libp2p.Protocols.Quic.Tests/Libp2p.Protocols.Quic.Tests.csproj @@ -3,7 +3,6 @@ enable enable - true Nethermind.$(MSBuildProjectName.Replace(" ", "_")) Nethermind.$(MSBuildProjectName) diff --git a/src/libp2p/Libp2p.Protocols.Quic.Tests/ProtocolTests.cs b/src/libp2p/Libp2p.Protocols.Quic.Tests/ProtocolTests.cs new file mode 100644 index 00000000..0b72f45d --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.Quic.Tests/ProtocolTests.cs @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Nethermind.Libp2p.Core; + +namespace Nethermind.Libp2p.Protocols.Quic.Tests; + +public class ProtocolTests +{ + [Test] + public async Task Test_CreateProtocol() + { + CancellationTokenSource cts = new(); + QuicProtocol proto = new(); + _ = new QuicProtocol().ListenAsync(new TransportContext(new LocalPeer(new Identity(), new Core.Discovery.PeerStore(), new ProtocolStackSettings()), new ProtocolRef(proto), true), "/ip4/127.0.0.1/udp/0", cts.Token); + await Task.Delay(1000); + cts.Cancel(); + } +} diff --git a/src/libp2p/Libp2p.Protocols.Quic/Libp2p.Protocols.Quic.csproj b/src/libp2p/Libp2p.Protocols.Quic/Libp2p.Protocols.Quic.csproj index 75536ddd..b4bd9659 100644 --- a/src/libp2p/Libp2p.Protocols.Quic/Libp2p.Protocols.Quic.csproj +++ b/src/libp2p/Libp2p.Protocols.Quic/Libp2p.Protocols.Quic.csproj @@ -7,6 +7,7 @@ latest Nethermind.$(MSBuildProjectName) Nethermind.$(MSBuildProjectName.Replace(" ", "_")) + true diff --git a/src/libp2p/Libp2p.Protocols.Quic/QuicProtocol.cs b/src/libp2p/Libp2p.Protocols.Quic/QuicProtocol.cs index 9bf3a57b..28f853bd 100644 --- a/src/libp2p/Libp2p.Protocols.Quic/QuicProtocol.cs +++ b/src/libp2p/Libp2p.Protocols.Quic/QuicProtocol.cs @@ -5,60 +5,50 @@ using Multiformats.Address; using Multiformats.Address.Protocols; using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Core.Utils; using Nethermind.Libp2p.Protocols.Quic; using System.Buffers; using System.Net; using System.Net.Quic; using System.Net.Security; using System.Net.Sockets; -using System.Runtime.CompilerServices; -using System.Runtime.Versioning; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; namespace Nethermind.Libp2p.Protocols; #pragma warning disable CA1416 // Do not inform about platform compatibility +#pragma warning disable CA2252 // Do not inform about platform compatibility /// /// https://github.com/libp2p/specs/blob/master/quic/README.md /// -[RequiresPreviewFeatures] -public class QuicProtocol : IProtocol +public class QuicProtocol(ILoggerFactory? loggerFactory = null) : ITransportProtocol { - private readonly ILogger? _logger; - private readonly ECDsa _sessionKey; + private readonly ILogger? _logger = loggerFactory?.CreateLogger(); + private readonly ECDsa _sessionKey = ECDsa.Create(); - public QuicProtocol(ILoggerFactory? loggerFactory = null) - { - _logger = loggerFactory?.CreateLogger(); - _sessionKey = ECDsa.Create(); - } - - private static readonly List protocols = new() - { + private static readonly List protocols = + [ new SslApplicationProtocol("libp2p"), // SslApplicationProtocol.Http3, // webtransport - }; + ]; public string Id => "quic-v1"; + public static Multiaddress[] GetDefaultAddresses(PeerId peerId) => IpHelper.GetListenerAddresses() + .Select(a => Multiaddress.Decode($"/{(a.AddressFamily is AddressFamily.InterNetwork ? "ip4" : "ip6")}/{a}/udp/0/quic-v1/p2p/{peerId}")).ToArray(); + public static bool IsAddressMatch(Multiaddress addr) => addr.Has(); - public async Task ListenAsync(IChannel signalingChannel, IChannelFactory? channelFactory, IPeerContext context) + public async Task ListenAsync(ITransportContext context, Multiaddress localAddr, CancellationToken token) { - if (channelFactory is null) - { - throw new ArgumentException($"The protocol requires {nameof(channelFactory)}"); - } - if (!QuicListener.IsSupported) { throw new NotSupportedException("QUIC is not supported, check for presence of libmsquic and support of TLS 1.3."); } - Multiaddress addr = context.LocalPeer.Address; - MultiaddressProtocol ipProtocol = addr.Has() ? addr.Get() : addr.Get(); + MultiaddressProtocol ipProtocol = localAddr.Has() ? localAddr.Get() : localAddr.Get(); IPAddress ipAddress = IPAddress.Parse(ipProtocol.ToString()); - int udpPort = int.Parse(addr.Get().ToString()); + int udpPort = int.Parse(localAddr.Get().ToString()); IPEndPoint localEndpoint = new(ipAddress, udpPort); @@ -70,8 +60,8 @@ public async Task ListenAsync(IChannel signalingChannel, IChannelFactory? channe ServerAuthenticationOptions = new SslServerAuthenticationOptions { ApplicationProtocols = protocols, - RemoteCertificateValidationCallback = (_, c, _, _) => VerifyRemoteCertificate(context.RemotePeer, c), - ServerCertificate = CertificateHelper.CertificateFromIdentity(_sessionKey, context.LocalPeer.Identity) + RemoteCertificateValidationCallback = (_, c, _, _) => true, + ServerCertificate = CertificateHelper.CertificateFromIdentity(_sessionKey, context.Peer.Identity) }, }; @@ -82,39 +72,25 @@ public async Task ListenAsync(IChannel signalingChannel, IChannelFactory? channe ConnectionOptionsCallback = (_, _, _) => ValueTask.FromResult(serverConnectionOptions) }); - var localEndPoint = new Multiaddress(); - // IP (4 or 6 is based on source address). - var strLocalEndpoint = listener.LocalEndPoint.Address.ToString(); - localEndPoint = addr.Has() ? localEndPoint.Add(strLocalEndpoint) : localEndPoint.Add(strLocalEndpoint); - - // UDP - localEndPoint = localEndPoint.Add(listener.LocalEndPoint.Port); - - // Set on context - context.LocalEndpoint = localEndPoint; - if (udpPort == 0) { - context.LocalPeer.Address = context.LocalPeer.Address - .ReplaceOrAdd(listener.LocalEndPoint.Port); + localAddr = localAddr.ReplaceOrAdd(listener.LocalEndPoint.Port); } + context.ListenerReady(localAddr); _logger?.ReadyToHandleConnections(); - context.ListenerReady(); - TaskAwaiter signalingWawaiter = signalingChannel.GetAwaiter(); - signalingWawaiter.OnCompleted(() => - { - listener.DisposeAsync(); - }); + token.Register(() => _ = listener.DisposeAsync()); - while (!signalingWawaiter.IsCompleted) + while (!token.IsCancellationRequested) { try { - QuicConnection connection = await listener.AcceptConnectionAsync(); - _ = ProcessStreams(connection, context.Fork(), channelFactory); + QuicConnection connection = await listener.AcceptConnectionAsync(token); + INewConnectionContext clientContext = context.CreateConnection(); + + _ = ProcessStreams(clientContext, connection, token).ContinueWith(t => clientContext.Dispose()); } catch (Exception ex) { @@ -124,37 +100,24 @@ public async Task ListenAsync(IChannel signalingChannel, IChannelFactory? channe } } - public async Task DialAsync(IChannel signalingChannel, IChannelFactory? channelFactory, IPeerContext context) + public async Task DialAsync(ITransportContext context, Multiaddress remoteAddr, CancellationToken token) { - if (channelFactory is null) - { - throw new ArgumentException($"The protocol requires {nameof(channelFactory)}"); - } - if (!QuicConnection.IsSupported) { throw new NotSupportedException("QUIC is not supported, check for presence of libmsquic and support of TLS 1.3."); } - Multiaddress addr = context.LocalPeer.Address; + Multiaddress addr = remoteAddr; bool isIp4 = addr.Has(); MultiaddressProtocol protocol = isIp4 ? addr.Get() : addr.Get(); + IPAddress ipAddress = IPAddress.Parse(protocol.ToString()); int udpPort = int.Parse(addr.Get().ToString()); - IPEndPoint localEndpoint = new(ipAddress, udpPort); - - addr = context.RemotePeer.Address; - isIp4 = addr.Has(); - protocol = isIp4 ? addr.Get() : addr.Get(); - ipAddress = IPAddress.Parse(protocol.ToString()!); - udpPort = int.Parse(addr.Get().ToString()!); - IPEndPoint remoteEndpoint = new(ipAddress, udpPort); QuicClientConnectionOptions clientConnectionOptions = new() { - LocalEndPoint = localEndpoint, DefaultStreamErrorCode = 0, // Protocol-dependent error code. DefaultCloseErrorCode = 1, // Protocol-dependent error code. MaxInboundUnidirectionalStreams = 256, @@ -163,8 +126,8 @@ public async Task DialAsync(IChannel signalingChannel, IChannelFactory? channelF { TargetHost = null, ApplicationProtocols = protocols, - RemoteCertificateValidationCallback = (_, c, _, _) => VerifyRemoteCertificate(context.RemotePeer, c), - ClientCertificates = new X509CertificateCollection { CertificateHelper.CertificateFromIdentity(_sessionKey, context.LocalPeer.Identity) }, + RemoteCertificateValidationCallback = (_, c, _, _) => VerifyRemoteCertificate(remoteAddr, c), + ClientCertificates = [CertificateHelper.CertificateFromIdentity(_sessionKey, context.Peer.Identity)], }, RemoteEndPoint = remoteEndpoint, }; @@ -172,71 +135,44 @@ public async Task DialAsync(IChannel signalingChannel, IChannelFactory? channelF QuicConnection connection = await QuicConnection.ConnectAsync(clientConnectionOptions); _logger?.Connected(connection.LocalEndPoint, connection.RemoteEndPoint); + INewConnectionContext connectionContext = context.CreateConnection(); - signalingChannel.GetAwaiter().OnCompleted(() => - { - connection.CloseAsync(0); - }); - - await ProcessStreams(connection, context, channelFactory); + token.Register(() => _ = connection.CloseAsync(0)); + await ProcessStreams(connectionContext, connection, token); } - private static bool VerifyRemoteCertificate(IPeer? remotePeer, X509Certificate certificate) => - CertificateHelper.ValidateCertificate(certificate as X509Certificate2, remotePeer?.Address.Get().ToString()); + private static bool VerifyRemoteCertificate(Multiaddress remoteAddr, X509Certificate certificate) => + CertificateHelper.ValidateCertificate(certificate as X509Certificate2, remoteAddr.Get().ToString()); - private async Task ProcessStreams(QuicConnection connection, IPeerContext context, IChannelFactory channelFactory, CancellationToken token = default) + private async Task ProcessStreams(INewConnectionContext context, QuicConnection connection, CancellationToken token = default) { _logger?.LogDebug("New connection to {remote}", connection.RemoteEndPoint); - bool isIP4 = connection.LocalEndPoint.AddressFamily == AddressFamily.InterNetwork; - - Multiaddress localEndPointMultiaddress = new(); - string strLocalEndpointAddress = connection.LocalEndPoint.Address.ToString(); - localEndPointMultiaddress = isIP4 ? localEndPointMultiaddress.Add(strLocalEndpointAddress) : localEndPointMultiaddress.Add(strLocalEndpointAddress); - localEndPointMultiaddress = localEndPointMultiaddress.Add(connection.LocalEndPoint.Port); - - context.LocalEndpoint = localEndPointMultiaddress; - - context.LocalPeer.Address = isIP4 ? context.LocalPeer.Address.ReplaceOrAdd(strLocalEndpointAddress) : context.LocalPeer.Address.ReplaceOrAdd(strLocalEndpointAddress); - - IPEndPoint remoteIpEndpoint = connection.RemoteEndPoint!; - isIP4 = remoteIpEndpoint.AddressFamily == AddressFamily.InterNetwork; - - Multiaddress remoteEndPointMultiaddress = new(); - string strRemoteEndpointAddress = remoteIpEndpoint.Address.ToString(); - remoteEndPointMultiaddress = isIP4 ? remoteEndPointMultiaddress.Add(strRemoteEndpointAddress) : remoteEndPointMultiaddress.Add(strRemoteEndpointAddress); - remoteEndPointMultiaddress = remoteEndPointMultiaddress.Add(remoteIpEndpoint.Port); - - context.RemoteEndpoint = remoteEndPointMultiaddress; - - context.Connected(context.RemotePeer); + using INewSessionContext session = context.UpgradeToSession(); _ = Task.Run(async () => { - foreach (IChannelRequest request in context.SubDialRequests.GetConsumingEnumerable()) + foreach (UpgradeOptions upgradeOptions in session.DialRequests) { QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); - IPeerContext dialContext = context.Fork(); - dialContext.SpecificProtocolRequest = request; - IChannel upChannel = channelFactory.SubDial(dialContext); - ExchangeData(stream, upChannel, request.CompletionSource); + IChannel upChannel = context.Upgrade(upgradeOptions with { ModeOverride = UpgradeModeOverride.Dial }); + ExchangeData(stream, upChannel); } }, token); while (!token.IsCancellationRequested) { QuicStream inboundStream = await connection.AcceptInboundStreamAsync(token); - IChannel upChannel = channelFactory.SubListen(context); - ExchangeData(inboundStream, upChannel, null); + IChannel upChannel = context.Upgrade(new UpgradeOptions { ModeOverride = UpgradeModeOverride.Listen }); + ExchangeData(inboundStream, upChannel); } } - private void ExchangeData(QuicStream stream, IChannel upChannel, TaskCompletionSource? tcs) + private void ExchangeData(QuicStream stream, IChannel upChannel) { upChannel.GetAwaiter().OnCompleted(() => { stream.Close(); - tcs?.SetResult(); _logger?.LogDebug("Stream {stream id}: Closed", stream.Id); }); diff --git a/src/libp2p/Libp2p.Protocols.Relay/Libp2p.Protocols.Relay.csproj b/src/libp2p/Libp2p.Protocols.Relay/Libp2p.Protocols.Relay.csproj new file mode 100644 index 00000000..c4e19369 --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.Relay/Libp2p.Protocols.Relay.csproj @@ -0,0 +1,33 @@ + + + + enable + enable + latest + Nethermind.$(MSBuildProjectName) + Nethermind.$(MSBuildProjectName.Replace(" ", "_")) + + + + README.md + libp2p network plaintext + + + + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + diff --git a/src/libp2p/Libp2p.Protocols.Relay/README.md b/src/libp2p/Libp2p.Protocols.Relay/README.md new file mode 100644 index 00000000..83546086 --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.Relay/README.md @@ -0,0 +1,3 @@ +# Relay protocols - WIP + +See the [libp2p spec](https://github.com/libp2p/specs/tree/master/relay) diff --git a/src/libp2p/Libp2p.Protocols.Relay/RelayHopProtocol.cs b/src/libp2p/Libp2p.Protocols.Relay/RelayHopProtocol.cs new file mode 100644 index 00000000..868e6862 --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.Relay/RelayHopProtocol.cs @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Nethermind.Libp2p.Core; + +namespace Nethermind.Libp2p.Protocols; + +public class RelayHopProtocol : ISessionProtocol +{ + public string Id => "/libp2p/circuit/relay/0.2.0/hop"; + + public Task DialAsync(IChannel downChannel, ISessionContext context) + { + throw new NotImplementedException(); + } + + public Task ListenAsync(IChannel downChannel, ISessionContext context) + { + throw new NotImplementedException(); + } +} diff --git a/src/libp2p/Libp2p.Protocols.Relay/RelayStopProtocol.cs b/src/libp2p/Libp2p.Protocols.Relay/RelayStopProtocol.cs new file mode 100644 index 00000000..b7dd9531 --- /dev/null +++ b/src/libp2p/Libp2p.Protocols.Relay/RelayStopProtocol.cs @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited +// SPDX-License-Identifier: MIT + +using Nethermind.Libp2p.Core; + +namespace Nethermind.Libp2p.Protocols; + +public class RelayStopProtocol : ISessionProtocol +{ + public string Id => "/libp2p/circuit/relay/0.2.0/stop"; + + public Task DialAsync(IChannel downChannel, ISessionContext context) + { + throw new NotImplementedException(); + } + + public Task ListenAsync(IChannel downChannel, ISessionContext context) + { + throw new NotImplementedException(); + } +} diff --git a/src/libp2p/Libp2p.Protocols.Tls.Tests/TlsProtocolTests.cs b/src/libp2p/Libp2p.Protocols.Tls.Tests/TlsProtocolTests.cs index ad20478f..d3c08ec8 100644 --- a/src/libp2p/Libp2p.Protocols.Tls.Tests/TlsProtocolTests.cs +++ b/src/libp2p/Libp2p.Protocols.Tls.Tests/TlsProtocolTests.cs @@ -1,6 +1,12 @@ // SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited // SPDX-License-Identifier: MIT +using Microsoft.Extensions.Logging; +using Multiformats.Address; +using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Core.TestsBase; +using NSubstitute; + namespace Nethermind.Libp2p.Protocols.TLS.Tests; [TestFixture] @@ -8,46 +14,47 @@ namespace Nethermind.Libp2p.Protocols.TLS.Tests; public class TlsProtocolTests { [Test] - [Ignore("Infinite loop")] + [Ignore("Needs a fix on Windows")] public async Task Test_ConnectionEstablished_AfterHandshake() { // Arrange IChannel downChannel = new TestChannel(); IChannel downChannelFromProtocolPov = ((TestChannel)downChannel).Reverse(); IChannelFactory channelFactory = Substitute.For(); - IPeerContext peerContext = Substitute.For(); - IPeerContext listenerContext = Substitute.For(); + IConnectionContext listenerContext = Substitute.For(); ILoggerFactory loggerFactory = Substitute.For(); - TestChannel upChannel = new TestChannel(); - channelFactory.SubDial(Arg.Any(), Arg.Any()) - .Returns(upChannel); + TestChannel upChannel = new(); + channelFactory.Upgrade(Arg.Any()).Returns(upChannel); + + TestChannel listenerUpChannel = new(); + channelFactory.Upgrade(Arg.Any()).Returns(listenerUpChannel); + + IConnectionContext dialerContext = Substitute.For(); + dialerContext.Peer.Identity.Returns(TestPeers.Identity(1)); + dialerContext.Peer.ListenAddresses.Returns([TestPeers.Multiaddr(1)]); + dialerContext.State.Returns(new State()); - TestChannel listenerUpChannel = new TestChannel(); - channelFactory.SubListen(Arg.Any(), Arg.Any()) - .Returns(listenerUpChannel); - peerContext.LocalPeer.Identity.Returns(new Identity()); - listenerContext.LocalPeer.Identity.Returns(new Identity()); + listenerContext.Peer.Identity.Returns(TestPeers.Identity(2)); - string peerId = peerContext.LocalPeer.Identity.PeerId.ToString(); + string peerId = dialerContext.Peer.Identity.PeerId.ToString(); Multiaddress localAddr = $"/ip4/0.0.0.0/tcp/0/p2p/{peerId}"; - peerContext.LocalPeer.Address.Returns(localAddr); - listenerContext.RemotePeer.Address.Returns(localAddr); + //listenerContext.State.RemoteAddress.Returns(localAddr); - string listenerPeerId = listenerContext.LocalPeer.Identity.PeerId.ToString(); + string listenerPeerId = listenerContext.Peer.Identity.PeerId.ToString(); Multiaddress listenerAddr = $"/ip4/0.0.0.0/tcp/0/p2p/{listenerPeerId}"; - peerContext.RemotePeer.Address.Returns(listenerAddr); - listenerContext.LocalPeer.Address.Returns(listenerAddr); + //dialerContext.State.RemoteAddress.Returns(listenerAddr); + //listenerContext.State.RemoteAddress.Returns(listenerAddr); - var i_multiplexerSettings = new MultiplexerSettings(); - var r_multiplexerSettings = new MultiplexerSettings(); - TlsProtocol tlsProtocolListener = new TlsProtocol(i_multiplexerSettings, loggerFactory); - TlsProtocol tlsProtocolInitiator = new TlsProtocol(r_multiplexerSettings, loggerFactory); + MultiplexerSettings i_multiplexerSettings = new(); + MultiplexerSettings r_multiplexerSettings = new(); + TlsProtocol tlsProtocolListener = new(i_multiplexerSettings, loggerFactory); + TlsProtocol tlsProtocolInitiator = new(r_multiplexerSettings, loggerFactory); // Act - Task listenTask = tlsProtocolListener.ListenAsync(downChannel, channelFactory, listenerContext); - Task dialTask = tlsProtocolInitiator.DialAsync(downChannelFromProtocolPov, channelFactory, peerContext); + Task listenTask = tlsProtocolListener.ListenAsync(downChannel, listenerContext); + Task dialTask = tlsProtocolInitiator.DialAsync(downChannelFromProtocolPov, dialerContext); int sent = 42; ValueTask writeTask = listenerUpChannel.Reverse().WriteVarintAsync(sent); diff --git a/src/libp2p/Libp2p.Protocols.Tls.Tests/Using.cs b/src/libp2p/Libp2p.Protocols.Tls.Tests/Using.cs index 2048bd47..bb339bea 100644 --- a/src/libp2p/Libp2p.Protocols.Tls.Tests/Using.cs +++ b/src/libp2p/Libp2p.Protocols.Tls.Tests/Using.cs @@ -1,11 +1,5 @@ // SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited // SPDX-License-Identifier: MIT -global using Nethermind.Libp2p.Core; -global using Nethermind.Libp2p.Core.TestsBase; -global using NSubstitute; global using NUnit.Framework; -global using Multiformats.Address; -global using System.Threading.Tasks; -global using Microsoft.Extensions.Logging; diff --git a/src/libp2p/Libp2p.Protocols.Tls/TlsProtocol.cs b/src/libp2p/Libp2p.Protocols.Tls/TlsProtocol.cs index 5a261fa0..388da7dc 100644 --- a/src/libp2p/Libp2p.Protocols.Tls/TlsProtocol.cs +++ b/src/libp2p/Libp2p.Protocols.Tls/TlsProtocol.cs @@ -1,5 +1,4 @@ using System.Buffers; -using System.Net; using System.Net.Security; using Nethermind.Libp2p.Protocols.Quic; using System.Security.Cryptography.X509Certificates; @@ -12,104 +11,113 @@ namespace Nethermind.Libp2p.Protocols; -public class TlsProtocol(MultiplexerSettings? multiplexerSettings = null, ILoggerFactory? loggerFactory = null) : IProtocol +public class TlsProtocol(MultiplexerSettings? multiplexerSettings = null, ILoggerFactory? loggerFactory = null) : IConnectionProtocol { private readonly ECDsa _sessionKey = ECDsa.Create(); private readonly ILogger? _logger = loggerFactory?.CreateLogger(); - public Lazy> ApplicationProtocols = new Lazy>(() => multiplexerSettings?.Multiplexers.Select(proto => new SslApplicationProtocol(proto.Id)).ToList() ?? []); + public Lazy> ApplicationProtocols = new(() => multiplexerSettings?.Multiplexers.Select(proto => new SslApplicationProtocol(proto.Id)).ToList() ?? []); public SslApplicationProtocol? LastNegotiatedApplicationProtocol { get; private set; } public string Id => "/tls/1.0.0"; - public async Task ListenAsync(IChannel downChannel, IChannelFactory? channelFactory, IPeerContext context) + public async Task ListenAsync(IChannel downChannel, IConnectionContext context) { - _logger?.LogInformation("Starting ListenAsync: PeerId {LocalPeerId}", context.LocalPeer.Address.Get()); - if (channelFactory is null) + try { - throw new ArgumentException("Protocol is not properly instantiated"); - } + _logger?.LogInformation("Starting ListenAsync: PeerId {LocalPeerId}", context.Peer.Identity.PeerId); - Stream str = new ChannelStream(downChannel); - X509Certificate certificate = CertificateHelper.CertificateFromIdentity(_sessionKey, context.LocalPeer.Identity); - _logger?.LogDebug("Successfully created X509Certificate for PeerId {LocalPeerId}. Certificate Subject: {Subject}, Issuer: {Issuer}", context.LocalPeer.Address.Get(), certificate.Subject, certificate.Issuer); + Stream str = new ChannelStream(downChannel); + X509Certificate certificate = CertificateHelper.CertificateFromIdentity(_sessionKey, context.Peer.Identity); + _logger?.LogDebug("Successfully created X509Certificate for PeerId {LocalPeerId}. Certificate Subject: {Subject}, Issuer: {Issuer}", context.Peer.Identity.PeerId, certificate.Subject, certificate.Issuer); - SslServerAuthenticationOptions serverAuthenticationOptions = new() - { - ApplicationProtocols = ApplicationProtocols.Value, - RemoteCertificateValidationCallback = (_, certificate, _, _) => VerifyRemoteCertificate(context.RemotePeer.Address, certificate), - ServerCertificate = certificate, - ClientCertificateRequired = true, - }; - _logger?.LogTrace("SslServerAuthenticationOptions initialized with ApplicationProtocols: {Protocols}.", string.Join(", ", ApplicationProtocols.Value)); - SslStream sslStream = new(str, false, serverAuthenticationOptions.RemoteCertificateValidationCallback); - _logger?.LogTrace("SslStream initialized."); - try - { - await sslStream.AuthenticateAsServerAsync(serverAuthenticationOptions); - _logger?.LogInformation("Server TLS Authentication successful. PeerId: {RemotePeerId}, NegotiatedProtocol: {Protocol}.", context.RemotePeer.Address.Get(), sslStream.NegotiatedApplicationProtocol.Protocol); + SslServerAuthenticationOptions serverAuthenticationOptions = new() + { + ApplicationProtocols = ApplicationProtocols.Value, + RemoteCertificateValidationCallback = (_, certificate, _, _) => VerifyRemoteCertificate(context.State.RemoteAddress, certificate), + ServerCertificate = certificate, + ClientCertificateRequired = true, + }; + _logger?.LogTrace("SslServerAuthenticationOptions initialized with ApplicationProtocols: {Protocols}.", string.Join(", ", ApplicationProtocols.Value)); + SslStream sslStream = new(str, false, serverAuthenticationOptions.RemoteCertificateValidationCallback); + _logger?.LogTrace("SslStream initialized."); + try + { + await sslStream.AuthenticateAsServerAsync(serverAuthenticationOptions); + _logger?.LogInformation("Server TLS Authentication successful. PeerId: {RemotePeerId}, NegotiatedProtocol: {Protocol}.", context.State.RemotePeerId, sslStream.NegotiatedApplicationProtocol.Protocol); + } + catch (Exception ex) + { + _logger?.LogError("Error during TLS authentication for PeerId {RemotePeerId}: {ErrorMessage}.", context.State.RemotePeerId, ex.Message); + _logger?.LogDebug("TLS Authentication Exception Details: {StackTrace}", ex.StackTrace); + throw; + } + _logger?.LogDebug($"{Encoding.UTF8.GetString(sslStream.NegotiatedApplicationProtocol.Protocol.ToArray())} protocol negotiated"); + IChannel upChannel = context.Upgrade(); + await ExchangeData(sslStream, upChannel, _logger); + _ = upChannel.CloseAsync(); } catch (Exception ex) { - _logger?.LogError("Error during TLS authentication for PeerId {RemotePeerId}: {ErrorMessage}.", context.RemotePeer.Address.Get(), ex.Message); - _logger?.LogDebug("TLS Authentication Exception Details: {StackTrace}", ex.StackTrace); + _logger?.LogError(ex, "Error during TLS protocol negotiation."); throw; } - _logger?.LogDebug($"{Encoding.UTF8.GetString(sslStream.NegotiatedApplicationProtocol.Protocol.ToArray())} protocol negotiated"); - IChannel upChannel = channelFactory.SubListen(context); - await ExchangeData(sslStream, upChannel, _logger); - _ = upChannel.CloseAsync(); } private static bool VerifyRemoteCertificate(Multiaddress remotePeerAddress, X509Certificate certificate) => CertificateHelper.ValidateCertificate(certificate as X509Certificate2, remotePeerAddress.Get().ToString()); - public async Task DialAsync(IChannel downChannel, IChannelFactory? channelFactory, IPeerContext context) + public async Task DialAsync(IChannel downChannel, IConnectionContext context) { - _logger?.LogInformation("Starting DialAsync: LocalPeerId {LocalPeerId}", context.LocalPeer.Address.Get()); - if (channelFactory is null) + try { - throw new ArgumentException("Protocol is not properly instantiated"); - } - Multiaddress addr = context.LocalPeer.Address; - bool isIP4 = addr.Has(); - MultiaddressProtocol ipProtocol = isIP4 ? addr.Get() : addr.Get(); - IPAddress ipAddress = IPAddress.Parse(ipProtocol.ToString()); + _logger?.LogInformation("Starting DialAsync: LocalPeerId {LocalPeerId}", context.Peer.Identity.PeerId); - SslClientAuthenticationOptions clientAuthenticationOptions = new() - { - CertificateChainPolicy = new X509ChainPolicy + // TODO + Multiaddress addr = context.Peer.ListenAddresses.First(); + bool isIP4 = addr.Has(); + MultiaddressProtocol ipProtocol = isIP4 ? addr.Get() : addr.Get(); + + SslClientAuthenticationOptions clientAuthenticationOptions = new() { - RevocationMode = X509RevocationMode.NoCheck, - VerificationFlags = X509VerificationFlags.AllowUnknownCertificateAuthority - }, - TargetHost = ipAddress.ToString(), - ApplicationProtocols = ApplicationProtocols.Value, - EnabledSslProtocols = System.Security.Authentication.SslProtocols.Tls13, - RemoteCertificateValidationCallback = (_, certificate, _, _) => VerifyRemoteCertificate(context.RemotePeer.Address, certificate), - ClientCertificates = new X509CertificateCollection { CertificateHelper.CertificateFromIdentity(_sessionKey, context.LocalPeer.Identity) }, - }; - _logger?.LogTrace("SslClientAuthenticationOptions initialized for PeerId {RemotePeerId}.", context.RemotePeer.Address.Get()); - Stream str = new ChannelStream(downChannel); - SslStream sslStream = new(str, false, clientAuthenticationOptions.RemoteCertificateValidationCallback); - _logger?.LogTrace("Sslstream initialized."); - try - { - await sslStream.AuthenticateAsClientAsync(clientAuthenticationOptions); - _logger?.LogInformation("Client TLS Authentication successful. RemotePeerId: {RemotePeerId}, NegotiatedProtocol: {Protocol}.", context.RemotePeer.Address.Get(), sslStream.NegotiatedApplicationProtocol.Protocol); + CertificateChainPolicy = new X509ChainPolicy + { + RevocationMode = X509RevocationMode.NoCheck, + VerificationFlags = X509VerificationFlags.AllowUnknownCertificateAuthority + }, + TargetHost = ipProtocol?.ToString(), + ApplicationProtocols = ApplicationProtocols.Value, + EnabledSslProtocols = System.Security.Authentication.SslProtocols.Tls13, + RemoteCertificateValidationCallback = (_, certificate, _, _) => VerifyRemoteCertificate(context.State.RemoteAddress, certificate), + ClientCertificates = [CertificateHelper.CertificateFromIdentity(_sessionKey, context.Peer.Identity)], + }; + //_logger?.LogTrace("SslClientAuthenticationOptions initialized for PeerId {RemotePeerId}.", context.State.RemotePeerId); + Stream str = new ChannelStream(downChannel); + SslStream sslStream = new(str, false, clientAuthenticationOptions.RemoteCertificateValidationCallback); + _logger?.LogTrace("Sslstream initialized."); + try + { + await sslStream.AuthenticateAsClientAsync(clientAuthenticationOptions); + //_logger?.LogInformation("Client TLS Authentication successful. RemotePeerId: {RemotePeerId}, NegotiatedProtocol: {Protocol}.", context.State.RemotePeerId, sslStream.NegotiatedApplicationProtocol.Protocol); + } + catch (Exception ex) + { + //_logger?.LogError("Error during TLS client authentication for RemotePeerId {RemotePeerId}: {ErrorMessage}.", context.State.RemotePeerId, ex.Message); + _logger?.LogDebug("TLS Authentication Exception Details: {StackTrace}", ex.StackTrace); + return; + } + _logger?.LogDebug("Subdialing protocols: {Protocols}.", string.Join(", ", context.SubProtocols.Select(x => x.Id))); + IChannel upChannel = context.Upgrade(); + _logger?.LogDebug("SubDial completed for PeerId {RemotePeerId}.", context.State.RemotePeerId); + await ExchangeData(sslStream, upChannel, _logger); + _logger?.LogDebug("Connection closed for PeerId {RemotePeerId}.", context.State.RemotePeerId); + _ = upChannel.CloseAsync(); } catch (Exception ex) { - _logger?.LogError("Error during TLS client authentication for RemotePeerId {RemotePeerId}: {ErrorMessage}.", context.RemotePeer.Address.Get(), ex.Message); - _logger?.LogDebug("TLS Authentication Exception Details: {StackTrace}", ex.StackTrace); - return; + _logger?.LogError(ex, "Error during TLS protocol negotiation."); + throw; } - _logger?.LogDebug("Subdialing protocols: {Protocols}.", string.Join(", ", channelFactory.SubProtocols.Select(x => x.Id))); - IChannel upChannel = channelFactory.SubDial(context); - _logger?.LogDebug("SubDial completed for PeerId {RemotePeerId}.", context.RemotePeer.Address.Get()); - await ExchangeData(sslStream, upChannel, _logger); - _logger?.LogDebug("Connection closed for PeerId {RemotePeerId}.", context.RemotePeer.Address.Get()); - _ = upChannel.CloseAsync(); } private static async Task ExchangeData(SslStream sslStream, IChannel upChannel, ILogger? logger) @@ -139,6 +147,7 @@ private static async Task ExchangeData(SslStream sslStream, IChannel upChannel, await upChannel.CloseAsync(); } }); + Task readTask = Task.Run(async () => { try @@ -171,6 +180,7 @@ private static async Task ExchangeData(SslStream sslStream, IChannel upChannel, logger?.LogError(ex, "Error while reading from sslStream"); } }); + await Task.WhenAll(writeTask, readTask); } } diff --git a/src/libp2p/Libp2p.Protocols.Yamux.Tests/YamuxProtocolTests.cs b/src/libp2p/Libp2p.Protocols.Yamux.Tests/YamuxProtocolTests.cs index 72a1cf75..44a04098 100644 --- a/src/libp2p/Libp2p.Protocols.Yamux.Tests/YamuxProtocolTests.cs +++ b/src/libp2p/Libp2p.Protocols.Yamux.Tests/YamuxProtocolTests.cs @@ -5,7 +5,6 @@ using Nethermind.Libp2p.Core.TestsBase; using NSubstitute; using NUnit.Framework.Internal; -using System.Collections.Concurrent; namespace Nethermind.Libp2p.Protocols.Noise.Tests; @@ -22,46 +21,50 @@ public class YamuxProtocolTests // Expect error and react to it [Test] - public async Task Test_Protocol_Communication() + public async Task Test_Protocol_Communication2() { IProtocol? proto1 = Substitute.For(); proto1.Id.Returns("proto1"); - IPeerContext dialerPeerContext = Substitute.For(); - var dialerRequests = new BlockingCollection() { new ChannelRequest() { SubProtocol = proto1 } }; - dialerPeerContext.SubDialRequests.Returns(dialerRequests); - TestChannel dialerDownChannel = new TestChannel(); - IChannelFactory dialerUpchannelFactory = Substitute.For(); - dialerUpchannelFactory.SubProtocols.Returns(new[] { proto1 }); - TestChannel dialerUpChannel = new TestChannel(); - dialerUpchannelFactory.SubDial(Arg.Any(), Arg.Any()) - .Returns(dialerUpChannel); + IConnectionContext dialerContext = Substitute.For(); + INewSessionContext dialerSessionContext = Substitute.For(); + dialerContext.UpgradeToSession().Returns(dialerSessionContext); + dialerContext.State.Returns(new State { RemoteAddress = TestPeers.Multiaddr(2) }); + dialerSessionContext.State.Returns(new State { RemoteAddress = TestPeers.Multiaddr(2) }); + dialerSessionContext.Id.Returns("dialer"); + dialerSessionContext.DialRequests.Returns([new UpgradeOptions() { SelectedProtocol = proto1 }]); + + TestChannel dialerDownChannel = new(); + dialerSessionContext.SubProtocols.Returns([proto1]); + TestChannel dialerUpChannel = new(); + dialerSessionContext.Upgrade(Arg.Any()).Returns(dialerUpChannel); _ = dialerUpChannel.Reverse().WriteLineAsync("hello").AsTask().ContinueWith((e) => dialerUpChannel.CloseAsync()); - IPeerContext listenerPeerContext = Substitute.For(); IChannel listenerDownChannel = dialerDownChannel.Reverse(); - IChannelFactory listenerUpchannelFactory = Substitute.For(); - var listenerRequests = new BlockingCollection(); - listenerPeerContext.SubDialRequests.Returns(listenerRequests); - listenerUpchannelFactory.SubProtocols.Returns(new[] { proto1 }); - TestChannel listenerUpChannel = new TestChannel(); - listenerUpchannelFactory.SubListen(Arg.Any(), Arg.Any()) - .Returns(listenerUpChannel); + + IConnectionContext listenerContext = Substitute.For(); + INewSessionContext listenerSessionContext = Substitute.For(); + listenerContext.UpgradeToSession().Returns(listenerSessionContext); + listenerContext.State.Returns(new State { RemoteAddress = TestPeers.Multiaddr(1) }); + listenerSessionContext.State.Returns(new State { RemoteAddress = TestPeers.Multiaddr(1) }); + listenerSessionContext.Id.Returns("listener"); + + listenerSessionContext.SubProtocols.Returns([proto1]); + TestChannel listenerUpChannel = new(); + listenerSessionContext.Upgrade(Arg.Any()).Returns(listenerUpChannel); YamuxProtocol proto = new(loggerFactory: new TestContextLoggerFactory()); - _ = proto.ListenAsync(listenerDownChannel, listenerUpchannelFactory, listenerPeerContext); + _ = proto.ListenAsync(listenerDownChannel, listenerContext); - _ = proto.DialAsync(dialerDownChannel, dialerUpchannelFactory, dialerPeerContext); + _ = proto.DialAsync(dialerDownChannel, dialerContext); - var res = await listenerUpChannel.Reverse().ReadLineAsync(); + string res = await listenerUpChannel.Reverse().ReadLineAsync(); await listenerUpChannel.CloseAsync(); Assert.That(res, Is.EqualTo("hello")); - - await Task.Delay(1000); } } diff --git a/src/libp2p/Libp2p.Protocols.Yamux/YamuxProtocol.cs b/src/libp2p/Libp2p.Protocols.Yamux/YamuxProtocol.cs index 805193ad..d67f0e76 100644 --- a/src/libp2p/Libp2p.Protocols.Yamux/YamuxProtocol.cs +++ b/src/libp2p/Libp2p.Protocols.Yamux/YamuxProtocol.cs @@ -11,7 +11,7 @@ namespace Nethermind.Libp2p.Protocols; -public class YamuxProtocol : SymmetricProtocol, IProtocol +public class YamuxProtocol : SymmetricProtocol, IConnectionProtocol { private const int HeaderLength = 12; private const int PingDelay = 30_000; @@ -26,45 +26,42 @@ public YamuxProtocol(MultiplexerSettings? multiplexerSettings = null, ILoggerFac public string Id => "/yamux/1.0.0"; - protected override async Task ConnectAsync(IChannel channel, IChannelFactory? channelFactory, - IPeerContext context, bool isListener) + protected override async Task ConnectAsync(IChannel channel, IConnectionContext context, bool isListener) { - if (channelFactory is null) - { - throw new ArgumentException("ChannelFactory should be available for a muxer", nameof(channelFactory)); - } - - _logger?.LogInformation(isListener ? "Listen" : "Dial"); + _logger?.LogInformation("Ctx({ctx}): {mode} {peer}", context.Id, isListener ? "Listen" : "Dial", context.State.RemoteAddress); TaskAwaiter downChannelAwaiter = channel.GetAwaiter(); - Dictionary channels = new(); + + Dictionary channels = []; try { int streamIdCounter = isListener ? 2 : 1; - context.Connected(context.RemotePeer); + using INewSessionContext session = context.UpgradeToSession(); + + _logger?.LogInformation("Ctx({ctx}): Session created for {peer}", session.Id, session.State.RemoteAddress); int pingCounter = 0; using Timer timer = new((s) => { - _ = WriteHeaderAsync(channel, new YamuxHeader { Type = YamuxHeaderType.Ping, Flags = YamuxHeaderFlags.Syn, Length = ++pingCounter }); + _ = WriteHeaderAsync(session.Id, channel, new YamuxHeader { Type = YamuxHeaderType.Ping, Flags = YamuxHeaderFlags.Syn, Length = ++pingCounter }); }, null, PingDelay, PingDelay); _ = Task.Run(() => { - foreach (IChannelRequest request in context.SubDialRequests.GetConsumingEnumerable()) + foreach (UpgradeOptions request in session.DialRequests) { int streamId = streamIdCounter; Interlocked.Add(ref streamIdCounter, 2); - _logger?.LogDebug("Stream {stream id}: Dialing with protocol {proto}", streamId, request.SubProtocol?.Id); - channels[streamId] = CreateUpchannel(streamId, YamuxHeaderFlags.Syn, request); + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: Dialing with protocol {proto}", session.Id, streamId, request.SelectedProtocol?.Id); + channels[streamId] = CreateUpchannel(session.Id, streamId, YamuxHeaderFlags.Syn, request); } }); while (!downChannelAwaiter.IsCompleted) { - YamuxHeader header = await ReadHeaderAsync(channel); + YamuxHeader header = await ReadHeaderAsync(session.Id, channel); ReadOnlySequence data = default; if (header.Type > YamuxHeaderType.GoAway) @@ -77,7 +74,7 @@ protected override async Task ConnectAsync(IChannel channel, IChannelFactory? ch { if ((header.Flags & YamuxHeaderFlags.Syn) == YamuxHeaderFlags.Syn) { - _ = WriteHeaderAsync(channel, + _ = WriteHeaderAsync(session.Id, channel, new YamuxHeader { Flags = YamuxHeaderFlags.Ack, @@ -85,14 +82,14 @@ protected override async Task ConnectAsync(IChannel channel, IChannelFactory? ch Length = header.Length, }); - _logger?.LogDebug("Ping received and acknowledged"); + _logger?.LogDebug("Ctx({ctx}): Ping received and acknowledged", session.Id); } continue; } if (header.Type == YamuxHeaderType.GoAway) { - _logger?.LogDebug("Closing all streams"); + _logger?.LogDebug("Ctx({ctx}): Closing all streams", session.Id); foreach (ChannelState channelState in channels.Values) { @@ -110,7 +107,7 @@ protected override async Task ConnectAsync(IChannel channel, IChannelFactory? ch if ((header.Flags & YamuxHeaderFlags.Syn) == YamuxHeaderFlags.Syn && !channels.ContainsKey(header.StreamID)) { - channels[header.StreamID] = CreateUpchannel(header.StreamID, YamuxHeaderFlags.Ack, null); + channels[header.StreamID] = CreateUpchannel(session.Id, header.StreamID, YamuxHeaderFlags.Ack, new UpgradeOptions()); } if (!channels.ContainsKey(header.StreamID)) @@ -119,7 +116,7 @@ protected override async Task ConnectAsync(IChannel channel, IChannelFactory? ch { await channel.ReadAsync(header.Length); } - _logger?.LogDebug("Stream {stream id}: Ignored for closed stream", header.StreamID); + _logger?.LogDebug("Ctx({ctx}): Stream {stream id}: Ignored for closed stream", session.Id, header.StreamID); continue; } @@ -127,10 +124,10 @@ protected override async Task ConnectAsync(IChannel channel, IChannelFactory? ch { if (header.Length > channels[header.StreamID].LocalWindow.Available) { - _logger?.LogDebug("Stream {stream id}: Data length > windows size: {length} > {window size}", + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: Data length > windows size: {length} > {window size}", session.Id, header.StreamID, header.Length, channels[header.StreamID].LocalWindow.Available); - await WriteGoAwayAsync(channel, SessionTerminationCode.ProtocolError); + await WriteGoAwayAsync(session.Id, channel, SessionTerminationCode.ProtocolError); return; } @@ -139,8 +136,8 @@ protected override async Task ConnectAsync(IChannel channel, IChannelFactory? ch bool spent = channels[header.StreamID].LocalWindow.SpendWindow((int)data.Length); if (!spent) { - _logger?.LogDebug("Stream {stream id}: Window spent out of budget", header.StreamID); - await WriteGoAwayAsync(channel, SessionTerminationCode.InternalError); + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: Window spent out of budget", session.Id, header.StreamID); + await WriteGoAwayAsync(session.Id, channel, SessionTerminationCode.InternalError); return; } @@ -153,7 +150,7 @@ protected override async Task ConnectAsync(IChannel channel, IChannelFactory? ch int extendedBy = channels[header.StreamID].LocalWindow.ExtendWindowIfNeeded(); if (extendedBy is not 0) { - _ = WriteHeaderAsync(channel, + _ = WriteHeaderAsync(session.Id, channel, new YamuxHeader { Type = YamuxHeaderType.WindowUpdate, @@ -172,7 +169,7 @@ protected override async Task ConnectAsync(IChannel channel, IChannelFactory? ch int extendedBy = channelState.LocalWindow.ExtendWindowIfNeeded(); if (extendedBy is not 0) { - _ = WriteHeaderAsync(channel, + _ = WriteHeaderAsync(session.Id, channel, new YamuxHeader { Type = YamuxHeaderType.WindowUpdate, @@ -189,7 +186,7 @@ protected override async Task ConnectAsync(IChannel channel, IChannelFactory? ch { int oldSize = channels[header.StreamID].RemoteWindow.Available; int newSize = channels[header.StreamID].RemoteWindow.ExtendWindow(header.Length); - _logger?.LogDebug("Stream {stream id}: Window update requested: {old} => {new}", header.StreamID, oldSize, newSize); + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: Window update requested: {old} => {new}", session.Id, header.StreamID, oldSize, newSize); } if ((header.Flags & YamuxHeaderFlags.Fin) == YamuxHeaderFlags.Fin) @@ -200,19 +197,19 @@ protected override async Task ConnectAsync(IChannel channel, IChannelFactory? ch } _ = state.Channel?.WriteEofAsync(); - _logger?.LogDebug("Stream {stream id}: Finish receiving", header.StreamID); + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: Finish receiving", session.Id, header.StreamID); } if ((header.Flags & YamuxHeaderFlags.Rst) == YamuxHeaderFlags.Rst) { _ = channels[header.StreamID].Channel?.CloseAsync(); - _logger?.LogDebug("Stream {stream id}: Reset", header.StreamID); + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: Reset", session.Id, header.StreamID); } } - await WriteGoAwayAsync(channel, SessionTerminationCode.Ok); + await WriteGoAwayAsync(session.Id, channel, SessionTerminationCode.Ok); - ChannelState CreateUpchannel(int streamId, YamuxHeaderFlags initiationFlag, IChannelRequest? channelRequest) + ChannelState CreateUpchannel(string contextId, int streamId, YamuxHeaderFlags initiationFlag, UpgradeOptions upgradeOptions) { bool isListenerChannel = isListener ^ (streamId % 2 == 0); @@ -221,30 +218,26 @@ ChannelState CreateUpchannel(int streamId, YamuxHeaderFlags initiationFlag, ICha if (isListenerChannel) { - upChannel = channelFactory.SubListen(context); + upChannel = session.Upgrade(upgradeOptions with { ModeOverride = UpgradeModeOverride.Listen }); } else { - IPeerContext dialContext = context.Fork(); - dialContext.SpecificProtocolRequest = channelRequest; - upChannel = channelFactory.SubDial(dialContext); + upChannel = session.Upgrade(upgradeOptions with { ModeOverride = UpgradeModeOverride.Dial }); } - ChannelState state = new(upChannel, channelRequest); - TaskCompletionSource? tcs = state.Request?.CompletionSource; + ChannelState state = new(upChannel); upChannel.GetAwaiter().OnCompleted(() => { - tcs?.SetResult(); channels.Remove(streamId); - _logger?.LogDebug("Stream {stream id}: Closed", streamId); + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: Closed", contextId, streamId); }); Task.Run(async () => { try { - await WriteHeaderAsync(channel, + await WriteHeaderAsync(contextId, channel, new YamuxHeader { Flags = initiationFlag, @@ -254,22 +247,22 @@ await WriteHeaderAsync(channel, if (initiationFlag == YamuxHeaderFlags.Syn) { - _logger?.LogDebug("Stream {stream id}: New stream request sent", streamId); + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: New stream request sent", contextId, streamId); } else { - _logger?.LogDebug("Stream {stream id}: New stream request acknowledged", streamId); + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: New stream request acknowledged", contextId, streamId); } - await foreach (var upData in upChannel.ReadAllAsync()) + await foreach (ReadOnlySequence upData in upChannel.ReadAllAsync()) { - _logger?.LogDebug("Stream {stream id}: Receive from upchannel, length={length}", streamId, upData.Length); + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: Receive from upchannel, length={length}", contextId, streamId, upData.Length); for (int i = 0; i < upData.Length;) { int sendingSize = await state.RemoteWindow.SpendWindowOrWait((int)upData.Length - i); - await WriteHeaderAsync(channel, + await WriteHeaderAsync(contextId, channel, new YamuxHeader { Type = YamuxHeaderType.Data, @@ -280,22 +273,22 @@ await WriteHeaderAsync(channel, } } - await WriteHeaderAsync(channel, + await WriteHeaderAsync(contextId, channel, new YamuxHeader { Flags = YamuxHeaderFlags.Fin, Type = YamuxHeaderType.WindowUpdate, StreamID = streamId }); - _logger?.LogDebug("Stream {stream id}: Upchannel finished writing", streamId); + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: Upchannel finished writing", contextId, streamId); } catch (ChannelClosedException e) { - _logger?.LogDebug("Stream {stream id}: Closed due to transport disconnection", streamId); + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: Closed due to transport disconnection", contextId, streamId); } catch (Exception e) { - await WriteHeaderAsync(channel, + await WriteHeaderAsync(contextId, channel, new YamuxHeader { Flags = YamuxHeaderFlags.Rst, @@ -305,22 +298,22 @@ await WriteHeaderAsync(channel, _ = upChannel.CloseAsync(); channels.Remove(streamId); - _logger?.LogDebug("Stream {stream id}: Unexpected error, closing: {error}", streamId, e.Message); + _logger?.LogDebug("Ctx({ctx}), stream {stream id}: Unexpected error, closing: {error}", contextId, streamId, e.Message); } }); return state; } } - catch (ChannelClosedException ex) + catch (ChannelClosedException) { - _logger?.LogDebug("Closed due to transport disconnection"); + _logger?.LogDebug("Ctx({ctx}): Closed due to transport disconnection", context.Id); } catch (Exception ex) { - await WriteGoAwayAsync(channel, SessionTerminationCode.InternalError); - _logger?.LogDebug("Closed with exception {exception}", ex.Message); - _logger?.LogTrace("{stackTrace}", ex.StackTrace); + _logger?.LogDebug("Ctx({ctx}): Closed with exception \"{exception}\" {stackTrace}", context.Id, ex.Message, ex.StackTrace); + await WriteGoAwayAsync(context.Id, channel, SessionTerminationCode.InternalError); + await channel.CloseAsync(); } foreach (ChannelState upChannel in channels.Values) @@ -329,15 +322,15 @@ await WriteHeaderAsync(channel, } } - private async Task ReadHeaderAsync(IReader reader, CancellationToken token = default) + private async Task ReadHeaderAsync(string contextId, IReader reader, CancellationToken token = default) { byte[] headerData = (await reader.ReadAsync(HeaderLength, token: token).OrThrow()).ToArray(); YamuxHeader header = YamuxHeader.FromBytes(headerData); - _logger?.LogTrace("Stream {stream id}: Receive type={type} flags={flags} length={length}", header.StreamID, header.Type, header.Flags, header.Length); + _logger?.LogTrace("Ctx({ctx}), stream {stream id}: Receive type={type} flags={flags} length={length}", contextId, header.StreamID, header.Type, header.Flags, header.Length); return header; } - private async Task WriteHeaderAsync(IWriter writer, YamuxHeader header, ReadOnlySequence data = default) + private async Task WriteHeaderAsync(string contextId, IWriter writer, YamuxHeader header, ReadOnlySequence data = default) { byte[] headerBuffer = new byte[HeaderLength]; if (header.Type == YamuxHeaderType.Data) @@ -346,22 +339,22 @@ private async Task WriteHeaderAsync(IWriter writer, YamuxHeader header, ReadOnly } YamuxHeader.ToBytes(headerBuffer, ref header); - _logger?.LogTrace("Stream {stream id}: Send type={type} flags={flags} length={length}", header.StreamID, header.Type, header.Flags, header.Length); + _logger?.LogTrace("Ctx({ ctx}), stream {stream id}: Send type={type} flags={flags} length={length}", contextId, header.StreamID, header.Type, header.Flags, header.Length); await writer.WriteAsync(data.Length == 0 ? new ReadOnlySequence(headerBuffer) : data.Prepend(headerBuffer)).OrThrow(); } - private Task WriteGoAwayAsync(IWriter channel, SessionTerminationCode code) => - WriteHeaderAsync(channel, new YamuxHeader + private Task WriteGoAwayAsync(string contextId, IWriter channel, SessionTerminationCode code) => + WriteHeaderAsync(contextId, channel, new YamuxHeader { Type = YamuxHeaderType.GoAway, Length = (int)code, StreamID = 0, }); - private class ChannelState(IChannel? channel = default, IChannelRequest? request = default) + private class ChannelState(IChannel? channel = default) { public IChannel? Channel { get; set; } = channel; - public IChannelRequest? Request { get; set; } = request; + //public ChannelRequest? Request { get; set; } = request; public DataWindow LocalWindow { get; } = new(); public DataWindow RemoteWindow { get; } = new(); diff --git a/src/libp2p/Libp2p.sln b/src/libp2p/Libp2p.sln index 3d311237..a21c51e8 100644 --- a/src/libp2p/Libp2p.sln +++ b/src/libp2p/Libp2p.sln @@ -66,18 +66,22 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Libp2p.Protocols.Yamux.Tests", "Libp2p.Protocols.Yamux.Tests\Libp2p.Protocols.Yamux.Tests.csproj", "{D9003366-1562-49CA-B32D-087BBE3973ED}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TransportInterop", "..\samples\transport-interop\TransportInterop.csproj", "{EC505F21-FC69-4432-88A8-3CD5F7899B08}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Libp2p.Protocols.PubsubPeerDiscovery", "Libp2p.Protocols.PubsubPeerDiscovery\Libp2p.Protocols.PubsubPeerDiscovery.csproj", "{F14C0226-D2B1-48B8-BC6A-163BE2C8A4C6}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Libp2p.Protocols.PubsubPeerDiscovery.Tests", "Libp2p.Protocols.PubsubPeerDiscovery.Tests\Libp2p.Protocols.PubsubPeerDiscovery.Tests.csproj", "{5883B53B-2BA5-4444-8E65-DA4B69EB8B2F}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Libp2p.Protocols.Pubsub.Profiler", "Libp2p.Protocols.Pubsub.Profiler\Libp2p.Protocols.Pubsub.Profiler.csproj", "{BFE1CCB2-59A3-4A69-B543-EBC9C16E39F7}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Libp2p.Protocols.Pubsub.E2eTests", "Libp2p.Protocols.Pubsub.E2eTests\Libp2p.Protocols.Pubsub.E2eTests.csproj", "{BFE1CCB2-59A3-4A69-B543-EBC9C16E39F7}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Libp2p.Protocols.Tls", "Libp2p.Protocols.Tls\Libp2p.Protocols.Tls.csproj", "{C3CDBAAE-C790-443A-A293-D6E2330160F7}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Libp2p.Protocols.Tls.Tests", "Libp2p.Protocols.Tls.Tests\Libp2p.Protocols.Tls.Tests.csproj", "{89BD907E-1399-4BE7-98CC-E541EAB21842}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Libp2p.E2eTests", "Libp2p.E2eTests\Libp2p.E2eTests.csproj", "{DBC86C19-3374-4001-AC8A-F672E29CB7B2}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Libp2p.Protocols.PubsubPeerDiscovery.E2eTests", "Libp2p.Protocols.PubsubPeerDiscovery.E2eTests\Libp2p.Protocols.PubsubPeerDiscovery.E2eTests.csproj", "{EC0B1626-C006-4138-A119-FE61CDAB824D}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Libp2p.Protocols.Relay", "Libp2p.Protocols.Relay\Libp2p.Protocols.Relay.csproj", "{F29F5376-4F93-486F-B933-3278177704DE}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TransportInterop", "..\samples\transport-interop\TransportInterop.csproj", "{2C00E91D-79CE-470B-A38B-975F03C8B8A3}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -184,18 +188,10 @@ Global {D9003366-1562-49CA-B32D-087BBE3973ED}.Debug|Any CPU.Build.0 = Debug|Any CPU {D9003366-1562-49CA-B32D-087BBE3973ED}.Release|Any CPU.ActiveCfg = Release|Any CPU {D9003366-1562-49CA-B32D-087BBE3973ED}.Release|Any CPU.Build.0 = Release|Any CPU - {EC505F21-FC69-4432-88A8-3CD5F7899B08}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {EC505F21-FC69-4432-88A8-3CD5F7899B08}.Debug|Any CPU.Build.0 = Debug|Any CPU - {EC505F21-FC69-4432-88A8-3CD5F7899B08}.Release|Any CPU.ActiveCfg = Release|Any CPU - {EC505F21-FC69-4432-88A8-3CD5F7899B08}.Release|Any CPU.Build.0 = Release|Any CPU {F14C0226-D2B1-48B8-BC6A-163BE2C8A4C6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {F14C0226-D2B1-48B8-BC6A-163BE2C8A4C6}.Debug|Any CPU.Build.0 = Debug|Any CPU {F14C0226-D2B1-48B8-BC6A-163BE2C8A4C6}.Release|Any CPU.ActiveCfg = Release|Any CPU {F14C0226-D2B1-48B8-BC6A-163BE2C8A4C6}.Release|Any CPU.Build.0 = Release|Any CPU - {5883B53B-2BA5-4444-8E65-DA4B69EB8B2F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {5883B53B-2BA5-4444-8E65-DA4B69EB8B2F}.Debug|Any CPU.Build.0 = Debug|Any CPU - {5883B53B-2BA5-4444-8E65-DA4B69EB8B2F}.Release|Any CPU.ActiveCfg = Release|Any CPU - {5883B53B-2BA5-4444-8E65-DA4B69EB8B2F}.Release|Any CPU.Build.0 = Release|Any CPU {BFE1CCB2-59A3-4A69-B543-EBC9C16E39F7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {BFE1CCB2-59A3-4A69-B543-EBC9C16E39F7}.Debug|Any CPU.Build.0 = Debug|Any CPU {BFE1CCB2-59A3-4A69-B543-EBC9C16E39F7}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -208,6 +204,22 @@ Global {89BD907E-1399-4BE7-98CC-E541EAB21842}.Debug|Any CPU.Build.0 = Debug|Any CPU {89BD907E-1399-4BE7-98CC-E541EAB21842}.Release|Any CPU.ActiveCfg = Release|Any CPU {89BD907E-1399-4BE7-98CC-E541EAB21842}.Release|Any CPU.Build.0 = Release|Any CPU + {DBC86C19-3374-4001-AC8A-F672E29CB7B2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {DBC86C19-3374-4001-AC8A-F672E29CB7B2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {DBC86C19-3374-4001-AC8A-F672E29CB7B2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {DBC86C19-3374-4001-AC8A-F672E29CB7B2}.Release|Any CPU.Build.0 = Release|Any CPU + {EC0B1626-C006-4138-A119-FE61CDAB824D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {EC0B1626-C006-4138-A119-FE61CDAB824D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {EC0B1626-C006-4138-A119-FE61CDAB824D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {EC0B1626-C006-4138-A119-FE61CDAB824D}.Release|Any CPU.Build.0 = Release|Any CPU + {F29F5376-4F93-486F-B933-3278177704DE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F29F5376-4F93-486F-B933-3278177704DE}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F29F5376-4F93-486F-B933-3278177704DE}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F29F5376-4F93-486F-B933-3278177704DE}.Release|Any CPU.Build.0 = Release|Any CPU + {2C00E91D-79CE-470B-A38B-975F03C8B8A3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2C00E91D-79CE-470B-A38B-975F03C8B8A3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2C00E91D-79CE-470B-A38B-975F03C8B8A3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2C00E91D-79CE-470B-A38B-975F03C8B8A3}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -231,10 +243,13 @@ Global {FC0E9BCE-2848-45DC-AE20-FB7E862A199E} = {6F3D9AA9-C92D-4998-BC4E-D5EA068E8D0D} {EEECB761-A3C3-4598-AD03-EFABBF6CAA77} = {6F3D9AA9-C92D-4998-BC4E-D5EA068E8D0D} {D9003366-1562-49CA-B32D-087BBE3973ED} = {6F3D9AA9-C92D-4998-BC4E-D5EA068E8D0D} - {EC505F21-FC69-4432-88A8-3CD5F7899B08} = {0DC1C6A1-0A5B-43BA-9605-621C21A16716} {F14C0226-D2B1-48B8-BC6A-163BE2C8A4C6} = {6F3D9AA9-C92D-4998-BC4E-D5EA068E8D0D} - {5883B53B-2BA5-4444-8E65-DA4B69EB8B2F} = {6F3D9AA9-C92D-4998-BC4E-D5EA068E8D0D} {BFE1CCB2-59A3-4A69-B543-EBC9C16E39F7} = {6F3D9AA9-C92D-4998-BC4E-D5EA068E8D0D} + {C3CDBAAE-C790-443A-A293-D6E2330160F7} = {6F3D9AA9-C92D-4998-BC4E-D5EA068E8D0D} + {89BD907E-1399-4BE7-98CC-E541EAB21842} = {6F3D9AA9-C92D-4998-BC4E-D5EA068E8D0D} + {EC0B1626-C006-4138-A119-FE61CDAB824D} = {6F3D9AA9-C92D-4998-BC4E-D5EA068E8D0D} + {F29F5376-4F93-486F-B933-3278177704DE} = {6F3D9AA9-C92D-4998-BC4E-D5EA068E8D0D} + {2C00E91D-79CE-470B-A38B-975F03C8B8A3} = {0DC1C6A1-0A5B-43BA-9605-621C21A16716} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {E337E37C-3DB8-42FA-9A83-AC4E3B2557B4} diff --git a/src/libp2p/Libp2p/Libp2p.csproj b/src/libp2p/Libp2p/Libp2p.csproj index e823c2f2..cc739c23 100644 --- a/src/libp2p/Libp2p/Libp2p.csproj +++ b/src/libp2p/Libp2p/Libp2p.csproj @@ -15,6 +15,7 @@ + @@ -22,6 +23,7 @@ + diff --git a/src/libp2p/Libp2p/Libp2pPeerFactory.cs b/src/libp2p/Libp2p/Libp2pPeerFactory.cs index 084f6ce7..5f5319b1 100644 --- a/src/libp2p/Libp2p/Libp2pPeerFactory.cs +++ b/src/libp2p/Libp2p/Libp2pPeerFactory.cs @@ -1,32 +1,22 @@ // SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited // SPDX-License-Identifier: MIT -using Multiformats.Address; -using Multiformats.Address.Protocols; +using Microsoft.Extensions.Logging; using Nethermind.Libp2p.Core; +using Nethermind.Libp2p.Core.Discovery; using Nethermind.Libp2p.Protocols; -namespace Nethermind.Libp2p.Stack; +namespace Nethermind.Libp2p; -public class Libp2pPeerFactory : PeerFactory +public class Libp2pPeerFactory(IProtocolStackSettings protocolStackSettings, PeerStore peerStore, ILoggerFactory? loggerFactory = null) : PeerFactory(protocolStackSettings, peerStore, loggerFactory) { - public Libp2pPeerFactory(IServiceProvider serviceProvider) : base(serviceProvider) - { - } - - protected override async Task ConnectedTo(IRemotePeer peer, bool isDialer) - { - await peer.DialAsync(); - } + public override ILocalPeer Create(Identity? identity = null) => new Libp2pPeer(protocolStackSettings, PeerStore, identity ?? new Identity(), LoggerFactory); +} - public override ILocalPeer Create(Identity? identity = null, Multiaddress? localAddr = null) +class Libp2pPeer(IProtocolStackSettings protocolStackSettings, PeerStore peerStore, Identity identity, ILoggerFactory? loggerFactory = null) : LocalPeer(identity, peerStore, protocolStackSettings, loggerFactory) +{ + protected override async Task ConnectedTo(ISession session, bool isDialer) { - identity ??= new Identity(); - localAddr ??= $"/ip4/0.0.0.0/tcp/0/p2p/{identity.PeerId}"; - if (localAddr.Get() is null) - { - localAddr.Add(identity.PeerId.ToString()); - } - return base.Create(identity, localAddr); + await session.DialAsync(); } } diff --git a/src/libp2p/Libp2p/Libp2pPeerFactoryBuilder.cs b/src/libp2p/Libp2p/Libp2pPeerFactoryBuilder.cs index d296f9f9..79a53b07 100644 --- a/src/libp2p/Libp2p/Libp2pPeerFactoryBuilder.cs +++ b/src/libp2p/Libp2p/Libp2pPeerFactoryBuilder.cs @@ -3,13 +3,16 @@ using Nethermind.Libp2p.Core; using Nethermind.Libp2p.Protocols; -using Nethermind.Libp2p.Protocols.Pubsub; -namespace Nethermind.Libp2p.Stack; +namespace Nethermind.Libp2p; -public class Libp2pPeerFactoryBuilder : PeerFactoryBuilderBase, ILibp2pPeerFactoryBuilder +public class Libp2pPeerFactoryBuilder(IServiceProvider? serviceProvider = default) : PeerFactoryBuilderBase(serviceProvider), + ILibp2pPeerFactoryBuilder { private bool enforcePlaintext; + private bool addPubsub; + private bool addRelay; + private bool addQuic; public ILibp2pPeerFactoryBuilder WithPlaintextEnforced() { @@ -17,28 +20,64 @@ public ILibp2pPeerFactoryBuilder WithPlaintextEnforced() return this; } - public Libp2pPeerFactoryBuilder(IServiceProvider? serviceProvider = default) : base(serviceProvider) { } + public ILibp2pPeerFactoryBuilder WithPubsub() + { + addPubsub = true; + return this; + } + + public ILibp2pPeerFactoryBuilder WithRelay() + { + addRelay = true; + return this; + } - protected override ProtocolStack BuildStack() + public ILibp2pPeerFactoryBuilder WithQuic() { - ProtocolStack tcpEncryptionStack = enforcePlaintext ? Over() : Over(); - - ProtocolStack tcpStack = Over() - .Over() - .Over(tcpEncryptionStack) - .Over() - .Over(); - - return Over() - // Quic is not working well, and requires consumers to mark projects with preview - //.Over().Or(tcpStack) - .Over(tcpStack) - .Over() - .AddAppLayerProtocol() - .AddAppLayerProtocol() - .AddAppLayerProtocol() - .AddAppLayerProtocol() - .AddAppLayerProtocol() - .AddAppLayerProtocol(); + addQuic = true; + return this; + } + + protected override ProtocolRef[] BuildStack(IEnumerable additionalProtocols) + { + ProtocolRef tcp = Get(); + + ProtocolRef[] encryption = enforcePlaintext ? [Get()] : [Get()/*, Get()*/]; + + ProtocolRef[] muxers = [Get()]; + + ProtocolRef[] commonAppProtocolSelector = [Get()]; + Connect([tcp], [Get()], encryption, [Get()], muxers, commonAppProtocolSelector); + + ProtocolRef[] relay = addRelay ? [Get(), Get()] : []; + ProtocolRef[] pubsub = addPubsub ? [ + Get(), + Get(), + Get(), + Get() + ] : []; + + ProtocolRef[] apps = [ + Get(), + Get(), + .. additionalProtocols, + .. relay, + .. pubsub, + ]; + Connect(commonAppProtocolSelector, apps); + + if (addRelay) + { + Connect(relay, [Get()], apps.Where(a => !relay.Contains(a)).ToArray()); + } + + if (addQuic) + { + ProtocolRef quic = Get(); + Connect([quic], commonAppProtocolSelector); + return [tcp, quic]; + } + + return [tcp]; } } diff --git a/src/libp2p/Libp2p/LogMessages.cs b/src/libp2p/Libp2p/LogMessages.cs index 01b41117..17777302 100644 --- a/src/libp2p/Libp2p/LogMessages.cs +++ b/src/libp2p/Libp2p/LogMessages.cs @@ -3,7 +3,7 @@ using Microsoft.Extensions.Logging; -namespace Nethermind.Libp2p.Stack; +namespace Nethermind.Libp2p; internal static partial class LogMessages { diff --git a/src/libp2p/Libp2p/MultiaddressBasedSelectorProtocol.cs b/src/libp2p/Libp2p/MultiaddressBasedSelectorProtocol.cs deleted file mode 100644 index 113354f6..00000000 --- a/src/libp2p/Libp2p/MultiaddressBasedSelectorProtocol.cs +++ /dev/null @@ -1,47 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited -// SPDX-License-Identifier: MIT - -using Microsoft.Extensions.Logging; -using Multiformats.Address.Protocols; -using Nethermind.Libp2p.Core; -using Nethermind.Libp2p.Stack; - -namespace Nethermind.Libp2p.Protocols; - -/// -/// Select protocol based on multiaddr -/// -public class MultiaddressBasedSelectorProtocol(ILoggerFactory? loggerFactory = null) : SymmetricProtocol, IProtocol -{ - private readonly ILogger? _logger = loggerFactory?.CreateLogger(); - - public string Id => "multiaddr-select"; - - protected override async Task ConnectAsync(IChannel _, IChannelFactory? channelFactory, IPeerContext context, bool isListener) - { - IProtocol protocol = null!; - // TODO: deprecate quic - if (context.LocalPeer.Address.Has()) - { - protocol = channelFactory!.SubProtocols.FirstOrDefault(proto => proto.Id == "quic-v1") ?? throw new ApplicationException("QUICv1 is not supported"); - } - else if (context.LocalPeer.Address.Has()) - { - protocol = channelFactory!.SubProtocols.FirstOrDefault(proto => proto.Id == "ip-tcp") ?? throw new ApplicationException("TCP is not supported"); - } - else if (context.LocalPeer.Address.Has()) - { - throw new ApplicationException("QUIC is not supported. Use QUICv1 instead."); - } - else - { - throw new NotImplementedException($"No transport protocol found for the given address: {context.LocalPeer.Address}"); - } - - _logger?.LogPickedProtocol(protocol.Id, isListener ? "listen" : "dial"); - - await (isListener - ? channelFactory.SubListen(context, protocol) - : channelFactory.SubDial(context, protocol)); - } -} diff --git a/src/libp2p/Libp2p/ServiceProviderExtensions.cs b/src/libp2p/Libp2p/ServiceProviderExtensions.cs index c92347c7..3f70a92f 100644 --- a/src/libp2p/Libp2p/ServiceProviderExtensions.cs +++ b/src/libp2p/Libp2p/ServiceProviderExtensions.cs @@ -7,20 +7,22 @@ using Nethermind.Libp2p.Protocols; using Nethermind.Libp2p.Protocols.Pubsub; -namespace Nethermind.Libp2p.Stack; +namespace Nethermind.Libp2p; public static class ServiceProviderExtensions { - public static IServiceCollection AddLibp2p(this IServiceCollection services, Func factorySetup) + public static IServiceCollection AddLibp2p(this IServiceCollection services, Func? factorySetup = null) { return services - .AddScoped(sp => factorySetup(new Libp2pPeerFactoryBuilder(sp))) - .AddScoped(sp => (ILibp2pPeerFactoryBuilder)factorySetup(new Libp2pPeerFactoryBuilder(sp))) - .AddScoped(sp => sp.GetService()!.Build()) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() + .AddSingleton(sp => new Libp2pPeerFactoryBuilder(sp)) + .AddSingleton(sp => (ILibp2pPeerFactoryBuilder)sp.GetRequiredService()) + .AddSingleton() + .AddSingleton(sp => factorySetup is null ? sp.GetRequiredService() : factorySetup(sp.GetRequiredService())) + .AddSingleton(sp => sp.GetService()!.Build()) + .AddSingleton() + .AddSingleton() + .AddSingleton() + .AddSingleton() ; } } diff --git a/src/samples/chat/ChatProtocol.cs b/src/samples/chat/ChatProtocol.cs index cfd569f5..8cff6247 100644 --- a/src/samples/chat/ChatProtocol.cs +++ b/src/samples/chat/ChatProtocol.cs @@ -5,15 +5,14 @@ using System.Text; using Nethermind.Libp2p.Core; -internal class ChatProtocol : SymmetricProtocol, IProtocol +internal class ChatProtocol : SymmetricSessionProtocol, ISessionProtocol { private static readonly ConsoleReader Reader = new(); private readonly ConsoleColor defautConsoleColor = Console.ForegroundColor; public string Id => "/chat/1.0.0"; - protected override async Task ConnectAsync(IChannel channel, IChannelFactory? channelFactory, - IPeerContext context, bool isListener) + protected override async Task ConnectAsync(IChannel channel, ISessionContext context, bool isListener) { Console.Write("> "); _ = Task.Run(async () => diff --git a/src/samples/chat/Program.cs b/src/samples/chat/Program.cs index 8b0f9129..c6d94047 100644 --- a/src/samples/chat/Program.cs +++ b/src/samples/chat/Program.cs @@ -3,15 +3,15 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using Nethermind.Libp2p.Stack; using Nethermind.Libp2p.Core; using Multiformats.Address; using Multiformats.Address.Protocols; +using Nethermind.Libp2p; ServiceProvider serviceProvider = new ServiceCollection() .AddLibp2p(builder => builder.AddAppLayerProtocol()) .AddLogging(builder => - builder.SetMinimumLevel(args.Contains("--trace") ? LogLevel.Trace : LogLevel.Information) + builder.SetMinimumLevel(args.Contains("--trace") ? LogLevel.Trace : LogLevel.Trace) .AddSimpleConsole(l => { l.SingleLine = true; @@ -32,10 +32,10 @@ "/ip4/0.0.0.0/udp/0/quic-v1" : "/ip4/0.0.0.0/tcp/0"; - ILocalPeer localPeer = peerFactory.Create(localAddr: addrTemplate); + ILocalPeer localPeer = peerFactory.Create(); logger.LogInformation("Dialing {remote}", remoteAddr); - IRemotePeer remotePeer = await localPeer.DialAsync(remoteAddr, ts.Token); + ISession remotePeer = await localPeer.DialAsync(remoteAddr, ts.Token); await remotePeer.DialAsync(ts.Token); await remotePeer.DisconnectAsync(); @@ -49,12 +49,15 @@ "/ip4/0.0.0.0/udp/{0}/quic-v1" : "/ip4/0.0.0.0/tcp/{0}"; - IListener listener = await peer.ListenAsync( - string.Format(addrTemplate, args.Length > 0 && args[0] == "-sp" ? args[1] : "0"), + peer.OnConnected += async newSession => logger.LogInformation("A peer connected {remote}", newSession.RemoteAddress); + + await peer.StartListenAsync( + [string.Format(addrTemplate, args.Length > 0 && args[0] == "-sp" ? args[1] : "0")], ts.Token); - logger.LogInformation("Listener started at {address}", listener.Address); - listener.OnConnection += async remotePeer => logger.LogInformation("A peer connected {remote}", remotePeer.Address); - Console.CancelKeyPress += delegate { listener.DisconnectAsync(); }; + logger.LogInformation("Listener started at {address}", string.Join(", ", peer.ListenAddresses)); + + Console.CancelKeyPress += delegate { ts.Cancel(); }; - await listener; + await Task.Delay(-1, ts.Token); + await peer.DisconnectAsync(); } diff --git a/src/samples/chat/Properties/launchSettings.json b/src/samples/chat/Properties/launchSettings.json new file mode 100644 index 00000000..4e8d0ddf --- /dev/null +++ b/src/samples/chat/Properties/launchSettings.json @@ -0,0 +1,12 @@ +{ + "profiles": { + "Chat": { + "commandName": "Project", + "commandLineArgs": "-d /ip4/127.0.0.1/tcp/9001/p2p/QmNMymSSfNSRvtB1tytWUYB6PXAZqmS8PYW7WKci5vuU8s --trace " + }, + "Chat server": { + "commandName": "Project", + "commandLineArgs": "--trace" + } + } +} \ No newline at end of file diff --git a/src/samples/perf-benchmarks/NoStackPeerFactoryBuilder.cs b/src/samples/perf-benchmarks/NoStackPeerFactoryBuilder.cs index 45d94a9f..999c20f4 100644 --- a/src/samples/perf-benchmarks/NoStackPeerFactoryBuilder.cs +++ b/src/samples/perf-benchmarks/NoStackPeerFactoryBuilder.cs @@ -1,9 +1,9 @@ // SPDX-FileCopyrightText:2023 Demerzel Solutions Limited // SPDX-License-Identifier:MIT -using Nethermind.Libp2p.Stack; using Nethermind.Libp2p.Core; using Nethermind.Libp2p.Protocols; +using Nethermind.Libp2p; namespace DataTransferBenchmark; @@ -16,8 +16,8 @@ public NoStackPeerFactoryBuilder() : base(default) public static Libp2pPeerFactoryBuilder Create => new(); - protected override ProtocolStack BuildStack() + protected override ProtocolRef[] BuildStack(IEnumerable additionalProtocols) { - return Over(); + return [Get()]; } } diff --git a/src/samples/perf-benchmarks/PerfBenchmarks.csproj b/src/samples/perf-benchmarks/PerfBenchmarks.csproj index 7a0264b2..e82be68b 100644 --- a/src/samples/perf-benchmarks/PerfBenchmarks.csproj +++ b/src/samples/perf-benchmarks/PerfBenchmarks.csproj @@ -5,7 +5,6 @@ enable enable true - true net8.0 PerfBenchmarks diff --git a/src/samples/perf-benchmarks/PerfProtocol.cs b/src/samples/perf-benchmarks/PerfProtocol.cs index 6dc94841..a0d91493 100644 --- a/src/samples/perf-benchmarks/PerfProtocol.cs +++ b/src/samples/perf-benchmarks/PerfProtocol.cs @@ -8,7 +8,7 @@ namespace DataTransferBenchmark; // TODO: Align with perf protocol -public class PerfProtocol : IProtocol +public class PerfProtocol : ISessionProtocol { private readonly ILogger? _logger; public string Id => "/perf/1.0.0"; @@ -21,7 +21,7 @@ public PerfProtocol(ILoggerFactory? loggerFactory = null) public const long TotalLoad = 1024L * 1024 * 100; private Random rand = new(); - public async Task DialAsync(IChannel downChannel, IChannelFactory upChannelFactory, IPeerContext context) + public async Task DialAsync(IChannel downChannel, ISessionContext context) { await downChannel.WriteVarintAsync(TotalLoad); @@ -60,7 +60,7 @@ public async Task DialAsync(IChannel downChannel, IChannelFactory upChannelFacto } } - public async Task ListenAsync(IChannel downChannel, IChannelFactory upChannelFactory, IPeerContext context) + public async Task ListenAsync(IChannel downChannel, ISessionContext context) { ulong total = await downChannel.ReadVarintUlongAsync(); ulong bytesRead = 0; diff --git a/src/samples/perf-benchmarks/Program.cs b/src/samples/perf-benchmarks/Program.cs index 5ff5cc51..d2b3f7a5 100644 --- a/src/samples/perf-benchmarks/Program.cs +++ b/src/samples/perf-benchmarks/Program.cs @@ -4,9 +4,9 @@ using System.Diagnostics; using DataTransferBenchmark; using Microsoft.Extensions.DependencyInjection; -using Nethermind.Libp2p.Stack; using Nethermind.Libp2p.Core; using Multiformats.Address; +using Nethermind.Libp2p; await Task.Delay(1000); { @@ -20,11 +20,11 @@ Identity optionalFixedIdentity = new(Enumerable.Repeat((byte)42, 32).ToArray()); ILocalPeer peer = peerFactory.Create(optionalFixedIdentity); - IListener listener = await peer.ListenAsync($"/ip4/0.0.0.0/tcp/0/p2p/{peer.Identity.PeerId}"); + await peer.StartListenAsync([$"/ip4/0.0.0.0/tcp/0/p2p/{peer.Identity.PeerId}"]); - Multiaddress remoteAddr = listener.Address; + Multiaddress remoteAddr = peer.ListenAddresses.First(); ILocalPeer localPeer = peerFactory.Create(); - IRemotePeer remotePeer = await localPeer.DialAsync(remoteAddr); + ISession remotePeer = await localPeer.DialAsync(remoteAddr); Stopwatch timeSpent = Stopwatch.StartNew(); await remotePeer.DialAsync(); @@ -42,11 +42,11 @@ .Build(); ILocalPeer peer = peerFactory.Create(); - IListener listener = await peer.ListenAsync($"/ip4/0.0.0.0/tcp/0"); + await peer.StartListenAsync([$"/ip4/0.0.0.0/tcp/0"]); - Multiaddress remoteAddr = listener.Address; + Multiaddress remoteAddr = peer.ListenAddresses.First(); ILocalPeer localPeer = peerFactory.Create(); - IRemotePeer remotePeer = await localPeer.DialAsync(remoteAddr); + ISession remotePeer = await localPeer.DialAsync(remoteAddr); Stopwatch timeSpent = Stopwatch.StartNew(); await remotePeer.DialAsync(); diff --git a/src/samples/pubsub-chat/Program.cs b/src/samples/pubsub-chat/Program.cs index e9db6b03..89e67d1c 100644 --- a/src/samples/pubsub-chat/Program.cs +++ b/src/samples/pubsub-chat/Program.cs @@ -2,24 +2,25 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using Nethermind.Libp2p.Stack; using Nethermind.Libp2p.Core; using System.Text; using System.Text.Json; using Nethermind.Libp2p.Protocols.Pubsub; -using Multiformats.Address.Protocols; -using Multiformats.Address; using Nethermind.Libp2p.Protocols; +using System.Text.RegularExpressions; +using Nethermind.Libp2p; + +Regex omittedLogs = new(".*(MDnsDiscoveryProtocol|IpTcpProtocol).*"); ServiceProvider serviceProvider = new ServiceCollection() - .AddLibp2p(builder => builder) + .AddLibp2p(builder => builder.WithPubsub()) .AddLogging(builder => builder.SetMinimumLevel(args.Contains("--trace") ? LogLevel.Trace : LogLevel.Information) .AddSimpleConsole(l => { l.SingleLine = true; - l.TimestampFormat = "[HH:mm:ss.FFF]"; - })) + l.TimestampFormat = "[HH:mm:ss.fff]"; + }).AddFilter((_, type, lvl) => !omittedLogs.IsMatch(type!))) .BuildServiceProvider(); IPeerFactory peerFactory = serviceProvider.GetService()!; @@ -30,7 +31,7 @@ Identity localPeerIdentity = new(); string addr = $"/ip4/0.0.0.0/tcp/0/p2p/{localPeerIdentity.PeerId}"; -ILocalPeer peer = peerFactory.Create(localPeerIdentity, Multiaddress.Decode(addr)); +ILocalPeer peer = peerFactory.Create(localPeerIdentity); PubsubRouter router = serviceProvider.GetService()!; ITopic topic = router.GetTopic("chat-room:awesome-chat-room"); @@ -51,13 +52,12 @@ } }; -await peer.ListenAsync(addr, ts.Token); - -_ = serviceProvider.GetService()!.DiscoverAsync(peer.Address, token: ts.Token); +await peer.StartListenAsync([addr], ts.Token); -_ = router.RunAsync(peer, token: ts.Token); +string peerId = peer.Identity.PeerId.ToString(); +_ = serviceProvider.GetService()!.StartDiscoveryAsync(peer.ListenAddresses, token: ts.Token); -string peerId = peer.Address.Get().ToString(); +await router.StartAsync(peer, token: ts.Token); string nickName = "libp2p-dotnet"; diff --git a/src/samples/pubsub-chat/PubsubChat.csproj b/src/samples/pubsub-chat/PubsubChat.csproj index 84c66c97..c2f86896 100644 --- a/src/samples/pubsub-chat/PubsubChat.csproj +++ b/src/samples/pubsub-chat/PubsubChat.csproj @@ -4,7 +4,6 @@ Exe enable enable - true net8.0 PubsubChat diff --git a/src/samples/transport-interop/Program.cs b/src/samples/transport-interop/Program.cs index a477a74d..62cc206d 100644 --- a/src/samples/transport-interop/Program.cs +++ b/src/samples/transport-interop/Program.cs @@ -3,6 +3,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Multiformats.Address; using Nethermind.Libp2p.Core; using Nethermind.Libp2p.Protocols; using StackExchange.Redis; @@ -24,7 +25,7 @@ int testTimeoutSeconds = int.Parse(Environment.GetEnvironmentVariable("test_timeout_seconds") ?? "180"); - TestPlansPeerFactoryBuilder builder = new TestPlansPeerFactoryBuilder(transport, muxer, security); + TestPlansPeerFactoryBuilder builder = new(transport, muxer, security); IPeerFactory peerFactory = builder.Build(); Log($"Connecting to redis at {redisAddr}..."); @@ -33,7 +34,7 @@ if (isDialer) { - ILocalPeer localPeer = peerFactory.Create(localAddr: builder.MakeAddress()); + ILocalPeer localPeer = peerFactory.Create(); Log($"Picking an address to dial..."); @@ -46,7 +47,7 @@ Log($"Dialing {listenerAddr}..."); Stopwatch handshakeStartInstant = Stopwatch.StartNew(); - IRemotePeer remotePeer = await localPeer.DialAsync(listenerAddr); + ISession remotePeer = await localPeer.DialAsync((Multiaddress)listenerAddr); Stopwatch pingIstant = Stopwatch.StartNew(); await remotePeer.DialAsync(); @@ -62,7 +63,7 @@ { if (ip == "0.0.0.0") { - var d = NetworkInterface.GetAllNetworkInterfaces()! + List d = NetworkInterface.GetAllNetworkInterfaces()! .Where(i => i.Name == "eth0" || (i.OperationalStatus == OperationalStatus.Up && i.NetworkInterfaceType == NetworkInterfaceType.Ethernet)).ToList(); @@ -81,13 +82,15 @@ ip = addresses.First().Address.ToString()!; } Log("Starting to listen..."); - ILocalPeer localPeer = peerFactory.Create(localAddr: builder.MakeAddress(ip)); - IListener listener = await localPeer.ListenAsync(localPeer.Address); - listener.OnConnection += (peer) => { Log($"Connected {peer.Address}"); return Task.CompletedTask; }; - Log($"Listening on {listener.Address}"); - db.ListRightPush(new RedisKey("listenerAddr"), new RedisValue(listener.Address.ToString())); + ILocalPeer localPeer = peerFactory.Create(); + + CancellationTokenSource listennTcs = new(); + await localPeer.StartListenAsync([builder.MakeAddress(ip)], listennTcs.Token); + localPeer.OnConnected += (session) => { Log($"Connected {session.RemoteAddress}"); return Task.CompletedTask; }; + Log($"Listening on {string.Join(", ", localPeer.ListenAddresses)}"); + db.ListRightPush(new RedisKey("listenerAddr"), new RedisValue(localPeer.ListenAddresses.First().ToString())); await Task.Delay(testTimeoutSeconds * 1000); - await listener.DisconnectAsync(); + await listennTcs.CancelAsync(); return -1; } } @@ -127,36 +130,39 @@ public TestPlansPeerFactoryBuilder(string transport, string? muxer, string? secu private static readonly string[] stacklessProtocols = ["quic", "quic-v1", "webtransport"]; - protected override ProtocolStack BuildStack() + protected override ProtocolRef[] BuildStack(IEnumerable additionalProtocols) { - ProtocolStack stack = transport switch + ProtocolRef[] transportStack = [transport switch { - "tcp" => Over(), + "tcp" => Get(), // TODO: Improve QUIC imnteroperability - "quic-v1" => Over(), + "quic-v1" => Get(), _ => throw new NotImplementedException(), - }; + }]; - stack = stack.Over(); + ProtocolRef[] selector = [Get()]; + Connect(transportStack, selector); if (!stacklessProtocols.Contains(transport)) { - stack = security switch + ProtocolRef[] securityStack = [security switch { - "noise" => stack.Over(), + "noise" => Get(), _ => throw new NotImplementedException(), - }; - stack = stack.Over(); - stack = muxer switch + }]; + ProtocolRef[] muxerStack = [muxer switch { - "yamux" => stack.Over(), + "yamux" => Get(), _ => throw new NotImplementedException(), - }; - stack = stack.Over(); + }]; + + selector = Connect(selector, transportStack, [Get()], muxerStack, [Get()]); } - return stack.AddAppLayerProtocol() - .AddAppLayerProtocol(); + ProtocolRef[] apps = [Get(), Get()]; + Connect(selector, apps); + + return transportStack; } public string MakeAddress(string ip = "0.0.0.0", string port = "0") => transport switch diff --git a/src/samples/transport-interop/TransportInterop.csproj b/src/samples/transport-interop/TransportInterop.csproj index 42309408..13deab2a 100644 --- a/src/samples/transport-interop/TransportInterop.csproj +++ b/src/samples/transport-interop/TransportInterop.csproj @@ -8,7 +8,6 @@ enable true true - true diff --git a/src/samples/transport-interop/packages.lock.json b/src/samples/transport-interop/packages.lock.json index 2ee01cdd..880a92b3 100644 --- a/src/samples/transport-interop/packages.lock.json +++ b/src/samples/transport-interop/packages.lock.json @@ -176,8 +176,8 @@ }, "Nethermind.Multiformats.Address": { "type": "Transitive", - "resolved": "1.1.5", - "contentHash": "wm3ooKVG2w0jIuqtHXUPWMck1gQ/DxIFB3RAqxsPIhJesm+dSOUmACkJ4t3GL+VhxtHlYdDVbIKimuKe83ZCGQ==", + "resolved": "1.1.8", + "contentHash": "+nRuuVXjj/Okj/RAJtJUZ/nDRjwMfjJnF1+4Z7gKX2MjMsxR92KPJbsP4fY4IaEwXMDjNJsXJu78z2C06tElzw==", "dependencies": { "BinaryEncoding": "1.4.0", "Nethermind.Multiformats.Base": "2.0.3-preview.1", @@ -1200,7 +1200,7 @@ "Microsoft.Extensions.DependencyInjection": "[8.0.0, )", "Microsoft.Extensions.DependencyInjection.Abstractions": "[8.0.0, )", "Microsoft.Extensions.Logging.Abstractions": "[8.0.0, )", - "Nethermind.Multiformats.Address": "[1.1.5, )", + "Nethermind.Multiformats.Address": "[1.1.8, )", "SimpleBase": "[4.0.0, )" } },