diff --git a/evergreen/evergreen.yml b/evergreen/evergreen.yml
index 94c6cddffa8..0199db345f5 100644
--- a/evergreen/evergreen.yml
+++ b/evergreen/evergreen.yml
@@ -869,6 +869,15 @@ functions:
cd ${DRIVERS_TOOLS}/.evergreen/csfle
. ./activate-kmstlsvenv.sh
python -u kms_http_server.py -v --ca_file ../x509gen/ca.pem --cert_file ../x509gen/server.pem --port 8002 --require_client_cert
+ - command: shell.exec
+ params:
+ background: true
+ shell: "bash"
+ script: |
+ #server.pem client cert
+ cd ${DRIVERS_TOOLS}/.evergreen/csfle
+ . ./activate-kmstlsvenv.sh
+ python -u kms_failpoint_server.py --port 9003
start-kms-mock-kmip-server:
- command: shell.exec
diff --git a/src/MongoDB.Driver.Encryption/CryptClientFactory.cs b/src/MongoDB.Driver.Encryption/CryptClientFactory.cs
index e9dc48cd839..0a2a5352b5a 100644
--- a/src/MongoDB.Driver.Encryption/CryptClientFactory.cs
+++ b/src/MongoDB.Driver.Encryption/CryptClientFactory.cs
@@ -163,6 +163,8 @@ public static CryptClient Create(CryptOptions options)
Library.mongocrypt_setopt_use_need_kms_credentials_state(handle);
+ Library.mongocrypt_setopt_retry_kms(handle, true);
+
Library.mongocrypt_init(handle);
if (options.IsCryptSharedLibRequired)
diff --git a/src/MongoDB.Driver.Encryption/CryptContext.cs b/src/MongoDB.Driver.Encryption/CryptContext.cs
index 23ab25cae99..297d669d733 100644
--- a/src/MongoDB.Driver.Encryption/CryptContext.cs
+++ b/src/MongoDB.Driver.Encryption/CryptContext.cs
@@ -156,18 +156,12 @@ public Binary FinalizeForEncryption()
}
///
- /// Gets a collection of KMS message requests to make
+ /// Gets the next KMS message request
///
- /// Collection of KMS Messages
- public KmsRequestCollection GetKmsMessageRequests()
+ public KmsRequest GetNextKmsMessageRequest()
{
- var requests = new List();
- for (IntPtr request = Library.mongocrypt_ctx_next_kms_ctx(_handle); request != IntPtr.Zero; request = Library.mongocrypt_ctx_next_kms_ctx(_handle))
- {
- requests.Add(new KmsRequest(request));
- }
-
- return new KmsRequestCollection(requests, this);
+ var request = Library.mongocrypt_ctx_next_kms_ctx(_handle);
+ return request == IntPtr.Zero ? null : new KmsRequest(request);
}
///
diff --git a/src/MongoDB.Driver.Encryption/KmsRequest.cs b/src/MongoDB.Driver.Encryption/KmsRequest.cs
index 2e8b6ecdc8a..a93a61c22fc 100644
--- a/src/MongoDB.Driver.Encryption/KmsRequest.cs
+++ b/src/MongoDB.Driver.Encryption/KmsRequest.cs
@@ -77,6 +77,11 @@ public string KmsProvider
}
}
+ ///
+ /// The number of milliseconds to wait before sending this request.
+ ///
+ public int Sleep => (int)(Library.mongocrypt_kms_ctx_usleep(_id) / 1000);
+
///
/// Gets the message to send to KMS.
///
@@ -88,6 +93,12 @@ public Binary GetMessage()
return binary;
}
+ ///
+ /// Indicates a network-level failure.
+ ///
+ /// A boolean indicating whether the failed request may be retried.
+ public bool Fail() => Library.mongocrypt_kms_ctx_fail(_id);
+
///
/// Feeds the response back to the libmongocrypt
///
diff --git a/src/MongoDB.Driver.Encryption/LibMongoCryptControllerBase.cs b/src/MongoDB.Driver.Encryption/LibMongoCryptControllerBase.cs
index a572b3f329d..9d07cd28b2a 100644
--- a/src/MongoDB.Driver.Encryption/LibMongoCryptControllerBase.cs
+++ b/src/MongoDB.Driver.Encryption/LibMongoCryptControllerBase.cs
@@ -211,22 +211,20 @@ private SslStreamSettings GetTlsStreamSettings(string kmsProvider)
private void ProcessNeedKmsState(CryptContext context, CancellationToken cancellationToken)
{
- var requests = context.GetKmsMessageRequests();
- foreach (var request in requests)
+ while (context.GetNextKmsMessageRequest() is { } request)
{
SendKmsRequest(request, cancellationToken);
}
- requests.MarkDone();
+ context.MarkKmsDone();
}
private async Task ProcessNeedKmsStateAsync(CryptContext context, CancellationToken cancellationToken)
{
- var requests = context.GetKmsMessageRequests();
- foreach (var request in requests)
+ while (context.GetNextKmsMessageRequest() is { } request)
{
await SendKmsRequestAsync(request, cancellationToken).ConfigureAwait(false);
}
- requests.MarkDone();
+ context.MarkKmsDone();
}
private void ProcessNeedMongoKeysState(CryptContext context, CancellationToken cancellationToken)
@@ -278,13 +276,21 @@ private static byte[] ProcessReadyState(CryptContext context)
private void SendKmsRequest(KmsRequest request, CancellationToken cancellation)
{
- var endpoint = CreateKmsEndPoint(request.Endpoint);
-
- var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
- var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
- using (var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation))
- using (var binary = request.GetMessage())
+ try
{
+ var endpoint = CreateKmsEndPoint(request.Endpoint);
+
+ var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
+ var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
+ using var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation);
+
+ var sleepMs = request.Sleep;
+ if (sleepMs > 0)
+ {
+ Thread.Sleep(sleepMs);
+ }
+
+ using var binary = request.GetMessage();
var requestBytes = binary.ToArray();
sslStream.Write(requestBytes, 0, requestBytes.Length);
@@ -292,22 +298,43 @@ private void SendKmsRequest(KmsRequest request, CancellationToken cancellation)
{
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
var count = sslStream.Read(buffer, 0, buffer.Length);
+
+ if (count == 0)
+ {
+ throw new IOException("Unexpected end of stream. No data was read from the SSL stream.");
+ }
+
var responseBytes = new byte[count];
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);
request.Feed(responseBytes);
}
}
+ catch (Exception ex) when (ex is IOException or SocketException)
+ {
+ if (!request.Fail())
+ {
+ throw;
+ }
+ }
}
private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken cancellation)
{
- var endpoint = CreateKmsEndPoint(request.Endpoint);
-
- var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
- var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
- using (var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false))
- using (var binary = request.GetMessage())
+ try
{
+ var endpoint = CreateKmsEndPoint(request.Endpoint);
+
+ var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
+ var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
+ using var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false);
+
+ var sleepMs = request.Sleep;
+ if (sleepMs > 0)
+ {
+ await Task.Delay(sleepMs, cancellation).ConfigureAwait(false);
+ }
+
+ using var binary = request.GetMessage();
var requestBytes = binary.ToArray();
await sslStream.WriteAsync(requestBytes, 0, requestBytes.Length).ConfigureAwait(false);
@@ -315,11 +342,24 @@ private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken can
{
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
var count = await sslStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false);
+
+ if (count == 0)
+ {
+ throw new IOException("Unexpected end of stream. No data was read from the SSL stream.");
+ }
+
var responseBytes = new byte[count];
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);
request.Feed(responseBytes);
}
}
+ catch (Exception ex) when (ex is IOException or SocketException)
+ {
+ if (!request.Fail())
+ {
+ throw;
+ }
+ }
}
// nested type
diff --git a/src/MongoDB.Driver.Encryption/Library.cs b/src/MongoDB.Driver.Encryption/Library.cs
index 6baf1abdad2..38538006fc9 100644
--- a/src/MongoDB.Driver.Encryption/Library.cs
+++ b/src/MongoDB.Driver.Encryption/Library.cs
@@ -147,6 +147,9 @@ static Library()
_mongocrypt_ctx_setopt_query_type = new Lazy(
() => __loader.Value.GetFunction(
("mongocrypt_ctx_setopt_query_type")), true);
+ _mongocrypt_setopt_retry_kms = new Lazy(
+ () => __loader.Value.GetFunction(
+ ("mongocrypt_setopt_retry_kms")), true);
_mongocrypt_ctx_status = new Lazy(
() => __loader.Value.GetFunction(("mongocrypt_ctx_status")), true);
@@ -210,6 +213,11 @@ static Library()
() => __loader.Value.GetFunction(("mongocrypt_ctx_destroy")), true);
_mongocrypt_kms_ctx_get_kms_provider = new Lazy(
() => __loader.Value.GetFunction(("mongocrypt_kms_ctx_get_kms_provider")), true);
+
+ _mongocrypt_kms_ctx_usleep = new Lazy(
+ () => __loader.Value.GetFunction(("mongocrypt_kms_ctx_usleep")), true);
+ _mongocrypt_kms_ctx_fail = new Lazy(
+ () => __loader.Value.GetFunction(("mongocrypt_kms_ctx_fail")), true);
}
///
@@ -287,6 +295,7 @@ public static string Version
internal static Delegates.mongocrypt_ctx_setopt_algorithm_range mongocrypt_ctx_setopt_algorithm_range => _mongocrypt_ctx_setopt_algorithm_range.Value;
internal static Delegates.mongocrypt_ctx_setopt_contention_factor mongocrypt_ctx_setopt_contention_factor => _mongocrypt_ctx_setopt_contention_factor.Value;
internal static Delegates.mongocrypt_ctx_setopt_query_type mongocrypt_ctx_setopt_query_type => _mongocrypt_ctx_setopt_query_type.Value;
+ internal static Delegates.mongocrypt_setopt_retry_kms mongocrypt_setopt_retry_kms => _mongocrypt_setopt_retry_kms.Value;
internal static Delegates.mongocrypt_ctx_state mongocrypt_ctx_state => _mongocrypt_ctx_state.Value;
internal static Delegates.mongocrypt_ctx_mongo_op mongocrypt_ctx_mongo_op => _mongocrypt_ctx_mongo_op.Value;
@@ -305,6 +314,9 @@ public static string Version
internal static Delegates.mongocrypt_ctx_destroy mongocrypt_ctx_destroy => _mongocrypt_ctx_destroy.Value;
internal static Delegates.mongocrypt_kms_ctx_get_kms_provider mongocrypt_kms_ctx_get_kms_provider => _mongocrypt_kms_ctx_get_kms_provider.Value;
+ internal static Delegates.mongocrypt_kms_ctx_usleep mongocrypt_kms_ctx_usleep => _mongocrypt_kms_ctx_usleep.Value;
+ internal static Delegates.mongocrypt_kms_ctx_fail mongocrypt_kms_ctx_fail => _mongocrypt_kms_ctx_fail.Value;
+
private static readonly Lazy __loader = new Lazy(
() => new LibraryLoader(), true);
private static readonly Lazy _mongocrypt_version;
@@ -392,6 +404,10 @@ public static string Version
private static readonly Lazy _mongocrypt_ctx_destroy;
private static readonly Lazy _mongocrypt_kms_ctx_get_kms_provider;
+ private static readonly Lazy _mongocrypt_kms_ctx_usleep;
+ private static readonly Lazy _mongocrypt_kms_ctx_fail;
+ private static readonly Lazy _mongocrypt_setopt_retry_kms;
+
// nested types
internal enum StatusType
{
@@ -640,6 +656,9 @@ public delegate bool
[return: MarshalAs(UnmanagedType.I1)]
public delegate bool mongocrypt_ctx_setopt_query_type(ContextSafeHandle ctx, [MarshalAs(UnmanagedType.LPStr)] string query_type, int length);
+ [return: MarshalAs(UnmanagedType.I1)]
+ public delegate bool mongocrypt_setopt_retry_kms(MongoCryptSafeHandle handle, bool enable);
+
public delegate CryptContext.StateCode mongocrypt_ctx_state(ContextSafeHandle handle);
[return: MarshalAs(UnmanagedType.I1)]
@@ -681,6 +700,11 @@ public delegate bool
public delegate void mongocrypt_ctx_destroy(IntPtr ptr);
public delegate IntPtr mongocrypt_kms_ctx_get_kms_provider(IntPtr handle, out uint length);
+
+ public delegate long mongocrypt_kms_ctx_usleep(IntPtr handle);
+
+ [return: MarshalAs(UnmanagedType.I1)]
+ public delegate bool mongocrypt_kms_ctx_fail(IntPtr handle);
}
}
}
diff --git a/src/MongoDB.Driver.Encryption/MongoDB.Driver.Encryption.csproj b/src/MongoDB.Driver.Encryption/MongoDB.Driver.Encryption.csproj
index 77c2546b839..f48395f28d5 100644
--- a/src/MongoDB.Driver.Encryption/MongoDB.Driver.Encryption.csproj
+++ b/src/MongoDB.Driver.Encryption/MongoDB.Driver.Encryption.csproj
@@ -14,11 +14,11 @@
- https://mciuploads.s3.amazonaws.com/libmongocrypt-release/macos/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz
- https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-64/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz
- https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-arm64/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz
- https://mciuploads.s3.amazonaws.com/libmongocrypt-release/alpine-arm64-earthly/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz
- https://mciuploads.s3.amazonaws.com/libmongocrypt-release/windows-test/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz
+ https://mciuploads.s3.amazonaws.com/libmongocrypt-release/macos/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz
+ https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-64/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz
+ https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-arm64/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz
+ https://mciuploads.s3.amazonaws.com/libmongocrypt-release/alpine-arm64-earthly/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz
+ https://mciuploads.s3.amazonaws.com/libmongocrypt-release/windows-test/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz
diff --git a/tests/MongoDB.Driver.Encryption.Tests/BasicTests.cs b/tests/MongoDB.Driver.Encryption.Tests/BasicTests.cs
index f24b5d3d780..ed1567a33f9 100644
--- a/tests/MongoDB.Driver.Encryption.Tests/BasicTests.cs
+++ b/tests/MongoDB.Driver.Encryption.Tests/BasicTests.cs
@@ -433,7 +433,7 @@ public void TestGetKmsProviderName(string kmsName)
using (var cryptClient = CryptClientFactory.Create(cryptOptions))
using (var context = cryptClient.StartCreateDataKeyContext(keyId))
{
- var request = context.GetKmsMessageRequests().Single();
+ var request = context.GetNextKmsMessageRequest();
request.KmsProvider.Should().Be(kmsName);
}
}
@@ -634,10 +634,9 @@ private static (CryptContext.StateCode stateProcessed, Binary binaryProduced, Bs
case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS:
{
- var requests = context.GetKmsMessageRequests();
- foreach (var req in requests)
+ while (context.GetNextKmsMessageRequest() is { } request)
{
- using var binary = req.GetMessage();
+ using var binary = request.GetMessage();
_output.WriteLine("Key Document: " + binary);
var postRequest = binary.ToString();
// TODO: add different hosts handling
@@ -645,11 +644,11 @@ private static (CryptContext.StateCode stateProcessed, Binary binaryProduced, Bs
var reply = ReadHttpTestFile(isKmsDecrypt ? "kms-decrypt-reply.txt" : "kms-encrypt-reply.txt");
_output.WriteLine("Reply: " + reply);
- req.Feed(Encoding.UTF8.GetBytes(reply));
- req.BytesNeeded.Should().Be(0);
+ request.Feed(Encoding.UTF8.GetBytes(reply));
+ request.BytesNeeded.Should().Be(0);
}
- requests.MarkDone();
+ context.MarkKmsDone();
return (CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS, null, null);
}
diff --git a/tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/prose-tests/ClientEncryptionProseTests.cs b/tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/prose-tests/ClientEncryptionProseTests.cs
index 51ba57dd7b7..7037164a1c1 100644
--- a/tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/prose-tests/ClientEncryptionProseTests.cs
+++ b/tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/prose-tests/ClientEncryptionProseTests.cs
@@ -24,6 +24,7 @@
using System.Net.Sockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
+using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Amazon.Runtime;
@@ -1428,6 +1429,163 @@ public void ExternalKeyVaultTest(
}
}
+ [Theory]
+ [ParameterAttributeData]
+ public async Task KmsRetryTest(
+ [Values("aws", "azure", "gcp")] string kmsProvider,
+ [Values("network", "http")] string failureType,
+ [Values(false, true)] bool async)
+ {
+ RequireServer.Check().Supports(Feature.ClientSideEncryption);
+ RequireEnvironment.Check().EnvironmentVariable("KMS_MOCK_SERVERS_ENABLED", isDefined: true);
+
+ const string endpoint = "127.0.0.1:9003";
+
+ var masterKey = kmsProvider switch
+ {
+ "aws" => new BsonDocument
+ {
+ { "region", "foo" },
+ { "key", "bar" },
+ { "endpoint", $"{endpoint}" }
+ },
+ "azure" => new BsonDocument
+ {
+ { "keyVaultEndpoint", $"{endpoint}" },
+ { "keyName", "foo" },
+ },
+ "gcp" => new BsonDocument
+ {
+ { "projectId", "foo" },
+ { "location", "bar" },
+ { "keyRing", "baz" },
+ { "keyName", "qux" },
+ { "endpoint", $"{endpoint}" }
+ },
+ _ => throw new ArgumentException(nameof(kmsProvider))
+ };
+
+ await ResetServer();
+
+ using var clientEncrypted = ConfigureClientEncrypted();
+ using var clientEncryption = ConfigureClientEncryption(
+ clientEncrypted,
+ kmsProviderFilter: kmsProvider,
+ kmsProviderConfigurator: KmsProviderEndpointConfigurator
+ );
+
+ var dataKeyOptions = CreateDataKeyOptions(kmsProvider, customMasterKey: masterKey);
+
+ await SetFailure(failureType, 1);
+
+ Guid dataKey = default;
+ Exception ex;
+ if (async)
+ {
+ ex = await Record.ExceptionAsync(async () => dataKey = await clientEncryption
+ .CreateDataKeyAsync(kmsProvider, dataKeyOptions, CancellationToken.None));
+ }
+ else
+ {
+ ex = Record.Exception(() => dataKey = clientEncryption
+ .CreateDataKey(kmsProvider, dataKeyOptions, CancellationToken.None));
+ }
+ ex.Should().BeNull();
+
+ await SetFailure(failureType, 1);
+
+ Exception ex2;
+ if (async)
+ {
+ ex2 = await Record.ExceptionAsync(async () => await clientEncryption.EncryptAsync(new BsonInt32(123),
+ new EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic", keyId: dataKey)));
+ }
+ else
+ {
+ ex2 = Record.Exception(() => clientEncryption.Encrypt(new BsonInt32(123),
+ new EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic", keyId: dataKey)));
+ }
+ ex2.Should().BeNull();
+
+ if (failureType == "network")
+ {
+ await SetFailure("network", 4);
+
+ Exception ex3;
+ if (async)
+ {
+ ex3 = await Record.ExceptionAsync(async () => dataKey = await clientEncryption
+ .CreateDataKeyAsync(kmsProvider, dataKeyOptions, CancellationToken.None));
+ }
+ else
+ {
+ ex3 = Record.Exception(() => dataKey = clientEncryption
+ .CreateDataKey(kmsProvider, dataKeyOptions, CancellationToken.None));
+ }
+ ex3.Should().NotBeNull();
+ }
+
+ return;
+
+ void KmsProviderEndpointConfigurator(string kmsProviderName, Dictionary kmsOptions)
+ {
+ switch (kmsProviderName)
+ {
+ case "aws":
+ break;
+ case "azure":
+ kmsOptions.Add("identityPlatformEndpoint", endpoint);
+ break;
+ case "gcp":
+ kmsOptions.Add("endpoint", endpoint);
+ break;
+ default:
+ throw new Exception($"Unexpected kmsProvider {endpoint}.");
+ }
+ }
+
+ HttpClient GetClient()
+ {
+ var handler = new HttpClientHandler
+ {
+ ClientCertificates =
+ {
+ new X509Certificate2(Environment.GetEnvironmentVariable("MONGO_X509_CLIENT_CERTIFICATE_PATH")!,
+ Environment.GetEnvironmentVariable("MONGO_X509_CLIENT_CERTIFICATE_PASSWORD"))
+ },
+ };
+
+ return new HttpClient(handler);
+ }
+
+ async Task SetFailure(string failure, int count)
+ {
+ using var client = GetClient();
+ var jsonData = new { count }.ToJson();
+ var content = new StringContent(jsonData, Encoding.UTF8, "application/json");
+
+ var uri = new Uri($"https://{endpoint}/set_failpoint/{failure}");
+ var response = await client.PostAsync(uri, content);
+
+ if (!response.IsSuccessStatusCode)
+ {
+ throw new Exception("Error while setting failure!");
+ }
+ }
+
+ async Task ResetServer()
+ {
+ using var client = GetClient();
+ var uri = new Uri($"https://{endpoint}/reset");
+ var response = await client.PostAsync(uri, null);
+
+ if (!response.IsSuccessStatusCode)
+ {
+ throw new Exception("Error while resetting!");
+ }
+ }
+ }
+
[Theory]
[ParameterAttributeData]
public void KmsTlsOptionsTest(