Skip to content

Commit

Permalink
zmtp: improve perfs of Connection.read{,Multipart} (#71)
Browse files Browse the repository at this point in the history
* zmtp: improve perfs of Connection.(read|send){,Multipart}

This CL uses io.ReadFull to make sure all requested bytes are read from
an io.Reader.
It's also using binary.ByteOrder.Uint64 and binary.ByteOrder.PutUint64
directly instead of going the round about way through (slow) reflection.

Fixes #67.
Fixes #61.

* zmtp: reduce number of allocs in Connection.SendCommand

* zmtp: reduce number of allocs in Connection.writeMetadata

* zmtp: removed reflection from de/serializing greetings

* zmtp: remove slow reflection in Connection.recvMetadata
  • Loading branch information
sbinet authored and Luna Duclos committed Apr 19, 2018
1 parent d84a741 commit 706c95d
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 82 deletions.
142 changes: 61 additions & 81 deletions zmtp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package zmtp

import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -113,7 +112,7 @@ func (c *Connection) sendGreeting(asServer bool) error {
}
toNullPaddedString(string(c.securityMechanism.Type()), greeting.Mechanism[:])

if err := binary.Write(c.rw, byteOrder, &greeting); err != nil {
if err := greeting.marshal(c.rw); err != nil {
return err
}

Expand All @@ -123,7 +122,7 @@ func (c *Connection) sendGreeting(asServer bool) error {
func (c *Connection) recvGreeting(asServer bool) error {
var greeting greeting

if err := binary.Read(c.rw, byteOrder, &greeting); err != nil {
if err := greeting.unmarshal(c.rw); err != nil {
return fmt.Errorf("Error while reading: %v", err)
}

Expand Down Expand Up @@ -179,10 +178,23 @@ func (c *Connection) sendMetadata(socketType SocketType, socketID SocketIdentity
}

func (c *Connection) writeMetadata(buffer *bytes.Buffer, name string, value string) {
buffer.WriteByte(byte(len(name)))
buffer.WriteString(name)
binary.Write(buffer, byteOrder, uint32(len(value)))
buffer.WriteString(value)
var (
p = 0
nameLen = len(name)
valueLen = len(value)
buf = make([]byte, 1+nameLen+4+valueLen)
)
buf[p] = byte(nameLen)
p++
p += copy(buf[p:], name)
byteOrder.PutUint32(buf[p:p+4], uint32(valueLen))
p += 4
copy(buf[p:], value)

_, err := buffer.Write(buf)
if err != nil {
panic(err)
}
}

func (c *Connection) recvMetadata() (map[string]string, error) {
Expand Down Expand Up @@ -220,10 +232,7 @@ func (c *Connection) recvMetadata() (map[string]string, error) {
i += keyLength

// Value length
var rawValueLength uint32
if err := binary.Read(bytes.NewBuffer(command.Body[i:i+4]), byteOrder, &rawValueLength); err != nil {
return nil, err
}
rawValueLength := byteOrder.Uint32(command.Body[i : i+4])

if uint64(rawValueLength) > uint64(maxInt) {
return nil, fmt.Errorf("Length of value %v overflows integer max length %v on this platform", rawValueLength, maxInt)
Expand Down Expand Up @@ -256,17 +265,19 @@ func (c *Connection) recvMetadata() (map[string]string, error) {

// SendCommand sends a ZMTP command over a Connection
func (c *Connection) SendCommand(commandName string, body []byte) error {
if len(commandName) > 255 {
cmdLen := len(commandName)
if cmdLen > 255 {
return errors.New("Command names may not be longer than 255 characters")
}

// Make the buffer of the correct length and reset it
buffer := new(bytes.Buffer)
buffer.WriteByte(byte(len(commandName)))
buffer.Write([]byte(commandName))
buffer.Write(body)
bodyLen := len(body)

return c.send(true, buffer.Bytes())
buf := make([]byte, 1+cmdLen+bodyLen) // FIXME(sbinet): maybe use a pool of []byte ?
buf[0] = byte(cmdLen)
copy(buf[1:], []byte(commandName))
copy(buf[1+cmdLen:], body)

return c.send(true, buf)
}

// SendFrame sends a ZMTP frame over a Connection
Expand Down Expand Up @@ -299,11 +310,13 @@ func (c *Connection) send(isCommand bool, body []byte) error {
}

if isLong {
if err := binary.Write(c.rw, byteOrder, int64(len(body))); err != nil {
var buf [8]byte
byteOrder.PutUint64(buf[:], uint64(len(body)))
if _, err := c.rw.Write(buf[:]); err != nil {
return err
}
} else {
if err := binary.Write(c.rw, byteOrder, uint8(len(body))); err != nil {
if _, err := c.rw.Write([]byte{uint8(len(body))}); err != nil {
return err
}
}
Expand Down Expand Up @@ -362,14 +375,9 @@ func (c *Connection) read() (bool, []byte, error) {
var longLength [8]byte

// Read out the header
readLength := uint64(0)
for readLength != 2 {
l, err := c.rw.Read(header[readLength:])
if err != nil {
return false, nil, err
}

readLength += uint64(l)
_, err := io.ReadFull(c.rw, header[:])
if err != nil {
return false, nil, err
}

bitFlags := header[0]
Expand All @@ -392,19 +400,12 @@ func (c *Connection) read() (bool, []byte, error) {
// We already have the first byte, so assign it, and then read the rest
longLength[0] = header[1]

readLength := 1
for readLength != 8 {
l, err := c.rw.Read(longLength[readLength:])
if err != nil {
return false, nil, err
}

readLength += l
}

if err := binary.Read(bytes.NewBuffer(longLength[:]), byteOrder, &bodyLength); err != nil {
_, err := io.ReadFull(c.rw, longLength[1:])
if err != nil {
return false, nil, err
}

bodyLength = byteOrder.Uint64(longLength[:])
} else {
// Short message length is just 1 byte, read it
bodyLength = uint64(header[1])
Expand All @@ -414,18 +415,12 @@ func (c *Connection) read() (bool, []byte, error) {
return false, nil, fmt.Errorf("Body length %v overflows max int64 value %v", bodyLength, maxInt64)
}

buffer := new(bytes.Buffer)
readLength = 0
for readLength < bodyLength {
l, err := buffer.ReadFrom(io.LimitReader(c.rw, int64(bodyLength)-int64(readLength)))
if err != nil {
return false, nil, err
}

readLength += uint64(l)
buf := make([]byte, bodyLength)
_, err = io.ReadFull(c.rw, buf)
if err != nil {
return false, nil, err
}

return isCommand, buffer.Bytes(), nil
return isCommand, buf, nil
}

func (c *Connection) parseCommand(body []byte) (*Command, error) {
Expand Down Expand Up @@ -482,11 +477,13 @@ func (c *Connection) sendMultipart(isCommand bool, bs [][]byte) error {
}

if isLong {
if err := binary.Write(c.rw, byteOrder, int64(len(part))); err != nil {
var buf [8]byte
byteOrder.PutUint64(buf[:], uint64(len(part)))
if _, err := c.rw.Write(buf[:]); err != nil {
return err
}
} else {
if err := binary.Write(c.rw, byteOrder, uint8(len(part))); err != nil {
if _, err := c.rw.Write([]byte{uint8(len(part))}); err != nil {
return err
}
}
Expand Down Expand Up @@ -551,14 +548,9 @@ func (c *Connection) readMultipart() (bool, [][]byte, error) {

for hasMore {
// Read out the header
readLength := uint64(0)
for readLength != 2 {
l, err := c.rw.Read(header[readLength:])
if err != nil {
return false, nil, err
}

readLength += uint64(l)
_, err := io.ReadFull(c.rw, header[:])
if err != nil {
return false, nil, err
}

bitFlags := header[0]
Expand All @@ -576,19 +568,12 @@ func (c *Connection) readMultipart() (bool, [][]byte, error) {
// We already have the first byte, so assign it, and then read the rest
longLength[0] = header[1]

readLength := 1
for readLength != 8 {
l, err := c.rw.Read(longLength[readLength:])
if err != nil {
return false, nil, err
}

readLength += l
}

if err := binary.Read(bytes.NewBuffer(longLength[:]), byteOrder, &bodyLength); err != nil {
_, err := io.ReadFull(c.rw, longLength[1:])
if err != nil {
return false, nil, err
}

bodyLength = byteOrder.Uint64(longLength[:])
} else {
// Short message length is just 1 byte, read it
bodyLength = uint64(header[1])
Expand All @@ -598,17 +583,12 @@ func (c *Connection) readMultipart() (bool, [][]byte, error) {
return false, nil, fmt.Errorf("Body length %v overflows max int64 value %v", bodyLength, maxInt64)
}

buffer := new(bytes.Buffer)
readLength = 0
for readLength < bodyLength {
l, err := buffer.ReadFrom(io.LimitReader(c.rw, int64(bodyLength)-int64(readLength)))
if err != nil {
return false, nil, err
}

readLength += uint64(l)
buf := make([]byte, bodyLength)
_, err = io.ReadFull(c.rw, buf)
if err != nil {
return false, nil, err
}
frames = append(frames, buffer.Bytes())
frames = append(frames, buf)
}

return isCommand, frames, nil
Expand Down
38 changes: 37 additions & 1 deletion zmtp/protocol.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package zmtp

import "encoding/binary"
import (
"encoding/binary"
"io"
)

const (
majorVersion uint8 = 3
Expand Down Expand Up @@ -58,6 +61,39 @@ type greeting struct {
_ [31]byte
}

func (g *greeting) unmarshal(r io.Reader) error {
var buf [64]byte
_, err := io.ReadFull(r, buf[:])
if err != nil {
return err
}
g.SignaturePrefix = buf[0]
// padding 1 ignored
g.SignatureSuffix = buf[9]
g.Version[0] = buf[10]
g.Version[1] = buf[11]
copy(g.Mechanism[:], buf[12:32])
g.ServerFlag = buf[32]
// padding 2 ignored

return nil
}

func (g *greeting) marshal(w io.Writer) error {
var buf [64]byte
buf[0] = g.SignaturePrefix
// padding 1 ignored
buf[9] = g.SignatureSuffix
buf[10] = g.Version[0]
buf[11] = g.Version[1]
copy(buf[12:32], g.Mechanism[:])
buf[32] = g.ServerFlag
// padding 2 ignored

_, err := w.Write(buf[:])
return err
}

// Command represents an underlying ZMTP command
type Command struct {
Index int
Expand Down

0 comments on commit 706c95d

Please sign in to comment.