-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
068b534
commit 7ca1d91
Showing
4 changed files
with
157 additions
and
138 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,163 +1,211 @@ | ||
package bson | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"io" | ||
"math" | ||
"reflect" | ||
"sort" | ||
"strconv" | ||
) | ||
|
||
// Encoder writes BSON values to an output stream. | ||
type Encoder struct { | ||
w io.Writer | ||
buf *bytes.Buffer | ||
buf []byte | ||
} | ||
|
||
// NewEncoder returns a new encoder that writes to w. | ||
func NewEncoder(w io.Writer) *Encoder { | ||
return &Encoder{ | ||
w: w, | ||
buf: bytes.NewBuffer(make([]byte, 0, 512)), | ||
buf: make([]byte, 0, 512), | ||
} | ||
} | ||
|
||
// Encode writes the BSON encoding of v to the stream. | ||
func (enc *Encoder) Encode(v any) error { | ||
enc.buf.Reset() | ||
enc.buf = enc.buf[:0] | ||
if err := enc.marshal(v); err != nil { | ||
return fmt.Errorf("encode failed: %w", err) | ||
} | ||
_, err := enc.w.Write(enc.buf.Bytes()) | ||
_, err := enc.w.Write(enc.buf) | ||
return err | ||
} | ||
|
||
func (enc *Encoder) marshal(v any) error { | ||
switch v := v.(type) { | ||
case Marshaler: | ||
if v, ok := v.(Marshaler); ok { | ||
raw, err := v.MarshalBSON() | ||
if err != nil { | ||
return err | ||
} | ||
enc.buf.Write(raw) | ||
case A: | ||
enc.marshalArray(enc.buf, v) | ||
case D: | ||
enc.marshalDoc(enc.buf, v) | ||
case M: | ||
enc.marshalDoc(enc.buf, v.AsD()) | ||
enc.buf = append(enc.buf, raw...) | ||
return nil | ||
} | ||
|
||
var err error | ||
switch rv := reflect.ValueOf(v); rv.Kind() { | ||
// TODO(cristaloleg): add reflect.Struct | ||
case reflect.Map: | ||
_, err = enc.writeMap(rv) | ||
case reflect.Array, reflect.Slice: | ||
_, err = enc.writeSlice(rv) | ||
default: | ||
return enc.marshalReflect(enc.buf, v) | ||
return fmt.Errorf("type %T is not supported yet", v) | ||
} | ||
return nil | ||
return err | ||
} | ||
|
||
func (enc *Encoder) marshalArray(w io.Writer, arr A) error { | ||
doc := make(D, len(arr)) | ||
for i := range arr { | ||
// TODO(cristaloleg): doc[i] value box-unbox can be omitted. | ||
func (enc *Encoder) writeMap(v reflect.Value) (int, error) { | ||
start := len(enc.buf) | ||
enc.buf = append(enc.buf, 0, 0, 0, 0) | ||
count := 4 + 1 // sizeof(int) + sizeof(\0) | ||
|
||
doc := make(D, v.Len()) | ||
for i, iter := 0, v.MapRange(); iter.Next(); i++ { | ||
doc[i] = e{ | ||
K: strconv.Itoa(i), | ||
V: arr[i], | ||
K: iter.Key().String(), | ||
V: iter.Value().Interface(), | ||
} | ||
} | ||
return enc.marshalDoc(w, doc) | ||
} | ||
|
||
func (enc *Encoder) marshalDoc(w io.Writer, doc D) error { | ||
// TODO(cristaloleg): prealloc or smarter way. | ||
elist := bytes.NewBuffer(make([]byte, 0, 128)) | ||
|
||
for i := range doc { | ||
key := doc[i].K | ||
val := doc[i].V | ||
|
||
switch v := val.(type) { | ||
case string: | ||
enc.writeKey(elist, TypeString, key) | ||
b := putUint32(uint32(len(v) + 1)) | ||
elist.Write(b[:]) | ||
elist.WriteString(v) | ||
elist.WriteByte(0) | ||
|
||
case int32: | ||
enc.writeKey(elist, TypeInt32, key) | ||
b := putUint32(uint32(v)) | ||
elist.Write(b[:]) | ||
|
||
case int64: | ||
enc.writeKey(elist, TypeInt64, key) | ||
b := putUint64(uint64(v)) | ||
elist.Write(b[:]) | ||
|
||
case bool: | ||
enc.writeKey(elist, TypeBool, key) | ||
elist.WriteByte(putBool(v)) | ||
|
||
default: | ||
var err error | ||
switch rv := reflect.ValueOf(val); rv.Kind() { | ||
case reflect.Map: | ||
enc.writeKey(elist, TypeDocument, key) | ||
err = enc.marshalMap(elist, rv) | ||
|
||
case reflect.Array, reflect.Slice: | ||
enc.writeKey(elist, TypeArray, key) | ||
err = enc.marshalSlice(elist, rv) | ||
|
||
default: | ||
return fmt.Errorf("type %T is not supported yet", v) | ||
} | ||
if err != nil { | ||
return err | ||
} | ||
// TODO(cristaloleg): use generic sort. | ||
sort.Slice(doc, func(i, j int) bool { | ||
return doc[i].K < doc[j].K | ||
}) | ||
|
||
for i := 0; i < len(doc); i++ { | ||
n, err := enc.writeValue(doc[i].K, reflect.ValueOf(doc[i].V)) | ||
if err != nil { | ||
return 0, err | ||
} | ||
count += n | ||
} | ||
|
||
size := 4 + elist.Len() + 1 // header + len + null. | ||
b := putUint32(uint32(size)) | ||
w.Write(b[:]) | ||
|
||
io.Copy(w, elist) | ||
w.Write([]byte{0}) | ||
return nil | ||
enc.buf = append(enc.buf, 0) | ||
enc.buf[start] = byte(count) | ||
enc.buf[start+1] = byte(count >> 8) | ||
enc.buf[start+2] = byte(count >> 16) | ||
enc.buf[start+3] = byte(count >> 24) | ||
return count, nil | ||
} | ||
|
||
func (enc *Encoder) writeKey(buf *bytes.Buffer, t Type, s string) { | ||
buf.WriteByte(byte(t)) | ||
buf.WriteString(s) | ||
buf.WriteByte(0) | ||
func (enc *Encoder) writeSlice(v reflect.Value) (int, error) { | ||
start := len(enc.buf) | ||
enc.buf = append(enc.buf, 0, 0, 0, 0) | ||
count := 4 + 1 // sizeof(int) + sizeof(\0) | ||
|
||
n := v.Len() | ||
for i := 0; i < n; i++ { | ||
val := v.Index(i) | ||
|
||
n, err := enc.writeValue(strconv.Itoa(i), val) | ||
if err != nil { | ||
return 0, err | ||
} | ||
count += n | ||
} | ||
|
||
enc.buf = append(enc.buf, 0) | ||
enc.buf[start] = byte(count) | ||
enc.buf[start+1] = byte(count >> 8) | ||
enc.buf[start+2] = byte(count >> 16) | ||
enc.buf[start+3] = byte(count >> 24) | ||
return count, nil | ||
} | ||
|
||
func (enc *Encoder) marshalReflect(w io.Writer, v any) error { | ||
switch rv := reflect.ValueOf(v); rv.Kind() { | ||
// TODO(cristaloleg): probably split into simple & compound types. | ||
func (enc *Encoder) writeValue(ename string, v reflect.Value) (int, error) { | ||
if v.Kind() == reflect.Interface { | ||
return enc.writeValue(ename, v.Elem()) | ||
} | ||
|
||
var count int | ||
switch v.Kind() { | ||
// TODO(cristaloleg): add reflect.Struct | ||
case reflect.Map: | ||
return enc.marshalMap(w, rv) | ||
case reflect.String: | ||
count += enc.writeElem(TypeString, ename) | ||
count += enc.writeString(v.String()) | ||
case reflect.Int32: | ||
count += enc.writeElem(TypeInt32, ename) | ||
count += enc.writeInt32(int32(v.Int())) | ||
case reflect.Int64: | ||
count += enc.writeElem(TypeInt64, ename) | ||
count += enc.writeInt64(int64(v.Int())) | ||
case reflect.Bool: | ||
count += enc.writeElem(TypeBool, ename) | ||
count += enc.writeBool(v.Bool()) | ||
case reflect.Float64: | ||
count += enc.writeElem(TypeDouble, ename) | ||
count += enc.writeInt64(int64(math.Float64bits(v.Float()))) | ||
|
||
case reflect.Array, reflect.Slice: | ||
return enc.marshalSlice(w, rv) | ||
count += enc.writeElem(TypeArray, ename) | ||
n, err := enc.writeSlice(v) | ||
if err != nil { | ||
return 0, err | ||
} | ||
count += n | ||
|
||
case reflect.Map: | ||
count += enc.writeElem(TypeDocument, ename) | ||
n, err := enc.writeMap(v) | ||
if err != nil { | ||
return 0, err | ||
} | ||
count += n | ||
|
||
default: | ||
return fmt.Errorf("type %T is not supported yet", v) | ||
return 0, fmt.Errorf("type %T is not supported", v) | ||
} | ||
return count, nil | ||
} | ||
|
||
func (enc *Encoder) marshalMap(w io.Writer, v reflect.Value) error { | ||
doc := make(D, v.Len()) | ||
for i, iter := 0, v.MapRange(); iter.Next(); i++ { | ||
doc[i] = e{ | ||
K: iter.Key().String(), | ||
V: iter.Value().Interface(), | ||
} | ||
} | ||
return enc.marshalDoc(w, doc) | ||
func (enc *Encoder) writeElem(typ Type, key string) int { | ||
enc.buf = append(enc.buf, byte(typ)) | ||
enc.buf = append(enc.buf, key...) | ||
enc.buf = append(enc.buf, 0) | ||
return 1 + len(key) + 1 | ||
} | ||
|
||
func (enc *Encoder) marshalSlice(w io.Writer, v reflect.Value) error { | ||
doc := make(D, v.Len()) | ||
for i := 0; i < v.Len(); i++ { | ||
doc[i] = e{ | ||
K: strconv.Itoa(i), | ||
V: v.Index(i).Interface(), | ||
} | ||
func (enc *Encoder) writeString(s string) int { | ||
size := len(s) + 1 | ||
enc.writeInt32(int32(size)) | ||
enc.buf = append(enc.buf, s...) | ||
enc.buf = append(enc.buf, 0) | ||
return 4 + size | ||
} | ||
|
||
func (enc *Encoder) writeInt32(v int32) int { | ||
enc.buf = append(enc.buf, | ||
byte(v), | ||
byte(v>>8), | ||
byte(v>>16), | ||
byte(v>>24), | ||
) | ||
return 4 | ||
} | ||
|
||
func (enc *Encoder) writeInt64(v int64) int { | ||
enc.buf = append(enc.buf, | ||
byte(v), | ||
byte(v>>8), | ||
byte(v>>16), | ||
byte(v>>24), | ||
byte(v>>32), | ||
byte(v>>40), | ||
byte(v>>48), | ||
byte(v>>56), | ||
) | ||
return 8 | ||
} | ||
|
||
func (enc *Encoder) writeBool(b bool) int { | ||
var v byte | ||
if b { | ||
v = 1 | ||
} | ||
return enc.marshalDoc(w, doc) | ||
enc.buf = append(enc.buf, v) | ||
return 1 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters