diff --git a/DuckDB.NET.Bindings/NativeMethods.cs b/DuckDB.NET.Bindings/NativeMethods.cs index 4f040cc7..ded5c428 100644 --- a/DuckDB.NET.Bindings/NativeMethods.cs +++ b/DuckDB.NET.Bindings/NativeMethods.cs @@ -102,7 +102,7 @@ public static class DataChunks public static extern IntPtr DuckDBVectorGetData(IntPtr vector); [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_vector_get_validity")] - public static extern UIntPtr DuckDBVectorGetValidity(IntPtr vector); + public static extern IntPtr DuckDBVectorGetValidity(IntPtr vector); [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_list_vector_get_child")] public static extern long DuckDBListVectorGetChild(IntPtr vector); @@ -113,7 +113,7 @@ public static class DataChunks [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_struct_vector_get_child")] public static extern long DuckDBStructVectorGetChild(IntPtr vector, long index); } - + public static class Types { [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_value_boolean")] @@ -169,7 +169,7 @@ public static class Types [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_value_timestamp")] public static extern DuckDBTimestampStruct DuckDBValueTimestamp([In, Out] DuckDBResult result, long col, long row); - + [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_result_get_chunk")] public static extern IntPtr DuckDBResultGetChunk(DuckDBResult result, long chunkIndex); @@ -255,7 +255,7 @@ public static class PreparedStatements public static extern DuckDBState DuckDBExecutePrepared(DuckDBPreparedStatement preparedStatement, [In, Out] DuckDBResult result); } - public static class ExtractStatements + public static class ExtractStatements { [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_extract_statements")] public static extern int DuckDBExtractStatements(DuckDBNativeConnection connection, string query, out DuckDBExtractedStatements extractedStatements); diff --git a/DuckDB.NET.Data/DuckDBDataReader.cs b/DuckDB.NET.Data/DuckDBDataReader.cs index 931d3150..8c7b4c45 100644 --- a/DuckDB.NET.Data/DuckDBDataReader.cs +++ b/DuckDB.NET.Data/DuckDBDataReader.cs @@ -26,7 +26,8 @@ public class DuckDBDataReader : DbDataReader private int fieldCount; private int recordsAffected; - private Dictionary vectors = new(); + private readonly Dictionary vectors = new(); + private readonly Dictionary vectorValidityMask = new(); internal DuckDBDataReader(DuckDbCommand command, List queryResults, CommandBehavior behavior) { @@ -50,6 +51,7 @@ private void InitReaderData() for (int i = 0; i < columnCount; i++) { vectors[i] = NativeMethods.DataChunks.DuckDBDataChunkGetVector(chunk, i); + vectorValidityMask[i] = (NativeMethods.DataChunks.DuckDBVectorGetValidity(vectors[i])); } currentRow = -1; @@ -120,6 +122,7 @@ private DuckDBTimeOnly GetTimeOnly(int ordinal) public override decimal GetDecimal(int ordinal) { + return 0; return decimal.Parse(GetString(ordinal), CultureInfo.InvariantCulture); } @@ -153,8 +156,7 @@ public override Type GetFieldType(int ordinal) DuckDBType.DuckdbTypeVarchar => typeof(string), DuckDBType.DuckdbTypeDecimal => typeof(decimal), DuckDBType.DuckdbTypeBlob => typeof(Stream), - var type => throw new ArgumentException( - $"Unrecognised type {type} ({(int)type}) in column {ordinal + 1}") + var type => throw new ArgumentException($"Unrecognised type {type} ({(int)type}) in column {ordinal + 1}") }; } @@ -235,6 +237,11 @@ public override int GetOrdinal(string name) public override string GetString(int ordinal) { + if (IsDBNull(ordinal)) + { + return null; + } + var data = NativeMethods.DataChunks.DuckDBVectorGetData(vectors[ordinal]); data += currentRow * Marshal.SizeOf(); @@ -309,9 +316,14 @@ public override Stream GetStream(int ordinal) public override bool IsDBNull(int ordinal) { - return false; - var nullMask = NativeMethods.Query.DuckDBNullmaskData(currentResult, ordinal); - return Marshal.ReadByte(nullMask, currentRow) != 0; + var validityMaskEntryIndex = currentRow / 64; + var validityBitIndex = currentRow % 64; + + var validityMaskEntryPtr = vectorValidityMask[ordinal] + validityMaskEntryIndex; + var validityBit = 1ul << validityBitIndex; + + var isValid = (Marshal.PtrToStructure(validityMaskEntryPtr) & validityBit) != 0; + return !isValid; } public override int FieldCount => fieldCount; @@ -329,11 +341,11 @@ public override bool IsDBNull(int ordinal) public override bool NextResult() { currentResultIndex++; - + if (currentResultIndex < queryResults.Count) { currentResult = queryResults[currentResultIndex]; - + InitReaderData(); return true; }