Skip to content

Commit

Permalink
Merge pull request scylladb#370 from illia-li/il/fix/marshal/cqltime
Browse files Browse the repository at this point in the history
Fix `time` marshal, unmarshall functions
  • Loading branch information
dkropachev authored Dec 9, 2024
2 parents 8f09151 + 514bf27 commit 0b40e7a
Show file tree
Hide file tree
Showing 8 changed files with 489 additions and 222 deletions.
60 changes: 14 additions & 46 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/gocql/gocql/serialization/blob"
"github.com/gocql/gocql/serialization/counter"
"github.com/gocql/gocql/serialization/cqlint"
"github.com/gocql/gocql/serialization/cqltime"
"github.com/gocql/gocql/serialization/decimal"
"github.com/gocql/gocql/serialization/double"
"github.com/gocql/gocql/serialization/float"
Expand Down Expand Up @@ -170,7 +171,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeDecimal:
return marshalDecimal(value)
case TypeTime:
return marshalTime(info, value)
return marshalTime(value)
case TypeTimestamp:
return marshalTimestamp(info, value)
case TypeList, TypeSet:
Expand Down Expand Up @@ -284,7 +285,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeDecimal:
return unmarshalDecimal(data, value)
case TypeTime:
return unmarshalTime(info, data, value)
return unmarshalTime(data, value)
case TypeTimestamp:
return unmarshalTimestamp(info, data, value)
case TypeList, TypeSet:
Expand Down Expand Up @@ -666,28 +667,20 @@ func encBigInt2C(n *big.Int) []byte {
return nil
}

func marshalTime(info TypeInfo, value interface{}) ([]byte, error) {
switch v := value.(type) {
case Marshaler:
return v.MarshalCQL(info)
case unsetColumn:
return nil, nil
case int64:
return encBigInt(v), nil
case time.Duration:
return encBigInt(v.Nanoseconds()), nil
}

if value == nil {
return nil, nil
func marshalTime(value interface{}) ([]byte, error) {
data, err := cqltime.Marshal(value)
if err != nil {
return nil, wrapMarshalError(err, "marshal error")
}
return data, nil
}

rv := reflect.ValueOf(value)
switch rv.Type().Kind() {
case reflect.Int64:
return encBigInt(rv.Int()), nil
func unmarshalTime(data []byte, value interface{}) error {
err := cqltime.Unmarshal(data, value)
if err != nil {
return wrapUnmarshalError(err, "unmarshal error")
}
return nil, marshalErrorf("can not marshal %T into %s", value, info)
return nil
}

func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) {
Expand Down Expand Up @@ -718,31 +711,6 @@ func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) {
return nil, marshalErrorf("can not marshal %T into %s", value, info)
}

func unmarshalTime(info TypeInfo, data []byte, value interface{}) error {
switch v := value.(type) {
case Unmarshaler:
return v.UnmarshalCQL(info, data)
case *int64:
*v = decBigInt(data)
return nil
case *time.Duration:
*v = time.Duration(decBigInt(data))
return nil
}

rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
}
rv = rv.Elem()
switch rv.Type().Kind() {
case reflect.Int64:
rv.SetInt(decBigInt(data))
return nil
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
}

func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error {
switch v := value.(type) {
case Unmarshaler:
Expand Down
61 changes: 0 additions & 61 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,20 +122,6 @@ var marshalTests = []struct {
nil,
nil,
},
{
NativeType{proto: 4, typ: TypeTime},
[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
time.Duration(int64(1376387523000)),
nil,
nil,
},
{
NativeType{proto: 4, typ: TypeTime},
[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
int64(1376387523000),
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeTimestamp},
[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
Expand Down Expand Up @@ -551,13 +537,6 @@ var marshalTests = []struct {
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeTime},
encBigInt(1000),
time.Duration(1000),
nil,
nil,
},
}

var unmarshalTests = []struct {
Expand Down Expand Up @@ -884,46 +863,6 @@ func TestMarshalPointer(t *testing.T) {
}
}

func TestMarshalTime(t *testing.T) {
durationS := "1h10m10s"
duration, _ := time.ParseDuration(durationS)
expectedData := encBigInt(duration.Nanoseconds())
var marshalTimeTests = []struct {
Info TypeInfo
Data []byte
Value interface{}
}{
{
NativeType{proto: 4, typ: TypeTime},
expectedData,
duration.Nanoseconds(),
},
{
NativeType{proto: 4, typ: TypeTime},
expectedData,
duration,
},
{
NativeType{proto: 4, typ: TypeTime},
expectedData,
&duration,
},
}

for i, test := range marshalTimeTests {
t.Log(i, test)
data, err := Marshal(test.Info, test.Value)
if err != nil {
t.Errorf("marshalTest[%d]: %v", i, err)
continue
}
if !bytes.Equal(data, test.Data) {
t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i,
test.Data, decInt(test.Data), data, decInt(data), test.Value)
}
}
}

func TestMarshalTimestamp(t *testing.T) {
var marshalTimestampTests = []struct {
Info TypeInfo
Expand Down
30 changes: 30 additions & 0 deletions serialization/cqltime/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package cqltime

import (
"reflect"
"time"
)

func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case int64:
return EncInt64(v)
case *int64:
return EncInt64R(v)
case time.Duration:
return EncDuration(v)
case *time.Duration:
return EncDurationR(v)

default:
// Custom types (type MyTime int64) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}
76 changes: 76 additions & 0 deletions serialization/cqltime/marshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package cqltime

import (
"fmt"
"reflect"
"time"
)

const (
maxValInt64 int64 = 86399999999999
minValInt64 int64 = 0
maxValDur time.Duration = 86399999999999
minValDur time.Duration = 0
)

var (
errOutRangeInt64 = fmt.Errorf("failed to marshal time: the (int64) should be in the range 0 to 86399999999999")
errOutRangeDur = fmt.Errorf("failed to marshal time: the (time.Duration) should be in the range 0 to 86399999999999")
)

func EncInt64(v int64) ([]byte, error) {
if v > maxValInt64 || v < minValInt64 {
return nil, errOutRangeInt64
}
return encInt64(v), nil
}

func EncInt64R(v *int64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt64(*v)
}

func EncDuration(v time.Duration) ([]byte, error) {
if v > maxValDur || v < minValDur {
return nil, errOutRangeDur
}
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}

func EncDurationR(v *time.Duration) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncDuration(*v)
}

func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Int64:
val := v.Int()
if val > maxValInt64 || val < minValInt64 {
return nil, fmt.Errorf("failed to marshal time: the (%T) should be in the range 0 to 86399999999999", v.Interface())
}
return encInt64(val), nil
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal time: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal time: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}

func encInt64(v int64) []byte {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}
36 changes: 36 additions & 0 deletions serialization/cqltime/unmarshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package cqltime

import (
"fmt"
"reflect"
"time"
)

func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil

case *int64:
return DecInt64(data, v)
case **int64:
return DecInt64R(data, v)
case *time.Duration:
return DecDuration(data, v)
case **time.Duration:
return DecDurationR(data, v)
default:

// Custom types (type MyTime int64) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal time: unsupported value type (%T)(%[1]v)", value)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}
Loading

0 comments on commit 0b40e7a

Please sign in to comment.