Skip to content

Commit

Permalink
fix date marshal, unmarshal functions
Browse files Browse the repository at this point in the history
  • Loading branch information
illia-li committed Dec 13, 2024
1 parent 466d662 commit 2e3502b
Show file tree
Hide file tree
Showing 5 changed files with 735 additions and 71 deletions.
84 changes: 13 additions & 71 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package gocql

import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"math"
Expand All @@ -23,6 +22,7 @@ import (
"github.com/gocql/gocql/serialization/counter"
"github.com/gocql/gocql/serialization/cqlint"
"github.com/gocql/gocql/serialization/cqltime"
"github.com/gocql/gocql/serialization/date"
"github.com/gocql/gocql/serialization/decimal"
"github.com/gocql/gocql/serialization/double"
"github.com/gocql/gocql/serialization/float"
Expand Down Expand Up @@ -192,7 +192,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeUDT:
return marshalUDT(info, value)
case TypeDate:
return marshalDate(info, value)
return marshalDate(value)
case TypeDuration:
return marshalDuration(info, value)
}
Expand Down Expand Up @@ -304,7 +304,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeUDT:
return unmarshalUDT(info, data, value)
case TypeDate:
return unmarshalDate(info, data, value)
return unmarshalDate(data, value)
case TypeDuration:
return unmarshalDuration(info, data, value)
}
Expand Down Expand Up @@ -700,78 +700,20 @@ func unmarshalTimestamp(data []byte, value interface{}) error {
return nil
}

const millisecondsInADay int64 = 24 * 60 * 60 * 1000

func marshalDate(info TypeInfo, value interface{}) ([]byte, error) {
var timestamp int64
switch v := value.(type) {
case Marshaler:
return v.MarshalCQL(info)
case unsetColumn:
return nil, nil
case int64:
timestamp = v
x := timestamp/millisecondsInADay + int64(1<<31)
return encInt(int32(x)), nil
case time.Time:
if v.IsZero() {
return []byte{}, nil
}
timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6)
x := timestamp/millisecondsInADay + int64(1<<31)
return encInt(int32(x)), nil
case *time.Time:
if v.IsZero() {
return []byte{}, nil
}
timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6)
x := timestamp/millisecondsInADay + int64(1<<31)
return encInt(int32(x)), nil
case string:
if v == "" {
return []byte{}, nil
}
t, err := time.Parse("2006-01-02", v)
if err != nil {
return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info)
}
timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6)
x := timestamp/millisecondsInADay + int64(1<<31)
return encInt(int32(x)), nil
}

if value == nil {
return nil, nil
func marshalDate(value interface{}) ([]byte, error) {
data, err := date.Marshal(value)
if err != nil {
return nil, wrapMarshalError(err, "marshal error")
}
return nil, marshalErrorf("can not marshal %T into %s", value, info)
return data, nil
}

func unmarshalDate(info TypeInfo, data []byte, value interface{}) error {
switch v := value.(type) {
case Unmarshaler:
return v.UnmarshalCQL(info, data)
case *time.Time:
if len(data) == 0 {
*v = time.Time{}
return nil
}
var origin uint32 = 1 << 31
var current uint32 = binary.BigEndian.Uint32(data)
timestamp := (int64(current) - int64(origin)) * millisecondsInADay
*v = time.UnixMilli(timestamp).In(time.UTC)
return nil
case *string:
if len(data) == 0 {
*v = ""
return nil
}
var origin uint32 = 1 << 31
var current uint32 = binary.BigEndian.Uint32(data)
timestamp := (int64(current) - int64(origin)) * millisecondsInADay
*v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02")
return nil
func unmarshalDate(data []byte, value interface{}) error {
err := date.Unmarshal(data, value)
if err != nil {
return wrapUnmarshalError(err, "unmarshal error")
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
return nil
}

func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) {
Expand Down
42 changes: 42 additions & 0 deletions serialization/date/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package date

import (
"reflect"
"time"
)

func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case int32:
return EncInt32(v)
case int64:
return EncInt64(v)
case uint32:
return EncUint32(v)
case string:
return EncString(v)
case time.Time:
return EncTime(v)

case *int32:
return EncInt32R(v)
case *int64:
return EncInt64R(v)
case *uint32:
return EncUint32R(v)
case *string:
return EncStringR(v)
case *time.Time:
return EncTimeR(v)
default:
// Custom types (type MyDate uint32) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}
226 changes: 226 additions & 0 deletions serialization/date/marshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
package date

import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
)

const (
millisecondsInADay int64 = 24 * 60 * 60 * 1000
centerEpoch int64 = 1 << 31
maxYear int = 5881580
minYear int = -5877641
maxMilliseconds int64 = 185542587100800000
minMilliseconds int64 = -185542587187200000
)

var (
maxDate = time.Date(5881580, 07, 11, 0, 0, 0, 0, time.UTC)
minDate = time.Date(-5877641, 06, 23, 0, 0, 0, 0, time.UTC)
)

func errWrongStringFormat(v interface{}) error {
return fmt.Errorf(`failed to marshal date: the (%T)(%[1]v) should have fromat "2006-01-02"`, v)
}

func EncInt32(v int32) ([]byte, error) {
return encInt32(v), nil
}

func EncInt32R(v *int32) ([]byte, error) {
if v == nil {
return nil, nil
}
return encInt32(*v), nil
}

func EncInt64(v int64) ([]byte, error) {
if v > maxMilliseconds || v < minMilliseconds {
return nil, fmt.Errorf("failed to marshal date: the (int64)(%v) value out of range", v)
}
return encInt64(days(v)), nil
}

func EncInt64R(v *int64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt64(*v)
}

func EncUint32(v uint32) ([]byte, error) {
return encUint32(v), nil
}

func EncUint32R(v *uint32) ([]byte, error) {
if v == nil {
return nil, nil
}
return encUint32(*v), nil
}

func EncTime(v time.Time) ([]byte, error) {
if v.After(maxDate) || v.Before(minDate) {
return nil, fmt.Errorf("failed to marshal date: the (%T)(%s) value should be in the range from -5877641-06-23 to 5881580-07-11", v, v.Format("2006-01-02"))
}
return encTime(v), nil
}

func EncTimeR(v *time.Time) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncTime(*v)
}

func EncString(v string) ([]byte, error) {
if v == "" {
return nil, nil
}
var err error
var y, m, d int
var t time.Time
switch ps := strings.Split(v, "-"); len(ps) {
case 3:
if y, err = strconv.Atoi(ps[0]); err != nil {
return nil, errWrongStringFormat(v)
}
if m, err = strconv.Atoi(ps[1]); err != nil {
return nil, errWrongStringFormat(v)
}
if d, err = strconv.Atoi(ps[2]); err != nil {
return nil, errWrongStringFormat(v)
}
case 4:
if y, err = strconv.Atoi(ps[1]); err != nil || ps[0] != "" {
return nil, errWrongStringFormat(v)
}
y = -y
if m, err = strconv.Atoi(ps[2]); err != nil {
return nil, errWrongStringFormat(v)
}
if d, err = strconv.Atoi(ps[3]); err != nil {
return nil, errWrongStringFormat(v)
}
default:
return nil, errWrongStringFormat(v)
}
if y > maxYear || y < minYear {
return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v)
}
t = time.Date(y, time.Month(m), d, 0, 0, 0, 0, time.UTC)
if t.After(maxDate) || t.Before(minDate) {
return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v)
}
return encTime(t), nil
}

func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncString(*v)
}

func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Int32:
return encInt64(v.Int()), nil
case reflect.Int64:
val := v.Int()
if val > maxMilliseconds || val < minMilliseconds {
return nil, fmt.Errorf("failed to marshal date: the value (%T)(%[1]v) out of range", v.Interface())
}
return encInt64(days(val)), nil
case reflect.Uint32:
val := v.Uint()
return []byte{byte(val >> 24), byte(val >> 16), byte(val >> 8), byte(val)}, nil
case reflect.String:
return encReflectString(v)
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal date: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal date: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}

func encReflectString(v reflect.Value) ([]byte, error) {
val := v.String()
if val == "" {
return nil, nil
}
var err error
var y, m, d int
var t time.Time
ps := strings.Split(val, "-")
switch len(ps) {
case 3:
if y, err = strconv.Atoi(ps[0]); err != nil {
return nil, errWrongStringFormat(v.Interface())
}
if m, err = strconv.Atoi(ps[1]); err != nil {
return nil, errWrongStringFormat(v.Interface())
}
if d, err = strconv.Atoi(ps[2]); err != nil {
return nil, errWrongStringFormat(v.Interface())
}
case 4:
if y, err = strconv.Atoi(ps[1]); err != nil {
return nil, errWrongStringFormat(v.Interface())
}
y = -y
if m, err = strconv.Atoi(ps[2]); err != nil {
return nil, errWrongStringFormat(v.Interface())
}
if d, err = strconv.Atoi(ps[3]); err != nil {
return nil, errWrongStringFormat(v.Interface())
}
default:
return nil, errWrongStringFormat(v.Interface())
}
if y > maxYear || y < minYear {
return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v.Interface())
}
t = time.Date(y, time.Month(m), d, 0, 0, 0, 0, time.UTC)
if t.After(maxDate) || t.Before(minDate) {
return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v.Interface())
}
return encTime(t), nil
}

func encInt64(v int64) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}

func encInt32(v int32) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}

func encUint32(v uint32) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}

func encTime(v time.Time) []byte {
if v.IsZero() {
return nil
}
d := days(v.UnixMilli())
return []byte{byte(d >> 24), byte(d >> 16), byte(d >> 8), byte(d)}
}

func days(v int64) int64 {
return v/millisecondsInADay + centerEpoch
}
Loading

0 comments on commit 2e3502b

Please sign in to comment.