diff --git a/chotki.go b/chotki.go index ff3269c..0baed9d 100644 --- a/chotki.go +++ b/chotki.go @@ -11,8 +11,8 @@ import ( "github.com/cockroachdb/pebble" "github.com/cockroachdb/pebble/vfs" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" ) type Packet []byte diff --git a/chotki_test.go b/chotki_test.go index 6d83eff..7d7e23b 100644 --- a/chotki_test.go +++ b/chotki_test.go @@ -7,7 +7,7 @@ import ( "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" + "github.com/drpcorg/chotki/toyqueue" "github.com/stretchr/testify/assert" ) diff --git a/counter.go b/counter.go index 1ff6326..dd48629 100644 --- a/counter.go +++ b/counter.go @@ -2,7 +2,7 @@ package chotki import ( "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) type Counter64 int64 diff --git a/counter_test.go b/counter_test.go index 3933578..fc7460a 100644 --- a/counter_test.go +++ b/counter_test.go @@ -2,7 +2,7 @@ package chotki import ( "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" "github.com/stretchr/testify/assert" "testing" ) diff --git a/examples/object_example.go b/examples/object_example.go index 95890ed..28268b4 100644 --- a/examples/object_example.go +++ b/examples/object_example.go @@ -4,7 +4,7 @@ import ( "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) const ExampleName = 1 diff --git a/examples/object_example_test.go b/examples/object_example_test.go index f031923..9bdfb16 100644 --- a/examples/object_example_test.go +++ b/examples/object_example_test.go @@ -7,7 +7,7 @@ import ( "github.com/drpcorg/chotki" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" + "github.com/drpcorg/chotki/toyqueue" "github.com/stretchr/testify/assert" ) diff --git a/examples/objects_test.go b/examples/objects_test.go index c37dd63..7766d30 100644 --- a/examples/objects_test.go +++ b/examples/objects_test.go @@ -6,7 +6,7 @@ import ( "github.com/drpcorg/chotki" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" + "github.com/drpcorg/chotki/toyqueue" "github.com/stretchr/testify/assert" ) diff --git a/go.mod b/go.mod index efc1913..ef675a9 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,6 @@ go 1.21.4 require ( github.com/cockroachdb/pebble v1.1.0 github.com/ergochat/readline v0.1.0 - github.com/learn-decentralized-systems/toyqueue v0.1.5 - github.com/learn-decentralized-systems/toytlv v0.2.1 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.8.4 ) diff --git a/go.sum b/go.sum index 2149c7b..b0eda58 100644 --- a/go.sum +++ b/go.sum @@ -177,10 +177,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/learn-decentralized-systems/toyqueue v0.1.5 h1:X2EQEWj2dyaE5BUkE58aXsMG7mq8Uv6CpiFCErAnCMQ= -github.com/learn-decentralized-systems/toyqueue v0.1.5/go.mod h1:T5PrFDCcxA1O7hb2MAlHYYFA89ry8hvXUuwg+drS1UQ= -github.com/learn-decentralized-systems/toytlv v0.2.1 h1:nk+gjjE9JZ659kkbxIlv/H/gF5Wtst1Dbn7KckqdFOQ= -github.com/learn-decentralized-systems/toytlv v0.2.1/go.mod h1:+xzKS/La5vCkdyIdOFDb2NVPGF808tG5n5b3Ufxkorg= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= diff --git a/log0.go b/log0.go index 39e5a4b..fdcb799 100644 --- a/log0.go +++ b/log0.go @@ -2,8 +2,8 @@ package chotki import ( "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" ) const id1 = rdx.ID0 + rdx.ProInc diff --git a/objects.go b/objects.go index 3043b2a..19b62e9 100644 --- a/objects.go +++ b/objects.go @@ -4,8 +4,8 @@ import ( "fmt" "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" "github.com/pkg/errors" "unicode/utf8" ) diff --git a/op.go b/op.go index 6b2d0d1..d9286c0 100644 --- a/op.go +++ b/op.go @@ -2,7 +2,7 @@ package chotki import ( "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) func ParsePacket(pack []byte) (lit byte, id, ref rdx.ID, body []byte, err error) { diff --git a/packets.go b/packets.go index 73fb6b8..d95c83c 100644 --- a/packets.go +++ b/packets.go @@ -5,7 +5,7 @@ import ( "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) func (cho *Chotki) UpdateVTree(id, ref rdx.ID, pb *pebble.Batch) (err error) { diff --git a/rdx/ELM.go b/rdx/ELM.go index b7290be..d52ec17 100644 --- a/rdx/ELM.go +++ b/rdx/ELM.go @@ -3,7 +3,7 @@ package rdx import ( "bytes" "errors" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" "slices" "sort" ) diff --git a/rdx/ELM_test.go b/rdx/ELM_test.go index c1e74ad..080eef5 100644 --- a/rdx/ELM_test.go +++ b/rdx/ELM_test.go @@ -1,8 +1,8 @@ package rdx import ( - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" "github.com/stretchr/testify/assert" "testing" ) diff --git a/rdx/FIRST.go b/rdx/FIRST.go index 564b8aa..5ed972a 100644 --- a/rdx/FIRST.go +++ b/rdx/FIRST.go @@ -4,9 +4,9 @@ import ( "bytes" "errors" "fmt" - "github.com/learn-decentralized-systems/toyqueue" + "github.com/drpcorg/chotki/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) // Common LWW functions diff --git a/rdx/FIRST_test.go b/rdx/FIRST_test.go index 6ea26ee..bd81c8a 100644 --- a/rdx/FIRST_test.go +++ b/rdx/FIRST_test.go @@ -3,7 +3,7 @@ package rdx import ( "testing" - "github.com/learn-decentralized-systems/toyqueue" + "github.com/drpcorg/chotki/toyqueue" "github.com/stretchr/testify/assert" ) diff --git a/rdx/NZ.go b/rdx/NZ.go index 50bc30f..d28102b 100644 --- a/rdx/NZ.go +++ b/rdx/NZ.go @@ -2,7 +2,7 @@ package rdx import ( "fmt" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) // N is an increment-only uint64 counter diff --git a/rdx/NZ_test.go b/rdx/NZ_test.go index ade0446..a64d297 100644 --- a/rdx/NZ_test.go +++ b/rdx/NZ_test.go @@ -1,7 +1,7 @@ package rdx import ( - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" "github.com/stretchr/testify/assert" "testing" ) diff --git a/rdx/X.go b/rdx/X.go index d1aea10..1d29da9 100644 --- a/rdx/X.go +++ b/rdx/X.go @@ -2,7 +2,7 @@ package rdx import ( hex2 "encoding/hex" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) func Xparse(rdt byte, val string) (tlv []byte) { diff --git a/rdx/id.go b/rdx/id.go index c593861..04abd00 100644 --- a/rdx/id.go +++ b/rdx/id.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) /* diff --git a/rdx/rdx.go b/rdx/rdx.go index 9bc9d4c..32647a9 100644 --- a/rdx/rdx.go +++ b/rdx/rdx.go @@ -2,7 +2,7 @@ package rdx import ( "errors" - "github.com/learn-decentralized-systems/toyqueue" + "github.com/drpcorg/chotki/toyqueue" ) const ( diff --git a/rdx/vv.go b/rdx/vv.go index 90cc716..3510911 100644 --- a/rdx/vv.go +++ b/rdx/vv.go @@ -4,7 +4,7 @@ import ( "errors" "slices" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) // VV is a version vector, max ids seen from each known replica. diff --git a/repl/commands.go b/repl/commands.go index 2dee05b..fa41564 100644 --- a/repl/commands.go +++ b/repl/commands.go @@ -9,8 +9,8 @@ import ( "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" ) var HelpCreate = errors.New("create zone/1 {Name:\"Name\",Description:\"long text\"}") diff --git a/repl/repl.go b/repl/repl.go index bdce599..582c373 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -7,7 +7,7 @@ import ( "github.com/drpcorg/chotki" "github.com/drpcorg/chotki/rdx" "github.com/ergochat/readline" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" "io" "os" "strings" diff --git a/sync.go b/sync.go index c4b5973..467425b 100644 --- a/sync.go +++ b/sync.go @@ -5,8 +5,8 @@ import ( "fmt" "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" "io" "os" "sync" diff --git a/toyqueue/drainfeed.go b/toyqueue/drainfeed.go new file mode 100644 index 0000000..621ad12 --- /dev/null +++ b/toyqueue/drainfeed.go @@ -0,0 +1,65 @@ +package toyqueue + +import "io" + +// Records (a batch of) as a very universal primitive, especially +// for database/network op/packet processing. Batching allows +// for writev() and other performance optimizations. ALso, if +// you have cryptography, blobs are way handier than structs. +// Records converts easily to net.Buffers. +type Records [][]byte + +type Feeder interface { + // Feed reads and returns records. + // The EoF convention follows that of io.Reader: + // can either return `records, EoF` or + // `records, nil` followed by `nil/{}, EoF` + Feed() (recs Records, err error) +} + +type FeedCloser interface { + Feeder + io.Closer +} + +type FeedSeeker interface { + Feeder + io.Seeker +} + +type FeedSeekCloser interface { + Feeder + io.Seeker + io.Closer +} + +type Drainer interface { + Drain(recs Records) error +} + +type DrainSeeker interface { + Drainer + io.Seeker +} + +type DrainCloser interface { + Drainer + io.Closer +} + +type DrainSeekCloser interface { + Drainer + io.Seeker + io.Closer +} + +type FeedDrainer interface { + Feeder + Drainer +} + +type FeedDrainCloser interface { + Feeder + Drainer + io.Closer +} diff --git a/toyqueue/fanout.go b/toyqueue/fanout.go new file mode 100644 index 0000000..ef6f643 --- /dev/null +++ b/toyqueue/fanout.go @@ -0,0 +1,110 @@ +package toyqueue + +import ( + "errors" + "sync" +) + +type multiDrain struct { + drains []DrainCloser + lock sync.Mutex +} + +type Fanout struct { + multiDrain + feeder FeedCloser +} + +func FanoutFeeder(feeder FeedCloser) *Fanout { + return &Fanout{ + feeder: feeder, + } +} + +func FanoutQueue(limit int) (fanout *Fanout, queue DrainCloser) { + q := RecordQueue{Limit: limit} + fanout = &Fanout{feeder: &q} + queue = &q + return +} + +func (f2ds *multiDrain) AddDrain(drain DrainCloser) { + f2ds.lock.Lock() + f2ds.drains = append(f2ds.drains, drain) + f2ds.lock.Unlock() +} + +var ErrNotKnown = errors.New("unknown drain") + +func (f2ds *multiDrain) findDrain(drain DrainCloser) int { + i := 0 + l := len(f2ds.drains) + for i < l && f2ds.drains[i] != drain { + i++ + } + return i +} + +func (f2ds *multiDrain) RemoveDrain(drain DrainCloser) (err error) { + f2ds.lock.Lock() + l := len(f2ds.drains) + i := f2ds.findDrain(drain) + if i < l { + f2ds.drains[i] = f2ds.drains[l-1] + f2ds.drains = f2ds.drains[:l-1] + } else { + err = ErrNotKnown + } + f2ds.lock.Unlock() + return +} + +func (f2ds *multiDrain) HasDrain(drain DrainCloser) (has bool) { + f2ds.lock.Lock() + has = f2ds.findDrain(drain) < len(f2ds.drains) + f2ds.lock.Unlock() + return +} + +// Run shovels the data from the feeder to the drains. +func (f2ds *Fanout) Run() { + var ferr, derr error + for ferr == nil && derr == nil { + var recs Records + recs, ferr = f2ds.feeder.Feed() + if len(recs) > 0 { + f2ds.lock.Lock() + ds := f2ds.drains + f2ds.lock.Unlock() + for i := 0; i < len(ds) && derr == nil; i++ { + derr = ds[i].Drain(recs) + } + } + } + _ = f2ds.feeder.Close() + f2ds.lock.Lock() + ds := f2ds.drains + f2ds.drains = nil + f2ds.feeder = nil + f2ds.lock.Unlock() + for _, drain := range ds { + _ = drain.Close() + } +} + +type feederDrainer struct { + feed Feeder + drain Drainer +} + +func (fd *feederDrainer) Feed() (recs Records, err error) { + return fd.feed.Feed() +} + +func (fd *feederDrainer) Drain(recs Records) error { + return fd.drain.Drain(recs) +} + +func JoinedFeedDrainer(feeder Feeder, drainer Drainer) FeedDrainer { + return &feederDrainer{feed: feeder, drain: drainer} +} diff --git a/toyqueue/fanout_test.go b/toyqueue/fanout_test.go new file mode 100644 index 0000000..2f1278b --- /dev/null +++ b/toyqueue/fanout_test.go @@ -0,0 +1,63 @@ +package toyqueue + +import ( + "sync" + "testing" +) + +type counterFeed struct { + Counter int +} + +func (c *counterFeed) Feed() (Records, error) { + if c.Counter == 0 { + return nil, ErrClosed + } else { + //fmt.Printf("feed to send: %d\n", c.Counter) + c.Counter-- + return Records{[]byte{'C'}}, nil + } +} + +func (c *counterFeed) Close() error { + c.Counter = 0 + //fmt.Printf("feed closed\n") + return nil +} + +type counterDrain struct { + counter int + closed bool + group *sync.WaitGroup +} + +func (c *counterDrain) Drain(records Records) error { + c.counter += len(records) + //fmt.Printf("drain received: %d\n", c.counter) + return nil +} + +func (c *counterDrain) Close() error { + c.closed = true + if c.group != nil { + c.group.Add(-1) + } + //fmt.Printf("drain closed\n") + return nil +} + +func TestFanout(t *testing.T) { + var f2d Fanout + f := counterFeed{Counter: 5} + f2d.feeder = &f + wait := sync.WaitGroup{} + wait.Add(3) + c1 := counterDrain{group: &wait} + c2 := counterDrain{group: &wait} + c3 := counterDrain{group: &wait} + f2d.AddDrain(&c1) + f2d.AddDrain(&c2) + go f2d.Run() + f2d.AddDrain(&c3) + wait.Wait() +} diff --git a/toyqueue/queue.go b/toyqueue/queue.go new file mode 100644 index 0000000..7b95f64 --- /dev/null +++ b/toyqueue/queue.go @@ -0,0 +1,153 @@ +package toyqueue + +import ( + "errors" + "sync" +) + +func (recs Records) recrem(total int64) (prelen int, prerem int64) { + for len(recs) > prelen && int64(len(recs[prelen])) <= total { + total -= int64(len(recs[prelen])) + prelen++ + } + prerem = total + return +} + +func (recs Records) WholeRecordPrefix(limit int64) (prefix Records, remainder int64) { + prelen, remainder := recs.recrem(limit) + prefix = recs[:prelen] + return +} + +func (recs Records) ExactSuffix(total int64) (suffix Records) { + prelen, prerem := recs.recrem(total) + suffix = recs[prelen:] + if prerem != 0 { // damages the original, hence copy + edited := make(Records, 1, len(suffix)) + edited[0] = suffix[0][prerem:] + suffix = append(edited, suffix[1:]...) + } + return +} + +func (recs Records) TotalLen() (total int64) { + for _, r := range recs { + total += int64(len(r)) + } + return +} + +type RecordQueue struct { + recs Records + lock sync.Mutex + cond sync.Cond + Limit int +} + +var ErrWouldBlock = errors.New("the queue is over capacity") +var ErrClosed = errors.New("queue is closed") + +func (q *RecordQueue) Drain(recs Records) error { + q.lock.Lock() + was0 := len(q.recs) == 0 + if len(q.recs)+len(recs) > q.Limit { + q.lock.Unlock() + if q.Limit == 0 { + return ErrClosed + } + return ErrWouldBlock + } + q.recs = append(q.recs, recs...) + if was0 && q.cond.L != nil { + q.cond.Broadcast() + } + q.lock.Unlock() + return nil +} + +func (q *RecordQueue) Close() error { + q.Limit = 0 + return nil +} + +func (q *RecordQueue) Feed() (recs Records, err error) { + q.lock.Lock() + if len(q.recs) == 0 { + err = ErrWouldBlock + if q.Limit == 0 { + err = ErrClosed + } + q.lock.Unlock() + return + } + wasfull := len(q.recs) >= q.Limit + recs = q.recs + q.recs = q.recs[len(q.recs):] + if wasfull && q.cond.L != nil { + q.cond.Broadcast() + } + q.lock.Unlock() + return +} + +func (q *RecordQueue) Blocking() FeedDrainCloser { + if q.cond.L == nil { + q.cond.L = &q.lock + } + return &blockingRecordQueue{q} +} + +type blockingRecordQueue struct { + queue *RecordQueue +} + +func (bq *blockingRecordQueue) Close() error { + return bq.queue.Close() +} + +func (bq *blockingRecordQueue) Drain(recs Records) error { + q := bq.queue + q.lock.Lock() + for len(recs) > 0 { + was0 := len(q.recs) == 0 + for q.Limit <= len(q.recs) { + if q.Limit == 0 { + q.lock.Unlock() + return ErrClosed + } + q.cond.Wait() + } + qcap := q.Limit - len(q.recs) + if qcap > len(recs) { + qcap = len(recs) + } + q.recs = append(q.recs, recs[:qcap]...) + recs = recs[qcap:] + if was0 { + q.cond.Broadcast() + } + } + q.lock.Unlock() + return nil +} + +func (bq *blockingRecordQueue) Feed() (recs Records, err error) { + q := bq.queue + q.lock.Lock() + wasfull := len(q.recs) >= q.Limit + for len(q.recs) == 0 { + if q.Limit == 0 { + q.lock.Unlock() + return nil, ErrClosed + } + q.cond.Wait() + } + recs = q.recs + q.recs = q.recs[len(q.recs):] + if wasfull { + q.cond.Broadcast() + } + q.lock.Unlock() + return +} diff --git a/toyqueue/queue_test.go b/toyqueue/queue_test.go new file mode 100644 index 0000000..8d7d902 --- /dev/null +++ b/toyqueue/queue_test.go @@ -0,0 +1,62 @@ +package toyqueue + +import ( + "encoding/binary" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestBlockingRecordQueue_Drain(t *testing.T) { + const N = 1 << 10 // 8K + const K = 1 << 4 // 16 + + orig := RecordQueue{Limit: 1024} + queue := orig.Blocking() + + for k := 0; k < K; k++ { + go func(k int) { + i := uint64(k) << 32 + for n := uint64(0); n < N; n++ { + var b [8]byte + binary.LittleEndian.PutUint64(b[:], i|n) + err := queue.Drain(Records{b[:]}) + assert.Nil(t, err) + } + }(k) + } + + check := [K]int{} + for i := uint64(0); i < N*K; { + nums, err := queue.Feed() + assert.Nil(t, err) + for _, num := range nums { + assert.Equal(t, 8, len(num)) + j := binary.LittleEndian.Uint64(num) + k := int(j >> 32) + n := int(j & 0xffffffff) + assert.Equal(t, check[k], n) + check[k] = n + 1 + i++ + } + } + + recs := [][]byte{{'a'}} + assert.Nil(t, queue.Close()) + err := queue.Drain(recs) + assert.Equal(t, ErrClosed, err) + _, err2 := queue.Feed() + assert.Equal(t, ErrClosed, err2) + +} + +func TestTwoWayQueue_Drain(t *testing.T) { + a, b := BlockingRecordQueuePair(1) + recs := Records{{'a'}} + go func() { + err := a.Drain(recs) + assert.Nil(t, err) + }() + recs2, err := b.Feed() + assert.Nil(t, err) + assert.Equal(t, recs, recs2) +} diff --git a/toyqueue/twoway.go b/toyqueue/twoway.go new file mode 100644 index 0000000..e539dd3 --- /dev/null +++ b/toyqueue/twoway.go @@ -0,0 +1,38 @@ +package toyqueue + +type twoWayQueue struct { + in DrainCloser + out FeedCloser +} + +func RecordQueuePair(limit int) (i, o FeedDrainCloser) { + a := RecordQueue{Limit: limit} + b := RecordQueue{Limit: limit} + i = &twoWayQueue{in: &a, out: &b} + o = &twoWayQueue{in: &b, out: &a} + return +} + +func BlockingRecordQueuePair(limit int) (i, o FeedDrainCloser) { + _a, _b := RecordQueue{Limit: limit}, RecordQueue{Limit: limit} + a, b := _a.Blocking(), _b.Blocking() + i = &twoWayQueue{in: a, out: b} + o = &twoWayQueue{in: b, out: a} + return +} + +func (tw *twoWayQueue) Feed() (recs Records, err error) { + return tw.out.Feed() +} + +func (tw *twoWayQueue) Drain(recs Records) error { + return tw.in.Drain(recs) +} + +func (tw *twoWayQueue) Close() (err error) { + err = tw.in.Close() + if err == nil { + err = tw.out.Close() + } + return +} diff --git a/toyqueue/twoway_test.go b/toyqueue/twoway_test.go new file mode 100644 index 0000000..40bc5e7 --- /dev/null +++ b/toyqueue/twoway_test.go @@ -0,0 +1,26 @@ +package toyqueue + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestTwoWayQueue_Feed(t *testing.T) { + a, b := BlockingRecordQueuePair(2) + err := a.Drain(Records{[]byte{'A'}, []byte{'B', 'B'}}) + assert.Nil(t, err) + go func() { + time.Sleep(time.Millisecond * 10) // well... + recs, err := b.Feed() + assert.Nil(t, err) + assert.Equal(t, 2, len(recs)) + assert.Equal(t, int64(3), recs.TotalLen()) + recs, err = b.Feed() + assert.Nil(t, err) + assert.Equal(t, 1, len(recs)) + assert.Equal(t, int64(3), recs.TotalLen()) + }() + err = a.Drain(Records{[]byte{'C', 'C', 'C'}}) + assert.Nil(t, err) +} diff --git a/toyqueue/util.go b/toyqueue/util.go new file mode 100644 index 0000000..728bbb6 --- /dev/null +++ b/toyqueue/util.go @@ -0,0 +1,46 @@ +package toyqueue + +func Relay(feeder Feeder, drainer Drainer) error { + recs, err := feeder.Feed() + if err != nil { + if len(recs) > 0 { + _ = drainer.Drain(recs) + } + return err + } + err = drainer.Drain(recs) + return err +} + +func Pump(feeder Feeder, drainer Drainer) (err error) { + for err == nil { + err = Relay(feeder, drainer) + } + return +} + +func PumpN(feeder Feeder, drainer Drainer, n int) (err error) { + for err == nil && n > 0 { + err = Relay(feeder, drainer) + n-- + } + return +} + +func PumpThenClose(feed FeedCloser, drain DrainCloser) error { + var ferr, derr error + for ferr == nil && derr == nil { + var recs Records + recs, ferr = feed.Feed() + if len(recs) > 0 { // e.g. Feed() may return data AND EOF + derr = drain.Drain(recs) + } + } + _ = feed.Close() + _ = drain.Close() + if ferr != nil { + return ferr + } else { + return derr + } +} diff --git a/toyqueue/util_test.go b/toyqueue/util_test.go new file mode 100644 index 0000000..8fb0705 --- /dev/null +++ b/toyqueue/util_test.go @@ -0,0 +1,54 @@ +package toyqueue + +import ( + "github.com/stretchr/testify/assert" + "io" + "testing" +) + +type sliceFeedDrainer struct { + data []byte + res []byte +} + +func (fd *sliceFeedDrainer) Close() error { + fd.res = append(fd.res, '(') + fd.res = append(fd.res, fd.data...) + fd.res = append(fd.res, ')') + return nil +} + +func (fd *sliceFeedDrainer) Drain(recs Records) error { + for _, rec := range recs { + fd.data = append(fd.data, rec...) + } + return nil +} + +func (fd *sliceFeedDrainer) Feed() (recs Records, err error) { + for i := 0; i < 3 && len(fd.data) > 0; i++ { + recs = append(recs, fd.data[0:1]) + fd.data = fd.data[1:] + } + if len(fd.data) == 0 { + err = io.EOF + } + return +} + +func TestPump(t *testing.T) { + sfd := sliceFeedDrainer{ + data: []byte("Hello world"), + } + err := PumpN(&sfd, &sfd, 2) + assert.Nil(t, err) + assert.Equal(t, sfd.data, []byte("worldHello ")) + + fro := sliceFeedDrainer{ + data: []byte("Hello world"), + } + to := sliceFeedDrainer{} + err = PumpThenClose(&fro, &to) + assert.Equal(t, err, io.EOF) + assert.Equal(t, []byte("(Hello world)"), to.res) +} diff --git a/toytlv/reader.go b/toytlv/reader.go new file mode 100644 index 0000000..bb727c9 --- /dev/null +++ b/toytlv/reader.go @@ -0,0 +1,180 @@ +package toytlv + +import ( + "github.com/drpcorg/chotki/toyqueue" + "io" +) + +// Feeder reads TLV records from an io.Reader stream. +// Note that Feeder is buffered, i.e. it reads ahead. +// When doing Seek() on a file, recreate Feeder, that is cheap. +type Reader2Feeder struct { + pre []byte + Reader io.Reader +} + +type ReadSeeker2FeedSeeker struct { + pre []byte + Reader io.ReadSeeker +} + +type ReadCloser2FeedCloser struct { + pre []byte + Reader io.ReadCloser +} + +type ReadSeekCloser2FeedSeekCloser struct { + pre []byte + Reader io.ReadSeekCloser +} + +const DefaultPreBufLength = 4096 +const MinRecommendedRead = 512 +const MinRecommendedWrite = 400 + +func (fs *ReadSeeker2FeedSeeker) Seek(offset int64, whence int) (int64, error) { + fs.pre = nil + return fs.Reader.Seek(offset, whence) +} + +func (fs *ReadSeekCloser2FeedSeekCloser) Seek(offset int64, whence int) (int64, error) { + fs.pre = nil + return fs.Reader.Seek(offset, whence) +} + +func (fs *ReadCloser2FeedCloser) Close() error { + fs.pre = nil + return fs.Reader.Close() +} + +func (fs *ReadSeekCloser2FeedSeekCloser) Close() error { + fs.pre = nil + return fs.Reader.Close() +} + +func (fs *Reader2Feeder) Feed() (recs toyqueue.Records, err error) { + fs.pre, recs, err = feed(fs.pre, fs.Reader) + return +} + +func (fs *ReadSeeker2FeedSeeker) Feed() (recs toyqueue.Records, err error) { + fs.pre, recs, err = feed(fs.pre, fs.Reader) + return +} + +func (fs *ReadCloser2FeedCloser) Feed() (recs toyqueue.Records, err error) { + fs.pre, recs, err = feed(fs.pre, fs.Reader) + return +} + +func (fs *ReadSeekCloser2FeedSeekCloser) Feed() (recs toyqueue.Records, err error) { + fs.pre, recs, err = feed(fs.pre, fs.Reader) + return +} + +func fill(past []byte, tolen int, reader io.Reader) (data []byte, err error) { + data = past + l := len(data) + c := cap(data) + if c-l < MinRecommendedRead || c < tolen { + newcap := DefaultPreBufLength + if newcap < tolen { + newcap = tolen + } + newpre := make([]byte, newcap) + copy(newpre, data) + newpre = newpre[:l] + data = newpre + l = len(data) + c = cap(data) + } + for len(data) < tolen { + vac := data[l:c] + var n int + n, err = reader.Read(vac) + if err != nil { + break + } + data = data[0 : l+n] + } + return +} + +func feed(past []byte, reader io.Reader) (rest []byte, tlv toyqueue.Records, err error) { + rest = past + var hdrlen, bodylen int + var lit byte + lit, hdrlen, bodylen = ProbeHeader(rest) + for lit == 0 || hdrlen+bodylen > len(rest) { + tolen := len(rest) + 1 + if lit != 0 { + tolen = hdrlen + bodylen + } + rest, err = fill(rest, tolen, reader) + if err != nil { + return + } + lit, hdrlen, bodylen = ProbeHeader(rest) + } + for lit >= 'A' && lit <= 'Z' && hdrlen+bodylen <= len(rest) { + tlv = append(tlv, rest[0:hdrlen+bodylen]) + rest = rest[hdrlen+bodylen:] + lit, hdrlen, bodylen = ProbeHeader(rest) + } + if lit == '-' { + err = ErrBadRecord + } + return +} + +type Writer2Drainer struct { + Writer io.Writer +} + +type WritCloser2DrainCloser struct { + Writer io.WriteCloser +} + +func next(rest []byte, more toyqueue.Records) (cur []byte, left toyqueue.Records) { + cur, left = rest, more + if len(cur) >= MinRecommendedWrite { + return + } + for len(cur) < MinRecommendedWrite && len(left) > 0 { + cur = append(cur, left[0]...) + left = left[1:] + } + return +} + +// Having no writev() we do the next best thing: bundle writes +func (d *Writer2Drainer) Drain(recs toyqueue.Records) error { + var cur []byte + for len(cur) > 0 || len(recs) > 0 { + cur, recs = next(cur, recs) + n, err := d.Writer.Write(cur) + if err != nil { + return err + } + cur = cur[n:] + } + return nil +} + +// Having no writev() we do the next best thing: bundle writes +func (d *WritCloser2DrainCloser) Drain(recs toyqueue.Records) error { + var cur []byte + for len(cur) > 0 || len(recs) > 0 { + cur, recs = next(cur, recs) + n, err := d.Writer.Write(cur) + if err != nil { + return err + } + cur = cur[n:] + } + return nil +} + +func (dc *WritCloser2DrainCloser) Close() error { + return dc.Writer.Close() +} diff --git a/toytlv/tcp.go b/toytlv/tcp.go new file mode 100644 index 0000000..6358674 --- /dev/null +++ b/toytlv/tcp.go @@ -0,0 +1,347 @@ +package toytlv + +import ( + "errors" + "fmt" + "github.com/drpcorg/chotki/toyqueue" + "io" + "net" + "os" + "sync" + "time" +) + +const MaxOutQueueLen = 1 << 20 // 16MB of pointers is a lot + +type TCPConn struct { + depot *TCPDepot + addr string + conn net.Conn + inout toyqueue.FeedDrainCloser + wake *sync.Cond + outmx sync.Mutex + Reconnect bool + KeepAlive bool +} + +type Jack func(conn net.Conn) toyqueue.FeedDrainCloser + +// A TCP server/client for the use case of real-time async communication. +// Differently from the case of request-response (like HTTP), we do not +// wait for a request, then dedicating a thread to processing, then sending +// back the resulting response. Instead, we constantly fan sendQueue tons of +// tiny messages. That dictates different work patterns than your typical +// HTTP/RPC server as, for example, we cannot let one slow receiver delay +// event transmission to all the other receivers. +type TCPDepot struct { + conns map[string]*TCPConn + listens map[string]net.Listener + conmx sync.Mutex + jack Jack +} + +func (de *TCPDepot) Open(jack Jack) { + de.conmx.Lock() + de.conns = make(map[string]*TCPConn) + de.listens = make(map[string]net.Listener) + de.conmx.Unlock() + de.jack = jack +} + +func (de *TCPDepot) Close() { + for _, lstn := range de.listens { + _ = lstn.Close() + } + de.listens = nil + for _, con := range de.conns { + con.Close() + } + de.conmx.Lock() + de.conns = make(map[string]*TCPConn) + de.listens = make(map[string]net.Listener) + de.conmx.Unlock() +} + +func (tcp *TCPConn) Close() { + // TODO writer closes on complete | 1 sec expired + tcp.outmx.Lock() + if tcp.conn != nil { + _ = tcp.conn.Close() + tcp.conn = nil + tcp.wake.Broadcast() + } + tcp.outmx.Unlock() +} + +var ErrAddressUnknown = errors.New("address unknown") + +const MAX_RETRY_PERIOD = time.Minute +const MIN_RETRY_PERIOD = time.Second / 2 + +// attrib?! +func (de *TCPDepot) Connect(addr string) (err error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return err + } + peer := TCPConn{ + depot: de, + conn: conn, + addr: addr, + inout: de.jack(conn), + } + peer.wake = sync.NewCond(&peer.outmx) + de.conmx.Lock() + de.conns[addr] = &peer + de.conmx.Unlock() + go peer.KeepTalking() + return nil +} + +var ErrDisconnected = errors.New("disconnected by user") + +func (tcp *TCPConn) KeepTalking() { + talk_backoff := MIN_RETRY_PERIOD + conn_backoff := MIN_RETRY_PERIOD + for { + + conntime := time.Now() + go tcp.doWrite() + err := tcp.Read() + + if !tcp.Reconnect { + break + } + + atLeast5min := conntime.Add(time.Minute * 5) + if atLeast5min.After(time.Now()) { + talk_backoff *= 2 // connected, tried to talk, failed => wait more + if talk_backoff > MAX_RETRY_PERIOD { + talk_backoff = MAX_RETRY_PERIOD + } + } + + for tcp.conn == nil { + time.Sleep(conn_backoff + talk_backoff) + tcp.conn, err = net.Dial("tcp", tcp.addr) + if err != nil { + conn_backoff = conn_backoff * 2 + if conn_backoff > MAX_RETRY_PERIOD/2 { + conn_backoff = MAX_RETRY_PERIOD + } + } else { + conn_backoff = MIN_RETRY_PERIOD + } + } + + } +} + +// Write what we believe is a valid ToyTLV frame. +// Provided for io.Writer compatibility +func (tcp *TCPConn) Write(data []byte) (n int, err error) { + err = tcp.Drain(toyqueue.Records{data}) + if err == nil { + n = len(data) + } + return +} + +func (tcp *TCPConn) Drain(recs toyqueue.Records) (err error) { + return tcp.inout.Drain(recs) +} + +func (tcp *TCPConn) Feed() (recs toyqueue.Records, err error) { + return tcp.inout.Feed() +} + +func (de *TCPDepot) DrainTo(recs toyqueue.Records, addr string) error { + de.conmx.Lock() + conn, ok := de.conns[addr] + de.conmx.Unlock() + if !ok { + return ErrAddressUnknown + } + return conn.Drain(recs) +} + +func (de *TCPDepot) Disconnect(addr string) (err error) { + de.conmx.Lock() + tcp, ok := de.conns[addr] + de.conmx.Unlock() + if !ok { + return ErrAddressUnknown + } + tcp.Close() + de.conmx.Lock() + delete(de.conns, addr) + de.conmx.Unlock() + return nil +} + +func (de *TCPDepot) Listen(addr string) (err error) { + listener, err := net.Listen("tcp", addr) + if err != nil { + return + } + de.conmx.Lock() + pre, ok := de.listens[addr] + if ok { + _ = pre.Close() + } + de.listens[addr] = listener + de.conmx.Unlock() + go de.KeepListening(addr) + return +} + +func (de *TCPDepot) StopListening(addr string) error { + de.conmx.Lock() + listener, ok := de.listens[addr] + delete(de.listens, addr) + de.conmx.Unlock() + if !ok { + return ErrAddressUnknown + } + return listener.Close() +} + +func (de *TCPDepot) KeepListening(addr string) { + for { + de.conmx.Lock() + listener, ok := de.listens[addr] + de.conmx.Unlock() + if !ok { + break + } + conn, err := listener.Accept() + if err != nil { + break + } + addr := conn.RemoteAddr().String() + peer := TCPConn{ + depot: de, + conn: conn, + addr: addr, + inout: de.jack(conn), + } + peer.wake = sync.NewCond(&peer.outmx) + de.conmx.Lock() + de.conns[addr] = &peer + de.conmx.Unlock() + + go peer.doWrite() + go peer.doRead() + + } +} + +func (tcp *TCPConn) doRead() { + err := tcp.Read() + if err != nil && err != ErrDisconnected { + } +} + +func (tcp *TCPConn) doWrite() { + conn := tcp.conn + var err error + var recs toyqueue.Records + for conn != nil && err == nil { + recs, err = tcp.inout.Feed() + b := net.Buffers(recs) + for len(b) > 0 && err == nil { + _, err = b.WriteTo(conn) + } + } + if err != nil { + tcp.Close() // TODO err + } +} + +const TYPICAL_MTU = 1500 + +func (tcp *TCPConn) Read() (err error) { + var buf []byte + conn := tcp.conn + for conn != nil { + buf, err = AppendRead(buf, conn, TYPICAL_MTU) + if err != nil { + break + } + var recs toyqueue.Records + recs, buf, err = Split(buf) + if len(recs) == 0 { + time.Sleep(time.Millisecond) + continue + } + if err != nil { + break + } + + err = tcp.inout.Drain(recs) + if err != nil { + break + } + + conn = tcp.conn + } + + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, err.Error()) + tcp.Close() + } + return +} + +func ReadBuf(buf []byte, rdr io.Reader) ([]byte, error) { + avail := cap(buf) - len(buf) + if avail < 512 { + l := 4096 + if len(buf) > 2048 { + l = len(buf) * 2 + } + newbuf := make([]byte, l) + copy(newbuf[:], buf) + buf = newbuf[:len(buf)] + } + idle := buf[len(buf):cap(buf)] + n, err := rdr.Read(idle) + if err != nil { + return buf, err + } + if n == 0 { + return buf, io.EOF + } + buf = buf[:len(buf)+n] + return buf, nil +} + +func RoundPage(l int) int { + if (l & 0xfff) != 0 { + l = (l & ^0xfff) + 0x1000 + } + return l +} + +// AppendRead reads data from io.Reader into the *spare space* of the provided buffer, +// i.e. those cap(buf)-len(buf) vacant bytes. If the spare space is smaller than +// lenHint, allocates (as reading less bytes might be unwise). +func AppendRead(buf []byte, rdr io.Reader, lenHint int) ([]byte, error) { + avail := cap(buf) - len(buf) + if avail < lenHint { + want := RoundPage(len(buf) + lenHint) + newbuf := make([]byte, want) + copy(newbuf[:], buf) + buf = newbuf[:len(buf)] + } + idle := buf[len(buf):cap(buf)] + n, err := rdr.Read(idle) + if err != nil { + return buf, err + } + if n == 0 { + return buf, io.EOF + } + buf = buf[:len(buf)+n] + return buf, nil +} diff --git a/toytlv/tcp_test.go b/toytlv/tcp_test.go new file mode 100644 index 0000000..57b47e4 --- /dev/null +++ b/toytlv/tcp_test.go @@ -0,0 +1,86 @@ +package toytlv + +import ( + "github.com/drpcorg/chotki/toyqueue" + "github.com/stretchr/testify/assert" + "net" + "sync" + "testing" +) + +// 1. create a server, create a client, echo +// 2. create a server, client, connect, disconn, reconnect +// 3. create a server, client, conn, stop the serv, relaunch, reconnect + +type TestConsumer struct { + rcvd toyqueue.Records + mx sync.Mutex + co sync.Cond +} + +func (c *TestConsumer) Drain(recs toyqueue.Records) error { + c.mx.Lock() + c.rcvd = append(c.rcvd, recs...) + c.co.Signal() + c.mx.Unlock() + return nil +} + +func (c *TestConsumer) Feed() (recs toyqueue.Records, err error) { + c.mx.Lock() + if len(c.rcvd) == 0 { + c.co.Wait() + } + recs = c.rcvd + c.rcvd = c.rcvd[len(c.rcvd):] + c.mx.Unlock() + return +} + +func (c *TestConsumer) Close() error { + return nil +} + +func TestTCPDepot_Connect(t *testing.T) { + + loop := "127.0.0.1:12345" + + tc := TestConsumer{} + tc.co.L = &tc.mx + depot := TCPDepot{} + addr := "" + depot.Open(func(conn net.Conn) toyqueue.FeedDrainCloser { + a := conn.RemoteAddr().String() + if a != loop { + addr = a + } + return &tc + }) + + err := depot.Listen(loop) + assert.Nil(t, err) + + err = depot.Connect(loop) + assert.Nil(t, err) + + // send a record + recsto := toyqueue.Records{Record('M', []byte("Hi there"))} + err = depot.DrainTo(recsto, loop) + rec, err := tc.Feed() + lit, body, rest := TakeAny(rec[0]) + assert.Equal(t, uint8('M'), lit) + assert.Equal(t, "Hi there", string(body)) + assert.Equal(t, 0, len(rest)) + + // respond to that + recsback := toyqueue.Records{Record('M', []byte("Re: Hi there"))} + err = depot.DrainTo(recsback, addr) + rerec, err := tc.Feed() + relit, rebody, rerest := TakeAny(rerec[0]) + assert.Equal(t, uint8('M'), relit) + assert.Equal(t, "Re: Hi there", string(rebody)) + assert.Equal(t, 0, len(rerest)) + + depot.Close() + +} diff --git a/toytlv/tlv.go b/toytlv/tlv.go new file mode 100644 index 0000000..5225ef0 --- /dev/null +++ b/toytlv/tlv.go @@ -0,0 +1,317 @@ +package toytlv + +import ( + "encoding/binary" + "errors" + "github.com/drpcorg/chotki/toyqueue" +) + +const CaseBit uint8 = 'a' - 'A' + +var ErrIncomplete = errors.New("incomplete data") +var ErrBadRecord = errors.New("bad TLV record format") + +// ProbeHeader probes a TLV record header. Return values: +// - 0 0 0 incomplete header +// - '-' 0 0 bad format +// - 'A' 2 123 success +func ProbeHeader(data []byte) (lit byte, hdrlen, bodylen int) { + if len(data) == 0 { + return 0, 0, 0 + } + dlit := data[0] + if dlit >= '0' && dlit <= '9' { // tiny + lit = '0' + bodylen = int(dlit - '0') + hdrlen = 1 + } else if dlit >= 'a' && dlit <= 'z' { // short + if len(data) < 2 { + return + } + lit = dlit - CaseBit + hdrlen = 2 + bodylen = int(data[1]) + } else if dlit >= 'A' && dlit <= 'Z' { // long + if len(data) < 5 { + return + } + bl := binary.LittleEndian.Uint32(data[1:5]) + if bl > 0x7fffffff { + lit = '-' + return + } + lit = dlit + bodylen = int(bl) + hdrlen = 5 + } else { + lit = '-' + } + return +} + +// Incomplete returns the number of supposedly yet-unread bytes. +// 0 for complete, -1 for bad format, +// >0 for least-necessary read to complete either header or record. +func Incomplete(data []byte) int { + if len(data) == 0 { + return 1 // get something + } + dlit := data[0] + bodylen := 1 + if dlit >= '0' && dlit <= '9' { // tiny + bodylen = int(dlit - '0') + } else if dlit >= 'a' && dlit <= 'z' { // short + if len(data) < 2 { + bodylen = 2 + } else { + bodylen = int(data[1]) + 2 + } + } else if dlit >= 'A' && dlit <= 'Z' { // long + if len(data) < 5 { + bodylen = 5 + } else { + bl := binary.LittleEndian.Uint32(data[1:5]) + if bl > 0x7fffffff { + return -1 + } + bodylen = int(bl) + 5 + } + } else { + return -1 + } + if bodylen > len(data) { + return bodylen - len(data) + } else { + return 0 + } +} + +func Split(data []byte) (recs toyqueue.Records, rest []byte, err error) { + rest = data + for len(rest) > 0 { + lit, hlen, blen := ProbeHeader(rest) + if lit == '-' { + if len(recs) == 0 { + err = ErrBadRecord + } + return + } + if lit == 0 { + return + } + if hlen+blen > len(rest) { + break + } + recs = append(recs, rest[:hlen+blen]) + rest = rest[hlen+blen:] + } + return +} + +func ProbeHeaders(lits string, data []byte) int { + rest := data + for i := 0; i < len(lits); i++ { + l, hl, bl := ProbeHeader(rest) + if l != lits[i] { + return -1 + } + rest = rest[hl+bl:] + } + return len(data) - len(rest) +} + +// Feeds the header into the buffer. +// Subtle: lower-case lit allows for defaulting, uppercase must be explicit. +func AppendHeader(into []byte, lit byte, bodylen int) (ret []byte) { + biglit := lit &^ CaseBit + if biglit < 'A' || biglit > 'Z' { + panic("ToyTLV record type is A..Z") + } + if bodylen < 10 && (lit&CaseBit) != 0 { + ret = append(into, byte('0'+bodylen)) + } else if bodylen > 0xff { + if bodylen > 0x7fffffff { + panic("oversized TLV record") + } + ret = append(into, biglit) + ret = binary.LittleEndian.AppendUint32(ret, uint32(bodylen)) + } else { + ret = append(into, lit|CaseBit, byte(bodylen)) + } + return ret +} + +// Take is used to read safe TLV inputs (e.g. from own storage) with +// record types known in advance. +func Take(lit byte, data []byte) (body, rest []byte) { + flit, hdrlen, bodylen := ProbeHeader(data) + if flit == 0 || hdrlen+bodylen > len(data) { + return nil, data // Incomplete + } + if flit != lit && flit != '0' { + return nil, nil // BadRecord + } + body = data[hdrlen : hdrlen+bodylen] + rest = data[hdrlen+bodylen:] + return +} + +// TakeAny is used for safe TLV inputs when record types can vary. +func TakeAny(data []byte) (lit byte, body, rest []byte) { + if len(data) == 0 { + return 0, nil, nil + } + lit = data[0] & ^CaseBit + body, rest = Take(lit, data) + return +} + +// TakeWary reads TLV records of known type from unsafe input. +func TakeWary(lit byte, data []byte) (body, rest []byte, err error) { + flit, hdrlen, bodylen := ProbeHeader(data) + if flit == 0 || hdrlen+bodylen > len(data) { + return nil, data, ErrIncomplete + } + if flit != lit && flit != '0' { + return nil, nil, ErrBadRecord + } + body = data[hdrlen : hdrlen+bodylen] + rest = data[hdrlen+bodylen:] + return +} + +// TakeWary reads TLV records of arbitrary type from unsafe input. +func TakeAnyWary(data []byte) (lit byte, body, rest []byte, err error) { + if len(data) == 0 { + return 0, nil, nil, ErrIncomplete + } + lit = data[0] & ^CaseBit + body, rest = Take(lit, data) + return +} + +func TakeRecord(lit byte, data []byte) (rec, rest []byte) { + flit, hdrlen, bodylen := ProbeHeader(data) + if flit == 0 || hdrlen+bodylen > len(data) { + return nil, data // Incomplete + } + if flit != lit && flit != '0' { + return nil, nil // BadRecord + } + rec = data[0 : hdrlen+bodylen] + rest = data[hdrlen+bodylen:] + return +} + +func TakeAnyRecord(data []byte) (lit byte, rec, rest []byte) { + lit, hdrlen, bodylen := ProbeHeader(data) + if lit == 0 || hdrlen+bodylen > len(data) { + return 0, nil, data // Incomplete + } + if lit == '-' { + return '-', nil, nil // BadRecord + } + rec = data[0 : hdrlen+bodylen] + rest = data[hdrlen+bodylen:] + return +} + +func TotalLen(inputs [][]byte) (sum int) { + for _, input := range inputs { + sum += len(input) + } + return +} + +func Lit(rec []byte) byte { + b := rec[0] + if b >= 'a' && b <= 'z' { + return b - CaseBit + } else if b >= 'A' && b <= 'Z' { + return b + } else if b >= '0' && b <= '9' { + return '0' + } else { + return '-' + } +} + +// Append appends a record to the buffer; note that uppercase type +// is always explicit, lowercase can be defaulted. +func Append(into []byte, lit byte, body ...[]byte) (res []byte) { + total := TotalLen(body) + res = AppendHeader(into, lit, total) + for _, b := range body { + res = append(res, b...) + } + return res +} + +// Record composes a record of a given type +func Record(lit byte, body ...[]byte) []byte { + total := TotalLen(body) + ret := make([]byte, 0, total+5) + ret = AppendHeader(ret, lit, total) + for _, b := range body { + ret = append(ret, b...) + } + return ret +} + +func AppendTiny(into []byte, lit byte, body []byte) (res []byte) { + if len(body) > 9 { + return Append(into, lit, body) + } + res = append(into, '0'+byte(len(body))) + res = append(res, body...) + return +} + +func TinyRecord(lit byte, body []byte) (tiny []byte) { + var data [10]byte + return AppendTiny(data[:0], lit, body) +} + +func Join(records ...[]byte) (ret toyqueue.Records) { + for _, rec := range records { + ret = append(ret, rec) + } + return +} + +func Records(lit byte, bodies ...[]byte) (recs toyqueue.Records) { + for _, body := range bodies { + recs = append(recs, Record(lit, body)) + } + return +} + +func Concat(msg ...[]byte) []byte { + total := TotalLen(msg) + ret := make([]byte, 0, total) + for _, b := range msg { + ret = append(ret, b...) + } + return ret +} + +// OpenHeader opens a streamed TLV record; use append() to create the +// record body, then call CloseHeader(&buf, bookmark) +func OpenHeader(buf []byte, lit byte) (bookmark int, res []byte) { + lit &= ^CaseBit + if lit < 'A' || lit > 'Z' { + panic("TLV liters are uppercase A-Z") + } + res = append(buf, lit) + blanclen := []byte{0, 0, 0, 0} + res = append(res, blanclen...) + return len(res), res +} + +// CloseHeader closes a streamed TLV record +func CloseHeader(buf []byte, bookmark int) { + if bookmark < 5 || len(buf) < bookmark { + panic("check the API docs") + } + binary.LittleEndian.PutUint32(buf[bookmark-4:bookmark], uint32(len(buf)-bookmark)) +} diff --git a/toytlv/tlv_test.go b/toytlv/tlv_test.go new file mode 100644 index 0000000..b16760e --- /dev/null +++ b/toytlv/tlv_test.go @@ -0,0 +1,115 @@ +package toytlv + +import ( + "github.com/stretchr/testify/assert" + "io" + "os" + "testing" +) + +func TestTLVAppend(t *testing.T) { + buf := []byte{} + buf = Append(buf, 'A', []byte{'A'}) + buf = Append(buf, 'b', []byte{'B', 'B'}) + correct2 := []byte{'a', 1, 'A', '2', 'B', 'B'} + assert.Equal(t, correct2, buf, "basic TLV fail") + + var c256 [256]byte + for n, _ := range c256 { + c256[n] = 'c' + } + buf = Append(buf, 'C', c256[:]) + assert.Equal(t, len(correct2)+1+4+len(c256), len(buf)) + assert.Equal(t, uint8(67), buf[len(correct2)]) + assert.Equal(t, uint8(1), buf[len(correct2)+2]) + + lit, body, buf, err := TakeAnyWary(buf) + assert.Nil(t, err) + assert.Equal(t, uint8('A'), lit) + assert.Equal(t, []byte{'A'}, body) + + body2, buf, err2 := TakeWary('B', buf) + assert.Nil(t, err2) + assert.Equal(t, []byte{'B', 'B'}, body2) +} + +func TestFeedHeader(t *testing.T) { + buf := []byte{} + l, buf := OpenHeader(buf, 'A') + text := "some text" + buf = append(buf, text...) + CloseHeader(buf, l) + lit, body, rest, err := TakeAnyWary(buf) + assert.Nil(t, err) + assert.Equal(t, uint8('A'), lit) + assert.Equal(t, text, string(body)) + assert.Equal(t, 0, len(rest)) +} + +func TestTLVReader_ReadRecord(t *testing.T) { + const K = 1000 + const L = 512 + _ = os.Remove("tlv") + file, err := os.OpenFile("tlv", os.O_CREATE|os.O_TRUNC|os.O_RDWR, os.ModePerm) + assert.Nil(t, err) + writer := Writer2Drainer{ + Writer: file, + } + var lo [L]byte + for i := 0; i < L; i++ { + lo[i] = byte(i) + } + var sho = [1]byte{'A'} + for i := 0; i < K; i++ { + err = writer.Drain( + Join( + Record('L', lo[:]), + Record('S', sho[:]), + ), + ) + assert.Nil(t, err) + } + assert.Nil(t, err) + info, err := file.Stat() + assert.Nil(t, err) + assert.Equal(t, int64((2+1)*K+(5+len(lo))*K), info.Size()) + _ = file.Close() + + file2, err := os.Open("tlv") + assert.Nil(t, err) + reader := Reader2Feeder{ + Reader: file2, + } + i := 0 + for i < K*2 { + + recs, err := reader.Feed() + assert.Nil(t, err) + for _, rec := range recs { + lit, body, rest, err := TakeAnyWary(rec) + assert.Nil(t, err) + assert.Equal(t, 0, len(rest)) + if (i & 1) == 0 { + assert.Equal(t, byte('L'), lit) + assert.Equal(t, lo[:], body) + } else { + assert.Equal(t, byte('S'), lit) + assert.Equal(t, sho[:], body) + } + i++ + } + + } + + recs, err := reader.Feed() + assert.Equal(t, io.EOF, err) + assert.Equal(t, 0, len(recs)) + + _ = os.Remove("tlv") +} + +func TestTinyRecord(t *testing.T) { + body := "12" + tiny := TinyRecord('X', []byte(body)) + assert.Equal(t, "212", string(tiny)) +}