Skip to content

Commit

Permalink
fix intermittent OperationCancelledException on closing websocket con…
Browse files Browse the repository at this point in the history
…nection (#177)

* fix build.

* fix intermittent OperationCancelledException on closing websocket connection.

* always close connection after RunAsync has finished.

* Make sure token is not created multiple times as ObjectDisposedException could throw.

* added ConnectionId to log messages.

* do not forcefully close connection unless exception is caught.

* capture any exception on CloseAsync.

* Some corrections to OWIN/System.WebSockets implementation

Code format

Removed IsConnected property

Code formatting

Use CancellationToken.None

Fixing if

Checking for multiple disposes in AsyncWampConnection, checking for cancellation requests in WebSocketWrapperConnection
  • Loading branch information
bigbearzhu authored and darkl committed May 8, 2017
1 parent 13c9e63 commit 2bee21b
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 58 deletions.
40 changes: 34 additions & 6 deletions src/net45/Extensions/WampSharp.Owin/Owin/OwinWebSocketWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ internal class OwinWebSocketWrapper : IWebSocketWrapper
private readonly Func<ArraySegment<byte>, int, bool, CancellationToken, Task> mSendAsync;
private readonly Func<ArraySegment<byte>, CancellationToken, Task<Tuple<int, bool, int>>> mReceiveAsync;
private readonly Func<int, string, CancellationToken, Task> mCloseAsync;
private WebSocketState mState = WebSocketState.Open;

private const string WebSocketSendAsync = "websocket.SendAsync";
private const string WebSocketReceiveAsync = "websocket.ReceiveAsync";
Expand Down Expand Up @@ -95,7 +96,29 @@ public async Task<WebSocketReceiveResult> ReceiveAsync
await mReceiveAsync(arraySegment, callCancelled)
.ConfigureAwait(false);

return new WebSocketReceiveResult(count: result.Item3, messageType: GetMessageType(result.Item1), endOfMessage: result.Item2);
WebSocketMessageType webSocketMessageType = GetMessageType(result.Item1);

if (webSocketMessageType == WebSocketMessageType.Close)
{
ChangeState(actionDone: WebSocketState.CloseReceived,
dualAction: WebSocketState.CloseSent);

return new WebSocketReceiveResult(count: result.Item3,
messageType: webSocketMessageType,
endOfMessage: result.Item2,
closeStatus: this.ClientCloseStatus,
closeStatusDescription: WebSocketClientCloseDescription);
}

return new WebSocketReceiveResult(count: result.Item3, messageType: webSocketMessageType, endOfMessage: result.Item2);
}

public WebSocketState State
{
get
{
return mState;
}
}

public Task SendAsync(ArraySegment<byte> data, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancel)
Expand All @@ -105,16 +128,21 @@ public Task SendAsync(ArraySegment<byte> data, WebSocketMessageType messageType,

public Task CloseAsync(WebSocketCloseStatus closeStatus, string closeDescription, CancellationToken cancel)
{
ChangeState(actionDone: WebSocketState.CloseSent,
dualAction: WebSocketState.CloseReceived);

return mCloseAsync((int) closeStatus, closeDescription, cancel);
}

public bool IsConnected
private void ChangeState(WebSocketState actionDone, WebSocketState dualAction)
{
get
if (State == WebSocketState.Open)
{
WebSocketCloseStatus? closeStatus = ClientCloseStatus;

return ((closeStatus == null) || (closeStatus == 0));
mState = actionDone;
}
else if (mState == dualAction)
{
mState = WebSocketState.Closed;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
{
{
"frameworks": {
"net45": {}
},
"runtimes": {
"win": {}
},
"dependencies": {
"Microsoft.Owin": "3.0.1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Threading.Tasks;
using WampSharp.Core.Listener;
using WampSharp.Core.Message;
using WampSharp.Logging;
using WampSharp.V2.Authentication;
using WampSharp.V2.Binding.Parsers;

Expand All @@ -16,15 +17,17 @@ public abstract class WebSocketWrapperConnection<TMessage> : AsyncWebSocketWampC
{
private readonly IWampStreamingMessageParser<TMessage> mParser;
private readonly IWebSocketWrapper mWebSocket;
private readonly CancellationTokenSource mCancellationTokenSource;
private CancellationTokenSource mCancellationTokenSource;
private readonly Uri mAddressUri;
private CancellationToken mCancellationToken;

public WebSocketWrapperConnection(IWebSocketWrapper webSocket, IWampStreamingMessageParser<TMessage> parser, ICookieProvider cookieProvider, ICookieAuthenticatorFactory cookieAuthenticatorFactory) :
base(cookieProvider, cookieAuthenticatorFactory)
{
mWebSocket = webSocket;
mParser = parser;
mCancellationTokenSource = new CancellationTokenSource();
mCancellationToken = mCancellationTokenSource.Token;
}

protected WebSocketWrapperConnection(IClientWebSocketWrapper clientWebSocket, Uri addressUri, string protocolName, IWampStreamingMessageParser<TMessage> parser) :
Expand All @@ -36,8 +39,9 @@ protected WebSocketWrapperConnection(IClientWebSocketWrapper clientWebSocket, Ur

protected override Task SendAsync(WampMessage<object> message)
{
mLogger.Debug("Attempting to send a message");
ArraySegment<byte> messageToSend = GetMessageInBytes(message);
return mWebSocket.SendAsync(messageToSend, WebSocketMessageType, true, mCancellationTokenSource.Token);
return mWebSocket.SendAsync(messageToSend, WebSocketMessageType, true, mCancellationToken);
}

protected abstract ArraySegment<byte> GetMessageInBytes(WampMessage<object> message);
Expand All @@ -48,12 +52,12 @@ protected async void Connect()
{
try
{
await this.ClientWebSocket.ConnectAsync(mAddressUri, mCancellationTokenSource.Token)
await this.ClientWebSocket.ConnectAsync(mAddressUri, mCancellationToken)
.ConfigureAwait(false);

RaiseConnectionOpen();

Task task = Task.Run((Func<Task>) this.RunAsync, mCancellationTokenSource.Token);
Task task = Task.Run((Func<Task>) this.RunAsync, mCancellationToken);
}
catch (Exception ex)
{
Expand Down Expand Up @@ -87,36 +91,13 @@ data is very small.
MemoryStream memoryStream = new MemoryStream();

// Checks WebSocket state.
while (mWebSocket.IsConnected)
while (IsConnected && !mCancellationToken.IsCancellationRequested)
{
// Reads data.
WebSocketReceiveResult webSocketReceiveResult;
WebSocketReceiveResult webSocketReceiveResult =
await ReadMessage(receivedDataBuffer, memoryStream);

long length = 0;
do
{
webSocketReceiveResult =
await mWebSocket.ReceiveAsync(receivedDataBuffer, mCancellationTokenSource.Token)
.ConfigureAwait(false);

length += webSocketReceiveResult.Count;

await memoryStream.WriteAsync(receivedDataBuffer.Array, receivedDataBuffer.Offset,
webSocketReceiveResult.Count, mCancellationTokenSource.Token)
.ConfigureAwait(false);

} while (!webSocketReceiveResult.EndOfMessage);

// If input frame is cancelation frame, send close command.
if (webSocketReceiveResult.MessageType == WebSocketMessageType.Close)
{
this.RaiseConnectionClosed();

await mWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure,
String.Empty, mCancellationTokenSource.Token)
.ConfigureAwait(false);
}
else
if (webSocketReceiveResult.MessageType != WebSocketMessageType.Close)
{
memoryStream.Position = 0;
OnNewMessage(memoryStream);
Expand All @@ -128,9 +109,62 @@ await mWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure,
}
catch (Exception ex)
{
RaiseConnectionError(ex);
RaiseConnectionClosed();
// Cancellation token could be cancelled in Dispose if a
// Goodbye message has been received.
if (!(ex is OperationCanceledException) ||
!mCancellationToken.IsCancellationRequested)
{
RaiseConnectionError(ex);
}
}

if (mWebSocket.State != WebSocketState.CloseReceived &&
mWebSocket.State != WebSocketState.Closed)
{
await CloseWebSocket().ConfigureAwait(false);
}

RaiseConnectionClosed();
}

private async Task CloseWebSocket()
{
try
{
await mWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure,
String.Empty,
CancellationToken.None)
.ConfigureAwait(false);
}
catch (Exception ex)
{
mLogger.WarnException("Failed sending a close message to client", ex);
}
}

private async Task<WebSocketReceiveResult> ReadMessage(ArraySegment<byte> receivedDataBuffer, MemoryStream memoryStream)
{
WebSocketReceiveResult webSocketReceiveResult;

long length = 0;

do
{
webSocketReceiveResult =
await mWebSocket.ReceiveAsync(receivedDataBuffer, mCancellationToken)
.ConfigureAwait(false);

length += webSocketReceiveResult.Count;

await memoryStream.WriteAsync(receivedDataBuffer.Array,
receivedDataBuffer.Offset,
webSocketReceiveResult.Count,
mCancellationToken)
.ConfigureAwait(false);
}
while (!webSocketReceiveResult.EndOfMessage);

return webSocketReceiveResult;
}

private void OnNewMessage(MemoryStream payloadData)
Expand All @@ -142,14 +176,16 @@ private void OnNewMessage(MemoryStream payloadData)
protected override void Dispose()
{
mCancellationTokenSource.Cancel();
mCancellationTokenSource.Dispose();
mCancellationTokenSource = null;
}

protected override bool IsConnected
{
get
{
return mWebSocket.IsConnected;
return mWebSocket.State == WebSocketState.Open;
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace WampSharp.WebSockets
{
public interface IWebSocketWrapper
{
bool IsConnected { get; }
WebSocketState State { get; }
Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> arraySegment, CancellationToken callCancelled);
Task SendAsync(ArraySegment<byte> data, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancel);
Task CloseAsync(WebSocketCloseStatus closeStatus, string closeDescription, CancellationToken cancel);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@ public WebSocketWrapper(WebSocket webSocket)
mWebSocket = webSocket;
}

public bool IsConnected
{
get
{
return mWebSocket.State == WebSocketState.Open;
}
}

public Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> arraySegment, CancellationToken callCancelled)
{
return mWebSocket.ReceiveAsync(arraySegment, callCancelled);
Expand All @@ -36,5 +28,10 @@ public Task CloseAsync(WebSocketCloseStatus closeStatus, string closeDescription
{
return mWebSocket.CloseAsync(closeStatus, closeDescription, cancel);
}

public WebSocketState State
{
get { return mWebSocket.State; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ public abstract class AsyncWampConnection<TMessage> : IWampConnection<TMessage>,
{
private readonly ActionBlock<WampMessage<object>> mSendBlock;
protected readonly ILog mLogger;
private int mDisposeCalled = 0;

protected AsyncWampConnection()
{
mLogger = LogProvider.GetLogger(this.GetType());
mLogger = new LoggerWithConnectionId(LogProvider.GetLogger(this.GetType()));
mSendBlock = new ActionBlock<WampMessage<object>>(x => InnerSend(x));
}

Expand Down Expand Up @@ -112,6 +113,7 @@ protected virtual void RaiseMessageArrived(WampMessage<TMessage> message)

protected virtual void RaiseConnectionClosed()
{
mLogger.Debug("Connection has been closed");
var handler = ConnectionClosed;
if (handler != null) handler(this, EventArgs.Empty);
}
Expand All @@ -125,9 +127,12 @@ protected virtual void RaiseConnectionError(Exception ex)

void IDisposable.Dispose()
{
mSendBlock.Complete();
mSendBlock.Completion.Wait();
this.Dispose();
if (Interlocked.CompareExchange(ref mDisposeCalled, 1, 0) == 0)
{
mSendBlock.Complete();
mSendBlock.Completion.Wait();
this.Dispose();
}
}

protected abstract void Dispose();
Expand All @@ -136,9 +141,12 @@ void IDisposable.Dispose()

async Task IAsyncDisposable.DisposeAsync()
{
mSendBlock.Complete();
await mSendBlock.Completion;
this.Dispose();
if (Interlocked.CompareExchange(ref mDisposeCalled, 1, 0) == 0)
{
mSendBlock.Complete();
await mSendBlock.Completion;
this.Dispose();
}
}

#else
Expand All @@ -151,5 +159,26 @@ Task IAsyncDisposable.DisposeAsync()

#endif

// TODO: move this to another file (after making it more generic)
// TODO: or get rid of this.
private class LoggerWithConnectionId : ILog
{
private readonly ILog mLogger;
private readonly string mConnectionId;

public LoggerWithConnectionId(ILog logger)
{
mConnectionId = Guid.NewGuid().ToString();
mLogger = logger;
}

public bool Log(LogLevel logLevel, Func<string> messageFunc, Exception exception = null, params object[] formatParameters)
{
using (LogProvider.OpenMappedContext("ConncetionId", mConnectionId))
{
return mLogger.Log(logLevel, messageFunc, exception, formatParameters);
}
}
}
}
}

0 comments on commit 2bee21b

Please sign in to comment.