From 3d0194c48db26516f60f3ef10c09747388944d75 Mon Sep 17 00:00:00 2001 From: Dariusz Stempniak Date: Tue, 8 Oct 2024 16:40:40 +0200 Subject: [PATCH 01/20] SNOW-1616044 set DateTimeKind.Unspecified for TIMESTAMP_NTZ, TIME and DATE (#1033) --- .../IntegrationTests/SFDbDataReaderIT.cs | 157 +++++++++--------- ...ructuredTypesWithEmbeddedUnstructuredIT.cs | 112 +++++++++++-- .../UnitTests/StructuredTypesTest.cs | 17 +- Snowflake.Data/Core/ArrowResultChunk.cs | 58 +++---- Snowflake.Data/Core/ArrowResultSet.cs | 46 ++--- .../Core/Converter/TimeConverter.cs | 23 +-- Snowflake.Data/Core/SFDataConverter.cs | 4 +- 7 files changed, 257 insertions(+), 160 deletions(-) diff --git a/Snowflake.Data.Tests/IntegrationTests/SFDbDataReaderIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFDbDataReaderIT.cs index b0e555185..c6952f84a 100755 --- a/Snowflake.Data.Tests/IntegrationTests/SFDbDataReaderIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFDbDataReaderIT.cs @@ -20,14 +20,14 @@ namespace Snowflake.Data.Tests.IntegrationTests class SFDbDataReaderIT : SFBaseTest { protected override string TestName => base.TestName + _resultFormat; - + private readonly ResultFormat _resultFormat; - + public SFDbDataReaderIT(ResultFormat resultFormat) { _resultFormat = resultFormat; } - + private void ValidateResultFormat(IDataReader reader) { Assert.AreEqual(_resultFormat, ((SnowflakeDbDataReader)reader).ResultFormat); @@ -39,7 +39,7 @@ public void TestRecordsAffected() using (var conn = CreateAndOpenConnection()) { CreateOrReplaceTable(conn, TableName, new []{"cola NUMBER"}); - + IDbCommand cmd = conn.CreateCommand(); string insertCommand = $"insert into {TableName} values (1),(1),(1)"; @@ -67,7 +67,7 @@ public void TestGetNumber() using (var conn = CreateAndOpenConnection()) { CreateOrReplaceTable(conn, TableName, new []{"cola NUMBER"}); - + IDbCommand cmd = conn.CreateCommand(); int numInt = 10000; @@ -114,7 +114,7 @@ public void TestGetNumber() Assert.IsFalse(reader.Read()); reader.Close(); - + CloseConnection(conn); } @@ -152,9 +152,9 @@ public void TestGetDouble() cmd.CommandText = $"select * from {TableName}"; IDataReader reader = cmd.ExecuteReader(); - + ValidateResultFormat(reader); - + Assert.IsTrue(reader.Read()); Assert.AreEqual(numFloat, reader.GetFloat(0)); Assert.AreEqual((decimal)numFloat, reader.GetDecimal(0)); @@ -235,7 +235,7 @@ public void TestGetTime(string inputTimeStr, int? precision) [TestCase("11:22:33.4455667")] [TestCase("23:59:59.9999999")] [TestCase("16:20:00.6666666")] - [TestCase("00:00:00.0000000")] + [TestCase("00:00:00.0000000")] [TestCase("00:00:00")] [TestCase("23:59:59.1")] [TestCase("23:59:59.12")] @@ -284,7 +284,7 @@ public void TestGetTimeSpan(string inputTimeStr) Assert.AreEqual(dateTimeTime.Minute, timeSpanTime.Minutes); Assert.AreEqual(dateTimeTime.Second, timeSpanTime.Seconds); Assert.AreEqual(dateTimeTime.Millisecond, timeSpanTime.Milliseconds); - + CloseConnection(conn); } } @@ -336,7 +336,7 @@ public void TestGetTimeSpanError() IDataReader reader = cmd.ExecuteReader(); ValidateResultFormat(reader); - + Assert.IsTrue(reader.Read()); // All types except TIME fail conversion when calling GetTimeSpan @@ -344,19 +344,19 @@ public void TestGetTimeSpanError() { try { - + ((SnowflakeDbDataReader)reader).GetTimeSpan(i); Assert.Fail("Data should not be converted to TIME"); } catch (SnowflakeDbException e) { - Assert.AreEqual(270003, e.ErrorCode); + Assert.AreEqual(270003, e.ErrorCode); } } // Null value // Null value can not be converted to TimeSpan because it is a non-nullable type - + try { ((SnowflakeDbDataReader)reader).GetTimeSpan(12); @@ -371,7 +371,7 @@ public void TestGetTimeSpanError() TimeSpan timeSpanTime = ((SnowflakeDbDataReader)reader).GetTimeSpan(13); reader.Close(); - + CloseConnection(conn); } } @@ -425,9 +425,9 @@ private void TestGetDateAndOrTime(string inputTimeStr, int? precision, SFDataTyp cmd.CommandText = $"select * from {TableName}"; IDataReader reader = cmd.ExecuteReader(); - + ValidateResultFormat(reader); - + Assert.IsTrue(reader.Read()); // For time, we getDateTime on the column and ignore date part @@ -435,7 +435,7 @@ private void TestGetDateAndOrTime(string inputTimeStr, int? precision, SFDataTyp if (dataType == SFDataType.DATE) { - Assert.AreEqual(inputTime.Date, reader.GetDateTime(0)); + Assert.AreEqual(inputTime.Date, actualTime); Assert.AreEqual(inputTime.Date.ToString("yyyy-MM-dd"), reader.GetString(0)); } if (dataType != SFDataType.DATE) @@ -449,14 +449,17 @@ private void TestGetDateAndOrTime(string inputTimeStr, int? precision, SFDataTyp { if (precision == 9) { - Assert.AreEqual(inputTime, reader.GetDateTime(0)); + Assert.AreEqual(inputTime, actualTime); } else { - Assert.AreEqual(inputTime.Date, reader.GetDateTime(0).Date); + Assert.AreEqual(inputTime.Date, actualTime.Date); } } + // DATE, TIME and TIMESTAMP_NTZ should be returned with DateTimeKind.Unspecified + Assert.AreEqual(DateTimeKind.Unspecified, actualTime.Kind); + reader.Close(); CloseConnection(conn); @@ -495,9 +498,9 @@ public void TestGetTimestampTZ(int timezoneOffsetInHours) using (var conn = CreateAndOpenConnection()) { CreateOrReplaceTable(conn, TableName, new []{"cola TIMESTAMP_TZ"}); - + DateTimeOffset now = DateTimeOffset.Now.ToOffset(TimeSpan.FromHours(timezoneOffsetInHours)); - + IDbCommand cmd = conn.CreateCommand(); string insertCommand = $"insert into {TableName} values (?)"; @@ -514,9 +517,9 @@ public void TestGetTimestampTZ(int timezoneOffsetInHours) cmd.CommandText = $"select * from {TableName}"; IDataReader reader = cmd.ExecuteReader(); - + ValidateResultFormat(reader); - + Assert.IsTrue(reader.Read()); DateTimeOffset dtOffset = (DateTimeOffset)reader.GetValue(0); reader.Close(); @@ -535,9 +538,9 @@ public void TestGetTimestampLTZ() using (var conn = CreateAndOpenConnection()) { CreateOrReplaceTable(conn, TableName, new []{"cola TIMESTAMP_LTZ"}); - + DateTimeOffset now = DateTimeOffset.Now; - + IDbCommand cmd = conn.CreateCommand(); string insertCommand = $"insert into {TableName} values (?)"; @@ -555,9 +558,9 @@ public void TestGetTimestampLTZ() cmd.CommandText = $"select * from {TableName}"; IDataReader reader = cmd.ExecuteReader(); - + ValidateResultFormat(reader); - + Assert.IsTrue(reader.Read()); DateTimeOffset dtOffset = (DateTimeOffset)reader.GetValue(0); reader.Close(); @@ -592,9 +595,9 @@ public void TestGetBoolean([Values]bool value) cmd.CommandText = $"select * from {TableName}"; IDataReader reader = cmd.ExecuteReader(); - + ValidateResultFormat(reader); - + Assert.IsTrue(reader.Read()); Assert.AreEqual(value, reader.GetBoolean(0)); reader.Close(); @@ -655,18 +658,18 @@ public void TestGetBinary() "col2 VARCHAR(50)", "col3 DOUBLE" }); - + byte[] testBytes = Encoding.UTF8.GetBytes("TEST_GET_BINARAY"); string testChars = "TEST_GET_CHARS"; double testDouble = 1.2345678; string insertCommand = $"insert into {TableName} values (?, '{testChars}',{testDouble.ToString()})"; IDbCommand cmd = conn.CreateCommand(); cmd.CommandText = insertCommand; - + var p1 = cmd.CreateParameter(); p1.ParameterName = "1"; p1.DbType = DbType.Binary; - p1.Value = testBytes; + p1.Value = testBytes; cmd.Parameters.Add(p1); var count = cmd.ExecuteNonQuery(); @@ -674,9 +677,9 @@ public void TestGetBinary() cmd.CommandText = $"select * from {TableName}"; IDataReader reader = cmd.ExecuteReader(); - + ValidateResultFormat(reader); - + Assert.IsTrue(reader.Read()); // Auto type conversion Assert.IsTrue(testBytes.SequenceEqual((byte[])reader.GetValue(0))); @@ -714,7 +717,7 @@ public void TestGetBinary() Assert.AreEqual(read, toReadLength); Assert.IsTrue(testSubBytes.SequenceEqual(sub)); - // Read subset 'GET_BINARAY' from actual 'TEST_GET_BINARAY' data + // Read subset 'GET_BINARAY' from actual 'TEST_GET_BINARAY' data // and copy inside existing buffer replacing Xs toReadLength = 11; byte[] testSubBytesWithTargetOffset = Encoding.UTF8.GetBytes("OFFSET GET_BINARAY EXTRA"); @@ -731,7 +734,7 @@ public void TestGetBinary() //** Invalid data offsets **/ try { - // Data offset > data length + // Data offset > data length reader.GetBytes(0, 25, sub, 7, toReadLength); Assert.Fail(); } @@ -754,7 +757,7 @@ public void TestGetBinary() //** Invalid buffer offsets **// try { - // Buffer offset > buffer length + // Buffer offset > buffer length reader.GetBytes(0, 6, sub, 25, toReadLength); Assert.Fail(); } @@ -775,7 +778,7 @@ public void TestGetBinary() } //** Null buffer **// - // If null, this method returns the size required of the array in order to fit all + // If null, this method returns the size required of the array in order to fit all // of the specified data. read = reader.GetBytes(0, 6, null, 0, toReadLength); Assert.AreEqual(testBytes.Length, read); @@ -828,7 +831,7 @@ public void TestGetChars() "col2 BINARY", "col3 DOUBLE" }); - + string testChars = "TEST_GET_CHARS"; byte[] testBytes = Encoding.UTF8.GetBytes("TEST_GET_BINARY"); double testDouble = 1.2345678; @@ -849,7 +852,7 @@ public void TestGetChars() IDataReader reader = cmd.ExecuteReader(); ValidateResultFormat(reader); - + Assert.IsTrue(reader.Read()); // Auto type conversion Assert.IsTrue(testChars.Equals(reader.GetValue(0))); @@ -889,7 +892,7 @@ public void TestGetChars() Assert.IsTrue(testSubChars.SequenceEqual(sub)); - // Read subset 'GET_CHARS' from actual 'TEST_GET_CHARS' data + // Read subset 'GET_CHARS' from actual 'TEST_GET_CHARS' data // and copy inside existing buffer replacing Xs char[] testSubCharsWithTargetOffset = "OFFSET GET_CHARS EXTRA".ToArray(); toReadLength = 9; @@ -906,7 +909,7 @@ public void TestGetChars() //** Invalid data offsets **// try { - // Data offset > data length + // Data offset > data length reader.GetChars(0, 25, sub, 7, toReadLength); Assert.Fail(); } @@ -929,7 +932,7 @@ public void TestGetChars() //** Invalid buffer offsets **// try { - // Buffer offset > buffer length + // Buffer offset > buffer length reader.GetChars(0, 6, sub, 25, toReadLength); Assert.Fail(); } @@ -950,7 +953,7 @@ public void TestGetChars() } //** Null buffer **// - // If null, this method returns the size required of the array in order to fit all + // If null, this method returns the size required of the array in order to fit all // of the specified data. read = reader.GetChars(0, 6, null, 0, toReadLength); Assert.AreEqual(testChars.Length, read); @@ -1016,7 +1019,7 @@ public void TestGetStream() "col2 BINARY", "col3 DOUBLE" }); - + string testChars = "TEST_GET_CHARS"; byte[] testBytes = Encoding.UTF8.GetBytes("TEST_GET_BINARY"); double testDouble = 1.2345678; @@ -1037,7 +1040,7 @@ public void TestGetStream() DbDataReader reader = (DbDataReader) cmd.ExecuteReader(); ValidateResultFormat(reader); - + Assert.IsTrue(reader.Read()); // Auto type conversion @@ -1087,9 +1090,9 @@ public void TestGetValueIndexOutOfBound() IDbCommand cmd = conn.CreateCommand(); cmd.CommandText = "select 1"; IDataReader reader = cmd.ExecuteReader(); - + ValidateResultFormat(reader); - + Assert.IsTrue(reader.Read()); try @@ -1128,7 +1131,7 @@ public void TestBasicDataReader() using (IDataReader reader = cmd.ExecuteReader()) { ValidateResultFormat(reader); - + Assert.AreEqual(2, reader.FieldCount); Assert.AreEqual(0, reader.Depth); Assert.IsTrue(((SnowflakeDbDataReader)reader).HasRows); @@ -1151,7 +1154,7 @@ public void TestBasicDataReader() reader.Close(); Assert.IsTrue(reader.IsClosed); - + try { reader.Read(); @@ -1199,7 +1202,7 @@ public void TestReadOutNullVal() using (IDataReader reader = cmd.ExecuteReader()) { ValidateResultFormat(reader); - + reader.Read(); object nullVal = reader.GetValue(0); Assert.AreEqual(DBNull.Value, nullVal); @@ -1211,7 +1214,7 @@ public void TestReadOutNullVal() } CloseConnection(conn); - } + } } [Test] @@ -1238,9 +1241,9 @@ public void TestGetGuid() cmd.CommandText = $"select * from {TableName}"; IDataReader reader = cmd.ExecuteReader(); - + ValidateResultFormat(reader); - + Assert.IsTrue(reader.Read()); Assert.AreEqual(val, reader.GetGuid(0)); @@ -1302,7 +1305,7 @@ public void TestCopyCmdResultSet() cmd.CommandText = $"create or replace stage {stageName}"; cmd.ExecuteNonQuery(); - cmd.CommandText = $"copy into {TableName} from @{stageName}"; + cmd.CommandText = $"copy into {TableName} from @{stageName}"; using (var rdr = cmd.ExecuteReader()) { // Can read the first row @@ -1433,7 +1436,7 @@ public void TestResultSetMetadata() CloseConnection(conn); } } - + [Test] public void TestHasRows() { @@ -1441,9 +1444,9 @@ public void TestHasRows() { DbCommand cmd = conn.CreateCommand(); cmd.CommandText = "select 1 where 1=2"; - + DbDataReader reader = cmd.ExecuteReader(); - + ValidateResultFormat(reader); Assert.IsFalse(reader.HasRows); @@ -1451,7 +1454,7 @@ public void TestHasRows() CloseConnection(conn); } } - + [Test] public void TestHasRowsMultiStatement() { @@ -1460,15 +1463,15 @@ public void TestHasRowsMultiStatement() DbCommand cmd = conn.CreateCommand(); cmd.CommandText = "select 1;" + "select 1 where 1=2;" + - "select 1;" + + "select 1;" + "select 1 where 1=2;"; - + DbParameter param = cmd.CreateParameter(); param.ParameterName = "MULTI_STATEMENT_COUNT"; param.DbType = DbType.Int16; param.Value = 4; cmd.Parameters.Add(param); - + DbDataReader reader = cmd.ExecuteReader(); // multi statements are always returned in JSON @@ -1483,7 +1486,7 @@ public void TestHasRowsMultiStatement() // select 1 where 1=2 Assert.IsFalse(reader.HasRows); reader.NextResult(); - + // select 1 Assert.IsTrue(reader.HasRows); reader.Read(); @@ -1494,12 +1497,12 @@ public void TestHasRowsMultiStatement() Assert.IsFalse(reader.HasRows); reader.NextResult(); Assert.IsFalse(reader.HasRows); - + reader.Close(); CloseConnection(conn); } } - + [Test] [TestCase("99")] // Int8 [TestCase("9.9")] // Int8 + scale @@ -1564,23 +1567,23 @@ public void TestTimestampTz(string testValue, int scale) using (var conn = CreateAndOpenConnection()) { DbCommand cmd = conn.CreateCommand(); - + cmd.CommandText = $"select '{testValue}'::TIMESTAMP_TZ({scale})"; using (SnowflakeDbDataReader reader = (SnowflakeDbDataReader)cmd.ExecuteReader()) { ValidateResultFormat(reader); reader.Read(); - + var expectedValue = DateTimeOffset.Parse(testValue); Assert.AreEqual(expectedValue, reader.GetValue(0)); } - + CloseConnection(conn); } } - + [Test] [TestCase("2019-01-01 12:12:12.1234567 +0500", 7)] [TestCase("2019-01-01 12:12:12.1234567 +1400", 7)] @@ -1591,23 +1594,23 @@ public void TestTimestampLtz(string testValue, int scale) using (var conn = CreateAndOpenConnection()) { DbCommand cmd = conn.CreateCommand(); - + cmd.CommandText = $"select '{testValue}'::TIMESTAMP_LTZ({scale})"; using (SnowflakeDbDataReader reader = (SnowflakeDbDataReader)cmd.ExecuteReader()) { ValidateResultFormat(reader); reader.Read(); - + var expectedValue = DateTimeOffset.Parse(testValue).ToLocalTime(); Assert.AreEqual(expectedValue, reader.GetValue(0)); } - + CloseConnection(conn); } } - + [Test] [TestCase("2019-01-01 12:12:12.1234567", 7)] [TestCase("0001-01-01 00:00:00.0000000", 9)] @@ -1617,19 +1620,19 @@ public void TestTimestampNtz(string testValue, int scale) using (var conn = CreateAndOpenConnection()) { DbCommand cmd = conn.CreateCommand(); - + cmd.CommandText = $"select '{testValue}'::TIMESTAMP_NTZ({scale})"; using (SnowflakeDbDataReader reader = (SnowflakeDbDataReader)cmd.ExecuteReader()) { ValidateResultFormat(reader); reader.Read(); - + var expectedValue = DateTime.Parse(testValue); Assert.AreEqual(expectedValue, reader.GetValue(0)); } - + CloseConnection(conn); } } diff --git a/Snowflake.Data.Tests/IntegrationTests/StructuredTypesWithEmbeddedUnstructuredIT.cs b/Snowflake.Data.Tests/IntegrationTests/StructuredTypesWithEmbeddedUnstructuredIT.cs index 6f88126d9..784aa4132 100644 --- a/Snowflake.Data.Tests/IntegrationTests/StructuredTypesWithEmbeddedUnstructuredIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/StructuredTypesWithEmbeddedUnstructuredIT.cs @@ -312,13 +312,55 @@ public void TestSelectDateTime(string dbValue, string dbType, DateTime? expected internal static IEnumerable DateTimeConversionCases() { - yield return new object[] { "2024-07-11 14:20:05", SFDataType.TIMESTAMP_NTZ.ToString(), DateTime.Parse("2024-07-11 14:20:05").ToUniversalTime(), DateTime.Parse("2024-07-11 14:20:05").ToUniversalTime() }; - yield return new object[] { "2024-07-11 14:20:05 +5:00", SFDataType.TIMESTAMP_TZ.ToString(), null, DateTime.Parse("2024-07-11 09:20:05").ToUniversalTime() }; - yield return new object[] {"2024-07-11 14:20:05 -7:00", SFDataType.TIMESTAMP_LTZ.ToString(), null, DateTime.Parse("2024-07-11 21:20:05").ToUniversalTime() }; - yield return new object[] { "2024-07-11", SFDataType.DATE.ToString(), DateTime.Parse("2024-07-11").ToUniversalTime(), DateTime.Parse("2024-07-11").ToUniversalTime() }; - yield return new object[] { "2024-07-11 14:20:05.123456789", SFDataType.TIMESTAMP_NTZ.ToString(), DateTime.Parse("2024-07-11 14:20:05.1234567").ToUniversalTime(), DateTime.Parse("2024-07-11 14:20:05.1234568").ToUniversalTime()}; - yield return new object[] { "2024-07-11 14:20:05.123456789 +5:00", SFDataType.TIMESTAMP_TZ.ToString(), null, DateTime.Parse("2024-07-11 09:20:05.1234568").ToUniversalTime() }; - yield return new object[] {"2024-07-11 14:20:05.123456789 -7:00", SFDataType.TIMESTAMP_LTZ.ToString(), null, DateTime.Parse("2024-07-11 21:20:05.1234568").ToUniversalTime() }; + yield return new object[] + { + "2024-07-11 14:20:05", + SFDataType.TIMESTAMP_NTZ.ToString(), + DateTime.Parse("2024-07-11 14:20:05"), + DateTime.Parse("2024-07-11 14:20:05") // kind -> Unspecified + }; + yield return new object[] + { + "2024-07-11 14:20:05 +5:00", + SFDataType.TIMESTAMP_TZ.ToString(), + null, + DateTime.SpecifyKind(DateTime.Parse("2024-07-11 09:20:05"), DateTimeKind.Utc) + }; + yield return new object[] + { + "2024-07-11 14:20:05 -7:00", + SFDataType.TIMESTAMP_LTZ.ToString(), + null, + DateTime.Parse("2024-07-11 21:20:05").ToLocalTime() + }; + yield return new object[] + { + "2024-07-11", + SFDataType.DATE.ToString(), + DateTime.SpecifyKind(DateTime.Parse("2024-07-11"), DateTimeKind.Unspecified), + DateTime.SpecifyKind(DateTime.Parse("2024-07-11"), DateTimeKind.Unspecified) + }; + yield return new object[] + { + "2024-07-11 14:20:05.123456789", + SFDataType.TIMESTAMP_NTZ.ToString(), + DateTime.Parse("2024-07-11 14:20:05.1234567"), + DateTime.Parse("2024-07-11 14:20:05.1234568") + }; + yield return new object[] + { + "2024-07-11 14:20:05.123456789 +5:00", + SFDataType.TIMESTAMP_TZ.ToString(), + null, + DateTime.SpecifyKind(DateTime.Parse("2024-07-11 09:20:05.1234568"), DateTimeKind.Utc) + }; + yield return new object[] + { + "2024-07-11 14:20:05.123456789 -7:00", + SFDataType.TIMESTAMP_LTZ.ToString(), + null, + DateTime.Parse("2024-07-11 21:20:05.1234568").ToLocalTime() + }; } [Test] @@ -354,13 +396,55 @@ public void TestSelectDateTimeOffset(string dbValue, string dbType, DateTime? ex internal static IEnumerable DateTimeOffsetConversionCases() { - yield return new object[] {"2024-07-11 14:20:05", SFDataType.TIMESTAMP_NTZ.ToString(), DateTime.Parse("2024-07-11 14:20:05").ToUniversalTime(), DateTimeOffset.Parse("2024-07-11 14:20:05Z")}; - yield return new object[] {"2024-07-11 14:20:05 +5:00", SFDataType.TIMESTAMP_TZ.ToString(), null, DateTimeOffset.Parse("2024-07-11 14:20:05 +5:00")}; - yield return new object[] {"2024-07-11 14:20:05 -7:00", SFDataType.TIMESTAMP_LTZ.ToString(), null, DateTimeOffset.Parse("2024-07-11 14:20:05 -7:00")}; - yield return new object[] {"2024-07-11", SFDataType.DATE.ToString(), DateTime.Parse("2024-07-11").ToUniversalTime(), DateTimeOffset.Parse("2024-07-11Z")}; - yield return new object[] {"2024-07-11 14:20:05.123456789", SFDataType.TIMESTAMP_NTZ.ToString(), DateTime.Parse("2024-07-11 14:20:05.1234567").ToUniversalTime(), DateTimeOffset.Parse("2024-07-11 14:20:05.1234568Z")}; - yield return new object[] {"2024-07-11 14:20:05.123456789 +5:00", SFDataType.TIMESTAMP_TZ.ToString(), null, DateTimeOffset.Parse("2024-07-11 14:20:05.1234568 +5:00")}; - yield return new object[] {"2024-07-11 14:20:05.123456789 -7:00", SFDataType.TIMESTAMP_LTZ.ToString(), null, DateTimeOffset.Parse("2024-07-11 14:20:05.1234568 -7:00")}; + yield return new object[] + { + "2024-07-11 14:20:05", + SFDataType.TIMESTAMP_NTZ.ToString(), + DateTime.Parse("2024-07-11 14:20:05"), + DateTimeOffset.Parse("2024-07-11 14:20:05Z") + }; + yield return new object[] + { + "2024-07-11 14:20:05 +5:00", + SFDataType.TIMESTAMP_TZ.ToString(), + null, + DateTimeOffset.Parse("2024-07-11 14:20:05 +5:00") + }; + yield return new object[] + { + "2024-07-11 14:20:05 -7:00", + SFDataType.TIMESTAMP_LTZ.ToString(), + null, + DateTimeOffset.Parse("2024-07-11 14:20:05 -7:00").ToLocalTime() + }; + yield return new object[] + { + "2024-07-11", + SFDataType.DATE.ToString(), + DateTime.SpecifyKind(DateTime.Parse("2024-07-11"), DateTimeKind.Unspecified), + DateTimeOffset.Parse("2024-07-11Z") + }; + yield return new object[] + { + "2024-07-11 14:20:05.123456789", + SFDataType.TIMESTAMP_NTZ.ToString(), + DateTime.Parse("2024-07-11 14:20:05.1234567"), + DateTimeOffset.Parse("2024-07-11 14:20:05.1234568Z") + }; + yield return new object[] + { + "2024-07-11 14:20:05.123456789 +5:00", + SFDataType.TIMESTAMP_TZ.ToString(), + null, + DateTimeOffset.Parse("2024-07-11 14:20:05.1234568 +5:00") + }; + yield return new object[] + { + "2024-07-11 14:20:05.123456789 -7:00", + SFDataType.TIMESTAMP_LTZ.ToString(), + null, + DateTimeOffset.Parse("2024-07-11 14:20:05.1234568 -7:00") + }; } private TimeZoneInfo GetTimeZone(SnowflakeDbConnection connection) diff --git a/Snowflake.Data.Tests/UnitTests/StructuredTypesTest.cs b/Snowflake.Data.Tests/UnitTests/StructuredTypesTest.cs index a10b4660c..0a91fdab5 100644 --- a/Snowflake.Data.Tests/UnitTests/StructuredTypesTest.cs +++ b/Snowflake.Data.Tests/UnitTests/StructuredTypesTest.cs @@ -23,24 +23,29 @@ public void TestTimeConversions(string value, string sfTypeString, object expect // assert Assert.AreEqual(expected, result); + + if (csharpType == typeof(DateTime)) + { + Assert.AreEqual(((DateTime)expected).Kind,((DateTime)result).Kind); + } } internal static IEnumerable TimeConversionCases() { - yield return new object[] {"2024-07-11 14:20:05", SFDataType.TIMESTAMP_NTZ.ToString(), DateTime.Parse("2024-07-11 14:20:05").ToUniversalTime()}; + yield return new object[] {"2024-07-11 14:20:05", SFDataType.TIMESTAMP_NTZ.ToString(), DateTime.Parse("2024-07-11 14:20:05")}; yield return new object[] {"2024-07-11 14:20:05", SFDataType.TIMESTAMP_NTZ.ToString(), DateTimeOffset.Parse("2024-07-11 14:20:05Z")}; yield return new object[] {"2024-07-11 14:20:05 +5:00", SFDataType.TIMESTAMP_TZ.ToString(), DateTimeOffset.Parse("2024-07-11 14:20:05 +5:00")}; - yield return new object[] {"2024-07-11 14:20:05 +5:00", SFDataType.TIMESTAMP_TZ.ToString(), DateTime.Parse("2024-07-11 09:20:05").ToUniversalTime()}; + yield return new object[] {"2024-07-11 14:20:05 +5:00", SFDataType.TIMESTAMP_TZ.ToString(), DateTime.SpecifyKind(DateTime.Parse("2024-07-11 09:20:05"), DateTimeKind.Utc)}; yield return new object[] {"2024-07-11 14:20:05 -7:00", SFDataType.TIMESTAMP_LTZ.ToString(), DateTimeOffset.Parse("2024-07-11 14:20:05 -7:00")}; - yield return new object[] {"2024-07-11 14:20:05 -7:00", SFDataType.TIMESTAMP_LTZ.ToString(), DateTime.Parse("2024-07-11 21:20:05").ToUniversalTime()}; + yield return new object[] {"2024-07-11 14:20:05 -7:00", SFDataType.TIMESTAMP_LTZ.ToString(), DateTime.Parse("2024-07-11 21:20:05").ToLocalTime()}; yield return new object[] {"14:20:05", SFDataType.TIME.ToString(), TimeSpan.Parse("14:20:05")}; yield return new object[] {"2024-07-11", SFDataType.DATE.ToString(), DateTime.Parse("2024-07-11")}; - yield return new object[] {"2024-07-11 14:20:05.123456", SFDataType.TIMESTAMP_NTZ.ToString(), DateTime.Parse("2024-07-11 14:20:05.123456").ToUniversalTime()}; + yield return new object[] {"2024-07-11 14:20:05.123456", SFDataType.TIMESTAMP_NTZ.ToString(), DateTime.Parse("2024-07-11 14:20:05.123456")}; yield return new object[] {"2024-07-11 14:20:05.123456", SFDataType.TIMESTAMP_NTZ.ToString(), DateTimeOffset.Parse("2024-07-11 14:20:05.123456Z")}; yield return new object[] {"2024-07-11 14:20:05.123456 +5:00", SFDataType.TIMESTAMP_TZ.ToString(), DateTimeOffset.Parse("2024-07-11 14:20:05.123456 +5:00")}; - yield return new object[] {"2024-07-11 14:20:05.123456 +5:00", SFDataType.TIMESTAMP_TZ.ToString(), DateTime.Parse("2024-07-11 09:20:05.123456").ToUniversalTime()}; + yield return new object[] {"2024-07-11 14:20:05.123456 +5:00", SFDataType.TIMESTAMP_TZ.ToString(), DateTime.SpecifyKind(DateTime.Parse("2024-07-11 09:20:05.123456"), DateTimeKind.Utc)}; yield return new object[] {"2024-07-11 14:20:05.123456 -7:00", SFDataType.TIMESTAMP_LTZ.ToString(), DateTimeOffset.Parse("2024-07-11 14:20:05.123456 -7:00")}; - yield return new object[] {"2024-07-11 14:20:05.123456 -7:00", SFDataType.TIMESTAMP_LTZ.ToString(), DateTime.Parse("2024-07-11 21:20:05.123456").ToUniversalTime()}; + yield return new object[] {"2024-07-11 14:20:05.123456 -7:00", SFDataType.TIMESTAMP_LTZ.ToString(), DateTime.Parse("2024-07-11 21:20:05.123456").ToLocalTime()}; yield return new object[] {"14:20:05.123456", SFDataType.TIME.ToString(), TimeSpan.Parse("14:20:05.123456")}; } } diff --git a/Snowflake.Data/Core/ArrowResultChunk.cs b/Snowflake.Data/Core/ArrowResultChunk.cs index 1616ec42a..85e5de62c 100755 --- a/Snowflake.Data/Core/ArrowResultChunk.cs +++ b/Snowflake.Data/Core/ArrowResultChunk.cs @@ -14,16 +14,16 @@ internal class ArrowResultChunk : BaseResultChunk internal override ResultFormat ResultFormat => ResultFormat.ARROW; private static readonly DateTimeOffset s_epochDate = SFDataConverter.UnixEpoch; - - private static readonly long[] s_powersOf10 = { - 1, - 10, - 100, - 1000, - 10000, - 100000, - 1000000, - 10000000, + + private static readonly long[] s_powersOf10 = { + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, 100000000, 1000000000 }; @@ -62,7 +62,7 @@ public ArrowResultChunk(RecordBatch recordBatch) RowCount = recordBatch.Length; ColumnCount = recordBatch.ColumnCount; ChunkIndex = -1; - + ResetTempTables(); } @@ -81,11 +81,11 @@ public void AddRecordBatch(RecordBatch recordBatch) { RecordBatch.Add(recordBatch); } - + internal override void Reset(ExecResponseChunk chunkInfo, int chunkIndex) { base.Reset(chunkInfo, chunkIndex); - + _currentBatchIndex = 0; _currentRecordIndex = -1; RecordBatch.Clear(); @@ -97,7 +97,7 @@ internal override bool Next() { if (_currentBatchIndex >= RecordBatch.Count) return false; - + _currentRecordIndex += 1; if (_currentRecordIndex < RecordBatch[_currentBatchIndex].Length) return true; @@ -149,7 +149,7 @@ public object ExtractCell(int columnIndex, SFDataType srcType, long scale) if (column.IsNull(_currentRecordIndex)) return DBNull.Value; - + switch (srcType) { case SFDataType.FIXED: @@ -170,7 +170,7 @@ public object ExtractCell(int columnIndex, SFDataType srcType, long scale) if (scale == 0) return _short[columnIndex][_currentRecordIndex]; return _short[columnIndex][_currentRecordIndex] / (decimal)s_powersOf10[scale]; - + case Int32Array array: if (_int[columnIndex] == null) _int[columnIndex] = array.Values.ToArray(); @@ -184,7 +184,7 @@ public object ExtractCell(int columnIndex, SFDataType srcType, long scale) if (scale == 0) return _long[columnIndex][_currentRecordIndex]; return _long[columnIndex][_currentRecordIndex] / (decimal)s_powersOf10[scale]; - + case Decimal128Array array: return array.GetValue(_currentRecordIndex); } @@ -210,8 +210,8 @@ public object ExtractCell(int columnIndex, SFDataType srcType, long scale) _int[columnIndex] = ((StringArray)column).ValueOffsets.ToArray(); } return StringArray.DefaultEncoding.GetString( - _byte[columnIndex], - _int[columnIndex][_currentRecordIndex], + _byte[columnIndex], + _int[columnIndex][_currentRecordIndex], _int[columnIndex][_currentRecordIndex + 1] - _int[columnIndex][_currentRecordIndex]); case SFDataType.VECTOR: @@ -250,16 +250,16 @@ public object ExtractCell(int columnIndex, SFDataType srcType, long scale) case SFDataType.BINARY: return ((BinaryArray)column).GetBytes(_currentRecordIndex).ToArray(); - + case SFDataType.DATE: if (_int[columnIndex] == null) _int[columnIndex] = ((Date32Array)column).Values.ToArray(); - return SFDataConverter.UnixEpoch.AddTicks(_int[columnIndex][_currentRecordIndex] * TicksPerDay); - + return DateTime.SpecifyKind(SFDataConverter.UnixEpoch.AddTicks(_int[columnIndex][_currentRecordIndex] * TicksPerDay), DateTimeKind.Unspecified); + case SFDataType.TIME: { long value; - + if (column.GetType() == typeof(Int32Array)) { if (_int[columnIndex] == null) @@ -278,7 +278,7 @@ public object ExtractCell(int columnIndex, SFDataType srcType, long scale) value = _long[columnIndex][_currentRecordIndex]; } - + if (scale == 0) return DateTimeOffset.FromUnixTimeSeconds(value).DateTime; if (scale <= 3) @@ -292,7 +292,7 @@ public object ExtractCell(int columnIndex, SFDataType srcType, long scale) var structCol = (StructArray)column; if (_long[columnIndex] == null) _long[columnIndex] = ((Int64Array)structCol.Fields[0]).Values.ToArray(); - + if (structCol.Fields.Count == 2) { if (_int[columnIndex] == null) @@ -309,7 +309,7 @@ public object ExtractCell(int columnIndex, SFDataType srcType, long scale) _fraction[columnIndex] = ((Int32Array)structCol.Fields[1]).Values.ToArray(); if (_int[columnIndex] == null) _int[columnIndex] = ((Int32Array)structCol.Fields[2]).Values.ToArray(); - + var epoch = _long[columnIndex][_currentRecordIndex]; var fraction = _fraction[columnIndex][_currentRecordIndex]; var timezone = _int[columnIndex][_currentRecordIndex]; @@ -331,7 +331,7 @@ public object ExtractCell(int columnIndex, SFDataType srcType, long scale) { if (_long[columnIndex] == null) _long[columnIndex] = ((Int64Array)column).Values.ToArray(); - + var value = _long[columnIndex][_currentRecordIndex]; var epoch = ExtractEpoch(value, scale); var fraction = ExtractFraction(value, scale); @@ -353,7 +353,7 @@ public object ExtractCell(int columnIndex, SFDataType srcType, long scale) { if (_long[columnIndex] == null) _long[columnIndex] = ((Int64Array)column).Values.ToArray(); - + var value = _long[columnIndex][_currentRecordIndex]; var epoch = ExtractEpoch(value, scale); var fraction = ExtractFraction(value, scale); @@ -362,7 +362,7 @@ public object ExtractCell(int columnIndex, SFDataType srcType, long scale) } throw new NotSupportedException($"Type {srcType} is not supported."); } - + private long ExtractEpoch(long value, long scale) { return value / s_powersOf10[scale]; diff --git a/Snowflake.Data/Core/ArrowResultSet.cs b/Snowflake.Data/Core/ArrowResultSet.cs index 56a636c4e..a3a6e2628 100755 --- a/Snowflake.Data/Core/ArrowResultSet.cs +++ b/Snowflake.Data/Core/ArrowResultSet.cs @@ -18,7 +18,7 @@ class ArrowResultSet : SFBaseResultSet internal override ResultFormat ResultFormat => ResultFormat.ARROW; private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - + private readonly int _totalChunkCount; private BaseResultChunk _currentChunk; private readonly IChunkDownloader _chunkDownloader; @@ -44,7 +44,7 @@ public ArrowResultSet(QueryExecResponseData responseData, SFStatement sfStatemen isClosed = false; queryId = responseData.queryId; - + ReadChunk(responseData); } catch(Exception ex) @@ -95,21 +95,21 @@ internal override async Task NextAsync() return false; } - + internal override bool Next() { ThrowIfClosed(); if (_currentChunk.Next()) return true; - + if (_totalChunkCount > 0) { s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + $" size uncompressed: {_currentChunk.UncompressedSize}"); _currentChunk = Task.Run(async() => await (_chunkDownloader.GetNextChunkAsync()).ConfigureAwait(false)).Result; - + return _currentChunk?.Next() ?? false; } @@ -154,21 +154,21 @@ internal override bool Rewind() return false; } - + private object GetObjectInternal(int ordinal) { ThrowIfClosed(); ThrowIfOutOfBounds(ordinal); - + var type = sfResultSetMetaData.GetTypesByIndex(ordinal).Item1; var scale = sfResultSetMetaData.GetScaleByIndex(ordinal); - + var value = ((ArrowResultChunk)_currentChunk).ExtractCell(ordinal, type, (int)scale); return value ?? DBNull.Value; - + } - + internal override object GetValue(int ordinal) { var value = GetObjectInternal(ordinal); @@ -176,7 +176,7 @@ internal override object GetValue(int ordinal) { return value; } - + object obj; checked { @@ -196,6 +196,10 @@ internal override object GetValue(int ordinal) break; case bool ret: obj = ret; break; + case DateTime ret: obj = ret; + break; + case DateTimeOffset ret: obj = ret; + break; default: { var dstType = sfResultSetMetaData.GetCSharpTypeByIndex(ordinal); @@ -217,7 +221,7 @@ internal override bool GetBoolean(int ordinal) { return (bool)GetObjectInternal(ordinal); } - + internal override byte GetByte(int ordinal) { var value = GetObjectInternal(ordinal); @@ -244,7 +248,7 @@ internal override char GetChar(int ordinal) { return ((string)GetObjectInternal(ordinal))[0]; } - + internal override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) { return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); @@ -303,7 +307,7 @@ internal override double GetDouble(int ordinal) case int ret: return ret; case short ret: return ret; case sbyte ret: return ret; - default: return (double)value; + default: return (double)value; } } @@ -374,7 +378,7 @@ internal override long GetInt64(int ordinal) } } } - + internal override string GetString(int ordinal) { var value = GetObjectInternal(ordinal); @@ -394,14 +398,14 @@ internal override string GetString(int ordinal) return Convert.ToString(value); } - + private void UpdateSessionStatus(QueryExecResponseData responseData) { SFSession session = this.sfStatement.SfSession; session.UpdateSessionProperties(responseData); session.UpdateSessionParameterMap(responseData.parameters); } - + private long ReadSubset(int ordinal, long dataOffset, T[] buffer, int bufferOffset, int length) where T : struct { if (dataOffset < 0) @@ -417,7 +421,7 @@ private long ReadSubset(int ordinal, long dataOffset, T[] buffer, int bufferO if (buffer != null && bufferOffset > buffer.Length) { throw new System.ArgumentException( - "Destination buffer is not long enough. Check the buffer offset, length, and the buffer's lower bounds.", + "Destination buffer is not long enough. Check the buffer offset, length, and the buffer's lower bounds.", nameof(buffer)); } @@ -446,14 +450,14 @@ private long ReadSubset(int ordinal, long dataOffset, T[] buffer, int bufferO "Source data is not long enough. Check the data offset, length, and the data's lower bounds.", nameof(dataOffset)); } - + long dataLength = data.Length - dataOffset; long elementsRead = Math.Min(length, dataLength); Array.Copy(data, dataOffset, buffer, bufferOffset, elementsRead); return elementsRead; - + } - + } } diff --git a/Snowflake.Data/Core/Converter/TimeConverter.cs b/Snowflake.Data/Core/Converter/TimeConverter.cs index 3f1252762..7a95de580 100644 --- a/Snowflake.Data/Core/Converter/TimeConverter.cs +++ b/Snowflake.Data/Core/Converter/TimeConverter.cs @@ -12,15 +12,15 @@ public object Convert(string value, SFDataType timestampType, Type fieldType) } if (timestampType == SFDataType.TIMESTAMP_NTZ) { - var dateTimeUtc = DateTime.Parse(value).ToUniversalTime(); + var dateTimeNoTz = DateTime.Parse(value); if (fieldType == typeof(DateTime) || fieldType == typeof(DateTime?)) { - return dateTimeUtc; + return dateTimeNoTz; } if (fieldType == typeof(DateTimeOffset) || fieldType == typeof(DateTimeOffset?)) { - return (DateTimeOffset) dateTimeUtc; + return (DateTimeOffset) DateTime.SpecifyKind(dateTimeNoTz, DateTimeKind.Utc); } throw new StructuredTypesReadingException($"Cannot read TIMESTAMP_NTZ into {fieldType} type"); @@ -35,21 +35,21 @@ public object Convert(string value, SFDataType timestampType, Type fieldType) } if (fieldType == typeof(DateTime) || fieldType == typeof(DateTime?)) { - return dateTimeOffset.ToUniversalTime().DateTime.ToUniversalTime(); + return dateTimeOffset.UtcDateTime; } throw new StructuredTypesReadingException($"Cannot read TIMESTAMP_TZ into {fieldType} type"); } if (timestampType == SFDataType.TIMESTAMP_LTZ) { - var dateTimeOffset = DateTimeOffset.Parse(value); + var dateTimeOffsetLocal = DateTimeOffset.Parse(value).ToLocalTime(); if (fieldType == typeof(DateTimeOffset) || fieldType == typeof(DateTimeOffset?)) { - return dateTimeOffset; + return dateTimeOffsetLocal; } if (fieldType == typeof(DateTime) || fieldType == typeof(DateTime?)) { - return dateTimeOffset.UtcDateTime; + return dateTimeOffsetLocal.LocalDateTime; } throw new StructuredTypesReadingException($"Cannot read TIMESTAMP_LTZ into {fieldType} type"); } @@ -63,13 +63,14 @@ public object Convert(string value, SFDataType timestampType, Type fieldType) } if (timestampType == SFDataType.DATE) { - if (fieldType == typeof(DateTimeOffset) || fieldType == typeof(DateTimeOffset?)) + var dateTime = DateTime.Parse(value); + if (fieldType == typeof(DateTime) || fieldType == typeof(DateTime?)) { - return DateTimeOffset.Parse(value).ToUniversalTime(); + return dateTime; } - if (fieldType == typeof(DateTime) || fieldType == typeof(DateTime?)) + if (fieldType == typeof(DateTimeOffset) || fieldType == typeof(DateTimeOffset?)) { - return DateTime.Parse(value).ToUniversalTime(); + return (DateTimeOffset) DateTime.SpecifyKind(dateTime, DateTimeKind.Utc); } throw new StructuredTypesReadingException($"Cannot not read DATE into {fieldType} type"); } diff --git a/Snowflake.Data/Core/SFDataConverter.cs b/Snowflake.Data/Core/SFDataConverter.cs index a415e5058..90e956314 100755 --- a/Snowflake.Data/Core/SFDataConverter.cs +++ b/Snowflake.Data/Core/SFDataConverter.cs @@ -152,12 +152,12 @@ private static DateTime ConvertToDateTime(UTF8Buffer srcVal, SFDataType srcType) { case SFDataType.DATE: long srcValLong = FastParser.FastParseInt64(srcVal.Buffer, srcVal.offset, srcVal.length); - return UnixEpoch.AddDays(srcValLong); + return DateTime.SpecifyKind(UnixEpoch.AddDays(srcValLong), DateTimeKind.Unspecified);; case SFDataType.TIME: case SFDataType.TIMESTAMP_NTZ: var tickDiff = GetTicksFromSecondAndNanosecond(srcVal); - return UnixEpoch.AddTicks(tickDiff); + return DateTime.SpecifyKind(UnixEpoch.AddTicks(tickDiff), DateTimeKind.Unspecified); default: throw new SnowflakeDbException(SFError.INVALID_DATA_CONVERSION, srcVal, srcType, typeof(DateTime)); From ef6db7e4f5efb523e5064f0788d464dbb6338b0a Mon Sep 17 00:00:00 2001 From: Juan Martinez Ramirez <126511805+sfc-gh-jmartinezramirez@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:39:28 -0600 Subject: [PATCH 02/20] Added SnowflakeDbDataReader implementation of GetEnumerator using DbEnumerator class (#1031) --- .../SFDbDataReaderGetEnumeratorIT.cs | 180 ++++++++++++++++++ .../Client/SnowflakeDbDataReader.cs | 5 +- 2 files changed, 181 insertions(+), 4 deletions(-) create mode 100755 Snowflake.Data.Tests/IntegrationTests/SFDbDataReaderGetEnumeratorIT.cs diff --git a/Snowflake.Data.Tests/IntegrationTests/SFDbDataReaderGetEnumeratorIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFDbDataReaderGetEnumeratorIT.cs new file mode 100755 index 000000000..88e25256e --- /dev/null +++ b/Snowflake.Data.Tests/IntegrationTests/SFDbDataReaderGetEnumeratorIT.cs @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Linq; +using System.Data.Common; +using System.Data; +using System.Globalization; +using System.Text; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.IntegrationTests +{ + [TestFixture(ResultFormat.ARROW)] + [TestFixture(ResultFormat.JSON)] + class SFDbDataReaderGetEnumeratorIT : SFBaseTest + { + protected override string TestName => base.TestName + _resultFormat; + + private readonly ResultFormat _resultFormat; + + public SFDbDataReaderGetEnumeratorIT(ResultFormat resultFormat) + { + _resultFormat = resultFormat; + } + + [Test] + public void TestGetEnumerator() + { + using (var conn = CreateAndOpenConnection()) + { + CreateAndPopulateTestTable(conn); + + string selectCommandText = $"select * from {TableName}"; + IDbCommand selectCmd = conn.CreateCommand(); + selectCmd.CommandText = selectCommandText; + DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader; + + var enumerator = reader.GetEnumerator(); + Assert.IsTrue(enumerator.MoveNext()); + Assert.AreEqual(3, (enumerator.Current as DbDataRecord).GetInt64(0)); + Assert.IsTrue(enumerator.MoveNext()); + Assert.AreEqual(5, (enumerator.Current as DbDataRecord).GetInt64(0)); + Assert.IsTrue(enumerator.MoveNext()); + Assert.AreEqual(8, (enumerator.Current as DbDataRecord).GetInt64(0)); + Assert.IsFalse(enumerator.MoveNext()); + + reader.Close(); + + DropTestTableAndCloseConnection(conn); + } + } + + [Test] + public void TestGetEnumeratorShouldBeEmptyWhenNotRowsReturned() + { + using (var conn = CreateAndOpenConnection()) + { + CreateAndPopulateTestTable(conn); + + string selectCommandText = $"select * from {TableName} WHERE cola > 10"; + IDbCommand selectCmd = conn.CreateCommand(); + selectCmd.CommandText = selectCommandText; + DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader; + + var enumerator = reader.GetEnumerator(); + Assert.IsFalse(enumerator.MoveNext()); + Assert.IsNull(enumerator.Current); + + reader.Close(); + DropTestTableAndCloseConnection(conn); + } + } + + [Test] + public void TestGetEnumeratorWithCastMethod() + { + using (var conn = CreateAndOpenConnection()) + { + CreateAndPopulateTestTable(conn); + + string selectCommandText = $"select * from {TableName}"; + IDbCommand selectCmd = conn.CreateCommand(); + selectCmd.CommandText = selectCommandText; + DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader; + + var dataRecords = reader.Cast().ToList(); + Assert.AreEqual(3, dataRecords.Count); + + reader.Close(); + + DropTestTableAndCloseConnection(conn); + } + } + + [Test] + public void TestGetEnumeratorForEachShouldNotEnterWhenResultsIsEmpty() + { + using (var conn = CreateAndOpenConnection()) + { + CreateAndPopulateTestTable(conn); + + string selectCommandText = $"select * from {TableName} WHERE cola > 10"; + IDbCommand selectCmd = conn.CreateCommand(); + selectCmd.CommandText = selectCommandText; + DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader; + + foreach (var record in reader) + { + Assert.Fail("Should not enter when results is empty"); + } + + reader.Close(); + DropTestTableAndCloseConnection(conn); + } + } + + [Test] + public void TestGetEnumeratorShouldThrowNonSupportedExceptionWhenReset() + { + using (var conn = CreateAndOpenConnection()) + { + CreateAndPopulateTestTable(conn); + + string selectCommandText = $"select * from {TableName}"; + IDbCommand selectCmd = conn.CreateCommand(); + selectCmd.CommandText = selectCommandText; + DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader; + + var enumerator = reader.GetEnumerator(); + Assert.IsTrue(enumerator.MoveNext()); + + Assert.Throws(() => enumerator.Reset()); + + reader.Close(); + + DropTestTableAndCloseConnection(conn); + } + } + + private void DropTestTableAndCloseConnection(DbConnection conn) + { + IDbCommand cmd = conn.CreateCommand(); + cmd.CommandText = $"drop table if exists {TableName}"; + var count = cmd.ExecuteNonQuery(); + Assert.AreEqual(0, count); + + CloseConnection(conn); + } + + private void CreateAndPopulateTestTable(DbConnection conn) + { + CreateOrReplaceTable(conn, TableName, new []{"cola NUMBER"}); + + var cmd = conn.CreateCommand(); + + string insertCommand = $"insert into {TableName} values (3),(5),(8)"; + cmd.CommandText = insertCommand; + cmd.ExecuteNonQuery(); + } + + private DbConnection CreateAndOpenConnection() + { + var conn = new SnowflakeDbConnection(ConnectionString); + conn.Open(); + SessionParameterAlterer.SetResultFormat(conn, _resultFormat); + return conn; + } + + private void CloseConnection(DbConnection conn) + { + SessionParameterAlterer.RestoreResultFormat(conn); + conn.Close(); + } + } +} diff --git a/Snowflake.Data/Client/SnowflakeDbDataReader.cs b/Snowflake.Data/Client/SnowflakeDbDataReader.cs index b7bc1615e..7d624bd80 100755 --- a/Snowflake.Data/Client/SnowflakeDbDataReader.cs +++ b/Snowflake.Data/Client/SnowflakeDbDataReader.cs @@ -189,10 +189,7 @@ public override double GetDouble(int ordinal) return resultSet.GetDouble(ordinal); } - public override IEnumerator GetEnumerator() - { - throw new NotImplementedException(); - } + public override IEnumerator GetEnumerator() => new DbEnumerator(this, closeReader: false); public override Type GetFieldType(int ordinal) { From 950fa558d94d432c567ebe20394b8373c481729e Mon Sep 17 00:00:00 2001 From: Dariusz Stempniak Date: Wed, 9 Oct 2024 13:09:06 +0200 Subject: [PATCH 03/20] SNOW-1640968 chunk downloader fix (#1022) --- .../UnitTests/SFReusableChunkTest.cs | 27 + Snowflake.Data/Core/BaseResultChunk.cs | 25 +- .../Core/SFBlockingChunkDownloaderV3.cs | 478 +++++++++--------- Snowflake.Data/Core/SFReusableChunk.cs | 57 ++- 4 files changed, 327 insertions(+), 260 deletions(-) diff --git a/Snowflake.Data.Tests/UnitTests/SFReusableChunkTest.cs b/Snowflake.Data.Tests/UnitTests/SFReusableChunkTest.cs index 6f021994b..25627dcaf 100755 --- a/Snowflake.Data.Tests/UnitTests/SFReusableChunkTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFReusableChunkTest.cs @@ -219,6 +219,33 @@ public void TestResetClearsChunkData() Assert.AreEqual(0, chunk.ChunkIndex); Assert.AreEqual(chunkInfo.url, chunk.Url); Assert.AreEqual(chunkInfo.rowCount, chunk.RowCount); + Assert.AreEqual(chunkInfo.uncompressedSize, chunk.UncompressedSize); + Assert.Greater(chunk.data.blockCount, 0); + Assert.Greater(chunk.data.metaBlockCount, 0); + } + + [Test] + public void TestClearRemovesAllChunkData() + { + const int RowCount = 3; + string data = "[ [\"1\"], [\"2\"], [\"3\"] ]"; + var chunk = PrepareChunkAsync(data, 1, RowCount).Result; + + ExecResponseChunk chunkInfo = new ExecResponseChunk() + { + url = "new_url", + uncompressedSize = 100, + rowCount = 200 + }; + + chunk.Clear(); + + Assert.AreEqual(0, chunk.ChunkIndex); + Assert.AreEqual(null, chunk.Url); + Assert.AreEqual(0, chunk.RowCount); + Assert.AreEqual(0, chunk.UncompressedSize); + Assert.AreEqual(0, chunk.data.blockCount); + Assert.AreEqual(0, chunk.data.metaBlockCount); } private async Task PrepareChunkAsync(string stringData, int colCount, int rowCount) diff --git a/Snowflake.Data/Core/BaseResultChunk.cs b/Snowflake.Data/Core/BaseResultChunk.cs index 37e8fa114..b3b764210 100755 --- a/Snowflake.Data/Core/BaseResultChunk.cs +++ b/Snowflake.Data/Core/BaseResultChunk.cs @@ -9,21 +9,21 @@ namespace Snowflake.Data.Core public abstract class BaseResultChunk : IResultChunk { internal abstract ResultFormat ResultFormat { get; } - + public int RowCount { get; protected set; } - + public int ColumnCount { get; protected set; } - + public int ChunkIndex { get; protected set; } internal int CompressedSize; - + internal int UncompressedSize; internal string Url { get; set; } internal string[,] RowSet { get; set; } - + public int GetRowCount() => RowCount; public int GetChunkIndex() => ChunkIndex; @@ -32,11 +32,11 @@ public abstract class BaseResultChunk : IResultChunk public abstract UTF8Buffer ExtractCell(int rowIndex, int columnIndex); public abstract UTF8Buffer ExtractCell(int columnIndex); - + internal abstract bool Next(); - + internal abstract bool Rewind(); - + internal virtual void Reset(ExecResponseChunk chunkInfo, int chunkIndex) { RowCount = chunkInfo.rowCount; @@ -46,6 +46,15 @@ internal virtual void Reset(ExecResponseChunk chunkInfo, int chunkIndex) UncompressedSize = chunkInfo.uncompressedSize; } + internal virtual void Clear() + { + RowCount = 0; + Url = null; + ChunkIndex = 0; + CompressedSize = 0; + UncompressedSize = 0; + } + internal virtual void ResetForRetry() { } diff --git a/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs b/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs index 282c502b1..2e19146aa 100755 --- a/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs +++ b/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs @@ -1,239 +1,239 @@ -/* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. - */ - -using System; -using System.IO.Compression; -using System.IO; -using System.Collections; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Net.Http; -using Newtonsoft.Json; -using System.Diagnostics; -using Newtonsoft.Json.Serialization; -using Snowflake.Data.Log; - -namespace Snowflake.Data.Core -{ - class SFBlockingChunkDownloaderV3 : IChunkDownloader - { - static private SFLogger logger = SFLoggerFactory.GetLogger(); - - private List chunkDatas = new List(); - - private string qrmk; - - private int nextChunkToDownloadIndex; - - private int nextChunkToConsumeIndex; - - // External cancellation token, used to stop donwload - private CancellationToken externalCancellationToken; - - private readonly int prefetchSlot; - - private readonly IRestRequester _RestRequester; - - private readonly SFSessionProperties sessionProperies; - - private Dictionary chunkHeaders; - - private readonly SFBaseResultSet ResultSet; - - private readonly List chunkInfos; - - private readonly List> taskQueues; - - public SFBlockingChunkDownloaderV3(int colCount, - List chunkInfos, string qrmk, - Dictionary chunkHeaders, - CancellationToken cancellationToken, - SFBaseResultSet ResultSet, - ResultFormat resultFormat) - { - this.qrmk = qrmk; - this.chunkHeaders = chunkHeaders; - this.nextChunkToDownloadIndex = 0; - this.ResultSet = ResultSet; - this._RestRequester = ResultSet.sfStatement.SfSession.restRequester; - this.sessionProperies = ResultSet.sfStatement.SfSession.properties; - this.prefetchSlot = Math.Min(chunkInfos.Count, GetPrefetchThreads(ResultSet)); - this.chunkInfos = chunkInfos; - this.nextChunkToConsumeIndex = 0; - this.taskQueues = new List>(); - externalCancellationToken = cancellationToken; - - for (int i=0; i sessionParameters = resultSet.sfStatement.SfSession.ParameterMap; - String val = (String)sessionParameters[SFSessionParameter.CLIENT_PREFETCH_THREADS]; - return Int32.Parse(val); - } - - public async Task GetNextChunkAsync() - { - logger.Info($"NextChunkToConsume: {nextChunkToConsumeIndex}, NextChunkToDownload: {nextChunkToDownloadIndex}"); - if (nextChunkToConsumeIndex < chunkInfos.Count) - { - Task chunk = taskQueues[nextChunkToConsumeIndex % prefetchSlot]; - - if (nextChunkToDownloadIndex < chunkInfos.Count && nextChunkToConsumeIndex > 0) - { - BaseResultChunk reusableChunk = chunkDatas[nextChunkToDownloadIndex % prefetchSlot]; - reusableChunk.Reset(chunkInfos[nextChunkToDownloadIndex], nextChunkToDownloadIndex); - - taskQueues[nextChunkToDownloadIndex % prefetchSlot] = DownloadChunkAsync(new DownloadContextV3() - { - chunk = reusableChunk, - qrmk = this.qrmk, - chunkHeaders = this.chunkHeaders, - cancellationToken = externalCancellationToken - }); - nextChunkToDownloadIndex++; - - // in case of one slot we need to return the chunk already downloaded - if (prefetchSlot == 1) - { - chunk = taskQueues[0]; - } - } - nextChunkToConsumeIndex++; - return await chunk; - } - else - { - return await Task.FromResult(null); - } - } - - private async Task DownloadChunkAsync(DownloadContextV3 downloadContext) - { - BaseResultChunk chunk = downloadContext.chunk; - int backOffInSec = 1; - bool retry = false; - int retryCount = 0; - int maxRetry = int.Parse(sessionProperies[SFSessionProperty.MAXHTTPRETRIES]); - - do - { - retry = false; - - S3DownloadRequest downloadRequest = - new S3DownloadRequest() - { - Url = new UriBuilder(chunk.Url).Uri, - qrmk = downloadContext.qrmk, - // s3 download request timeout to one hour - RestTimeout = TimeSpan.FromHours(1), - HttpTimeout = Timeout.InfiniteTimeSpan, // Disable timeout for each request - chunkHeaders = downloadContext.chunkHeaders, - sid = ResultSet.sfStatement.SfSession.sessionId - }; - - using (var httpResponse = await _RestRequester.GetAsync(downloadRequest, downloadContext.cancellationToken) - .ConfigureAwait(continueOnCapturedContext: false)) - using (Stream stream = await httpResponse.Content.ReadAsStreamAsync() - .ConfigureAwait(continueOnCapturedContext: false)) - { - // retry on chunk downloading since the retry logic in HttpClient.RetryHandler - // doesn't cover this. The GET request could be succeeded but network error - // still could happen during reading chunk data from stream and that needs - // retry as well. - try - { - IEnumerable encoding; - if (httpResponse.Content.Headers.TryGetValues("Content-Encoding", out encoding)) - { - if (String.Compare(encoding.First(), "gzip", true) == 0) - { - using (Stream streamGzip = new GZipStream(stream, CompressionMode.Decompress)) - { - await ParseStreamIntoChunk(streamGzip, chunk).ConfigureAwait(false); - } - } - else - { - await ParseStreamIntoChunk(stream, chunk).ConfigureAwait(false); - } - } - else - { - await ParseStreamIntoChunk(stream, chunk).ConfigureAwait(false); - } - } - catch (Exception e) - { - if ((maxRetry <= 0) || (retryCount < maxRetry)) - { - logger.Debug($"Retry {retryCount}/{maxRetry} of parse stream to chunk error: " + e.Message); - retry = true; - // reset the chunk before retry in case there could be garbage - // data left from last attempt - chunk.ResetForRetry(); - await Task.Delay(TimeSpan.FromSeconds(backOffInSec), downloadContext.cancellationToken).ConfigureAwait(false); - ++retryCount; - // Set next backoff time - backOffInSec = backOffInSec * 2; - if (backOffInSec > HttpUtil.MAX_BACKOFF) - { - backOffInSec = HttpUtil.MAX_BACKOFF; - } - } - else - { - //parse error - logger.Error("Failed retries of parse stream to chunk error: " + e.Message); - throw new Exception("Parse stream to chunk error: " + e.Message); - } - } - } - } while (retry); - logger.Info($"Succeed downloading chunk #{chunk.ChunkIndex}"); - return chunk; - } - - private async Task ParseStreamIntoChunk(Stream content, BaseResultChunk resultChunk) - { - IChunkParser parser = ChunkParserFactory.Instance.GetParser(resultChunk.ResultFormat, content); - await parser.ParseChunk(resultChunk); - } - } - - class DownloadContextV3 - { - public BaseResultChunk chunk { get; set; } - - public string qrmk { get; set; } - - public Dictionary chunkHeaders { get; set; } - - public CancellationToken cancellationToken { get; set; } - } -} +/* + * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.IO.Compression; +using System.IO; +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using System.Net.Http; +using Newtonsoft.Json; +using System.Diagnostics; +using Newtonsoft.Json.Serialization; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Core +{ + class SFBlockingChunkDownloaderV3 : IChunkDownloader + { + static private SFLogger logger = SFLoggerFactory.GetLogger(); + + private List chunkDatas = new List(); + + private string qrmk; + + private int nextChunkToDownloadIndex; + + private int nextChunkToConsumeIndex; + + // External cancellation token, used to stop donwload + private CancellationToken externalCancellationToken; + + private readonly int prefetchSlot; + + private readonly IRestRequester _RestRequester; + + private readonly SFSessionProperties sessionProperies; + + private Dictionary chunkHeaders; + + private readonly SFBaseResultSet ResultSet; + + private readonly List chunkInfos; + + private readonly List> taskQueues; + + public SFBlockingChunkDownloaderV3(int colCount, + List chunkInfos, string qrmk, + Dictionary chunkHeaders, + CancellationToken cancellationToken, + SFBaseResultSet ResultSet, + ResultFormat resultFormat) + { + this.qrmk = qrmk; + this.chunkHeaders = chunkHeaders; + this.nextChunkToDownloadIndex = 0; + this.ResultSet = ResultSet; + this._RestRequester = ResultSet.sfStatement.SfSession.restRequester; + this.sessionProperies = ResultSet.sfStatement.SfSession.properties; + this.prefetchSlot = Math.Min(chunkInfos.Count, GetPrefetchThreads(ResultSet)); + this.chunkInfos = chunkInfos; + this.nextChunkToConsumeIndex = 0; + this.taskQueues = new List>(); + externalCancellationToken = cancellationToken; + + for (int i=0; i sessionParameters = resultSet.sfStatement.SfSession.ParameterMap; + String val = (String)sessionParameters[SFSessionParameter.CLIENT_PREFETCH_THREADS]; + return Int32.Parse(val); + } + + public async Task GetNextChunkAsync() + { + logger.Info($"NextChunkToConsume: {nextChunkToConsumeIndex}, NextChunkToDownload: {nextChunkToDownloadIndex}"); + if (nextChunkToConsumeIndex < chunkInfos.Count) + { + Task chunk = taskQueues[nextChunkToConsumeIndex % prefetchSlot]; + + if (nextChunkToDownloadIndex < chunkInfos.Count && nextChunkToConsumeIndex > 0) + { + BaseResultChunk reusableChunk = chunkDatas[nextChunkToDownloadIndex % prefetchSlot]; + reusableChunk.Reset(chunkInfos[nextChunkToDownloadIndex], nextChunkToDownloadIndex); + + taskQueues[nextChunkToDownloadIndex % prefetchSlot] = DownloadChunkAsync(new DownloadContextV3() + { + chunk = reusableChunk, + qrmk = this.qrmk, + chunkHeaders = this.chunkHeaders, + cancellationToken = externalCancellationToken + }); + nextChunkToDownloadIndex++; + + // in case of one slot we need to return the chunk already downloaded + if (prefetchSlot == 1) + { + chunk = taskQueues[0]; + } + } + nextChunkToConsumeIndex++; + return await chunk; + } + else + { + return await Task.FromResult(null); + } + } + + private async Task DownloadChunkAsync(DownloadContextV3 downloadContext) + { + BaseResultChunk chunk = downloadContext.chunk; + int backOffInSec = 1; + bool retry = false; + int retryCount = 0; + int maxRetry = int.Parse(sessionProperies[SFSessionProperty.MAXHTTPRETRIES]); + + do + { + retry = false; + + S3DownloadRequest downloadRequest = + new S3DownloadRequest() + { + Url = new UriBuilder(chunk.Url).Uri, + qrmk = downloadContext.qrmk, + // s3 download request timeout to one hour + RestTimeout = TimeSpan.FromHours(1), + HttpTimeout = Timeout.InfiniteTimeSpan, // Disable timeout for each request + chunkHeaders = downloadContext.chunkHeaders, + sid = ResultSet.sfStatement.SfSession.sessionId + }; + + using (var httpResponse = await _RestRequester.GetAsync(downloadRequest, downloadContext.cancellationToken) + .ConfigureAwait(continueOnCapturedContext: false)) + using (Stream stream = await httpResponse.Content.ReadAsStreamAsync() + .ConfigureAwait(continueOnCapturedContext: false)) + { + // retry on chunk downloading since the retry logic in HttpClient.RetryHandler + // doesn't cover this. The GET request could be succeeded but network error + // still could happen during reading chunk data from stream and that needs + // retry as well. + try + { + IEnumerable encoding; + if (httpResponse.Content.Headers.TryGetValues("Content-Encoding", out encoding)) + { + if (String.Compare(encoding.First(), "gzip", true) == 0) + { + using (Stream streamGzip = new GZipStream(stream, CompressionMode.Decompress)) + { + await ParseStreamIntoChunk(streamGzip, chunk).ConfigureAwait(false); + } + } + else + { + await ParseStreamIntoChunk(stream, chunk).ConfigureAwait(false); + } + } + else + { + await ParseStreamIntoChunk(stream, chunk).ConfigureAwait(false); + } + } + catch (Exception e) + { + if ((maxRetry <= 0) || (retryCount < maxRetry)) + { + logger.Debug($"Retry {retryCount}/{maxRetry} of parse stream to chunk error: " + e.Message); + retry = true; + // reset the chunk before retry in case there could be garbage + // data left from last attempt + chunk.ResetForRetry(); + await Task.Delay(TimeSpan.FromSeconds(backOffInSec), downloadContext.cancellationToken).ConfigureAwait(false); + ++retryCount; + // Set next backoff time + backOffInSec = backOffInSec * 2; + if (backOffInSec > HttpUtil.MAX_BACKOFF) + { + backOffInSec = HttpUtil.MAX_BACKOFF; + } + } + else + { + //parse error + logger.Error("Failed retries of parse stream to chunk error: " + e.Message); + throw new Exception("Parse stream to chunk error: " + e.Message); + } + } + } + } while (retry); + logger.Info($"Succeed downloading chunk #{chunk.ChunkIndex}"); + return chunk; + } + + private async Task ParseStreamIntoChunk(Stream content, BaseResultChunk resultChunk) + { + IChunkParser parser = ChunkParserFactory.Instance.GetParser(resultChunk.ResultFormat, content); + await parser.ParseChunk(resultChunk); + } + } + + class DownloadContextV3 + { + public BaseResultChunk chunk { get; set; } + + public string qrmk { get; set; } + + public Dictionary chunkHeaders { get; set; } + + public CancellationToken cancellationToken { get; set; } + } +} diff --git a/Snowflake.Data/Core/SFReusableChunk.cs b/Snowflake.Data/Core/SFReusableChunk.cs index 06ea7cef3..4db8ec0d7 100755 --- a/Snowflake.Data/Core/SFReusableChunk.cs +++ b/Snowflake.Data/Core/SFReusableChunk.cs @@ -11,8 +11,8 @@ namespace Snowflake.Data.Core class SFReusableChunk : BaseResultChunk { internal override ResultFormat ResultFormat => ResultFormat.JSON; - - private readonly BlockResultData data; + + internal readonly BlockResultData data; private int _currentRowIndex = -1; @@ -29,11 +29,18 @@ internal override void Reset(ExecResponseChunk chunkInfo, int chunkIndex) data.Reset(RowCount, ColumnCount, chunkInfo.uncompressedSize); } + internal override void Clear() + { + base.Clear(); + _currentRowIndex = -1; + data.Clear(); + } + internal override void ResetForRetry() { data.ResetForRetry(); } - + [Obsolete("ExtractCell with rowIndex is deprecated", false)] public override UTF8Buffer ExtractCell(int rowIndex, int columnIndex) { @@ -62,21 +69,22 @@ internal override bool Next() _currentRowIndex += 1; return _currentRowIndex < RowCount; } - + internal override bool Rewind() { _currentRowIndex -= 1; return _currentRowIndex >= 0; } - private class BlockResultData + internal class BlockResultData { private static readonly int NULL_VALUE = -100; - private int blockCount; - private static int blockLengthBits = 24; + internal int blockCount; + private static int blockLengthBits = 23; private static int blockLength = 1 << blockLengthBits; - int metaBlockCount; + + internal int metaBlockCount; private static int metaBlockLengthBits = 15; private static int metaBlockLength = 1 << metaBlockLengthBits; @@ -98,11 +106,24 @@ internal void Reset(int rowCount, int colCount, int uncompressedSize) savedColCount = colCount; currentDatOffset = 0; nextIndex = 0; - int bytesNeeded = uncompressedSize - (rowCount * 2) - (rowCount * colCount); - this.blockCount = getBlock(bytesNeeded - 1) + 1; + this.blockCount = 1; // init with 1 block only this.metaBlockCount = getMetaBlock(rowCount * colCount - 1) + 1; } + internal void Clear() + { + savedRowCount = 0; + savedColCount = 0; + currentDatOffset = 0; + nextIndex = 0; + blockCount = 0; + metaBlockCount = 0; + + data.Clear(); + offsets.Clear(); + lengths.Clear(); + } + internal void ResetForRetry() { currentDatOffset = 0; @@ -157,6 +178,16 @@ int copySize public void add(byte[] bytes, int length) { + // check if a new block for data is needed + if (getBlock(currentDatOffset) == blockCount - 1) + { + var neededSize = length - spaceLeftOnBlock(currentDatOffset); + while (neededSize >= 0) + { + blockCount++; + neededSize -= blockLength; + } + } if (data.Count < blockCount || offsets.Count < metaBlockCount) { allocateArrays(); @@ -232,12 +263,12 @@ private void allocateArrays() { while (data.Count < blockCount) { - data.Add(new byte[1 << blockLengthBits]); + data.Add(new byte[blockLength]); } while (offsets.Count < metaBlockCount) { - offsets.Add(new int[1 << metaBlockLengthBits]); - lengths.Add(new int[1 << metaBlockLengthBits]); + offsets.Add(new int[metaBlockLength]); + lengths.Add(new int[metaBlockLength]); } } } From 13c4bec7454f469ff37cbfddbf3806f1e8a6571f Mon Sep 17 00:00:00 2001 From: Piotr Bulawa Date: Mon, 14 Oct 2024 13:15:01 +0200 Subject: [PATCH 04/20] NOSNOW: Rotate GPG files (#1039) --- .../parameters/parameters_AWS.json.gpg | Bin 272 -> 278 bytes .../parameters/parameters_AZURE.json.gpg | 3 ++- .../parameters/parameters_GCP.json.gpg | Bin 274 -> 294 bytes 3 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/parameters/parameters_AWS.json.gpg b/.github/workflows/parameters/parameters_AWS.json.gpg index e897ec5fe893bda8eb5d81b1ed366e36f62bb5f6..778cca91d8f344431f51207028bbe11e529b0919 100644 GIT binary patch literal 278 zcmV+x0qOpX4Fm}T2sopz^#h1U75UP@L;-vsPXxWlg)SjKqk@wFcd}*1^4bO+VgxSY z201nw~CswP~LEtDF*!_s{Vv7TFX)>;wP180fBKA#{8Y9ul(r7>U#i`8MY)>D1SZKf2_8)4ukrA{mkpLw>l%{ z@@?!i^?Uwf;MMt=6VdSSOjI{S$b&Vj%^@!VTx~HE3nH06rnfW#OQnKL<55k{gae*d z*iR;4xdx!KuQubSBfU3` HscFI +}vźmlc"A4 \ No newline at end of file diff --git a/.github/workflows/parameters/parameters_GCP.json.gpg b/.github/workflows/parameters/parameters_GCP.json.gpg index ea00564b795feddfe07dc1688bc230bcf6f8ed57..5dc80fd7d8afcc2c0c3926d573595614c420c39c 100644 GIT binary patch literal 294 zcmV+>0oneH4Fm}T2-p3}&HhDi0{POwQ~`M9y%R(HY0?dW8}m)=Tp8w&{WON7<0ciB zGA)#D1W-$_6Y$xU-sznLylO&YA>RI$m67lizP`6b_0ZG`%1p;^RX8%;zd!oofh$S{?Jniw646K26Q;vG!cWD~8tv*Ip_O-4h8j zG_dn=v&KDt5a@FZ(Gq8sI}h7>(iNK=q(zWzXk~=I$o5SIt7#$5U*J-o`2)LJpKZJC snF7o$kUVeU+evjtGSj(O6*>$x1xE|rWcqV{W5<~7G#>$@CkVSJb=@J58~^|S literal 274 zcmV+t0qy>b4Fm}T0wX1baMH5~)8f*=Kml6K_p_@MY@N-3R-ZbUcg&hxl0KJ#<+qQy zs&Tb)l-3KlVKt)Kj|lQ1ZVOhIkB?{}uOW?1M?IofB79rM$K#@v(B>W|0L4aKJVhiO z+!4=)Hte!;22(pqx@&x1@-nmT_f?)|jq|)|tCJs$o4q{6%JWnckoGSZHeNnsD0kPs z+z2~zWVOtv&g7(78)xKrIjK!5a6D2}o8eE4@A0Y}sOCfv1F9HV3(P7lQ3IF7bqr~1iq YdwN+OZX->jBhZDb;0jM286Wedi|+b`82|tP From 1394f53fef6cfe2b26acbdcd8cb0a4fe91710653 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Mon, 14 Oct 2024 16:08:41 +0200 Subject: [PATCH 05/20] SNOW-1488701 Enable structured types, documentation for structured types and vector type (#1016) Co-authored-by: Piotr Bulawa --- README.md | 8 + .../IntegrationTests/StructuredArraysIT.cs | 32 +-- .../IntegrationTests/VectorTypesIT.cs | 27 --- .../Client/SnowflakeDbDataReader.cs | 33 +--- doc/StructuredTypes.md | 185 ++++++++++++++++++ doc/VectorType.md | 18 ++ 6 files changed, 233 insertions(+), 70 deletions(-) create mode 100644 doc/StructuredTypes.md create mode 100644 doc/VectorType.md diff --git a/README.md b/README.md index ed8b8341c..0378b5416 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,14 @@ Snowflake data types and their .NET types is covered in: [Data Types and Data Fo How execute a query, use query bindings, run queries synchronously and asynchronously: [Running Queries and Reading Results](doc/QueryingData.md) +## Structured types + +Using structured types: [Structured types](doc/StructuredTypes.md) + +## Vector type + +Using vector type: [Vector type](doc/VectorType.md) + ## Stage Files Using stage files within PUT/GET commands: diff --git a/Snowflake.Data.Tests/IntegrationTests/StructuredArraysIT.cs b/Snowflake.Data.Tests/IntegrationTests/StructuredArraysIT.cs index 142f1eae3..aee5e666e 100644 --- a/Snowflake.Data.Tests/IntegrationTests/StructuredArraysIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/StructuredArraysIT.cs @@ -28,7 +28,7 @@ public void TestSelectArray() Assert.IsTrue(reader.Read()); // act - var array = reader.GetStucturedArray(0); + var array = reader.GetArray(0); // assert Assert.AreEqual(3, array.Length); @@ -54,7 +54,7 @@ public void TestSelectArrayOfObjects() Assert.IsTrue(reader.Read()); // act - var array = reader.GetStucturedArray(0); + var array = reader.GetArray(0); // assert Assert.AreEqual(2, array.Length); @@ -79,7 +79,7 @@ public void TestSelectArrayOfArrays() Assert.IsTrue(reader.Read()); // act - var array = reader.GetStucturedArray(0); + var array = reader.GetArray(0); // assert Assert.AreEqual(2, array.Length); @@ -104,7 +104,7 @@ public void TestSelectArrayOfMap() Assert.IsTrue(reader.Read()); // act - var array = reader.GetStucturedArray>(0); + var array = reader.GetArray>(0); // assert Assert.AreEqual(1, array.Length); @@ -134,7 +134,7 @@ public void TestSelectSemiStructuredTypesInArray(string valueSfString, string ex Assert.IsTrue(reader.Read()); // act - var array = reader.GetStucturedArray(0); + var array = reader.GetArray(0); // assert Assert.NotNull(array); @@ -159,7 +159,7 @@ public void TestSelectArrayOfIntegers() Assert.IsTrue(reader.Read()); // act - var array = reader.GetStucturedArray(0); + var array = reader.GetArray(0); // assert Assert.AreEqual(3, array.Length); @@ -184,7 +184,7 @@ public void TestSelectArrayOfLong() Assert.IsTrue(reader.Read()); // act - var array = reader.GetStucturedArray(0); + var array = reader.GetArray(0); // assert Assert.AreEqual(3, array.Length); @@ -209,7 +209,7 @@ public void TestSelectArrayOfFloats() Assert.IsTrue(reader.Read()); // act - var array = reader.GetStucturedArray(0); + var array = reader.GetArray(0); // assert Assert.AreEqual(3, array.Length); @@ -234,7 +234,7 @@ public void TestSelectArrayOfDoubles() Assert.IsTrue(reader.Read()); // act - var array = reader.GetStucturedArray(0); + var array = reader.GetArray(0); // assert Assert.AreEqual(3, array.Length); @@ -259,7 +259,7 @@ public void TestSelectArrayOfDoublesWithExponentNotation() Assert.IsTrue(reader.Read()); // act - var array = reader.GetStucturedArray(0); + var array = reader.GetArray(0); // assert Assert.AreEqual(2, array.Length); @@ -284,7 +284,7 @@ public void TestSelectStringArrayWithNulls() Assert.IsTrue(reader.Read()); // act - var array = reader.GetStucturedArray(0); + var array = reader.GetArray(0); // assert Assert.AreEqual(3, array.Length); @@ -309,7 +309,7 @@ public void TestSelectIntArrayWithNulls() Assert.IsTrue(reader.Read()); // act - var array = reader.GetStucturedArray(0); + var array = reader.GetArray(0); // assert Assert.AreEqual(3, array.Length); @@ -334,7 +334,7 @@ public void TestSelectNullArray() Assert.IsTrue(reader.Read()); // act - var nullArray = reader.GetStucturedArray(0); + var nullArray = reader.GetArray(0); // assert Assert.IsNull(nullArray); @@ -358,7 +358,7 @@ public void TestThrowExceptionForInvalidArray() Assert.IsTrue(reader.Read()); // act - var thrown = Assert.Throws(() => reader.GetStucturedArray(0)); + var thrown = Assert.Throws(() => reader.GetArray(0)); // assert SnowflakeDbExceptionAssert.HasErrorCode(thrown, SFError.STRUCTURED_TYPE_READ_DETAILED_ERROR); @@ -384,7 +384,7 @@ public void TestThrowExceptionForInvalidArrayElement() Assert.IsTrue(reader.Read()); // act - var thrown = Assert.Throws(() => reader.GetStucturedArray(0)); + var thrown = Assert.Throws(() => reader.GetArray(0)); // assert SnowflakeDbExceptionAssert.HasErrorCode(thrown, SFError.STRUCTURED_TYPE_READ_ERROR); @@ -411,7 +411,7 @@ public void TestThrowExceptionForNextedInvalidElement() Assert.IsTrue(reader.Read()); // act - var thrown = Assert.Throws(() => reader.GetStucturedArray(0)); + var thrown = Assert.Throws(() => reader.GetArray(0)); // assert SnowflakeDbExceptionAssert.HasErrorCode(thrown, SFError.STRUCTURED_TYPE_READ_DETAILED_ERROR); diff --git a/Snowflake.Data.Tests/IntegrationTests/VectorTypesIT.cs b/Snowflake.Data.Tests/IntegrationTests/VectorTypesIT.cs index aa2475c86..07a24a91b 100644 --- a/Snowflake.Data.Tests/IntegrationTests/VectorTypesIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/VectorTypesIT.cs @@ -6,7 +6,6 @@ using Snowflake.Data.Client; using System.Data.Common; using Snowflake.Data.Core; -using Snowflake.Data.Tests.Util; using System; namespace Snowflake.Data.Tests.IntegrationTests @@ -345,32 +344,6 @@ public void TestThrowExceptionForInvalidIdentifierForFloatVector() } } - [Test] - public void TestThrowExceptionForInvalidVectorType() - { - using (DbConnection conn = new SnowflakeDbConnection()) - { - conn.ConnectionString = ConnectionString; - conn.Open(); - AlterSessionSettings(conn); - - using (DbCommand command = conn.CreateCommand()) - { - command.CommandText = "SELECT ARRAY_CONSTRUCT(1.1)::ARRAY(DOUBLE)"; - var reader = (SnowflakeDbDataReader)command.ExecuteReader(); - Assert.IsTrue(reader.Read()); - - // act - var thrown = Assert.Throws(() => reader.GetArray(0)); - - // assert - SnowflakeDbExceptionAssert.HasErrorCode(thrown, SFError.STRUCTURED_TYPE_READ_DETAILED_ERROR); - Assert.That(thrown.Message, Does.Contain("Failed to read structured type when getting an array")); - Assert.That(thrown.Message, Does.Contain("Method GetArray can be used only for vector types")); - } - } - } - private void AlterSessionSettings(DbConnection conn) { using (var command = conn.CreateCommand()) diff --git a/Snowflake.Data/Client/SnowflakeDbDataReader.cs b/Snowflake.Data/Client/SnowflakeDbDataReader.cs index 7d624bd80..7d475024a 100755 --- a/Snowflake.Data/Client/SnowflakeDbDataReader.cs +++ b/Snowflake.Data/Client/SnowflakeDbDataReader.cs @@ -253,7 +253,7 @@ public override int GetValues(object[] values) return count; } - internal T GetObject(int ordinal) + public T GetObject(int ordinal) where T : class, new() { try @@ -282,9 +282,11 @@ public T[] GetArray(int ordinal) { var rowType = resultSet.sfResultSetMetaData.rowTypes[ordinal]; var fields = rowType.fields; - if (fields == null || fields.Count == 0 || !JsonToStructuredTypeConverter.IsVectorType(rowType.type)) + var isArrayOrVector = JsonToStructuredTypeConverter.IsArrayType(rowType.type) || + JsonToStructuredTypeConverter.IsVectorType(rowType.type); + if (fields == null || fields.Count == 0 || !isArrayOrVector) { - throw new StructuredTypesReadingException($"Method GetArray<{typeof(T)}> can be used only for vector types"); + throw new StructuredTypesReadingException($"Method GetArray<{typeof(T)}> can be used only for structured array or vector types"); } var stringValue = GetString(ordinal); @@ -299,30 +301,7 @@ public T[] GetArray(int ordinal) } } - internal T[] GetStucturedArray(int ordinal) - { - try - { - var rowType = resultSet.sfResultSetMetaData.rowTypes[ordinal]; - var fields = rowType.fields; - if (fields == null || fields.Count == 0 || !JsonToStructuredTypeConverter.IsArrayType(rowType.type)) - { - throw new StructuredTypesReadingException($"Method GetArray<{typeof(T)}> can be used only for structured array"); - } - - var stringValue = GetString(ordinal); - var json = stringValue == null ? null : JArray.Parse(stringValue); - return JsonToStructuredTypeConverter.ConvertArray(fields, json); - } - catch (Exception e) - { - if (e is SnowflakeDbException) - throw; - throw StructuredTypesReadingHandler.ToSnowflakeDbException(e, "when getting an array"); - } - } - - internal Dictionary GetMap(int ordinal) + public Dictionary GetMap(int ordinal) { try { diff --git a/doc/StructuredTypes.md b/doc/StructuredTypes.md new file mode 100644 index 000000000..bc45d98c9 --- /dev/null +++ b/doc/StructuredTypes.md @@ -0,0 +1,185 @@ +## Concept + +Snowflake structured types documentation is available here: [Snowflake Structured Types Documentation](https://docs.snowflake.com/en/sql-reference/data-types-structured). + +Snowflake offers a way to store structured types which can be: +- objects, e.g. ```OBJECT(city VARCHAR, state VARCHAR)``` +- arrays, e.g. ```ARRAY(NUMBER)``` +- maps, e.g. ```MAP(VARCHAR, VARCHAR)``` + +The driver allows reading and casting such structured objects into customer classes. + +**Note**: Currently, reading structured types is available only for JSON result format. + +## Enabling the feature + +Currently, reading structured types is available only for JSON result format, so you can make sure you are using JSON result format by: +```sql +ALTER SESSION SET DOTNET_QUERY_RESULT_FORMAT = JSON; +``` + +The structured types feature is enabled starting from v4.2.0 driver version. + +## Structured types vs semi-structured types + +The difference between structured types and semi-structured types is that structured types contain types definitions for given objects/arrays/maps. + +E.g. for a given object: +```sql +SELECT OBJECT_CONSTRUCT('city','San Mateo', 'state', 'CA')::OBJECT(city VARCHAR, state VARCHAR) +``` + +The part indicating the type of object is `::OBJECT(city VARCHAR, state VARCHAR)`. +This part of definition is essential for structured types because it is used to convert the object into the customer class instance. + +Whereas the corresponding semi-structured type does not contain a detailed type definition, for instance: +```sql +SELECT OBJECT_CONSTRUCT('city','San Mateo', 'state', 'CA')::OBJECT +``` + +which means the semi-structured types are returned only as a JSON string. + +## Handling objects + +You can construct structured objects by using an object constructor and providing type details: + +```sql +SELECT OBJECT_CONSTRUCT('city','San Mateo', 'state', 'CA')::OBJECT(city VARCHAR, state VARCHAR) +``` + +You can read the object into your class by executing `T SnowflakeDbReader.GetObject(int ordinal)` method: + +```csharp +var reader = (SnowflakeDbDataReader) command.ExecuteReader(); +Assert.IsTrue(reader.Read()); +var address = reader.GetObject
(0); +``` + +where `Address` is a customer class, e.g. +```csharp +public class Address +{ + public string city { get; set; } + public string state { get; set; } + public Zip zip { get; set; } +} +``` + +There are a few possible ways of constructing an object of a customer class. +The customer object (e.g. `Address`) can be created either: +- by the properties order, which is a default method +- by properties names +- by the constructor. + +### Creating objects by properties order + +Creating objects by properties order is a default construction method. +Objects are created by the non-parametrized constructor, and then the n-th Snowflake object field is converted into the n-th customer object property, one by one. + +You can annotate your class with `SnowflakeObject` annotation to make sure this creation method would be chosen (however it is not necessary since it is a default method): +```csharp +[SnowflakeObject(ConstructionMethod = SnowflakeObjectConstructionMethod.PROPERTIES_ORDER)] +public class Address +{ + public string city { get; set; } + public string state { get; set; } + public Zip zip { get; set; } +} +``` + +If you would like to skip any customer property, you could use a `[SnowflakeColumn(IgnoreForPropertyOrder = true)]` annotation for a given property. +For instance, the annotation used in the following class definition makes the `city` be skipped when mapping the properties: +```csharp +public class Address +{ + [SnowflakeColumn(IgnoreForPropertyOrder = true)] + public string city { get; set; } + public string state { get; set; } + public Zip zip { get; set; } +} +``` + +So, the first field from the database object would be mapped to the `state` property because `city` is skipped. + +### Creating objects by property names + +Using the `[SnowflakeObject(ConstructionMethod = SnowflakeObjectConstructionMethod.PROPERTIES_NAMES)]` annotation on the customer class can enable the creation of objects by their property names. +In this creation method, objects are created by the non-parametrised constructor, and then for each of the database object fields a property of the same name is set with the field value. +It is crucial that database object field names are the same as customer property names; otherwise, a given database object field value would not be set in the customer object. +You can use the annotation `SnowflakeColumn` to rename the customer object property to the match database object field name. + +In the example: + +```csharp +[SnowflakeObject(ConstructionMethod = SnowflakeObjectConstructionMethod.PROPERTIES_NAMES)] +public class Address +{ + [SnowflakeColumn(Name = "nearestCity")] + public string city { get; set; } + public string state { get; set; } + public Zip zip { get; set; } +} +``` + +the database object field `nearestCity` would be mapped to the `city` property of `Address` class. + +### Creating objects by the constructor + +Using the `[SnowflakeObject(ConstructionMethod = SnowflakeObjectConstructionMethod.CONSTRUCTOR)]` annotation on the customer class enables the creation of objects by a constructor. +In this creation method, an object with all its fields is created by a constructor. +A constructor with the exact number of parameters as the number of database object fields should exist because such a constructor would be chosen to instantiate a customer object. +Database object fields are mapped to customer object constructor parameters based on their order. + +Example: +```csharp +[SnowflakeObject(ConstructionMethod = SnowflakeObjectConstructionMethod.CONSTRUCTOR)] +public class Address +{ + private string _city; + private string _state; + + public Address() + { + } + + public Address(string city, string state) + { + _city = city; + _state = state; + } +} +``` + +## Handling arrays + +You can construct structured arrays like this: + +```sql +SELECT ARRAY_CONSTRUCT('a', 'b', 'c')::ARRAY(TEXT) +``` + +You can read such a structured array using `T[] SnowflakeDbReader.GetArray(int ordinal)` method to get an array of specified type. + +```csharp +var reader = (SnowflakeDbDataReader) command.ExecuteReader(); +Assert.IsTrue(reader.Read()); +var array = reader.GetArray(0); +``` + +## Handling maps + +You can construct structured maps like this: + +```sql +SELECT OBJECT_CONSTRUCT('5','San Mateo', '8', 'CA', '13', '01-234')::MAP(INTEGER, VARCHAR) +``` + +**Note**: The only possible map key types are: VARCHAR or NUMBER with scale 0. + +You can read a structured map using `Dictionary SnowflakeDbReader.GetMap(int ordinal)` method to get an array of specified type. + +```csharp +var reader = (SnowflakeDbDataReader) command.ExecuteReader(); +Assert.IsTrue(reader.Read()); +var map = reader.GetMap(0); +``` diff --git a/doc/VectorType.md b/doc/VectorType.md new file mode 100644 index 000000000..fcf3cdaa4 --- /dev/null +++ b/doc/VectorType.md @@ -0,0 +1,18 @@ +# Vector type + +Vector type represents an array of either integer or float type and a fixed size. +Examples: +- `[4, 5, 6]::VECTOR(INT, 3)` is a 3 elements vector of integers +- `[1.1, 2.2]::VECTOR(FLOAT, 2)` is a 2 elements vector of floats + +More about vectors you can read here: [Vector data types](https://docs.snowflake.com/en/sql-reference/data-types-vector). + +The driver allows to read a vector column into `int[]` or `float[]` arrays by calling `T[] SnowflakeDbReader.GetArray(int ordinal)` +method for either int or float types. + +```csharp +var reader = (SnowflakeDbDataReader) command.ExecuteReader(); +Assert.IsTrue(reader.Read()); +int[] intVector = reader.GetArray(0); +float[] floatVector = reader.GetArray(1); +``` From b1464100aa5e80fdb968555e3eeca8540c1c9f20 Mon Sep 17 00:00:00 2001 From: Dariusz Stempniak Date: Wed, 16 Oct 2024 11:59:46 +0200 Subject: [PATCH 06/20] SNOW-1736920 Fix bindings uploading and error handling for GCP (#1041) --- .../IntegrationTests/SFBindTestIT.cs | 1 + Snowflake.Data.Tests/Mock/MockGCSClient.cs | 22 +++++++++++++------ .../UnitTests/SFGCSClientTest.cs | 22 +++++++++---------- .../FileTransfer/StorageClient/SFGCSClient.cs | 16 ++++++++++---- Snowflake.Data/Core/SFStatement.cs | 16 +++++++++----- 5 files changed, 48 insertions(+), 29 deletions(-) diff --git a/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs index 1683700cb..956362fe8 100755 --- a/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs @@ -649,6 +649,7 @@ public void TestPutArrayBind() var count = cmd.ExecuteNonQuery(); Assert.AreEqual(total * 3, count); + cmd.Parameters.Clear(); cmd.CommandText = $"SELECT * FROM {TableName}"; IDataReader reader = cmd.ExecuteReader(); Assert.IsTrue(reader.Read()); diff --git a/Snowflake.Data.Tests/Mock/MockGCSClient.cs b/Snowflake.Data.Tests/Mock/MockGCSClient.cs index cb36918ae..a25d4279a 100644 --- a/Snowflake.Data.Tests/Mock/MockGCSClient.cs +++ b/Snowflake.Data.Tests/Mock/MockGCSClient.cs @@ -25,7 +25,7 @@ class MockGCSClient internal const string GcsFileContent = "GCSClientTest"; // Create a mock response for GetFileHeader - static internal HttpWebResponse CreateResponseForFileHeader(HttpStatusCode httpStatusCode) + internal static HttpWebResponse CreateResponseForFileHeader(HttpStatusCode httpStatusCode) { var response = new Mock(); @@ -46,14 +46,18 @@ static internal HttpWebResponse CreateResponseForFileHeader(HttpStatusCode httpS } // Create a mock response for UploadFile - static internal HttpWebResponse CreateResponseForUploadFile(HttpStatusCode httpStatusCode) + internal static HttpWebResponse CreateResponseForUploadFile(HttpStatusCode? httpStatusCode) { var response = new Mock(); - if (httpStatusCode != HttpStatusCode.OK) + if (httpStatusCode is null) + { + throw new WebException("Mock GCS Error - no response", null, 0, null); + } + else if (httpStatusCode != HttpStatusCode.OK) { response.SetupGet(c => c.StatusCode) - .Returns(httpStatusCode); + .Returns(httpStatusCode.Value); throw new WebException("Mock GCS Error", null, 0, response.Object); } @@ -61,11 +65,15 @@ static internal HttpWebResponse CreateResponseForUploadFile(HttpStatusCode httpS } // Create a mock response for DownloadFile - static internal HttpWebResponse CreateResponseForDownloadFile(HttpStatusCode httpStatusCode) + internal static HttpWebResponse CreateResponseForDownloadFile(HttpStatusCode? httpStatusCode) { var response = new Mock(); - if (httpStatusCode == HttpStatusCode.OK) + if (httpStatusCode is null) + { + throw new WebException("Mock GCS Error - no response", null, 0, null); + } + else if (httpStatusCode == HttpStatusCode.OK) { response.Setup(c => c.Headers).Returns(new WebHeaderCollection()); response.Object.Headers.Add(SFGCSClient.GCS_METADATA_ENCRYPTIONDATAPROP, @@ -82,7 +90,7 @@ static internal HttpWebResponse CreateResponseForDownloadFile(HttpStatusCode htt else { response.SetupGet(c => c.StatusCode) - .Returns(httpStatusCode); + .Returns(httpStatusCode.Value); throw new WebException("Mock GCS Error", null, 0, response.Object); } diff --git a/Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs b/Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs index 925ce4c98..0fad57542 100644 --- a/Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs @@ -223,16 +223,14 @@ private void AssertForGetFileHeaderTests(ResultStatus expectedResultStatus, File [TestCase(HttpStatusCode.Forbidden, ResultStatus.NEED_RETRY)] [TestCase(HttpStatusCode.InternalServerError, ResultStatus.NEED_RETRY)] [TestCase(HttpStatusCode.ServiceUnavailable, ResultStatus.NEED_RETRY)] - public void TestUploadFile(HttpStatusCode httpStatusCode, ResultStatus expectedResultStatus) + [TestCase(null, ResultStatus.ERROR)] + public void TestUploadFile(HttpStatusCode? httpStatusCode, ResultStatus expectedResultStatus) { // Arrange var mockWebRequest = new Mock(); mockWebRequest.Setup(c => c.Headers).Returns(new WebHeaderCollection()); mockWebRequest.Setup(client => client.GetResponse()) - .Returns(() => - { - return MockGCSClient.CreateResponseForUploadFile(httpStatusCode); - }); + .Returns(() => MockGCSClient.CreateResponseForUploadFile(httpStatusCode)); mockWebRequest.Setup(client => client.GetRequestStream()) .Returns(() => new MemoryStream()); _client.SetCustomWebRequest(mockWebRequest.Object); @@ -257,16 +255,14 @@ public void TestUploadFile(HttpStatusCode httpStatusCode, ResultStatus expectedR [TestCase(HttpStatusCode.Forbidden, ResultStatus.NEED_RETRY)] [TestCase(HttpStatusCode.InternalServerError, ResultStatus.NEED_RETRY)] [TestCase(HttpStatusCode.ServiceUnavailable, ResultStatus.NEED_RETRY)] - public async Task TestUploadFileAsync(HttpStatusCode httpStatusCode, ResultStatus expectedResultStatus) + [TestCase(null, ResultStatus.ERROR)] + public async Task TestUploadFileAsync(HttpStatusCode? httpStatusCode, ResultStatus expectedResultStatus) { // Arrange var mockWebRequest = new Mock(); mockWebRequest.Setup(c => c.Headers).Returns(new WebHeaderCollection()); mockWebRequest.Setup(client => client.GetResponseAsync()) - .Returns(() => - { - return Task.FromResult((WebResponse)MockGCSClient.CreateResponseForUploadFile(httpStatusCode)); - }); + .Returns(() => Task.FromResult((WebResponse)MockGCSClient.CreateResponseForUploadFile(httpStatusCode))); mockWebRequest.Setup(client => client.GetRequestStreamAsync()) .Returns(() => Task.FromResult((Stream) new MemoryStream())); _client.SetCustomWebRequest(mockWebRequest.Object); @@ -301,7 +297,8 @@ private void AssertForUploadFileTests(ResultStatus expectedResultStatus) [TestCase(HttpStatusCode.Forbidden, ResultStatus.NEED_RETRY)] [TestCase(HttpStatusCode.InternalServerError, ResultStatus.NEED_RETRY)] [TestCase(HttpStatusCode.ServiceUnavailable, ResultStatus.NEED_RETRY)] - public void TestDownloadFile(HttpStatusCode httpStatusCode, ResultStatus expectedResultStatus) + [TestCase(null, ResultStatus.ERROR)] + public void TestDownloadFile(HttpStatusCode? httpStatusCode, ResultStatus expectedResultStatus) { // Arrange var mockWebRequest = new Mock(); @@ -325,7 +322,8 @@ public void TestDownloadFile(HttpStatusCode httpStatusCode, ResultStatus expecte [TestCase(HttpStatusCode.Forbidden, ResultStatus.NEED_RETRY)] [TestCase(HttpStatusCode.InternalServerError, ResultStatus.NEED_RETRY)] [TestCase(HttpStatusCode.ServiceUnavailable, ResultStatus.NEED_RETRY)] - public async Task TestDownloadFileAsync(HttpStatusCode httpStatusCode, ResultStatus expectedResultStatus) + [TestCase(null, ResultStatus.ERROR)] + public async Task TestDownloadFileAsync(HttpStatusCode? httpStatusCode, ResultStatus expectedResultStatus) { // Arrange var mockWebRequest = new Mock(); diff --git a/Snowflake.Data/Core/FileTransfer/StorageClient/SFGCSClient.cs b/Snowflake.Data/Core/FileTransfer/StorageClient/SFGCSClient.cs index 14abaad4a..9e588e921 100644 --- a/Snowflake.Data/Core/FileTransfer/StorageClient/SFGCSClient.cs +++ b/Snowflake.Data/Core/FileTransfer/StorageClient/SFGCSClient.cs @@ -349,7 +349,7 @@ public void DownloadFile(SFFileMetadata fileMetadata, string fullDstPath, int ma try { // Issue the GET request - WebRequest request = _customWebRequest == null ? FormBaseRequest(fileMetadata, "GET") : _customWebRequest; + WebRequest request = _customWebRequest == null ? FormBaseRequest(fileMetadata, "GET") : _customWebRequest; using (HttpWebResponse response = (HttpWebResponse)request.GetResponse()) { // Write to file @@ -444,7 +444,7 @@ private void HandleDownloadResponse(HttpWebResponse response, SFFileMetadata fil private SFFileMetadata HandleFileHeaderErrForPresignedUrls(WebException ex, SFFileMetadata fileMetadata) { Logger.Error("Failed to get file header for presigned url: " + ex.Message); - + HttpWebResponse response = (HttpWebResponse)ex.Response; if (response.StatusCode == HttpStatusCode.Unauthorized || response.StatusCode == HttpStatusCode.Forbidden || @@ -509,7 +509,11 @@ private SFFileMetadata HandleUploadFileErr(WebException ex, SFFileMetadata fileM fileMetadata.lastError = ex; HttpWebResponse response = (HttpWebResponse)ex.Response; - if (response.StatusCode == HttpStatusCode.BadRequest && GCS_ACCESS_TOKEN != null) + if (response is null) + { + fileMetadata.resultStatus = ResultStatus.ERROR.ToString(); + } + else if (response.StatusCode == HttpStatusCode.BadRequest && GCS_ACCESS_TOKEN != null) { fileMetadata.resultStatus = ResultStatus.RENEW_PRESIGNED_URL.ToString(); } @@ -539,7 +543,11 @@ private SFFileMetadata HandleDownloadFileErr(WebException ex, SFFileMetadata fil fileMetadata.lastError = ex; HttpWebResponse response = (HttpWebResponse)ex.Response; - if (response.StatusCode == HttpStatusCode.Unauthorized) + if (response is null) + { + fileMetadata.resultStatus = ResultStatus.ERROR.ToString(); + } + else if (response.StatusCode == HttpStatusCode.Unauthorized) { fileMetadata.resultStatus = ResultStatus.RENEW_TOKEN.ToString(); } diff --git a/Snowflake.Data/Core/SFStatement.cs b/Snowflake.Data/Core/SFStatement.cs index e84690d54..146a10130 100644 --- a/Snowflake.Data/Core/SFStatement.cs +++ b/Snowflake.Data/Core/SFStatement.cs @@ -383,11 +383,14 @@ internal async Task ExecuteAsync(int timeout, string sql, Dicti SFBindUploader uploader = new SFBindUploader(SfSession, _requestId); await uploader.UploadAsync(bindings, cancellationToken).ConfigureAwait(false); _bindStage = uploader.getStagePath(); - ClearQueryRequestId(); } catch (Exception e) { - logger.Warn("Exception encountered trying to upload binds to stage. Attaching binds in payload instead. {0}", e); + logger.Warn("Exception encountered trying to upload binds to stage. Attaching binds in payload instead. Exception: " + e.Message); + } + finally + { + ClearQueryRequestId(); } } @@ -532,13 +535,14 @@ private SFBaseResultSet ExecuteSqlOtherThanPutGet(int timeout, string sql, Dicti SFBindUploader uploader = new SFBindUploader(SfSession, _requestId); uploader.Upload(bindings); _bindStage = uploader.getStagePath(); - ClearQueryRequestId(); } catch (Exception e) { - logger.Warn( - "Exception encountered trying to upload binds to stage. Attaching binds in payload instead. {0}", - e); + logger.Warn("Exception encountered trying to upload binds to stage. Attaching binds in payload instead. Exception: " + e.Message); + } + finally + { + ClearQueryRequestId(); } } From 90b700753dff8c6b2fb7031bc2ca8651e25aa722 Mon Sep 17 00:00:00 2001 From: Dariusz Stempniak Date: Mon, 21 Oct 2024 11:36:30 +0200 Subject: [PATCH 07/20] SNOW-1729244 support for large and small timestamps (#1038) --- .../IntegrationTests/SFBindTestIT.cs | 130 +++++++++++++----- ...ructuredTypesWithEmbeddedUnstructuredIT.cs | 70 ++++++++++ .../UnitTests/SFBindUploaderTest.cs | 29 ++-- .../UnitTests/SFDataConverterTest.cs | 30 +++- .../UnitTests/StructuredTypesTest.cs | 12 ++ Snowflake.Data/Client/SnowflakeDbCommand.cs | 8 +- Snowflake.Data/Core/ArrowResultSet.cs | 2 +- Snowflake.Data/Core/SFBindUploader.cs | 28 ++-- Snowflake.Data/Core/SFDataConverter.cs | 48 +++---- Snowflake.Data/Core/SFResultSet.cs | 2 +- 10 files changed, 271 insertions(+), 88 deletions(-) diff --git a/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs index 956362fe8..05995e0d4 100755 --- a/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs @@ -1,6 +1,7 @@ /* * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ +#nullable enable using System; using System.Data; @@ -87,7 +88,7 @@ public void TestBindNullValue() foreach (DbType type in Enum.GetValues(typeof(DbType))) { bool isTypeSupported = true; - string colName = null; + string colName; using (IDbCommand command = dbConnection.CreateCommand()) { var param = command.CreateParameter(); @@ -226,7 +227,7 @@ public void TestBindValue() foreach (DbType type in Enum.GetValues(typeof(DbType))) { bool isTypeSupported = true; - string colName = null; + string colName; using (IDbCommand command = dbConnection.CreateCommand()) { var param = command.CreateParameter(); @@ -885,13 +886,20 @@ public void TestExplicitDbTypeAssignmentForArrayValue() [TestCase(ResultFormat.ARROW, SFTableType.Iceberg, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, null)] */ // Session TimeZone cases - [TestCase(ResultFormat.ARROW, SFTableType.Standard, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, "Europe/Warsaw")] + [TestCase(ResultFormat.JSON, SFTableType.Standard, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, "Europe/Warsaw")] [TestCase(ResultFormat.JSON, SFTableType.Standard, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, "Asia/Tokyo")] + [TestCase(ResultFormat.ARROW, SFTableType.Standard, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, "Europe/Warsaw")] + [TestCase(ResultFormat.ARROW, SFTableType.Standard, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, "Asia/Tokyo")] public void TestDateTimeBinding(ResultFormat resultFormat, SFTableType tableType, SFDataType columnType, Int32? columnPrecision, DbType bindingType, string comparisonFormat, string timeZone) { // Arrange - var timestamp = "2023/03/15 13:17:29.207 +05:00"; // 08:17:29.207 UTC - var expected = ExpectedTimestampWrapper.From(timestamp, columnType); + string[] timestamps = + { + "2023/03/15 13:17:29.207 +05:00", + "9999/12/30 23:24:25.987 +07:00", + "0001/01/02 02:06:07.000 -04:00" + }; + var expected = ExpectedTimestampWrapper.From(timestamps, columnType); var columnWithPrecision = ColumnTypeWithPrecision(columnType, columnPrecision); var testCase = $"ResultFormat={resultFormat}, TableType={tableType}, ColumnType={columnWithPrecision}, BindingType={bindingType}, ComparisonFormat={comparisonFormat}"; var bindingThreshold = 65280; // when exceeded enforces bindings via file on stage @@ -907,24 +915,34 @@ public void TestDateTimeBinding(ResultFormat resultFormat, SFTableType tableType if (!timeZone.IsNullOrEmpty()) // Driver ignores this setting and relies on local environment timezone conn.ExecuteNonQuery($"alter session set TIMEZONE = '{timeZone}'"); + // prepare initial column + var columns = new List { "id number(10,0) not null primary key" }; + var sql_columns = "id"; + var sql_values = "?"; + + // prepare additional columns + for (int i = 1; i <= timestamps.Length; ++i) + { + columns.Add($"ts_{i} {columnWithPrecision}"); + sql_columns += $",ts_{i}"; + sql_values += ",?"; + } + CreateOrReplaceTable(conn, TableName, tableType.TableDDLCreationPrefix(), - new[] { - "id number(10,0) not null primary key", // necessary only for HYBRID tables - $"ts {columnWithPrecision}" - }, + columns, tableType.TableDDLCreationFlags()); // Act+Assert - var sqlInsert = $"insert into {TableName} (id, ts) values (?, ?)"; + var sqlInsert = $"insert into {TableName} ({sql_columns}) values ({sql_values})"; InsertSingleRecord(conn, sqlInsert, bindingType, 1, expected); InsertMultipleRecords(conn, sqlInsert, bindingType, 2, expected, smallBatchRowCount, false); InsertMultipleRecords(conn, sqlInsert, bindingType, smallBatchRowCount+2, expected, bigBatchRowCount, true); // Assert var row = 0; - using (var select = conn.CreateCommand($"select id, ts from {TableName} order by id")) + using (var select = conn.CreateCommand($"select {sql_columns} from {TableName} order by id")) { s_logger.Debug(select.CommandText); var reader = select.ExecuteReader(); @@ -933,7 +951,11 @@ public void TestDateTimeBinding(ResultFormat resultFormat, SFTableType tableType ++row; string faultMessage = $"Mismatch for row: {row}, {testCase}"; Assert.AreEqual(row, reader.GetInt32(0)); - expected.AssertEqual(reader.GetValue(1), comparisonFormat, faultMessage); + + for (int i = 0; i < timestamps.Length; ++i) + { + expected.AssertEqual(reader.GetValue(i + 1), comparisonFormat, faultMessage, i); + } } } Assert.AreEqual(1+smallBatchRowCount+bigBatchRowCount, row); @@ -948,12 +970,24 @@ private void InsertSingleRecord(IDbConnection conn, string sqlInsert, DbType bin insert.Add("1", DbType.Int32, identifier); if (ExpectedTimestampWrapper.IsOffsetType(ts.ExpectedColumnType())) { - var parameter = (SnowflakeDbParameter)insert.Add("2", binding, ts.GetDateTimeOffset()); - parameter.SFDataType = ts.ExpectedColumnType(); + var dateTimeOffsets = ts.GetDateTimeOffsets(); + for (int i = 0; i < dateTimeOffsets.Length; ++i) + { + var parameterName = (i + 2).ToString(); + var parameterValue = dateTimeOffsets[i]; + var parameter = insert.Add(parameterName, binding, parameterValue); + parameter.SFDataType = ts.ExpectedColumnType(); + } } else { - insert.Add("2", binding, ts.GetDateTime()); + var dateTimes = ts.GetDateTimes(); + for (int i = 0; i < dateTimes.Length; ++i) + { + var parameterName = (i + 2).ToString(); + var parameterValue = dateTimes[i]; + insert.Add(parameterName, binding, parameterValue); + } } // Act @@ -974,12 +1008,25 @@ private void InsertMultipleRecords(IDbConnection conn, string sqlInsert, DbType insert.Add("1", DbType.Int32, Enumerable.Range(initialIdentifier, rowsCount).ToArray()); if (ExpectedTimestampWrapper.IsOffsetType(ts.ExpectedColumnType())) { - var parameter = (SnowflakeDbParameter)insert.Add("2", binding, Enumerable.Repeat(ts.GetDateTimeOffset(), rowsCount).ToArray()); - parameter.SFDataType = ts.ExpectedColumnType(); + var dateTimeOffsets = ts.GetDateTimeOffsets(); + for (int i = 0; i < dateTimeOffsets.Length; ++i) + { + var parameterName = (i + 2).ToString(); + var parameterValue = Enumerable.Repeat(dateTimeOffsets[i], rowsCount).ToArray(); + var parameter = insert.Add(parameterName, binding, parameterValue); + parameter.SFDataType = ts.ExpectedColumnType(); + } + } else { - insert.Add("2", binding, Enumerable.Repeat(ts.GetDateTime(), rowsCount).ToArray()); + var dateTimes = ts.GetDateTimes(); + for (int i = 0; i < dateTimes.Length; ++i) + { + var parameterName = (i + 2).ToString(); + var parameterValue = Enumerable.Repeat(dateTimes[i], rowsCount).ToArray(); + insert.Add(parameterName, binding, parameterValue); + } } // Act @@ -1002,57 +1049,66 @@ private static string ColumnTypeWithPrecision(SFDataType columnType, Int32? colu class ExpectedTimestampWrapper { private readonly SFDataType _columnType; - private readonly DateTime? _expectedDateTime; - private readonly DateTimeOffset? _expectedDateTimeOffset; + private readonly DateTime[]? _expectedDateTimes; + private readonly DateTimeOffset[]? _expectedDateTimeOffsets; - internal static ExpectedTimestampWrapper From(string timestampWithTimeZone, SFDataType columnType) + internal static ExpectedTimestampWrapper From(string[] timestampsWithTimeZone, SFDataType columnType) { if (IsOffsetType(columnType)) { - var dateTimeOffset = DateTimeOffset.ParseExact(timestampWithTimeZone, "yyyy/MM/dd HH:mm:ss.fff zzz", CultureInfo.InvariantCulture); - return new ExpectedTimestampWrapper(dateTimeOffset, columnType); + var dateTimeOffsets = + timestampsWithTimeZone + .Select(ts => DateTimeOffset.ParseExact(ts, "yyyy/MM/dd HH:mm:ss.fff zzz", CultureInfo.InvariantCulture)) + .ToArray(); + return new ExpectedTimestampWrapper(dateTimeOffsets, columnType); } - var dateTime = DateTime.ParseExact(timestampWithTimeZone, "yyyy/MM/dd HH:mm:ss.fff zzz", CultureInfo.InvariantCulture); - return new ExpectedTimestampWrapper(dateTime, columnType); + var dateTimes = + timestampsWithTimeZone + .Select(ts => DateTime.ParseExact(ts, "yyyy/MM/dd HH:mm:ss.fff zzz", CultureInfo.InvariantCulture)) + .ToArray(); + + return new ExpectedTimestampWrapper(dateTimes, columnType); } - private ExpectedTimestampWrapper(DateTime dateTime, SFDataType columnType) + private ExpectedTimestampWrapper(DateTime[] dateTimes, SFDataType columnType) { - _expectedDateTime = dateTime; - _expectedDateTimeOffset = null; + _expectedDateTimes = dateTimes; + _expectedDateTimeOffsets = null; _columnType = columnType; } - private ExpectedTimestampWrapper(DateTimeOffset dateTimeOffset, SFDataType columnType) + private ExpectedTimestampWrapper(DateTimeOffset[] dateTimeOffsets, SFDataType columnType) { - _expectedDateTimeOffset = dateTimeOffset; - _expectedDateTime = null; + _expectedDateTimeOffsets = dateTimeOffsets; + _expectedDateTimes = null; _columnType = columnType; } internal SFDataType ExpectedColumnType() => _columnType; - internal void AssertEqual(object actual, string comparisonFormat, string faultMessage) + internal void AssertEqual(object actual, string comparisonFormat, string faultMessage, int index) { switch (_columnType) { case SFDataType.TIMESTAMP_TZ: - Assert.AreEqual(GetDateTimeOffset().ToString(comparisonFormat), ((DateTimeOffset)actual).ToString(comparisonFormat), faultMessage); + Assert.AreEqual(GetDateTimeOffsets()[index].ToString(comparisonFormat), ((DateTimeOffset)actual).ToString(comparisonFormat), faultMessage); break; case SFDataType.TIMESTAMP_LTZ: - Assert.AreEqual(GetDateTimeOffset().ToUniversalTime().ToString(comparisonFormat), ((DateTimeOffset)actual).ToUniversalTime().ToString(comparisonFormat), faultMessage); + Assert.AreEqual(GetDateTimeOffsets()[index].ToUniversalTime().ToString(comparisonFormat), ((DateTimeOffset)actual).ToUniversalTime().ToString(comparisonFormat), faultMessage); break; default: - Assert.AreEqual(GetDateTime().ToString(comparisonFormat), ((DateTime)actual).ToString(comparisonFormat), faultMessage); + Assert.AreEqual(GetDateTimes()[index].ToString(comparisonFormat), ((DateTime)actual).ToString(comparisonFormat), faultMessage); break; } } - internal DateTime GetDateTime() => _expectedDateTime ?? throw new Exception($"Column {_columnType} is not matching the expected value type {typeof(DateTime)}"); + internal DateTime[] GetDateTimes() => _expectedDateTimes ?? throw new Exception($"Column {_columnType} is not matching the expected value type {typeof(DateTime)}"); - internal DateTimeOffset GetDateTimeOffset() => _expectedDateTimeOffset ?? throw new Exception($"Column {_columnType} is not matching the expected value type {typeof(DateTime)}"); + internal DateTimeOffset[] GetDateTimeOffsets() => _expectedDateTimeOffsets ?? throw new Exception($"Column {_columnType} is not matching the expected value type {typeof(DateTime)}"); internal static bool IsOffsetType(SFDataType type) => type == SFDataType.TIMESTAMP_LTZ || type == SFDataType.TIMESTAMP_TZ; } } + +#nullable restore diff --git a/Snowflake.Data.Tests/IntegrationTests/StructuredTypesWithEmbeddedUnstructuredIT.cs b/Snowflake.Data.Tests/IntegrationTests/StructuredTypesWithEmbeddedUnstructuredIT.cs index 784aa4132..22f8310a1 100644 --- a/Snowflake.Data.Tests/IntegrationTests/StructuredTypesWithEmbeddedUnstructuredIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/StructuredTypesWithEmbeddedUnstructuredIT.cs @@ -361,6 +361,41 @@ internal static IEnumerable DateTimeConversionCases() null, DateTime.Parse("2024-07-11 21:20:05.1234568").ToLocalTime() }; + yield return new object[] + { + "9999-12-31 23:59:59.999999", + SFDataType.TIMESTAMP_NTZ.ToString(), + DateTime.Parse("9999-12-31 23:59:59.999999"), + DateTime.Parse("9999-12-31 23:59:59.999999") + }; + yield return new object[] + { + "9999-12-31 23:59:59.999999 +1:00", + SFDataType.TIMESTAMP_TZ.ToString(), + null, + DateTime.SpecifyKind(DateTime.Parse("9999-12-31 22:59:59.999999"), DateTimeKind.Utc) + }; + yield return new object[] + { + "9999-12-31 23:59:59.999999 +13:00", + SFDataType.TIMESTAMP_LTZ.ToString(), + null, + DateTime.Parse("9999-12-31 10:59:59.999999").ToLocalTime() + }; + yield return new object[] + { + "0001-01-01 00:00:00", + SFDataType.TIMESTAMP_NTZ.ToString(), + DateTime.Parse("0001-01-01 00:00:00"), + DateTime.Parse("0001-01-01 00:00:00") + }; + yield return new object[] + { + "0001-01-01 00:00:00 -1:00", + SFDataType.TIMESTAMP_TZ.ToString(), + null, + DateTime.SpecifyKind(DateTime.Parse("0001-01-01 01:00:00"), DateTimeKind.Utc) + }; } [Test] @@ -445,6 +480,41 @@ internal static IEnumerable DateTimeOffsetConversionCases() null, DateTimeOffset.Parse("2024-07-11 14:20:05.1234568 -7:00") }; + yield return new object[] + { + "9999-12-31 23:59:59.999999", + SFDataType.TIMESTAMP_NTZ.ToString(), + DateTime.Parse("9999-12-31 23:59:59.999999"), + DateTimeOffset.Parse("9999-12-31 23:59:59.999999Z") + }; + yield return new object[] + { + "9999-12-31 23:59:59.999999 +1:00", + SFDataType.TIMESTAMP_TZ.ToString(), + null, + DateTimeOffset.Parse("9999-12-31 23:59:59.999999 +1:00") + }; + yield return new object[] + { + "9999-12-31 23:59:59.999999 +13:00", + SFDataType.TIMESTAMP_LTZ.ToString(), + null, + DateTimeOffset.Parse("9999-12-31 23:59:59.999999 +13:00") + }; + yield return new object[] + { + "0001-01-01 00:00:00", + SFDataType.TIMESTAMP_NTZ.ToString(), + DateTime.Parse("0001-01-01 00:00:00"), + DateTimeOffset.Parse("0001-01-01 00:00:00Z") + }; + yield return new object[] + { + "0001-01-01 00:00:00 -1:00", + SFDataType.TIMESTAMP_TZ.ToString(), + null, + DateTimeOffset.Parse("0001-01-01 00:00:00 -1:00") + }; } private TimeZoneInfo GetTimeZone(SnowflakeDbConnection connection) diff --git a/Snowflake.Data.Tests/UnitTests/SFBindUploaderTest.cs b/Snowflake.Data.Tests/UnitTests/SFBindUploaderTest.cs index ac5172086..46e5b5b90 100644 --- a/Snowflake.Data.Tests/UnitTests/SFBindUploaderTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFBindUploaderTest.cs @@ -20,7 +20,7 @@ public void TestCsvDataConversionForDate(SFDataType dbType, string input, string { // Arrange var dateExpected = DateTime.Parse(expected); - var check = SFDataConverter.csharpValToSfVal(SFDataType.DATE, dateExpected); + var check = SFDataConverter.CSharpValToSfVal(SFDataType.DATE, dateExpected); Assert.AreEqual(check, input); // Act DateTime dateActual = DateTime.Parse(_bindUploader.GetCSVData(dbType.ToString(), input)); @@ -37,51 +37,60 @@ public void TestCsvDataConversionForTime(SFDataType dbType, string input, string { // Arrange DateTime timeExpected = DateTime.Parse(expected); - var check = SFDataConverter.csharpValToSfVal(SFDataType.TIME, timeExpected); + var check = SFDataConverter.CSharpValToSfVal(SFDataType.TIME, timeExpected); Assert.AreEqual(check, input); // Act DateTime timeActual = DateTime.Parse(_bindUploader.GetCSVData(dbType.ToString(), input)); // Assert Assert.AreEqual(timeExpected, timeActual); } - - [TestCase(SFDataType.TIMESTAMP_LTZ, "39600000000000", "1970-01-01T12:00:00.0000000+01:00")] + + [TestCase(SFDataType.TIMESTAMP_LTZ, "0", "1970-01-01T00:00:00.0000000+00:00")] + [TestCase(SFDataType.TIMESTAMP_LTZ, "39600000000000", "1970-01-01T12:00:00.0000000+01:00")] [TestCase(SFDataType.TIMESTAMP_LTZ, "1341136800000000000", "2012-07-01T12:00:00.0000000+02:00")] [TestCase(SFDataType.TIMESTAMP_LTZ, "352245599987654000", "1981-02-28T23:59:59.9876540+02:00")] [TestCase(SFDataType.TIMESTAMP_LTZ, "1678868249207000000", "2023/03/15T13:17:29.207+05:00")] + [TestCase(SFDataType.TIMESTAMP_LTZ, "253402300799999999900", "9999-12-31T23:59:59.9999999+00:00")] + [TestCase(SFDataType.TIMESTAMP_LTZ, "-62135596800000000000", "0001-01-01T00:00:00.0000000+00:00")] public void TestCsvDataConversionForTimestampLtz(SFDataType dbType, string input, string expected) { // Arrange var timestampExpected = DateTimeOffset.Parse(expected); - var check = SFDataConverter.csharpValToSfVal(SFDataType.TIMESTAMP_LTZ, timestampExpected); + var check = SFDataConverter.CSharpValToSfVal(SFDataType.TIMESTAMP_LTZ, timestampExpected); Assert.AreEqual(check, input); // Act var timestampActual = DateTimeOffset.Parse(_bindUploader.GetCSVData(dbType.ToString(), input)); // Assert Assert.AreEqual(timestampExpected.ToLocalTime(), timestampActual); } - + + [TestCase(SFDataType.TIMESTAMP_TZ, "0 1440", "1970-01-01 00:00:00.000000 +00:00")] [TestCase(SFDataType.TIMESTAMP_TZ, "1341136800000000000 1560", "2012-07-01 12:00:00.000000 +02:00")] [TestCase(SFDataType.TIMESTAMP_TZ, "352245599987654000 1560", "1981-02-28 23:59:59.987654 +02:00")] + [TestCase(SFDataType.TIMESTAMP_TZ, "253402300799999999000 1440", "9999-12-31 23:59:59.999999 +00:00")] + [TestCase(SFDataType.TIMESTAMP_TZ, "-62135596800000000000 1440", "0001-01-01 00:00:00.000000 +00:00")] public void TestCsvDataConversionForTimestampTz(SFDataType dbType, string input, string expected) { // Arrange DateTimeOffset timestampExpected = DateTimeOffset.Parse(expected); - var check = SFDataConverter.csharpValToSfVal(SFDataType.TIMESTAMP_TZ, timestampExpected); + var check = SFDataConverter.CSharpValToSfVal(SFDataType.TIMESTAMP_TZ, timestampExpected); Assert.AreEqual(check, input); // Act DateTimeOffset timestampActual = DateTimeOffset.Parse(_bindUploader.GetCSVData(dbType.ToString(), input)); // Assert Assert.AreEqual(timestampExpected, timestampActual); } - + + [TestCase(SFDataType.TIMESTAMP_NTZ, "0", "1970-01-01 00:00:00.000000")] [TestCase(SFDataType.TIMESTAMP_NTZ, "1341144000000000000", "2012-07-01 12:00:00.000000")] [TestCase(SFDataType.TIMESTAMP_NTZ, "352252799987654000", "1981-02-28 23:59:59.987654")] + [TestCase(SFDataType.TIMESTAMP_NTZ, "253402300799999999000", "9999-12-31 23:59:59.999999")] + [TestCase(SFDataType.TIMESTAMP_NTZ, "-62135596800000000000", "0001-01-01 00:00:00.000000")] public void TestCsvDataConversionForTimestampNtz(SFDataType dbType, string input, string expected) { - // Arrange + // Arrange DateTime timestampExpected = DateTime.Parse(expected); - var check = SFDataConverter.csharpValToSfVal(SFDataType.TIMESTAMP_NTZ, timestampExpected); + var check = SFDataConverter.CSharpValToSfVal(SFDataType.TIMESTAMP_NTZ, timestampExpected); Assert.AreEqual(check, input); // Act DateTime timestampActual = DateTime.Parse(_bindUploader.GetCSVData(dbType.ToString(), input)); diff --git a/Snowflake.Data.Tests/UnitTests/SFDataConverterTest.cs b/Snowflake.Data.Tests/UnitTests/SFDataConverterTest.cs index 65160ac97..7def7ce6a 100755 --- a/Snowflake.Data.Tests/UnitTests/SFDataConverterTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFDataConverterTest.cs @@ -4,6 +4,8 @@ using System; using System.Text; +using Snowflake.Data.Client; +using Snowflake.Data.Tests.Util; namespace Snowflake.Data.Tests.UnitTests { @@ -36,8 +38,8 @@ public void TestConvertBindToSFValFinlandLocale() Thread.CurrentThread.CurrentCulture = ci; - System.Tuple t = - SFDataConverter.csharpTypeValToSfTypeVal(System.Data.DbType.Double, 1.2345); + System.Tuple t = + SFDataConverter.CSharpTypeValToSfTypeVal(System.Data.DbType.Double, 1.2345); Assert.AreEqual("REAL", t.Item1); Assert.AreEqual("1.2345", t.Item2); @@ -109,7 +111,7 @@ public void TestConvertTimeSpan(string inputTimeStr) var tickDiff = val.Ticks; var inputStringAsItComesBackFromDatabase = (tickDiff / 10000000.0m).ToString(CultureInfo.InvariantCulture); inputStringAsItComesBackFromDatabase += inputTimeStr.Substring(8, inputTimeStr.Length - 8); - + // Run the conversion var result = SFDataConverter.ConvertToCSharpVal(ConvertToUTF8Buffer(inputStringAsItComesBackFromDatabase), SFDataType.TIME, typeof(TimeSpan)); @@ -148,7 +150,7 @@ public void TestConvertDate(string inputTimeStr, object kind = null) private void internalTestConvertDate(DateTime dtExpected, DateTime testValue) { - var result = SFDataConverter.csharpTypeValToSfTypeVal(System.Data.DbType.Date, testValue); + var result = SFDataConverter.CSharpTypeValToSfTypeVal(System.Data.DbType.Date, testValue); // Convert result to DateTime for easier interpretation var unixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); DateTime dtResult = unixEpoch.AddMilliseconds(Int64.Parse(result.Item2)); @@ -326,5 +328,25 @@ public void TestInvalidConversionInvalidDecimal(string s) Assert.Throws(() => SFDataConverter.ConvertToCSharpVal(ConvertToUTF8Buffer(s), SFDataType.FIXED, typeof(decimal))); } + [Test] + [TestCase(SFDataType.TIMESTAMP_LTZ, typeof(DateTime))] + [TestCase(SFDataType.TIMESTAMP_TZ, typeof(DateTime))] + [TestCase(SFDataType.TIMESTAMP_NTZ, typeof(DateTimeOffset))] + [TestCase(SFDataType.TIME, typeof(DateTimeOffset))] + [TestCase(SFDataType.DATE, typeof(DateTimeOffset))] + public void TestInvalidTimestampConversion(SFDataType dataType, Type unsupportedType) + { + object unsupportedObject; + if (unsupportedType == typeof(DateTimeOffset)) + unsupportedObject = new DateTimeOffset(); + else if (unsupportedType == typeof(DateTime)) + unsupportedObject = new DateTime(); + else + unsupportedObject = null; + + Assert.NotNull(unsupportedObject); + SnowflakeDbException ex = Assert.Throws(() => SFDataConverter.CSharpValToSfVal(dataType, unsupportedObject)); + SnowflakeDbExceptionAssert.HasErrorCode(ex, SFError.INVALID_DATA_CONVERSION); + } } } diff --git a/Snowflake.Data.Tests/UnitTests/StructuredTypesTest.cs b/Snowflake.Data.Tests/UnitTests/StructuredTypesTest.cs index 0a91fdab5..cff0c6959 100644 --- a/Snowflake.Data.Tests/UnitTests/StructuredTypesTest.cs +++ b/Snowflake.Data.Tests/UnitTests/StructuredTypesTest.cs @@ -47,6 +47,18 @@ internal static IEnumerable TimeConversionCases() yield return new object[] {"2024-07-11 14:20:05.123456 -7:00", SFDataType.TIMESTAMP_LTZ.ToString(), DateTimeOffset.Parse("2024-07-11 14:20:05.123456 -7:00")}; yield return new object[] {"2024-07-11 14:20:05.123456 -7:00", SFDataType.TIMESTAMP_LTZ.ToString(), DateTime.Parse("2024-07-11 21:20:05.123456").ToLocalTime()}; yield return new object[] {"14:20:05.123456", SFDataType.TIME.ToString(), TimeSpan.Parse("14:20:05.123456")}; + yield return new object[] {"9999-12-31 23:59:59.999999", SFDataType.TIMESTAMP_NTZ.ToString(), DateTime.Parse("9999-12-31 23:59:59.999999")}; + yield return new object[] {"9999-12-31 23:59:59.999999", SFDataType.TIMESTAMP_NTZ.ToString(), DateTimeOffset.Parse("9999-12-31 23:59:59.999999Z")}; + yield return new object[] {"9999-12-31 23:59:59.999999 +1:00", SFDataType.TIMESTAMP_TZ.ToString(), DateTimeOffset.Parse("9999-12-31 23:59:59.999999 +1:00")}; + yield return new object[] {"9999-12-31 23:59:59.999999 +1:00", SFDataType.TIMESTAMP_TZ.ToString(), DateTime.SpecifyKind(DateTime.Parse("9999-12-31 22:59:59.999999"), DateTimeKind.Utc)}; + yield return new object[] {"9999-12-31 23:59:59.999999 +1:00", SFDataType.TIMESTAMP_LTZ.ToString(), DateTimeOffset.Parse("9999-12-31 23:59:59.999999 +1:00")}; + yield return new object[] {"9999-12-31 23:59:59.999999 +13:00", SFDataType.TIMESTAMP_LTZ.ToString(), DateTime.Parse("9999-12-31 10:59:59.999999").ToLocalTime()}; + yield return new object[] {"0001-01-01 00:00:00.123456", SFDataType.TIMESTAMP_NTZ.ToString(), DateTime.Parse("0001-01-01 00:00:00.123456")}; + yield return new object[] {"0001-01-01 00:00:00.123456", SFDataType.TIMESTAMP_NTZ.ToString(), DateTimeOffset.Parse("0001-01-01 00:00:00.123456Z")}; + yield return new object[] {"0001-01-01 00:00:00.123456 -1:00", SFDataType.TIMESTAMP_TZ.ToString(), DateTimeOffset.Parse("0001-01-01 00:00:00.123456 -1:00")}; + yield return new object[] {"0001-01-01 00:00:00.123456 -1:00", SFDataType.TIMESTAMP_TZ.ToString(), DateTime.SpecifyKind(DateTime.Parse("0001-01-01 01:00:00.123456"), DateTimeKind.Utc)}; + yield return new object[] {"0001-01-01 00:00:00.123456 -1:00", SFDataType.TIMESTAMP_LTZ.ToString(), DateTimeOffset.Parse("0001-01-01 00:00:00.123456 -1:00")}; + yield return new object[] {"0001-01-01 00:00:00.123456 -13:00", SFDataType.TIMESTAMP_LTZ.ToString(), DateTime.Parse("0001-01-01 13:00:00.123456").ToLocalTime()}; } } } diff --git a/Snowflake.Data/Client/SnowflakeDbCommand.cs b/Snowflake.Data/Client/SnowflakeDbCommand.cs index b52d53643..68d3dccb0 100755 --- a/Snowflake.Data/Client/SnowflakeDbCommand.cs +++ b/Snowflake.Data/Client/SnowflakeDbCommand.cs @@ -393,7 +393,7 @@ private static Dictionary convertToBindList(List typeAndVal = SFDataConverter - .csharpTypeValToSfTypeVal(parameter.DbType, val); + .CSharpTypeValToSfTypeVal(parameter.DbType, val); bindingType = typeAndVal.Item1; vals.Add(typeAndVal.Item2); @@ -401,7 +401,7 @@ private static Dictionary convertToBindList(List convertToBindList(List typeAndVal = SFDataConverter - .csharpTypeValToSfTypeVal(parameter.DbType, parameter.Value); + .CSharpTypeValToSfTypeVal(parameter.DbType, parameter.Value); bindingType = typeAndVal.Item1; bindingVal = typeAndVal.Item2; } else { bindingType = parameter.SFDataType.ToString(); - bindingVal = SFDataConverter.csharpValToSfVal(parameter.SFDataType, parameter.Value); + bindingVal = SFDataConverter.CSharpValToSfVal(parameter.SFDataType, parameter.Value); } } diff --git a/Snowflake.Data/Core/ArrowResultSet.cs b/Snowflake.Data/Core/ArrowResultSet.cs index a3a6e2628..178531eaf 100755 --- a/Snowflake.Data/Core/ArrowResultSet.cs +++ b/Snowflake.Data/Core/ArrowResultSet.cs @@ -392,7 +392,7 @@ internal override string GetString(int ordinal) return ret; case DateTime ret: if (type == SFDataType.DATE) - return SFDataConverter.toDateString(ret, sfResultSetMetaData.dateOutputFormat); + return SFDataConverter.ToDateString(ret, sfResultSetMetaData.dateOutputFormat); break; } diff --git a/Snowflake.Data/Core/SFBindUploader.cs b/Snowflake.Data/Core/SFBindUploader.cs index 6268c724c..400c3b0c9 100644 --- a/Snowflake.Data/Core/SFBindUploader.cs +++ b/Snowflake.Data/Core/SFBindUploader.cs @@ -251,26 +251,38 @@ internal string GetCSVData(string sType, string sValue) return '"' + sValue.Replace("\"", "\"\"") + '"'; return sValue; case "DATE": - long msFromEpoch = long.Parse(sValue); // SFDateConverter.csharpValToSfVal provides in [ms] from Epoch + long msFromEpoch = long.Parse(sValue); // SFDateConverter.CSharpValToSfVal provides in [ms] from Epoch DateTime date = epoch.AddMilliseconds(msFromEpoch); return date.ToShortDateString(); case "TIME": - long nsSinceMidnight = long.Parse(sValue); // SFDateConverter.csharpValToSfVal provides in [ns] from Midnight + long nsSinceMidnight = long.Parse(sValue); // SFDateConverter.CSharpValToSfVal provides in [ns] from Midnight DateTime time = epoch.AddTicks(nsSinceMidnight/100); return time.ToString("HH:mm:ss.fffffff"); case "TIMESTAMP_LTZ": - long nsFromEpochLtz = long.Parse(sValue); // SFDateConverter.csharpValToSfVal provides in [ns] from Epoch - DateTime ltz = epoch.AddTicks(nsFromEpochLtz/100); + long ticksFromEpochLtz = + long.TryParse(sValue, out var nsLtz) + ? nsLtz / 100 + : (long)(decimal.Parse(sValue) / 100); + + DateTime ltz = epoch.AddTicks(ticksFromEpochLtz); return ltz.ToLocalTime().ToString("O"); // ISO 8601 format case "TIMESTAMP_NTZ": - long nsFromEpochNtz = long.Parse(sValue); // SFDateConverter.csharpValToSfVal provides in [ns] from Epoch - DateTime ntz = epoch.AddTicks(nsFromEpochNtz/100); + long ticksFromEpochNtz = + long.TryParse(sValue, out var nsNtz) + ? nsNtz / 100 + : (long)(decimal.Parse(sValue) / 100); + + DateTime ntz = epoch.AddTicks(ticksFromEpochNtz); return ntz.ToString("yyyy-MM-dd HH:mm:ss.fffffff"); case "TIMESTAMP_TZ": string[] tstzString = sValue.Split(' '); - long nsFromEpochTz = long.Parse(tstzString[0]); // SFDateConverter provides in [ns] from Epoch + long ticksFromEpochTz = + long.TryParse(tstzString[0], out var nsTz) + ? nsTz / 100 + : (long)(decimal.Parse(tstzString[0]) / 100); + int timeZoneOffset = int.Parse(tstzString[1]) - 1440; // SFDateConverter provides in minutes increased by 1440m - DateTime timestamp = epoch.AddTicks(nsFromEpochTz/100).AddMinutes(timeZoneOffset); + DateTime timestamp = epoch.AddTicks(ticksFromEpochTz).AddMinutes(timeZoneOffset); TimeSpan offset = TimeSpan.FromMinutes(timeZoneOffset); DateTimeOffset tzDateTimeOffset = new DateTimeOffset(timestamp.Ticks, offset); return tzDateTimeOffset.ToString("yyyy-MM-dd HH:mm:ss.fffffff zzz"); diff --git a/Snowflake.Data/Core/SFDataConverter.cs b/Snowflake.Data/Core/SFDataConverter.cs index 90e956314..619976400 100755 --- a/Snowflake.Data/Core/SFDataConverter.cs +++ b/Snowflake.Data/Core/SFDataConverter.cs @@ -152,7 +152,7 @@ private static DateTime ConvertToDateTime(UTF8Buffer srcVal, SFDataType srcType) { case SFDataType.DATE: long srcValLong = FastParser.FastParseInt64(srcVal.Buffer, srcVal.offset, srcVal.length); - return DateTime.SpecifyKind(UnixEpoch.AddDays(srcValLong), DateTimeKind.Unspecified);; + return DateTime.SpecifyKind(UnixEpoch.AddDays(srcValLong), DateTimeKind.Unspecified); case SFDataType.TIME: case SFDataType.TIMESTAMP_NTZ: @@ -240,7 +240,7 @@ private static long GetTicksFromSecondAndNanosecond(UTF8Buffer srcVal) } - internal static Tuple csharpTypeValToSfTypeVal(DbType srcType, object srcVal) + internal static Tuple CSharpTypeValToSfTypeVal(DbType srcType, object srcVal) { SFDataType destType; string destVal; @@ -300,7 +300,7 @@ internal static Tuple csharpTypeValToSfTypeVal(DbType srcType, o default: throw new SnowflakeDbException(SFError.UNSUPPORTED_DOTNET_TYPE, srcType); } - destVal = csharpValToSfVal(destType, srcVal); + destVal = CSharpValToSfVal(destType, srcVal); return Tuple.Create(destType.ToString(), destVal); } @@ -323,7 +323,7 @@ internal static byte[] HexToBytes(string hex) return bytes; } - internal static string csharpValToSfVal(SFDataType sfDataType, object srcVal) + internal static string CSharpValToSfVal(SFDataType sfDataType, object srcVal) { string destVal = null; @@ -331,18 +331,6 @@ internal static string csharpValToSfVal(SFDataType sfDataType, object srcVal) { switch (sfDataType) { - case SFDataType.TIMESTAMP_LTZ: - if (srcVal.GetType() != typeof(DateTimeOffset)) - { - throw new SnowflakeDbException(SFError.INVALID_DATA_CONVERSION, srcVal, - srcVal.GetType().ToString(), SFDataType.TIMESTAMP_LTZ.ToString()); - } - else - { - destVal = ((long)(((DateTimeOffset)srcVal).UtcTicks - UnixEpoch.Ticks) * 100).ToString(); - } - break; - case SFDataType.FIXED: case SFDataType.BOOLEAN: case SFDataType.REAL: @@ -359,9 +347,8 @@ internal static string csharpValToSfVal(SFDataType sfDataType, object srcVal) else { DateTime srcDt = ((DateTime)srcVal); - long nanoSinceMidNight = (long)(srcDt.Ticks - srcDt.Date.Ticks) * 100L; - - destVal = nanoSinceMidNight.ToString(); + var tickDiff = srcDt.Ticks - srcDt.Date.Ticks; + destVal = TicksToNanoSecondsString(tickDiff); } break; @@ -380,6 +367,19 @@ internal static string csharpValToSfVal(SFDataType sfDataType, object srcVal) } break; + case SFDataType.TIMESTAMP_LTZ: + if (srcVal.GetType() != typeof(DateTimeOffset)) + { + throw new SnowflakeDbException(SFError.INVALID_DATA_CONVERSION, srcVal, + srcVal.GetType().ToString(), SFDataType.TIMESTAMP_LTZ.ToString()); + } + else + { + var tickDiff = ((DateTimeOffset)srcVal).UtcTicks - UnixEpoch.Ticks; + destVal = TicksToNanoSecondsString(tickDiff); + } + break; + case SFDataType.TIMESTAMP_NTZ: if (srcVal.GetType() != typeof(DateTime)) { @@ -391,7 +391,7 @@ internal static string csharpValToSfVal(SFDataType sfDataType, object srcVal) DateTime srcDt = (DateTime)srcVal; var diff = srcDt.Subtract(UnixEpoch); var tickDiff = diff.Ticks; - destVal = $"{tickDiff}00"; // Cannot multiple tickDiff by 100 because long might overflow. + destVal = TicksToNanoSecondsString(tickDiff); } break; @@ -404,8 +404,8 @@ internal static string csharpValToSfVal(SFDataType sfDataType, object srcVal) else { DateTimeOffset dtOffset = (DateTimeOffset)srcVal; - destVal = String.Format("{0} {1}", (dtOffset.UtcTicks - UnixEpoch.Ticks) * 100L, - dtOffset.Offset.TotalMinutes + 1440); + var tickDiff = dtOffset.UtcTicks - UnixEpoch.Ticks; + destVal = $"{TicksToNanoSecondsString(tickDiff)} {dtOffset.Offset.TotalMinutes + 1440}"; } break; @@ -429,7 +429,9 @@ internal static string csharpValToSfVal(SFDataType sfDataType, object srcVal) return destVal; } - internal static string toDateString(DateTime date, string formatter) + private static string TicksToNanoSecondsString(long tickDiff) => tickDiff == 0 ? "0" : $"{tickDiff}00"; + + internal static string ToDateString(DateTime date, string formatter) { // change formatter from "YYYY-MM-DD" to "yyyy-MM-dd" formatter = formatter.Replace("Y", "y").Replace("m", "M").Replace("D", "d"); diff --git a/Snowflake.Data/Core/SFResultSet.cs b/Snowflake.Data/Core/SFResultSet.cs index a7586f2c3..e81db8c14 100755 --- a/Snowflake.Data/Core/SFResultSet.cs +++ b/Snowflake.Data/Core/SFResultSet.cs @@ -283,7 +283,7 @@ internal override string GetString(int ordinal) var val = GetValue(ordinal); if (val == DBNull.Value) return null; - return SFDataConverter.toDateString((DateTime)val, sfResultSetMetaData.dateOutputFormat); + return SFDataConverter.ToDateString((DateTime)val, sfResultSetMetaData.dateOutputFormat); default: return GetObjectInternal(ordinal).SafeToString(); From 9ae511963e0073acf873d1f5564f8571c875da38 Mon Sep 17 00:00:00 2001 From: Waleed Fateem <72769898+sfc-gh-wfateem@users.noreply.github.com> Date: Tue, 22 Oct 2024 10:51:29 -0500 Subject: [PATCH 08/20] SNOW-1675321 Remove Account Identifier Question From Templates (#1032) --- .github/ISSUE_TEMPLATE/BUG_REPORT.md | 4 +++- .github/ISSUE_TEMPLATE/FEATURE_REQUEST.md | 2 -- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/BUG_REPORT.md b/.github/ISSUE_TEMPLATE/BUG_REPORT.md index 0264ced87..503d6a776 100644 --- a/.github/ISSUE_TEMPLATE/BUG_REPORT.md +++ b/.github/ISSUE_TEMPLATE/BUG_REPORT.md @@ -38,4 +38,6 @@ In order to accurately debug the issue this information is required. Thanks! https://community.snowflake.com/s/article/How-to-generate-log-file-on-Snowflake-connectors There is an example in READMD.md file showing you how to enable logging. -7. What is your Snowflake account identifier, if any? (Optional) + + Before sharing any information, please be sure to review the log and remove any sensitive + information. diff --git a/.github/ISSUE_TEMPLATE/FEATURE_REQUEST.md b/.github/ISSUE_TEMPLATE/FEATURE_REQUEST.md index 4cbd59985..147279709 100644 --- a/.github/ISSUE_TEMPLATE/FEATURE_REQUEST.md +++ b/.github/ISSUE_TEMPLATE/FEATURE_REQUEST.md @@ -18,5 +18,3 @@ otherwise continue here. ## References, Other Background -## What is your Snowflake account identifier, if any? - From 898a2f05337ec626986a80c114343b6437670f7b Mon Sep 17 00:00:00 2001 From: Dariusz Stempniak Date: Thu, 24 Oct 2024 14:53:50 +0200 Subject: [PATCH 09/20] SNOW-1739483 Fix S3 exception casting and flaky tests (#1042) --- .../IntegrationTests/SFConnectionIT.cs | 8 +- Snowflake.Data.Tests/Mock/MockS3Client.cs | 21 ++-- .../UnitTests/SFS3ClientTest.cs | 54 +++------ .../FileTransfer/StorageClient/SFS3Client.cs | 104 +++++++++++------- Snowflake.Data/Core/HttpUtil.cs | 29 +++-- 5 files changed, 117 insertions(+), 99 deletions(-) diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index 6f3c87291..554d0c2a9 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -514,8 +514,8 @@ public void TestDefaultLoginTimeout() // Should timeout after the default timeout (300 sec) Assert.GreaterOrEqual(stopwatch.ElapsedMilliseconds, conn.ConnectionTimeout * 1000 - delta); - // But never more because there's no connection timeout remaining - Assert.LessOrEqual(stopwatch.ElapsedMilliseconds, (conn.ConnectionTimeout + 1) * 1000); + // But never more because there's no connection timeout remaining (with 2 seconds margin) + Assert.LessOrEqual(stopwatch.ElapsedMilliseconds, (conn.ConnectionTimeout + 2) * 1000); } } } @@ -2015,8 +2015,8 @@ public void TestAsyncDefaultLoginTimeout() // Should timeout after the default timeout (300 sec) Assert.GreaterOrEqual(stopwatch.ElapsedMilliseconds, conn.ConnectionTimeout * 1000 - delta); - // But never more because there's no connection timeout remaining - Assert.LessOrEqual(stopwatch.ElapsedMilliseconds, (conn.ConnectionTimeout + 1) * 1000); + // But never more because there's no connection timeout remaining (with 2 seconds margin) + Assert.LessOrEqual(stopwatch.ElapsedMilliseconds, (conn.ConnectionTimeout + 2) * 1000); Assert.AreEqual(ConnectionState.Closed, conn.State); Assert.AreEqual(SFSessionHttpClientProperties.DefaultRetryTimeout.TotalSeconds, conn.ConnectionTimeout); diff --git a/Snowflake.Data.Tests/Mock/MockS3Client.cs b/Snowflake.Data.Tests/Mock/MockS3Client.cs index 02d08ca63..8a17efd30 100644 --- a/Snowflake.Data.Tests/Mock/MockS3Client.cs +++ b/Snowflake.Data.Tests/Mock/MockS3Client.cs @@ -35,22 +35,23 @@ class MockS3Client internal const int ContentLength = 9999; // Create AWS exception for mock requests - static Exception CreateMockAwsResponseError(string errorCode, bool isAsync) + static Exception CreateMockAwsResponseError(string awsErrorCode, bool isAsync) { - AmazonS3Exception awsError = new AmazonS3Exception(S3ErrorMessage); - awsError.ErrorCode = errorCode; + Exception exception = awsErrorCode.Length > 0 + ? new AmazonS3Exception(S3ErrorMessage) { ErrorCode = awsErrorCode } + : new Exception("Non-AWS exception"); if (isAsync) { - return awsError; // S3 throws the AmazonS3Exception on async calls + return exception; // S3 throws the AmazonS3Exception on async calls } - Exception exceptionContainingS3Error = new Exception(S3ErrorMessage, awsError); + Exception exceptionContainingS3Error = new Exception(S3ErrorMessage, exception); return exceptionContainingS3Error; // S3 places the AmazonS3Exception on the InnerException property on non-async calls } // Create mock response for GetFileHeader - static internal Task CreateResponseForGetFileHeader(string statusCode, bool isAsync) + internal static Task CreateResponseForGetFileHeader(string statusCode, bool isAsync) { if (statusCode == HttpStatusCode.OK.ToString()) { @@ -70,20 +71,20 @@ static internal Task CreateResponseForGetFileHeader(string st } // Create mock response for UploadFile - static internal Task CreateResponseForUploadFile(string statusCode, bool isAsync) + internal static Task CreateResponseForUploadFile(string awsStatusCode, bool isAsync) { - if (statusCode == HttpStatusCode.OK.ToString()) + if (awsStatusCode == AwsStatusOk) { return Task.FromResult(new PutObjectResponse()); } else { - throw CreateMockAwsResponseError(statusCode, isAsync); + throw CreateMockAwsResponseError(awsStatusCode, isAsync); } } // Create mock response for DownloadFile - static internal Task CreateResponseForDownloadFile(string statusCode, bool isAsync) + internal static Task CreateResponseForDownloadFile(string statusCode, bool isAsync) { if (statusCode == HttpStatusCode.OK.ToString()) { diff --git a/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs b/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs index 50faae758..54647db8b 100644 --- a/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs @@ -127,18 +127,15 @@ public void TestExtractBucketNameAndPath() [TestCase(SFS3Client.EXPIRED_TOKEN, ResultStatus.RENEW_TOKEN)] [TestCase(SFS3Client.NO_SUCH_KEY, ResultStatus.NOT_FOUND_FILE)] [TestCase(MockS3Client.AwsStatusError, ResultStatus.ERROR)] // Any error that isn't the above will return ResultStatus.ERROR - public void TestGetFileHeader(string requestKey, ResultStatus expectedResultStatus) + [TestCase("", ResultStatus.ERROR)] // For non-AWS exception will return ResultStatus.ERROR + public void TestGetFileHeader(string awsStatusCode, ResultStatus expectedResultStatus) { // Arrange var mockAmazonS3Client = new Mock(AwsKeyId, AwsSecretKey, AwsToken, _clientConfig); mockAmazonS3Client.Setup(client => client.GetObjectAsync(It.IsAny(), It.IsAny())) - .Returns((request, cancellationToken) => - { - return MockS3Client.CreateResponseForGetFileHeader(request.BucketName, false); - }); + .Returns(() => MockS3Client.CreateResponseForGetFileHeader(awsStatusCode, false)); _client = new SFS3Client(_fileMetadata.stageInfo, MaxRetry, Parallel, _proxyCredentials, mockAmazonS3Client.Object); _fileMetadata.client = _client; - _fileMetadata.stageInfo.location = requestKey; // Act FileHeader fileHeader = _client.GetFileHeader(_fileMetadata); @@ -152,18 +149,15 @@ public void TestGetFileHeader(string requestKey, ResultStatus expectedResultStat [TestCase(SFS3Client.EXPIRED_TOKEN, ResultStatus.RENEW_TOKEN)] [TestCase(SFS3Client.NO_SUCH_KEY, ResultStatus.NOT_FOUND_FILE)] [TestCase(MockS3Client.AwsStatusError, ResultStatus.ERROR)] // Any error that isn't the above will return ResultStatus.ERROR - public async Task TestGetFileHeaderAsync(string requestKey, ResultStatus expectedResultStatus) + [TestCase("", ResultStatus.ERROR)] // For non-AWS exception will return ResultStatus.ERROR + public async Task TestGetFileHeaderAsync(string awsStatusCode, ResultStatus expectedResultStatus) { // Arrange var mockAmazonS3Client = new Mock(AwsKeyId, AwsSecretKey, AwsToken, _clientConfig); mockAmazonS3Client.Setup(client => client.GetObjectAsync(It.IsAny(), It.IsAny())) - .Returns((request, cancellationToken) => - { - return MockS3Client.CreateResponseForGetFileHeader(request.BucketName, true); - }); + .Returns(() => MockS3Client.CreateResponseForGetFileHeader(awsStatusCode, true)); _client = new SFS3Client(_fileMetadata.stageInfo, MaxRetry, Parallel, _proxyCredentials, mockAmazonS3Client.Object); _fileMetadata.client = _client; - _fileMetadata.stageInfo.location = requestKey; // Act FileHeader fileHeader = await _client.GetFileHeaderAsync(_fileMetadata, _cancellationToken).ConfigureAwait(false); @@ -194,18 +188,15 @@ private void AssertForGetFileHeaderTests(ResultStatus expectedResultStatus, File [TestCase(MockS3Client.AwsStatusOk, ResultStatus.UPLOADED)] [TestCase(SFS3Client.EXPIRED_TOKEN, ResultStatus.RENEW_TOKEN)] [TestCase(MockS3Client.AwsStatusError, ResultStatus.NEED_RETRY)] // Any error that isn't the above will return ResultStatus.NEED_RETRY - public void TestUploadFile(string requestKey, ResultStatus expectedResultStatus) + [TestCase("", ResultStatus.NEED_RETRY)] // For non-AWS exception will return ResultStatus.NEED_RETRY + public void TestUploadFile(string awsStatusCode, ResultStatus expectedResultStatus) { // Arrange var mockAmazonS3Client = new Mock(AwsKeyId, AwsSecretKey, AwsToken, _clientConfig); mockAmazonS3Client.Setup(client => client.PutObjectAsync(It.IsAny(), It.IsAny())) - .Returns((request, cancellationToken) => - { - return MockS3Client.CreateResponseForUploadFile(request.BucketName, false); - }); + .Returns(() => MockS3Client.CreateResponseForUploadFile(awsStatusCode, false)); _client = new SFS3Client(_fileMetadata.stageInfo, MaxRetry, Parallel, _proxyCredentials, mockAmazonS3Client.Object); _fileMetadata.client = _client; - _fileMetadata.stageInfo.location = requestKey; _fileMetadata.uploadSize = UploadFileSize; // Act @@ -254,18 +245,15 @@ public void TestAppendHttpsToEndpointWithBrackets() [TestCase(MockS3Client.AwsStatusOk, ResultStatus.UPLOADED)] [TestCase(SFS3Client.EXPIRED_TOKEN, ResultStatus.RENEW_TOKEN)] [TestCase(MockS3Client.AwsStatusError, ResultStatus.NEED_RETRY)] // Any error that isn't the above will return ResultStatus.NEED_RETRY - public async Task TestUploadFileAsync(string requestKey, ResultStatus expectedResultStatus) + [TestCase("", ResultStatus.NEED_RETRY)] // For non-AWS exception will return ResultStatus.NEED_RETRY + public async Task TestUploadFileAsync(string awsStatusCode, ResultStatus expectedResultStatus) { // Arrange var mockAmazonS3Client = new Mock(AwsKeyId, AwsSecretKey, AwsToken, _clientConfig); mockAmazonS3Client.Setup(client => client.PutObjectAsync(It.IsAny(), It.IsAny())) - .Returns((request, cancellationToken) => - { - return MockS3Client.CreateResponseForUploadFile(request.BucketName, true); - }); + .Returns(() => MockS3Client.CreateResponseForUploadFile(awsStatusCode, true)); _client = new SFS3Client(_fileMetadata.stageInfo, MaxRetry, Parallel, _proxyCredentials, mockAmazonS3Client.Object); _fileMetadata.client = _client; - _fileMetadata.stageInfo.location = requestKey; _fileMetadata.uploadSize = UploadFileSize; // Act @@ -295,18 +283,15 @@ private void AssertForUploadFileTests(ResultStatus expectedResultStatus) [TestCase(MockS3Client.AwsStatusOk, ResultStatus.DOWNLOADED)] [TestCase(SFS3Client.EXPIRED_TOKEN, ResultStatus.RENEW_TOKEN)] [TestCase(MockS3Client.AwsStatusError, ResultStatus.NEED_RETRY)] // Any error that isn't the above will return ResultStatus.NEED_RETRY - public void TestDownloadFile(string requestKey, ResultStatus expectedResultStatus) + [TestCase("", ResultStatus.NEED_RETRY)] // For non-AWS exception will return ResultStatus.NEED_RETRY + public void TestDownloadFile(string awsStatusCode, ResultStatus expectedResultStatus) { // Arrange var mockAmazonS3Client = new Mock(AwsKeyId, AwsSecretKey, AwsToken, _clientConfig); mockAmazonS3Client.Setup(client => client.GetObjectAsync(It.IsAny(), It.IsAny())) - .Returns((request, cancellationToken) => - { - return MockS3Client.CreateResponseForDownloadFile(request.BucketName, false); - }); + .Returns(() => MockS3Client.CreateResponseForDownloadFile(awsStatusCode, false)); _client = new SFS3Client(_fileMetadata.stageInfo, MaxRetry, Parallel, _proxyCredentials, mockAmazonS3Client.Object); _fileMetadata.client = _client; - _fileMetadata.stageInfo.location = requestKey; // Act _client.DownloadFile(_fileMetadata, t_downloadFileName, Parallel); @@ -319,18 +304,15 @@ public void TestDownloadFile(string requestKey, ResultStatus expectedResultStatu [TestCase(MockS3Client.AwsStatusOk, ResultStatus.DOWNLOADED)] [TestCase(SFS3Client.EXPIRED_TOKEN, ResultStatus.RENEW_TOKEN)] [TestCase(MockS3Client.AwsStatusError, ResultStatus.NEED_RETRY)] // Any error that isn't the above will return ResultStatus.NEED_RETRY - public async Task TestDownloadFileAsync(string requestKey, ResultStatus expectedResultStatus) + [TestCase("", ResultStatus.NEED_RETRY)] // For non-AWS exception will return ResultStatus.NEED_RETRY + public async Task TestDownloadFileAsync(string awsStatusCode, ResultStatus expectedResultStatus) { // Arrange var mockAmazonS3Client = new Mock(AwsKeyId, AwsSecretKey, AwsToken, _clientConfig); mockAmazonS3Client.Setup(client => client.GetObjectAsync(It.IsAny(), It.IsAny())) - .Returns((request, cancellationToken) => - { - return MockS3Client.CreateResponseForDownloadFile(request.BucketName, true); - }); + .Returns(() => MockS3Client.CreateResponseForDownloadFile(awsStatusCode, true)); _client = new SFS3Client(_fileMetadata.stageInfo, MaxRetry, Parallel, _proxyCredentials, mockAmazonS3Client.Object); _fileMetadata.client = _client; - _fileMetadata.stageInfo.location = requestKey; // Act await _client.DownloadFileAsync(_fileMetadata, t_downloadFileName, Parallel, _cancellationToken).ConfigureAwait(false); diff --git a/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs b/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs index 60d67b5d7..b6896cc79 100644 --- a/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs +++ b/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs @@ -206,7 +206,7 @@ public FileHeader GetFileHeader(SFFileMetadata fileMetadata) } catch (Exception ex) { - fileMetadata = HandleFileHeaderErr(ex.InnerException, fileMetadata); // S3 places the AmazonS3Exception on the InnerException on non-async calls + HandleFileHeaderErr(ex.InnerException, fileMetadata); // S3 places the AmazonS3Exception on the InnerException on non-async calls return null; } } @@ -233,7 +233,7 @@ public async Task GetFileHeaderAsync(SFFileMetadata fileMetadata, Ca } catch (Exception ex) { - fileMetadata = HandleFileHeaderErr(ex, fileMetadata); // S3 throws the AmazonS3Exception on async calls + HandleFileHeaderErr(ex, fileMetadata); // S3 throws the AmazonS3Exception on async calls return null; } @@ -363,7 +363,7 @@ public void UploadFile(SFFileMetadata fileMetadata, Stream fileBytesStream, SFEn } catch (Exception ex) { - fileMetadata = HandleUploadFileErr(ex.InnerException, fileMetadata); + HandleUploadFileErr(ex.InnerException, fileMetadata); return; } @@ -391,7 +391,7 @@ public async Task UploadFileAsync(SFFileMetadata fileMetadata, Stream fileBytesS } catch (Exception ex) { - fileMetadata = HandleUploadFileErr(ex, fileMetadata); + HandleUploadFileErr(ex, fileMetadata); return; } @@ -461,7 +461,7 @@ public void DownloadFile(SFFileMetadata fileMetadata, string fullDstPath, int ma } catch (Exception ex) { - fileMetadata = HandleDownloadFileErr(ex.InnerException, fileMetadata); + HandleDownloadFileErr(ex.InnerException, fileMetadata); return; } @@ -494,7 +494,7 @@ public async Task DownloadFileAsync(SFFileMetadata fileMetadata, string fullDstP } catch (Exception ex) { - fileMetadata = HandleDownloadFileErr(ex, fileMetadata); + HandleDownloadFileErr(ex, fileMetadata); return; } @@ -519,25 +519,31 @@ private GetObjectRequest GetGetObjectRequest(ref AmazonS3Client client, SFFileMe /// /// Exception from file header. /// The file metadata. - /// The file metadata. - private SFFileMetadata HandleFileHeaderErr(Exception ex, SFFileMetadata fileMetadata) + private void HandleFileHeaderErr(Exception ex, SFFileMetadata fileMetadata) { Logger.Error("Failed to get file header: " + ex.Message); - AmazonS3Exception err = (AmazonS3Exception)ex; - if (err.ErrorCode == EXPIRED_TOKEN || err.ErrorCode == HttpStatusCode.BadRequest.ToString()) + switch (ex) { - fileMetadata.resultStatus = ResultStatus.RENEW_TOKEN.ToString(); - } - else if (err.ErrorCode == NO_SUCH_KEY) - { - fileMetadata.resultStatus = ResultStatus.NOT_FOUND_FILE.ToString(); - } - else - { - fileMetadata.resultStatus = ResultStatus.ERROR.ToString(); + case AmazonS3Exception exAws: + if (exAws.ErrorCode == EXPIRED_TOKEN || exAws.ErrorCode == HttpStatusCode.BadRequest.ToString()) + { + fileMetadata.resultStatus = ResultStatus.RENEW_TOKEN.ToString(); + } + else if (exAws.ErrorCode == NO_SUCH_KEY) + { + fileMetadata.resultStatus = ResultStatus.NOT_FOUND_FILE.ToString(); + } + else + { + fileMetadata.resultStatus = ResultStatus.ERROR.ToString(); + } + + break; + default: + fileMetadata.resultStatus = ResultStatus.ERROR.ToString(); + break; } - return fileMetadata; } /// @@ -545,22 +551,29 @@ private SFFileMetadata HandleFileHeaderErr(Exception ex, SFFileMetadata fileMeta /// /// Exception from file header. /// The file metadata. - /// The file metadata. - private SFFileMetadata HandleUploadFileErr(Exception ex, SFFileMetadata fileMetadata) + private void HandleUploadFileErr(Exception ex, SFFileMetadata fileMetadata) { Logger.Error("Failed to upload file: " + ex.Message); - AmazonS3Exception err = (AmazonS3Exception)ex; - if (err.ErrorCode == EXPIRED_TOKEN) - { - fileMetadata.resultStatus = ResultStatus.RENEW_TOKEN.ToString(); - } - else + switch (ex) { - fileMetadata.lastError = err; - fileMetadata.resultStatus = ResultStatus.NEED_RETRY.ToString(); + case AmazonS3Exception exAws: + if (exAws.ErrorCode == EXPIRED_TOKEN) + { + fileMetadata.resultStatus = ResultStatus.RENEW_TOKEN.ToString(); + } + else + { + fileMetadata.lastError = exAws; + fileMetadata.resultStatus = ResultStatus.NEED_RETRY.ToString(); + } + break; + + case Exception exOther: + fileMetadata.lastError = exOther; + fileMetadata.resultStatus = ResultStatus.NEED_RETRY.ToString(); + break; } - return fileMetadata; } /// @@ -568,22 +581,29 @@ private SFFileMetadata HandleUploadFileErr(Exception ex, SFFileMetadata fileMeta /// /// Exception from file header. /// The file metadata. - /// The file metadata. - private SFFileMetadata HandleDownloadFileErr(Exception ex, SFFileMetadata fileMetadata) + private void HandleDownloadFileErr(Exception ex, SFFileMetadata fileMetadata) { Logger.Error("Failed to download file: " + ex.Message); - AmazonS3Exception err = (AmazonS3Exception)ex; - if (err.ErrorCode == EXPIRED_TOKEN) - { - fileMetadata.resultStatus = ResultStatus.RENEW_TOKEN.ToString(); - } - else + switch (ex) { - fileMetadata.lastError = err; - fileMetadata.resultStatus = ResultStatus.NEED_RETRY.ToString(); + case AmazonS3Exception exAws: + if (exAws.ErrorCode == EXPIRED_TOKEN) + { + fileMetadata.resultStatus = ResultStatus.RENEW_TOKEN.ToString(); + } + else + { + fileMetadata.lastError = exAws; + fileMetadata.resultStatus = ResultStatus.NEED_RETRY.ToString(); + } + break; + + case Exception exOther: + fileMetadata.lastError = exOther; + fileMetadata.resultStatus = ResultStatus.NEED_RETRY.ToString(); + break; } - return fileMetadata; } } } diff --git a/Snowflake.Data/Core/HttpUtil.cs b/Snowflake.Data/Core/HttpUtil.cs index 72d18bcdd..3e779e34a 100755 --- a/Snowflake.Data/Core/HttpUtil.cs +++ b/Snowflake.Data/Core/HttpUtil.cs @@ -378,9 +378,9 @@ protected override async Task SendAsync(HttpRequestMessage UriUpdater updater = new UriUpdater(requestMessage.RequestUri, includeRetryReason); int retryCount = 0; + long startTimeInMilliseconds = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); while (true) { - try { childCts = null; @@ -401,13 +401,12 @@ protected override async Task SendAsync(HttpRequestMessage lastException = e; if (cancellationToken.IsCancellationRequested) { - logger.Debug("SF rest request timeout or explicit cancel called."); + logger.Info("SF rest request timeout or explicit cancel called."); cancellationToken.ThrowIfCancellationRequested(); } else if (childCts != null && childCts.Token.IsCancellationRequested) { - logger.Warn("Http request timeout. Retry the request"); - totalRetryTime += (int)httpTimeout.TotalSeconds; + logger.Warn($"Http request timeout. Retry the request after {backOffInSec} sec."); } else { @@ -426,6 +425,8 @@ protected override async Task SendAsync(HttpRequestMessage } } + totalRetryTime = (int)((DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() - startTimeInMilliseconds) / 1000); + if (childCts != null) { childCts.Dispose(); @@ -464,6 +465,19 @@ protected override async Task SendAsync(HttpRequestMessage logger.Info("Response returned was null."); } + if (restTimeout.TotalSeconds > 0 && totalRetryTime > restTimeout.TotalSeconds) + { + logger.Debug($"stop retry as connection_timeout {restTimeout.TotalSeconds} sec. reached"); + if (response != null) + { + return response; + } + var errorMessage = $"http request failed and connection_timeout {restTimeout.TotalSeconds} sec. reached.\n"; + errorMessage += $"Last exception encountered: {lastException}"; + logger.Error(errorMessage); + throw new OperationCanceledException(errorMessage); + } + retryCount++; if ((maxRetryCount > 0) && (retryCount > maxRetryCount)) { @@ -486,7 +500,6 @@ protected override async Task SendAsync(HttpRequestMessage logger.Debug($"Sleep {backOffInSec} seconds and then retry the request, retryCount: {retryCount}"); await Task.Delay(TimeSpan.FromSeconds(backOffInSec), cancellationToken).ConfigureAwait(false); - totalRetryTime += backOffInSec; var jitter = GetJitter(backOffInSec); @@ -504,12 +517,14 @@ protected override async Task SendAsync(HttpRequestMessage backOffInSec *= 2; } + totalRetryTime = (int)((DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() - startTimeInMilliseconds) / 1000); if ((restTimeout.TotalSeconds > 0) && (totalRetryTime + backOffInSec > restTimeout.TotalSeconds)) { // No need to wait more than necessary if it can be avoided. // If the rest timeout will be reached before the next back-off, - // then use the remaining connection timeout - backOffInSec = Math.Min(backOffInSec, (int)restTimeout.TotalSeconds - totalRetryTime); + // then use the remaining connection timeout. + // Math.Max with 0 in case totalRetryTime > restTimeout.TotalSeconds + backOffInSec = Math.Max(Math.Min(backOffInSec, (int)restTimeout.TotalSeconds - totalRetryTime), 0); } } } From f995ba6449efe12f6fe30e140ae1ee0545f36f22 Mon Sep 17 00:00:00 2001 From: Dariusz Stempniak Date: Fri, 25 Oct 2024 11:42:17 +0200 Subject: [PATCH 10/20] SNOW-1739583 Fix of rounding issue (#1045) --- Snowflake.Data/Core/HttpUtil.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Snowflake.Data/Core/HttpUtil.cs b/Snowflake.Data/Core/HttpUtil.cs index 3e779e34a..f835b7eb5 100755 --- a/Snowflake.Data/Core/HttpUtil.cs +++ b/Snowflake.Data/Core/HttpUtil.cs @@ -523,8 +523,7 @@ protected override async Task SendAsync(HttpRequestMessage // No need to wait more than necessary if it can be avoided. // If the rest timeout will be reached before the next back-off, // then use the remaining connection timeout. - // Math.Max with 0 in case totalRetryTime > restTimeout.TotalSeconds - backOffInSec = Math.Max(Math.Min(backOffInSec, (int)restTimeout.TotalSeconds - totalRetryTime), 0); + backOffInSec = Math.Min(backOffInSec, (int)restTimeout.TotalSeconds - totalRetryTime + 1); } } } From 6b1114ac3ca67915894afcae7f34390e6af47ee2 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Fri, 25 Oct 2024 13:46:45 +0200 Subject: [PATCH 11/20] SNOW-1524245 Gcm encryption using BouncyCastle (#1043) --- .../UnitTests/GcmEncryptionProviderTest.cs | 318 ++++++++++++++++++ .../Util/TestDataGenarator.cs | 23 +- .../Core/FileTransfer/EncryptionProvider.cs | 13 +- .../FileTransfer/GcmEncryptionProvider.cs | 188 +++++++++++ .../Core/FileTransfer/MaterialDescriptor.cs | 11 + .../Core/FileTransfer/SFFileMetadata.cs | 21 +- 6 files changed, 547 insertions(+), 27 deletions(-) create mode 100644 Snowflake.Data.Tests/UnitTests/GcmEncryptionProviderTest.cs create mode 100644 Snowflake.Data/Core/FileTransfer/GcmEncryptionProvider.cs create mode 100644 Snowflake.Data/Core/FileTransfer/MaterialDescriptor.cs diff --git a/Snowflake.Data.Tests/UnitTests/GcmEncryptionProviderTest.cs b/Snowflake.Data.Tests/UnitTests/GcmEncryptionProviderTest.cs new file mode 100644 index 000000000..60c0c2059 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/GcmEncryptionProviderTest.cs @@ -0,0 +1,318 @@ +using System; +using System.IO; +using System.Text; +using NUnit.Framework; +using Org.BouncyCastle.Crypto; +using Snowflake.Data.Core; +using Snowflake.Data.Core.FileTransfer; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.UnitTests +{ + [TestFixture] + public class GcmEncryptionProviderTest + { + private const string PlainText = "there is no rose without thorns"; + private static readonly byte[] s_plainTextBytes = Encoding.UTF8.GetBytes(PlainText); + private static readonly byte[] s_qsmkBytes = TestDataGenarator.NextBytes(GcmEncryptionProvider.BlockSizeInBytes); + private static readonly string s_qsmk = Convert.ToBase64String(s_qsmkBytes); + private static readonly string s_queryId = Guid.NewGuid().ToString(); + private const long SmkId = 1234L; + private const string KeyAad = "key additional information"; + private static readonly byte[] s_keyAadBytes = Encoding.UTF8.GetBytes(KeyAad); + private static readonly string s_keyAadBase64 = Convert.ToBase64String(s_keyAadBytes); + private const string ContentAad = "content additional information"; + private static readonly byte[] s_contentAadBytes = Encoding.UTF8.GetBytes(ContentAad); + private static readonly string s_contentAadBase64 = Convert.ToBase64String(s_contentAadBytes); + private const string InvalidAad = "invalid additional information"; + private static readonly byte[] s_invalidAadBytes = Encoding.UTF8.GetBytes(InvalidAad); + private static readonly string s_invalidAadBase64 = Convert.ToBase64String(s_invalidAadBytes); + private static readonly string s_emptyAad = string.Empty; + private static readonly byte[] s_emptyAadBytes = Encoding.UTF8.GetBytes(s_emptyAad); + private static readonly string s_emptyAadBase64 = Convert.ToBase64String(s_emptyAadBytes); + private static readonly PutGetEncryptionMaterial s_encryptionMaterial = new PutGetEncryptionMaterial + { + queryStageMasterKey = s_qsmk, + queryId = s_queryId, + smkId = SmkId + }; + private static readonly FileTransferConfiguration s_fileTransferConfiguration = new FileTransferConfiguration + { + TempDir = Path.GetTempPath(), + MaxBytesInMemory = FileTransferConfiguration.DefaultMaxBytesInMemory + }; + + [Test] + public void TestEncryptAndDecryptWithoutAad() + { + // arrange + SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata(); + + // act + using (var encryptedStream = GcmEncryptionProvider.Encrypt( + s_encryptionMaterial, + encryptionMetadata, // this is output parameter + s_fileTransferConfiguration, + new MemoryStream(s_plainTextBytes), + null, + null)) + { + var encryptedContent = ExtractContentBytes(encryptedStream); + + // assert + Assert.NotNull(encryptionMetadata.key); + Assert.NotNull(encryptionMetadata.iv); + Assert.NotNull(encryptionMetadata.matDesc); + Assert.IsNull(encryptionMetadata.keyAad); + Assert.IsNull(encryptionMetadata.aad); + + // act + using (var decryptedStream = GcmEncryptionProvider.Decrypt(new MemoryStream(encryptedContent), s_encryptionMaterial, encryptionMetadata, s_fileTransferConfiguration)) + { + // assert + var decryptedText = ExtractContent(decryptedStream); + CollectionAssert.AreEqual(s_plainTextBytes, decryptedText); + } + } + } + + [Test] + public void TestEncryptAndDecryptWithEmptyAad() + { + // arrange + SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata(); + + // act + using (var encryptedStream = GcmEncryptionProvider.Encrypt( + s_encryptionMaterial, + encryptionMetadata, // this is output parameter + s_fileTransferConfiguration, + new MemoryStream(s_plainTextBytes), + s_emptyAadBytes, + s_emptyAadBytes)) + { + var encryptedContent = ExtractContentBytes(encryptedStream); + + // assert + Assert.NotNull(encryptionMetadata.key); + Assert.NotNull(encryptionMetadata.iv); + Assert.NotNull(encryptionMetadata.matDesc); + Assert.AreEqual(s_emptyAadBase64, encryptionMetadata.keyAad); + Assert.AreEqual(s_emptyAadBase64, encryptionMetadata.aad); + + // act + using (var decryptedStream = GcmEncryptionProvider.Decrypt(new MemoryStream(encryptedContent), s_encryptionMaterial, encryptionMetadata, s_fileTransferConfiguration)) + { + // assert + var decryptedText = ExtractContent(decryptedStream); + CollectionAssert.AreEqual(s_plainTextBytes, decryptedText); + } + } + } + + [Test] + public void TestEncryptAndDecryptWithAad() + { + // arrange + SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata(); + + // act + using (var encryptedStream = GcmEncryptionProvider.Encrypt( + s_encryptionMaterial, + encryptionMetadata, // this is output parameter + s_fileTransferConfiguration, + new MemoryStream(s_plainTextBytes), + s_contentAadBytes, + s_keyAadBytes)) + { + var encryptedContent = ExtractContentBytes(encryptedStream); + + // assert + Assert.NotNull(encryptionMetadata.key); + Assert.NotNull(encryptionMetadata.iv); + Assert.NotNull(encryptionMetadata.matDesc); + CollectionAssert.AreEqual(s_keyAadBase64, encryptionMetadata.keyAad); + CollectionAssert.AreEqual(s_contentAadBase64, encryptionMetadata.aad); + + // act + using (var decryptedStream = GcmEncryptionProvider.Decrypt(new MemoryStream(encryptedContent), s_encryptionMaterial, encryptionMetadata, s_fileTransferConfiguration)) + { + // assert + var decryptedText = ExtractContent(decryptedStream); + CollectionAssert.AreEqual(s_plainTextBytes, decryptedText); + } + } + } + + [Test] + public void TestFailDecryptWithInvalidKeyAad() + { + // arrange + SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata(); + using (var encryptedStream = GcmEncryptionProvider.Encrypt( + s_encryptionMaterial, + encryptionMetadata, // this is output parameter + s_fileTransferConfiguration, + new MemoryStream(s_plainTextBytes), + null, + s_keyAadBytes)) + { + var encryptedContent = ExtractContentBytes(encryptedStream); + Assert.NotNull(encryptionMetadata.key); + Assert.NotNull(encryptionMetadata.iv); + Assert.NotNull(encryptionMetadata.matDesc); + CollectionAssert.AreEqual(s_keyAadBase64, encryptionMetadata.keyAad); + Assert.IsNull(encryptionMetadata.aad); + encryptionMetadata.keyAad = s_invalidAadBase64; + + // act + var thrown = Assert.Throws(() => + GcmEncryptionProvider.Decrypt(new MemoryStream(encryptedContent), s_encryptionMaterial, encryptionMetadata, s_fileTransferConfiguration)); + + // assert + Assert.AreEqual("mac check in GCM failed", thrown.Message); + } + } + + [Test] + public void TestFailDecryptWithInvalidContentAad() + { + // arrange + SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata(); + using (var encryptedStream = GcmEncryptionProvider.Encrypt( + s_encryptionMaterial, + encryptionMetadata, // this is output parameter + s_fileTransferConfiguration, + new MemoryStream(s_plainTextBytes), + s_contentAadBytes, + null)) + { + var encryptedContent = ExtractContentBytes(encryptedStream); + Assert.NotNull(encryptionMetadata.key); + Assert.NotNull(encryptionMetadata.iv); + Assert.NotNull(encryptionMetadata.matDesc); + Assert.IsNull(encryptionMetadata.keyAad); + CollectionAssert.AreEqual(s_contentAadBase64, encryptionMetadata.aad); + encryptionMetadata.aad = s_invalidAadBase64; + + // act + var thrown = Assert.Throws(() => + GcmEncryptionProvider.Decrypt(new MemoryStream(encryptedContent), s_encryptionMaterial, encryptionMetadata, s_fileTransferConfiguration)); + + // assert + Assert.AreEqual("mac check in GCM failed", thrown.Message); + } + } + + [Test] + public void TestFailDecryptWhenMissingAad() + { + // arrange + SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata(); + using (var encryptedStream = GcmEncryptionProvider.Encrypt( + s_encryptionMaterial, + encryptionMetadata, // this is output parameter + s_fileTransferConfiguration, + new MemoryStream(s_plainTextBytes), + s_contentAadBytes, + s_keyAadBytes)) + { + var encryptedContent = ExtractContentBytes(encryptedStream); + Assert.NotNull(encryptionMetadata.key); + Assert.NotNull(encryptionMetadata.iv); + Assert.NotNull(encryptionMetadata.matDesc); + CollectionAssert.AreEqual(s_keyAadBase64, encryptionMetadata.keyAad); + CollectionAssert.AreEqual(s_contentAadBase64, encryptionMetadata.aad); + encryptionMetadata.keyAad = null; + encryptionMetadata.aad = null; + + // act + var thrown = Assert.Throws(() => + GcmEncryptionProvider.Decrypt(new MemoryStream(encryptedContent), s_encryptionMaterial, encryptionMetadata,s_fileTransferConfiguration)); + + // assert + Assert.AreEqual("mac check in GCM failed", thrown.Message); + } + } + + [Test] + public void TestEncryptAndDecryptFile() + { + // arrange + SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata(); + var plainTextFilePath = Path.Combine(Path.GetTempPath(), "plaintext.txt"); + var encryptedFilePath = Path.Combine(Path.GetTempPath(), "encrypted.txt"); + try + { + CreateFile(plainTextFilePath, PlainText); + + // act + using (var encryptedStream = GcmEncryptionProvider.EncryptFile(plainTextFilePath, s_encryptionMaterial, encryptionMetadata, + s_fileTransferConfiguration, s_contentAadBytes, s_keyAadBytes)) + { + CreateFile(encryptedFilePath, encryptedStream); + } + + // assert + Assert.NotNull(encryptionMetadata.key); + Assert.NotNull(encryptionMetadata.iv); + Assert.NotNull(encryptionMetadata.matDesc); + CollectionAssert.AreEqual(s_keyAadBase64, encryptionMetadata.keyAad); + CollectionAssert.AreEqual(s_contentAadBase64, encryptionMetadata.aad); + + // act + string result; + using (var decryptedStream = GcmEncryptionProvider.DecryptFile(encryptedFilePath, s_encryptionMaterial, encryptionMetadata, + s_fileTransferConfiguration)) + { + decryptedStream.Position = 0; + var resultBytes = new byte[decryptedStream.Length]; + var bytesRead = decryptedStream.Read(resultBytes, 0, resultBytes.Length); + Assert.AreEqual(decryptedStream.Length, bytesRead); + result = Encoding.UTF8.GetString(resultBytes); + } + + // assert + CollectionAssert.AreEqual(PlainText, result); + } + finally + { + File.Delete(plainTextFilePath); + File.Delete(encryptedFilePath); + } + } + + private static void CreateFile(string filePath, string content) + { + using (var writer = File.CreateText(filePath)) + { + writer.Write(content); + } + } + + private static void CreateFile(string filePath, Stream content) + { + using (var writer = File.Create(filePath)) + { + var buffer = new byte[1024]; + int bytesRead; + content.Position = 0; + while ((bytesRead = content.Read(buffer, 0, 1024)) > 0) + { + writer.Write(buffer, 0, bytesRead); + } + } + } + + private string ExtractContent(Stream stream) => + Encoding.UTF8.GetString(ExtractContentBytes(stream)); + + private byte[] ExtractContentBytes(Stream stream) + { + var memoryStream = new MemoryStream(); + stream.Position = 0; + stream.CopyTo(memoryStream); + return memoryStream.ToArray(); + } + } +} diff --git a/Snowflake.Data.Tests/Util/TestDataGenarator.cs b/Snowflake.Data.Tests/Util/TestDataGenarator.cs index 27dda5ab0..760f1820b 100644 --- a/Snowflake.Data.Tests/Util/TestDataGenarator.cs +++ b/Snowflake.Data.Tests/Util/TestDataGenarator.cs @@ -22,7 +22,7 @@ public class TestDataGenarator public static char SnowflakeUnicode => '\u2744'; public static string EmojiUnicode => "\uD83D\uDE00"; public static string StringWithUnicode => AsciiCodes + SnowflakeUnicode + EmojiUnicode; - + public static bool NextBool() { return s_random.Next(0, 1) == 1; @@ -32,7 +32,7 @@ public static int NextInt(int minValueInclusive, int maxValueExclusive) { return s_random.Next(minValueInclusive, maxValueExclusive); } - + public static string NextAlphaNumeric() { return NextAlphaNumeric(s_random.Next(5, 12)); @@ -72,17 +72,24 @@ public static string NextDigitsString(int length) } return new string(buffer); } - + + public static byte[] NextBytes(int length) + { + var buffer = new byte[length]; + s_random.NextBytes(buffer); + return buffer; + } + private static char NextAlphaNumericChar() => NextChar(s_alphanumericChars); - + public static string NextNonZeroDigitAsString() => NextNonZeroDigitChar().ToString(); private static char NextNonZeroDigitChar() => NextChar(s_nonZeroDigits); - - private static string NextDigitAsString() => NextDigitChar().ToString(); - + + private static string NextDigitAsString() => NextDigitChar().ToString(); + private static char NextDigitChar() => NextChar(s_digitChars); - + private static string NextLetterAsString() => NextLetterChar().ToString(); private static char NextLetterChar() => NextChar(s_letterChars); diff --git a/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs b/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs index 411a6eeab..b625f80d3 100644 --- a/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs +++ b/Snowflake.Data/Core/FileTransfer/EncryptionProvider.cs @@ -9,20 +9,9 @@ namespace Snowflake.Data.Core.FileTransfer { - /// - /// The encryption materials. - /// - internal class MaterialDescriptor - { - public string smkId { get; set; } - - public string queryId { get; set; } - - public string keySize { get; set; } - } - /// /// The encryptor/decryptor for PUT/GET files. + /// Handles encryption and decryption using AES CBC (for files) and ECB (for keys). /// class EncryptionProvider { diff --git a/Snowflake.Data/Core/FileTransfer/GcmEncryptionProvider.cs b/Snowflake.Data/Core/FileTransfer/GcmEncryptionProvider.cs new file mode 100644 index 000000000..50b80dd05 --- /dev/null +++ b/Snowflake.Data/Core/FileTransfer/GcmEncryptionProvider.cs @@ -0,0 +1,188 @@ +using System; +using System.IO; +using Org.BouncyCastle.Crypto; +using Org.BouncyCastle.Crypto.IO; +using Org.BouncyCastle.Crypto.Parameters; +using Org.BouncyCastle.Security; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Core.FileTransfer +{ + internal class GcmEncryptionProvider + { + private const int AesBlockSize = 128; + internal const int BlockSizeInBytes = AesBlockSize / 8; + private const string AesGcmNoPaddingCipher = "AES/GCM/NoPadding"; + + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private static readonly SecureRandom s_random = SecureRandom.GetInstance("SHA1PRNG"); + + public static Stream EncryptFile( + string inFile, + PutGetEncryptionMaterial encryptionMaterial, + SFEncryptionMetadata encryptionMetadata, + FileTransferConfiguration transferConfiguration, + byte[] contentAad, + byte[] keyAad + ) + { + using (var fileStream = File.OpenRead(inFile)) + { + return Encrypt(encryptionMaterial, encryptionMetadata, transferConfiguration, fileStream, contentAad, keyAad); + } + } + + public static Stream DecryptFile( + string inFile, + PutGetEncryptionMaterial encryptionMaterial, + SFEncryptionMetadata encryptionMetadata, + FileTransferConfiguration transferConfiguration) + { + using (var fileStream = File.OpenRead(inFile)) + { + return Decrypt(fileStream, encryptionMaterial, encryptionMetadata, transferConfiguration); + } + } + + public static Stream Encrypt( + PutGetEncryptionMaterial encryptionMaterial, + SFEncryptionMetadata encryptionMetadata, // this is output parameter + FileTransferConfiguration fileTransferConfiguration, + Stream inputStream, + byte[] contentAad, + byte[] keyAad) + { + byte[] decodedMasterKey = Convert.FromBase64String(encryptionMaterial.queryStageMasterKey); + int masterKeySize = decodedMasterKey.Length; + s_logger.Debug($"Master key size : {masterKeySize}"); + + var contentIV = new byte[BlockSizeInBytes]; + var keyIV = new byte[BlockSizeInBytes]; + var fileKeyBytes = new byte[masterKeySize]; // we choose a random fileKey to encrypt it with qsmk key with GCM + s_random.NextBytes(contentIV); + s_random.NextBytes(keyIV); + s_random.NextBytes(fileKeyBytes); + + var encryptedKey = EncryptKey(fileKeyBytes, decodedMasterKey, keyIV, keyAad); + var result = EncryptContent(inputStream, fileKeyBytes, contentIV, contentAad, fileTransferConfiguration); + + MaterialDescriptor matDesc = new MaterialDescriptor + { + smkId = encryptionMaterial.smkId.ToString(), + queryId = encryptionMaterial.queryId, + keySize = (masterKeySize * 8).ToString() + }; + + encryptionMetadata.key = Convert.ToBase64String(encryptedKey); + encryptionMetadata.iv = Convert.ToBase64String(contentIV); + encryptionMetadata.keyIV = Convert.ToBase64String(keyIV); + encryptionMetadata.keyAad = keyAad == null ? null : Convert.ToBase64String(keyAad); + encryptionMetadata.aad = contentAad == null ? null : Convert.ToBase64String(contentAad); + encryptionMetadata.matDesc = Newtonsoft.Json.JsonConvert.SerializeObject(matDesc); + + return result; + } + + public static Stream Decrypt( + Stream inputStream, + PutGetEncryptionMaterial encryptionMaterial, + SFEncryptionMetadata encryptionMetadata, + FileTransferConfiguration fileTransferConfiguration) + { + var decodedMasterKey = Convert.FromBase64String(encryptionMaterial.queryStageMasterKey); + var keyBytes = Convert.FromBase64String(encryptionMetadata.key); + var keyIVBytes = Convert.FromBase64String(encryptionMetadata.keyIV); + var keyAad = encryptionMetadata.keyAad == null ? null : Convert.FromBase64String(encryptionMetadata.keyAad); + var ivBytes = Convert.FromBase64String(encryptionMetadata.iv); + var contentAad = encryptionMetadata.aad == null ? null : Convert.FromBase64String(encryptionMetadata.aad); + var decryptedFileKey = DecryptKey(keyBytes, decodedMasterKey, keyIVBytes, keyAad); + return DecryptContent(inputStream, decryptedFileKey, ivBytes, contentAad, fileTransferConfiguration); + } + + private static byte[] EncryptKey(byte[] fileKeyBytes, byte[] qsmk, byte[] keyIV, byte[] keyAad) + { + var keyCipher = BuildAesGcmNoPaddingCipher(true, qsmk, keyIV, keyAad); + var cipherKeyData = new byte[keyCipher.GetOutputSize(fileKeyBytes.Length)]; + var processLength = keyCipher.ProcessBytes(fileKeyBytes, 0, fileKeyBytes.Length, cipherKeyData, 0); + keyCipher.DoFinal(cipherKeyData, processLength); + return cipherKeyData; + } + + private static Stream EncryptContent(Stream inputStream, byte[] fileKeyBytes, byte[] contentIV, byte[] contentAad, + FileTransferConfiguration transferConfiguration) + { + var contentCipher = BuildAesGcmNoPaddingCipher(true, fileKeyBytes, contentIV, contentAad); + var targetStream = new FileBackedOutputStream(transferConfiguration.MaxBytesInMemory, transferConfiguration.TempDir); + try + { + var cipherStream = new CipherStream(targetStream, null, contentCipher); + byte[] buffer = new byte[transferConfiguration.MaxBytesInMemory]; + int bytesRead; + while ((bytesRead = inputStream.Read(buffer, 0, buffer.Length)) > 0) + { + cipherStream.Write(buffer, 0, bytesRead); + } + + cipherStream.Flush(); // we cannot close or dispose cipherStream because: 1) it would do additional DoFinal resulting in an exception 2) closing cipherStream would close target stream + var mac = contentCipher.DoFinal(); // getting authentication tag for the whole content + targetStream.Write(mac, 0, mac.Length); + return targetStream; + } + catch (Exception) + { + targetStream.Dispose(); + throw; + } + } + + private static byte[] DecryptKey(byte[] fileKey, byte[] qsmk, byte[] keyIV, byte[] keyAad) + { + var keyCipher = BuildAesGcmNoPaddingCipher(false, qsmk, keyIV, keyAad); + var decryptedKeyData = new byte[keyCipher.GetOutputSize(fileKey.Length)]; + var processLength = keyCipher.ProcessBytes(fileKey, 0, fileKey.Length, decryptedKeyData, 0); + keyCipher.DoFinal(decryptedKeyData, processLength); + return decryptedKeyData; + } + + private static Stream DecryptContent(Stream inputStream, byte[] fileKeyBytes, byte[] contentIV, byte[] contentAad, + FileTransferConfiguration transferConfiguration) + { + var contentCipher = BuildAesGcmNoPaddingCipher(false, fileKeyBytes, contentIV, contentAad); + var targetStream = new FileBackedOutputStream(transferConfiguration.MaxBytesInMemory, transferConfiguration.TempDir); + try + { + var cipherStream = new CipherStream(targetStream, null, contentCipher); + byte[] buffer = new byte[transferConfiguration.MaxBytesInMemory]; + int bytesRead; + while ((bytesRead = inputStream.Read(buffer, 0, buffer.Length)) > 0) + { + cipherStream.Write(buffer, 0, bytesRead); + } + cipherStream.Flush(); // we cannot close or dispose cipherStream because closing cipherStream would close target stream + var lastBytes = contentCipher.DoFinal(); + if (lastBytes != null && lastBytes.Length > 0) + { + targetStream.Write(lastBytes, 0, lastBytes.Length); + } + return targetStream; + } + catch (Exception) + { + targetStream.Dispose(); + throw; + } + } + + private static IBufferedCipher BuildAesGcmNoPaddingCipher(bool forEncryption, byte[] keyBytes, byte[] initialisationVector, byte[] aadData) + { + var cipher = CipherUtilities.GetCipher(AesGcmNoPaddingCipher); + KeyParameter keyParameter = new KeyParameter(keyBytes); + var keyParameterAead = aadData == null + ? new AeadParameters(keyParameter, AesBlockSize, initialisationVector) + : new AeadParameters(keyParameter, AesBlockSize, initialisationVector, aadData); + cipher.Init(forEncryption, keyParameterAead); + return cipher; + } + } +} diff --git a/Snowflake.Data/Core/FileTransfer/MaterialDescriptor.cs b/Snowflake.Data/Core/FileTransfer/MaterialDescriptor.cs new file mode 100644 index 000000000..e0b352910 --- /dev/null +++ b/Snowflake.Data/Core/FileTransfer/MaterialDescriptor.cs @@ -0,0 +1,11 @@ +namespace Snowflake.Data.Core.FileTransfer +{ + internal class MaterialDescriptor + { + public string smkId { get; set; } + + public string queryId { get; set; } + + public string keySize { get; set; } + } +} diff --git a/Snowflake.Data/Core/FileTransfer/SFFileMetadata.cs b/Snowflake.Data/Core/FileTransfer/SFFileMetadata.cs index 605de0be1..e1647257b 100644 --- a/Snowflake.Data/Core/FileTransfer/SFFileMetadata.cs +++ b/Snowflake.Data/Core/FileTransfer/SFFileMetadata.cs @@ -3,21 +3,28 @@ */ using System; -using System.Collections.Generic; using System.IO; -using System.Text; using static Snowflake.Data.Core.FileTransfer.SFFileCompressionTypes; namespace Snowflake.Data.Core.FileTransfer { public class SFEncryptionMetadata { - /// Initialization vector + /// Initialization vector for file content encryption public string iv { set; get; } /// File key public string key { set; get; } + /// Additional Authentication Data for file content encryption + public string aad { set; get; } + + /// Initialization vector for key encryption + public string keyIV { set; get; } + + /// Additional Authentication Data for key encryption + public string keyAad { set; get; } + /// Encryption material descriptor public string matDesc { set; get; } } @@ -89,7 +96,7 @@ internal class SFFileMetadata /// File message digest (after compression if required) public string sha256Digest { set; get; } - /// Source compression + /// Source compression public SFFileCompressionType sourceCompression { set; get; } /// Target compression @@ -122,9 +129,9 @@ internal class SFFileMetadata // Proxy credentials of the remote storage client. public ProxyCredentials proxyCredentials { get; set; } - + public int MaxBytesInMemory { get; set; } - + internal CommandTypes _operationType; internal string RemoteFileName() @@ -142,7 +149,7 @@ internal class FileTransferConfiguration { private const int OneMegabyteInBytes = 1024 * 1024; - + public string TempDir { get; set; } public int MaxBytesInMemory { get; set; } From 81995633744df22936f82e1d0f135b9b76f03945 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Mon, 28 Oct 2024 11:53:12 +0100 Subject: [PATCH 12/20] NOSNOW Fix operating system name (#1047) --- Snowflake.Data/Core/Tools/Diagnostics.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Snowflake.Data/Core/Tools/Diagnostics.cs b/Snowflake.Data/Core/Tools/Diagnostics.cs index dc0b19593..0e9f5b0dd 100644 --- a/Snowflake.Data/Core/Tools/Diagnostics.cs +++ b/Snowflake.Data/Core/Tools/Diagnostics.cs @@ -51,7 +51,7 @@ private static void AppendAssemblyInfo(StringBuilder info, Assembly assembly, st private static string OsName() { if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - return "UNIX"; + return "LINUX"; if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) return "WINDOWS"; if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) From fe8e84810d15f8b9f8e0f3a5aa577906fabeb351 Mon Sep 17 00:00:00 2001 From: sfc-gh-ext-simba-lf <115584722+sfc-gh-ext-simba-lf@users.noreply.github.com> Date: Mon, 28 Oct 2024 09:58:47 -0700 Subject: [PATCH 13/20] SNOW-1657238: Fix incorrect row count for rows loaded (#1044) --- .../IntegrationTests/SFDbCommandIT.cs | 90 +++++++++++++++++++ Snowflake.Data/Core/ResultSetUtil.cs | 8 +- 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs index 5aa01ee46..5950e4f9b 100755 --- a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs @@ -8,6 +8,8 @@ using System.Threading; using System.Threading.Tasks; using Snowflake.Data.Core; +using System.Linq; +using System.IO; namespace Snowflake.Data.Tests.IntegrationTests { @@ -1674,5 +1676,93 @@ public async Task TestCommandWithCommentEmbeddedAsync() Assert.AreEqual("--", reader.GetString(0)); } } + + [Test] + public void TestExecuteNonQueryReturnsCorrectRowCountForUploadWithMultipleFiles() + { + const int NumberOfFiles = 5; + const int NumberOfRows = 3; + const int ExpectedRowCount = NumberOfFiles * NumberOfRows; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + var tempFolder = $"{Path.GetTempPath()}Temp_{Guid.NewGuid()}"; + + try + { + // Arrange + Directory.CreateDirectory(tempFolder); + var data = string.Concat(Enumerable.Repeat(string.Join(",", "TestData") + "\n", NumberOfRows)); + for (int i = 0; i < NumberOfFiles; i++) + { + File.WriteAllText(Path.Combine(tempFolder, $"{TestContext.CurrentContext.Test.Name}_{i}.csv"), data); + } + CreateOrReplaceTable(conn, TableName, new[] { "COL1 STRING" }); + cmd.CommandText = $"PUT file://{Path.Combine(tempFolder, "*.csv")} @%{TableName} AUTO_COMPRESS=FALSE"; + var reader = cmd.ExecuteReader(); + + // Act + cmd.CommandText = $"COPY INTO {TableName} FROM @%{TableName} PATTERN='.*.csv' FILE_FORMAT=(TYPE=CSV)"; + int actualRowCount = cmd.ExecuteNonQuery(); + + // Assert + Assert.AreEqual(ExpectedRowCount, actualRowCount); + } + finally + { + Directory.Delete(tempFolder, true); + } + } + } + } + + [Test] + public async Task TestExecuteNonQueryAsyncReturnsCorrectRowCountForUploadWithMultipleFiles() + { + const int NumberOfFiles = 5; + const int NumberOfRows = 3; + const int ExpectedRowCount = NumberOfFiles * NumberOfRows; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + var tempFolder = $"{Path.GetTempPath()}Temp_{Guid.NewGuid()}"; + + try + { + // Arrange + Directory.CreateDirectory(tempFolder); + var data = string.Concat(Enumerable.Repeat(string.Join(",", "TestData") + "\n", NumberOfRows)); + for (int i = 0; i < NumberOfFiles; i++) + { + File.WriteAllText(Path.Combine(tempFolder, $"{TestContext.CurrentContext.Test.Name}_{i}.csv"), data); + } + CreateOrReplaceTable(conn, TableName, new[] { "COL1 STRING" }); + cmd.CommandText = $"PUT file://{Path.Combine(tempFolder, "*.csv")} @%{TableName} AUTO_COMPRESS=FALSE"; + var reader = cmd.ExecuteReader(); + + // Act + cmd.CommandText = $"COPY INTO {TableName} FROM @%{TableName} PATTERN='.*.csv' FILE_FORMAT=(TYPE=CSV)"; + int actualRowCount = await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); + + // Assert + Assert.AreEqual(ExpectedRowCount, actualRowCount); + } + finally + { + Directory.Delete(tempFolder, true); + } + } + } + } } } diff --git a/Snowflake.Data/Core/ResultSetUtil.cs b/Snowflake.Data/Core/ResultSetUtil.cs index 9d62a17d7..236efab9c 100755 --- a/Snowflake.Data/Core/ResultSetUtil.cs +++ b/Snowflake.Data/Core/ResultSetUtil.cs @@ -36,9 +36,11 @@ internal static int CalculateUpdateCount(this SFBaseResultSet resultSet) var index = resultSet.sfResultSetMetaData.GetColumnIndexByName("rows_loaded"); if (index >= 0) { - resultSet.Next(); - updateCount = resultSet.GetInt64(index); - resultSet.Rewind(); + while (resultSet.Next()) + { + updateCount += resultSet.GetInt64(index); + } + while (resultSet.Rewind()) {} } break; case SFStatementType.COPY_UNLOAD: From b21bfb34a716062f6a7425d74325ac00dc6b6963 Mon Sep 17 00:00:00 2001 From: Dariusz Stempniak Date: Mon, 28 Oct 2024 22:04:16 +0100 Subject: [PATCH 14/20] SNOW-1739483 improve calculation of time to wait before retry (#1046) --- Snowflake.Data/Core/HttpUtil.cs | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/Snowflake.Data/Core/HttpUtil.cs b/Snowflake.Data/Core/HttpUtil.cs index f835b7eb5..fa52348e8 100755 --- a/Snowflake.Data/Core/HttpUtil.cs +++ b/Snowflake.Data/Core/HttpUtil.cs @@ -406,7 +406,7 @@ protected override async Task SendAsync(HttpRequestMessage } else if (childCts != null && childCts.Token.IsCancellationRequested) { - logger.Warn($"Http request timeout. Retry the request after {backOffInSec} sec."); + logger.Warn($"Http request timeout. Retry the request after max {backOffInSec} sec."); } else { @@ -465,7 +465,7 @@ protected override async Task SendAsync(HttpRequestMessage logger.Info("Response returned was null."); } - if (restTimeout.TotalSeconds > 0 && totalRetryTime > restTimeout.TotalSeconds) + if (restTimeout.TotalSeconds > 0 && totalRetryTime >= restTimeout.TotalSeconds) { logger.Debug($"stop retry as connection_timeout {restTimeout.TotalSeconds} sec. reached"); if (response != null) @@ -478,6 +478,12 @@ protected override async Task SendAsync(HttpRequestMessage throw new OperationCanceledException(errorMessage); } + if (restTimeout.TotalSeconds > 0 && totalRetryTime + backOffInSec > restTimeout.TotalSeconds) + { + // No need to wait more than necessary if it can be avoided. + backOffInSec = (int)restTimeout.TotalSeconds - totalRetryTime; + } + retryCount++; if ((maxRetryCount > 0) && (retryCount > maxRetryCount)) { @@ -516,15 +522,6 @@ protected override async Task SendAsync(HttpRequestMessage // Multiply sleep by 2 for non-login requests backOffInSec *= 2; } - - totalRetryTime = (int)((DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() - startTimeInMilliseconds) / 1000); - if ((restTimeout.TotalSeconds > 0) && (totalRetryTime + backOffInSec > restTimeout.TotalSeconds)) - { - // No need to wait more than necessary if it can be avoided. - // If the rest timeout will be reached before the next back-off, - // then use the remaining connection timeout. - backOffInSec = Math.Min(backOffInSec, (int)restTimeout.TotalSeconds - totalRetryTime + 1); - } } } } From 5e166938f691b15f548ea7b577800e89a4a1b2a7 Mon Sep 17 00:00:00 2001 From: Dominik Przybysz <132913826+sfc-gh-dprzybysz@users.noreply.github.com> Date: Wed, 30 Oct 2024 11:12:10 +0100 Subject: [PATCH 15/20] SNOW-1756807: Add note about GCP regional endpoints (#1051) --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 0378b5416..408563f25 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,8 @@ Read more in [certificate validation](doc/CertficateValidation.md) docs. were not performed where the insecureMode flag was set to false, which is the default setting. From version v2.1.5 CRL is working back as intended. +5. This driver currently does not support GCP regional endpoints. Please ensure that any workloads using through this driver do not require support for regional endpoints on GCP. If you have questions about this, please contact Snowflake Support. + Note that the driver is now targeting .NET Standard 2.0. When upgrading, you might also need to run “Update-Package -reinstall” to update the dependencies. See more: From d8e4b63339d4b9f5500b8fcdcef87e4ca7a38f5a Mon Sep 17 00:00:00 2001 From: Dariusz Stempniak Date: Thu, 31 Oct 2024 14:44:28 +0100 Subject: [PATCH 16/20] SNOW-1672654 Support for empty encryptionMaterial (#1048) --- .../IntegrationTests/SFPutGetTest.cs | 110 +++++++++++------- .../UnitTests/SFAzureClientTest.cs | 1 - .../UnitTests/SFGCSClientTest.cs | 1 - .../UnitTests/SFRemoteStorageClientTest.cs | 1 - .../UnitTests/SFS3ClientTest.cs | 1 - .../Core/FileTransfer/SFFileTransferAgent.cs | 31 ++--- .../FileTransfer/StorageClient/SFGCSClient.cs | 48 ++++---- .../FileTransfer/StorageClient/SFS3Client.cs | 11 +- .../StorageClient/SFSnowflakeAzureClient.cs | 64 +++++----- 9 files changed, 155 insertions(+), 113 deletions(-) diff --git a/Snowflake.Data.Tests/IntegrationTests/SFPutGetTest.cs b/Snowflake.Data.Tests/IntegrationTests/SFPutGetTest.cs index 975d041e0..2ef0c7ef9 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFPutGetTest.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFPutGetTest.cs @@ -31,6 +31,7 @@ class SFPutGetTest : SFBaseTest [ThreadStatic] private static string t_schemaName; [ThreadStatic] private static string t_tableName; [ThreadStatic] private static string t_stageName; + [ThreadStatic] private static string t_stageNameSse; // server side encryption without client side encryption [ThreadStatic] private static string t_fileName; [ThreadStatic] private static string t_outputFileName; [ThreadStatic] private static string t_inputFilePath; @@ -41,7 +42,7 @@ class SFPutGetTest : SFBaseTest [ThreadStatic] private static string t_destCompressionType; [ThreadStatic] private static bool t_autoCompress; [ThreadStatic] private static List t_filesToDelete; - + public enum StageType { USER, @@ -63,7 +64,7 @@ public static void OneTimeTearDown() // Delete temp output directory and downloaded files Directory.Delete(s_outputDirectory, true); } - + [SetUp] public void SetUp() { @@ -73,6 +74,7 @@ public void SetUp() t_schemaName = testConfig.schema; t_tableName = $"TABLE_{threadSuffix}"; t_stageName = $"STAGE_{threadSuffix}"; + t_stageNameSse = $"STAGE_{threadSuffix}_SSE"; t_filesToDelete = new List(); using (var conn = new SnowflakeDbConnection(ConnectionString)) @@ -88,6 +90,10 @@ public void SetUp() // Create temp stage command.CommandText = $"CREATE OR REPLACE STAGE {t_schemaName}.{t_stageName}"; command.ExecuteNonQuery(); + + // Create temp stage without client side encryption + command.CommandText = $"CREATE OR REPLACE STAGE {t_schemaName}.{t_stageNameSse} ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')"; + command.ExecuteNonQuery(); } } } @@ -109,7 +115,7 @@ public void TearDown() command.ExecuteNonQuery(); } } - + // Delete temp files if necessary if (t_filesToDelete != null) { @@ -130,7 +136,7 @@ public void TestPutFileAsteriskWildcard() $"{absolutePathPrefix}_three.csv" }; PrepareFileData(files); - + // Set the PUT query variables t_inputFilePath = $"{absolutePathPrefix}*"; t_internalStagePath = $"@{t_schemaName}.{t_stageName}"; @@ -142,7 +148,7 @@ public void TestPutFileAsteriskWildcard() VerifyFilesAreUploaded(conn, files, t_internalStagePath); } } - + [Test] public void TestPutFileAsteriskWildcardWithExtension() { @@ -167,7 +173,7 @@ public void TestPutFileAsteriskWildcardWithExtension() VerifyFilesAreUploaded(conn, files, t_internalStagePath); } } - + [Test] public void TestPutFileQuestionMarkWildcard() { @@ -180,7 +186,7 @@ public void TestPutFileQuestionMarkWildcard() PrepareFileData(files); // Create file which should be omitted during the transfer PrepareFileData($"{absolutePathPrefix}_four.csv"); - + // Set the PUT query variables t_inputFilePath = $"{absolutePathPrefix}_?.csv"; t_internalStagePath = $"@{t_schemaName}.{t_stageName}"; @@ -192,14 +198,14 @@ public void TestPutFileQuestionMarkWildcard() VerifyFilesAreUploaded(conn, files, t_internalStagePath); } } - + [Test] public void TestPutFileRelativePathWithoutDirectory() { // Set the PUT query variables t_inputFilePath = $"{Guid.NewGuid()}_1.csv"; t_internalStagePath = $"@{t_schemaName}.{t_stageName}"; - + PrepareFileData(t_inputFilePath); using (var conn = new SnowflakeDbConnection(ConnectionString)) @@ -226,7 +232,7 @@ public void TestPutGetOnClosedConnectionThrowsWithoutQueryId([Values("GET", "PUT SnowflakeDbExceptionAssert.HasErrorCode(snowflakeDbException, SFError.EXECUTE_COMMAND_ON_CLOSED_CONNECTION); } } - + [Test] public void TestGetNonExistentFileReturnsFalseAndDoesNotThrow() { @@ -236,7 +242,7 @@ public void TestGetNonExistentFileReturnsFalseAndDoesNotThrow() // Act using (var conn = new SnowflakeDbConnection(ConnectionString)) { - conn.Open(); + conn.Open(); var sql = $"GET {t_internalStagePath}/{t_fileName} file://{s_outputDirectory}"; using (var command = conn.CreateCommand()) { @@ -246,7 +252,7 @@ public void TestGetNonExistentFileReturnsFalseAndDoesNotThrow() } } } - + [Test] public void TestPutNonExistentFileThrowsWithQueryId() { @@ -256,14 +262,14 @@ public void TestPutNonExistentFileThrowsWithQueryId() // Act using (var conn = new SnowflakeDbConnection(ConnectionString)) { - conn.Open(); + conn.Open(); var snowflakeDbException = Assert.Throws(() => PutFile(conn)); Assert.IsNotNull(snowflakeDbException); Assert.IsNotNull(snowflakeDbException.QueryId); SnowflakeDbExceptionAssert.HasErrorCode(snowflakeDbException, SFError.IO_ERROR_ON_GETPUT_COMMAND); } } - + [Test] public void TestPutFileProvidesQueryIdOnFailure() { @@ -285,7 +291,7 @@ public void TestPutFileProvidesQueryIdOnFailure() SnowflakeDbExceptionAssert.HasErrorCode(snowflakeDbException, SFError.IO_ERROR_ON_GETPUT_COMMAND); } } - + [Test] public void TestPutFileWithSyntaxErrorProvidesQueryIdOnFailure() { @@ -308,7 +314,7 @@ public void TestPutFileWithSyntaxErrorProvidesQueryIdOnFailure() Assert.That(snowflakeDbException.InnerException, Is.Null); } } - + [Test] public void TestPutFileProvidesQueryIdOnSuccess() { @@ -323,7 +329,7 @@ public void TestPutFileProvidesQueryIdOnSuccess() { conn.Open(); var queryId = PutFile(conn); - + // Assert Assert.IsNotNull(queryId); Assert.DoesNotThrow(()=>Guid.Parse(queryId)); @@ -337,11 +343,11 @@ public void TestPutFileRelativePathWithDirectory() var guid = Guid.NewGuid(); var relativePath = $"{guid}"; Directory.CreateDirectory(relativePath); - + // Set the PUT query variables t_inputFilePath = $"{relativePath}{Path.DirectorySeparatorChar}{guid}_1.csv"; t_internalStagePath = $"@{t_schemaName}.{t_stageName}"; - + PrepareFileData(t_inputFilePath); using (var conn = new SnowflakeDbConnection(ConnectionString)) @@ -351,7 +357,7 @@ public void TestPutFileRelativePathWithDirectory() VerifyFilesAreUploaded(conn, new List { t_inputFilePath }, t_internalStagePath); } } - + [Test] public void TestPutFileRelativePathAsteriskWildcard() { @@ -374,7 +380,7 @@ public void TestPutFileRelativePathAsteriskWildcard() VerifyFilesAreUploaded(conn, files, t_internalStagePath); } } - + [Test] // presigned url is enabled on CI so we need to disable the test // it should be enabled when downscoped credential is the default option @@ -384,7 +390,7 @@ public void TestPutFileWithoutOverwriteFlagSkipsSecondUpload() // Set the PUT query variables t_inputFilePath = $"{Guid.NewGuid()}.csv"; t_internalStagePath = $"@{t_schemaName}.{t_stageName}"; - + PrepareFileData(t_inputFilePath); using (var conn = new SnowflakeDbConnection(ConnectionString)) @@ -395,18 +401,18 @@ public void TestPutFileWithoutOverwriteFlagSkipsSecondUpload() PutFile(conn, expectedStatus: ResultStatus.SKIPPED); } } - + [Test] public void TestPutFileWithOverwriteFlagRunsSecondUpload() { var overwriteAttribute = "OVERWRITE=TRUE"; - + // Set the PUT query variables t_inputFilePath = $"{Guid.NewGuid()}.csv"; t_internalStagePath = $"@{t_schemaName}.{t_stageName}"; - + PrepareFileData(t_inputFilePath); - + using (var conn = new SnowflakeDbConnection(ConnectionString)) { conn.Open(); @@ -415,7 +421,7 @@ public void TestPutFileWithOverwriteFlagRunsSecondUpload() PutFile(conn, overwriteAttribute, expectedStatus: ResultStatus.UPLOADED); } } - + [Test] public void TestPutDirectoryAsteriskWildcard() { @@ -431,7 +437,7 @@ public void TestPutDirectoryAsteriskWildcard() PrepareFileData(fullPath); files.Add(fullPath); } - + // Set the PUT query variables t_inputFilePath = $"{path}*{Path.DirectorySeparatorChar}*"; t_internalStagePath = $"@{t_schemaName}.{t_stageName}"; @@ -459,7 +465,7 @@ public void TestPutDirectoryQuestionMarkWildcard() PrepareFileData(fullPath); files.Add(fullPath); } - + // Set the PUT query variables t_inputFilePath = $"{path}_?{Path.DirectorySeparatorChar}{guid}_?_file.csv"; t_internalStagePath = $"@{t_schemaName}.{t_stageName}"; @@ -471,7 +477,7 @@ public void TestPutDirectoryQuestionMarkWildcard() VerifyFilesAreUploaded(conn, files, t_internalStagePath); } } - + [Test] public void TestPutDirectoryMixedWildcard() { @@ -487,7 +493,7 @@ public void TestPutDirectoryMixedWildcard() PrepareFileData(fullPath); files.Add(fullPath); } - + // Set the PUT query variables t_inputFilePath = $"{path}_*{Path.DirectorySeparatorChar}{guid}_?_file.csv"; t_internalStagePath = $"@{t_schemaName}.{t_stageName}"; @@ -499,7 +505,7 @@ public void TestPutDirectoryMixedWildcard() VerifyFilesAreUploaded(conn, files, t_internalStagePath); } } - + [Test] public void TestPutGetCommand( [Values("none", "gzip", "bzip2", "brotli", "deflate", "raw_deflate", "zstd")] string sourceFileCompressionType, @@ -517,7 +523,24 @@ public void TestPutGetCommand( GetFile(conn); } } - + + [Test] + public void TestPutGetCommandForNamedStageWithoutClientSideEncryption( + [Values("none", "gzip")] string sourceFileCompressionType, + [Values("", "/DEEP/TEST_PATH")] string stagePath, + [Values] bool autoCompress) + { + PrepareTest(sourceFileCompressionType, StageType.NAMED, stagePath, autoCompress, false); + + using (var conn = new SnowflakeDbConnection(ConnectionString)) + { + conn.Open(); + PutFile(conn); + CopyIntoTable(conn); + GetFile(conn); + } + } + // Test small file upload/download with GCS_USE_DOWNSCOPED_CREDENTIAL set to true [Test] [IgnoreOnEnvIs("snowflake_cloud_env", new [] { "AWS", "AZURE" })] @@ -536,14 +559,15 @@ public void TestPutGetGcsDownscopedCredential( GetFile(conn); } } - - private void PrepareTest(string sourceFileCompressionType, StageType stageType, string stagePath, bool autoCompress) + + private void PrepareTest(string sourceFileCompressionType, StageType stageType, string stagePath, + bool autoCompress, bool clientEncryption = true) { t_stageType = stageType; t_sourceCompressionType = sourceFileCompressionType; t_autoCompress = autoCompress; // Prepare temp file name with specified file extension - t_fileName = Guid.NewGuid() + ".csv" + + t_fileName = Guid.NewGuid() + ".csv" + (t_autoCompress? SFFileCompressionTypes.LookUpByName(t_sourceCompressionType).FileExtension: ""); t_inputFilePath = Path.GetTempPath() + t_fileName; if (IsCompressedByTheDriver()) @@ -570,7 +594,9 @@ private void PrepareTest(string sourceFileCompressionType, StageType stageType, t_internalStagePath = $"@{t_schemaName}.%{t_tableName}{stagePath}"; break; case StageType.NAMED: - t_internalStagePath = $"@{t_schemaName}.{t_stageName}{stagePath}"; + t_internalStagePath = clientEncryption + ? $"@{t_schemaName}.{t_stageName}{stagePath}" + : $"@{t_schemaName}.{t_stageNameSse}{stagePath}"; break; } } @@ -579,11 +605,11 @@ private static bool IsCompressedByTheDriver() { return t_sourceCompressionType == "none" && t_autoCompress; } - + // PUT - upload file from local directory to the stage string PutFile( - SnowflakeDbConnection conn, - String additionalAttribute = "", + SnowflakeDbConnection conn, + String additionalAttribute = "", ResultStatus expectedStatus = ResultStatus.UPLOADED) { string queryId; @@ -704,7 +730,7 @@ private void ProcessFile(String command, SnowflakeDbConnection connection) { switch (command) { - case "GET": + case "GET": GetFile(connection); break; case "PUT": @@ -747,7 +773,7 @@ private static void PrepareFileData(string file) // Prepare csv raw data and write to temp files var rawDataRow = string.Join(",", s_colData) + "\n"; var rawData = string.Concat(Enumerable.Repeat(rawDataRow, NumberOfRows)); - + File.WriteAllText(file, rawData); t_filesToDelete.Add(file); } diff --git a/Snowflake.Data.Tests/UnitTests/SFAzureClientTest.cs b/Snowflake.Data.Tests/UnitTests/SFAzureClientTest.cs index a1c791071..08b85a9b5 100644 --- a/Snowflake.Data.Tests/UnitTests/SFAzureClientTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFAzureClientTest.cs @@ -69,7 +69,6 @@ class SFAzureClientTest : SFBaseTest stageInfo = new PutGetStageInfo() { endPoint = EndPoint, - isClientSideEncrypted = true, location = Location, locationType = SFRemoteStorageUtil.AZURE_FS, path = LocationPath, diff --git a/Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs b/Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs index 0fad57542..d47742743 100644 --- a/Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFGCSClientTest.cs @@ -62,7 +62,6 @@ class SFGCSClientTest : SFBaseTest stageInfo = new PutGetStageInfo() { endPoint = null, - isClientSideEncrypted = true, location = Location, locationType = SFRemoteStorageUtil.GCS_FS, path = LocationPath, diff --git a/Snowflake.Data.Tests/UnitTests/SFRemoteStorageClientTest.cs b/Snowflake.Data.Tests/UnitTests/SFRemoteStorageClientTest.cs index 0e9d53767..76ec7c557 100644 --- a/Snowflake.Data.Tests/UnitTests/SFRemoteStorageClientTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFRemoteStorageClientTest.cs @@ -87,7 +87,6 @@ class SFRemoteStorageClientTest : SFBaseTest stageInfo = new PutGetStageInfo() { endPoint = EndPoint, - isClientSideEncrypted = true, location = Location, locationType = SFRemoteStorageUtil.GCS_FS, path = LocationPath, diff --git a/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs b/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs index 54647db8b..5432b0121 100644 --- a/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs @@ -87,7 +87,6 @@ class SFS3ClientTest : SFBaseTest stageInfo = new PutGetStageInfo() { endPoint = Endpoint, - isClientSideEncrypted = true, location = Location, locationType = SFRemoteStorageUtil.S3_FS, path = LocationPath, diff --git a/Snowflake.Data/Core/FileTransfer/SFFileTransferAgent.cs b/Snowflake.Data/Core/FileTransfer/SFFileTransferAgent.cs index b27daa51f..1ee8557b6 100644 --- a/Snowflake.Data/Core/FileTransfer/SFFileTransferAgent.cs +++ b/Snowflake.Data/Core/FileTransfer/SFFileTransferAgent.cs @@ -468,7 +468,7 @@ private void updatePresignedUrl() fileMeta.stageInfo = response.data.stageInfo; fileMeta.presignedUrl = response.data.stageInfo.presignedUrl; - } + } } else if (CommandTypes.DOWNLOAD == CommandType) { @@ -477,7 +477,7 @@ private void updatePresignedUrl() FilesMetas[index].presignedUrl = TransferMetadata.presignedUrls[index]; } } - } + } } /// @@ -544,7 +544,10 @@ private void initEncryptionMaterial() { if (CommandTypes.UPLOAD == CommandType) { - EncryptionMaterials.Add(TransferMetadata.encryptionMaterial[0]); + if (TransferMetadata.stageInfo.isClientSideEncrypted) + { + EncryptionMaterials.Add(TransferMetadata.encryptionMaterial[0]); + } } } @@ -670,7 +673,9 @@ private void initFileMetadata( overwrite = TransferMetadata.overwrite, presignedUrl = TransferMetadata.stageInfo.presignedUrl, parallel = TransferMetadata.parallel, - encryptionMaterial = TransferMetadata.encryptionMaterial[index], + encryptionMaterial = index < TransferMetadata.encryptionMaterial.Count + ? TransferMetadata.encryptionMaterial[index] + : null, MaxBytesInMemory = GetFileTransferMaxBytesInMemory(), _operationType = CommandTypes.DOWNLOAD }; @@ -715,7 +720,7 @@ private int GetFileTransferMaxBytesInMemory() return FileTransferConfiguration.DefaultMaxBytesInMemory; } } - + /// /// Expand the wildcards if any to generate the list of paths for all files matched by the wildcards. /// Also replace the relative paths to the absolute paths for the files if needed. @@ -731,7 +736,7 @@ private List expandFileNames(string location) var directoryName = Path.GetDirectoryName(location); var foundDirectories = ExpandDirectories(directoryName); var filePaths = new List(); - + if (ContainsWildcard(fileName)) { foreach (var directory in foundDirectories) @@ -756,8 +761,8 @@ private List expandFileNames(string location) { filePaths.AddRange( Directory.GetFiles( - directory, - fileName, + directory, + fileName, SearchOption.TopDirectoryOnly)); } } @@ -788,7 +793,7 @@ private List expandFileNames(string location) return filePaths; } - + /// /// Expand the wildcards in the directory path to generate the list of directories to be searched for the files. /// @@ -803,7 +808,7 @@ private static IEnumerable ExpandDirectories(string directoryPath) { return new List { Path.GetFullPath(directoryPath) + Path.DirectorySeparatorChar }; } - + var pathParts = directoryPath.Split(Path.DirectorySeparatorChar); var resolvedPaths = new List(); @@ -863,7 +868,7 @@ private static IEnumerable ExpandDirectories(string directoryPath) private static string ExpandHomeDirectoryIfNeeded(string directoryPath) { if (!directoryPath.Contains('~')) return directoryPath; - + var homePath = (Environment.OSVersion.Platform == PlatformID.Unix || Environment.OSVersion.Platform == PlatformID.MacOSX) ? Environment.GetEnvironmentVariable("HOME") @@ -1036,7 +1041,7 @@ private async Task UploadFilesInSequentialAsync( { await updatePresignedUrlAsync(cancellationToken).ConfigureAwait(false); } - + // Break out of loop if file is successfully uploaded or already exists if (fileMetadata.resultStatus == ResultStatus.UPLOADED.ToString() || fileMetadata.resultStatus == ResultStatus.SKIPPED.ToString()) @@ -1429,7 +1434,7 @@ private void initFileMetadataForUpload() throw new ArgumentException("No file found for: " + TransferMetadata.src_locations[0].ToString()); } } - + private static bool IsDirectory(string path) { var attr = File.GetAttributes(path); diff --git a/Snowflake.Data/Core/FileTransfer/StorageClient/SFGCSClient.cs b/Snowflake.Data/Core/FileTransfer/StorageClient/SFGCSClient.cs index 9e588e921..f56baf2fa 100644 --- a/Snowflake.Data/Core/FileTransfer/StorageClient/SFGCSClient.cs +++ b/Snowflake.Data/Core/FileTransfer/StorageClient/SFGCSClient.cs @@ -240,11 +240,9 @@ internal string generateFileURL(string stageLocation, string fileName) /// The encryption metadata for the header. public void UploadFile(SFFileMetadata fileMetadata, Stream fileBytesStream, SFEncryptionMetadata encryptionMetadata) { - String encryptionData = GetUploadEncryptionData(encryptionMetadata); - try { - WebRequest request = GetUploadFileRequest(fileMetadata, encryptionMetadata, encryptionData); + WebRequest request = GetUploadFileRequest(fileMetadata, encryptionMetadata); Stream dataStream = request.GetRequestStream(); fileBytesStream.Position = 0; @@ -271,11 +269,9 @@ public void UploadFile(SFFileMetadata fileMetadata, Stream fileBytesStream, SFEn /// The encryption metadata for the header. public async Task UploadFileAsync(SFFileMetadata fileMetadata, Stream fileByteStream, SFEncryptionMetadata encryptionMetadata, CancellationToken cancellationToken) { - String encryptionData = GetUploadEncryptionData(encryptionMetadata); - try { - WebRequest request = GetUploadFileRequest(fileMetadata, encryptionMetadata, encryptionData); + WebRequest request = GetUploadFileRequest(fileMetadata, encryptionMetadata); Stream dataStream = await request.GetRequestStreamAsync().ConfigureAwait(false); fileByteStream.Position = 0; @@ -294,14 +290,19 @@ public async Task UploadFileAsync(SFFileMetadata fileMetadata, Stream fileByteSt } } - private WebRequest GetUploadFileRequest(SFFileMetadata fileMetadata, SFEncryptionMetadata encryptionMetadata, String encryptionData) + private WebRequest GetUploadFileRequest(SFFileMetadata fileMetadata, SFEncryptionMetadata encryptionMetadata) { // Issue the POST/PUT request WebRequest request = _customWebRequest == null ? FormBaseRequest(fileMetadata, "PUT") : _customWebRequest; request.Headers.Add(GCS_METADATA_SFC_DIGEST, fileMetadata.sha256Digest); - request.Headers.Add(GCS_METADATA_MATDESC_KEY, encryptionMetadata.matDesc); - request.Headers.Add(GCS_METADATA_ENCRYPTIONDATAPROP, encryptionData); + if (fileMetadata.stageInfo.isClientSideEncrypted) + { + String encryptionData = GetUploadEncryptionData(ref fileMetadata, encryptionMetadata); + + request.Headers.Add(GCS_METADATA_MATDESC_KEY, encryptionMetadata.matDesc); + request.Headers.Add(GCS_METADATA_ENCRYPTIONDATAPROP, encryptionData); + } return request; } @@ -311,7 +312,7 @@ private WebRequest GetUploadFileRequest(SFFileMetadata fileMetadata, SFEncryptio /// /// The encryption metadata for the header. /// Stream content. - private String GetUploadEncryptionData(SFEncryptionMetadata encryptionMetadata) + private String GetUploadEncryptionData(ref SFFileMetadata fileMetadata, SFEncryptionMetadata encryptionMetadata) { // Create the encryption header value string encryptionData = JsonConvert.SerializeObject(new EncryptionData @@ -415,20 +416,23 @@ private void HandleDownloadResponse(HttpWebResponse response, SFFileMetadata fil WebHeaderCollection headers = response.Headers; // Get header values - dynamic encryptionData = JsonConvert.DeserializeObject(headers.Get(GCS_METADATA_ENCRYPTIONDATAPROP)); - string matDesc = headers.Get(GCS_METADATA_MATDESC_KEY); - - // Get encryption metadata from encryption data header value - SFEncryptionMetadata encryptionMetadata = null; - if (encryptionData != null) + var encryptionDataStr = headers.Get(GCS_METADATA_ENCRYPTIONDATAPROP); + if (encryptionDataStr != null) { - encryptionMetadata = new SFEncryptionMetadata + dynamic encryptionData = JsonConvert.DeserializeObject(encryptionDataStr); + string matDesc = headers.Get(GCS_METADATA_MATDESC_KEY); + + // Get encryption metadata from encryption data header value + if (encryptionData != null) { - iv = encryptionData["ContentEncryptionIV"], - key = encryptionData["WrappedContentKey"]["EncryptedKey"], - matDesc = matDesc - }; - fileMetadata.encryptionMetadata = encryptionMetadata; + SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata + { + iv = encryptionData["ContentEncryptionIV"], + key = encryptionData["WrappedContentKey"]["EncryptedKey"], + matDesc = matDesc + }; + fileMetadata.encryptionMetadata = encryptionMetadata; + } } fileMetadata.sha256Digest = headers.Get(GCS_METADATA_SFC_DIGEST); diff --git a/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs b/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs index b6896cc79..ea0eb3fd0 100644 --- a/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs +++ b/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs @@ -422,10 +422,13 @@ private PutObjectRequest GetPutObjectRequest(ref AmazonS3Client client, SFFileMe ContentType = HTTP_HEADER_VALUE_OCTET_STREAM }; - // Populate the S3 Request Metadata - putObjectRequest.Metadata.Add(AMZ_META_PREFIX + AMZ_IV, encryptionMetadata.iv); - putObjectRequest.Metadata.Add(AMZ_META_PREFIX + AMZ_KEY, encryptionMetadata.key); - putObjectRequest.Metadata.Add(AMZ_META_PREFIX + AMZ_MATDESC, encryptionMetadata.matDesc); + if (stageInfo.isClientSideEncrypted) + { + // Populate the S3 Request Metadata + putObjectRequest.Metadata.Add(AMZ_META_PREFIX + AMZ_IV, encryptionMetadata.iv); + putObjectRequest.Metadata.Add(AMZ_META_PREFIX + AMZ_KEY, encryptionMetadata.key); + putObjectRequest.Metadata.Add(AMZ_META_PREFIX + AMZ_MATDESC, encryptionMetadata.matDesc); + } return putObjectRequest; } diff --git a/Snowflake.Data/Core/FileTransfer/StorageClient/SFSnowflakeAzureClient.cs b/Snowflake.Data/Core/FileTransfer/StorageClient/SFSnowflakeAzureClient.cs index f0ad3f09e..98c2694cb 100644 --- a/Snowflake.Data/Core/FileTransfer/StorageClient/SFSnowflakeAzureClient.cs +++ b/Snowflake.Data/Core/FileTransfer/StorageClient/SFSnowflakeAzureClient.cs @@ -158,13 +158,17 @@ private FileHeader HandleFileHeaderResponse(ref SFFileMetadata fileMetadata, Blo { fileMetadata.resultStatus = ResultStatus.UPLOADED.ToString(); - dynamic encryptionData = JsonConvert.DeserializeObject(response.Metadata["encryptiondata"]); - SFEncryptionMetadata encryptionMetadata = new SFEncryptionMetadata + SFEncryptionMetadata encryptionMetadata = null; + if (response.Metadata.TryGetValue("encryptiondata", out var encryptionDataStr)) { - iv = encryptionData["ContentEncryptionIV"], - key = encryptionData.WrappedContentKey["EncryptedKey"], - matDesc = response.Metadata["matdesc"] - }; + dynamic encryptionData = JsonConvert.DeserializeObject(encryptionDataStr); + encryptionMetadata = new SFEncryptionMetadata + { + iv = encryptionData["ContentEncryptionIV"], + key = encryptionData.WrappedContentKey["EncryptedKey"], + matDesc = response.Metadata["matdesc"] + }; + } return new FileHeader { @@ -242,31 +246,35 @@ public async Task UploadFileAsync(SFFileMetadata fileMetadata, Stream fileBytesS /// The encryption metadata for the header. private BlobClient GetUploadFileBlobClient(ref IDictionarymetadata, SFFileMetadata fileMetadata, SFEncryptionMetadata encryptionMetadata) { - // Create the JSON for the encryption data header - string encryptionData = JsonConvert.SerializeObject(new EncryptionData + if (fileMetadata.stageInfo.isClientSideEncrypted) { - EncryptionMode = "FullBlob", - WrappedContentKey = new WrappedContentInfo - { - KeyId = "symmKey1", - EncryptedKey = encryptionMetadata.key, - Algorithm = "AES_CBC_256" - }, - EncryptionAgent = new EncryptionAgentInfo + // Create the JSON for the encryption data header + string encryptionData = JsonConvert.SerializeObject(new EncryptionData { - Protocol = "1.0", - EncryptionAlgorithm = "AES_CBC_256" - }, - ContentEncryptionIV = encryptionMetadata.iv, - KeyWrappingMetadata = new KeyWrappingMetadataInfo - { - EncryptionLibrary = "Java 5.3.0" - } - }); + EncryptionMode = "FullBlob", + WrappedContentKey = new WrappedContentInfo + { + KeyId = "symmKey1", + EncryptedKey = encryptionMetadata.key, + Algorithm = "AES_CBC_256" + }, + EncryptionAgent = new EncryptionAgentInfo + { + Protocol = "1.0", + EncryptionAlgorithm = "AES_CBC_256" + }, + ContentEncryptionIV = encryptionMetadata.iv, + KeyWrappingMetadata = new KeyWrappingMetadataInfo + { + EncryptionLibrary = "Java 5.3.0" + } + }); + + // Create the metadata to use for the header + metadata.Add("encryptiondata", encryptionData); + metadata.Add("matdesc", encryptionMetadata.matDesc); + } - // Create the metadata to use for the header - metadata.Add("encryptiondata", encryptionData); - metadata.Add("matdesc", encryptionMetadata.matDesc); metadata.Add("sfcdigest", fileMetadata.sha256Digest); PutGetStageInfo stageInfo = fileMetadata.stageInfo; From d69e2cb42ebd0f076476b979dcd249a0470f4e71 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Mon, 4 Nov 2024 16:09:33 +0100 Subject: [PATCH 17/20] MINOR: Bumped up DotNet connector MINOR version from 4.1.0 to 4.2.0 (#1054) Co-authored-by: Jenkins User <900904> --- Snowflake.Data/Snowflake.Data.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Snowflake.Data/Snowflake.Data.csproj b/Snowflake.Data/Snowflake.Data.csproj index 35599c903..a0b09fade 100644 --- a/Snowflake.Data/Snowflake.Data.csproj +++ b/Snowflake.Data/Snowflake.Data.csproj @@ -11,7 +11,7 @@ Snowflake Computing, Inc Snowflake Connector for .NET Snowflake - 4.1.0 + 4.2.0 Full 7.3 From 43924472903467ad741f8d1ea3ee777c02e21829 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Tue, 5 Nov 2024 20:02:11 +0100 Subject: [PATCH 18/20] SNOW-1758373 Documentation for package signature verification (#1049) --- README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/README.md b/README.md index 408563f25..682475df3 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,22 @@ Alternatively, packages can also be downloaded using Package Manager Console: PM> Install-Package Snowflake.Data ``` +# Verifying the package signature + +Starting from version v4.2.0 the driver package is signed with a signature allowing to verify its authenticity and integrity. +Steps to verify the signature: +1. Install `cosign` +2. Download the driver package file (`.nupkg`) from nuget, e.g.: https://www.nuget.org/packages/Snowflake.Data/4.2.0 +3. Download the signatures file from the release, e.g.: https://github.com/snowflakedb/snowflake-connector-net/releases/tag/v4.2.0 +4. Verify the signature, e.g: +```shell +cosign verify-blob snowflake.data.4.2.0.nupkg \ +--key snowflake-connector-net-v4.2.0.pub \ +--signature Snowflake.Data.4.2.0.nupkg.sig + +Verified OK +``` + # Testing and Code Coverage [Running tests](doc/Testing.md) From f27eb2a84ca797ee435d68ac58b0779833ab2a53 Mon Sep 17 00:00:00 2001 From: Juan Martinez Ramirez <126511805+sfc-gh-jmartinezramirez@users.noreply.github.com> Date: Thu, 7 Nov 2024 08:42:01 -0600 Subject: [PATCH 19/20] SNOW-1444876: Support for TOML connections (#995) Co-authored-by: Krzysztof Nozderko --- .gitignore | 633 +++++++++--------- .../IntegrationTests/SFConnectionIT.cs | 23 +- .../SFConnectionWithTomlIT.cs | 143 ++++ Snowflake.Data.Tests/SFBaseTest.cs | 15 +- .../Snowflake.Data.Tests.csproj | 1 + .../EasyLoggingConfigFinderTest.cs | 30 +- .../UnitTests/SnowflakeDbConnectionTest.cs | 61 ++ .../SnowflakeTomlConnectionBuilderTest.cs | 514 ++++++++++++++ .../UnitTests/Tools/FileOperationsTest.cs | 110 +++ .../UnitTests/Tools/UnixOperationsTest.cs | 103 ++- .../Client/SnowflakeDbConnection.cs | 35 +- Snowflake.Data/Core/TomlConnectionBuilder.cs | 171 +++++ Snowflake.Data/Core/Tools/FileOperations.cs | 16 + Snowflake.Data/Core/Tools/UnixOperations.cs | 18 + Snowflake.Data/Snowflake.Data.csproj | 1 + ci/test.sh | 1 + doc/Connecting.md | 87 +++ snowflake-connector-net.sln.DotSettings | 3 + 18 files changed, 1602 insertions(+), 363 deletions(-) create mode 100644 Snowflake.Data.Tests/IntegrationTests/SFConnectionWithTomlIT.cs create mode 100644 Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionTest.cs create mode 100644 Snowflake.Data.Tests/UnitTests/SnowflakeTomlConnectionBuilderTest.cs create mode 100644 Snowflake.Data.Tests/UnitTests/Tools/FileOperationsTest.cs create mode 100644 Snowflake.Data/Core/TomlConnectionBuilder.cs diff --git a/.gitignore b/.gitignore index 268c8f4dc..28192867b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,314 +1,319 @@ -## Ignore Visual Studio temporary files, build results, and -## files generated by popular Visual Studio add-ons. -## -## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore - -# User-specific files -*.suo -*.user -*.userosscache -*.sln.docstates - -# User-specific files (MonoDevelop/Xamarin Studio) -*.userprefs - -# Build results -[Dd]ebug/ -[Dd]ebugPublic/ -[Rr]elease/ -[Rr]eleases/ -x64/ -x86/ -bld/ -[Bb]in/ -[Oo]bj/ -[Ll]og/ - -# Visual Studio 2015 cache/options directory -.vs/ -# Uncomment if you have tasks that create the project's static files in wwwroot -#wwwroot/ - -# MSTest test Results -[Tt]est[Rr]esult*/ -[Bb]uild[Ll]og.* - -# NUNIT -*.VisualState.xml -TestResult.xml - -# Build Results of an ATL Project -[Dd]ebugPS/ -[Rr]eleasePS/ -dlldata.c - -# Benchmark Results -BenchmarkDotNet.Artifacts/ - -# .NET Core -project.lock.json -project.fragment.lock.json -artifacts/ -**/Properties/launchSettings.json - -*_i.c -*_p.c -*_i.h -*.ilk -*.meta -*.obj -*.pch -*.pdb -*.pgc -*.pgd -*.rsp -*.sbr -*.tlb -*.tli -*.tlh -*.tmp -*.tmp_proj -*.log -*.vspscc -*.vssscc -.builds -*.pidb -*.svclog -*.scc - -# Chutzpah Test files -_Chutzpah* - -# Visual C++ cache files -ipch/ -*.aps -*.ncb -*.opendb -*.opensdf -*.sdf -*.cachefile -*.VC.db -*.VC.VC.opendb - -# Visual Studio profiler -*.psess -*.vsp -*.vspx -*.sap - -# TFS 2012 Local Workspace -$tf/ - -# Guidance Automation Toolkit -*.gpState - -# ReSharper is a .NET coding add-in -_ReSharper*/ -*.[Rr]e[Ss]harper -*.DotSettings.user - -# JustCode is a .NET coding add-in -.JustCode - -# TeamCity is a build add-in -_TeamCity* - -# DotCover is a Code Coverage Tool -*.dotCover - -# Visual Studio code coverage results -*.coverage -*.coveragexml - -# NCrunch -_NCrunch_* -.*crunch*.local.xml -nCrunchTemp_* - -# MightyMoose -*.mm.* -AutoTest.Net/ - -# Web workbench (sass) -.sass-cache/ - -# Installshield output folder -[Ee]xpress/ - -# DocProject is a documentation generator add-in -DocProject/buildhelp/ -DocProject/Help/*.HxT -DocProject/Help/*.HxC -DocProject/Help/*.hhc -DocProject/Help/*.hhk -DocProject/Help/*.hhp -DocProject/Help/Html2 -DocProject/Help/html - -# Click-Once directory -publish/ - -# Publish Web Output -*.[Pp]ublish.xml -*.azurePubxml -# TODO: Comment the next line if you want to checkin your web deploy settings -# but database connection strings (with potential passwords) will be unencrypted -*.pubxml -*.publishproj - -# Microsoft Azure Web App publish settings. Comment the next line if you want to -# checkin your Azure Web App publish settings, but sensitive information contained -# in these scripts will be unencrypted -PublishScripts/ - -# NuGet Packages -*.nupkg -# The packages folder can be ignored because of Package Restore -**/packages/* -# except build/, which is used as an MSBuild target. -!**/packages/build/ -# Uncomment if necessary however generally it will be regenerated when needed -#!**/packages/repositories.config -# NuGet v3's project.json files produces more ignorable files -*.nuget.props -*.nuget.targets - -# Microsoft Azure Build Output -csx/ -*.build.csdef - -# Microsoft Azure Emulator -ecf/ -rcf/ - -# Windows Store app package directories and files -AppPackages/ -BundleArtifacts/ -Package.StoreAssociation.xml -_pkginfo.txt -*.appx - -# Visual Studio cache files -# files ending in .cache can be ignored -*.[Cc]ache -# but keep track of directories ending in .cache -!*.[Cc]ache/ - -# Others -ClientBin/ -~$* -*~ -*.dbmdl -*.dbproj.schemaview -*.jfm -*.pfx -*.publishsettings -orleans.codegen.cs - -# Since there are multiple workflows, uncomment next line to ignore bower_components -# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) -#bower_components/ - -# RIA/Silverlight projects -Generated_Code/ - -# Backup & report files from converting an old project file -# to a newer Visual Studio version. Backup files are not needed, -# because we have git ;-) -_UpgradeReport_Files/ -Backup*/ -UpgradeLog*.XML -UpgradeLog*.htm - -# SQL Server files -*.mdf -*.ldf -*.ndf - -# Business Intelligence projects -*.rdl.data -*.bim.layout -*.bim_*.settings - -# Microsoft Fakes -FakesAssemblies/ - -# GhostDoc plugin setting file -*.GhostDoc.xml - -# Node.js Tools for Visual Studio -.ntvs_analysis.dat -node_modules/ - -# Typescript v1 declaration files -typings/ - -# Visual Studio 6 build log -*.plg - -# Visual Studio 6 workspace options file -*.opt - -# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) -*.vbw - -# Visual Studio LightSwitch build output -**/*.HTMLClient/GeneratedArtifacts -**/*.DesktopClient/GeneratedArtifacts -**/*.DesktopClient/ModelManifest.xml -**/*.Server/GeneratedArtifacts -**/*.Server/ModelManifest.xml -_Pvt_Extensions - -# Paket dependency manager -.paket/paket.exe -paket-files/ - -# FAKE - F# Make -.fake/ - -# JetBrains Rider -.idea/ -*.sln.iml - -# CodeRush -.cr/ - -# Python Tools for Visual Studio (PTVS) -__pycache__/ -*.pyc - -# Cake - Uncomment if you are using it -# tools/** -# !tools/packages.config - -# Tabs Studio -*.tss - -# Telerik's JustMock configuration file -*.jmconfig - -# BizTalk build output -*.btp.cs -*.btm.cs -*.odx.cs -*.xsd.cs - -# Unencrypted file -Snowflake.Data.Tests/parameters.json -*.xml - -# WhiteSource -wss-*.config -wss-unified-agent.jar -whitesource/ -/testEnvironments.json -/parameters.json - -# Test performance reports -Snowflake.Data.Tests/macos_*_performance.csv -Snowflake.Data.Tests/windows_*_performance.csv -Snowflake.Data.Tests/unix_*_performance.csv - -# Ignore Mac files -**/.DS_Store \ No newline at end of file +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore + +# User-specific files +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ + +# Visual Studio 2015 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUNIT +*.VisualState.xml +TestResult.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ +**/Properties/launchSettings.json + +*_i.c +*_p.c +*_i.h +*.ilk +*.meta +*.obj +*.pch +*.pdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*.log +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# JustCode is a .NET coding add-in +.JustCode + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# TODO: Comment the next line if you want to checkin your web deploy settings +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# The packages folder can be ignored because of Package Restore +**/packages/* +# except build/, which is used as an MSBuild target. +!**/packages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/packages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Typescript v1 declaration files +typings/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# JetBrains Rider +.idea/ +*.sln.iml + +# CodeRush +.cr/ + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# Unencrypted file +Snowflake.Data.Tests/parameters.json +*.xml + +# WhiteSource +wss-*.config +wss-unified-agent.jar +whitesource/ + +# Test performance reports +Snowflake.Data.Tests/macos_*_performance.csv +Snowflake.Data.Tests/windows_*_performance.csv +Snowflake.Data.Tests/unix_*_performance.csv + +# Ignore Mac files +**/.DS_Store + +# Ignore config files +/testEnvironments.json +/parameters.json +parameters*.json +Snowflake.Data.Tests/toml_config_folder +*.toml \ No newline at end of file diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index 554d0c2a9..e3303bdee 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -2,25 +2,24 @@ * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. */ +using System; +using System.Data; using System.Data.Common; +using System.Diagnostics; using System.Net; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; using Snowflake.Data.Core.Session; +using Snowflake.Data.Log; +using Snowflake.Data.Tests.Mock; using Snowflake.Data.Tests.Util; namespace Snowflake.Data.Tests.IntegrationTests { - using NUnit.Framework; - using Snowflake.Data.Client; - using System.Data; - using System; - using Snowflake.Data.Core; - using System.Threading.Tasks; - using System.Threading; - using Snowflake.Data.Log; - using System.Diagnostics; - using Snowflake.Data.Tests.Mock; - using System.Runtime.InteropServices; - using System.Net.Http; [TestFixture] class SFConnectionIT : SFBaseTest diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionWithTomlIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionWithTomlIT.cs new file mode 100644 index 000000000..29d99744c --- /dev/null +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionWithTomlIT.cs @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Data; +using System.IO; +using System.Runtime.InteropServices; +using Mono.Unix.Native; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Log; +using Tomlyn; +using Tomlyn.Model; + +namespace Snowflake.Data.Tests.IntegrationTests +{ + + [TestFixture, NonParallelizable] + class SFConnectionWithTomlIT : SFBaseTest + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private static string s_workingDirectory; + + + [SetUp] + public new void BeforeTest() + { + s_workingDirectory ??= Path.Combine(TestContext.CurrentContext.WorkDirectory, "../../..", "toml_config_folder"); + if (!Directory.Exists(s_workingDirectory)) + { + Directory.CreateDirectory(s_workingDirectory); + } + CreateTomlConfigBaseOnConnectionString(ConnectionString); + } + + [TearDown] + public new void AfterTest() + { + Directory.Delete(s_workingDirectory, true); + } + + [Test] + public void TestLocalDefaultConnectStringReadFromToml() + { + var snowflakeHome = Environment.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeHome); + Environment.SetEnvironmentVariable(TomlConnectionBuilder.SnowflakeHome, s_workingDirectory); + try + { + using (var conn = new SnowflakeDbConnection()) + { + conn.Open(); + Assert.AreEqual(ConnectionState.Open, conn.State); + } + } + finally + { + Environment.SetEnvironmentVariable(TomlConnectionBuilder.SnowflakeHome, snowflakeHome); + } + } + + [Test] + public void TestThrowExceptionIfTomlNotFoundWithOtherConnectionString() + { + var snowflakeHome = Environment.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeHome); + var connectionName = Environment.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName); + Environment.SetEnvironmentVariable(TomlConnectionBuilder.SnowflakeHome, s_workingDirectory); + Environment.SetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName, "notfoundconnection"); + try + { + using (var conn = new SnowflakeDbConnection()) + { + Assert.Throws(() => conn.Open(), "Unable to connect. Specified connection name does not exist in connections.toml"); + } + } + finally + { + Environment.SetEnvironmentVariable(TomlConnectionBuilder.SnowflakeHome, snowflakeHome); + Environment.SetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName, connectionName); + } + } + + [Test] + public void TestThrowExceptionIfTomlFromNotFoundFromDbConnection() + { + var snowflakeHome = Environment.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeHome); + Environment.SetEnvironmentVariable(TomlConnectionBuilder.SnowflakeHome, Path.Combine(s_workingDirectory, "InvalidFolder")); + try + { + using (var conn = new SnowflakeDbConnection()) + { + Assert.Throws(() => conn.Open(), "Error: Required property ACCOUNT is not provided"); + } + } + finally + { + Environment.SetEnvironmentVariable(TomlConnectionBuilder.SnowflakeHome, snowflakeHome); + } + } + + private static void CreateTomlConfigBaseOnConnectionString(string connectionString) + { + var tomlModel = new TomlTable(); + var properties = SFSessionProperties.ParseConnectionString(connectionString, null); + + var defaultTomlTable = new TomlTable(); + tomlModel.Add("default", defaultTomlTable); + + foreach (var property in properties) + { + defaultTomlTable.Add(property.Key.ToString(), property.Value); + } + + var filePath = Path.Combine(s_workingDirectory, "connections.toml"); + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + using (var writer = File.CreateText(filePath)) + { + writer.Write(Toml.FromModel(tomlModel)); + } + } + else + { + using (var writer = File.CreateText(filePath)) + { + writer.Write(string.Empty); + } + Syscall.chmod(filePath, FilePermissions.S_IRUSR | FilePermissions.S_IWUSR); + using (var writer = File.CreateText(filePath)) + { + writer.Write(Toml.FromModel(tomlModel)); + } + Syscall.chmod(filePath, FilePermissions.S_IRUSR | FilePermissions.S_IWUSR); + } + } + } + +} + + diff --git a/Snowflake.Data.Tests/SFBaseTest.cs b/Snowflake.Data.Tests/SFBaseTest.cs index 1e8e13018..2784f0e25 100755 --- a/Snowflake.Data.Tests/SFBaseTest.cs +++ b/Snowflake.Data.Tests/SFBaseTest.cs @@ -421,10 +421,14 @@ public class IgnoreOnEnvIsAttribute : Attribute, ITestAction private readonly string _key; private readonly string[] _values; - public IgnoreOnEnvIsAttribute(string key, string[] values) + + private readonly string _reason; + + public IgnoreOnEnvIsAttribute(string key, string[] values, string reason = null) { _key = key; _values = values; + _reason = reason; } public void BeforeTest(ITest test) @@ -433,7 +437,7 @@ public void BeforeTest(ITest test) { if (Environment.GetEnvironmentVariable(_key) == value) { - Assert.Ignore("Test is ignored when environment variable {0} is {1} ", _key, value); + Assert.Ignore("Test is ignored when environment variable {0} is {1}. {2}", _key, value, _reason); } } } @@ -468,4 +472,11 @@ public void AfterTest(ITest test) public ActionTargets Targets => ActionTargets.Test | ActionTargets.Suite; } + + public class IgnoreOnCI : IgnoreOnEnvIsAttribute + { + public IgnoreOnCI(string reason = null) : base("CI", new[] { "true" }, reason) + { + } + } } diff --git a/Snowflake.Data.Tests/Snowflake.Data.Tests.csproj b/Snowflake.Data.Tests/Snowflake.Data.Tests.csproj index cc895154e..86da12b20 100644 --- a/Snowflake.Data.Tests/Snowflake.Data.Tests.csproj +++ b/Snowflake.Data.Tests/Snowflake.Data.Tests.csproj @@ -19,6 +19,7 @@ + diff --git a/Snowflake.Data.Tests/UnitTests/Configuration/EasyLoggingConfigFinderTest.cs b/Snowflake.Data.Tests/UnitTests/Configuration/EasyLoggingConfigFinderTest.cs index b23fbbf0e..4b9e36d47 100644 --- a/Snowflake.Data.Tests/UnitTests/Configuration/EasyLoggingConfigFinderTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Configuration/EasyLoggingConfigFinderTest.cs @@ -45,7 +45,7 @@ public void Setup() MockHomeDirectory(); MockExecutionDirectory(); } - + [Test] public void TestThatTakesFilePathFromTheInput() { @@ -53,10 +53,10 @@ public void TestThatTakesFilePathFromTheInput() MockFileFromEnvironmentalVariable(); MockFileOnDriverPath(); MockFileOnHomePath(); - + // act var filePath = t_finder.FindConfigFilePath(InputConfigFilePath); - + // assert Assert.AreEqual(InputConfigFilePath, filePath); t_fileOperations.VerifyNoOtherCalls(); @@ -71,14 +71,14 @@ public void TestThatTakesFilePathFromEnvironmentVariableIfInputNotPresent( MockFileFromEnvironmentalVariable(); MockFileOnDriverPath(); MockFileOnHomePath(); - + // act var filePath = t_finder.FindConfigFilePath(inputFilePath); - + // assert Assert.AreEqual(EnvironmentalConfigFilePath, filePath); } - + [Test] public void TestThatTakesFilePathFromDriverLocationWhenNoInputParameterNorEnvironmentVariable() { @@ -88,20 +88,20 @@ public void TestThatTakesFilePathFromDriverLocationWhenNoInputParameterNorEnviro // act var filePath = t_finder.FindConfigFilePath(null); - + // assert Assert.AreEqual(s_driverConfigFilePath, filePath); } - + [Test] public void TestThatTakesFilePathFromHomeLocationWhenNoInputParamEnvironmentVarNorDriverLocation() { // arrange MockFileOnHomePath(); - + // act var filePath = t_finder.FindConfigFilePath(null); - + // assert Assert.AreEqual(s_homeConfigFilePath, filePath); } @@ -138,13 +138,13 @@ public void TestThatConfigFileIsNotUsedIfOthersCanModifyTheConfigFile() Assert.IsNotNull(thrown); Assert.AreEqual(thrown.Message, $"Error due to other users having permission to modify the config file: {s_homeConfigFilePath}"); } - + [Test] public void TestThatReturnsNullIfNoWayOfGettingTheFile() { // act var filePath = t_finder.FindConfigFilePath(null); - + // assert Assert.IsNull(filePath); } @@ -157,7 +157,7 @@ public void TestThatDoesNotFailWhenSearchForOneOfDirectoriesFails() // act var filePath = t_finder.FindConfigFilePath(null); - + // assert Assert.IsNull(filePath); t_environmentOperations.Verify(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile), Times.Once); @@ -186,7 +186,7 @@ public void TestThatDoesNotFailWhenHomeDirectoryDoesNotExist() // act var filePath = t_finder.FindConfigFilePath(null); - + // assert Assert.IsNull(filePath); t_environmentOperations.Verify(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile), Times.Once); @@ -220,7 +220,7 @@ private static void MockExecutionDirectory() .Setup(e => e.GetExecutionDirectory()) .Returns(DriverDirectory); } - + private static void MockFileOnHomePathDoesNotExist() { t_fileOperations diff --git a/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionTest.cs b/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionTest.cs new file mode 100644 index 000000000..18ba3539d --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionTest.cs @@ -0,0 +1,61 @@ + + +using System; +using System.IO; +using Mono.Unix; + +namespace Snowflake.Data.Tests.UnitTests +{ + using Core; + using Core.Tools; + using Moq; + using NUnit.Framework; + using Snowflake.Data.Client; + + public class SnowflakeDbConnectionTest + { + [Test] + public void TestFillConnectionStringFromTomlConfig() + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.IsAny(), It.IsAny>())) + .Returns("[default]\naccount=\"testaccount\"\nuser=\"testuser\"\npassword=\"testpassword\"\n"); + var tomlConnectionBuilder = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + using (var conn = new SnowflakeDbConnection(tomlConnectionBuilder)) + { + conn.FillConnectionStringFromTomlConfigIfNotSet(); + // Assert + Assert.AreEqual("account=testaccount;user=testuser;password=testpassword;", conn.ConnectionString); + } + } + + [Test] + public void TestTomlConfigurationDoesNotOverrideExistingConnectionString() + { + // Arrange + var connectionTest = "account=user1account;user=user1;password=user1password;"; + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.IsAny())) + .Returns("[default]\naccount=\"testaccount\"\nuser=\"testuser\"\npassword=\"testpassword\"\n"); + var tomlConnectionBuilder = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + using (var conn = new SnowflakeDbConnection(tomlConnectionBuilder)) + { + conn.ConnectionString = connectionTest; + conn.FillConnectionStringFromTomlConfigIfNotSet(); + // Assert + Assert.AreEqual(connectionTest, conn.ConnectionString); + } + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/SnowflakeTomlConnectionBuilderTest.cs b/Snowflake.Data.Tests/UnitTests/SnowflakeTomlConnectionBuilderTest.cs new file mode 100644 index 000000000..24c2cb259 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/SnowflakeTomlConnectionBuilderTest.cs @@ -0,0 +1,514 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Mono.Unix; +using Snowflake.Data.Client; + +namespace Snowflake.Data.Tests.UnitTests +{ + using System; + using System.IO; + using Moq; + using NUnit.Framework; + using Core.Tools; + using Snowflake.Data.Core; + + [TestFixture] + class TomlConnectionBuilderTest + { + private const string BasicTomlConfig = @" +[default] +account = ""defaultaccountname"" +user = ""defaultusername"" +password = ""defaultpassword"" +[testconnection] +account = ""testaccountname"" +user = ""testusername"" +password = ""testpassword"" +[otherconnection] +account = ""otheraccountname"" +user = ""otherusername"" +password = ""otherpassword"""; + + [Test] + public void TestConnectionWithReadFromDefaultValuesInSnowflakeTomlConnectionBuilder() + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns(BasicTomlConfig); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual("account=defaultaccountname;user=defaultusername;password=defaultpassword;", connectionString); + } + + [Test] + public void TestConnectionFromCustomSnowflakeHome() + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeHome)) + .Returns($"{Path.DirectorySeparatorChar}customsnowhome"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains("customsnowhome")), It.IsAny>())) + .Returns(BasicTomlConfig); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual("account=defaultaccountname;user=defaultusername;password=defaultpassword;", connectionString); + } + + [Test] + public void TestConnectionWithUserConnectionNameFromEnvVariable() + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName)) + .Returns("testconnection"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns(BasicTomlConfig); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual("account=testaccountname;user=testusername;password=testpassword;", connectionString); + } + + [Test] + public void TestConnectionWithUserConnectionNameFromEnvVariableWithMultipleConnections() + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName)) + .Returns("otherconnection"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns(BasicTomlConfig); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual("account=otheraccountname;user=otherusername;password=otherpassword;", connectionString); + } + + [Test] + public void TestConnectionWithUserConnectionName() + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName)) + .Returns("otherconnection"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns(BasicTomlConfig); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml("testconnection"); + + // Assert + Assert.AreEqual("account=testaccountname;user=testusername;password=testpassword;", connectionString); + } + + + [Test] + [TestCase("database = \"mydb\"", "DB=mydb;")] + public void TestConnectionMapPropertiesFromTomlKeyValues(string tomlKeyValue, string connectionStringValue) + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns($@" +[default] +account = ""defaultaccountname"" +user = ""defaultusername"" +password = ""defaultpassword"" +{tomlKeyValue} +"); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual($"account=defaultaccountname;user=defaultusername;password=defaultpassword;{connectionStringValue}", connectionString); + } + + [Test] + public void TestConnectionConfigurationFileDoesNotExistsShouldReturnEmpty() + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeHome)) + .Returns($"{Path.DirectorySeparatorChar}notexistenttestpath"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(false); + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual(string.Empty, connectionString); + } + + [Test] + public void TestConnectionWithInvalidConnectionName() + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName)) + .Returns("wrongconnectionname"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns(BasicTomlConfig); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act and assert + Assert.Throws(() => reader.GetConnectionStringFromToml(), "Specified connection name does not exist in connections.toml"); + } + + [Test] + public void TestConnectionWithNonExistingDefaultConnection() + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns("[qa]\naccount = \"qaaccountname\"\nuser = \"qausername\"\npassword = \"qapassword\""); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual(string.Empty, connectionString); + } + + + [Test] + public void TestConnectionWithSpecifiedConnectionEmpty() + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName)) + .Returns("testconnection1"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns(@" +[default] +account = ""defaultaccountname"" +user = ""defaultusername"" +password = ""defaultpassword"" +[testconnection1] +[testconnection2] +account = ""testaccountname"" +user = ""testusername"" +password = ""testpassword"""); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual(string.Empty, connectionString); + } + + [Test] + public void TestConnectionWithOauthAuthenticatorTokenFromFile() + { + // Arrange + var tokenFilePath = "/Users/testuser/token"; + var testToken = "token1234"; + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName)) + .Returns("oauthconnection"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(tokenFilePath, It.IsAny>())).Returns(testToken); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns(@$" +[default] +account = ""defaultaccountname"" +user = ""defaultusername"" +password = ""defaultpassword"" +[oauthconnection] +account = ""testaccountname"" +authenticator = ""oauth"" +token_file_path = ""{tokenFilePath}"""); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual($"account=testaccountname;authenticator=oauth;token={testToken};", connectionString); + } + + [Test] + public void TestConnectionWithOauthAuthenticatorThrowsExceptionIfTokenFilePathNotExists() + { + // Arrange + var tokenFilePath = "/Users/testuser/token"; + var defaultToken = "defaultToken1234"; + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName)) + .Returns("oauthconnection"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(tokenFilePath)).Returns(false); + mockFileOperations.Setup(f => f.Exists(It.Is(p => !p.Equals(tokenFilePath)))).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns(@$" +[default] +account = ""defaultaccountname"" +user = ""defaultusername"" +password = ""defaultpassword"" +[oauthconnection] +account = ""testaccountname"" +authenticator = ""oauth"" +token_file_path = ""{tokenFilePath}"""); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains("/snowflake/session/token")), It.IsAny>())).Returns(defaultToken); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act and assert + var exception = Assert.Throws(() => reader.GetConnectionStringFromToml()); + Assert.IsTrue(exception.Message.StartsWith("Error: Invalid parameter value /Users/testuser/token for token_file_path")); + } + + [Test] + public void TestConnectionWithOauthAuthenticatorFromDefaultPathShouldBeLoadedIfTokenFilePathNotSpecified() + { + // Arrange + var defaultToken = "defaultToken1234"; + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName)) + .Returns("oauthconnection"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns(@$" +[default] +account = ""defaultaccountname"" +user = ""defaultusername"" +password = ""defaultpassword"" +[oauthconnection] +account = ""testaccountname"" +authenticator = ""oauth"""); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains("/snowflake/session/token")), It.IsAny>())).Returns(defaultToken); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual($"account=testaccountname;authenticator=oauth;token={defaultToken};", connectionString); + } + + [Test] + public void TestConnectionWithOauthAuthenticatorShouldNotIncludeTokenIfNotStoredDefaultPath() + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName)) + .Returns("oauthconnection"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.Is(p => p.Contains("/snowflake/session/token")))).Returns(false); + mockFileOperations.Setup(f => f.Exists(It.Is(p => !string.IsNullOrEmpty(p) && !p.Contains("/snowflake/session/token")))).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns(@$" +[default] +account = ""defaultaccountname"" +user = ""defaultusername"" +password = ""defaultpassword"" +[oauthconnection] +account = ""testaccountname"" +authenticator = ""oauth"""); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual($"account=testaccountname;authenticator=oauth;", connectionString); + } + + + [Test] + public void TestConnectionWithOauthAuthenticatorShouldNotLoadFromFileIsSpecifiedInTokenProperty() + { + // Arrange + var tokenFilePath = "/Users/testuser/token"; + var tokenFromToml = "tomlToken1234"; + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName)) + .Returns("oauthconnection"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns(@$" +[default] +account = ""defaultaccountname"" +user = ""defaultusername"" +password = ""defaultpassword"" +[oauthconnection] +account = ""testaccountname"" +authenticator = ""oauth"" +token = ""{tokenFromToml}"" +token_file_path = ""{tokenFilePath}"""); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual($"account=testaccountname;authenticator=oauth;token={tokenFromToml};", connectionString); + } + + [Test] + public void TestConnectionWithOauthAuthenticatorShouldNotIncludeTokenIfNullOrEmpty() + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations + .Setup(e => e.GetEnvironmentVariable(TomlConnectionBuilder.SnowflakeDefaultConnectionName)) + .Returns("oauthconnection"); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns(@$" +[default] +account = ""defaultaccountname"" +user = ""defaultusername"" +password = ""defaultpassword"" +[oauthconnection] +account = ""testaccountname"" +authenticator = ""oauth"""); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains("/snowflake/session/token")), It.IsAny>())).Returns(string.Empty); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + + // Assert + Assert.AreEqual($"account=testaccountname;authenticator=oauth;", connectionString); + } + + [Test] + [TestCase("\\\"password;default\\\"", "password;default")] + [TestCase("\\\"\\\"\\\"password;default\\\"", "\"password;default")] + [TestCase("p\\\"assworddefault", "p\"assworddefault")] + [TestCase("password\\\"default", "password\"default")] + [TestCase("password\'default", "password\'default")] + [TestCase("password=default", "password=default")] + [TestCase("\\\"pa=ss\\\"\\\"word;def\'ault\\\"", "pa=ss\"word;def\'ault")] + public void TestConnectionMapPropertiesWithSpecialCharacters(string passwordValueWithSpecialCharacter, string expectedValue) + { + // Arrange + var mockFileOperations = new Mock(); + var mockEnvironmentOperations = new Mock(); + mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns($"{Path.DirectorySeparatorChar}home"); + mockFileOperations.Setup(f => f.Exists(It.IsAny())).Returns(true); + mockFileOperations.Setup(f => f.ReadAllText(It.Is(p => p.Contains(".snowflake")), It.IsAny>())) + .Returns($@" +[default] +account = ""defaultaccountname"" +user = ""defaultusername"" +password = ""{passwordValueWithSpecialCharacter}"" +"); + + var reader = new TomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object); + + // Act + var connectionString = reader.GetConnectionStringFromToml(); + var properties = SFSessionProperties.ParseConnectionString(connectionString, null); + + // Assert + Assert.AreEqual(expectedValue, properties[SFSessionProperty.PASSWORD]); + } + } + +} diff --git a/Snowflake.Data.Tests/UnitTests/Tools/FileOperationsTest.cs b/Snowflake.Data.Tests/UnitTests/Tools/FileOperationsTest.cs new file mode 100644 index 000000000..b8b311357 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Tools/FileOperationsTest.cs @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System.Collections.Generic; +using Snowflake.Data.Core; +using System.IO; +using System.Runtime.InteropServices; +using Mono.Unix; +using Mono.Unix.Native; +using NUnit.Framework; +using Snowflake.Data.Core.Tools; +using static Snowflake.Data.Tests.UnitTests.Configuration.EasyLoggingConfigGenerator; +using System.Security; + +namespace Snowflake.Data.Tests.Tools +{ + [TestFixture, NonParallelizable] + public class FileOperationsTest + { + private static FileOperations s_fileOperations; + private static readonly string s_workingDirectory = Path.Combine(Path.GetTempPath(), "file_operations_test_", Path.GetRandomFileName()); + + [OneTimeSetUp] + public static void BeforeAll() + { + if (!Directory.Exists(s_workingDirectory)) + { + Directory.CreateDirectory(s_workingDirectory); + } + + s_fileOperations = new FileOperations(); + } + + [OneTimeTearDown] + public static void AfterAll() + { + Directory.Delete(s_workingDirectory, true); + } + + [Test] + public void TestReadAllTextOnWindows() + { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test only runs on Windows"); + } + + var content = "random text"; + var filePath = CreateConfigTempFile(s_workingDirectory, content); + + // act + var result = s_fileOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions); + + // assert + Assert.AreEqual(content, result); + } + + [Test] + public void TestReadAllTextCheckingPermissionsUsingTomlConfigurationFileValidations( + [ValueSource(nameof(UserAllowedFilePermissions))] + FileAccessPermissions userAllowedFilePermissions) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + + var content = "random text"; + var filePath = CreateConfigTempFile(s_workingDirectory, content); + var filePermissions = userAllowedFilePermissions; + + Syscall.chmod(filePath, (FilePermissions)filePermissions); + + // act + var result = s_fileOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions); + + // assert + Assert.AreEqual(content, result); + } + + [Test] + public void TestShouldThrowExceptionIfOtherPermissionsIsSetWhenReadConfigurationFile( + [ValueSource(nameof(UserAllowedFilePermissions))] + FileAccessPermissions userAllowedFilePermissions) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + + var content = "random text"; + var filePath = CreateConfigTempFile(s_workingDirectory, content); + var filePermissions = userAllowedFilePermissions | FileAccessPermissions.OtherReadWriteExecute; + + Syscall.chmod(filePath, (FilePermissions)filePermissions); + + // act and assert + Assert.Throws(() => s_fileOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions), + "Attempting to read a file with too broad permissions assigned"); + } + + + public static IEnumerable UserAllowedFilePermissions() + { + yield return FileAccessPermissions.UserRead; + yield return FileAccessPermissions.UserRead | FileAccessPermissions.UserWrite; + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs b/Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs index fde51602c..14e2df121 100644 --- a/Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs @@ -1,9 +1,11 @@ using System.Collections.Generic; using System.IO; using System.Runtime.InteropServices; +using System.Security; using Mono.Unix; using Mono.Unix.Native; using NUnit.Framework; +using Snowflake.Data.Core; using Snowflake.Data.Core.Tools; using static Snowflake.Data.Tests.UnitTests.Configuration.EasyLoggingConfigGenerator; @@ -14,7 +16,7 @@ public class UnixOperationsTest { private static UnixOperations s_unixOperations; private static readonly string s_workingDirectory = Path.Combine(Path.GetTempPath(), "easy_logging_test_configs_", Path.GetRandomFileName()); - + [OneTimeSetUp] public static void BeforeAll() { @@ -34,7 +36,7 @@ public static void AfterAll() return; Directory.Delete(s_workingDirectory, true); } - + [Test] public void TestDetectGroupOrOthersWritablePermissions( [ValueSource(nameof(GroupOrOthersWritablePermissions))] FilePermissions groupOrOthersWritablePermissions, @@ -45,23 +47,23 @@ public void TestDetectGroupOrOthersWritablePermissions( { Assert.Ignore("skip test on Windows"); } - + // arrange var filePath = CreateConfigTempFile(s_workingDirectory, "random text"); var readWriteUserPermissions = FilePermissions.S_IRUSR | FilePermissions.S_IWUSR; var filePermissions = readWriteUserPermissions | groupOrOthersWritablePermissions | groupNotWritablePermissions | otherNotWritablePermissions; Syscall.chmod(filePath, filePermissions); - + // act var result = s_unixOperations.CheckFileHasAnyOfPermissions(filePath, FileAccessPermissions.GroupWrite | FileAccessPermissions.OtherWrite); - + // assert Assert.IsTrue(result); } [Test] public void TestDetectGroupOrOthersNotWritablePermissions( - [ValueSource(nameof(UserPermissions))] FilePermissions userPermissions, + [ValueSource(nameof(UserPermissions))] FilePermissions userPermissions, [ValueSource(nameof(GroupNotWritablePermissions))] FilePermissions groupNotWritablePermissions, [ValueSource(nameof(OtherNotWritablePermissions))] FilePermissions otherNotWritablePermissions) { @@ -69,18 +71,60 @@ public void TestDetectGroupOrOthersNotWritablePermissions( { Assert.Ignore("skip test on Windows"); } - + var filePath = CreateConfigTempFile(s_workingDirectory, "random text"); var filePermissions = userPermissions | groupNotWritablePermissions | otherNotWritablePermissions; Syscall.chmod(filePath, filePermissions); - + // act var result = s_unixOperations.CheckFileHasAnyOfPermissions(filePath, FileAccessPermissions.GroupWrite | FileAccessPermissions.OtherWrite); - + // assert Assert.IsFalse(result); } + [Test] + public void TestReadAllTextCheckingPermissionsUsingTomlConfigurationFileValidations( + [ValueSource(nameof(UserAllowedPermissions))] FilePermissions userAllowedPermissions) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + var content = "random text"; + var filePath = CreateConfigTempFile(s_workingDirectory, content); + Syscall.chmod(filePath, userAllowedPermissions); + + // act + var result = s_unixOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions); + + // assert + Assert.AreEqual(content, result); + } + + [Test] + public void TestFailIfGroupOrOthersHavePermissionsToFileWithTomlConfigurationValidations([ValueSource(nameof(UserReadWritePermissions))] FilePermissions userPermissions, + [ValueSource(nameof(GroupPermissions))] FilePermissions groupPermissions, + [ValueSource(nameof(OthersPermissions))] FilePermissions othersPermissions) + { + if(groupPermissions == 0 && othersPermissions == 0) + { + Assert.Ignore("Skip test when group and others have no permissions"); + } + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + var content = "random text"; + var filePath = CreateConfigTempFile(s_workingDirectory, content); + + var filePermissions = userPermissions | groupPermissions | othersPermissions; + Syscall.chmod(filePath, filePermissions); + + // act and assert + Assert.Throws(() => s_unixOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions), "Attempting to read a file with too broad permissions assigned"); + } public static IEnumerable UserPermissions() { @@ -89,14 +133,32 @@ public static IEnumerable UserPermissions() yield return FilePermissions.S_IXUSR; yield return FilePermissions.S_IRUSR | FilePermissions.S_IWUSR | FilePermissions.S_IXUSR; } - + + public static IEnumerable GroupPermissions() + { + yield return 0; + yield return FilePermissions.S_IRGRP; + yield return FilePermissions.S_IWGRP; + yield return FilePermissions.S_IXGRP; + yield return FilePermissions.S_IRGRP | FilePermissions.S_IWGRP | FilePermissions.S_IXGRP; + } + + public static IEnumerable OthersPermissions() + { + yield return 0; + yield return FilePermissions.S_IROTH; + yield return FilePermissions.S_IWOTH; + yield return FilePermissions.S_IXOTH; + yield return FilePermissions.S_IROTH | FilePermissions.S_IWOTH | FilePermissions.S_IXOTH; + } + public static IEnumerable GroupOrOthersWritablePermissions() { yield return FilePermissions.S_IWGRP; yield return FilePermissions.S_IWOTH; yield return FilePermissions.S_IWGRP | FilePermissions.S_IWOTH; } - + public static IEnumerable GroupNotWritablePermissions() { yield return 0; @@ -112,5 +174,24 @@ public static IEnumerable OtherNotWritablePermissions() yield return FilePermissions.S_IXOTH; yield return FilePermissions.S_IROTH | FilePermissions.S_IXOTH; } + + public static IEnumerable UserReadWritePermissions() + { + yield return FilePermissions.S_IRUSR | FilePermissions.S_IWUSR; + } + + public static IEnumerable UserAllowedPermissions() + { + yield return FilePermissions.S_IRUSR; + yield return FilePermissions.S_IRUSR | FilePermissions.S_IWUSR; + } + + public static IEnumerable GroupOrOthersReadablePermissions() + { + yield return 0; + yield return FilePermissions.S_IRGRP; + yield return FilePermissions.S_IROTH; + yield return FilePermissions.S_IRGRP | FilePermissions.S_IROTH; + } } } diff --git a/Snowflake.Data/Client/SnowflakeDbConnection.cs b/Snowflake.Data/Client/SnowflakeDbConnection.cs index 9acb24f06..70fa642ea 100755 --- a/Snowflake.Data/Client/SnowflakeDbConnection.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnection.cs @@ -3,12 +3,12 @@ */ using System; +using System.Data; using System.Data.Common; -using Snowflake.Data.Core; using System.Security; -using System.Threading.Tasks; -using System.Data; using System.Threading; +using System.Threading.Tasks; +using Snowflake.Data.Core; using Snowflake.Data.Log; namespace Snowflake.Data.Client @@ -37,6 +37,8 @@ public class SnowflakeDbConnection : DbConnection // Will fix that in a separated PR though as it's a different issue private static Boolean _isArrayBindStageCreated; + private readonly TomlConnectionBuilder _tomlConnectionBuilder; + protected enum TransactionRollbackStatus { Undefined, // used to indicate ignored transaction status when pool disabled @@ -44,8 +46,18 @@ protected enum TransactionRollbackStatus Failure } - public SnowflakeDbConnection() + public SnowflakeDbConnection() : this(TomlConnectionBuilder.Instance) { + } + + public SnowflakeDbConnection(string connectionString) : this() + { + ConnectionString = connectionString; + } + + internal SnowflakeDbConnection(TomlConnectionBuilder tomlConnectionBuilder) + { + _tomlConnectionBuilder = tomlConnectionBuilder; _connectionState = ConnectionState.Closed; _connectionTimeout = int.Parse(SFSessionProperty.CONNECTION_TIMEOUT.GetAttribute(). @@ -54,11 +66,6 @@ public SnowflakeDbConnection() ExplicitTransaction = null; } - public SnowflakeDbConnection(string connectionString) : this() - { - ConnectionString = connectionString; - } - public override string ConnectionString { get; set; @@ -268,6 +275,7 @@ public override void Open() } try { + FillConnectionStringFromTomlConfigIfNotSet(); OnSessionConnecting(); SfSession = SnowflakeDbConnectionPool.GetSession(ConnectionString, Password); if (SfSession == null) @@ -292,6 +300,14 @@ public override void Open() } } + internal void FillConnectionStringFromTomlConfigIfNotSet() + { + if (string.IsNullOrEmpty(ConnectionString)) + { + ConnectionString = _tomlConnectionBuilder.GetConnectionStringFromToml(); + } + } + public override Task OpenAsync(CancellationToken cancellationToken) { logger.Debug("Open Connection Async."); @@ -302,6 +318,7 @@ public override Task OpenAsync(CancellationToken cancellationToken) } registerConnectionCancellationCallback(cancellationToken); OnSessionConnecting(); + FillConnectionStringFromTomlConfigIfNotSet(); return SnowflakeDbConnectionPool .GetSessionAsync(ConnectionString, Password, cancellationToken) .ContinueWith(previousTask => diff --git a/Snowflake.Data/Core/TomlConnectionBuilder.cs b/Snowflake.Data/Core/TomlConnectionBuilder.cs new file mode 100644 index 000000000..a8c2396b1 --- /dev/null +++ b/Snowflake.Data/Core/TomlConnectionBuilder.cs @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Security; +using System.Text; +using Mono.Unix; +using Mono.Unix.Native; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; +using Tomlyn; +using Tomlyn.Model; + +namespace Snowflake.Data.Core +{ + internal class TomlConnectionBuilder + { + private const string DefaultConnectionName = "default"; + private const string DefaultSnowflakeFolder = ".snowflake"; + private const string DefaultTokenPath = "/snowflake/session/token"; + + internal const string SnowflakeDefaultConnectionName = "SNOWFLAKE_DEFAULT_CONNECTION_NAME"; + internal const string SnowflakeHome = "SNOWFLAKE_HOME"; + + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private readonly Dictionary _tomlToNetPropertiesMapper = new Dictionary(StringComparer.InvariantCultureIgnoreCase) + { + { "DATABASE", "DB" } + }; + + private readonly FileOperations _fileOperations; + private readonly EnvironmentOperations _environmentOperations; + + public static readonly TomlConnectionBuilder Instance = new TomlConnectionBuilder(); + + private TomlConnectionBuilder() : this(FileOperations.Instance, EnvironmentOperations.Instance) + { + } + + internal TomlConnectionBuilder(FileOperations fileOperations, EnvironmentOperations environmentOperations) + { + _fileOperations = fileOperations; + _environmentOperations = environmentOperations; + } + + public string GetConnectionStringFromToml(string connectionName = null) + { + var tomlPath = ResolveConnectionTomlFile(); + var connectionToml = GetTomlTableFromConfig(tomlPath, connectionName); + s_logger.Info($"Reading connection parameters from file using key: {connectionName} and path: {tomlPath}"); + return connectionToml == null ? string.Empty : GetConnectionStringFromTomlTable(connectionToml); + } + + private string GetConnectionStringFromTomlTable(TomlTable connectionToml) + { + var connectionStringBuilder = new StringBuilder(); + var tokenFilePathValue = string.Empty; + var isOauth = connectionToml.TryGetValue("authenticator", out var authenticator) && authenticator.ToString().Equals("oauth"); + foreach (var property in connectionToml.Keys) + { + var propertyValue = (string)connectionToml[property]; + if (isOauth && property.Equals("token_file_path", StringComparison.InvariantCultureIgnoreCase)) + { + tokenFilePathValue = propertyValue; + continue; + } + var mappedProperty = _tomlToNetPropertiesMapper.TryGetValue(property, out var mapped) ? mapped : property; + connectionStringBuilder.Append($"{mappedProperty}={propertyValue};"); + } + + AppendTokenFromFileIfNotGivenExplicitly(connectionToml, isOauth, connectionStringBuilder, tokenFilePathValue); + return connectionStringBuilder.ToString(); + } + + private void AppendTokenFromFileIfNotGivenExplicitly(TomlTable connectionToml, bool isOauth, + StringBuilder connectionStringBuilder, string tokenFilePathValue) + { + if (!isOauth || connectionToml.ContainsKey("token")) + { + return; + } + + s_logger.Info($"Trying to load token from file {tokenFilePathValue}"); + var token = LoadTokenFromFile(tokenFilePathValue); + if (!string.IsNullOrEmpty(token)) + { + connectionStringBuilder.Append($"token={token};"); + } + else + { + s_logger.Warn("The token has empty value"); + } + } + + private string LoadTokenFromFile(string tokenFilePathValue) + { + string tokenFile; + if(string.IsNullOrEmpty(tokenFilePathValue)) + { + tokenFile = DefaultTokenPath; + } + else + { + if (!_fileOperations.Exists(tokenFilePathValue)) + { + s_logger.Info($"Specified token file {tokenFilePathValue} does not exists."); + throw new SnowflakeDbException(SFError.INVALID_CONNECTION_PARAMETER_VALUE, tokenFilePathValue, "token_file_path"); + } + + tokenFile = tokenFilePathValue; + } + s_logger.Info($"Read token from file path: {tokenFile}"); + return _fileOperations.Exists(tokenFile) ? _fileOperations.ReadAllText(tokenFile, ValidateFilePermissions) : null; + } + + private TomlTable GetTomlTableFromConfig(string tomlPath, string connectionName) + { + if (!_fileOperations.Exists(tomlPath)) + { + return null; + } + + var tomlContent = _fileOperations.ReadAllText(tomlPath, ValidateFilePermissions) ?? string.Empty; + var toml = Toml.ToModel(tomlContent); + if (string.IsNullOrEmpty(connectionName)) + { + connectionName = _environmentOperations.GetEnvironmentVariable(SnowflakeDefaultConnectionName) ?? DefaultConnectionName; + } + + var connectionExists = toml.TryGetValue(connectionName, out var connection); + // Avoid handling error when default connection does not exist, user could not want to use toml configuration and forgot to provide the + // connection string, this error should be thrown later when the undefined connection string is used. + if (!connectionExists && connectionName != DefaultConnectionName) + { + throw new Exception("Specified connection name does not exist in connections.toml"); + } + + var result = connection as TomlTable; + return result; + } + + private string ResolveConnectionTomlFile() + { + var defaultDirectory = Path.Combine(HomeDirectoryProvider.HomeDirectory(_environmentOperations), DefaultSnowflakeFolder); + var tomlFolder = _environmentOperations.GetEnvironmentVariable(SnowflakeHome) ?? defaultDirectory; + var tomlPath = Path.Combine(tomlFolder, "connections.toml"); + return tomlPath; + } + + internal static void ValidateFilePermissions(UnixStream stream) + { + var allowedPermissions = new FileAccessPermissions[] + { + FileAccessPermissions.UserRead | FileAccessPermissions.UserWrite, + FileAccessPermissions.UserRead + }; + if (stream.OwnerUser.UserId != Syscall.geteuid()) + throw new SecurityException("Attempting to read a file not owned by the effective user of the current process"); + if (stream.OwnerGroup.GroupId != Syscall.getegid()) + throw new SecurityException("Attempting to read a file not owned by the effective group of the current process"); + if (!(allowedPermissions.Any(a => stream.FileAccessPermissions == a))) + throw new SecurityException("Attempting to read a file with too broad permissions assigned"); + } + } +} diff --git a/Snowflake.Data/Core/Tools/FileOperations.cs b/Snowflake.Data/Core/Tools/FileOperations.cs index 9efe481bd..577bd54ee 100644 --- a/Snowflake.Data/Core/Tools/FileOperations.cs +++ b/Snowflake.Data/Core/Tools/FileOperations.cs @@ -2,17 +2,33 @@ * Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. */ +using System; using System.IO; +using System.Runtime.InteropServices; +using Mono.Unix; namespace Snowflake.Data.Core.Tools { + internal class FileOperations { public static readonly FileOperations Instance = new FileOperations(); + private readonly UnixOperations _unixOperations = UnixOperations.Instance; public virtual bool Exists(string path) { return File.Exists(path); } + + public virtual string ReadAllText(string path) + { + return ReadAllText(path, null); + } + + public virtual string ReadAllText(string path, Action validator) + { + var contentFile = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) || validator == null ? File.ReadAllText(path) : _unixOperations.ReadAllText(path, validator); + return contentFile; + } } } diff --git a/Snowflake.Data/Core/Tools/UnixOperations.cs b/Snowflake.Data/Core/Tools/UnixOperations.cs index cb44099b7..655b708ea 100644 --- a/Snowflake.Data/Core/Tools/UnixOperations.cs +++ b/Snowflake.Data/Core/Tools/UnixOperations.cs @@ -2,6 +2,10 @@ * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. */ +using System; +using System.IO; +using System.Security; +using System.Text; using Mono.Unix; using Mono.Unix.Native; @@ -27,5 +31,19 @@ public virtual bool CheckFileHasAnyOfPermissions(string path, FileAccessPermissi var fileInfo = new UnixFileInfo(path); return (permissions & fileInfo.FileAccessPermissions) != 0; } + + public string ReadAllText(string path, Action validator) + { + var fileInfo = new UnixFileInfo(path: path); + + using (var handle = fileInfo.OpenRead()) + { + validator?.Invoke(handle); + using (var streamReader = new StreamReader(handle, Encoding.UTF8)) + { + return streamReader.ReadToEnd(); + } + } + } } } diff --git a/Snowflake.Data/Snowflake.Data.csproj b/Snowflake.Data/Snowflake.Data.csproj index a0b09fade..f17124419 100644 --- a/Snowflake.Data/Snowflake.Data.csproj +++ b/Snowflake.Data/Snowflake.Data.csproj @@ -28,6 +28,7 @@ + diff --git a/ci/test.sh b/ci/test.sh index b8ee8aec0..aaa2aa51b 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -48,6 +48,7 @@ for name in "${!TARGET_TEST_IMAGES[@]}"; do -e RUNNER_TRACKING_ID \ -e JOB_NAME \ -e BUILD_NUMBER \ + -e CI \ ${TEST_IMAGE_NAMES[$name]} \ /mnt/host/ci/container/test_component.sh echo "[INFO] Test Results: $WORKSPACE/junit-dotnet.xml" diff --git a/doc/Connecting.md b/doc/Connecting.md index 576120f79..0999d6a58 100644 --- a/doc/Connecting.md +++ b/doc/Connecting.md @@ -295,3 +295,90 @@ Examples: - `myaccount.snowflakecomputing.com` (Not bypassed). - `*myaccount.snowflakecomputing.com` (Bypassed). +### Snowflake credentials using a configuration file + +.NET Drivers allows to add connections definitions to a configuration file. For a connection defined in this way all supported parameters in .NET could be defined and will be used to generate our connection string. + +.NET Driver looks for the `connections.toml` in the following locations, in order. + +* `SNOWFLAKE_HOME` environment variable, You can modify the environment variable to use a different location. +* Otherwise, it uses the `connections.toml` file in `.snowflake` subfolder of the home directory, that is, based on your operating system: + * MacOS/Linux: `~/.snowflake/connections.toml` + * Windows: `%USERPROFILE%\.snowflake\connections.toml` + +For MacOS and Linux systems, .NET Driver demands the connections.toml file to have limited file permissions to read and write for the file owner only. To set the file required file permissions execute the following commands: + +``` BASH +chown $USER connections.toml +chmod 0600 connections.toml +``` + +In the C# code to use this mechanism you should not specify any connection and it will try to use the configuration file. + +``` toml +[myconnection] +account = "myaccount" +user = "jdoe" +password = "xyz1234" +``` + +```cs +using (IDbConnection conn = new SnowflakeDbConnection()) +{ + conn.Open(); // Reads connection definition from configuration file. + + conn.Close(); +} +``` + +By default the name of the connection will be `default`. You can also change the default connection name by setting the SNOWFLAKE_DEFAULT_CONNECTION_NAME environment variable, as shown: + +```bash +set SNOWFLAKE_DEFAULT_CONNECTION_NAME=my_prod_connection +``` + +The following examples show how you can include different types of special characters in a toml key value pair string: + +- To include a single quote (') character: + + ```toml + [default] + host = "fakeaccount.snowflakecomputing.com" + user = "fakeuser" + password = "fake\'password" + ``` + +- To include a double quote (") character: + + ```toml + [default] + host = "fakeaccount.snowflakecomputing.com" + user = "fakeuser" + password = "fake\"password" + ``` + - In case that double quote is use with other character that requires be wrap with double quoted it shoud use \\"\\" for a ": + + ```toml + [default] + host = "fakeaccount.snowflakecomputing.com" + user = "fakeuser" + password = "\";fake\"\"password\"" + ``` + +- To include a semicolon (;): + + ```toml + [default] + host = "fakeaccount.snowflakecomputing.com" + user = "fakeuser" + password = "\";fakepassword\"" + ``` + +- To include an equal sign (=): + + ```toml + [default] + host = "fakeaccount.snowflakecomputing.com" + user = "fakeuser" + password = "fake=password" + ``` diff --git a/snowflake-connector-net.sln.DotSettings b/snowflake-connector-net.sln.DotSettings index b3644095f..fd86da92d 100644 --- a/snowflake-connector-net.sln.DotSettings +++ b/snowflake-connector-net.sln.DotSettings @@ -1,5 +1,8 @@  + False + True True + CI False <Policy Inspect="True" Prefix="" Suffix="" Style="AaBb" /> <Policy Inspect="True" Prefix="s_" Suffix="" Style="aaBb"><ExtraRule Prefix="t_" Suffix="" Style="aaBb" /></Policy> From 05da00a935fbaf42f25c2ddb98cbf999c5909ee2 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Fri, 8 Nov 2024 11:10:58 +0100 Subject: [PATCH 20/20] SNOW-1524245 Set initialisation vector length for Gcm encryption to 12 bytes (#1056) --- .../UnitTests/GcmEncryptionProviderTest.cs | 2 +- .../Core/FileTransfer/GcmEncryptionProvider.cs | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Snowflake.Data.Tests/UnitTests/GcmEncryptionProviderTest.cs b/Snowflake.Data.Tests/UnitTests/GcmEncryptionProviderTest.cs index 60c0c2059..53bc7c27e 100644 --- a/Snowflake.Data.Tests/UnitTests/GcmEncryptionProviderTest.cs +++ b/Snowflake.Data.Tests/UnitTests/GcmEncryptionProviderTest.cs @@ -14,7 +14,7 @@ public class GcmEncryptionProviderTest { private const string PlainText = "there is no rose without thorns"; private static readonly byte[] s_plainTextBytes = Encoding.UTF8.GetBytes(PlainText); - private static readonly byte[] s_qsmkBytes = TestDataGenarator.NextBytes(GcmEncryptionProvider.BlockSizeInBytes); + private static readonly byte[] s_qsmkBytes = TestDataGenarator.NextBytes(GcmEncryptionProvider.TagSizeInBytes); private static readonly string s_qsmk = Convert.ToBase64String(s_qsmkBytes); private static readonly string s_queryId = Guid.NewGuid().ToString(); private const long SmkId = 1234L; diff --git a/Snowflake.Data/Core/FileTransfer/GcmEncryptionProvider.cs b/Snowflake.Data/Core/FileTransfer/GcmEncryptionProvider.cs index 50b80dd05..b7ad2cda0 100644 --- a/Snowflake.Data/Core/FileTransfer/GcmEncryptionProvider.cs +++ b/Snowflake.Data/Core/FileTransfer/GcmEncryptionProvider.cs @@ -10,8 +10,9 @@ namespace Snowflake.Data.Core.FileTransfer { internal class GcmEncryptionProvider { - private const int AesBlockSize = 128; - internal const int BlockSizeInBytes = AesBlockSize / 8; + private const int TagSizeInBits = 128; + internal const int TagSizeInBytes = TagSizeInBits / 8; + private const int InitVectorSizeInBytes = 12; private const string AesGcmNoPaddingCipher = "AES/GCM/NoPadding"; private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); @@ -57,8 +58,8 @@ public static Stream Encrypt( int masterKeySize = decodedMasterKey.Length; s_logger.Debug($"Master key size : {masterKeySize}"); - var contentIV = new byte[BlockSizeInBytes]; - var keyIV = new byte[BlockSizeInBytes]; + var contentIV = new byte[InitVectorSizeInBytes]; + var keyIV = new byte[InitVectorSizeInBytes]; var fileKeyBytes = new byte[masterKeySize]; // we choose a random fileKey to encrypt it with qsmk key with GCM s_random.NextBytes(contentIV); s_random.NextBytes(keyIV); @@ -179,8 +180,8 @@ private static IBufferedCipher BuildAesGcmNoPaddingCipher(bool forEncryption, by var cipher = CipherUtilities.GetCipher(AesGcmNoPaddingCipher); KeyParameter keyParameter = new KeyParameter(keyBytes); var keyParameterAead = aadData == null - ? new AeadParameters(keyParameter, AesBlockSize, initialisationVector) - : new AeadParameters(keyParameter, AesBlockSize, initialisationVector, aadData); + ? new AeadParameters(keyParameter, TagSizeInBits, initialisationVector) + : new AeadParameters(keyParameter, TagSizeInBits, initialisationVector, aadData); cipher.Init(forEncryption, keyParameterAead); return cipher; }