diff --git a/Snowflake.Data.Tests/UnitTests/PutGetStageInfoTest.cs b/Snowflake.Data.Tests/UnitTests/PutGetStageInfoTest.cs new file mode 100644 index 000000000..85a34b222 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/PutGetStageInfoTest.cs @@ -0,0 +1,59 @@ +using System.Collections.Generic; +using NUnit.Framework; +using Snowflake.Data.Core; +using Snowflake.Data.Core.FileTransfer; + +namespace Snowflake.Data.Tests.UnitTests +{ + [TestFixture] + public class PutGetStageInfoTest + { + [Test] + [TestCaseSource(nameof(TestCases))] + public void TestGcmRegionalUrl(string region, bool useRegionalUrl, string endPoint, string expectedGcmEndpoint) + { + // arrange + var stageInfo = CreateGCSStageInfo(region, useRegionalUrl, endPoint); + + // act + var gcsCustomEndpoint = stageInfo.GcsCustomEndpoint(); + + // assert + Assert.AreEqual(expectedGcmEndpoint, gcsCustomEndpoint); + } + + internal static IEnumerable TestCases() + { + yield return new object[] { "US-CENTRAL1", false, null, null }; + yield return new object[] { "US-CENTRAL1", false, "", null }; + yield return new object[] { "US-CENTRAL1", false, "null", null }; + yield return new object[] { "US-CENTRAL1", false, " ", null }; + yield return new object[] { "US-CENTRAL1", false, "example.com", "example.com" }; + yield return new object[] { "ME-CENTRAL2", false, null, "storage.me-central2.rep.googleapis.com" }; + yield return new object[] { "ME-CENTRAL2", true, null, "storage.me-central2.rep.googleapis.com" }; + yield return new object[] { "ME-CENTRAL2", true, "", "storage.me-central2.rep.googleapis.com" }; + yield return new object[] { "ME-CENTRAL2", true, " ", "storage.me-central2.rep.googleapis.com" }; + yield return new object[] { "ME-CENTRAL2", true, "example.com", "example.com" }; + yield return new object[] { "US-CENTRAL1", true, null, "storage.us-central1.rep.googleapis.com" }; + yield return new object[] { "US-CENTRAL1", true, "", "storage.us-central1.rep.googleapis.com" }; + yield return new object[] { "US-CENTRAL1", true, " ", "storage.us-central1.rep.googleapis.com" }; + yield return new object[] { "US-CENTRAL1", true, "null", "storage.us-central1.rep.googleapis.com" }; + yield return new object[] { "US-CENTRAL1", true, "example.com", "example.com" }; + } + + private PutGetStageInfo CreateGCSStageInfo(string region, bool useRegionalUrl, string endPoint) => + new PutGetStageInfo + { + locationType = SFRemoteStorageUtil.GCS_FS, + location = "some location", + path = "some path", + region = region, + storageAccount = "some storage account", + isClientSideEncrypted = true, + stageCredentials = new Dictionary(), + presignedUrl = "some pre-signed url", + endPoint = endPoint, + useRegionalUrl = useRegionalUrl + }; + } +} diff --git a/Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs b/Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs index d47742743..599a0b29b 100644 --- a/Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs @@ -18,7 +18,7 @@ namespace Snowflake.Data.Tests.UnitTests using Snowflake.Data.Tests.Mock; using Moq; - [TestFixture] + [TestFixture, NonParallelizable] class SFGCSClientTest : SFBaseTest { // Mock data for file metadata @@ -340,6 +340,37 @@ public async Task TestDownloadFileAsync(HttpStatusCode? httpStatusCode, ResultSt AssertForDownloadFileTests(expectedResultStatus); } + [Test] + [TestCase("us-central1", null, null, "https://storage.googleapis.com/mock-customer-stage/mock-id/tables/mock-key/")] + [TestCase("us-central1", "example.com", null, "https://example.com/mock-customer-stage/mock-id/tables/mock-key/")] + [TestCase("us-central1", "https://example.com", null, "https://example.com/mock-customer-stage/mock-id/tables/mock-key/")] + [TestCase("us-central1", null, true, "https://storage.us-central1.rep.googleapis.com/mock-customer-stage/mock-id/tables/mock-key/")] + [TestCase("me-central2", null, null, "https://storage.me-central2.rep.googleapis.com/mock-customer-stage/mock-id/tables/mock-key/")] + public void TestUseUriWithRegionsWhenNeeded(string region, string endPoint, bool useRegionalUrl, string expectedRequestUri) + { + var fileMetadata = new SFFileMetadata() + { + stageInfo = new PutGetStageInfo() + { + endPoint = endPoint, + location = Location, + locationType = SFRemoteStorageUtil.GCS_FS, + path = LocationPath, + presignedUrl = null, + region = region, + stageCredentials = _stageCredentials, + storageAccount = null, + useRegionalUrl = useRegionalUrl + } + }; + + // act + var uri = _client.FormBaseRequest(fileMetadata, "PUT").RequestUri.ToString(); + + // assert + Assert.AreEqual(expectedRequestUri, uri); + } + private void AssertForDownloadFileTests(ResultStatus expectedResultStatus) { if (expectedResultStatus == ResultStatus.DOWNLOADED) diff --git a/Snowflake.Data/Core/FileTransfer/StorageClient/SFGCSClient.cs b/Snowflake.Data/Core/FileTransfer/StorageClient/SFGCSClient.cs index f56baf2fa..b51afed36 100644 --- a/Snowflake.Data/Core/FileTransfer/StorageClient/SFGCSClient.cs +++ b/Snowflake.Data/Core/FileTransfer/StorageClient/SFGCSClient.cs @@ -10,6 +10,8 @@ using Newtonsoft.Json; using Snowflake.Data.Log; using System.Net; +using Google.Apis.Storage.v1; +using Google.Cloud.Storage.V1; namespace Snowflake.Data.Core.FileTransfer.StorageClient { @@ -52,6 +54,8 @@ class SFGCSClient : ISFRemoteStorageClient /// private WebRequest _customWebRequest = null; + private static readonly string[] s_scopes = new[] { StorageService.Scope.DevstorageFullControl }; + /// /// GCS client with access token. /// @@ -65,15 +69,32 @@ public SFGCSClient(PutGetStageInfo stageInfo) Logger.Debug("Constructing client using access token"); AccessToken = accessToken; GoogleCredential creds = GoogleCredential.FromAccessToken(accessToken, null); - StorageClient = Google.Cloud.Storage.V1.StorageClient.Create(creds); + var storageClientBuilder = new StorageClientBuilder + { + Credential = creds?.CreateScoped(s_scopes), + EncryptionKey = null + }; + StorageClient = BuildStorageClient(storageClientBuilder, stageInfo); } else { Logger.Info("No access token received from GS, constructing anonymous client with no encryption support"); - StorageClient = Google.Cloud.Storage.V1.StorageClient.CreateUnauthenticated(); + var storageClientBuilder = new StorageClientBuilder + { + UnauthenticatedAccess = true + }; + StorageClient = BuildStorageClient(storageClientBuilder, stageInfo); } } + private Google.Cloud.Storage.V1.StorageClient BuildStorageClient(StorageClientBuilder builder, PutGetStageInfo stageInfo) + { + var gcmCustomEndpoint = stageInfo.GcsCustomEndpoint(); + if (!string.IsNullOrEmpty(gcmCustomEndpoint)) + builder.BaseUri = gcmCustomEndpoint; + return builder.Build(); + } + internal void SetCustomWebRequest(WebRequest mockWebRequest) { _customWebRequest = mockWebRequest; @@ -112,7 +133,7 @@ public RemoteLocation ExtractBucketNameAndPath(string stageLocation) internal WebRequest FormBaseRequest(SFFileMetadata fileMetadata, string method) { string url = string.IsNullOrEmpty(fileMetadata.presignedUrl) ? - generateFileURL(fileMetadata.stageInfo.location, fileMetadata.RemoteFileName()) : + generateFileURL(fileMetadata.stageInfo, fileMetadata.RemoteFileName()) : fileMetadata.presignedUrl; WebRequest request = WebRequest.Create(url); @@ -219,19 +240,26 @@ public async Task GetFileHeaderAsync(SFFileMetadata fileMetadata, Ca return null; } - /// - /// Generate the file URL. - /// - /// The GCS file metadata. - /// The GCS file metadata. - internal string generateFileURL(string stageLocation, string fileName) + internal string generateFileURL(PutGetStageInfo stageInfo, string fileName) { - var gcsLocation = ExtractBucketNameAndPath(stageLocation); + var storageHostPath = ExtractStorageHostPath(stageInfo); + var gcsLocation = ExtractBucketNameAndPath(stageInfo.location); var fullFilePath = gcsLocation.key + fileName; - var link = "https://storage.googleapis.com/" + gcsLocation.bucket + "/" + fullFilePath; + var link = storageHostPath + gcsLocation.bucket + "/" + fullFilePath; return link; } + private string ExtractStorageHostPath(PutGetStageInfo stageInfo) + { + var gcsEndpoint = stageInfo.GcsCustomEndpoint(); + var storageHostPath = string.IsNullOrEmpty(gcsEndpoint) ? "https://storage.googleapis.com/" : gcsEndpoint; + if (!storageHostPath.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + storageHostPath = "https://" + storageHostPath; + if (!storageHostPath.EndsWith("/")) + storageHostPath = storageHostPath + "/"; + return storageHostPath; + } + /// /// Upload the file to the GCS location. /// diff --git a/Snowflake.Data/Core/RestResponse.cs b/Snowflake.Data/Core/RestResponse.cs index 64275fa42..b490ddcdc 100755 --- a/Snowflake.Data/Core/RestResponse.cs +++ b/Snowflake.Data/Core/RestResponse.cs @@ -8,6 +8,7 @@ using Newtonsoft.Json.Converters; using Newtonsoft.Json.Linq; using Snowflake.Data.Client; +using Snowflake.Data.Core.FileTransfer; namespace Snowflake.Data.Core { @@ -439,6 +440,22 @@ internal class PutGetStageInfo [JsonProperty(PropertyName = "endPoint", NullValueHandling = NullValueHandling.Ignore)] internal string endPoint { get; set; } + + [JsonProperty(PropertyName = "useRegionalUrl", NullValueHandling = NullValueHandling.Ignore)] + internal bool useRegionalUrl { get; set; } + + private const string GcsRegionMeCentral2 = "me-central2"; + + internal string GcsCustomEndpoint() + { + if (!(locationType ?? string.Empty).Equals(SFRemoteStorageUtil.GCS_FS, StringComparison.OrdinalIgnoreCase)) + return null; + if (!string.IsNullOrWhiteSpace(endPoint) && endPoint != "null") + return endPoint; + if (GcsRegionMeCentral2.Equals(region, StringComparison.OrdinalIgnoreCase) || useRegionalUrl) + return $"storage.{region.ToLower()}.rep.googleapis.com"; + return null; + } } internal class PutGetEncryptionMaterial