diff --git a/Directory.Build.props b/Directory.Build.props index f3332cb93b..e47830316f 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ 3.46.1 3.47.0 preview.1 - 3.37.5 + 3.37.6 1.0.0 beta.0 2.0.4 diff --git a/Microsoft.Azure.Cosmos/src/DocumentClient.cs b/Microsoft.Azure.Cosmos/src/DocumentClient.cs index d2df25a539..4f0f800d43 100644 --- a/Microsoft.Azure.Cosmos/src/DocumentClient.cs +++ b/Microsoft.Azure.Cosmos/src/DocumentClient.cs @@ -6652,17 +6652,6 @@ private bool IsValidConsistency( private void InitializeDirectConnectivity(IStoreClientFactory storeClientFactory) { - this.AddressResolver = new GlobalAddressResolver( - this.GlobalEndpointManager, - this.PartitionKeyRangeLocation, - this.ConnectionPolicy.ConnectionProtocol, - this, - this.collectionCache, - this.partitionKeyRangeCache, - this.accountServiceConfiguration, - this.ConnectionPolicy, - this.httpClient); - // Check if we have a store client factory in input and if we do, do not initialize another store client // The purpose is to reuse store client factory across all document clients inside compute gateway if (storeClientFactory != null) @@ -6704,7 +6693,6 @@ private void InitializeDirectConnectivity(IStoreClientFactory storeClientFactory sendHangDetectionTimeSeconds: this.rntbdSendHangDetectionTimeSeconds, retryWithConfiguration: this.ConnectionPolicy.RetryOptions?.GetRetryWithConfiguration(), enableTcpConnectionEndpointRediscovery: this.ConnectionPolicy.EnableTcpConnectionEndpointRediscovery, - addressResolver: this.AddressResolver, rntbdMaxConcurrentOpeningConnectionCount: this.rntbdMaxConcurrentOpeningConnectionCount, remoteCertificateValidationCallback: this.remoteCertificateValidationCallback, distributedTracingOptions: distributedTracingOptions, @@ -6719,6 +6707,18 @@ private void InitializeDirectConnectivity(IStoreClientFactory storeClientFactory this.isStoreClientFactoryCreatedInternally = true; } + this.AddressResolver = new GlobalAddressResolver( + this.GlobalEndpointManager, + this.PartitionKeyRangeLocation, + this.ConnectionPolicy.ConnectionProtocol, + this, + this.collectionCache, + this.partitionKeyRangeCache, + this.accountServiceConfiguration, + this.ConnectionPolicy, + this.httpClient, + this.storeClientFactory.GetConnectionStateListener()); + this.CreateStoreModel(subscribeRntbdStatus: true); } diff --git a/Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs b/Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs index f2012fe0bc..4106509ed5 100644 --- a/Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs +++ b/Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs @@ -53,6 +53,7 @@ internal class GatewayAddressCache : IAddressCache, IDisposable private readonly SemaphoreSlim semaphore; private readonly CosmosHttpClient httpClient; private readonly bool isReplicaAddressValidationEnabled; + private readonly IConnectionStateListener connectionStateListener; private Tuple masterPartitionAddressCache; private DateTime suboptimalMasterPartitionTimestamp; @@ -67,6 +68,7 @@ public GatewayAddressCache( IServiceConfigurationReader serviceConfigReader, CosmosHttpClient httpClient, IOpenConnectionsHandler openConnectionsHandler, + IConnectionStateListener connectionStateListener, long suboptimalPartitionForceRefreshIntervalInSeconds = 600, bool enableTcpConnectionEndpointRediscovery = false, bool replicaAddressValidationEnabled = false) @@ -81,6 +83,7 @@ public GatewayAddressCache( this.serverPartitionAddressToPkRangeIdMap = new ConcurrentDictionary>(); this.suboptimalMasterPartitionTimestamp = DateTime.MaxValue; this.enableTcpConnectionEndpointRediscovery = enableTcpConnectionEndpointRediscovery; + this.connectionStateListener = connectionStateListener; this.suboptimalPartitionForceRefreshIntervalInSeconds = suboptimalPartitionForceRefreshIntervalInSeconds; @@ -496,6 +499,13 @@ private static void LogPartitionCacheRefresh( public async Task MarkAddressesToUnhealthyAsync( ServerKey serverKey) { + if (this.disposedValue) + { + // Will enable Listener to un-register in-case of un-graceful dispose + // + throw new ObjectDisposedException(nameof(GatewayAddressCache)); + } + if (serverKey == null) { throw new ArgumentNullException(nameof(serverKey)); @@ -538,6 +548,9 @@ where serverKey.Equals(transportAddress.ReplicaServerKey) address.SetUnhealthy(); } + + // Update the health status + this.CaptureTransportAddressUriHealthStates(addressInfo, transportAddresses); } } } @@ -828,9 +841,21 @@ internal Tuple ToPartiti partitionKeyRangeIdentity.PartitionKeyRangeId, addressInfo.PhysicalUri); + HashSet createdValue = null; + ServerKey serverKey = new ServerKey(new Uri(addressInfo.PhysicalUri)); HashSet pkRangeIdSet = this.serverPartitionAddressToPkRangeIdMap.GetOrAdd( - new ServerKey(new Uri(addressInfo.PhysicalUri)), - (_) => new HashSet()); + serverKey, + (_) => + { + createdValue = new HashSet(); + return createdValue; + }); + + if (object.ReferenceEquals(pkRangeIdSet, createdValue)) + { + this.connectionStateListener.Register(serverKey, this.MarkAddressesToUnhealthyAsync); + } + lock (pkRangeIdSet) { pkRangeIdSet.Add(partitionKeyRangeIdentity); @@ -893,7 +918,7 @@ private static void LogAddressResolutionEnd(DocumentServiceRequest request, stri private static Protocol ProtocolFromString(string protocol) { - return (protocol.ToLowerInvariant()) switch + return protocol.ToLowerInvariant() switch { RuntimeConstants.Protocols.HTTPS => Protocol.Https, RuntimeConstants.Protocols.RNTBD => Protocol.Tcp, @@ -903,7 +928,7 @@ private static Protocol ProtocolFromString(string protocol) private static string ProtocolString(Protocol protocol) { - return ((int)protocol) switch + return (int)protocol switch { (int)Protocol.Https => RuntimeConstants.Protocols.HTTPS, (int)Protocol.Tcp => RuntimeConstants.Protocols.RNTBD, @@ -1071,11 +1096,18 @@ protected virtual void Dispose(bool disposing) { if (this.disposedValue) { + DefaultTrace.TraceInformation("GatewayAddressCache is already disposed {0}", this.GetHashCode()); return; } if (disposing) { + // Unregister the server-key + foreach (ServerKey serverKey in this.serverPartitionAddressToPkRangeIdMap.Keys) + { + this.connectionStateListener.UnRegister(serverKey, this.MarkAddressesToUnhealthyAsync); + } + this.serverPartitionAddressCache?.Dispose(); } diff --git a/Microsoft.Azure.Cosmos/src/Routing/GlobalAddressResolver.cs b/Microsoft.Azure.Cosmos/src/Routing/GlobalAddressResolver.cs index 058e9c3ce4..3205d4e60d 100644 --- a/Microsoft.Azure.Cosmos/src/Routing/GlobalAddressResolver.cs +++ b/Microsoft.Azure.Cosmos/src/Routing/GlobalAddressResolver.cs @@ -39,6 +39,7 @@ internal sealed class GlobalAddressResolver : IAddressResolverExtension, IDispos private readonly ConcurrentDictionary addressCacheByEndpoint; private readonly bool enableTcpConnectionEndpointRediscovery; private readonly bool isReplicaAddressValidationEnabled; + private readonly IConnectionStateListener connectionStateListener; private IOpenConnectionsHandler openConnectionsHandler; public GlobalAddressResolver( @@ -50,7 +51,8 @@ public GlobalAddressResolver( PartitionKeyRangeCache routingMapProvider, IServiceConfigurationReader serviceConfigReader, ConnectionPolicy connectionPolicy, - CosmosHttpClient httpClient) + CosmosHttpClient httpClient, + IConnectionStateListener connectionStateListener) { this.endpointManager = endpointManager; this.partitionKeyRangeLocationCache = partitionKeyRangeLocationCache; @@ -60,6 +62,7 @@ public GlobalAddressResolver( this.routingMapProvider = routingMapProvider; this.serviceConfigReader = serviceConfigReader; this.httpClient = httpClient; + this.connectionStateListener = connectionStateListener; int maxBackupReadEndpoints = !connectionPolicy.EnableReadRequestsFallback.HasValue || connectionPolicy.EnableReadRequestsFallback.Value @@ -229,19 +232,6 @@ public async Task ResolveAsync( return await resolver.ResolveAsync(request, forceRefresh, cancellationToken); } - public async Task UpdateAsync( - ServerKey serverKey, - CancellationToken cancellationToken) - { - foreach (KeyValuePair addressCache in this.addressCacheByEndpoint) - { - // since we don't know which address cache contains the pkRanges mapped to this node, - // we mark all transport uris that has the same server key to unhealthy status in the - // AddressCaches of all regions. - await addressCache.Value.AddressCache.MarkAddressesToUnhealthyAsync(serverKey); - } - } - /// /// ReplicatedResourceClient will use this API to get the direct connectivity AddressCache for given request. /// @@ -283,6 +273,7 @@ private EndpointCache GetOrAddEndpoint(Uri endpoint) this.serviceConfigReader, this.httpClient, this.openConnectionsHandler, + this.connectionStateListener, enableTcpConnectionEndpointRediscovery: this.enableTcpConnectionEndpointRediscovery, replicaAddressValidationEnabled: this.isReplicaAddressValidationEnabled); diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ConnectionStateMuxListenerTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ConnectionStateMuxListenerTests.cs new file mode 100644 index 0000000000..0b4aaae2a6 --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ConnectionStateMuxListenerTests.cs @@ -0,0 +1,276 @@ +namespace Microsoft.Azure.Cosmos.Tests +{ + using System; + using System.Collections.Generic; + using System.Linq; + using System.Net.Http; + using System.Text; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Routing; + using Microsoft.Azure.Documents; + using Microsoft.Azure.Documents.Client; + using Microsoft.Azure.Documents.Rntbd; + using Microsoft.Azure.Documents.Routing; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using Moq; + + [TestClass] + public class ConnectionStateMuxListenerTests + { + [Owner("kirankk")] + [TestMethod] + public void StoreClientFactoryV2Setup() + { + StoreClientFactory storeClientFactory = new StoreClientFactory(Protocol.Tcp, + requestTimeoutInSeconds: 10, + maxConcurrentConnectionOpenRequests: 10); + Assert.IsInstanceOfType(storeClientFactory.GetConnectionStateListener(), typeof(ConnectionStateMuxListener)); + + } + + [Owner("kirankk")] + [TestMethod] + public async Task DisableEndpointRediscoveryAsync() + { + bool enableTcpConnectionEndpointRediscovery = false; + ConnectionStateMuxListener connectionStateMuxListener = new ConnectionStateMuxListener(enableTcpConnectionEndpointRediscovery); + + ServerKey serverKey = new ServerKey(new Uri("http://localhost:8081/ep1")); + ServerKey calledBackServerKey = null; + connectionStateMuxListener.Register(serverKey, + async (sk) => + { + await Task.Yield(); + calledBackServerKey = sk; + }); + + await connectionStateMuxListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, serverKey); + + Assert.IsFalse(connectionStateMuxListener.enableTcpConnectionEndpointRediscovery); + Assert.IsFalse(connectionStateMuxListener.serverKeyEventHandlers.Any()); + Assert.IsNull(calledBackServerKey); + } + + [Owner("kirankk")] + [TestMethod] + public async Task RegisterUnregisterNotifyAsync() + { + bool enableTcpConnectionEndpointRediscovery = true; + ConnectionStateMuxListener connectionStateMuxListener = new ConnectionStateMuxListener(enableTcpConnectionEndpointRediscovery); + + ServerKey serverKey = new ServerKey(new Uri("http://localhost:8081/ep1")); + ServerKey calledBackServerKey = null; + Func callback = async (sk) => + { + await Task.Yield(); + calledBackServerKey = sk; + }; + + // Register + connectionStateMuxListener.Register(serverKey, callback); + Assert.IsTrue(connectionStateMuxListener.enableTcpConnectionEndpointRediscovery); + Assert.AreEqual(1, connectionStateMuxListener.serverKeyEventHandlers.Count); + Assert.AreEqual(1, connectionStateMuxListener.serverKeyEventHandlers.First().Value.Count); + + await connectionStateMuxListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, serverKey); + Assert.AreEqual(serverKey, calledBackServerKey); + + // Call with a different equivalent object + calledBackServerKey = null; + ServerKey serverKeyDuplicate = new ServerKey(new Uri("http://localhost:8081/ep1")); + await connectionStateMuxListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, serverKeyDuplicate); + + Assert.AreEqual(1, connectionStateMuxListener.serverKeyEventHandlers.Count); + Assert.AreEqual(serverKey, serverKeyDuplicate); + + // UnRegister + calledBackServerKey = null; + connectionStateMuxListener.UnRegister(serverKey, callback); + await connectionStateMuxListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, serverKey); + + Assert.IsTrue(connectionStateMuxListener.enableTcpConnectionEndpointRediscovery); + Assert.AreEqual(0, connectionStateMuxListener.serverKeyEventHandlers.Count); // Unregiter will not remove entry fully + Assert.IsNull(calledBackServerKey); + } + + [Owner("kirankk")] + [TestMethod] + public async Task DuplicateRegisterAndUnRegisterSingleCallbackAsync() + { + bool enableTcpConnectionEndpointRediscovery = true; + ConnectionStateMuxListener connectionStateMuxListener = new ConnectionStateMuxListener(enableTcpConnectionEndpointRediscovery); + + ServerKey serverKey = new ServerKey(new Uri("http://localhost:8081/ep1")); + ServerKey calledBackServerKey = null; + int callCount = 0; + Func callback = async (sk) => + { + await Task.Yield(); + callCount++; + calledBackServerKey = sk; + }; + + // Register + connectionStateMuxListener.Register(serverKey, callback); + await connectionStateMuxListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, serverKey); + + Assert.IsTrue(connectionStateMuxListener.enableTcpConnectionEndpointRediscovery); + Assert.AreEqual(1, connectionStateMuxListener.serverKeyEventHandlers.Count); + Assert.AreEqual(serverKey, calledBackServerKey); + Assert.AreEqual(1, callCount); + + // Re-register the name key and callback => also should get called only once + connectionStateMuxListener.Register(serverKey, callback); + calledBackServerKey = null; + await connectionStateMuxListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, serverKey); + + Assert.AreEqual(serverKey, calledBackServerKey); + Assert.AreEqual(2, callCount); + + // UnRegister + calledBackServerKey = null; + connectionStateMuxListener.UnRegister(serverKey, callback); + await connectionStateMuxListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, serverKey); + + Assert.IsTrue(connectionStateMuxListener.enableTcpConnectionEndpointRediscovery); + Assert.AreEqual(0, connectionStateMuxListener.serverKeyEventHandlers.Count); + Assert.IsNull(calledBackServerKey); + + // Double un-register + connectionStateMuxListener.UnRegister(serverKey, callback); + Assert.IsTrue(connectionStateMuxListener.enableTcpConnectionEndpointRediscovery); + Assert.AreEqual(0, connectionStateMuxListener.serverKeyEventHandlers.Count); + } + + [Owner("kirankk")] + [TestMethod] + public async Task DuplicateRegisterAndUnRegisterDifferentCallbacksAsync() + { + bool enableTcpConnectionEndpointRediscovery = true; + ConnectionStateMuxListener connectionStateMuxListener = new ConnectionStateMuxListener(enableTcpConnectionEndpointRediscovery); + + ServerKey serverKey = new ServerKey(new Uri("http://localhost:8081/ep1")); + int callCount = 0; + Func callback1 = async (sk) => + { + await Task.Yield(); + callCount++; + }; + Func callback2 = async (sk) => + { + await Task.Yield(); + callCount++; + }; + + // Register + connectionStateMuxListener.Register(serverKey, callback1); + connectionStateMuxListener.Register(serverKey, callback2); + await connectionStateMuxListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, serverKey); + + Assert.IsTrue(connectionStateMuxListener.enableTcpConnectionEndpointRediscovery); + Assert.AreEqual(1, connectionStateMuxListener.serverKeyEventHandlers.Count); + Assert.AreEqual(2, connectionStateMuxListener.serverKeyEventHandlers.First().Value.Count); + Assert.AreEqual(2, callCount); + + // UnRegister + connectionStateMuxListener.UnRegister(serverKey, callback1); + await connectionStateMuxListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, serverKey); + + Assert.IsTrue(connectionStateMuxListener.enableTcpConnectionEndpointRediscovery); + Assert.AreEqual(1, connectionStateMuxListener.serverKeyEventHandlers.Count); + Assert.AreEqual(3, callCount); + + // Double un-register + connectionStateMuxListener.UnRegister(serverKey, callback2); + Assert.IsTrue(connectionStateMuxListener.enableTcpConnectionEndpointRediscovery); + Assert.AreEqual(0, connectionStateMuxListener.serverKeyEventHandlers.Count); + } + + [Owner("kirankk")] + [TestMethod] + public void ArgumentCheck() + { + ConnectionStateMuxListener connectionStateMuxListener = new ConnectionStateMuxListener(true); + + Assert.ThrowsException(() => connectionStateMuxListener.Register(null, null)); + Assert.ThrowsException(() => connectionStateMuxListener.Register(new ServerKey(new Uri("http://localost:8081")), null)); + + Assert.ThrowsException(() => connectionStateMuxListener.UnRegister(null, null)); + Assert.ThrowsException(() => connectionStateMuxListener.UnRegister(new ServerKey(new Uri("http://localost:8081")), null)); + } + + [Owner("kirankk")] + [TestMethod] + public async Task DynamicConcurrencyMuxListenerTestAsync() + { + ConnectionStateMuxListener connectionStateMuxListener = new ConnectionStateMuxListener(true); + Assert.AreEqual(Environment.ProcessorCount, connectionStateMuxListener.notificationConcurrency); + Assert.IsTrue(connectionStateMuxListener.enableTcpConnectionEndpointRediscovery); + + int callbackCount = 0; + ServerKey serverKey = new ServerKey(new Uri("http://localhost:8081/")); + connectionStateMuxListener.Register(serverKey, + async (sk) => + { + callbackCount++; + await Task.Yield(); + }); + + await connectionStateMuxListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, serverKey); + Assert.AreEqual(1, callbackCount); + + connectionStateMuxListener.SetConnectionEventConcurrency(0); + Assert.AreEqual(0, connectionStateMuxListener.notificationConcurrency); + Assert.IsTrue(connectionStateMuxListener.enableTcpConnectionEndpointRediscovery); + + callbackCount = 0; + await connectionStateMuxListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, serverKey); + Assert.AreEqual(0, callbackCount); + } + + [Owner("kirankk")] + [TestMethod] + public async Task NotificationsConcurrencyListenerTestAsync() + { + int concurrentToTest = 2; + int totalRequests = concurrentToTest * 2; + IConnectionStateListener connectionStateMuxListener = new ConnectionStateMuxListener(true); + connectionStateMuxListener.SetConnectionEventConcurrency(concurrentToTest); + + SemaphoreSlim mainTaskSemsphore = new SemaphoreSlim(concurrentToTest); + ManualResetEvent manualResetEvent = new ManualResetEvent(false); + + int callbackCount = 0; + ServerKey serverKey = new ServerKey(new Uri("http://localhost:8081/")); + connectionStateMuxListener.Register(serverKey, + async (sk) => { + await mainTaskSemsphore.WaitAsync(); + + lock (sk) + { + callbackCount++; + if (callbackCount >= concurrentToTest) + { + manualResetEvent.Set(); + } + } + }); + + Task[] allTasks = new Task[totalRequests]; + for (int i = 0; i < totalRequests; i++) + { + allTasks[i] = connectionStateMuxListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, serverKey); + } + + manualResetEvent.WaitOne(); + Assert.AreEqual(concurrentToTest, callbackCount); + + mainTaskSemsphore.Release(concurrentToTest); + await Task.WhenAll(allTasks); + Assert.AreEqual(totalRequests, callbackCount); + + mainTaskSemsphore.Release(concurrentToTest); + } + } +} diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/GatewayAddressCacheTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/GatewayAddressCacheTests.cs index 58bc94e72d..e3bca41d0a 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/GatewayAddressCacheTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/GatewayAddressCacheTests.cs @@ -85,6 +85,7 @@ public void TestGatewayAddressCacheAutoRefreshOnSuboptimalPartition() this.mockServiceConfigReader.Object, MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), openConnectionsHandler: null, + Mock.Of(), suboptimalPartitionForceRefreshIntervalInSeconds: 2); int initialAddressesCount = cache.TryGetAddressesAsync( @@ -108,8 +109,9 @@ public void TestGatewayAddressCacheAutoRefreshOnSuboptimalPartition() Assert.IsTrue(finalAddressCount == this.targetReplicaSetSize); } + [TestMethod] - public async Task TestGatewayAddressCacheUpdateOnConnectionResetAsync() + public async Task TestGatewayAddressCacheAndConnectionStateListenerRegAndUnReg() { FakeMessageHandler messageHandler = new FakeMessageHandler(); HttpClient httpClient = new HttpClient(messageHandler) @@ -117,6 +119,8 @@ public async Task TestGatewayAddressCacheUpdateOnConnectionResetAsync() Timeout = TimeSpan.FromSeconds(120) }; + Mock connectionListenerMock = new Mock(); + GatewayAddressCache cache = new GatewayAddressCache( new Uri(GatewayAddressCacheTests.DatabaseAccountApiEndpoint), Documents.Client.Protocol.Tcp, @@ -124,6 +128,7 @@ public async Task TestGatewayAddressCacheUpdateOnConnectionResetAsync() this.mockServiceConfigReader.Object, MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), openConnectionsHandler: null, + connectionListenerMock.Object, suboptimalPartitionForceRefreshIntervalInSeconds: 2, enableTcpConnectionEndpointRediscovery: true); @@ -134,6 +139,39 @@ public async Task TestGatewayAddressCacheUpdateOnConnectionResetAsync() false, CancellationToken.None); + Mock.Get(connectionListenerMock.Object).Verify(x => x.Register(It.IsAny(), It.IsAny>()), Times.Exactly(3)); + cache.Dispose(); + Mock.Get(connectionListenerMock.Object).Verify(x => x.UnRegister(It.IsAny(), It.IsAny>()), Times.Exactly(3)); + + } + + [TestMethod] + public async Task TestGatewayAddressCacheUpdateOnConnectionResetAsync() + { + FakeMessageHandler messageHandler = new FakeMessageHandler(); + HttpClient httpClient = new HttpClient(messageHandler) + { + Timeout = TimeSpan.FromSeconds(120) + }; + + GatewayAddressCache cache = new GatewayAddressCache( + new Uri(GatewayAddressCacheTests.DatabaseAccountApiEndpoint), + Documents.Client.Protocol.Tcp, + this.mockTokenProvider.Object, + this.mockServiceConfigReader.Object, + MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), + openConnectionsHandler: null, + Mock.Of(), + suboptimalPartitionForceRefreshIntervalInSeconds: 2, + enableTcpConnectionEndpointRediscovery: true); + + PartitionAddressInformation addresses = await cache.TryGetAddressesAsync( + DocumentServiceRequest.Create(OperationType.Invalid, ResourceType.Address, AuthorizationTokenType.Invalid), + this.testPartitionKeyRangeIdentity, + this.serviceIdentity, + false, + CancellationToken.None); + Assert.IsNotNull(addresses.AllAddresses.Select(address => address.PhysicalUri == "https://blabla.com")); // Mark transport addresses to Unhealthy depcting a connection reset event. @@ -191,6 +229,7 @@ public async Task TestGatewayAddressCacheAvoidCacheRefresWhenAlreadyUpdatedAsync this.mockServiceConfigReader.Object, MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), openConnectionsHandler: null, + Mock.Of(), suboptimalPartitionForceRefreshIntervalInSeconds: 2, enableTcpConnectionEndpointRediscovery: true); @@ -269,19 +308,20 @@ public void GlobalAddressResolverUpdateAsyncSynchronizationTest() RequestTimeout = TimeSpan.FromSeconds(10) }; + IConnectionStateListener connectionStateListener = new ConnectionStateMuxListener(true); GlobalAddressResolver globalAddressResolver = new GlobalAddressResolver( endpointManager: globalEndpointManager, - partitionKeyRangeLocationCache: partitionKeyRangeLocationCache, - protocol: Documents.Client.Protocol.Tcp, - tokenProvider: this.mockTokenProvider.Object, - collectionCache: null, - routingMapProvider: null, - serviceConfigReader: this.mockServiceConfigReader.Object, - connectionPolicy: connectionPolicy, - httpClient: MockCosmosUtil.CreateCosmosHttpClient(() => new HttpClient(messageHandler))); - - ConnectionStateListener connectionStateListener = new ConnectionStateListener(globalAddressResolver); - connectionStateListener.OnConnectionEvent(ConnectionEvent.ReadEof, DateTime.Now, new Documents.Rntbd.ServerKey(new Uri("https://endpoint.azure.com:4040/"))); + null, + Protocol.Tcp, + this.mockTokenProvider.Object, + null, + null, + this.mockServiceConfigReader.Object, + connectionPolicy, + httpClient: MockCosmosUtil.CreateCosmosHttpClient(() => new HttpClient(messageHandler)), + connectionStateListener: connectionStateListener); + + connectionStateListener.OnConnectionEventAsync(ConnectionEvent.ReadEof, DateTime.Now, new ServerKey(new Uri("https://endpoint.azure.com:4040/"))).Wait(); }, state: null); } @@ -307,6 +347,7 @@ public async Task GatewayAddressCacheInNetworkRequestTestAsync() this.mockServiceConfigReader.Object, MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), openConnectionsHandler: null, + connectionStateListener: Mock.Of(), suboptimalPartitionForceRefreshIntervalInSeconds: 2, enableTcpConnectionEndpointRediscovery: true); @@ -367,6 +408,7 @@ public async Task OpenConnectionsAsync_WithValidOpenConnectionHandler_ShouldInvo this.mockServiceConfigReader.Object, MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), openConnectionsHandler: fakeOpenConnectionHandler, + connectionStateListener: Mock.Of(), suboptimalPartitionForceRefreshIntervalInSeconds: 2); // Act. @@ -413,6 +455,7 @@ public async Task OpenConnectionsAsync_WhenConnectionHandlerThrowsException_Shou this.mockServiceConfigReader.Object, MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), openConnectionsHandler: fakeOpenConnectionHandler, + connectionStateListener: Mock.Of(), suboptimalPartitionForceRefreshIntervalInSeconds: 2); // Act. @@ -462,6 +505,7 @@ public async Task OpenConnectionsAsync_WithNullOpenConnectionHandler_ShouldNotIn this.mockServiceConfigReader.Object, MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), openConnectionsHandler: null, + connectionStateListener: Mock.Of(), suboptimalPartitionForceRefreshIntervalInSeconds: 2); // Act. @@ -529,6 +573,7 @@ public async Task OpenConnectionsAsync_WithValidOpenConnectionHandlerAndCancella this.mockServiceConfigReader.Object, MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), openConnectionsHandler: fakeOpenConnectionHandler, + connectionStateListener: Mock.Of(), suboptimalPartitionForceRefreshIntervalInSeconds: 2); // Act. @@ -611,6 +656,7 @@ public async Task GlobalAddressResolver_OpenConnectionsToAllReplicasAsync_WithVa routingMapProvider: this.partitionKeyRangeCache.Object, serviceConfigReader: this.mockServiceConfigReader.Object, connectionPolicy: connectionPolicy, + connectionStateListener: Mock.Of(), httpClient: MockCosmosUtil.CreateCosmosHttpClient(() => new HttpClient(messageHandler))); globalAddressResolver.SetOpenConnectionsHandler( @@ -689,6 +735,7 @@ public async Task GlobalAddressResolver_OpenConnectionsToAllReplicasAsync_WhenHa routingMapProvider: this.partitionKeyRangeCache.Object, serviceConfigReader: this.mockServiceConfigReader.Object, connectionPolicy: connectionPolicy, + connectionStateListener: Mock.Of(), httpClient: MockCosmosUtil.CreateCosmosHttpClient(() => new HttpClient(messageHandler))); globalAddressResolver.SetOpenConnectionsHandler( @@ -778,6 +825,7 @@ public async Task GlobalAddressResolver_OpenConnectionsToAllReplicasAsync_WhenIn routingMapProvider: partitionKeyRangeCache.Object, serviceConfigReader: this.mockServiceConfigReader.Object, connectionPolicy: connectionPolicy, + connectionStateListener: Mock.Of(), httpClient: MockCosmosUtil.CreateCosmosHttpClient(() => new HttpClient(messageHandler))); globalAddressResolver.SetOpenConnectionsHandler( @@ -854,6 +902,7 @@ public async Task GlobalAddressResolver_OpenConnectionsToAllReplicasAsync_WhenNu routingMapProvider: this.partitionKeyRangeCache.Object, serviceConfigReader: this.mockServiceConfigReader.Object, connectionPolicy: connectionPolicy, + connectionStateListener: Mock.Of(), httpClient: MockCosmosUtil.CreateCosmosHttpClient(() => new HttpClient(messageHandler))); globalAddressResolver.SetOpenConnectionsHandler( @@ -939,6 +988,7 @@ public async Task OpenConnectionsAsync_WhenSomeAddressResolvingFailsWithExceptio this.mockServiceConfigReader.Object, mockHttpClient.Object, openConnectionsHandler: fakeOpenConnectionHandler, + connectionStateListener: Mock.Of(), suboptimalPartitionForceRefreshIntervalInSeconds: 2); // Act. @@ -1003,6 +1053,7 @@ public async Task TryGetAddressesAsync_WhenReplicaVlidationEnabled_ShouldValidat this.mockServiceConfigReader.Object, MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), openConnectionsHandler: fakeOpenConnectionHandler, + connectionStateListener: Mock.Of(), suboptimalPartitionForceRefreshIntervalInSeconds: 2, enableTcpConnectionEndpointRediscovery: true, replicaAddressValidationEnabled: true); @@ -1148,6 +1199,7 @@ public async Task TryGetAddressesAsync_WhenReplicaVlidationEnabledAndUnhealthyUr this.mockServiceConfigReader.Object, MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), openConnectionsHandler: fakeOpenConnectionHandler, + connectionStateListener: Mock.Of(), suboptimalPartitionForceRefreshIntervalInSeconds: 2, enableTcpConnectionEndpointRediscovery: true, replicaAddressValidationEnabled: true); @@ -1365,6 +1417,7 @@ public async Task TryGetAddressesAsync_WhenReplicaVlidationEnabledAndCSListenerM this.mockServiceConfigReader.Object, MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), openConnectionsHandler: fakeOpenConnectionHandler, + connectionStateListener: Mock.Of(), suboptimalPartitionForceRefreshIntervalInSeconds: 2, enableTcpConnectionEndpointRediscovery: true, replicaAddressValidationEnabled: true); @@ -1481,6 +1534,7 @@ public async Task TryGetAddressesAsync_WhenReplicaVlidationDisabled_ShouldNotVal this.mockServiceConfigReader.Object, MockCosmosUtil.CreateCosmosHttpClient(() => httpClient), openConnectionsHandler: fakeOpenConnectionHandler, + connectionStateListener: Mock.Of(), suboptimalPartitionForceRefreshIntervalInSeconds: 2, enableTcpConnectionEndpointRediscovery: true);