Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to scan int into interface{} issue fix #1815

Open
wants to merge 1 commit into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 161 additions & 1 deletion cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import (
"time"
"unicode"

inf "gopkg.in/inf.v0"
"gopkg.in/inf.v0"
)

func TestEmptyHosts(t *testing.T) {
Expand Down Expand Up @@ -3288,3 +3288,163 @@ func TestQuery_NamedValues(t *testing.T) {
t.Fatal(err)
}
}

func TestScanToAny(t *testing.T) {
session := createSession(t)
defer session.Close()
ctx := context.Background()

dataTypes := []struct {
tableName string
createQuery string
insertQuery string
expectedVal interface{}
}{
{
"scan_to_any_varchar",
"CREATE TABLE IF NOT EXISTS scan_to_any_varchar (id int PRIMARY KEY, val varchar)",
"INSERT INTO scan_to_any_varchar (id, val) VALUES (?, ?)",
"test",
},
{
"scan_to_any_bool",
"CREATE TABLE IF NOT EXISTS scan_to_any_bool (id int PRIMARY KEY, val boolean)",
"INSERT INTO scan_to_any_bool (id, val) VALUES (?, ?)",
true,
},
{
"scan_to_any_int",
"CREATE TABLE IF NOT EXISTS scan_to_any_int (id int PRIMARY KEY, val int)",
"INSERT INTO scan_to_any_int (id, val) VALUES (?, ?)",
42,
},
{
"scan_to_any_float",
"CREATE TABLE IF NOT EXISTS scan_to_any_float (id int PRIMARY KEY, val float)",
"INSERT INTO scan_to_any_float (id, val) VALUES (?, ?)",
float32(3.14),
},
{
"scan_to_any_double",
"CREATE TABLE IF NOT EXISTS scan_to_any_double (id int PRIMARY KEY, val double)",
"INSERT INTO scan_to_any_double (id, val) VALUES (?, ?)",
3.14159,
},
{
"scan_to_any_decimal",
"CREATE TABLE IF NOT EXISTS scan_to_any_decimal (id int PRIMARY KEY, val decimal)",
"INSERT INTO scan_to_any_decimal (id, val) VALUES (?, ?)",
inf.NewDec(12345, 2), // Example decimal value
},
{
"scan_to_any_time",
"CREATE TABLE IF NOT EXISTS scan_to_any_time (id int PRIMARY KEY, val time)",
"INSERT INTO scan_to_any_time (id, val) VALUES (?, ?)",
time.Duration(1000),
},
{
"scan_to_any_timestamp",
"CREATE TABLE IF NOT EXISTS scan_to_any_timestamp (id int PRIMARY KEY, val timestamp)",
"INSERT INTO scan_to_any_timestamp (id, val) VALUES (?, ?)",
time.Now().UTC().Truncate(time.Millisecond),
},
{
"scan_to_any_inet",
"CREATE TABLE IF NOT EXISTS scan_to_any_inet (id int PRIMARY KEY, val inet)",
"INSERT INTO scan_to_any_inet (id, val) VALUES (?, ?)",
net.ParseIP("192.168.0.1"),
},
{
"scan_to_any_uuid",
"CREATE TABLE IF NOT EXISTS scan_to_any_uuid (id int PRIMARY KEY, val uuid)",
"INSERT INTO scan_to_any_uuid (id, val) VALUES (?, ?)",
TimeUUID().String(),
},
{
"scan_to_any_date",
"CREATE TABLE IF NOT EXISTS scan_to_any_date (id int PRIMARY KEY, val date)",
"INSERT INTO scan_to_any_date (id, val) VALUES (?, ?)",
time.Now().UTC().Truncate(time.Hour * 24),
},
{
"scan_to_any_duration",
"CREATE TABLE IF NOT EXISTS scan_to_any_duration (id int PRIMARY KEY, val duration)",
"INSERT INTO scan_to_any_duration (id, val) VALUES (?, ?)",
Duration{0, 0, 123},
},
}

for _, dt := range dataTypes {
t.Run(fmt.Sprintf("Test_%s", dt.tableName), func(t *testing.T) {
if err := session.Query(dt.createQuery).WithContext(ctx).Exec(); err != nil {
t.Fatal(err)
}

if err := session.Query(dt.insertQuery, 1, dt.expectedVal).WithContext(ctx).Exec(); err != nil {
t.Fatal(err)
}

var out interface{}
if err := session.Query(fmt.Sprintf("SELECT val FROM %s WHERE id = 1", dt.tableName)).WithContext(ctx).Scan(&out); err != nil {
t.Fatal(err)
}

if err := session.Query(fmt.Sprintf("DROP TABLE %s", dt.tableName)).WithContext(ctx).Exec(); err != nil {
t.Fatal(err)
}

switch dt.tableName {
case "scan_to_any_decimal":
result, ok := out.(inf.Dec)
if !ok {
t.Fatal("expected inf.Dec, got", out)
}
expected := inf.NewDec(12345, 2)

if result.Cmp(expected) != 0 {
t.Fatalf("expected %v, got %v", expected, out)
}
case "scan_to_any_inet":
result, ok := out.(net.IP)
if !ok {
t.Fatal("expected net.IP, got", out)
}
expected, ok := dt.expectedVal.(net.IP)
if !ok {
t.Fatal("expected net.IP, got", dt.expectedVal)
}
if result.String() != expected.String() {
t.Fatalf("expected %v, got %v", expected, out)
}
case "scan_to_any_date":
result, ok := out.(time.Time)
if !ok {
t.Fatal("expected time.Time, got", out)
}
expected, ok := dt.expectedVal.(time.Time)
if !ok {
t.Fatal("expected time.Time, got", dt.expectedVal)
}
if result.String() != expected.String() {
t.Fatalf("expected %v, got %v", expected, out)
}
case "scan_to_any_duration":
result, ok := out.(Duration)
if !ok {
t.Fatal("expected time.Duration, got", out)
}
expected, ok := dt.expectedVal.(Duration)
if !ok {
t.Fatal("expected time.Duration, got", dt.expectedVal)
}
if result != expected {
t.Fatalf("expected %v, got %v", expected, out)
}
default:
if out != dt.expectedVal {
t.Fatalf("expected %v, got %v", dt.expectedVal, out)
}
}
})
}
}
102 changes: 102 additions & 0 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,23 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
// date | *time.Time | time of beginning of the day (in UTC)
// date | *string | formatted with 2006-01-02 format
// duration | *gocql.Duration |
//
// Scan into interface{} implemented by unmarshal into default type:
//
// CQL type | Go type
// Varchar | string
// Varint | bigInt
// IntLike | int
// Boolean | bool
// Float | float32
// Double | float64
// Decimal | infDec
// Time | time.Duration
// Timestamp | time.Time
// Date | time.Time
// Duration | Duration
// UUID | string
// Inet | net.IP
func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
if v, ok := value.(Unmarshaler); ok {
return v.UnmarshalCQL(info, data)
Expand Down Expand Up @@ -350,6 +367,9 @@ func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error {
*v = nil
}
return nil
case *interface{}:
*v = string(data)
return nil
}

rv := reflect.ValueOf(value)
Expand Down Expand Up @@ -743,6 +763,8 @@ func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error {
*v = bytesToUint64(data[1:])
return nil
}
case *interface{}:
return unmarshalBigInt(info, data, value)
}

if len(data) > 8 {
Expand Down Expand Up @@ -904,6 +926,12 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
case *string:
*v = strconv.FormatInt(int64Val, 10)
return nil
case *interface{}:
if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) {
return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, info.Type())
}
*v = int(int64Val)
return nil
}

rv := reflect.ValueOf(value)
Expand Down Expand Up @@ -1055,6 +1083,9 @@ func unmarshalBool(info TypeInfo, data []byte, value interface{}) error {
case *bool:
*v = decBool(data)
return nil
case *interface{}:
*v = decBool(data)
return nil
}
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
Expand Down Expand Up @@ -1105,6 +1136,9 @@ func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error {
case *float32:
*v = math.Float32frombits(uint32(decInt(data)))
return nil
case *interface{}:
*v = math.Float32frombits(uint32(decInt(data)))
return nil
}
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
Expand Down Expand Up @@ -1146,6 +1180,9 @@ func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error {
case *float64:
*v = math.Float64frombits(uint64(decBigInt(data)))
return nil
case *interface{}:
*v = math.Float64frombits(uint64(decBigInt(data)))
return nil
}
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
Expand Down Expand Up @@ -1196,6 +1233,14 @@ func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error {
unscaled := decBigInt2C(data[4:], nil)
*v = *inf.NewDecBig(unscaled, inf.Scale(scale))
return nil
case *interface{}:
if len(data) < 4 {
return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data))
}
scale := decInt(data[0:4])
unscaled := decBigInt2C(data[4:], nil)
*v = *inf.NewDecBig(unscaled, inf.Scale(scale))
return nil
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
}
Expand Down Expand Up @@ -1302,6 +1347,9 @@ func unmarshalTime(info TypeInfo, data []byte, value interface{}) error {
case *time.Duration:
*v = time.Duration(decBigInt(data))
return nil
case *interface{}:
*v = time.Duration(decBigInt(data))
return nil
}

rv := reflect.ValueOf(value)
Expand Down Expand Up @@ -1334,6 +1382,16 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error {
nsec := (x - sec*1000) * 1000000
*v = time.Unix(sec, nsec).In(time.UTC)
return nil
case *interface{}:
if len(data) == 0 {
*v = time.Time{}
return nil
}
x := decBigInt(data)
sec := x / 1000
nsec := (x - sec*1000) * 1000000
*v = time.Unix(sec, nsec).In(time.UTC)
return nil
}

rv := reflect.ValueOf(value)
Expand Down Expand Up @@ -1419,6 +1477,16 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error {
timestamp := (int64(current) - int64(origin)) * millisecondsInADay
*v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02")
return nil
case *interface{}:
if len(data) == 0 {
*v = time.Time{}
return nil
}
var origin uint32 = 1 << 31
var current uint32 = binary.BigEndian.Uint32(data)
timestamp := (int64(current) - int64(origin)) * millisecondsInADay
*v = time.UnixMilli(timestamp).In(time.UTC)
return nil
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
}
Expand Down Expand Up @@ -1478,6 +1546,25 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error {
Nanoseconds: nanos,
}
return nil
case *interface{}:
if len(data) == 0 {
*v = Duration{
Months: 0,
Days: 0,
Nanoseconds: 0,
}
return nil
}
months, days, nanos, err := decVints(data)
if err != nil {
return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error())
}
*v = Duration{
Months: months,
Days: days,
Nanoseconds: nanos,
}
return nil
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
}
Expand Down Expand Up @@ -1914,6 +2001,9 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
case *[]byte:
*v = u[:]
return nil
case *interface{}:
*v = u.String()
return nil
}
return unmarshalErrorf("can not unmarshal X %s into %T", info, value)
}
Expand Down Expand Up @@ -1996,6 +2086,18 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error {
}
*v = ip.String()
return nil
case *interface{}:
if x := len(data); !(x == 4 || x == 16) {
return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x)
}
buf := copyBytes(data)
ip := net.IP(buf)
if v4 := ip.To4(); v4 != nil {
*v = v4
return nil
}
*v = ip
return nil
}
return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
}
Expand Down