Skip to content

Commit

Permalink
[core] Refactor PeerID out of EncryptorFactory
Browse files Browse the repository at this point in the history
This is one fewer places where the catch-all 'PeerId' is not
actually needed. This simplifies the testing of this section
of code as we don't need a full ClientEngine/TorrentManager
etc.
  • Loading branch information
alanmcgovern committed Aug 15, 2019
1 parent 7f98fa0 commit e0cbf35
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 75 deletions.
6 changes: 3 additions & 3 deletions src/MonoTorrent.Tests/Client/MetadataModeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ public async Task Setup(bool metadataMode, string metadataPath)
id.Connection = connection;


byte[] data = EncryptorFactory.CheckEncryptionAsync(id, 68, new InfoHash[] { id.TorrentManager.InfoHash }).Result;
decryptor = id.Decryptor;
encryptor = id.Encryptor;
var result = await EncryptorFactory.CheckIncomingConnectionAsync(id.Connection, id.Peer.Encryption, rig.Engine.Settings, HandshakeMessage.HandshakeLength, new InfoHash[] { id.TorrentManager.InfoHash });
decryptor = id.Decryptor = result.Decryptor;
encryptor = id.Encryptor = result.Encryptor;
}

[TearDown]
Expand Down
4 changes: 2 additions & 2 deletions src/MonoTorrent.Tests/Client/PeerMessagesTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public void HandshakeEncoding()
Assert.IsTrue(Toolbox.ByteMatch(new byte[8], 0, buffer, offset + 20, 8), "3");
Assert.IsTrue(Toolbox.ByteMatch(infohash, 0, buffer, offset + 28, 20), "4");
Assert.IsTrue(Toolbox.ByteMatch(peerId, 0, buffer, offset + 48, 20), "5");
Assert.AreEqual(length, 68, "6");
Assert.AreEqual(length, HandshakeMessage.HandshakeLength, "6");

length = new HandshakeMessage(new InfoHash (infohash), "12312312345645645678", VersionInfo.ProtocolStringV100, true, false).Encode(buffer, offset);
Assert.AreEqual(BitConverter.ToString(buffer, offset, length), "13-42-69-74-54-6F-72-72-65-6E-74-20-70-72-6F-74-6F-63-6F-6C-00-00-00-00-00-00-00-04-01-02-03-04-05-06-07-08-09-0A-0B-0C-0D-0E-0F-00-0C-0F-0C-34-31-32-33-31-32-33-31-32-33-34-35-36-34-35-36-34-35-36-37-38", "#7");
Expand All @@ -175,7 +175,7 @@ public void HandshakeDecoding()
HandshakeMessage orig = new HandshakeMessage(new InfoHash (infohash), "12312312345645645678", VersionInfo.ProtocolStringV100);
orig.Encode(buffer, offset);
HandshakeMessage dec = new HandshakeMessage();
dec.Decode(buffer, offset, 68);
dec.Decode(buffer, offset, HandshakeMessage.HandshakeLength);
Assert.IsTrue(orig.Equals(dec));
Assert.AreEqual(orig.Encode(), dec.Encode());
}
Expand Down
4 changes: 2 additions & 2 deletions src/MonoTorrent.Tests/Client/TestEncryption.cs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async Task PeerATest(EncryptionTypes encryption, bool addInitial)
}

int received = await conn.Outgoing.ReceiveAsync(buffer, 0, buffer.Length);
Assert.AreEqual (68, received, "Recived handshake");
Assert.AreEqual (HandshakeMessage.HandshakeLength, received, "Recived handshake");

a.Decryptor.Decrypt(buffer);
message.Decode(buffer, 0, buffer.Length);
Expand All @@ -263,7 +263,7 @@ async Task PeerBTest(EncryptionTypes encryption)
Assert.Fail("Handshake timed out");

HandshakeMessage message = new HandshakeMessage();
byte[] buffer = new byte[68];
byte[] buffer = new byte[HandshakeMessage.HandshakeLength];

await conn.Incoming.ReceiveAsync(buffer, 0, buffer.Length);

Expand Down
21 changes: 11 additions & 10 deletions src/MonoTorrent.Tests/Client/TransferTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,15 @@ public async Task NegativeData()

public async Task InitiateTransfer(CustomConnection connection, EncryptionTypes allowedEncryption)
{
PeerId id = new PeerId(new Peer("", connection.Uri), rig.Manager);
id.Peer.Encryption = allowedEncryption;
id.Connection = connection;

var data = await EncryptorFactory.CheckEncryptionAsync(id, 68, new InfoHash[] { id.TorrentManager.InfoHash });
decryptor = id.Decryptor;
encryptor = id.Encryptor;
TestHandshake(data, connection);
EncryptorFactory.EncryptorResult result;
if (connection.IsIncoming) {
result = await EncryptorFactory.CheckIncomingConnectionAsync(connection, allowedEncryption, rig.Engine.Settings, HandshakeMessage.HandshakeLength, new [] { rig.Manager.InfoHash });
} else {
result = await EncryptorFactory.CheckOutgoingConnectionAsync(connection, allowedEncryption, rig.Engine.Settings, rig.Manager.InfoHash);
}
decryptor = result.Decryptor;
encryptor = result.Encryptor;
TestHandshake(result.InitialData, connection);
}

public void TestHandshake(byte[] buffer, CustomConnection connection)
Expand All @@ -298,8 +299,8 @@ public void TestHandshake(byte[] buffer, CustomConnection connection)
// 2) Receive remote handshake
if (buffer == null || buffer.Length == 0)
{
buffer = new byte[68];
Receive (connection, buffer, 0, 68);
buffer = new byte[HandshakeMessage.HandshakeLength];
Receive (connection, buffer, 0, HandshakeMessage.HandshakeLength);
decryptor.Decrypt(buffer);
}

Expand Down
130 changes: 80 additions & 50 deletions src/MonoTorrent/MonoTorrent.Client.Encryption/EncryptorFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@


using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;

Expand All @@ -38,74 +39,103 @@ namespace MonoTorrent.Client.Encryption
{
static class EncryptorFactory
{
internal static async Task<byte[]> CheckEncryptionAsync(PeerId id, int bytesToReceive, InfoHash[] sKeys)
internal struct EncryptorResult
{
using (var cts = new CancellationTokenSource (TimeSpan.FromSeconds (1000)))
using (var registration = cts.Token.Register (id.Connection.Dispose))
return await CheckEncryptionAsync (id, bytesToReceive, sKeys, cts.Token);
public IEncryption Decryptor { get; }
public IEncryption Encryptor { get; }
public byte [] InitialData { get; }

public EncryptorResult (IEncryption decryptor, IEncryption encryptor, byte [] data)
{
Decryptor = decryptor;
Encryptor = encryptor;
InitialData = data;
}
}

static TimeSpan Timeout => Debugger.IsAttached ? TimeSpan.FromHours (1) : TimeSpan.FromSeconds (10);

internal static async Task<EncryptorResult> CheckIncomingConnectionAsync (IConnection connection, EncryptionTypes encryption, EngineSettings settings, int bytesToReceive, InfoHash[] sKeys)
{
if (!connection.IsIncoming)
throw new Exception ("oops");

using (var cts = new CancellationTokenSource (Timeout))
using (var registration = cts.Token.Register (connection.Dispose))
return await DoCheckIncomingConnectionAsync (connection, encryption, settings, bytesToReceive, sKeys);
}

static async Task<byte[]> CheckEncryptionAsync(PeerId id, int bytesToReceive, InfoHash[] sKeys, CancellationToken token)
static async Task<EncryptorResult> DoCheckIncomingConnectionAsync(IConnection connection, EncryptionTypes encryption, EngineSettings settings, int bytesToReceive, InfoHash[] sKeys)
{
IConnection connection = id.Connection;
var allowedEncryption = (id.Engine?.Settings.AllowedEncryption ?? EncryptionTypes.All) & id.Peer.Encryption;
var allowedEncryption = (settings?.AllowedEncryption ?? EncryptionTypes.All) & encryption;
var supportsRC4Header = allowedEncryption.HasFlag (EncryptionTypes.RC4Header);
var supportsRC4Full = allowedEncryption.HasFlag (EncryptionTypes.RC4Full);
var supportsPlainText = allowedEncryption.HasFlag (EncryptionTypes.PlainText);

// If the connection is incoming, receive the handshake before
// trying to decide what encryption to use
if (connection.IsIncoming)
{
var buffer = new byte[bytesToReceive];
await NetworkIO.ReceiveAsync(connection, buffer, 0, bytesToReceive, null, null, null).ConfigureAwait (false);

HandshakeMessage message = new HandshakeMessage();
message.Decode(buffer, 0, buffer.Length);
var buffer = new byte[bytesToReceive];
await NetworkIO.ReceiveAsync(connection, buffer, 0, bytesToReceive, null, null, null).ConfigureAwait (false);

if (message.ProtocolString == VersionInfo.ProtocolStringV100) {
if (supportsPlainText) {
id.Encryptor = id.Decryptor = PlainTextEncryption.Instance;
return buffer;
}
}
else if (supportsRC4Header || supportsRC4Full)
{
// The data we just received was part of an encrypted handshake and was *not* the BitTorrent handshake
var encSocket = new PeerBEncryption(sKeys, EncryptionTypes.All);
await encSocket.HandshakeAsync(connection, buffer, 0, buffer.Length);
if (encSocket.Decryptor is RC4Header && !supportsRC4Header)
throw new EncryptionException("Decryptor was RC4Header but that is not allowed");
if (encSocket.Decryptor is RC4 && !supportsRC4Full)
throw new EncryptionException("Decryptor was RC4Full but that is not allowed");

id.Decryptor = encSocket.Decryptor;
id.Encryptor = encSocket.Encryptor;
return encSocket.InitialData?.Length > 0 ? encSocket.InitialData : null;
HandshakeMessage message = new HandshakeMessage();
message.Decode(buffer, 0, buffer.Length);

if (message.ProtocolString == VersionInfo.ProtocolStringV100) {
if (supportsPlainText) {
return new EncryptorResult (PlainTextEncryption.Instance, PlainTextEncryption.Instance, buffer);
}
}
else
else if (supportsRC4Header || supportsRC4Full)
{
if ((id.Engine.Settings.PreferEncryption || !supportsPlainText) && (supportsRC4Header || supportsRC4Full)) {
var encSocket = new PeerAEncryption(id.TorrentManager.InfoHash, allowedEncryption);
await encSocket.HandshakeAsync(connection);
if (encSocket.Decryptor is RC4Header && !supportsRC4Header)
throw new EncryptionException("Decryptor was RC4Header but that is not allowed");
if (encSocket.Decryptor is RC4 && !supportsRC4Full)
throw new EncryptionException("Decryptor was RC4Full but that is not allowed");

id.Decryptor = encSocket.Decryptor;
id.Encryptor = encSocket.Encryptor;
return encSocket.InitialData?.Length > 0 ? encSocket.InitialData : null;
}
else if (supportsPlainText)
{
id.Encryptor = id.Decryptor = PlainTextEncryption.Instance;
return null;
}
// The data we just received was part of an encrypted handshake and was *not* the BitTorrent handshake
var encSocket = new PeerBEncryption(sKeys, EncryptionTypes.All);
await encSocket.HandshakeAsync(connection, buffer, 0, buffer.Length);
if (encSocket.Decryptor is RC4Header && !supportsRC4Header)
throw new EncryptionException("Decryptor was RC4Header but that is not allowed");
if (encSocket.Decryptor is RC4 && !supportsRC4Full)
throw new EncryptionException("Decryptor was RC4Full but that is not allowed");

var data = encSocket.InitialData?.Length > 0 ? encSocket.InitialData : null;
return new EncryptorResult (encSocket.Decryptor, encSocket.Encryptor, data);
}

throw new EncryptionException("Invalid handshake received and no decryption works");
}

internal static async Task<EncryptorResult> CheckOutgoingConnectionAsync(IConnection connection, EncryptionTypes encryption, EngineSettings settings, InfoHash infoHash)
{
if (connection.IsIncoming)
throw new Exception ("oops");

using (var cts = new CancellationTokenSource (Timeout))
using (var registration = cts.Token.Register (connection.Dispose))
return await DoCheckOutgoingConnectionAsync (connection, encryption, settings, infoHash);
}

static async Task<EncryptorResult> DoCheckOutgoingConnectionAsync(IConnection connection, EncryptionTypes encryption, EngineSettings settings, InfoHash infoHash)
{
var allowedEncryption = settings.AllowedEncryption & encryption;
var supportsRC4Header = allowedEncryption.HasFlag (EncryptionTypes.RC4Header);
var supportsRC4Full = allowedEncryption.HasFlag (EncryptionTypes.RC4Full);
var supportsPlainText = allowedEncryption.HasFlag (EncryptionTypes.PlainText);

if ((settings.PreferEncryption || !supportsPlainText) && (supportsRC4Header || supportsRC4Full)) {
var encSocket = new PeerAEncryption(infoHash, allowedEncryption);
await encSocket.HandshakeAsync(connection);
if (encSocket.Decryptor is RC4Header && !supportsRC4Header)
throw new EncryptionException("Decryptor was RC4Header but that is not allowed");
if (encSocket.Decryptor is RC4 && !supportsRC4Full)
throw new EncryptionException("Decryptor was RC4Full but that is not allowed");

var data = encSocket.InitialData?.Length > 0 ? encSocket.InitialData : null;
return new EncryptorResult (encSocket.Decryptor, encSocket.Encryptor, data);
}
else if (supportsPlainText)
{
return new EncryptorResult (PlainTextEncryption.Instance, PlainTextEncryption.Instance, null);
}
throw new EncryptionException("Invalid handshake received and no decryption works");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class HandshakeMessage : PeerMessage

public override int ByteLength
{
get { return 68; }
get { return HandshakeMessage.HandshakeLength; }
}

/// <summary>
Expand Down
12 changes: 6 additions & 6 deletions src/MonoTorrent/MonoTorrent.Client/Managers/ConnectionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,11 @@ internal async void ProcessFreshConnection(PeerId id)
try
{
// Increase the count of the "open" connections
var initialData = await EncryptorFactory.CheckEncryptionAsync (id, 0, new[] { id.TorrentManager.InfoHash });
await EndCheckEncryption(id, initialData);
var result = await EncryptorFactory.CheckOutgoingConnectionAsync (id.Connection, id.Peer.Encryption, engine.Settings, id.TorrentManager.InfoHash);
id.Decryptor = result.Decryptor;
id.Encryptor = result.Encryptor;

await EndCheckEncryption(id);

id.WhenConnected.Restart ();
// Baseline the time the last block was received
Expand All @@ -225,13 +228,10 @@ internal async void ProcessFreshConnection(PeerId id)
}
}

private async Task EndCheckEncryption(PeerId id, byte[] initialData)
private async Task EndCheckEncryption(PeerId id)
{
try
{
if (initialData != null && initialData.Length > 0)
throw new EncryptionException("unhandled initial data");

EncryptionTypes e = engine.Settings.AllowedEncryption;
if (id.Encryptor is RC4 && !e.HasFlag (EncryptionTypes.RC4Full) ||
id.Encryptor is RC4Header && !e.HasFlag (EncryptionTypes.RC4Header) ||
Expand Down
5 changes: 4 additions & 1 deletion src/MonoTorrent/MonoTorrent.Client/Managers/ListenManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ private async void ConnectionReceived(object sender, NewConnectionEventArgs e)
for (int i = 0; i < Engine.Torrents.Count; i++)
skeys.Add(Engine.Torrents[i].InfoHash);

var initialData = await EncryptorFactory.CheckEncryptionAsync(id, HandshakeMessage.HandshakeLength, skeys.ToArray());
var result = await EncryptorFactory.CheckIncomingConnectionAsync(id.Connection, id.Peer.Encryption, Engine.Settings, HandshakeMessage.HandshakeLength, skeys.ToArray());
id.Decryptor = result.Decryptor;
id.Encryptor = result.Encryptor;
var initialData = result.InitialData;
if (initialData != null && initialData.Length != HandshakeMessage.HandshakeLength)
{
e.Connection.Dispose();
Expand Down

0 comments on commit e0cbf35

Please sign in to comment.