diff --git a/Cassandra.ThriftClient.Tests/FunctionalTests/CustomNodeTests/AuthenticationTest.cs b/Cassandra.ThriftClient.Tests/FunctionalTests/CustomNodeTests/AuthenticationTest.cs index 1498e1d..f1780f0 100644 --- a/Cassandra.ThriftClient.Tests/FunctionalTests/CustomNodeTests/AuthenticationTest.cs +++ b/Cassandra.ThriftClient.Tests/FunctionalTests/CustomNodeTests/AuthenticationTest.cs @@ -1,8 +1,11 @@ using System; using System.IO; +using System.Linq.Expressions; using Cassandra; +using Moq; + using NUnit.Framework; using SkbKontur.Cassandra.Local; @@ -84,11 +87,34 @@ public void TestNonAuthenticatedConnection() Assert.AreEqual("Provided username non-existent and/or password are incorrect", authenticationException.Why); } - private void SomeActionThatRequiresAuthentication(string username, string password) + [Test] + public void TestThriftConnectionClosedAfterNonSuccessfulAuthentication() + { + var logger = new Mock(MockBehavior.Strict); + logger.Setup(l => l.ForContext(It.IsAny())) + .Returns(logger.Object); + + logger.Setup(l => l.IsEnabledFor(It.IsAny())) + .Returns((LogLevel level) => level == LogLevel.Error); + + Expression> logAuthFailSetup = l => + l.Log(It.Is( + e => e.Exception is AuthenticationException + && e.MessageTemplate == "Error occured while opening thrift connection. Will try to close open transports. Failed action: {ActionName}." + && e.Properties != null + && e.Properties.ContainsKey("ActionName") + && e.Properties["ActionName"] as string == "login")); + + logger.Setup(logAuthFailSetup).Verifiable(); + Assert.Throws(() => SomeActionThatRequiresAuthentication("cassandra", "wrong_password", logger.Object)); + logger.Verify(logAuthFailSetup, Times.Exactly(2)); + } + + private void SomeActionThatRequiresAuthentication(string username, string password, ILog logger = null) { var settings = node.CreateSettings(); settings.Credentials = new Credentials(username, password); - using (var cluster = new CassandraCluster(settings, new SilentLog())) + using (var cluster = new CassandraCluster(settings, logger ?? new SilentLog())) cluster.RetrieveClusterConnection().RetrieveKeyspaces(); } diff --git a/Cassandra.ThriftClient/Core/ThriftConnection.cs b/Cassandra.ThriftClient/Core/ThriftConnection.cs index a2c252d..f512c1c 100644 --- a/Cassandra.ThriftClient/Core/ThriftConnection.cs +++ b/Cassandra.ThriftClient/Core/ThriftConnection.cs @@ -110,28 +110,63 @@ private void OpenTransport() if (!cassandraClient.InputProtocol.Transport.Equals(cassandraClient.OutputProtocol.Transport)) cassandraClient.OutputProtocol.Transport.Open(); - if (credentials != null) - cassandraClient.login(new AuthenticationRequest(new Dictionary - { - ["username"] = credentials.Username, - ["password"] = credentials.Password, - })); - - if (!string.IsNullOrEmpty(keyspaceName)) - cassandraClient.set_keyspace(keyspaceName); + WithCloseTransportOnError(Login, "login"); + WithCloseTransportOnError(SetKeyspace, "set keyspace"); } } + private void WithCloseTransportOnError(Action action, string actionName) + { + try + { + action(); + } + catch (Exception e) + { + logger.Error(e, "Error occured while opening thrift connection. Will try to close open transports. Failed action: {ActionName}.", new {ActionName = actionName}); + try + { + DoCloseTransport(); + } + catch (Exception closeException) + { + logger.Error(closeException, "Error occured while closing connection's transports."); + } + throw; + } + } + + private void Login() + { + if (credentials != null) + cassandraClient.login(new AuthenticationRequest(new Dictionary + { + ["username"] = credentials.Username, + ["password"] = credentials.Password, + })); + } + + private void SetKeyspace() + { + if (!string.IsNullOrEmpty(keyspaceName)) + cassandraClient.set_keyspace(keyspaceName); + } + private void CloseTransport() { lock (locker) { - cassandraClient.InputProtocol.Transport.Close(); - if (!cassandraClient.InputProtocol.Transport.Equals(cassandraClient.OutputProtocol.Transport)) - cassandraClient.OutputProtocol.Transport.Close(); + DoCloseTransport(); } } + private void DoCloseTransport() + { + cassandraClient.InputProtocol.Transport.Close(); + if (!cassandraClient.InputProtocol.Transport.Equals(cassandraClient.OutputProtocol.Transport)) + cassandraClient.OutputProtocol.Transport.Close(); + } + private Timestamp lastSuccessPingTimestamp; private bool isAlive; diff --git a/global.json b/global.json index 9ee6c25..70db897 100644 --- a/global.json +++ b/global.json @@ -1,5 +1,5 @@ { "sdk": { - "version": "3.1.300" + "version": "3.1.402" } } \ No newline at end of file