diff --git a/aw.go b/aw.go index ec4f96d..76446bc 100644 --- a/aw.go +++ b/aw.go @@ -15,6 +15,10 @@ import ( "go.uber.org/zap" ) +type ( + ContentResolver = dht.ContentResolver +) + type Builder struct { opts Options @@ -75,10 +79,12 @@ func (builder *Builder) WithAddr(addr wire.Address) *Builder { } return builder } + func (builder *Builder) WithHost(host string) *Builder { builder.trans.TCPServerOpts = builder.trans.TCPServerOpts.WithHost(host) return builder } + func (builder *Builder) WithPort(port uint16) *Builder { builder.trans.TCPServerOpts = builder.trans.TCPServerOpts.WithPort(port) return builder @@ -155,6 +161,10 @@ func (node *Node) Gossiper() *gossip.Gossiper { return node.gossiper } +func (node *Node) Sync(ctx context.Context, subnet, hash id.Hash, dataType uint8) ([]byte, error) { + return node.gossiper.Sync(ctx, subnet, hash, dataType) +} + func (node *Node) Identity() id.Signatory { return node.peer.Identity() } diff --git a/aw_test.go b/aw_test.go index 6a2f7c7..dbd72b6 100644 --- a/aw_test.go +++ b/aw_test.go @@ -1,7 +1,9 @@ package aw_test import ( + "bytes" "context" + "crypto/sha256" "fmt" "math/rand" "sync/atomic" @@ -9,6 +11,7 @@ import ( "github.com/renproject/aw" "github.com/renproject/aw/dht" + "github.com/renproject/aw/gossip" "github.com/renproject/aw/wire" "github.com/renproject/id" @@ -32,15 +35,23 @@ var _ = Describe("Airwave", func() { defer cancel() port1 := uint16(3000 + r.Int()%3000) + addr1 := wire.NewUnsignedAddress(wire.TCP, fmt.Sprintf("0.0.0.0:%v", port1), uint64(time.Now().UnixNano())) + privKey1 := id.NewPrivKey() + Expect(addr1.Sign(privKey1)).To(Succeed()) node1 := aw.New(). - WithAddr(wire.NewUnsignedAddress(wire.TCP, fmt.Sprintf("0.0.0.0:%v", port1), uint64(time.Now().UnixNano()))). + WithPrivKey(privKey1). + WithAddr(addr1). WithHost("0.0.0.0"). WithPort(port1). Build() port2 := uint16(3000 + r.Int()%3000) + addr2 := wire.NewUnsignedAddress(wire.TCP, fmt.Sprintf("0.0.0.0:%v", port2), uint64(time.Now().UnixNano())) + privKey2 := id.NewPrivKey() + Expect(addr2.Sign(privKey2)).To(Succeed()) node2 := aw.New(). - WithAddr(wire.NewUnsignedAddress(wire.TCP, fmt.Sprintf("0.0.0.0:%v", port2), uint64(time.Now().UnixNano()))). + WithPrivKey(privKey2). + WithAddr(addr2). WithHost("0.0.0.0"). WithPort(port2). WithContentResolver( @@ -93,4 +104,74 @@ var _ = Describe("Airwave", func() { }) }) }) + + Context("when gossiping", func() { + Context("when fully connected", func() { + It("should return content from all nodes", func() { + defer time.Sleep(time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Initialise nodes. + n := 3 + nodes := make([]*aw.Node, n) + addrs := make([]wire.Address, n) + for i := range nodes { + port := uint16(3000 + i) + addrs[i] = wire.NewUnsignedAddress(wire.TCP, fmt.Sprintf("0.0.0.0:%v", port), uint64(time.Now().UnixNano())) + privKey := id.NewPrivKey() + Expect(addrs[i].Sign(privKey)).To(Succeed()) + node := aw.New(). + WithPrivKey(privKey). + WithAddr(addrs[i]). + WithHost("0.0.0.0"). + WithPort(port). + Build() + nodes[i] = node + } + + // Connect nodes in a fully connected cyclic graph. + for i := range nodes { + for j := range nodes { + if i == j { + continue + } + nodes[i].DHT().InsertAddr(addrs[j]) + } + } + + // Run the nodes. + for i := range nodes { + go nodes[i].Run(ctx) + } + + // Sleep for enough time for nodes to find each other by pinging + // each other. + time.Sleep(100 * time.Millisecond) + + contentHash := sha256.Sum256([]byte("hello!")) + contentType := uint8(1) + content := []byte("hello!") + nodes[0].Broadcast(ctx, gossip.DefaultSubnet, contentType, content) + + found := map[id.Signatory]struct{}{} + for { + time.Sleep(time.Millisecond) + for i := range nodes { + data, ok := nodes[i].DHT().Content(contentHash, contentType) + if !ok { + continue + } + if bytes.Equal(content, data) { + found[nodes[i].Identity()] = struct{}{} + } + } + if len(found) == n { + return + } + } + }) + }) + }) }) diff --git a/handshake/ecdsa.go b/handshake/ecdsa.go index bb4875d..87efe7f 100644 --- a/handshake/ecdsa.go +++ b/handshake/ecdsa.go @@ -95,8 +95,7 @@ func (handshaker *ecdsaHandshaker) AcceptHandshake(ctx context.Context, conn net // // 2 // - err = writePubKeyWithSignature(conn, &handshaker.opts.PrivKey.PublicKey, handshaker.opts.PrivKey) - if err != nil { + if err := writePubKeyWithSignature(conn, &handshaker.opts.PrivKey.PublicKey, handshaker.opts.PrivKey); err != nil { return nil, fmt.Errorf("writing server pubkey with signature: %v", err) } diff --git a/tcp/client.go b/tcp/client.go index 093bfcf..e27fa1f 100644 --- a/tcp/client.go +++ b/tcp/client.go @@ -4,13 +4,13 @@ import ( "bufio" "context" "fmt" + "math" "net" "sync" "time" "github.com/renproject/aw/handshake" "github.com/renproject/aw/wire" - "github.com/renproject/surge" "go.uber.org/zap" ) @@ -23,6 +23,8 @@ var ( // takes longer than this duration, it will be dropped (and a new attempt // may begin). DefaultClientTimeToDial = 15 * time.Second + // DefaultClientTimeToDialBackoff is set to 1.2. + DefaultClientTimeToDialBackoff = 1.2 // DefaultClientMaxDialAttempts is set to 5. If the first 5 dial attempts // fail, the connection will be dropped. DefaultClientMaxDialAttempts = 5 @@ -49,6 +51,9 @@ type ClientOptions struct { // be started, assuming that the client has not been attempting dials for // longer than the TimeToLive. TimeToDial time.Duration + // TimeToDialBackoff when establishing new connections. This is the + // multiplier added to the TimeToDial variable in the case an attempt fails. + TimeToDialBackoff float64 // MaxDialAttempts when establishing new connections. MaxDialAttempts int // MaxCapacity of messages that can be bufferred while waiting to write @@ -66,12 +71,13 @@ func DefaultClientOptions() ClientOptions { panic(err) } return ClientOptions{ - Logger: logger, - TimeToLive: DefaultClientTimeToLive, - TimeToDial: DefaultClientTimeToDial, - MaxDialAttempts: DefaultClientMaxDialAttempts, - MaxCapacity: DefaultClientMaxCapacity, - MaxConnections: DefaultClientMaxConnections, + Logger: logger, + TimeToLive: DefaultClientTimeToLive, + TimeToDial: DefaultClientTimeToDial, + TimeToDialBackoff: DefaultClientTimeToDialBackoff, + MaxDialAttempts: DefaultClientMaxDialAttempts, + MaxCapacity: DefaultClientMaxCapacity, + MaxConnections: DefaultClientMaxConnections, } } @@ -85,6 +91,16 @@ func (opts ClientOptions) WithTimeToLive(ttl time.Duration) ClientOptions { return opts } +func (opts ClientOptions) WithTimeToDial(ttd time.Duration) ClientOptions { + opts.TimeToDial = ttd + return opts +} + +func (opts ClientOptions) WithTimeToDialBackoff(backoff float64) ClientOptions { + opts.TimeToDialBackoff = backoff + return opts +} + func (opts ClientOptions) WithMaxDialAttempts(dialAttempts int) ClientOptions { opts.MaxDialAttempts = dialAttempts return opts @@ -275,18 +291,17 @@ func (client *Client) run(ctx context.Context, deadline *time.Timer, addr string // Run the message loop until the context is cancelled, or the deadline is // exceeded. rw := bufio.NewReadWriter( - bufio.NewReaderSize(conn, surge.MaxBytes), - bufio.NewWriterSize(conn, surge.MaxBytes), + bufio.NewReaderSize(conn, 1024), + bufio.NewWriterSize(conn, 1024), ) - buf := make([]byte, surge.MaxBytes) if lastMessage != nil { // Failing to write a message should not result in the // connection/channel being killed, and should not result in all pending // messages being lost. Therefore, we consume the error, and return a // nil-error. client.opts.Logger.Info("resending last message sent") - if err := client.write(session, rw, *lastMessage, buf); err != nil { + if err := client.write(session, rw, *lastMessage); err != nil { client.opts.Logger.Error("writing", zap.Error(err)) return lastMessage, nil } @@ -331,7 +346,7 @@ func (client *Client) run(ctx context.Context, deadline *time.Timer, addr string // connection/channel being killed, and should not result in all pending // messages being lost. Therefore, we consume the error, and return a // nil-error. - if err := client.write(session, rw, msg, buf); err != nil { + if err := client.write(session, rw, msg); err != nil { client.opts.Logger.Error("writing", zap.Error(err)) return &msg, nil } @@ -339,7 +354,7 @@ func (client *Client) run(ctx context.Context, deadline *time.Timer, addr string } } -func (client *Client) write(session handshake.Session, rw *bufio.ReadWriter, msg wire.Message, buf []byte) error { +func (client *Client) write(session handshake.Session, rw *bufio.ReadWriter, msg wire.Message) error { var err error // @@ -352,11 +367,7 @@ func (client *Client) write(session handshake.Session, rw *bufio.ReadWriter, msg return fmt.Errorf("encrypting message: %v", err) } // Write the encrypted message to the connection. - msgBytes, err := surge.ToBinary(msg) - if err != nil { - return fmt.Errorf("marshaling message: %v", err) - } - if _, err := rw.Write(msgBytes); err != nil { + if err := msg.Write(rw); err != nil { return fmt.Errorf("writing message: %v", err) } if err := rw.Flush(); err != nil { @@ -368,14 +379,10 @@ func (client *Client) write(session handshake.Session, rw *bufio.ReadWriter, msg // // Read an encrypted response from the connection. - if _, err := rw.Read(buf); err != nil { - return fmt.Errorf("reading response: %v", err) - } response := wire.Message{} - if err := surge.FromBinary(&response, buf); err != nil { + if err := response.Read(rw); err != nil { return fmt.Errorf("unmarshaling response: %v", err) } - buf = buf[:0] // Check that the response version is supported. switch response.Version { case wire.V1: @@ -433,15 +440,14 @@ func (client *Client) dial(ctx context.Context, addr string) (net.Conn, handshak dialCtx, dialCancel := context.WithTimeout(ctx, client.opts.TimeToLive) defer dialCancel() - // Remember the number of dial attempts for this connection. If this - // exceeds MaxDialAttempts, the connection will be dropped. + // Remember the number of dial attempts for this connection. If this exceeds + // MaxDialAttempts, the connection will be dropped. numDialAttempts := 0 for { select { case <-dialCtx.Done(): - // dial must only return an error if it cannot successful - // dial within the TimeToLive duration. Othewise, it will keep - // retrying. + // dial must only return an error if it cannot successful dial + // within the TimeToLive duration. Othewise, it will keep retrying. return nil, nil, dialCtx.Err() default: } @@ -449,43 +455,46 @@ func (client *Client) dial(ctx context.Context, addr string) (net.Conn, handshak // Increment the number of dial attempts. numDialAttempts++ - // Attempt to dial a new connection for TimeToDial seconds. If we are - // not successful, then wait until the end of the TimeToDial second - // timeout and try again. - innerDialCtx, innerDialCancel := context.WithTimeout(dialCtx, client.opts.TimeToDial) - conn, err := new(net.Dialer).DialContext(innerDialCtx, "tcp", addr) + // If the number of dial attempts has exceeded the maximum, return an + // error. + if numDialAttempts > client.opts.MaxDialAttempts { + return nil, nil, fmt.Errorf("dialing: exceeded max dial attempts") + } + + // Attempt to dial a new connection for TimeToDial duration. If we are + // not successful, then wait until the end of the TimeToDial timeout and + // try again. + backoff := math.Pow(client.opts.TimeToDialBackoff, float64(numDialAttempts)) + innerDialCtx, innerDialCancel := context.WithTimeout(dialCtx, client.opts.TimeToDial*time.Duration(backoff)) + conn, session, err := func() (net.Conn, handshake.Session, error) { + conn, err := new(net.Dialer).DialContext(innerDialCtx, "tcp", addr) + if err != nil { + return nil, nil, fmt.Errorf("dialing: %v", err) + } + + // Handshake with the server to establish authentication and + // encryption on the connection. + session, err := client.handshaker.Handshake(dialCtx, conn) + if err != nil { + return conn, nil, fmt.Errorf("handshaking: %v", err) + } + if session == nil { + return conn, nil, fmt.Errorf("handshaking: nil") + } + return conn, session, nil + }() if err != nil { + client.opts.Logger.Error(err.Error()) + // Make sure to wait until the entire TimeToDial seconds has passed, // otherwise we might attempt to re-dial too quickly. <-innerDialCtx.Done() innerDialCancel() - // If the number of dial attempts has exceeded the maximum, return - // an error. - if numDialAttempts >= client.opts.MaxDialAttempts { - return nil, nil, fmt.Errorf("dialing: exceeded max dial attempts: %v", err) - } - - client.opts.Logger.Warn("dialing", zap.Error(err)) - continue - } - - // Handshake with the server to establish authentication and encryption - // on the connection. - session, err := client.handshaker.Handshake(dialCtx, conn) - if err != nil { - innerDialCancel() - client.opts.Logger.Error("handshaking", zap.Error(err)) - if err := conn.Close(); err != nil { - client.opts.Logger.Error("closing connection", zap.Error(err)) - } - continue - } - if session == nil { - innerDialCancel() - client.opts.Logger.Error("handshaking: nil") - if err := conn.Close(); err != nil { - client.opts.Logger.Error("closing connection", zap.Error(err)) + if conn != nil { + if err := conn.Close(); err != nil { + client.opts.Logger.Error("closing connection", zap.Error(err)) + } } continue } diff --git a/tcp/server.go b/tcp/server.go index a5675fb..afccbe3 100644 --- a/tcp/server.go +++ b/tcp/server.go @@ -12,7 +12,6 @@ import ( "github.com/renproject/aw/handshake" "github.com/renproject/aw/wire" "github.com/renproject/id" - "github.com/renproject/surge" "go.uber.org/zap" "golang.org/x/time/rate" ) @@ -250,9 +249,8 @@ func (server *Server) handle(ctx context.Context, conn net.Conn) { // Read messages from the client until the time-to-live expires, or an error // is encountered when trying to read. server.opts.Logger.Info("handling", zap.String("remote", remoteAddr)) - buf := make([]byte, surge.MaxBytes) - bufReader := bufio.NewReaderSize(conn, surge.MaxBytes) - bufWriter := bufio.NewWriterSize(conn, surge.MaxBytes) + bufReader := bufio.NewReaderSize(conn, 1024) + bufWriter := bufio.NewWriterSize(conn, 1024) for { // We have "time-to-live" amount of time to read a message and write a // response to the message. @@ -262,12 +260,8 @@ func (server *Server) handle(ctx context.Context, conn net.Conn) { } // Read message from connection. - if _, err := bufReader.Read(buf); err != nil { - server.opts.Logger.Error("bad message", zap.Error(err)) - return - } msg := wire.Message{} - if err := surge.FromBinary(&msg, buf); err != nil { + if err := msg.Read(bufReader); err != nil { server.opts.Logger.Info("closing connection", zap.String("remote", conn.RemoteAddr().String())) return } @@ -331,15 +325,10 @@ func (server *Server) handle(ctx context.Context, conn net.Conn) { // sending it. response.Data, err = session.Encrypt(response.Data) if err != nil { - server.opts.Logger.Error("bad response: %v", zap.Error(err)) - return - } - responseBytes, err := surge.ToBinary(response) - if err != nil { - server.opts.Logger.Info("closing connection", zap.String("remote", conn.RemoteAddr().String())) + server.opts.Logger.Error("bad response", zap.Error(err)) return } - if _, err := bufWriter.Write(responseBytes); err != nil { + if err := response.Write(bufWriter); err != nil { server.opts.Logger.Error("bad response", zap.Error(err)) return } diff --git a/wire/msg.go b/wire/msg.go index abaa6ba..bc65475 100644 --- a/wire/msg.go +++ b/wire/msg.go @@ -2,6 +2,8 @@ package wire import ( "bytes" + "encoding/binary" + "fmt" "io" "github.com/renproject/id" @@ -51,6 +53,9 @@ const ( PullAck = Type(7) ) +// Data is an alias of a byte slice. +type Data = []byte + // A Message defines all of the information needed to gossip information on the // wire. type Message struct { @@ -59,7 +64,7 @@ type Message struct { // Version. Version Version `json:"version"` Type Type `json:"type"` - Data []byte `json:"data"` + Data Data `json:"data"` } // Equal compares one Message to another. It returns true if they are equal, @@ -68,30 +73,43 @@ func (msg Message) Equal(other *Message) bool { return msg.Version == other.Version && msg.Type == other.Type && bytes.Equal(msg.Data, other.Data) } -// Marshal this Message into binary. -func (msg Message) Marshal(buf []byte, rem int) ([]byte, int, error) { - var err error - if buf, rem, err = surge.MarshalU8(uint8(msg.Version), buf, rem); err != nil { - return buf, rem, err +// Write the message into an I/O writer. +func (msg Message) Write(w io.Writer) error { + if err := binary.Write(w, binary.BigEndian, uint8(msg.Version)); err != nil { + return err } - if buf, rem, err = surge.MarshalU8(uint8(msg.Type), buf, rem); err != nil { - return buf, rem, err + if err := binary.Write(w, binary.BigEndian, uint8(msg.Type)); err != nil { + return err } - return surge.MarshalBytes(msg.Data, buf, rem) + if err := binary.Write(w, binary.BigEndian, uint32(len(msg.Data))); err != nil { + return err + } + if err := binary.Write(w, binary.BigEndian, msg.Data); err != nil { + return err + } + return nil } -// Unmarshal from binary into this Message. -func (msg *Message) Unmarshal(buf []byte, rem int) ([]byte, int, error) { - var err error - buf, rem, err = surge.UnmarshalU8((*uint8)(&msg.Version), buf, rem) - if err != nil { - return buf, rem, err +// Read the message from an I/O writer. +func (msg *Message) Read(r io.Reader) error { + if err := binary.Read(r, binary.BigEndian, (*uint8)(&msg.Version)); err != nil { + return err } - buf, rem, err = surge.UnmarshalU8((*uint8)(&msg.Type), buf, rem) - if err != nil { - return buf, rem, err + if err := binary.Read(r, binary.BigEndian, (*uint8)(&msg.Type)); err != nil { + return err + } + dataLen := uint32(0) + if err := binary.Read(r, binary.BigEndian, &dataLen); err != nil { + return err + } + if dataLen > uint32(surge.MaxBytes) { + return fmt.Errorf("message length exceeds max bytes") + } + msg.Data = make([]byte, dataLen) + if err := binary.Read(r, binary.BigEndian, &msg.Data); err != nil { + return err } - return surge.UnmarshalBytes(&msg.Data, buf, rem) + return nil } type PingV1 struct {