Skip to content

Commit

Permalink
Use parameter logical type to create parameter values for prepared st…
Browse files Browse the repository at this point in the history
…atements
  • Loading branch information
Giorgi committed Nov 27, 2024
1 parent 09e2dd0 commit 465f11e
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,8 @@ public static class PreparedStatements

[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_execute_prepared_streaming")]
public static extern DuckDBState DuckDBExecutePreparedStreaming(DuckDBPreparedStatement preparedStatement, out DuckDBResult result);

[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_param_logical_type")]
public static extern DuckDBLogicalType DuckDBParamLogicalType(DuckDBPreparedStatement preparedStatement, long index);
}
}
16 changes: 7 additions & 9 deletions DuckDB.NET.Data/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Numerics;
using System.Reflection;

namespace DuckDB.NET.Data.Extensions;

internal static class TypeExtensions
{
private static readonly HashSet<Type> FloatingNumericTypes = new()
{
typeof(decimal), typeof(float), typeof(double)
};
private static readonly HashSet<Type> FloatingNumericTypes = [typeof(decimal), typeof(float), typeof(double)];

private static readonly HashSet<Type> IntegralNumericTypes = new()
{
private static readonly HashSet<Type> IntegralNumericTypes =
[
typeof(byte), typeof(sbyte),
typeof(short), typeof(ushort),
typeof(int), typeof(uint),
typeof(long),typeof(ulong),
typeof(long), typeof(ulong),
typeof(BigInteger)
};
];

public static bool IsNull(this object? value) => value is null or DBNull;
public static bool IsNull([NotNullWhen(false)] this object? value) => value is null or DBNull;

public static (bool isNullableValueType, Type type) IsNullableValueType<T>()
{
Expand Down
60 changes: 60 additions & 0 deletions DuckDB.NET.Data/Internal/ClrToDuckDBConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,64 @@ private static DuckDBValue CreateCollectionValue<T>(DuckDBType duckDBType, IColl

return NativeMethods.Value.DuckDBCreateListValue(listItemType, values, collection.Count);
}

public static DuckDBValue ToDuckDBValue(this object? item, DuckDBLogicalType logicalType)
{
if (item.IsNull())
{
return NativeMethods.Value.DuckDBCreateNullValue();
}

var duckDBType = NativeMethods.LogicalType.DuckDBGetTypeId(logicalType);

return (duckDBType, item) switch
{
(DuckDBType.Boolean, bool value) => NativeMethods.Value.DuckDBCreateBool(value),

(DuckDBType.TinyInt, _) => NativeMethods.Value.DuckDBCreateInt8(ConvertTo<sbyte>()),
(DuckDBType.SmallInt, _) => NativeMethods.Value.DuckDBCreateInt16(ConvertTo<short>()),
(DuckDBType.Integer, _) => NativeMethods.Value.DuckDBCreateInt32(ConvertTo<int>()),
(DuckDBType.BigInt, _) => NativeMethods.Value.DuckDBCreateInt64(ConvertTo<long>()),

(DuckDBType.UnsignedTinyInt, _) => NativeMethods.Value.DuckDBCreateUInt8(ConvertTo<byte>()),
(DuckDBType.UnsignedSmallInt, _) => NativeMethods.Value.DuckDBCreateUInt16(ConvertTo<ushort>()),
(DuckDBType.UnsignedInteger, _) => NativeMethods.Value.DuckDBCreateUInt32(ConvertTo<uint>()),
(DuckDBType.UnsignedBigInt, _) => NativeMethods.Value.DuckDBCreateUInt64(ConvertTo<ulong>()),

(DuckDBType.Float, float value) => NativeMethods.Value.DuckDBCreateFloat(value),
(DuckDBType.Double, double value) => NativeMethods.Value.DuckDBCreateDouble(value),

(DuckDBType.Decimal, decimal value) => DecimalToDuckDBValue(value),
(DuckDBType.HugeInt, BigInteger value) => NativeMethods.Value.DuckDBCreateHugeInt(new DuckDBHugeInt(value)),

(DuckDBType.Varchar, string value) => StringToDuckDBValue(value),
(DuckDBType.Uuid, Guid value) => GuidToDuckDBValue(value),

(DuckDBType.Timestamp, DateTime value) => NativeMethods.Value.DuckDBCreateTimestamp(NativeMethods.DateTimeHelpers.DuckDBToTimestamp(DuckDBTimestamp.FromDateTime(value))),
(DuckDBType.Interval, TimeSpan value) => NativeMethods.Value.DuckDBCreateInterval(value),
(DuckDBType.Date, DuckDBDateOnly value) => NativeMethods.Value.DuckDBCreateDate(NativeMethods.DateTimeHelpers.DuckDBToDate(value)),
(DuckDBType.Time, DuckDBTimeOnly value) => NativeMethods.Value.DuckDBCreateTime(NativeMethods.DateTimeHelpers.DuckDBToTime(value)),
#if NET6_0_OR_GREATER
(DuckDBType.Date, DateOnly value) => NativeMethods.Value.DuckDBCreateDate(NativeMethods.DateTimeHelpers.DuckDBToDate(value)),
(DuckDBType.Time, TimeOnly value) => NativeMethods.Value.DuckDBCreateTime(NativeMethods.DateTimeHelpers.DuckDBToTime(value)),
#endif
(DuckDBType.TimeTz, DateTimeOffset value) => DateTimeOffsetToDuckDBValue(value),
(DuckDBType.Blob, byte[] value) => NativeMethods.Value.DuckDBCreateBlob(value, value.Length),
(DuckDBType.List, ICollection value) => CreateCollectionValue(value),
(DuckDBType.Array, ICollection value) => CreateCollectionValue(value),
_ => throw new InvalidOperationException($"Cannot bind parameter type {item.GetType().FullName} to column of type {duckDBType}")
};

T ConvertTo<T>()
{
try
{
return (T)Convert.ChangeType(item, typeof(T));
}
catch (Exception)
{
throw new ArgumentOutOfRangeException($"Cannot bind parameter type {item.GetType().FullName} to column of type {duckDBType}");
}
}
}
}
10 changes: 9 additions & 1 deletion DuckDB.NET.Data/Internal/IsExternalInit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ class StackTraceHiddenAttribute : Attribute { }
namespace System.Diagnostics.CodeAnalysis
{
[AttributeUsage(AttributeTargets.Method)]
class DoesNotReturnAttribute: Attribute { }
class DoesNotReturnAttribute : Attribute { }

[AttributeUsage(AttributeTargets.Parameter, Inherited = false)]
class NotNullWhenAttribute : Attribute
{
public NotNullWhenAttribute(bool returnValue) => ReturnValue = returnValue;

public bool ReturnValue { get; }
}
}
#endif
7 changes: 6 additions & 1 deletion DuckDB.NET.Data/Internal/PreparedStatement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ private static void BindParameters(DuckDBPreparedStatement preparedStatement, Du
{
BindParameter(preparedStatement, index, param);
}
else
{
throw new DuckDBException($"Cannot get index for parameter {param.ParameterName}");
}
}
}
else
Expand All @@ -110,7 +114,8 @@ private static void BindParameters(DuckDBPreparedStatement preparedStatement, Du

private static void BindParameter(DuckDBPreparedStatement preparedStatement, long index, DuckDBParameter parameter)
{
using var duckDBValue = parameter.Value.ToDuckDBValue();
using var parameterLogicalType = NativeMethods.PreparedStatements.DuckDBParamLogicalType(preparedStatement, index);
using var duckDBValue = parameter.Value.ToDuckDBValue(parameterLogicalType);

var result = NativeMethods.PreparedStatements.DuckDBBindValue(preparedStatement, index, duckDBValue);

Expand Down
2 changes: 1 addition & 1 deletion DuckDB.NET.Test/DuckDBClientFactoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void UseFactory()

connection.ConnectionString = connectionStringBuilder.ConnectionString;

command.CommandText = "Select ?";
command.CommandText = "Select ?::integer";
command.Connection = connection;
parameter.Value = 42;
command.Parameters.Add(parameter);
Expand Down
2 changes: 1 addition & 1 deletion DuckDB.NET.Test/ExtractStatementsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public void NotExistingTableThrowsException()
[Fact]
public void MissingParametersThrowsException()
{
Command.CommandText = "Select ?1; Select ?1, ?2";
Command.CommandText = "Select ?1::integer; Select ?1::integer, ?2::integer";
Command.Parameters.Add(new DuckDBParameter(42));

var dataReader = Command.ExecuteReader();
Expand Down
4 changes: 2 additions & 2 deletions DuckDB.NET.Test/Parameters/DateTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void BindWithCastTest(int year, int mon, int day)
{
var expectedValue = new DateTime(year, mon, day);

Command.CommandText = "SELECT ?;";
Command.CommandText = "SELECT ?::DATE;";
Command.Parameters.Add(new DuckDBParameter((DuckDBDateOnly)expectedValue));

var scalar = Command.ExecuteScalar();
Expand Down Expand Up @@ -104,7 +104,7 @@ public void BindDateOnly(int year, int mon, int day)
{
var expectedValue = new DateOnly(year, mon, day);

Command.CommandText = "SELECT ?;";
Command.CommandText = "SELECT ?::DATE;";
Command.Parameters.Add(new DuckDBParameter(expectedValue));

var scalar = Command.ExecuteScalar();
Expand Down
2 changes: 1 addition & 1 deletion DuckDB.NET.Test/Parameters/DecimalParameterTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void DecimalTests(decimal[] values, int precision, int scale)

foreach (var value in values)
{
Command.CommandText = "Insert Into DecimalValuesTests (key, value) values (1, ?)";
Command.CommandText = $"Insert Into DecimalValuesTests (key, value) values (1, ?::decimal({precision}, {scale}))";
Command.Parameters.Add(new DuckDBParameter(value));
Command.ExecuteNonQuery();

Expand Down
20 changes: 10 additions & 10 deletions DuckDB.NET.Test/Parameters/IntegerParametersTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ namespace DuckDB.NET.Test.Parameters;

public class IntegerParametersTests(DuckDBDatabaseFixture db) : DuckDBTestBase(db)
{
private void TestBind<TValue>(TValue expectedValue, DuckDBParameter parameter, Func<DuckDBDataReader, TValue> getValue)
private void TestBind<TValue>(string duckDbType, TValue expectedValue, DuckDBParameter parameter, Func<DuckDBDataReader, TValue> getValue)
{
Command.CommandText = "SELECT ?;";
Command.CommandText = $"SELECT ?::{duckDbType};";
Command.Parameters.Add(parameter);

var scalar = Command.ExecuteScalar();
Expand Down Expand Up @@ -86,7 +86,7 @@ void TestReadValueAs<T>(DuckDBDataReader reader) where T : INumberBase<T>
public void ByteTest(byte value)
{
TestSimple<byte>("UTINYINT", value, r => r.GetByte(0));
TestBind(value, new DuckDBParameter(value), r => r.GetByte(0));
TestBind("UTINYINT", value, new DuckDBParameter(value), r => r.GetByte(0));
}

[Theory]
Expand All @@ -96,7 +96,7 @@ public void ByteTest(byte value)
public void SByteTest(sbyte value)
{
TestSimple<sbyte>("TINYINT", value, r => r.GetFieldValue<sbyte>(0));
TestBind(value, new DuckDBParameter(value), r => r.GetFieldValue<sbyte>(0));
TestBind("TINYINT", value, new DuckDBParameter(value), r => r.GetFieldValue<sbyte>(0));
}

[Theory]
Expand All @@ -106,7 +106,7 @@ public void SByteTest(sbyte value)
public void UInt16Test(ushort value)
{
TestSimple<ushort>("USMALLINT", value, r => r.GetFieldValue<ushort>(0));
TestBind(value, new DuckDBParameter(value), r => r.GetFieldValue<ushort>(0));
TestBind("USMALLINT", value, new DuckDBParameter(value), r => r.GetFieldValue<ushort>(0));
}

[Theory]
Expand All @@ -116,7 +116,7 @@ public void UInt16Test(ushort value)
public void Int16Test(short value)
{
TestSimple<short>("SMALLINT", value, r => r.GetInt16(0));
TestBind(value, new DuckDBParameter(value), r => r.GetInt16(0));
TestBind("SMALLINT", value, new DuckDBParameter(value), r => r.GetInt16(0));
}

[Theory]
Expand All @@ -126,7 +126,7 @@ public void Int16Test(short value)
public void UInt32Test(uint value)
{
TestSimple<uint>("UINTEGER", value, r => r.GetFieldValue<uint>(0));
TestBind(value, new DuckDBParameter(value), r => r.GetFieldValue<uint>(0));
TestBind("UINTEGER", value, new DuckDBParameter(value), r => r.GetFieldValue<uint>(0));
}

[Theory]
Expand All @@ -136,7 +136,7 @@ public void UInt32Test(uint value)
public void Int32Test(int value)
{
TestSimple<int>("INTEGER", value, r => r.GetInt32(0));
TestBind(value, new DuckDBParameter(value), r => r.GetInt32(0));
TestBind("INTEGER", value, new DuckDBParameter(value), r => r.GetInt32(0));

TestSimple<int>("INTEGER", value, r => r.GetFieldValue<int?>(0));
}
Expand All @@ -148,7 +148,7 @@ public void Int32Test(int value)
public void UInt64Test(ulong value)
{
TestSimple<ulong>("UBIGINT", value, r => r.GetFieldValue<ulong>(0));
TestBind(value, new DuckDBParameter(value), r => r.GetFieldValue<ulong>(0));
TestBind("UBIGINT", value, new DuckDBParameter(value), r => r.GetFieldValue<ulong>(0));

TestSimple("UBIGINT", value, r => r.GetFieldValue<ulong?>(0));
}
Expand All @@ -160,7 +160,7 @@ public void UInt64Test(ulong value)
public void Int64Test(long value)
{
TestSimple<long>("BIGINT", value, r => r.GetInt64(0));
TestBind(value, new DuckDBParameter(value), r => r.GetInt64(0));
TestBind("BIGINT", value, new DuckDBParameter(value), r => r.GetInt64(0));

TestSimple("BIGINT", value, r => r.GetFieldValue<long?>(0));
}
Expand Down
30 changes: 14 additions & 16 deletions DuckDB.NET.Test/Parameters/ParameterCollectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ namespace DuckDB.NET.Test.Parameters;
public class ParameterCollectionTests(DuckDBDatabaseFixture db) : DuckDBTestBase(db)
{
[Theory]
[InlineData("SELECT ?1;")]
[InlineData("SELECT ?;")]
[InlineData("SELECT $1;")]
[InlineData("SELECT ?1::INT;")]
[InlineData("SELECT ?::INT;")]
[InlineData("SELECT $1::INT;")]
public void BindSingleValueTest(string query)
{
Command.Parameters.Add(new DuckDBParameter("1", 42));
Expand Down Expand Up @@ -65,7 +65,7 @@ public void ParameterConstructorTests()
Command.CommandText = "CREATE TABLE ParameterConstructorTests (key INTEGER, value double, State Boolean, ErrorCode Long, value2 float)";
Command.ExecuteNonQuery();

Command.CommandText = "Insert Into ParameterConstructorTests values (?,?,?,?,?)";
Command.CommandText = "Insert Into ParameterConstructorTests values (?::INT,?::double,?::Boolean,?::BIGINT,?::float)";
Command.Parameters.Add(new DuckDBParameter(DbType.Double, 2.4));
Command.Parameters.Add(new DuckDBParameter(true));
Command.Parameters.Insert(0, new DuckDBParameter(2));
Expand Down Expand Up @@ -159,10 +159,10 @@ public void BindMultipleValuesTestNamedParameters(string queryStatement)
}

[Theory]
[InlineData("INSERT INTO ParametersTestInvalidOrderKeyValue (KEY, VALUE) VALUES (?2, ?1)")]
[InlineData("INSERT INTO ParametersTestInvalidOrderKeyValue (KEY, VALUE) VALUES ($2, $1)")]
[InlineData("UPDATE ParametersTestInvalidOrderKeyValue SET Key = ?2, Value = ?1;")]
[InlineData("UPDATE ParametersTestInvalidOrderKeyValue SET Key = $2, Value = $1;")]
[InlineData("INSERT INTO ParametersTestInvalidOrderKeyValue (KEY, VALUE) VALUES (?2::VARCHAR, ?1::INT)")]
[InlineData("INSERT INTO ParametersTestInvalidOrderKeyValue (KEY, VALUE) VALUES ($2::VARCHAR, $1::INT)")]
[InlineData("UPDATE ParametersTestInvalidOrderKeyValue SET Key = ?2::INT, Value = ?1::VARCHAR;")]
[InlineData("UPDATE ParametersTestInvalidOrderKeyValue SET Key = $2::INT, Value = $1::VARCHAR;")]
public void BindMultipleValuesInvalidOrderTest(string queryStatement)
{
using var defer = new Defer(() => Connection.Execute("DROP TABLE ParametersTestInvalidOrderKeyValue;"));
Expand All @@ -175,14 +175,12 @@ public void BindMultipleValuesInvalidOrderTest(string queryStatement)
Command.CommandText = queryStatement;
Command.Parameters.Add(new DuckDBParameter("param1", 42));
Command.Parameters.Add(new DuckDBParameter("param2", "hello"));
Command.Invoking(cmd => cmd.ExecuteNonQuery())
.Should().ThrowExactly<DuckDBException>();
Command.Invoking(cmd => cmd.ExecuteNonQuery()).Should().ThrowExactly<DuckDBException>();

Command.Parameters.Clear();
Command.Parameters.Add(new DuckDBParameter(42));
Command.Parameters.Add(new DuckDBParameter("hello"));
Command.Invoking(cmd => cmd.ExecuteNonQuery())
.Should().ThrowExactly<DuckDBException>();
Command.Invoking(cmd => cmd.ExecuteNonQuery()).Should().ThrowExactly<InvalidOperationException>();
}

[Theory]
Expand Down Expand Up @@ -282,7 +280,7 @@ public void BindUnreferencedNamedParameterInParameterlessQueryTest()
[Fact]
public void BindUnreferencedNamedParameterTest()
{
Command.CommandText = "SELECT $used";
Command.CommandText = "SELECT $used::INT";
Command.Parameters.Add(new DuckDBParameter("unused", 24));
Command.Parameters.Add(new DuckDBParameter("used", 42));
var scalar = Command.ExecuteScalar();
Expand All @@ -301,14 +299,14 @@ public void BindUnreferencedPositionalParameterTest()
[Fact]
public void BindUnicodeParameterTest()
{
Command.CommandText = "SELECT $数字";
Command.CommandText = "SELECT $数字::INT";
Command.Parameters.Add(new DuckDBParameter("数字",42));
Command.ExecuteScalar().Should().Be(42);
}

[Theory]
[InlineData("SELECT $2 - $1", 18)]
[InlineData("SELECT ? - ?", -18)]
[InlineData("SELECT $2::INT - $1::INT", 18)]
[InlineData("SELECT ?::INT - ?::INT", -18)]
public void BindParameterWithPositionOrAutoIncrement(string query, int result)
{
Command.Parameters.Add(new DuckDBParameter(24));
Expand Down
6 changes: 3 additions & 3 deletions DuckDB.NET.Test/Parameters/TimeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public void BindWithCastTest(int hour, int minute, int second, int microsecond)
var expectedValue = new DateTime(DateTime.MinValue.Year, DateTime.MinValue.Month, DateTime.MinValue.Day,
hour, minute, second).AddTicks(microsecond * 10);

Command.CommandText = "SELECT ?;";
Command.CommandText = "SELECT ?::TIME;";
Command.Parameters.Add(new DuckDBParameter((DuckDBTimeOnly)expectedValue));

var scalar = Command.ExecuteScalar();
Expand Down Expand Up @@ -160,7 +160,7 @@ public void QueryTimeTzReaderTest(int hour, int minute, int second, int microsec
var timeSpan = new TimeSpan(offsetHours, offsetHours >= 0 ? offsetMinutes : -offsetMinutes, 0);
dateTimeOffset.Offset.Should().Be(timeSpan);

Command.CommandText = "SELECT ?";
Command.CommandText = "SELECT ?::TIMETZ";
Command.Parameters.Add(new DuckDBParameter(dateTimeOffset));

using var reader = Command.ExecuteReader();
Expand All @@ -181,7 +181,7 @@ public void BindTimeOnly(int hour, int minute, int second, int microsecond)
{
var expectedValue = new TimeOnly(hour, minute, second,0).Add(TimeSpan.FromMicroseconds(microsecond));

Command.CommandText = "SELECT ?;";
Command.CommandText = "SELECT ?::TIME;";
Command.Parameters.Add(new DuckDBParameter(expectedValue));

var scalar = Command.ExecuteScalar();
Expand Down
2 changes: 1 addition & 1 deletion DuckDB.NET.Test/Parameters/TimestampTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public void BindTest(int year, int mon, int day, int hour, int minute, int secon
{
var expectedValue = new DateTime(year, mon, day, hour, minute, second).AddTicks(microsecond * 10);

Command.CommandText = "SELECT ?;";
Command.CommandText = "SELECT ?::TimeStamp;";
Command.Parameters.Add(new DuckDBParameter(expectedValue));

var scalar = Command.ExecuteScalar();
Expand Down

0 comments on commit 465f11e

Please sign in to comment.