Skip to content

Commit

Permalink
Merge pull request #378 from illia-li/il/fix/marshal/duration
Browse files Browse the repository at this point in the history
Fix `duration` marshal, unmarshall functions
  • Loading branch information
dkropachev authored Dec 29, 2024
2 parents fffa208 + b977347 commit daef54f
Show file tree
Hide file tree
Showing 11 changed files with 1,283 additions and 401 deletions.
200 changes: 23 additions & 177 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@ import (
"fmt"
"github.com/gocql/gocql/serialization/boolean"
"math"
"math/big"
"math/bits"
"reflect"
"strings"
"time"
"unsafe"

"github.com/gocql/gocql/serialization/ascii"
Expand All @@ -26,6 +23,7 @@ import (
"github.com/gocql/gocql/serialization/date"
"github.com/gocql/gocql/serialization/decimal"
"github.com/gocql/gocql/serialization/double"
"github.com/gocql/gocql/serialization/duration"
"github.com/gocql/gocql/serialization/float"
"github.com/gocql/gocql/serialization/inet"
"github.com/gocql/gocql/serialization/smallint"
Expand All @@ -39,7 +37,6 @@ import (
)

var (
bigOne = big.NewInt(1)
emptyValue reflect.Value
)

Expand Down Expand Up @@ -195,7 +192,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeDate:
return marshalDate(value)
case TypeDuration:
return marshalDuration(info, value)
return marshalDuration(value)
}

// detect protocol 2 UDT
Expand Down Expand Up @@ -307,7 +304,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeDate:
return unmarshalDate(data, value)
case TypeDuration:
return unmarshalDuration(info, data, value)
return unmarshalDuration(data, value)
}

// detect protocol 2 UDT
Expand Down Expand Up @@ -430,17 +427,6 @@ func marshalInt(value interface{}) ([]byte, error) {
return data, nil
}

func encInt(x int32) []byte {
return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)}
}

func decInt(x []byte) int32 {
if len(x) != 4 {
return 0
}
return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3])
}

func marshalBigInt(value interface{}) ([]byte, error) {
data, err := bigint.Marshal(value)
if err != nil {
Expand All @@ -457,11 +443,6 @@ func marshalCounter(value interface{}) ([]byte, error) {
return data, nil
}

func encBigInt(x int64) []byte {
return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40), byte(x >> 32),
byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)}
}

func unmarshalCounter(data []byte, value interface{}) error {
err := counter.Unmarshal(data, value)
if err != nil {
Expand Down Expand Up @@ -587,46 +568,6 @@ func unmarshalDecimal(data []byte, value interface{}) error {
return nil
}

// decBigInt2C sets the value of n to the big-endian two's complement
// value stored in the given data. If data[0]&80 != 0, the number
// is negative. If data is empty, the result will be 0.
func decBigInt2C(data []byte, n *big.Int) *big.Int {
if n == nil {
n = new(big.Int)
}
n.SetBytes(data)
if len(data) > 0 && data[0]&0x80 > 0 {
n.Sub(n, new(big.Int).Lsh(bigOne, uint(len(data))*8))
}
return n
}

// encBigInt2C returns the big-endian two's complement
// form of n.
func encBigInt2C(n *big.Int) []byte {
switch n.Sign() {
case 0:
return []byte{0}
case 1:
b := n.Bytes()
if b[0]&0x80 > 0 {
b = append([]byte{0}, b...)
}
return b
case -1:
length := uint(n.BitLen()/8+1) * 8
b := new(big.Int).Add(n, new(big.Int).Lsh(bigOne, length)).Bytes()
// When the most significant bit is on a byte
// boundary, we can get some extra significant
// bits, so strip them off when that happens.
if len(b) >= 2 && b[0] == 0xff && b[1]&0x80 != 0 {
b = b[1:]
}
return b
}
return nil
}

func marshalTime(value interface{}) ([]byte, error) {
data, err := cqltime.Marshal(value)
if err != nil {
Expand Down Expand Up @@ -675,131 +616,36 @@ func unmarshalDate(data []byte, value interface{}) error {
return nil
}

func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) {
switch v := value.(type) {
case Marshaler:
return v.MarshalCQL(info)
case unsetColumn:
return nil, nil
case int64:
return encVints(0, 0, v), nil
case time.Duration:
return encVints(0, 0, v.Nanoseconds()), nil
case string:
d, err := time.ParseDuration(v)
if err != nil {
return nil, err
}
return encVints(0, 0, d.Nanoseconds()), nil
func marshalDuration(value interface{}) ([]byte, error) {
switch uv := value.(type) {
case Duration:
return encVints(v.Months, v.Days, v.Nanoseconds), nil
}

if value == nil {
return nil, nil
value = duration.Duration(uv)
case *Duration:
value = (*duration.Duration)(uv)
}

rv := reflect.ValueOf(value)
switch rv.Type().Kind() {
case reflect.Int64:
return encVints(0, 0, rv.Int()), nil
data, err := duration.Marshal(value)
if err != nil {
return nil, wrapMarshalError(err, "marshal error")
}
return nil, marshalErrorf("can not marshal %T into %s", value, info)
return data, nil
}

func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error {
switch v := value.(type) {
case Unmarshaler:
return v.UnmarshalCQL(info, data)
func unmarshalDuration(data []byte, value interface{}) error {
switch uv := value.(type) {
case *Duration:
if len(data) == 0 {
*v = Duration{
Months: 0,
Days: 0,
Nanoseconds: 0,
}
return nil
}
months, days, nanos, err := decVints(data)
if err != nil {
return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error())
}
*v = Duration{
Months: months,
Days: days,
Nanoseconds: nanos,
value = (*duration.Duration)(uv)
case **Duration:
if uv == nil {
value = (**duration.Duration)(nil)
} else {
value = (**duration.Duration)(unsafe.Pointer(uv))
}
return nil
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
}

func decVints(data []byte) (int32, int32, int64, error) {
month, i, err := decVint(data, 0)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to extract month: %s", err.Error())
}
days, i, err := decVint(data, i)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to extract days: %s", err.Error())
}
nanos, _, err := decVint(data, i)
err := duration.Unmarshal(data, value)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to extract nanoseconds: %s", err.Error())
}
return int32(month), int32(days), nanos, err
}

func decVint(data []byte, start int) (int64, int, error) {
if len(data) <= start {
return 0, 0, errors.New("unexpected eof")
}
firstByte := data[start]
if firstByte&0x80 == 0 {
return decIntZigZag(uint64(firstByte)), start + 1, nil
}
numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24
ret := uint64(firstByte & (0xff >> uint(numBytes)))
if len(data) < start+numBytes+1 {
return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data))
}
for i := start; i < start+numBytes; i++ {
ret <<= 8
ret |= uint64(data[i+1] & 0xff)
}
return decIntZigZag(ret), start + numBytes + 1, nil
}

func decIntZigZag(n uint64) int64 {
return int64((n >> 1) ^ -(n & 1))
}

func encIntZigZag(n int64) uint64 {
return uint64((n >> 63) ^ (n << 1))
}

func encVints(months int32, days int32, nanos int64) []byte {
buf := append(encVint(int64(months)), encVint(int64(days))...)
return append(buf, encVint(nanos)...)
}

func encVint(v int64) []byte {
vEnc := encIntZigZag(v)
lead0 := bits.LeadingZeros64(vEnc)
numBytes := (639 - lead0*9) >> 6

// It can be 1 or 0 is v ==0
if numBytes <= 1 {
return []byte{byte(vEnc)}
}
extraBytes := numBytes - 1
var buf = make([]byte, numBytes)
for i := extraBytes; i >= 0; i-- {
buf[i] = byte(vEnc)
vEnc >>= 8
return wrapUnmarshalError(err, "unmarshal error")
}
buf[0] |= byte(^(0xff >> uint(extraBytes)))
return buf
return nil
}

func writeCollectionSize(info CollectionType, n int, buf *bytes.Buffer) error {
Expand Down
Loading

0 comments on commit daef54f

Please sign in to comment.