Skip to content

Commit

Permalink
scan to any were implemented for all simple types
Browse files Browse the repository at this point in the history
  • Loading branch information
tengu-alt committed Oct 21, 2024
1 parent 974fa12 commit 16cc778
Show file tree
Hide file tree
Showing 2 changed files with 263 additions and 1 deletion.
162 changes: 161 additions & 1 deletion cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import (
"time"
"unicode"

inf "gopkg.in/inf.v0"
"gopkg.in/inf.v0"
)

func TestEmptyHosts(t *testing.T) {
Expand Down Expand Up @@ -3288,3 +3288,163 @@ func TestQuery_NamedValues(t *testing.T) {
t.Fatal(err)
}
}

func TestScanToAny(t *testing.T) {
session := createSession(t)
defer session.Close()
ctx := context.Background()

dataTypes := []struct {
tableName string
createQuery string
insertQuery string
expectedVal interface{}
}{
{
"scan_to_any_varchar",
"CREATE TABLE IF NOT EXISTS scan_to_any_varchar (id int PRIMARY KEY, val varchar)",
"INSERT INTO scan_to_any_varchar (id, val) VALUES (?, ?)",
"test",
},
{
"scan_to_any_bool",
"CREATE TABLE IF NOT EXISTS scan_to_any_bool (id int PRIMARY KEY, val boolean)",
"INSERT INTO scan_to_any_bool (id, val) VALUES (?, ?)",
true,
},
{
"scan_to_any_int",
"CREATE TABLE IF NOT EXISTS scan_to_any_int (id int PRIMARY KEY, val int)",
"INSERT INTO scan_to_any_int (id, val) VALUES (?, ?)",
42,
},
{
"scan_to_any_float",
"CREATE TABLE IF NOT EXISTS scan_to_any_float (id int PRIMARY KEY, val float)",
"INSERT INTO scan_to_any_float (id, val) VALUES (?, ?)",
float32(3.14),
},
{
"scan_to_any_double",
"CREATE TABLE IF NOT EXISTS scan_to_any_double (id int PRIMARY KEY, val double)",
"INSERT INTO scan_to_any_double (id, val) VALUES (?, ?)",
3.14159,
},
{
"scan_to_any_decimal",
"CREATE TABLE IF NOT EXISTS scan_to_any_decimal (id int PRIMARY KEY, val decimal)",
"INSERT INTO scan_to_any_decimal (id, val) VALUES (?, ?)",
inf.NewDec(12345, 2), // Example decimal value
},
{
"scan_to_any_time",
"CREATE TABLE IF NOT EXISTS scan_to_any_time (id int PRIMARY KEY, val time)",
"INSERT INTO scan_to_any_time (id, val) VALUES (?, ?)",
time.Duration(1000),
},
{
"scan_to_any_timestamp",
"CREATE TABLE IF NOT EXISTS scan_to_any_timestamp (id int PRIMARY KEY, val timestamp)",
"INSERT INTO scan_to_any_timestamp (id, val) VALUES (?, ?)",
time.Now().UTC().Truncate(time.Millisecond),
},
{
"scan_to_any_inet",
"CREATE TABLE IF NOT EXISTS scan_to_any_inet (id int PRIMARY KEY, val inet)",
"INSERT INTO scan_to_any_inet (id, val) VALUES (?, ?)",
net.ParseIP("192.168.0.1"),
},
{
"scan_to_any_uuid",
"CREATE TABLE IF NOT EXISTS scan_to_any_uuid (id int PRIMARY KEY, val uuid)",
"INSERT INTO scan_to_any_uuid (id, val) VALUES (?, ?)",
TimeUUID().String(),
},
{
"scan_to_any_date",
"CREATE TABLE IF NOT EXISTS scan_to_any_date (id int PRIMARY KEY, val date)",
"INSERT INTO scan_to_any_date (id, val) VALUES (?, ?)",
time.Now().UTC().Truncate(time.Hour * 24),
},
{
"scan_to_any_duration",
"CREATE TABLE IF NOT EXISTS scan_to_any_duration (id int PRIMARY KEY, val duration)",
"INSERT INTO scan_to_any_duration (id, val) VALUES (?, ?)",
Duration{0, 0, 123},
},
}

for _, dt := range dataTypes {
t.Run(fmt.Sprintf("Test_%s", dt.tableName), func(t *testing.T) {
if err := session.Query(dt.createQuery).WithContext(ctx).Exec(); err != nil {
t.Fatal(err)
}

if err := session.Query(dt.insertQuery, 1, dt.expectedVal).WithContext(ctx).Exec(); err != nil {
t.Fatal(err)
}

var out interface{}
if err := session.Query(fmt.Sprintf("SELECT val FROM %s WHERE id = 1", dt.tableName)).WithContext(ctx).Scan(&out); err != nil {
t.Fatal(err)
}

if err := session.Query(fmt.Sprintf("DROP TABLE %s", dt.tableName)).WithContext(ctx).Exec(); err != nil {
t.Fatal(err)
}

switch dt.tableName {
case "scan_to_any_decimal":
result, ok := out.(inf.Dec)
if !ok {
t.Fatal("expected inf.Dec, got", out)
}
expected := inf.NewDec(12345, 2)

if result.Cmp(expected) != 0 {
t.Fatalf("expected %v, got %v", expected, out)
}
case "scan_to_any_inet":
result, ok := out.(net.IP)
if !ok {
t.Fatal("expected net.IP, got", out)
}
expected, ok := dt.expectedVal.(net.IP)
if !ok {
t.Fatal("expected net.IP, got", dt.expectedVal)
}
if result.String() != expected.String() {
t.Fatalf("expected %v, got %v", expected, out)
}
case "scan_to_any_date":
result, ok := out.(time.Time)
if !ok {
t.Fatal("expected time.Time, got", out)
}
expected, ok := dt.expectedVal.(time.Time)
if !ok {
t.Fatal("expected time.Time, got", dt.expectedVal)
}
if result.String() != expected.String() {
t.Fatalf("expected %v, got %v", expected, out)
}
case "scan_to_any_duration":
result, ok := out.(Duration)
if !ok {
t.Fatal("expected time.Duration, got", out)
}
expected, ok := dt.expectedVal.(Duration)
if !ok {
t.Fatal("expected time.Duration, got", dt.expectedVal)
}
if result != expected {
t.Fatalf("expected %v, got %v", expected, out)
}
default:
if out != dt.expectedVal {
t.Fatalf("expected %v, got %v", dt.expectedVal, out)
}
}
})
}
}
102 changes: 102 additions & 0 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,23 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
// date | *time.Time | time of beginning of the day (in UTC)
// date | *string | formatted with 2006-01-02 format
// duration | *gocql.Duration |
//
// Scan into interface{} implemented by unmarshal into default type:
//
// CQL type | Go type
// Varchar | string
// Varint | bigInt
// IntLike | int
// Boolean | bool
// Float | float32
// Double | float64
// Decimal | infDec
// Time | time.Duration
// Timestamp | time.Time
// Date | time.Time
// Duration | Duration
// UUID | string
// Inet | net.IP
func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
if v, ok := value.(Unmarshaler); ok {
return v.UnmarshalCQL(info, data)
Expand Down Expand Up @@ -350,6 +367,9 @@ func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error {
*v = nil
}
return nil
case *interface{}:
*v = string(data)
return nil
}

rv := reflect.ValueOf(value)
Expand Down Expand Up @@ -743,6 +763,8 @@ func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error {
*v = bytesToUint64(data[1:])
return nil
}
case *interface{}:
return unmarshalBigInt(info, data, value)
}

if len(data) > 8 {
Expand Down Expand Up @@ -904,6 +926,12 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
case *string:
*v = strconv.FormatInt(int64Val, 10)
return nil
case *interface{}:
if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) {
return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, info.Type())
}
*v = int(int64Val)
return nil
}

rv := reflect.ValueOf(value)
Expand Down Expand Up @@ -1055,6 +1083,9 @@ func unmarshalBool(info TypeInfo, data []byte, value interface{}) error {
case *bool:
*v = decBool(data)
return nil
case *interface{}:
*v = decBool(data)
return nil
}
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
Expand Down Expand Up @@ -1105,6 +1136,9 @@ func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error {
case *float32:
*v = math.Float32frombits(uint32(decInt(data)))
return nil
case *interface{}:
*v = math.Float32frombits(uint32(decInt(data)))
return nil
}
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
Expand Down Expand Up @@ -1146,6 +1180,9 @@ func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error {
case *float64:
*v = math.Float64frombits(uint64(decBigInt(data)))
return nil
case *interface{}:
*v = math.Float64frombits(uint64(decBigInt(data)))
return nil
}
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
Expand Down Expand Up @@ -1196,6 +1233,14 @@ func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error {
unscaled := decBigInt2C(data[4:], nil)
*v = *inf.NewDecBig(unscaled, inf.Scale(scale))
return nil
case *interface{}:
if len(data) < 4 {
return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data))
}
scale := decInt(data[0:4])
unscaled := decBigInt2C(data[4:], nil)
*v = *inf.NewDecBig(unscaled, inf.Scale(scale))
return nil
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
}
Expand Down Expand Up @@ -1302,6 +1347,9 @@ func unmarshalTime(info TypeInfo, data []byte, value interface{}) error {
case *time.Duration:
*v = time.Duration(decBigInt(data))
return nil
case *interface{}:
*v = time.Duration(decBigInt(data))
return nil
}

rv := reflect.ValueOf(value)
Expand Down Expand Up @@ -1334,6 +1382,16 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error {
nsec := (x - sec*1000) * 1000000
*v = time.Unix(sec, nsec).In(time.UTC)
return nil
case *interface{}:
if len(data) == 0 {
*v = time.Time{}
return nil
}
x := decBigInt(data)
sec := x / 1000
nsec := (x - sec*1000) * 1000000
*v = time.Unix(sec, nsec).In(time.UTC)
return nil
}

rv := reflect.ValueOf(value)
Expand Down Expand Up @@ -1419,6 +1477,16 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error {
timestamp := (int64(current) - int64(origin)) * millisecondsInADay
*v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02")
return nil
case *interface{}:
if len(data) == 0 {
*v = time.Time{}
return nil
}
var origin uint32 = 1 << 31
var current uint32 = binary.BigEndian.Uint32(data)
timestamp := (int64(current) - int64(origin)) * millisecondsInADay
*v = time.UnixMilli(timestamp).In(time.UTC)
return nil
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
}
Expand Down Expand Up @@ -1478,6 +1546,25 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error {
Nanoseconds: nanos,
}
return nil
case *interface{}:
if len(data) == 0 {
*v = Duration{
Months: 0,
Days: 0,
Nanoseconds: 0,
}
return nil
}
months, days, nanos, err := decVints(data)
if err != nil {
return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error())
}
*v = Duration{
Months: months,
Days: days,
Nanoseconds: nanos,
}
return nil
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
}
Expand Down Expand Up @@ -1914,6 +2001,9 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
case *[]byte:
*v = u[:]
return nil
case *interface{}:
*v = u.String()
return nil
}
return unmarshalErrorf("can not unmarshal X %s into %T", info, value)
}
Expand Down Expand Up @@ -1996,6 +2086,18 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error {
}
*v = ip.String()
return nil
case *interface{}:
if x := len(data); !(x == 4 || x == 16) {
return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x)
}
buf := copyBytes(data)
ip := net.IP(buf)
if v4 := ip.To4(); v4 != nil {
*v = v4
return nil
}
*v = ip
return nil
}
return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
}
Expand Down

0 comments on commit 16cc778

Please sign in to comment.