Skip to content

Commit

Permalink
Better encoder (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg authored Aug 16, 2023
1 parent 74f64ea commit d2eda37
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 111 deletions.
19 changes: 18 additions & 1 deletion bson.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package bson

import "bytes"
import (
"bytes"
"sort"
)

// Marshaler is the interface implemented by types that
// can marshal themselves into valid BSON.
Expand Down Expand Up @@ -52,3 +55,17 @@ type e struct {
//
// bson.M{"hello": "world", "foo": "bar", "pi": 3.14159}
type M map[string]any

func (m M) AsD() D {
d := make(D, len(m))
i := 0
for k, v := range m {
d[i] = e{K: k, V: v}
i++
}

sort.Slice(d, func(i, j int) bool {
return d[i].K < d[j].K
})
return d
}
110 changes: 65 additions & 45 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ package bson

import (
"bytes"
"encoding/binary"
"fmt"
"io"
"math"
"time"
"strconv"
)

// Encoder writes BSON values to an output stream.
Expand All @@ -24,67 +22,89 @@ func NewEncoder(w io.Writer) *Encoder {
}

// Encode writes the BSON encoding of v to the stream.
func (e *Encoder) Encode(v any) error {
e.buf.Reset()
if err := e.marshal(v); err != nil {
func (enc *Encoder) Encode(v any) error {
enc.buf.Reset()
if err := enc.marshal(v); err != nil {
return fmt.Errorf("encode failed: %w", err)
}
_, err := e.w.Write(e.buf.Bytes())
_, err := enc.w.Write(enc.buf.Bytes())
return err
}

func (e *Encoder) marshal(v any) error {
func (enc *Encoder) marshal(v any) error {
switch v := v.(type) {
case Marshaler:
raw, err := v.MarshalBSON()
if err != nil {
return err
}
e.buf.Write(raw)
enc.buf.Write(raw)
case A:
enc.marshalArray(v)
case D:
enc.marshalDoc(v)
case M:
enc.marshalDoc(v.AsD())
default:
return fmt.Errorf("type %T is not supported yet", v)
}
return nil
}

case []byte:
var b [4]byte
binary.LittleEndian.PutUint32(b[:], uint32(len(v)+1))
e.buf.Write(b[:])
e.buf.WriteByte(0x80) // TODO(cristaloleg): better binary type?
e.buf.Write(v)
func (enc *Encoder) marshalArray(arr A) error {
doc := make(D, len(arr))
for i := range arr {
doc[i] = e{
K: strconv.Itoa(i),
V: arr[i],
}
}
return enc.marshalDoc(doc)
}

case string:
var b [4]byte
binary.LittleEndian.PutUint32(b[:], uint32(len(v)+1))
e.buf.Write(b[:])
e.buf.Write([]byte(v))
e.buf.WriteByte(0)
func (enc *Encoder) marshalDoc(doc D) error {
// TODO(cristaloleg): prealloc or smarter way.
var elist bytes.Buffer

case time.Time:
var b [8]byte
binary.LittleEndian.PutUint64(b[:], uint64(v.UnixMilli()))
e.buf.Write(b[:])
for i := range doc {
pair := doc[i]
key := doc[i].K

case int32:
var b [4]byte
binary.LittleEndian.PutUint32(b[:], uint32(v))
e.buf.Write(b[:])
switch v := pair.V.(type) {
case string:
enc.writeKey(&elist, TypeString, key)
b := putUint32(uint32(len(v) + 1))
elist.Write(b[:])
elist.WriteString(v)
elist.WriteByte(0)

case int64:
var b [8]byte
binary.LittleEndian.PutUint64(b[:], uint64(v))
e.buf.Write(b[:])
case int32:
enc.writeKey(&elist, TypeInt32, key)
b := putUint32(uint32(v))
elist.Write(b[:])

case float64:
var b [8]byte
binary.LittleEndian.PutUint64(b[:], math.Float64bits(float64(v)))
e.buf.Write(b[:])
case int64:
enc.writeKey(&elist, TypeInt64, key)
b := putUint64(uint64(v))
elist.Write(b[:])

case bool:
var b [1]byte
if v {
b[0] = 1
case bool:
enc.writeKey(&elist, TypeBool, key)
elist.WriteByte(putBool(v))
}
e.buf.Write(b[:])

default:
return fmt.Errorf("type %T is not supported yet", v)
}

size := 4 + elist.Len() + 1 // header + len + null.
b := putUint32(uint32(size))
enc.buf.Write(b[:])

io.Copy(enc.buf, &elist)
enc.buf.WriteByte(0)
return nil
}

func (enc *Encoder) writeKey(buf *bytes.Buffer, t Type, s string) {
buf.WriteByte(byte(t))
buf.WriteString(s)
buf.WriteByte(0)
}
107 changes: 42 additions & 65 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,118 +2,95 @@ package bson

import (
"bytes"
"math"
"testing"
"time"
)

func TestEncode(t *testing.T) {
func TestEncodeA(t *testing.T) {
var buf bytes.Buffer
enc := NewEncoder(&buf)

err := enc.Encode(int32(123456789))
mustOk(t, err)
wantBytes(t, buf.Bytes(), "15cd5b07")
buf.Reset()

err = enc.Encode(int64(123456789123456789))
mustOk(t, err)
wantBytes(t, buf.Bytes(), "155fd0ac4b9bb601")
buf.Reset()

err = enc.Encode(true)
mustOk(t, err)
wantBytes(t, buf.Bytes(), "01")
buf.Reset()
}

func TestEncodeBytes(t *testing.T) {
var err error
var buf bytes.Buffer
enc := NewEncoder(&buf)
var arr A

err = enc.Encode([]byte("foo"))
arr = A{}
err = enc.Encode(arr)
mustOk(t, err)
wantBytes(t, buf.Bytes(), "0400000080666f6f")
wantBytes(t, buf.Bytes(), "0500000000")
buf.Reset()

err = enc.Encode([]byte{0x00})
arr = A{int32(30)}
err = enc.Encode(arr)
mustOk(t, err)
wantBytes(t, buf.Bytes(), "020000008000")
wantBytes(t, buf.Bytes(), "0c0000001030001e00000000")
buf.Reset()
}

func TestEncodeString(t *testing.T) {
var err error
var buf bytes.Buffer
enc := NewEncoder(&buf)

err = enc.Encode("foo")
arr = A{"a", int32(10), "c"}
err = enc.Encode(arr)
mustOk(t, err)
wantBytes(t, buf.Bytes(), "04000000666f6f00")
wantBytes(t, buf.Bytes(), "1e0000000230000200000061001031000a00000002320002000000630000")
buf.Reset()

err = enc.Encode("")
arr = A{"a", int32(10), "c", true, "b", int64(10203040)}
err = enc.Encode(arr)
mustOk(t, err)
wantBytes(t, buf.Bytes(), "0100000000")
wantBytes(t, buf.Bytes(), "360000000230000200000061001031000a00000002320002000000630008330001023400020000006200123500a0af9b000000000000")
buf.Reset()
}

func TestEncodeNumbers(t *testing.T) {
func TestEncodeD(t *testing.T) {
var buf bytes.Buffer
enc := NewEncoder(&buf)

err := enc.Encode(float64(3.14159))
mustOk(t, err)
wantBytes(t, buf.Bytes(), "6e861bf0f9210940")
buf.Reset()

err = enc.Encode(float64(0))
mustOk(t, err)
wantBytes(t, buf.Bytes(), "0000000000000000")
buf.Reset()
var err error
var doc D

err = enc.Encode(math.NaN())
doc = D{}
err = enc.Encode(doc)
mustOk(t, err)
wantBytes(t, buf.Bytes(), "010000000000f87f")
wantBytes(t, buf.Bytes(), "0500000000")
buf.Reset()

err = enc.Encode(math.Inf(+1))
doc = D{{"a", int32(10)}}
err = enc.Encode(doc)
mustOk(t, err)
wantBytes(t, buf.Bytes(), "000000000000f07f")
wantBytes(t, buf.Bytes(), "0c0000001061000a00000000")
buf.Reset()

err = enc.Encode(math.Inf(-1))
doc = D{{"a", int32(10)}, {"c", true}}
err = enc.Encode(doc)
mustOk(t, err)
wantBytes(t, buf.Bytes(), "000000000000f0ff")
wantBytes(t, buf.Bytes(), "100000001061000a0000000863000100")
buf.Reset()

err = enc.Encode(42.13)
doc = D{{"a", int32(10)}, {"c", true}, {"b", int64(10203040)}}
err = enc.Encode(doc)
mustOk(t, err)
wantBytes(t, buf.Bytes(), "713d0ad7a3104540")
wantBytes(t, buf.Bytes(), "1b0000001061000a00000008630001126200a0af9b000000000000")
buf.Reset()
}

func TestEncodeTime(t *testing.T) {
var err error
func TestEncodeM(t *testing.T) {
var buf bytes.Buffer
enc := NewEncoder(&buf)

var ts time.Time
err = enc.Encode(ts)
var err error
var maa M

maa = M{"a": int32(10)}
err = enc.Encode(maa)
mustOk(t, err)
wantBytes(t, buf.Bytes(), "0028d3ed7cc7ffff")
wantBytes(t, buf.Bytes(), "0c0000001061000a00000000")
buf.Reset()

ts = time.Unix(0, 0)
err = enc.Encode(ts)
maa = M{"a": int32(10), "c": true}
err = enc.Encode(maa)
mustOk(t, err)
wantBytes(t, buf.Bytes(), "0000000000000000")
wantBytes(t, buf.Bytes(), "100000001061000a0000000863000100")
buf.Reset()

ts = time.Date(2023, 8, 17, 10, 20, 30, 0, time.UTC)
err = enc.Encode(ts)
maa = M{"a": int32(10), "c": true, "b": int64(10203040)}
err = enc.Encode(maa)
mustOk(t, err)
wantBytes(t, buf.Bytes(), "b0cd02038a010000")
wantBytes(t, buf.Bytes(), "1b0000001061000a000000126200a0af9b00000000000863000100")
buf.Reset()
}
29 changes: 29 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,32 @@ func must[T any](v T, err error) T {
}
return v
}

func putBool(b bool) byte {
if b {
return 1
}
return 0
}

func putUint32(v uint32) [4]byte {
var b [4]byte
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
return b
}

func putUint64(v uint64) [8]byte {
var b [8]byte
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
b[4] = byte(v >> 32)
b[5] = byte(v >> 40)
b[6] = byte(v >> 48)
b[7] = byte(v >> 56)
return b
}

0 comments on commit d2eda37

Please sign in to comment.