-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Abstract away socket logic and make able to handle both named pipes f…
…or Win10 SSH and AF_UNIX sockets for WSL
- Loading branch information
Showing
9 changed files
with
369 additions
and
138 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,66 +1,54 @@ | ||
using System; | ||
using System.IO; | ||
using System.Net.Sockets; | ||
using System.Threading; | ||
using Microsoft.Extensions.CommandLineUtils; | ||
using System.Collections.Generic; | ||
using System.Threading.Tasks; | ||
|
||
namespace WslSSHPageant | ||
{ | ||
class Program | ||
{ | ||
static Mutex mutex; | ||
|
||
static async Task Main(string[] args) | ||
static void Main(string[] args) | ||
{ | ||
var socketPath = @".\ssh-agent.sock"; | ||
CommandLineApplication commandLineApplication = new CommandLineApplication(throwOnUnexpectedArg: false); | ||
|
||
if (args.Length == 1) | ||
{ | ||
socketPath = args[0]; | ||
} | ||
else if (args.Length != 0) | ||
{ | ||
Console.WriteLine(@"wsl-ssh-agent.exe <path: .\ssh-agent.sock>"); | ||
return; | ||
} | ||
CommandOption wslSocketPath = commandLineApplication.Option( | ||
"--wsl <path>", | ||
"Which path to listen on with the AF_UNIX socket for WSL", | ||
CommandOptionType.SingleValue); | ||
|
||
socketPath = Path.GetFullPath(socketPath); | ||
CommandOption winsshPipeName = commandLineApplication.Option( | ||
"--winssh <name>", | ||
"Which pipe to listen on for Windows 10 OpenSSH Client", | ||
CommandOptionType.SingleValue); | ||
|
||
var mutexName = socketPath + "-{642b3e23-f0f5-4cc1-8a41-bf95e9a438ad}"; | ||
mutexName = mutexName.Replace(Path.DirectorySeparatorChar, '_'); | ||
mutex = new Mutex(true, mutexName); | ||
commandLineApplication.HelpOption("-? | -h | --help"); | ||
|
||
if (!mutex.WaitOne(TimeSpan.Zero, true)) | ||
{ | ||
Console.Error.WriteLine("Already running on that AF_UNIX path"); | ||
Console.In.ReadLine(); | ||
return; | ||
} | ||
List<Task> runningServers = new List<Task>(); | ||
|
||
try | ||
commandLineApplication.OnExecute(() => | ||
{ | ||
File.Delete(socketPath); | ||
var server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.IP); | ||
server.Bind(new UnixEndPoint(socketPath)); | ||
server.Listen(5); | ||
|
||
Console.WriteLine(@"Listening on {0}", socketPath); | ||
|
||
// Enter the listening loop. | ||
while (true) | ||
if (wslSocketPath.HasValue()) | ||
{ | ||
WSLSocket wslSocket = new WSLSocket(wslSocketPath.Value()); | ||
runningServers.Add(wslSocket.Listen()); | ||
} | ||
if (winsshPipeName.HasValue()) | ||
{ | ||
WSLClient client = new WSLClient(await server.AcceptAsync()); | ||
WinSSHSocket winsshSocket = new WinSSHSocket(winsshPipeName.Value()); | ||
runningServers.Add(winsshSocket.Listen()); | ||
} | ||
|
||
// Don't await this, we want to service other sockets | ||
#pragma warning disable CS4014 | ||
client.WorkSocket(); | ||
#pragma warning restore CS4014 | ||
if (runningServers.Count < 1) | ||
{ | ||
commandLineApplication.ShowHelp(); | ||
return 1; | ||
} | ||
} | ||
finally | ||
{ | ||
mutex.ReleaseMutex(); | ||
} | ||
|
||
Task.WaitAny(runningServers.ToArray()); | ||
|
||
return 0; | ||
}); | ||
|
||
commandLineApplication.Execute(args); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
using System; | ||
using System.Net.Sockets; | ||
using System.Threading.Tasks; | ||
|
||
namespace WslSSHPageant | ||
{ | ||
internal abstract class SSHAgentClientPartialRead : SSHAgentClient | ||
{ | ||
internal SSHAgentClientPartialRead() | ||
{ | ||
} | ||
|
||
protected override async Task<bool> ReceiveArraySegment(ArraySegment<byte> buf) | ||
{ | ||
int i; | ||
while ((i = await ReceivePartialArraySegment(buf)) != 0) | ||
{ | ||
buf = buf.Slice(i); | ||
if (buf.Count <= 0) | ||
{ | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
protected abstract Task<int> ReceivePartialArraySegment(ArraySegment<byte> buf); | ||
} | ||
|
||
internal abstract class SSHAgentClient | ||
{ | ||
internal SSHAgentClient() | ||
{ | ||
} | ||
|
||
protected virtual void Initialize() | ||
{ | ||
return; | ||
} | ||
|
||
protected abstract bool IsConnected(); | ||
|
||
protected abstract void Close(); | ||
|
||
protected abstract Task<bool> ReceiveArraySegment(ArraySegment<byte> buf); | ||
|
||
protected abstract Task<bool> SendArraySegment(ArraySegment<byte> buf); | ||
|
||
internal async Task WorkSocket() | ||
{ | ||
Initialize(); | ||
|
||
bool clientWasSuccess = false; | ||
|
||
try | ||
{ | ||
clientWasSuccess = await ServiceSocket(); | ||
} | ||
catch (TimeoutException) | ||
{ | ||
// Ignore timeouts, those should not explode our stuff | ||
Console.Error.WriteLine("Socket timeout"); | ||
} | ||
// These two just mean the remote end closed the socket, we don't care, same for TaskCanceledException | ||
catch (ObjectDisposedException) { } | ||
catch (InvalidOperationException) { } | ||
catch (TaskCanceledException) { } | ||
catch (SocketException e) | ||
{ | ||
// Other socket errors can happen and shouldn't kill the app | ||
Console.Error.WriteLine(e); | ||
} | ||
catch (PageantException e) | ||
{ | ||
// Pageant errors can happen, too | ||
Console.Error.WriteLine(e); | ||
} | ||
catch (Exception e) | ||
{ | ||
Console.Error.WriteLine(e); | ||
throw e; | ||
} | ||
finally | ||
{ | ||
if (IsConnected() && !clientWasSuccess) | ||
{ | ||
try | ||
{ | ||
await SendArraySegment(PageantHandler.AGENT_EMPTY_RESPONSE); | ||
} | ||
catch { } | ||
} | ||
|
||
Close(); | ||
} | ||
} | ||
|
||
private async Task<bool> ServiceSocket() | ||
{ | ||
var bytes = new byte[PageantHandler.AGENT_MAX_MSGLEN]; | ||
|
||
bool lastWasSuccess = true; | ||
|
||
while (IsConnected()) | ||
{ | ||
// Read length as uint32 (4 bytes) | ||
if (!await ReceiveArraySegment(new ArraySegment<byte>(bytes, 0, 4))) | ||
{ | ||
break; | ||
} | ||
|
||
lastWasSuccess = false; | ||
|
||
var len = (bytes[0] << 24) | | ||
(bytes[1] << 16) | | ||
(bytes[2] << 8) | | ||
(bytes[3]); | ||
|
||
if (len + 4 > PageantHandler.AGENT_MAX_MSGLEN) | ||
{ | ||
break; | ||
} | ||
|
||
// Read actual data in the part after len | ||
if (!await ReceiveArraySegment(new ArraySegment<byte>(bytes, 4, len))) | ||
{ | ||
break; | ||
} | ||
|
||
var msg = PageantHandler.Query(new ArraySegment<byte>(bytes, 0, len + 4)); | ||
await SendArraySegment(new ArraySegment<byte>(msg, 0, msg.Length)); | ||
lastWasSuccess = true; | ||
} | ||
|
||
return lastWasSuccess; | ||
} | ||
} | ||
} |
Oops, something went wrong.