diff --git a/.gitignore b/.gitignore index 17f1ccdc..fe7f1f38 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ _generated/*_gen.go _generated/*_gen_test.go msgp/defgen_test.go msgp/cover.out -*~ \ No newline at end of file +*~ +.idea diff --git a/Makefile b/Makefile index 81b8b126..9ae72018 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ # normal `go install`. # generated integration test files -GGEN = ./_generated/generated.go ./_generated/generated_test.go +GGEN = ./_generated/generated.go ./_generated/generated_test.go ./_generated/issue94_gen.go ./_generated/issue94_gen_test.go # generated unit test files MGEN = ./msgp/defgen_test.go diff --git a/README.md b/README.md index a7cc849c..8b186c5b 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,16 @@ type Person struct { unexported bool // this field is also ignored } ``` +If you need to have numeric labels for a struct fields, you could set it as following: +```go +type Person struct { + Name string `msg:"0x01,int"` + Address string `msg:"0b10,int"` + Email string `msg:"03,int"` + Age int `msg:"4,int"` +} +``` +> Note that field labels with `uint` value will be serialized as `msgp.fixint` in case when label value is <= (1<<7)-1 By default, the code generator will satisfy `msgp.Sizer`, `msgp.Encodable`, `msgp.Decodable`, `msgp.Marshaler`, and `msgp.Unmarshaler`. Carefully-designed applications can use these methods to do diff --git a/_generated/def.go b/_generated/def.go index 13b37e99..516317eb 100644 --- a/_generated/def.go +++ b/_generated/def.go @@ -40,10 +40,13 @@ type Fixed struct { type TestType struct { F *float64 `msg:"float"` Els map[string]string `msg:"elements"` + Els2 map[int]int `msg:"elements_2"` + Els3 map[uint]uint `msg:"elements_3"` + Els4 map[uint][]byte `msg:"elements_4"` Obj struct { // test anonymous struct - ValueA string `msg:"value_a"` - ValueB []byte `msg:"value_b"` - } `msg:"object"` + ValueA string `msg:"value_a"` + ValueB []byte `msg:"value_b"` + } `msg:"object"` Child *TestType `msg:"child"` Time time.Time `msg:"time"` Any interface{} `msg:"any"` @@ -53,6 +56,27 @@ type TestType struct { Slice2 []string } +type TestNumericLabels struct { + One string `msg:"0x01,int"` + Two string `msg:"0xffffffffffffffff,uint"` +} + +type TestOnlyIntLabels struct { + A string `msg:"0xfa,int"` + B string `msg:"0xfb,int"` + C string `msg:"0xfc,int"` +} + +type TestIntLiterals struct { + A string `msg:"0x01,int"` + B string `msg:"03,int"` + C string `msg:"4,int"` +} + +type SingleFieldNumeric struct { + Message string `msg:"0x00,uint"` +} + //msgp:tuple Object type Object struct { ObjectNo string `msg:"objno"` diff --git a/gen/decode.go b/gen/decode.go index 5367ad3e..55426342 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -7,33 +7,24 @@ import ( func decode(w io.Writer) *decodeGen { return &decodeGen{ - p: printer{w: w}, - hasfield: false, + p: printer{w: w}, } } type decodeGen struct { passes - p printer - hasfield bool + fields + p printer } func (d *decodeGen) Method() Method { return Decode } -func (d *decodeGen) needsField() { - if d.hasfield { - return - } - d.p.print("\nvar field []byte; _ = field") - d.hasfield = true -} - func (d *decodeGen) Execute(p Elem) error { p = d.applyall(p) if p == nil { return nil } - d.hasfield = false + d.fields.drop() if !d.p.ok() { return d.p.err } @@ -63,14 +54,32 @@ func (d *decodeGen) gStruct(s *Struct) { return } -func (d *decodeGen) assignAndCheck(name string, typ string) { +func (d *decodeGen) assignAndCheck(name string, base string) { if !d.p.ok() { return } - d.p.printf("\n%s, err = dc.Read%s()", name, typ) + if base == mapKey { + d.p.printf("\n%s, err = dc.ReadMapKeyPtr()", name) + } else { + d.p.printf("\n%s, err = dc.Read%s()", name, base) + } + d.p.print(errcheck) } +func (u *decodeGen) nextTypeAndCheck(name string) { + if !u.p.ok() { + return + } + u.p.printf("\n%s, err = dc.NextType()", name) + u.p.print(errcheck) +} + +func (u *decodeGen) skipAndCheck() { + u.p.print("\nerr = dc.Skip()") + u.p.print(errcheck) +} + func (d *decodeGen) structAsTuple(s *Struct) { nfields := len(s.Fields) @@ -87,25 +96,7 @@ func (d *decodeGen) structAsTuple(s *Struct) { } func (d *decodeGen) structAsMap(s *Struct) { - d.needsField() - sz := randIdent() - d.p.declare(sz, u32) - d.assignAndCheck(sz, mapHeader) - - d.p.printf("\nfor %s > 0 {\n%s--", sz, sz) - d.assignAndCheck("field", mapKey) - d.p.print("\nswitch msgp.UnsafeString(field) {") - for i := range s.Fields { - d.p.printf("\ncase \"%s\":", s.Fields[i].FieldTag) - next(d, s.Fields[i].FieldElem) - if !d.p.ok() { - return - } - } - d.p.print("\ndefault:\nerr = dc.Skip()") - d.p.print(errcheck) - d.p.closeblock() // close switch - d.p.closeblock() // close for loop + genStructFieldsParser(d, d.p, s.Fields) } func (d *decodeGen) gBase(b *BaseElem) { @@ -166,9 +157,9 @@ func (d *decodeGen) gMap(m *Map) { // for element in map, read string/value // pair and assign d.p.printf("\nfor %s > 0 {\n%s--", sz, sz) - d.p.declare(m.Keyidx, "string") + d.p.declare(m.Keyidx, m.Key.TypeName()) d.p.declare(m.Validx, m.Value.TypeName()) - d.assignAndCheck(m.Keyidx, stringTyp) + next(d, m.Key) next(d, m.Value) d.p.mapAssign(m) d.p.closeblock() diff --git a/gen/elem.go b/gen/elem.go index 250c187e..f9727483 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -236,6 +236,7 @@ type Map struct { common Keyidx string // key variable name Validx string // value variable name + Key Elem // key element Value Elem // value element } @@ -250,6 +251,7 @@ ridx: goto ridx } + m.Key.SetVarname(m.Keyidx) m.Value.SetVarname(m.Validx) } @@ -257,7 +259,7 @@ func (m *Map) TypeName() string { if m.common.alias != "" { return m.common.alias } - m.common.Alias("map[string]" + m.Value.TypeName()) + m.common.Alias("map[" + m.Key.TypeName() + "]" + m.Value.TypeName()) return m.common.alias } @@ -396,9 +398,9 @@ func (s *Struct) Complexity() int { } type StructField struct { - FieldTag string // the string inside the `msg:""` tag - FieldName string // the name of the struct field - FieldElem Elem // the field type + FieldTag interface{} // the label of the struct field in msgpack + FieldName string // the name of the struct field + FieldElem Elem // the field type } // BaseElem is an element that diff --git a/gen/encode.go b/gen/encode.go index b3415286..90d9a657 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -2,8 +2,9 @@ package gen import ( "fmt" - "github.com/tinylib/msgp/msgp" "io" + + "github.com/tinylib/msgp/msgp" ) func encode(w io.Writer) *encodeGen { @@ -101,19 +102,7 @@ func (e *encodeGen) appendraw(bts []byte) { } func (e *encodeGen) structmap(s *Struct) { - nfields := len(s.Fields) - data := msgp.AppendMapHeader(nil, uint32(nfields)) - e.p.printf("\n// map header, size %d", nfields) - e.Fuse(data) - for i := range s.Fields { - if !e.p.ok() { - return - } - data = msgp.AppendString(nil, s.Fields[i].FieldTag) - e.p.printf("\n// write %q", s.Fields[i].FieldTag) - e.Fuse(data) - next(e, s.Fields[i].FieldElem) - } + genStructFieldsSerializer(e, e.p, s.Fields) } func (e *encodeGen) gMap(m *Map) { @@ -125,7 +114,7 @@ func (e *encodeGen) gMap(m *Map) { e.writeAndCheck(mapHeader, lenAsUint32, vname) e.p.printf("\nfor %s, %s := range %s {", m.Keyidx, m.Validx, vname) - e.writeAndCheck(stringTyp, literalFmt, m.Keyidx) + next(e, m.Key) next(e, m.Value) e.p.closeblock() } diff --git a/gen/marshal.go b/gen/marshal.go index 8e9a4765..522cacc9 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -2,8 +2,9 @@ package gen import ( "fmt" - "github.com/tinylib/msgp/msgp" "io" + + "github.com/tinylib/msgp/msgp" ) func marshal(w io.Writer) *marshalGen { @@ -96,21 +97,7 @@ func (m *marshalGen) tuple(s *Struct) { } func (m *marshalGen) mapstruct(s *Struct) { - data := make([]byte, 0, 64) - data = msgp.AppendMapHeader(data, uint32(len(s.Fields))) - m.p.printf("\n// map header, size %d", len(s.Fields)) - m.Fuse(data) - for i := range s.Fields { - if !m.p.ok() { - return - } - data = msgp.AppendString(nil, s.Fields[i].FieldTag) - - m.p.printf("\n// string %q", s.Fields[i].FieldTag) - m.Fuse(data) - - next(m, s.Fields[i].FieldElem) - } + genStructFieldsSerializer(m, m.p, s.Fields) } // append raw data @@ -130,7 +117,7 @@ func (m *marshalGen) gMap(s *Map) { vname := s.Varname() m.rawAppend(mapHeader, lenAsUint32, vname) m.p.printf("\nfor %s, %s := range %s {", s.Keyidx, s.Validx, vname) - m.rawAppend(stringTyp, literalFmt, s.Keyidx) + next(m, s.Key) next(m, s.Value) m.p.closeblock() } diff --git a/gen/size.go b/gen/size.go index 5c71ec72..09cf9ed0 100644 --- a/gen/size.go +++ b/gen/size.go @@ -2,9 +2,10 @@ package gen import ( "fmt" - "github.com/tinylib/msgp/msgp" "io" "strconv" + + "github.com/tinylib/msgp/msgp" ) type sizeState uint8 @@ -107,7 +108,7 @@ func (s *sizeGen) gStruct(st *Struct) { s.addConstant(strconv.Itoa(len(data))) for i := range st.Fields { data = data[:0] - data = msgp.AppendString(data, st.Fields[i].FieldTag) + data, s.p.err = msgp.AppendIntf(data, st.Fields[i].FieldTag) s.addConstant(strconv.Itoa(len(data))) next(s, st.Fields[i].FieldElem) } @@ -168,9 +169,10 @@ func (s *sizeGen) gMap(m *Map) { vn := m.Varname() s.p.printf("\nif %s != nil {", vn) s.p.printf("\nfor %s, %s := range %s {", m.Keyidx, m.Validx, vn) + s.p.printf("\n_ = %s", m.Keyidx) // we may not use the key s.p.printf("\n_ = %s", m.Validx) // we may not use the value - s.p.printf("\ns += msgp.StringPrefixSize + len(%s)", m.Keyidx) - s.state = expr + s.state = add + next(s, m.Key) next(s, m.Value) s.p.closeblock() s.p.closeblock() @@ -238,8 +240,12 @@ func fixedsizeExpr(e Elem) (string, bool) { mhdr := msgp.AppendMapHeader(nil, uint32(len(e.Fields))) hdrlen += len(mhdr) var strbody []byte + var err error for _, f := range e.Fields { - strbody = msgp.AppendString(strbody[:0], f.FieldTag) + strbody, err = msgp.AppendIntf(strbody[:0], f.FieldTag) + if err != nil { + return "", false + } hdrlen += len(strbody) } return fmt.Sprintf("%d + %s", hdrlen, str), true diff --git a/gen/spec.go b/gen/spec.go index 26d26df6..18069ccb 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -3,6 +3,10 @@ package gen import ( "fmt" "io" + "math" + + "github.com/tinylib/msgp/internal/log" + "github.com/tinylib/msgp/msgp" ) const ( @@ -13,7 +17,7 @@ const ( quotedFmt = `"%s"` mapHeader = "MapHeader" arrayHeader = "ArrayHeader" - mapKey = "MapKeyPtr" + mapKey = "MapKey" stringTyp = "String" u32 = "uint32" ) @@ -190,6 +194,28 @@ func (p *passes) applyall(e Elem) Elem { return e } +type fields struct { + cache map[string]bool +} + +func (f *fields) declareOnce(p printer, name, typ string) { + key := name + "." + typ + if f.cache == nil { + f.cache = make(map[string]bool) + } else if f.cache[key] { + return + } + + p.printf("\nvar %s %s;", name, typ) + f.cache[key] = true +} + +func (f *fields) drop() { + for k := range f.cache { + delete(f.cache, k) + } +} + type traversal interface { gMap(*Map) gSlice(*Slice) @@ -199,6 +225,181 @@ type traversal interface { gStruct(*Struct) } +type declarer interface { + declareOnce(p printer, name, typ string) +} + +type assigner interface { + assignAndCheck(name, base string) + nextTypeAndCheck(name string) + skipAndCheck() +} + +type fuser interface { + Fuse([]byte) +} + +type traversalAssigner interface { + traversal + assigner + declarer +} + +type traversalFuser interface { + traversal + fuser +} + +func genStructFieldsSerializer(t traversalFuser, p printer, fields []StructField) { + data := msgp.AppendMapHeader(nil, uint32(len(fields))) + p.printf("\n// map header, size %d", len(fields)) + t.Fuse(data) + for _, f := range fields { + if !p.ok() { + return + } + + data, p.err = msgp.AppendIntf(nil, f.FieldTag) + p.printf( + "\n// [field %q] write label `%v` as msgp.%s", + f.FieldName, f.FieldTag, msgp.NextType(data), + ) + t.Fuse(data) + + next(t, f.FieldElem) + } +} + +func genStructFieldsParser(t traversalAssigner, p printer, fields []StructField) { + const ( + fieldBytes = "fieldBytes" + fieldInt = "fieldInt" + fieldUint = "fieldUint" + typ = "typ" + ) + + groups := groupFieldsByType(fields) + hasUint := len(groups[msgp.UintType]) > 0 + hasInt := len(groups[msgp.IntType]) > 0 + hasStr := len(groups[msgp.StrType]) > 0 + singleType := len(groups) == 1 + + if hasStr { + t.declareOnce(p, fieldBytes, "[]byte") + } + if hasUint { + t.declareOnce(p, fieldUint, "uint64") + + // Append to int fields also uint fields that do not overflow int64. + // This is necessary because uint field with value <= (1<<7)-1 could be serialized as fixint. + // and become a msgp.IntType. This is done for best compatibility with other libraries + // (and with other languages). That is, some endpoint could serialize uint16 key with value <= (1<<7)-1 as + // real msgpack uint16, but also could serialize it like fixint. + for _, f := range groups[msgp.UintType] { + v := f.FieldTag.(uint64) + if v <= math.MaxInt64 { + groups[msgp.IntType] = append(groups[msgp.IntType], f) + hasInt = true + } + } + } + if hasInt { + t.declareOnce(p, fieldInt, "int64") + } + if !singleType || hasUint { + t.declareOnce(p, typ, "msgp.Type") + } + + sz := randIdent() + p.declare(sz, u32) + t.assignAndCheck(sz, mapHeader) + p.printf("\nfor %s > 0 {\n%s--", sz, sz) + switch { + case singleType && hasStr: + t.assignAndCheck(fieldBytes, mapKey) + switchFieldKeysStr(t, p, fields, fieldBytes) + + case singleType && !hasUint && hasInt: + t.assignAndCheck(fieldInt, "Int64") + switchFieldKeys(t, p, fields, fieldInt) + + default: + // switch on inferred type of next field + t.nextTypeAndCheck(typ) + p.printf("\nswitch %s {", typ) + if hasUint { + p.print("\ncase msgp.UintType:") + t.assignAndCheck(fieldUint, "Uint64") + switchFieldKeys(t, p, groups[msgp.UintType], fieldUint) + } + if hasInt { + p.print("\ncase msgp.IntType:") + t.assignAndCheck(fieldInt, "Int64") + switchFieldKeys(t, p, groups[msgp.IntType], fieldInt) + } + if hasStr { + // double case is done for backward compatibility with previous implementation + p.print("\ncase msgp.StrType, msgp.BinType:") + t.assignAndCheck(fieldBytes, mapKey) + switchFieldKeysStr(t, p, groups[msgp.StrType], fieldBytes) + } + p.print("\ndefault:") + t.skipAndCheck() + p.closeblock() // close switch + } + p.closeblock() // close loop +} + +func groupFieldsByType(fields []StructField) map[msgp.Type][]StructField { + groups := make(map[msgp.Type][]StructField, len(fields)) + for _, f := range fields { + var t msgp.Type + switch f.FieldTag.(type) { + case int, int8, int16, int32, int64: + t = msgp.IntType + case uint, uint8, uint16, uint32, uint64: + t = msgp.UintType + case string: + t = msgp.StrType + default: + log.Fatalf( + "could not generate code to work with field's %q label: has unknown type %T", + f.FieldName, f.FieldTag, + ) + } + groups[t] = append(groups[t], f) + } + return groups +} + +func switchFieldKeysStr(t traversalAssigner, p printer, fields []StructField, label string) { + p.printf("\nswitch msgp.UnsafeString(%s) {", label) + for _, f := range fields { + p.printf("\ncase \"%s\":", f.FieldTag) + next(t, f.FieldElem) + if !p.ok() { + return + } + } + p.print("\ndefault:") + t.skipAndCheck() + p.closeblock() // close switch +} + +func switchFieldKeys(t traversalAssigner, p printer, fields []StructField, label string) { + p.printf("\nswitch %s {", label) + for _, f := range fields { + p.printf("\ncase %v:", f.FieldTag) + next(t, f.FieldElem) + if !p.ok() { + return + } + } + p.print("\ndefault:") + t.skipAndCheck() + p.closeblock() // close switch +} + // type-switch dispatch to the correct // method given the type of 'e' func next(t traversal, e Elem) { diff --git a/gen/unmarshal.go b/gen/unmarshal.go index eadbd841..cbc68151 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -13,22 +13,14 @@ func unmarshal(w io.Writer) *unmarshalGen { type unmarshalGen struct { passes - p printer - hasfield bool + fields + p printer } func (u *unmarshalGen) Method() Method { return Unmarshal } -func (u *unmarshalGen) needsField() { - if u.hasfield { - return - } - u.p.print("\nvar field []byte; _ = field") - u.hasfield = true -} - func (u *unmarshalGen) Execute(p Elem) error { - u.hasfield = false + u.fields.drop() if !u.p.ok() { return u.p.err } @@ -51,7 +43,23 @@ func (u *unmarshalGen) assignAndCheck(name string, base string) { if !u.p.ok() { return } - u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", name, base) + if base == mapKey { + u.p.printf("\n%s, bts, err = msgp.ReadMapKeyZC(bts)", name) + } else { + u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", name, base) + } + u.p.print(errcheck) +} + +func (u *unmarshalGen) nextTypeAndCheck(name string) { + if !u.p.ok() { + return + } + u.p.printf("\n%s = msgp.NextType(bts)", name) +} + +func (u *unmarshalGen) skipAndCheck() { + u.p.print("\nbts, err = msgp.Skip(bts)") u.p.print(errcheck) } @@ -68,7 +76,6 @@ func (u *unmarshalGen) gStruct(s *Struct) { } func (u *unmarshalGen) tuple(s *Struct) { - // open block sz := randIdent() u.p.declare(sz, u32) @@ -83,25 +90,7 @@ func (u *unmarshalGen) tuple(s *Struct) { } func (u *unmarshalGen) mapstruct(s *Struct) { - u.needsField() - sz := randIdent() - u.p.declare(sz, u32) - u.assignAndCheck(sz, mapHeader) - - u.p.printf("\nfor %s > 0 {", sz) - u.p.printf("\n%s--; field, bts, err = msgp.ReadMapKeyZC(bts)", sz) - u.p.print(errcheck) - u.p.print("\nswitch msgp.UnsafeString(field) {") - for i := range s.Fields { - if !u.p.ok() { - return - } - u.p.printf("\ncase \"%s\":", s.Fields[i].FieldTag) - next(u, s.Fields[i].FieldElem) - } - u.p.print("\ndefault:\nbts, err = msgp.Skip(bts)") - u.p.print(errcheck) - u.p.print("\n}\n}") // close switch and for loop + genStructFieldsParser(u, u.p, s.Fields) } func (u *unmarshalGen) gBase(b *BaseElem) { @@ -180,8 +169,10 @@ func (u *unmarshalGen) gMap(m *Map) { // loop and get key,value u.p.printf("\nfor %s > 0 {", sz) - u.p.printf("\nvar %s string; var %s %s; %s--", m.Keyidx, m.Validx, m.Value.TypeName(), sz) - u.assignAndCheck(m.Keyidx, stringTyp) + u.p.printf("\n%s--", sz) + u.p.declare(m.Keyidx, m.Key.TypeName()) + u.p.declare(m.Validx, m.Value.TypeName()) + next(u, m.Key) next(u, m.Value) u.p.mapAssign(m) u.p.closeblock() diff --git a/internal/log/log.go b/internal/log/log.go new file mode 100644 index 00000000..26b8a6cf --- /dev/null +++ b/internal/log/log.go @@ -0,0 +1,56 @@ +package log + +import ( + "fmt" + "github.com/ttacon/chalk" + "os" + "strings" +) + +var logctx []string + +// push logging state +func PushState(s string) { + logctx = append(logctx, s) +} + +// pop logging state +func PopState() { + logctx = logctx[:len(logctx)-1] +} + +func Infof(s string, v ...interface{}) { + PushState(s) + fmt.Printf(chalk.Green.Color(strings.Join(logctx, ": ")), v...) + PopState() +} + +func Infoln(s string) { + PushState(s) + fmt.Println(chalk.Green.Color(strings.Join(logctx, ": "))) + PopState() +} + +func Warnf(s string, v ...interface{}) { + PushState(s) + fmt.Printf(chalk.Yellow.Color(strings.Join(logctx, ": ")), v...) + PopState() +} + +func Warnln(s string) { + PushState(s) + fmt.Println(chalk.Yellow.Color(strings.Join(logctx, ": "))) + PopState() +} + +func Fatal(s string) { + PushState(s) + fmt.Print(chalk.Red.Color(strings.Join(logctx, ": "))) + os.Exit(1) +} + +func Fatalf(s string, v ...interface{}) { + PushState(s) + fmt.Printf(chalk.Red.Color(strings.Join(logctx, ": ")), v...) + os.Exit(1) +} diff --git a/msgp/file_test.go b/msgp/file_test.go index 1cc01cec..9a3c3aed 100644 --- a/msgp/file_test.go +++ b/msgp/file_test.go @@ -5,10 +5,11 @@ package msgp_test import ( "bytes" "crypto/rand" - "github.com/tinylib/msgp/msgp" prand "math/rand" "os" "testing" + + "github.com/tinylib/msgp/msgp" ) type rawBytes []byte diff --git a/msgp/read.go b/msgp/read.go index a493f941..b3aaa350 100644 --- a/msgp/read.go +++ b/msgp/read.go @@ -146,6 +146,56 @@ func (m *Reader) Read(p []byte) (int, error) { return m.R.Read(p) } +// CopyNext reads the next object from m without decoding it and writes it to w. +// It avoids unnecessary copies internally. +func (m *Reader) CopyNext(w io.Writer) (int64, error) { + sz, o, err := getNextSize(m.R) + if err != nil { + return 0, err + } + + var n int64 + // Opportunistic optimization: if we can fit the whole thing in the m.R + // buffer, then just get a pointer to that, and pass it to w.Write, + // avoiding an allocation. + if int(sz) <= m.R.BufferSize() { + var nn int + var buf []byte + buf, err = m.R.Next(int(sz)) + if err != nil { + if err == io.ErrUnexpectedEOF { + err = ErrShortBytes + } + return 0, err + } + nn, err = w.Write(buf) + n += int64(nn) + } else { + // Fall back to io.CopyN. + // May avoid allocating if w is a ReaderFrom (e.g. bytes.Buffer) + n, err = io.CopyN(w, m.R, int64(sz)) + if err == io.ErrUnexpectedEOF { + err = ErrShortBytes + } + } + if err != nil { + return n, err + } else if n < int64(sz) { + return n, io.ErrShortWrite + } + + // for maps and slices, read elements + for x := uintptr(0); x < o; x++ { + var n2 int64 + n2, err = m.CopyNext(w) + if err != nil { + return n, err + } + n += n2 + } + return n, nil +} + // ReadFull implements `io.ReadFull` func (m *Reader) ReadFull(p []byte) (int, error) { return m.R.ReadFull(p) @@ -194,8 +244,10 @@ func (m *Reader) IsNil() bool { return err == nil && p[0] == mnil } +// getNextSize returns the size of the next object on the wire. // returns (obj size, obj elements, error) // only maps and arrays have non-zero obj elements +// for maps and arrays, obj size does not include elements // // use uintptr b/c it's guaranteed to be large enough // to hold whatever we can fit in memory. @@ -1087,6 +1139,33 @@ func (m *Reader) ReadMapStrIntf(mp map[string]interface{}) (err error) { return } +// ReadMapIntfIntf reads a MessagePack map into a map[interface{}]interface{}. +// (You must pass a non-nil map into the function.) +func (m *Reader) ReadMapIntfIntf(mp map[interface{}]interface{}) (err error) { + var sz uint32 + sz, err = m.ReadMapHeader() + if err != nil { + return + } + for key := range mp { + delete(mp, key) + } + for i := uint32(0); i < sz; i++ { + var key interface{} + var val interface{} + key, err = m.ReadIntf() + if err != nil { + return + } + val, err = m.ReadIntf() + if err != nil { + return + } + mp[key] = val + } + return +} + // ReadTime reads a time.Time object from the reader. // The returned time's location will be set to time.Local. func (m *Reader) ReadTime() (t time.Time, err error) { @@ -1111,7 +1190,7 @@ func (m *Reader) ReadTime() (t time.Time, err error) { // ReadIntf reads out the next object as a raw interface{}. // Arrays are decoded as []interface{}, and maps are decoded -// as map[string]interface{}. Integers are decoded as int64 +// as map[interface{}]interface{}. Integers are decoded as int64 // and unsigned integers are decoded as uint64. func (m *Reader) ReadIntf() (i interface{}, err error) { var t Type @@ -1172,8 +1251,8 @@ func (m *Reader) ReadIntf() (i interface{}, err error) { return case MapType: - mp := make(map[string]interface{}) - err = m.ReadMapStrIntf(mp) + mp := make(map[interface{}]interface{}) + err = m.ReadMapIntfIntf(mp) i = mp return diff --git a/msgp/read_test.go b/msgp/read_test.go index aa191439..fd6a9922 100644 --- a/msgp/read_test.go +++ b/msgp/read_test.go @@ -2,6 +2,7 @@ package msgp import ( "bytes" + "fmt" "io" "math" "math/rand" @@ -31,11 +32,15 @@ func TestReadIntf(t *testing.T) { time.Now(), "hello!", []byte("hello!"), - map[string]interface{}{ + map[interface{}]interface{}{ "thing-1": "thing-1-value", "thing-2": int64(800), "thing-3": []byte("some inner bytes..."), "thing-4": false, + int64(1): "thing-1-value", + int64(2): int64(800), + int64(3): []byte("some inner bytes..."), + int64(4): false, }, } @@ -60,12 +65,30 @@ func TestReadIntf(t *testing.T) { t.Errorf("Test case: %d: %s", i, err) } if !reflect.DeepEqual(v, ts) { - t.Errorf("%v in; %v out", ts, v) + // if v and ts are maps + if m, ok := v.(map[interface{}]interface{}); ok { + t.Errorf( + "\n%s\n%s\n", + dumpMap("v", m), + dumpMap("ts", ts.(map[interface{}]interface{})), + ) + } else { + t.Errorf("in: %#v; out: %#v", ts, v) + } } } } +func dumpMap(label string, m map[interface{}]interface{}) string { + buf := &bytes.Buffer{} + fmt.Fprintf(buf, "map %q contents:\n", label) + for k, v := range m { + fmt.Fprintf(buf, "%#v (%t) -> %#v (%t)\n", k, k, v, v) + } + return buf.String() +} + func TestReadMapHeader(t *testing.T) { tests := []struct { Sz uint32 @@ -722,3 +745,48 @@ func BenchmarkSkip(b *testing.B) { } } } + +func TestCopyNext(t *testing.T) { + var buf bytes.Buffer + en := NewWriter(&buf) + + en.WriteMapHeader(6) + + en.WriteString("thing_one") + en.WriteString("value_one") + + en.WriteString("thing_two") + en.WriteFloat64(3.14159) + + en.WriteString("some_bytes") + en.WriteBytes([]byte("nkl4321rqw908vxzpojnlk2314rqew098-s09123rdscasd")) + + en.WriteString("the_time") + en.WriteTime(time.Now()) + + en.WriteString("what?") + en.WriteBool(true) + + en.WriteString("ext") + en.WriteExtension(&RawExtension{Type: 55, Data: []byte("raw data!!!")}) + + en.Flush() + + // Read from a copy of the original buf. + de := NewReader(bytes.NewReader(buf.Bytes())) + + w := new(bytes.Buffer) + + n, err := de.CopyNext(w) + if err != nil { + t.Fatal(err) + } + if n != int64(buf.Len()) { + t.Fatalf("CopyNext returned the wrong value (%d != %d)", + n, buf.Len()) + } + + if !bytes.Equal(buf.Bytes(), w.Bytes()) { + t.Fatalf("not equal! %v, %v", buf.Bytes(), w.Bytes()) + } +} diff --git a/msgp/write.go b/msgp/write.go index 0245c1bd..2fecfe1a 100644 --- a/msgp/write.go +++ b/msgp/write.go @@ -558,7 +558,7 @@ func (mw *Writer) WriteMapStrStr(mp map[string]string) (err error) { return nil } -// WriteMapStrIntf writes a map[string]interface to the writer +// WriteMapStrIntf writes a map[string]interface{} to the writer. func (mw *Writer) WriteMapStrIntf(mp map[string]interface{}) (err error) { err = mw.WriteMapHeader(uint32(len(mp))) if err != nil { @@ -577,6 +577,25 @@ func (mw *Writer) WriteMapStrIntf(mp map[string]interface{}) (err error) { return } +// WriteMapIntfIntf writes a map[interface{}]interface{} to the writer. +func (mw *Writer) WriteMapIntfIntf(mp map[interface{}]interface{}) (err error) { + err = mw.WriteMapHeader(uint32(len(mp))) + if err != nil { + return + } + for key, val := range mp { + err = mw.WriteIntf(key) + if err != nil { + return + } + err = mw.WriteIntf(val) + if err != nil { + return + } + } + return +} + // WriteTime writes a time.Time object to the wire. // // Time is encoded as Unix time, which means that @@ -665,6 +684,8 @@ func (mw *Writer) WriteIntf(v interface{}) error { return mw.WriteMapStrStr(v) case map[string]interface{}: return mw.WriteMapStrIntf(v) + case map[interface{}]interface{}: + return mw.WriteMapIntfIntf(v) case time.Time: return mw.WriteTime(v) } diff --git a/parse/directives.go b/parse/directives.go index fb78974b..dcd4ecd6 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -2,9 +2,11 @@ package parse import ( "fmt" - "github.com/tinylib/msgp/gen" "go/ast" "strings" + + "github.com/tinylib/msgp/gen" + "github.com/tinylib/msgp/internal/log" ) const linePrefix = "//msgp:" @@ -30,12 +32,12 @@ var passDirectives = map[string]passDirective{ } func passignore(m gen.Method, text []string, p *gen.Printer) error { - pushstate(m.String()) + log.PushState(m.String()) for _, a := range text { p.ApplyDirective(m, gen.IgnoreTypename(a)) - infof("ignoring %s\n", a) + log.Infof("ignoring %s\n", a) } - popstate() + log.PopState() return nil } @@ -76,7 +78,7 @@ func applyShim(text []string, f *FileSet) error { be.ShimToBase = methods[0] be.ShimFromBase = methods[1] - infof("%s -> %s\n", name, be.Value.String()) + log.Infof("%s -> %s\n", name, be.Value.String()) f.findShim(name, be) return nil @@ -91,7 +93,7 @@ func ignore(text []string, f *FileSet) error { name := strings.TrimSpace(item) if _, ok := f.Identities[name]; ok { delete(f.Identities, name) - infof("ignoring %s\n", name) + log.Infof("ignoring %s\n", name) } } return nil @@ -107,9 +109,9 @@ func astuple(text []string, f *FileSet) error { if el, ok := f.Identities[name]; ok { if st, ok := el.(*gen.Struct); ok { st.AsTuple = true - infoln(name) + log.Infoln(name) } else { - warnf("%s: only structs can be tuples\n", name) + log.Warnf("%s: only structs can be tuples\n", name) } } } diff --git a/parse/getast.go b/parse/getast.go index 355ad772..dbd9639b 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -8,10 +8,11 @@ import ( "os" "reflect" "sort" + "strconv" "strings" "github.com/tinylib/msgp/gen" - "github.com/ttacon/chalk" + "github.com/tinylib/msgp/internal/log" ) // A FileSet is the in-memory representation of a @@ -31,8 +32,8 @@ type FileSet struct { // If unexport is false, only exported identifiers are included in the FileSet. // If the resulting FileSet would be empty, an error is returned. func File(name string, unexported bool) (*FileSet, error) { - pushstate(name) - defer popstate() + log.PushState(name) + defer log.PopState() fs := &FileSet{ Specs: make(map[string]ast.Expr), Identities: make(map[string]gen.Elem), @@ -58,13 +59,13 @@ func File(name string, unexported bool) (*FileSet, error) { } fs.Package = one.Name for _, fl := range one.Files { - pushstate(fl.Name.Name) + log.PushState(fl.Name.Name) fs.Directives = append(fs.Directives, yieldComments(fl.Comments)...) if !unexported { ast.FileExports(fl) } fs.getTypeSpecs(fl) - popstate() + log.PopState() } } else { f, err := parser.ParseFile(fset, name, nil, parser.ParseComments) @@ -99,12 +100,12 @@ func (f *FileSet) applyDirectives() { chunks := strings.Split(d, " ") if len(chunks) > 0 { if fn, ok := directives[chunks[0]]; ok { - pushstate(chunks[0]) + log.PushState(chunks[0]) err := fn(chunks, f) if err != nil { - warnln(err.Error()) + log.Warnln(err.Error()) } - popstate() + log.PopState() } else { newdirs = append(newdirs, d) } @@ -158,7 +159,7 @@ func (f *FileSet) resolve(ls linkset) { // what's left can't be resolved for name, elem := range ls { - warnf("couldn't resolve type %s (%s)\n", name, elem.TypeName()) + log.Warnf("couldn't resolve type %s (%s)\n", name, elem.TypeName()) } } @@ -169,11 +170,11 @@ func (f *FileSet) process() { deferred := make(linkset) parse: for name, def := range f.Specs { - pushstate(name) + log.PushState(name) el := f.parseExpr(def) if el == nil { - warnln("failed to parse") - popstate() + log.Warnln("failed to parse") + log.PopState() continue parse } // push unresolved identities into @@ -181,12 +182,12 @@ parse: // we've handled every possible named type. if be, ok := el.(*gen.BaseElem); ok && be.Value == gen.IDENT { deferred[name] = be - popstate() + log.PopState() continue parse } el.Alias(name) f.Identities[name] = el - popstate() + log.PopState() } if len(deferred) > 0 { @@ -227,21 +228,21 @@ loop: } m := strToMethod(chunks[0]) if m == 0 { - warnf("unknown pass name: %q\n", chunks[0]) + log.Warnf("unknown pass name: %q\n", chunks[0]) continue loop } if fn, ok := passDirectives[chunks[1]]; ok { - pushstate(chunks[1]) + log.PushState(chunks[1]) err := fn(m, chunks[2:], p) if err != nil { - warnf("error applying directive: %s\n", err) + log.Warnf("error applying directive: %s\n", err) } - popstate() + log.PopState() } else { - warnf("unrecognized directive %q\n", chunks[1]) + log.Warnf("unrecognized directive %q\n", chunks[1]) } } else { - warnf("empty directive: %q\n", d) + log.Warnf("empty directive: %q\n", d) } } } @@ -256,9 +257,9 @@ func (f *FileSet) PrintTo(p *gen.Printer) error { for _, name := range names { el := f.Identities[name] el.SetVarname("z") - pushstate(el.TypeName()) + log.PushState(el.TypeName()) err := p.Print(el) - popstate() + log.PopState() if err != nil { return err } @@ -319,14 +320,14 @@ func (fs *FileSet) parseFieldList(fl *ast.FieldList) []gen.StructField { } out := make([]gen.StructField, 0, fl.NumFields()) for _, field := range fl.List { - pushstate(fieldName(field)) + log.PushState(fieldName(field)) fds := fs.getField(field) if len(fds) > 0 { out = append(out, fds...) } else { - warnln("ignored.") + log.Warnln("ignored") } - popstate() + log.PopState() } return out } @@ -339,14 +340,32 @@ func (fs *FileSet) getField(f *ast.Field) []gen.StructField { if f.Tag != nil { body := reflect.StructTag(strings.Trim(f.Tag.Value, "`")).Get("msg") tags := strings.Split(body, ",") - if len(tags) == 2 && tags[1] == "extension" { - extension = true - } + // ignore "-" fields if tags[0] == "-" { return nil } - sf[0].FieldTag = tags[0] + + if tags[0] != "" { + sf[0].FieldTag = tags[0] + } + + if len(tags) > 1 { + last := len(tags) - 1 + extension = tags[last] == "extension" + + var err error + switch tags[1] { + case "uint": + sf[0].FieldTag, err = strconv.ParseUint(tags[0], 0, 64) + case "int": + sf[0].FieldTag, err = strconv.ParseInt(tags[0], 0, 64) + } + if err != nil { + log.Warnf("could not parse field label %q as msgp.%s: %s\n", tags[0], tags[1], err) + return nil + } + } } ex := fs.parseExpr(f.Type) @@ -374,7 +393,7 @@ func (fs *FileSet) getField(f *ast.Field) []gen.StructField { return sf } sf[0].FieldElem = ex - if sf[0].FieldTag == "" { + if sf[0].FieldTag == nil { sf[0].FieldTag = sf[0].FieldName } @@ -385,13 +404,13 @@ func (fs *FileSet) getField(f *ast.Field) []gen.StructField { if b, ok := ex.Value.(*gen.BaseElem); ok { b.Value = gen.Ext } else { - warnln("couldn't cast to extension.") + log.Warnln("couldn't cast to extension.") return nil } case *gen.BaseElem: ex.Value = gen.Ext default: - warnln("couldn't cast to extension.") + log.Warnln("couldn't cast to extension.") return nil } } @@ -456,9 +475,12 @@ func (fs *FileSet) parseExpr(e ast.Expr) gen.Elem { switch e := e.(type) { case *ast.MapType: - if k, ok := e.Key.(*ast.Ident); ok && k.Name == "string" { - if in := fs.parseExpr(e.Value); in != nil { - return &gen.Map{Value: in} + key := fs.parseExpr(e.Key) + val := fs.parseExpr(e.Value) + if key != nil && val != nil { + return &gen.Map{ + Key: key, + Value: val, } } return nil @@ -466,12 +488,12 @@ func (fs *FileSet) parseExpr(e ast.Expr) gen.Elem { case *ast.Ident: b := gen.Ident(e.Name) - // work to resove this expression + // work to resolve this expression // can be done later, once we've resolved // everything else. if b.Value == gen.IDENT { if _, ok := fs.Specs[e.Name]; !ok { - warnf("non-local identifier: %s\n", e.Name) + log.Warnf("non-local identifier: %s\n", e.Name) } } return b @@ -545,45 +567,3 @@ func (fs *FileSet) parseExpr(e ast.Expr) gen.Elem { return nil } } - -func infof(s string, v ...interface{}) { - pushstate(s) - fmt.Printf(chalk.Green.Color(strings.Join(logctx, ": ")), v...) - popstate() -} - -func infoln(s string) { - pushstate(s) - fmt.Println(chalk.Green.Color(strings.Join(logctx, ": "))) - popstate() -} - -func warnf(s string, v ...interface{}) { - pushstate(s) - fmt.Printf(chalk.Yellow.Color(strings.Join(logctx, ": ")), v...) - popstate() -} - -func warnln(s string) { - pushstate(s) - fmt.Println(chalk.Yellow.Color(strings.Join(logctx, ": "))) - popstate() -} - -func fatalf(s string, v ...interface{}) { - pushstate(s) - fmt.Printf(chalk.Red.Color(strings.Join(logctx, ": ")), v...) - popstate() -} - -var logctx []string - -// push logging state -func pushstate(s string) { - logctx = append(logctx, s) -} - -// pop logging state -func popstate() { - logctx = logctx[:len(logctx)-1] -} diff --git a/parse/inline.go b/parse/inline.go index 85d60c92..3a03f274 100644 --- a/parse/inline.go +++ b/parse/inline.go @@ -2,6 +2,7 @@ package parse import ( "github.com/tinylib/msgp/gen" + "github.com/tinylib/msgp/internal/log" ) // This file defines when and how we @@ -32,7 +33,7 @@ const maxComplex = 5 // given name and replace them with be func (f *FileSet) findShim(id string, be *gen.BaseElem) { for name, el := range f.Identities { - pushstate(name) + log.PushState(name) switch el := el.(type) { case *gen.Struct: for i := range el.Fields { @@ -47,7 +48,7 @@ func (f *FileSet) findShim(id string, be *gen.BaseElem) { case *gen.Ptr: f.nextShim(&el.Value, id, be) } - popstate() + log.PopState() } // we'll need this at the top level as well f.Identities[id] = be @@ -77,7 +78,7 @@ func (f *FileSet) nextShim(ref *gen.Elem, id string, be *gen.BaseElem) { // propInline identifies and inlines candidates func (f *FileSet) propInline() { for name, el := range f.Identities { - pushstate(name) + log.PushState(name) switch el := el.(type) { case *gen.Struct: for i := range el.Fields { @@ -92,7 +93,7 @@ func (f *FileSet) propInline() { case *gen.Ptr: f.nextInline(&el.Value, name) } - popstate() + log.PopState() } } @@ -109,7 +110,7 @@ func (f *FileSet) nextInline(ref *gen.Elem, root string) { typ := el.TypeName() if el.Value == gen.IDENT && typ != root { if node, ok := f.Identities[typ]; ok && node.Complexity() < maxComplex { - infof("inlining %s\n", typ) + log.Infof("inlining %s\n", typ) // This should never happen; it will cause // infinite recursion. @@ -125,7 +126,7 @@ func (f *FileSet) nextInline(ref *gen.Elem, root string) { // this is the point at which we're sure that // we've got a type that isn't a primitive, // a library builtin, or a processed type - warnf("unresolved identifier: %s\n", typ) + log.Warnf("unresolved identifier: %s\n", typ) } } case *gen.Struct: diff --git a/printer/print.go b/printer/print.go index 4766871f..501c6f6b 100644 --- a/printer/print.go +++ b/printer/print.go @@ -3,18 +3,15 @@ package printer import ( "bytes" "fmt" - "github.com/tinylib/msgp/gen" - "github.com/tinylib/msgp/parse" - "github.com/ttacon/chalk" - "golang.org/x/tools/imports" "io" "io/ioutil" "strings" -) -func infof(s string, v ...interface{}) { - fmt.Printf(chalk.Magenta.Color(s), v...) -} + "github.com/tinylib/msgp/gen" + "github.com/tinylib/msgp/internal/log" + "github.com/tinylib/msgp/parse" + "golang.org/x/tools/imports" +) // PrintFile prints the methods for the provided list // of elements to the given file name and canonical @@ -38,7 +35,7 @@ func PrintFile(file string, f *parse.FileSet, mode gen.Method) error { if err != nil { return err } - infof(">>> Wrote and formatted \"%s\"\n", testfile) + log.Infof(">>> Wrote and formatted \"%s\"\n", testfile) } err = <-res if err != nil { @@ -59,7 +56,7 @@ func goformat(file string, data []byte) <-chan error { out := make(chan error, 1) go func(file string, data []byte, end chan error) { end <- format(file, data) - infof(">>> Wrote and formatted \"%s\"\n", file) + log.Infof(">>> Wrote and formatted \"%s\"\n", file) }(file, data, out) return out }