diff --git a/evergreen/evergreen.yml b/evergreen/evergreen.yml index 94c6cddffa8..0199db345f5 100644 --- a/evergreen/evergreen.yml +++ b/evergreen/evergreen.yml @@ -869,6 +869,15 @@ functions: cd ${DRIVERS_TOOLS}/.evergreen/csfle . ./activate-kmstlsvenv.sh python -u kms_http_server.py -v --ca_file ../x509gen/ca.pem --cert_file ../x509gen/server.pem --port 8002 --require_client_cert + - command: shell.exec + params: + background: true + shell: "bash" + script: | + #server.pem client cert + cd ${DRIVERS_TOOLS}/.evergreen/csfle + . ./activate-kmstlsvenv.sh + python -u kms_failpoint_server.py --port 9003 start-kms-mock-kmip-server: - command: shell.exec diff --git a/src/MongoDB.Driver.Encryption/CryptClientFactory.cs b/src/MongoDB.Driver.Encryption/CryptClientFactory.cs index e9dc48cd839..0a2a5352b5a 100644 --- a/src/MongoDB.Driver.Encryption/CryptClientFactory.cs +++ b/src/MongoDB.Driver.Encryption/CryptClientFactory.cs @@ -163,6 +163,8 @@ public static CryptClient Create(CryptOptions options) Library.mongocrypt_setopt_use_need_kms_credentials_state(handle); + Library.mongocrypt_setopt_retry_kms(handle, true); + Library.mongocrypt_init(handle); if (options.IsCryptSharedLibRequired) diff --git a/src/MongoDB.Driver.Encryption/CryptContext.cs b/src/MongoDB.Driver.Encryption/CryptContext.cs index 23ab25cae99..297d669d733 100644 --- a/src/MongoDB.Driver.Encryption/CryptContext.cs +++ b/src/MongoDB.Driver.Encryption/CryptContext.cs @@ -156,18 +156,12 @@ public Binary FinalizeForEncryption() } /// - /// Gets a collection of KMS message requests to make + /// Gets the next KMS message request /// - /// Collection of KMS Messages - public KmsRequestCollection GetKmsMessageRequests() + public KmsRequest GetNextKmsMessageRequest() { - var requests = new List(); - for (IntPtr request = Library.mongocrypt_ctx_next_kms_ctx(_handle); request != IntPtr.Zero; request = Library.mongocrypt_ctx_next_kms_ctx(_handle)) - { - requests.Add(new KmsRequest(request)); - } - - return new KmsRequestCollection(requests, this); + var request = Library.mongocrypt_ctx_next_kms_ctx(_handle); + return request == IntPtr.Zero ? null : new KmsRequest(request); } /// diff --git a/src/MongoDB.Driver.Encryption/KmsRequest.cs b/src/MongoDB.Driver.Encryption/KmsRequest.cs index 2e8b6ecdc8a..a93a61c22fc 100644 --- a/src/MongoDB.Driver.Encryption/KmsRequest.cs +++ b/src/MongoDB.Driver.Encryption/KmsRequest.cs @@ -77,6 +77,11 @@ public string KmsProvider } } + /// + /// The number of milliseconds to wait before sending this request. + /// + public int Sleep => (int)(Library.mongocrypt_kms_ctx_usleep(_id) / 1000); + /// /// Gets the message to send to KMS. /// @@ -88,6 +93,12 @@ public Binary GetMessage() return binary; } + /// + /// Indicates a network-level failure. + /// + /// A boolean indicating whether the failed request may be retried. + public bool Fail() => Library.mongocrypt_kms_ctx_fail(_id); + /// /// Feeds the response back to the libmongocrypt /// diff --git a/src/MongoDB.Driver.Encryption/LibMongoCryptControllerBase.cs b/src/MongoDB.Driver.Encryption/LibMongoCryptControllerBase.cs index a572b3f329d..9d07cd28b2a 100644 --- a/src/MongoDB.Driver.Encryption/LibMongoCryptControllerBase.cs +++ b/src/MongoDB.Driver.Encryption/LibMongoCryptControllerBase.cs @@ -211,22 +211,20 @@ private SslStreamSettings GetTlsStreamSettings(string kmsProvider) private void ProcessNeedKmsState(CryptContext context, CancellationToken cancellationToken) { - var requests = context.GetKmsMessageRequests(); - foreach (var request in requests) + while (context.GetNextKmsMessageRequest() is { } request) { SendKmsRequest(request, cancellationToken); } - requests.MarkDone(); + context.MarkKmsDone(); } private async Task ProcessNeedKmsStateAsync(CryptContext context, CancellationToken cancellationToken) { - var requests = context.GetKmsMessageRequests(); - foreach (var request in requests) + while (context.GetNextKmsMessageRequest() is { } request) { await SendKmsRequestAsync(request, cancellationToken).ConfigureAwait(false); } - requests.MarkDone(); + context.MarkKmsDone(); } private void ProcessNeedMongoKeysState(CryptContext context, CancellationToken cancellationToken) @@ -278,13 +276,21 @@ private static byte[] ProcessReadyState(CryptContext context) private void SendKmsRequest(KmsRequest request, CancellationToken cancellation) { - var endpoint = CreateKmsEndPoint(request.Endpoint); - - var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider); - var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory); - using (var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation)) - using (var binary = request.GetMessage()) + try { + var endpoint = CreateKmsEndPoint(request.Endpoint); + + var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider); + var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory); + using var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation); + + var sleepMs = request.Sleep; + if (sleepMs > 0) + { + Thread.Sleep(sleepMs); + } + + using var binary = request.GetMessage(); var requestBytes = binary.ToArray(); sslStream.Write(requestBytes, 0, requestBytes.Length); @@ -292,22 +298,43 @@ private void SendKmsRequest(KmsRequest request, CancellationToken cancellation) { var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive. var count = sslStream.Read(buffer, 0, buffer.Length); + + if (count == 0) + { + throw new IOException("Unexpected end of stream. No data was read from the SSL stream."); + } + var responseBytes = new byte[count]; Buffer.BlockCopy(buffer, 0, responseBytes, 0, count); request.Feed(responseBytes); } } + catch (Exception ex) when (ex is IOException or SocketException) + { + if (!request.Fail()) + { + throw; + } + } } private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken cancellation) { - var endpoint = CreateKmsEndPoint(request.Endpoint); - - var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider); - var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory); - using (var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false)) - using (var binary = request.GetMessage()) + try { + var endpoint = CreateKmsEndPoint(request.Endpoint); + + var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider); + var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory); + using var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false); + + var sleepMs = request.Sleep; + if (sleepMs > 0) + { + await Task.Delay(sleepMs, cancellation).ConfigureAwait(false); + } + + using var binary = request.GetMessage(); var requestBytes = binary.ToArray(); await sslStream.WriteAsync(requestBytes, 0, requestBytes.Length).ConfigureAwait(false); @@ -315,11 +342,24 @@ private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken can { var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive. var count = await sslStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false); + + if (count == 0) + { + throw new IOException("Unexpected end of stream. No data was read from the SSL stream."); + } + var responseBytes = new byte[count]; Buffer.BlockCopy(buffer, 0, responseBytes, 0, count); request.Feed(responseBytes); } } + catch (Exception ex) when (ex is IOException or SocketException) + { + if (!request.Fail()) + { + throw; + } + } } // nested type diff --git a/src/MongoDB.Driver.Encryption/Library.cs b/src/MongoDB.Driver.Encryption/Library.cs index 6baf1abdad2..38538006fc9 100644 --- a/src/MongoDB.Driver.Encryption/Library.cs +++ b/src/MongoDB.Driver.Encryption/Library.cs @@ -147,6 +147,9 @@ static Library() _mongocrypt_ctx_setopt_query_type = new Lazy( () => __loader.Value.GetFunction( ("mongocrypt_ctx_setopt_query_type")), true); + _mongocrypt_setopt_retry_kms = new Lazy( + () => __loader.Value.GetFunction( + ("mongocrypt_setopt_retry_kms")), true); _mongocrypt_ctx_status = new Lazy( () => __loader.Value.GetFunction(("mongocrypt_ctx_status")), true); @@ -210,6 +213,11 @@ static Library() () => __loader.Value.GetFunction(("mongocrypt_ctx_destroy")), true); _mongocrypt_kms_ctx_get_kms_provider = new Lazy( () => __loader.Value.GetFunction(("mongocrypt_kms_ctx_get_kms_provider")), true); + + _mongocrypt_kms_ctx_usleep = new Lazy( + () => __loader.Value.GetFunction(("mongocrypt_kms_ctx_usleep")), true); + _mongocrypt_kms_ctx_fail = new Lazy( + () => __loader.Value.GetFunction(("mongocrypt_kms_ctx_fail")), true); } /// @@ -287,6 +295,7 @@ public static string Version internal static Delegates.mongocrypt_ctx_setopt_algorithm_range mongocrypt_ctx_setopt_algorithm_range => _mongocrypt_ctx_setopt_algorithm_range.Value; internal static Delegates.mongocrypt_ctx_setopt_contention_factor mongocrypt_ctx_setopt_contention_factor => _mongocrypt_ctx_setopt_contention_factor.Value; internal static Delegates.mongocrypt_ctx_setopt_query_type mongocrypt_ctx_setopt_query_type => _mongocrypt_ctx_setopt_query_type.Value; + internal static Delegates.mongocrypt_setopt_retry_kms mongocrypt_setopt_retry_kms => _mongocrypt_setopt_retry_kms.Value; internal static Delegates.mongocrypt_ctx_state mongocrypt_ctx_state => _mongocrypt_ctx_state.Value; internal static Delegates.mongocrypt_ctx_mongo_op mongocrypt_ctx_mongo_op => _mongocrypt_ctx_mongo_op.Value; @@ -305,6 +314,9 @@ public static string Version internal static Delegates.mongocrypt_ctx_destroy mongocrypt_ctx_destroy => _mongocrypt_ctx_destroy.Value; internal static Delegates.mongocrypt_kms_ctx_get_kms_provider mongocrypt_kms_ctx_get_kms_provider => _mongocrypt_kms_ctx_get_kms_provider.Value; + internal static Delegates.mongocrypt_kms_ctx_usleep mongocrypt_kms_ctx_usleep => _mongocrypt_kms_ctx_usleep.Value; + internal static Delegates.mongocrypt_kms_ctx_fail mongocrypt_kms_ctx_fail => _mongocrypt_kms_ctx_fail.Value; + private static readonly Lazy __loader = new Lazy( () => new LibraryLoader(), true); private static readonly Lazy _mongocrypt_version; @@ -392,6 +404,10 @@ public static string Version private static readonly Lazy _mongocrypt_ctx_destroy; private static readonly Lazy _mongocrypt_kms_ctx_get_kms_provider; + private static readonly Lazy _mongocrypt_kms_ctx_usleep; + private static readonly Lazy _mongocrypt_kms_ctx_fail; + private static readonly Lazy _mongocrypt_setopt_retry_kms; + // nested types internal enum StatusType { @@ -640,6 +656,9 @@ public delegate bool [return: MarshalAs(UnmanagedType.I1)] public delegate bool mongocrypt_ctx_setopt_query_type(ContextSafeHandle ctx, [MarshalAs(UnmanagedType.LPStr)] string query_type, int length); + [return: MarshalAs(UnmanagedType.I1)] + public delegate bool mongocrypt_setopt_retry_kms(MongoCryptSafeHandle handle, bool enable); + public delegate CryptContext.StateCode mongocrypt_ctx_state(ContextSafeHandle handle); [return: MarshalAs(UnmanagedType.I1)] @@ -681,6 +700,11 @@ public delegate bool public delegate void mongocrypt_ctx_destroy(IntPtr ptr); public delegate IntPtr mongocrypt_kms_ctx_get_kms_provider(IntPtr handle, out uint length); + + public delegate long mongocrypt_kms_ctx_usleep(IntPtr handle); + + [return: MarshalAs(UnmanagedType.I1)] + public delegate bool mongocrypt_kms_ctx_fail(IntPtr handle); } } } diff --git a/src/MongoDB.Driver.Encryption/MongoDB.Driver.Encryption.csproj b/src/MongoDB.Driver.Encryption/MongoDB.Driver.Encryption.csproj index 77c2546b839..f48395f28d5 100644 --- a/src/MongoDB.Driver.Encryption/MongoDB.Driver.Encryption.csproj +++ b/src/MongoDB.Driver.Encryption/MongoDB.Driver.Encryption.csproj @@ -14,11 +14,11 @@ - https://mciuploads.s3.amazonaws.com/libmongocrypt-release/macos/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz - https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-64/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz - https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-arm64/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz - https://mciuploads.s3.amazonaws.com/libmongocrypt-release/alpine-arm64-earthly/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz - https://mciuploads.s3.amazonaws.com/libmongocrypt-release/windows-test/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz + https://mciuploads.s3.amazonaws.com/libmongocrypt-release/macos/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz + https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-64/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz + https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-arm64/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz + https://mciuploads.s3.amazonaws.com/libmongocrypt-release/alpine-arm64-earthly/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz + https://mciuploads.s3.amazonaws.com/libmongocrypt-release/windows-test/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz diff --git a/tests/MongoDB.Driver.Encryption.Tests/BasicTests.cs b/tests/MongoDB.Driver.Encryption.Tests/BasicTests.cs index f24b5d3d780..ed1567a33f9 100644 --- a/tests/MongoDB.Driver.Encryption.Tests/BasicTests.cs +++ b/tests/MongoDB.Driver.Encryption.Tests/BasicTests.cs @@ -433,7 +433,7 @@ public void TestGetKmsProviderName(string kmsName) using (var cryptClient = CryptClientFactory.Create(cryptOptions)) using (var context = cryptClient.StartCreateDataKeyContext(keyId)) { - var request = context.GetKmsMessageRequests().Single(); + var request = context.GetNextKmsMessageRequest(); request.KmsProvider.Should().Be(kmsName); } } @@ -634,10 +634,9 @@ private static (CryptContext.StateCode stateProcessed, Binary binaryProduced, Bs case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS: { - var requests = context.GetKmsMessageRequests(); - foreach (var req in requests) + while (context.GetNextKmsMessageRequest() is { } request) { - using var binary = req.GetMessage(); + using var binary = request.GetMessage(); _output.WriteLine("Key Document: " + binary); var postRequest = binary.ToString(); // TODO: add different hosts handling @@ -645,11 +644,11 @@ private static (CryptContext.StateCode stateProcessed, Binary binaryProduced, Bs var reply = ReadHttpTestFile(isKmsDecrypt ? "kms-decrypt-reply.txt" : "kms-encrypt-reply.txt"); _output.WriteLine("Reply: " + reply); - req.Feed(Encoding.UTF8.GetBytes(reply)); - req.BytesNeeded.Should().Be(0); + request.Feed(Encoding.UTF8.GetBytes(reply)); + request.BytesNeeded.Should().Be(0); } - requests.MarkDone(); + context.MarkKmsDone(); return (CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS, null, null); } diff --git a/tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/prose-tests/ClientEncryptionProseTests.cs b/tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/prose-tests/ClientEncryptionProseTests.cs index 51ba57dd7b7..7037164a1c1 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/prose-tests/ClientEncryptionProseTests.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/prose-tests/ClientEncryptionProseTests.cs @@ -24,6 +24,7 @@ using System.Net.Sockets; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; +using System.Text; using System.Threading; using System.Threading.Tasks; using Amazon.Runtime; @@ -1428,6 +1429,163 @@ public void ExternalKeyVaultTest( } } + [Theory] + [ParameterAttributeData] + public async Task KmsRetryTest( + [Values("aws", "azure", "gcp")] string kmsProvider, + [Values("network", "http")] string failureType, + [Values(false, true)] bool async) + { + RequireServer.Check().Supports(Feature.ClientSideEncryption); + RequireEnvironment.Check().EnvironmentVariable("KMS_MOCK_SERVERS_ENABLED", isDefined: true); + + const string endpoint = "127.0.0.1:9003"; + + var masterKey = kmsProvider switch + { + "aws" => new BsonDocument + { + { "region", "foo" }, + { "key", "bar" }, + { "endpoint", $"{endpoint}" } + }, + "azure" => new BsonDocument + { + { "keyVaultEndpoint", $"{endpoint}" }, + { "keyName", "foo" }, + }, + "gcp" => new BsonDocument + { + { "projectId", "foo" }, + { "location", "bar" }, + { "keyRing", "baz" }, + { "keyName", "qux" }, + { "endpoint", $"{endpoint}" } + }, + _ => throw new ArgumentException(nameof(kmsProvider)) + }; + + await ResetServer(); + + using var clientEncrypted = ConfigureClientEncrypted(); + using var clientEncryption = ConfigureClientEncryption( + clientEncrypted, + kmsProviderFilter: kmsProvider, + kmsProviderConfigurator: KmsProviderEndpointConfigurator + ); + + var dataKeyOptions = CreateDataKeyOptions(kmsProvider, customMasterKey: masterKey); + + await SetFailure(failureType, 1); + + Guid dataKey = default; + Exception ex; + if (async) + { + ex = await Record.ExceptionAsync(async () => dataKey = await clientEncryption + .CreateDataKeyAsync(kmsProvider, dataKeyOptions, CancellationToken.None)); + } + else + { + ex = Record.Exception(() => dataKey = clientEncryption + .CreateDataKey(kmsProvider, dataKeyOptions, CancellationToken.None)); + } + ex.Should().BeNull(); + + await SetFailure(failureType, 1); + + Exception ex2; + if (async) + { + ex2 = await Record.ExceptionAsync(async () => await clientEncryption.EncryptAsync(new BsonInt32(123), + new EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic", keyId: dataKey))); + } + else + { + ex2 = Record.Exception(() => clientEncryption.Encrypt(new BsonInt32(123), + new EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic", keyId: dataKey))); + } + ex2.Should().BeNull(); + + if (failureType == "network") + { + await SetFailure("network", 4); + + Exception ex3; + if (async) + { + ex3 = await Record.ExceptionAsync(async () => dataKey = await clientEncryption + .CreateDataKeyAsync(kmsProvider, dataKeyOptions, CancellationToken.None)); + } + else + { + ex3 = Record.Exception(() => dataKey = clientEncryption + .CreateDataKey(kmsProvider, dataKeyOptions, CancellationToken.None)); + } + ex3.Should().NotBeNull(); + } + + return; + + void KmsProviderEndpointConfigurator(string kmsProviderName, Dictionary kmsOptions) + { + switch (kmsProviderName) + { + case "aws": + break; + case "azure": + kmsOptions.Add("identityPlatformEndpoint", endpoint); + break; + case "gcp": + kmsOptions.Add("endpoint", endpoint); + break; + default: + throw new Exception($"Unexpected kmsProvider {endpoint}."); + } + } + + HttpClient GetClient() + { + var handler = new HttpClientHandler + { + ClientCertificates = + { + new X509Certificate2(Environment.GetEnvironmentVariable("MONGO_X509_CLIENT_CERTIFICATE_PATH")!, + Environment.GetEnvironmentVariable("MONGO_X509_CLIENT_CERTIFICATE_PASSWORD")) + }, + }; + + return new HttpClient(handler); + } + + async Task SetFailure(string failure, int count) + { + using var client = GetClient(); + var jsonData = new { count }.ToJson(); + var content = new StringContent(jsonData, Encoding.UTF8, "application/json"); + + var uri = new Uri($"https://{endpoint}/set_failpoint/{failure}"); + var response = await client.PostAsync(uri, content); + + if (!response.IsSuccessStatusCode) + { + throw new Exception("Error while setting failure!"); + } + } + + async Task ResetServer() + { + using var client = GetClient(); + var uri = new Uri($"https://{endpoint}/reset"); + var response = await client.PostAsync(uri, null); + + if (!response.IsSuccessStatusCode) + { + throw new Exception("Error while resetting!"); + } + } + } + [Theory] [ParameterAttributeData] public void KmsTlsOptionsTest(