diff --git a/cassandra_test.go b/cassandra_test.go index 797a7cf7f..16b87a2a9 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -44,7 +44,7 @@ import ( "time" "unicode" - inf "gopkg.in/inf.v0" + "gopkg.in/inf.v0" ) func TestEmptyHosts(t *testing.T) { @@ -3288,3 +3288,163 @@ func TestQuery_NamedValues(t *testing.T) { t.Fatal(err) } } + +func TestScanToAny(t *testing.T) { + session := createSession(t) + defer session.Close() + ctx := context.Background() + + dataTypes := []struct { + tableName string + createQuery string + insertQuery string + expectedVal interface{} + }{ + { + "scan_to_any_varchar", + "CREATE TABLE IF NOT EXISTS scan_to_any_varchar (id int PRIMARY KEY, val varchar)", + "INSERT INTO scan_to_any_varchar (id, val) VALUES (?, ?)", + "test", + }, + { + "scan_to_any_bool", + "CREATE TABLE IF NOT EXISTS scan_to_any_bool (id int PRIMARY KEY, val boolean)", + "INSERT INTO scan_to_any_bool (id, val) VALUES (?, ?)", + true, + }, + { + "scan_to_any_int", + "CREATE TABLE IF NOT EXISTS scan_to_any_int (id int PRIMARY KEY, val int)", + "INSERT INTO scan_to_any_int (id, val) VALUES (?, ?)", + 42, + }, + { + "scan_to_any_float", + "CREATE TABLE IF NOT EXISTS scan_to_any_float (id int PRIMARY KEY, val float)", + "INSERT INTO scan_to_any_float (id, val) VALUES (?, ?)", + float32(3.14), + }, + { + "scan_to_any_double", + "CREATE TABLE IF NOT EXISTS scan_to_any_double (id int PRIMARY KEY, val double)", + "INSERT INTO scan_to_any_double (id, val) VALUES (?, ?)", + 3.14159, + }, + { + "scan_to_any_decimal", + "CREATE TABLE IF NOT EXISTS scan_to_any_decimal (id int PRIMARY KEY, val decimal)", + "INSERT INTO scan_to_any_decimal (id, val) VALUES (?, ?)", + inf.NewDec(12345, 2), // Example decimal value + }, + { + "scan_to_any_time", + "CREATE TABLE IF NOT EXISTS scan_to_any_time (id int PRIMARY KEY, val time)", + "INSERT INTO scan_to_any_time (id, val) VALUES (?, ?)", + time.Duration(1000), + }, + { + "scan_to_any_timestamp", + "CREATE TABLE IF NOT EXISTS scan_to_any_timestamp (id int PRIMARY KEY, val timestamp)", + "INSERT INTO scan_to_any_timestamp (id, val) VALUES (?, ?)", + time.Now().UTC().Truncate(time.Millisecond), + }, + { + "scan_to_any_inet", + "CREATE TABLE IF NOT EXISTS scan_to_any_inet (id int PRIMARY KEY, val inet)", + "INSERT INTO scan_to_any_inet (id, val) VALUES (?, ?)", + net.ParseIP("192.168.0.1"), + }, + { + "scan_to_any_uuid", + "CREATE TABLE IF NOT EXISTS scan_to_any_uuid (id int PRIMARY KEY, val uuid)", + "INSERT INTO scan_to_any_uuid (id, val) VALUES (?, ?)", + TimeUUID().String(), + }, + { + "scan_to_any_date", + "CREATE TABLE IF NOT EXISTS scan_to_any_date (id int PRIMARY KEY, val date)", + "INSERT INTO scan_to_any_date (id, val) VALUES (?, ?)", + time.Now().UTC().Truncate(time.Hour * 24), + }, + { + "scan_to_any_duration", + "CREATE TABLE IF NOT EXISTS scan_to_any_duration (id int PRIMARY KEY, val duration)", + "INSERT INTO scan_to_any_duration (id, val) VALUES (?, ?)", + Duration{0, 0, 123}, + }, + } + + for _, dt := range dataTypes { + t.Run(fmt.Sprintf("Test_%s", dt.tableName), func(t *testing.T) { + if err := session.Query(dt.createQuery).WithContext(ctx).Exec(); err != nil { + t.Fatal(err) + } + + if err := session.Query(dt.insertQuery, 1, dt.expectedVal).WithContext(ctx).Exec(); err != nil { + t.Fatal(err) + } + + var out interface{} + if err := session.Query(fmt.Sprintf("SELECT val FROM %s WHERE id = 1", dt.tableName)).WithContext(ctx).Scan(&out); err != nil { + t.Fatal(err) + } + + if err := session.Query(fmt.Sprintf("DROP TABLE %s", dt.tableName)).WithContext(ctx).Exec(); err != nil { + t.Fatal(err) + } + + switch dt.tableName { + case "scan_to_any_decimal": + result, ok := out.(inf.Dec) + if !ok { + t.Fatal("expected inf.Dec, got", out) + } + expected := inf.NewDec(12345, 2) + + if result.Cmp(expected) != 0 { + t.Fatalf("expected %v, got %v", expected, out) + } + case "scan_to_any_inet": + result, ok := out.(net.IP) + if !ok { + t.Fatal("expected net.IP, got", out) + } + expected, ok := dt.expectedVal.(net.IP) + if !ok { + t.Fatal("expected net.IP, got", dt.expectedVal) + } + if result.String() != expected.String() { + t.Fatalf("expected %v, got %v", expected, out) + } + case "scan_to_any_date": + result, ok := out.(time.Time) + if !ok { + t.Fatal("expected time.Time, got", out) + } + expected, ok := dt.expectedVal.(time.Time) + if !ok { + t.Fatal("expected time.Time, got", dt.expectedVal) + } + if result.String() != expected.String() { + t.Fatalf("expected %v, got %v", expected, out) + } + case "scan_to_any_duration": + result, ok := out.(Duration) + if !ok { + t.Fatal("expected time.Duration, got", out) + } + expected, ok := dt.expectedVal.(Duration) + if !ok { + t.Fatal("expected time.Duration, got", dt.expectedVal) + } + if result != expected { + t.Fatalf("expected %v, got %v", expected, out) + } + default: + if out != dt.expectedVal { + t.Fatalf("expected %v, got %v", dt.expectedVal, out) + } + } + }) + } +} diff --git a/marshal.go b/marshal.go index 4d0adb923..09dfbab27 100644 --- a/marshal.go +++ b/marshal.go @@ -222,6 +222,23 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { // date | *time.Time | time of beginning of the day (in UTC) // date | *string | formatted with 2006-01-02 format // duration | *gocql.Duration | +// +// Scan into interface{} implemented by unmarshal into default type: +// +// CQL type | Go type +// Varchar | string +// Varint | bigInt +// IntLike | int +// Boolean | bool +// Float | float32 +// Double | float64 +// Decimal | infDec +// Time | time.Duration +// Timestamp | time.Time +// Date | time.Time +// Duration | Duration +// UUID | string +// Inet | net.IP func Unmarshal(info TypeInfo, data []byte, value interface{}) error { if v, ok := value.(Unmarshaler); ok { return v.UnmarshalCQL(info, data) @@ -350,6 +367,9 @@ func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { *v = nil } return nil + case *interface{}: + *v = string(data) + return nil } rv := reflect.ValueOf(value) @@ -743,6 +763,8 @@ func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { *v = bytesToUint64(data[1:]) return nil } + case *interface{}: + return unmarshalBigInt(info, data, value) } if len(data) > 8 { @@ -904,6 +926,12 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac case *string: *v = strconv.FormatInt(int64Val, 10) return nil + case *interface{}: + if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) { + return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, info.Type()) + } + *v = int(int64Val) + return nil } rv := reflect.ValueOf(value) @@ -1055,6 +1083,9 @@ func unmarshalBool(info TypeInfo, data []byte, value interface{}) error { case *bool: *v = decBool(data) return nil + case *interface{}: + *v = decBool(data) + return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { @@ -1105,6 +1136,9 @@ func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { case *float32: *v = math.Float32frombits(uint32(decInt(data))) return nil + case *interface{}: + *v = math.Float32frombits(uint32(decInt(data))) + return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { @@ -1146,6 +1180,9 @@ func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error { case *float64: *v = math.Float64frombits(uint64(decBigInt(data))) return nil + case *interface{}: + *v = math.Float64frombits(uint64(decBigInt(data))) + return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { @@ -1196,6 +1233,14 @@ func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { unscaled := decBigInt2C(data[4:], nil) *v = *inf.NewDecBig(unscaled, inf.Scale(scale)) return nil + case *interface{}: + if len(data) < 4 { + return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) + } + scale := decInt(data[0:4]) + unscaled := decBigInt2C(data[4:], nil) + *v = *inf.NewDecBig(unscaled, inf.Scale(scale)) + return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } @@ -1302,6 +1347,9 @@ func unmarshalTime(info TypeInfo, data []byte, value interface{}) error { case *time.Duration: *v = time.Duration(decBigInt(data)) return nil + case *interface{}: + *v = time.Duration(decBigInt(data)) + return nil } rv := reflect.ValueOf(value) @@ -1334,6 +1382,16 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { nsec := (x - sec*1000) * 1000000 *v = time.Unix(sec, nsec).In(time.UTC) return nil + case *interface{}: + if len(data) == 0 { + *v = time.Time{} + return nil + } + x := decBigInt(data) + sec := x / 1000 + nsec := (x - sec*1000) * 1000000 + *v = time.Unix(sec, nsec).In(time.UTC) + return nil } rv := reflect.ValueOf(value) @@ -1419,6 +1477,16 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { timestamp := (int64(current) - int64(origin)) * millisecondsInADay *v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02") return nil + case *interface{}: + if len(data) == 0 { + *v = time.Time{} + return nil + } + var origin uint32 = 1 << 31 + var current uint32 = binary.BigEndian.Uint32(data) + timestamp := (int64(current) - int64(origin)) * millisecondsInADay + *v = time.UnixMilli(timestamp).In(time.UTC) + return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } @@ -1478,6 +1546,25 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { Nanoseconds: nanos, } return nil + case *interface{}: + if len(data) == 0 { + *v = Duration{ + Months: 0, + Days: 0, + Nanoseconds: 0, + } + return nil + } + months, days, nanos, err := decVints(data) + if err != nil { + return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error()) + } + *v = Duration{ + Months: months, + Days: days, + Nanoseconds: nanos, + } + return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } @@ -1914,6 +2001,9 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error { case *[]byte: *v = u[:] return nil + case *interface{}: + *v = u.String() + return nil } return unmarshalErrorf("can not unmarshal X %s into %T", info, value) } @@ -1996,6 +2086,18 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { } *v = ip.String() return nil + case *interface{}: + if x := len(data); !(x == 4 || x == 16) { + return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x) + } + buf := copyBytes(data) + ip := net.IP(buf) + if v4 := ip.To4(); v4 != nil { + *v = v4 + return nil + } + *v = ip + return nil } return unmarshalErrorf("cannot unmarshal %s into %T", info, value) }