diff --git a/src/Foundatio.TestHarness/Messaging/MessageBusTestBase.cs b/src/Foundatio.TestHarness/Messaging/MessageBusTestBase.cs index 7bcdff6c2..0fcc8e43d 100644 --- a/src/Foundatio.TestHarness/Messaging/MessageBusTestBase.cs +++ b/src/Foundatio.TestHarness/Messaging/MessageBusTestBase.cs @@ -304,7 +304,15 @@ public virtual async Task CanTolerateSubscriberFailureAsync() { return; try { - var countdown = new AsyncCountdownEvent(2); + var countdown = new AsyncCountdownEvent(4); + await messageBus.SubscribeAsync(msg => { + Assert.Equal("Hello", msg.Data); + countdown.Signal(); + }); + await messageBus.SubscribeAsync(msg => { + Assert.Equal("Hello", msg.Data); + countdown.Signal(); + }); await messageBus.SubscribeAsync(msg => { throw new Exception(); }); diff --git a/src/Foundatio/Messaging/InMemoryMessageBus.cs b/src/Foundatio/Messaging/InMemoryMessageBus.cs index cd77c0fc8..572ec58e3 100644 --- a/src/Foundatio/Messaging/InMemoryMessageBus.cs +++ b/src/Foundatio/Messaging/InMemoryMessageBus.cs @@ -31,20 +31,20 @@ public void ResetMessagesSent() { _messageCounts.Clear(); } - protected override Task PublishImplAsync(string messageType, object message, TimeSpan? delay, CancellationToken cancellationToken) { + protected override async Task PublishImplAsync(string messageType, object message, TimeSpan? delay, CancellationToken cancellationToken) { Interlocked.Increment(ref _messagesSent); _messageCounts.AddOrUpdate(messageType, t => 1, (t, c) => c + 1); Type mappedType = GetMappedMessageType(messageType); if (_subscribers.IsEmpty) - return Task.CompletedTask; + return; bool isTraceLogLevelEnabled = _logger.IsEnabled(LogLevel.Trace); if (delay.HasValue && delay.Value > TimeSpan.Zero) { if (isTraceLogLevelEnabled) _logger.LogTrace("Schedule delayed message: {MessageType} ({Delay}ms)", messageType, delay.Value.TotalMilliseconds); SendDelayedMessage(mappedType, message, delay.Value); - return Task.CompletedTask; + return; } var body = SerializeMessageBody(messageType, message); @@ -54,8 +54,12 @@ protected override Task PublishImplAsync(string messageType, object message, Tim Data = body }; - SendMessageToSubscribers(messageData); - return Task.CompletedTask; + try { + await SendMessageToSubscribers(messageData); + } catch (Exception ex) { + // swallow exceptions from subscriber handlers for the in memory bus + _logger.LogWarning(ex, "Error sending message to subscribers: {ErrorMessage}", ex.Message); + } } } } \ No newline at end of file diff --git a/src/Foundatio/Messaging/MessageBusBase.cs b/src/Foundatio/Messaging/MessageBusBase.cs index a181aa18f..9ec8a3ccf 100644 --- a/src/Foundatio/Messaging/MessageBusBase.cs +++ b/src/Foundatio/Messaging/MessageBusBase.cs @@ -149,7 +149,7 @@ protected virtual object DeserializeMessageBody(string messageType, byte[] data) return body; } - protected void SendMessageToSubscribers(IMessage message) { + protected async Task SendMessageToSubscribers(IMessage message) { bool isTraceLogLevelEnabled = _logger.IsEnabled(LogLevel.Trace); var subscribers = GetMessageSubscribers(message); @@ -167,19 +167,19 @@ protected void SendMessageToSubscribers(IMessage message) { return; } - foreach (var subscriber in subscribers) { + var subscriberHandlers = subscribers.Select(subscriber => { if (subscriber.CancellationToken.IsCancellationRequested) { - if (_subscribers.TryRemove(subscriber.Id, out var _)) { + if (_subscribers.TryRemove(subscriber.Id, out _)) { if (isTraceLogLevelEnabled) _logger.LogTrace("Removed cancelled subscriber: {SubscriberId}", subscriber.Id); } else if (isTraceLogLevelEnabled) { _logger.LogTrace("Unable to remove cancelled subscriber: {SubscriberId}", subscriber.Id); } - continue; + return Task.CompletedTask; } - Task.Factory.StartNew(async () => { + return Task.Run(async () => { if (subscriber.CancellationToken.IsCancellationRequested) { if (isTraceLogLevelEnabled) _logger.LogTrace("The cancelled subscriber action will not be called: {SubscriberId}", subscriber.Id); @@ -190,19 +190,23 @@ protected void SendMessageToSubscribers(IMessage message) { if (isTraceLogLevelEnabled) _logger.LogTrace("Calling subscriber action: {SubscriberId}", subscriber.Id); - try { - if (subscriber.Type == typeof(IMessage)) - await subscriber.Action(message, subscriber.CancellationToken).AnyContext(); - else - await subscriber.Action(body.Value, subscriber.CancellationToken).AnyContext(); - - if (isTraceLogLevelEnabled) - _logger.LogTrace("Finished calling subscriber action: {SubscriberId}", subscriber.Id); - } catch (Exception ex) { - if (_logger.IsEnabled(LogLevel.Warning)) - _logger.LogWarning(ex, "Error sending message to subscriber: {ErrorMessage}", ex.Message); - } + if (subscriber.Type == typeof(IMessage)) + await subscriber.Action(message, subscriber.CancellationToken).AnyContext(); + else + await subscriber.Action(body.Value, subscriber.CancellationToken).AnyContext(); + + if (isTraceLogLevelEnabled) + _logger.LogTrace("Finished calling subscriber action: {SubscriberId}", subscriber.Id); }); + }); + + try { + await Task.WhenAll(subscriberHandlers.ToArray()); + } catch (Exception ex) { + if (_logger.IsEnabled(LogLevel.Warning)) + _logger.LogWarning(ex, "Error sending message to subscribers: {ErrorMessage}", ex.Message); + + throw; } if (isTraceLogLevelEnabled)