From 8fd36b69f1e8c868b1f2ed8c32e6cf718b21bdc8 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Tue, 26 Mar 2024 09:23:27 +0100 Subject: [PATCH] Add types to Go SQL driver This adds two interfaces, one for the type names and one for nullable to the Go SQL driver for Vitess. This is missing and means this type information can't be easily retrieved from result sets. Signed-off-by: Dirkjan Bussink --- go/vt/vitessdriver/rows.go | 77 ++++++++++++++++++++ go/vt/vitessdriver/rows_test.go | 120 ++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+) diff --git a/go/vt/vitessdriver/rows.go b/go/vt/vitessdriver/rows.go index a2438bb891c..1af88e64ec3 100644 --- a/go/vt/vitessdriver/rows.go +++ b/go/vt/vitessdriver/rows.go @@ -119,3 +119,80 @@ func (ri *rows) ColumnTypeScanType(index int) reflect.Type { return typeUnknown } } + +func (ri *rows) ColumnTypeDatabaseTypeName(index int) string { + field := ri.qr.Fields[index] + switch field.GetType() { + case query.Type_INT8: + return "TINYINT" + case query.Type_UINT8: + return "UNSIGNED TINYINT" + case query.Type_INT16: + return "SMALLINT" + case query.Type_UINT16: + return "UNSIGNED SMALLINT" + case query.Type_YEAR: + return "YEAR" + case query.Type_INT24: + return "MEDIUMINT" + case query.Type_UINT24: + return "UNSIGNED MEDIUMINT" + case query.Type_INT32: + return "INT" + case query.Type_UINT32: + return "UNSIGNED INT" + case query.Type_INT64: + return "BIGINT" + case query.Type_UINT64: + return "UNSIGNED BIGINT" + case query.Type_FLOAT32: + return "FLOAT" + case query.Type_FLOAT64: + return "DOUBLE" + case query.Type_DECIMAL: + return "DECIMAL" + case query.Type_VARCHAR: + return "VARCHAR" + case query.Type_TEXT: + return "TEXT" + case query.Type_BLOB: + return "BLOB" + case query.Type_VARBINARY: + return "VARBINARY" + case query.Type_CHAR: + return "CHAR" + case query.Type_BINARY: + return "BINARY" + case query.Type_BIT: + return "BIT" + case query.Type_ENUM: + return "ENUM" + case query.Type_SET: + return "SET" + case query.Type_HEXVAL: + return "VARBINARY" + case query.Type_HEXNUM: + return "VARBINARY" + case query.Type_BITNUM: + return "VARBINARY" + case query.Type_GEOMETRY: + return "GEOMETRY" + case query.Type_JSON: + return "JSON" + case query.Type_TIMESTAMP: + return "TIMESTAMP" + case query.Type_DATE: + return "DATE" + case query.Type_TIME: + return "TIME" + case query.Type_DATETIME: + return "DATETIME" + default: + return "" + } +} + +func (ri *rows) ColumnTypeNullable(index int) (nullable, ok bool) { + field := ri.qr.Fields[index] + return field.GetFlags()&uint32(query.MySqlFlag_NOT_NULL_FLAG) == 0, true +} diff --git a/go/vt/vitessdriver/rows_test.go b/go/vt/vitessdriver/rows_test.go index 13584e70dd8..bb196da30c3 100644 --- a/go/vt/vitessdriver/rows_test.go +++ b/go/vt/vitessdriver/rows_test.go @@ -226,3 +226,123 @@ func TestColumnTypeScanType(t *testing.T) { assert.Equal(t, ri.ColumnTypeScanType(i), wantTypes[i], fmt.Sprintf("unexpected type %v, wanted %v", ri.ColumnTypeScanType(i), wantTypes[i])) } } + +// Test that the ColumnTypeScanType function returns the correct reflection type for each +// sql type. The sql type in turn comes from a table column's type. +func TestColumnTypeDatabaseTypeName(t *testing.T) { + var r = sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "field1", + Type: sqltypes.Int8, + }, + { + Name: "field2", + Type: sqltypes.Uint8, + }, + { + Name: "field3", + Type: sqltypes.Int16, + }, + { + Name: "field4", + Type: sqltypes.Uint16, + }, + { + Name: "field5", + Type: sqltypes.Int24, + }, + { + Name: "field6", + Type: sqltypes.Uint24, + }, + { + Name: "field7", + Type: sqltypes.Int32, + }, + { + Name: "field8", + Type: sqltypes.Uint32, + }, + { + Name: "field9", + Type: sqltypes.Int64, + }, + { + Name: "field10", + Type: sqltypes.Uint64, + }, + { + Name: "field11", + Type: sqltypes.Float32, + }, + { + Name: "field12", + Type: sqltypes.Float64, + }, + { + Name: "field13", + Type: sqltypes.VarBinary, + }, + { + Name: "field14", + Type: sqltypes.Datetime, + }, + }, + } + + ri := newRows(&r, &converter{}).(driver.RowsColumnTypeDatabaseTypeName) + defer ri.Close() + + wantTypes := []string{ + "TINYINT", + "UNSIGNED TINYINT", + "SMALLINT", + "UNSIGNED SMALLINT", + "MEDIUMINT", + "UNSIGNED MEDIUMINT", + "INT", + "UNSIGNED INT", + "BIGINT", + "UNSIGNED BIGINT", + "FLOAT", + "DOUBLE", + "VARBINARY", + "DATETIME", + } + + for i := 0; i < len(wantTypes); i++ { + assert.Equal(t, ri.ColumnTypeDatabaseTypeName(i), wantTypes[i], fmt.Sprintf("unexpected type %v, wanted %v", ri.ColumnTypeDatabaseTypeName(i), wantTypes[i])) + } +} + +// Test that the ColumnTypeScanType function returns the correct reflection type for each +// sql type. The sql type in turn comes from a table column's type. +func TestColumnTypeNullable(t *testing.T) { + var r = sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "field1", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG), + }, + { + Name: "field2", + Type: sqltypes.Int64, + }, + }, + } + + ri := newRows(&r, &converter{}).(driver.RowsColumnTypeNullable) + defer ri.Close() + + nullable := []bool{ + false, + true, + } + + for i := 0; i < len(nullable); i++ { + null, _ := ri.ColumnTypeNullable(i) + assert.Equal(t, null, nullable[i], fmt.Sprintf("unexpected type %v, wanted %v", null, nullable[i])) + } +}