diff --git a/src/dbspecific.h b/src/dbspecific.h index d9bc4dc4..e6af0a93 100644 --- a/src/dbspecific.h +++ b/src/dbspecific.h @@ -11,6 +11,7 @@ // SQL Server +#define SQL_SS_VARIANT -150 // SQL Server 2008 SQL_VARIANT type #define SQL_SS_XML -152 // SQL Server 2005 XML type #define SQL_DB2_DECFLOAT -360 // IBM DB/2 DECFLOAT type #define SQL_DB2_XML -370 // IBM DB/2 XML type diff --git a/src/getdata.cpp b/src/getdata.cpp index 626f41b1..c401a35a 100644 --- a/src/getdata.cpp +++ b/src/getdata.cpp @@ -38,6 +38,7 @@ void GetData_init() } static byte* ReallocOrFreeBuffer(byte* pb, Py_ssize_t cbNeed); +PyObject *GetData_SqlVariant(Cursor *cur, Py_ssize_t iCol); inline bool IsBinaryType(SQLSMALLINT sqltype) { @@ -534,28 +535,30 @@ static PyObject* GetDataTimestamp(Cursor* cur, Py_ssize_t iCol) switch (cur->colinfos[iCol].sql_type) { - case SQL_TYPE_TIME: - { - int micros = (int)(value.fraction / 1000); // nanos --> micros - return PyTime_FromTime(value.hour, value.minute, value.second, micros); - } - - case SQL_TYPE_DATE: - return PyDate_FromDate(value.year, value.month, value.day); - - case SQL_TYPE_TIMESTAMP: - { - if (value.year < 1) - { - value.year = 1; - } - else if (value.year > 9999) - { - value.year = 9999; - } - } + case SQL_TYPE_TIME: + { + int micros = (int)(value.fraction / 1000); // nanos --> micros + return PyTime_FromTime(value.hour, value.minute, value.second, micros); + } + + case SQL_TYPE_DATE: + case SQL_DATE: + return PyDate_FromDate(value.year, value.month, value.day); + + case SQL_TYPE_TIMESTAMP: + case SQL_TIMESTAMP: + { + if (value.year < 1) + { + value.year = 1; + } + else if (value.year > 9999) + { + value.year = 9999; + } + } } - + int micros = (int)(value.fraction / 1000); // nanos --> micros @@ -645,6 +648,7 @@ PyObject* PythonTypeFromSqlType(Cursor* cur, SQLSMALLINT type) break; case SQL_TYPE_DATE: + case SQL_DATE: pytype = (PyObject*)PyDateTimeAPI->DateType; break; @@ -654,6 +658,7 @@ PyObject* PythonTypeFromSqlType(Cursor* cur, SQLSMALLINT type) break; case SQL_TYPE_TIMESTAMP: + case SQL_TIMESTAMP: pytype = (PyObject*)PyDateTimeAPI->DateTimeType; break; @@ -744,15 +749,55 @@ PyObject* GetData(Cursor* cur, Py_ssize_t iCol) return GetDataDouble(cur, iCol); + case SQL_DATE: case SQL_TYPE_DATE: case SQL_TYPE_TIME: + case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: return GetDataTimestamp(cur, iCol); case SQL_SS_TIME2: return GetSqlServerTime(cur, iCol); + + case SQL_SS_VARIANT: + return GetData_SqlVariant(cur, iCol); } return RaiseErrorV("HY106", ProgrammingError, "ODBC SQL type %d is not yet supported. column-index=%zd type=%d", (int)pinfo->sql_type, iCol, (int)pinfo->sql_type); } + +PyObject *GetData_SqlVariant(Cursor *cur, Py_ssize_t iCol) { + char pBuff; + + SQLLEN indicator, variantType; + SQLRETURN retcode; + + PyObject *decodeResult; + + // Call SQLGetData on the current column with a data length of 0. According to MS, this makes + // the ODBC driver read the sql_variant header which contains the underlying data type + pBuff = 0; + indicator = 0; + retcode = SQLGetData(cur->hstmt, static_cast(iCol + 1), SQL_C_BINARY, + &pBuff, 0, &indicator); + if (!SQL_SUCCEEDED(retcode)) + return RaiseErrorFromHandle(cur->cnxn, "SQLGetData", cur->cnxn->hdbc, cur->hstmt); + + // Get the SQL_CA_SS_VARIANT_TYPE field for the column which will contain the underlying data type + variantType = 0; + retcode = SQLColAttribute(cur->hstmt, iCol + 1, SQL_CA_SS_VARIANT_TYPE, NULL, 0, NULL, &variantType); + if (!SQL_SUCCEEDED(retcode)) + return RaiseErrorFromHandle(cur->cnxn, "SQLColAttribute", cur->cnxn->hdbc, cur->hstmt); + + // Replace the original SQL_VARIANT data type with the underlying data type then call GetData() again + cur->colinfos[iCol].sql_type = static_cast(variantType); + decodeResult = GetData(cur, iCol); + + // Restore the original SQL_VARIANT data type so that the next decode will call this method again + cur->colinfos[iCol].sql_type = static_cast(SQL_SS_VARIANT); + + return decodeResult; + + // NOTE: we don't free the hstmt here as it's managed by the cursor +} diff --git a/src/pyodbc.h b/src/pyodbc.h index a18529e9..febcc677 100644 --- a/src/pyodbc.h +++ b/src/pyodbc.h @@ -76,6 +76,10 @@ typedef unsigned long long UINT64; #define SQL_CA_SS_CATALOG_NAME 1225 #endif +#ifndef SQL_CA_SS_VARIANT_TYPE +#define SQL_CA_SS_VARIANT_TYPE 1215 +#endif + inline bool IsSet(DWORD grf, DWORD flags) { return (grf & flags) == flags; @@ -117,7 +121,7 @@ inline void DebugTrace(const char* szFmt, ...) { UNUSED(szFmt); } // issue #880: entry missing from iODBC sqltypes.h #ifndef BYTE - typedef unsigned char BYTE; + typedef unsigned char BYTE; #endif bool PyMem_Realloc(BYTE** pp, size_t newlen); // A wrapper around realloc with a safer interface. If it is successful, *pp is updated to the diff --git a/tests/sqlserver_test.py b/tests/sqlserver_test.py index 9602de4a..ab51ce0d 100755 --- a/tests/sqlserver_test.py +++ b/tests/sqlserver_test.py @@ -1614,6 +1614,44 @@ def test_tvp_diffschema(cursor: pyodbc.Cursor): _test_tvp(cursor, True) +@pytest.mark.skipif(SQLSERVER_YEAR < 2000, reason='sql_variant not supported until 2000') +def test_sql_variant(cursor: pyodbc.Cursor): + """ + Tests decoding of the sql_variant data type as performed by the GetData_SqlVariant() method. + """ + + cursor.execute("create table t1 (a sql_variant)") + + # insert a number of values of disparate types. this is not exhaustive as not all + # types that can be contained within a sql_variant field are supported by pyodbc + cursor.execute("insert into t1 values (456.7)") + cursor.execute("insert into t1 values ('a string')") + cursor.execute("insert into t1 values (CAST('2024-06-03' AS DATE))") + cursor.execute("insert into t1 values (CAST('2024-06-03 23:46:03.000' AS DATETIME))") + cursor.execute("insert into t1 values (CAST('binary data' AS VARBINARY(200)))") + cursor.execute( + "insert into t1 values (CAST('0592b437-745f-4b2c-a997-97022c624cf6' AS UNIQUEIDENTIFIER))" + ) + + # select all of the values we inserted and ensure they have the correct types + results = [record[0] for record in cursor.execute("select a from t1").fetchall()] + for index, assertion_tuple in enumerate( + [ + (Decimal, Decimal("456.7")), + (str, "a string"), + (date, date(2024, 6, 3)), + (datetime, datetime(2024, 6, 3, 23, 46, 3)), + (bytes, b'binary data'), + (uuid.UUID, uuid.UUID("0592b437-745f-4b2c-a997-97022c624cf6")) + ] + ): + # pylint: disable=unidiomatic-typecheck + expected_type, expected_value = assertion_tuple + + assert type(results[index]) == expected_type + assert results[index] == expected_value + + def get_sqlserver_version(cursor: pyodbc.Cursor): """