diff --git a/peer.go b/peer.go index aa7c271..b044caf 100644 --- a/peer.go +++ b/peer.go @@ -40,6 +40,9 @@ type PeerOptions struct { BroadcasterStore kv.Table // Defaults to using in memory store SignVerifier protocol.SignVerifier // Defaults to nil Runners []Runner // Defaults to nil + + ServerOptions tcp.ServerOptions + ClientOptions tcp.ClientOptions } type Peer interface { @@ -244,11 +247,6 @@ func newErrInvalidPeerOptions(err error) error { } func NewTCPPeer(options PeerOptions, events EventSender, cap, port int) Peer { - var handshaker handshake.Handshaker - if options.SignVerifier != nil { - handshaker = handshake.New(options.SignVerifier) - } - if options.DHTStore == nil { options.DHTStore = kv.NewTable(kv.NewMemDB(kv.GobCodec), "dht") } @@ -257,6 +255,26 @@ func NewTCPPeer(options PeerOptions, events EventSender, cap, port int) Peer { options.BroadcasterStore = kv.NewTable(kv.NewMemDB(kv.GobCodec), "broadcaster") } + if options.ServerOptions.Logger == nil { + options.ServerOptions.Logger = options.Logger + } + + if options.ServerOptions.Port == 0 { + options.ServerOptions.Port = port + } + + if options.ServerOptions.Handshaker == nil && options.SignVerifier != nil { + options.ServerOptions.Handshaker = handshake.New(options.SignVerifier) + } + + if options.ClientOptions.Logger == nil { + options.ClientOptions.Logger = options.Logger + } + + if options.ClientOptions.Handshaker == nil && options.SignVerifier != nil { + options.ClientOptions.Handshaker = handshake.New(options.SignVerifier) + } + serverMessages := make(chan protocol.MessageOnTheWire, cap) clientMessages := make(chan protocol.MessageOnTheWire, cap) @@ -266,18 +284,8 @@ func NewTCPPeer(options PeerOptions, events EventSender, cap, port int) Peer { } options.Runners = append(options.Runners, - tcp.NewServer(tcp.ServerOptions{ - Logger: options.Logger, - Timeout: time.Minute, - Handshaker: handshaker, - Port: port, - }, serverMessages), - tcp.NewClient(tcp.NewClientConns(tcp.ClientOptions{ - Logger: options.Logger, - Timeout: 10 * time.Second, - Handshaker: handshaker, - MaxConnections: 200, - }), dht, clientMessages), + tcp.NewServer(options.ServerOptions, serverMessages), + tcp.NewClient(tcp.NewClientConns(options.ClientOptions), dht, clientMessages), ) return New( diff --git a/peer_test.go b/peer_test.go index f70a86a..21e1045 100644 --- a/peer_test.go +++ b/peer_test.go @@ -2,6 +2,7 @@ package aw_test import ( "context" + "encoding/base64" "fmt" "time" @@ -10,6 +11,7 @@ import ( . "github.com/renproject/aw" "github.com/renproject/aw/protocol" + "github.com/renproject/aw/tcp" "github.com/renproject/aw/testutil" "github.com/renproject/phi/co" "github.com/sirupsen/logrus" @@ -130,7 +132,7 @@ var _ = Describe("airwaves peer", func() { }) Context("when updating peer address", func() { - FIt("should be able to send messages to the new address", func() { + It("should be able to send messages to the new address", func() { logger := logrus.StandardLogger() peer1Events := make(chan protocol.Event, 65535) @@ -144,30 +146,39 @@ var _ = Describe("airwaves peer", func() { updatedPeer2Address.Nonce = 1 peer1 := NewTCPPeer(PeerOptions{ - Logger: logger, + Logger: logger.WithField("peer", 1), Me: peer1Address, BootstrapAddresses: PeerAddresses{peer2Address}, Codec: codec, BootstrapDuration: 3 * time.Second, + ClientOptions: tcp.ClientOptions{ + MaxRetries: 60, + }, }, peer1Events, 65535, 8080) peer2 := NewTCPPeer(PeerOptions{ - Logger: logger, + Logger: logger.WithField("peer", 2), Me: peer2Address, BootstrapAddresses: PeerAddresses{peer1Address}, Codec: codec, BootstrapDuration: 3 * time.Second, + ClientOptions: tcp.ClientOptions{ + MaxRetries: 60, + }, }, peer2Events, 65535, 8081) updatedPeer2 := NewTCPPeer(PeerOptions{ - Logger: logger, + Logger: logger.WithField("updated peer", 2), Me: updatedPeer2Address, BootstrapAddresses: PeerAddresses{peer1Address}, Codec: codec, BootstrapDuration: 3 * time.Second, + ClientOptions: tcp.ClientOptions{ + MaxRetries: 60, + }, }, updatedPeer2Events, 65535, 8082) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -176,14 +187,16 @@ var _ = Describe("airwaves peer", func() { cancel() }() + newCtx, newCancel := context.WithCancel(context.Background()) + co.ParBegin( func() { - peer1.Run(context.Background()) + peer1.Run(newCtx) }, func() { peer2.Run(ctx) fmt.Println("peer 2 restarted") - updatedPeer2.Run(context.Background()) + updatedPeer2.Run(newCtx) }, func() { <-ctx.Done() @@ -192,19 +205,29 @@ var _ = Describe("airwaves peer", func() { <-ctx2.Done() cancel() }() - if err := peer1.Cast(ctx2, testutil.SimplePeerID("peer_2"), []byte("hello")); err != nil { + // After peer 2 receives this message, its server should + // shut down as the context has expired. + if err := peer1.Cast(ctx2, peer2Address.PeerID(), []byte("hello")); err != nil { + panic(err) + } + time.Sleep(time.Second) + if err := peer1.Cast(ctx2, peer2Address.PeerID(), []byte("hello")); err != nil { panic(err) } }, func() { - for event := range peer1Events { - fmt.Println(event) - } + event := <-peer1Events + _, ok := event.(protocol.EventPeerChanged) + Expect(ok).To(BeTrue()) }, func() { - for event := range updatedPeer2Events { - fmt.Println(event) - } + event := <-updatedPeer2Events + msg, ok := event.(protocol.EventMessageReceived) + Expect(ok).To(BeTrue()) + msgBytes, err := base64.StdEncoding.DecodeString(msg.Message.String()) + Expect(err).ToNot(HaveOccurred()) + Expect(msgBytes).To(Equal([]byte("hello"))) + newCancel() }, ) }) diff --git a/tcp/client.go b/tcp/client.go index aa7af0d..48bc00d 100644 --- a/tcp/client.go +++ b/tcp/client.go @@ -24,6 +24,8 @@ type ClientOptions struct { Timeout time.Duration // MaxConnections to remote servers that the Client will maintain. MaxConnections int + // MaxRetries if the message cannot be sent. + MaxRetries int // Handshaker handles the handshake process between peers. Default: no handshake Handshaker handshake.Handshaker } @@ -60,7 +62,7 @@ func NewClientConns(options ClientOptions) *ClientConns { if options.Timeout == 0 { options.Timeout = 10 * time.Second } - if options.MaxConnections == 256 { + if options.MaxConnections == 0 { options.MaxConnections = 256 } @@ -87,7 +89,11 @@ func (clientConns *ClientConns) Write(ctx context.Context, addr net.Addr, messag clientConns.connsMu.RLock() conn := clientConns.conns[addr.String()] clientConns.connsMu.RUnlock() + + clientConns.connsMu.RLock() if conn != nil && conn.conn != nil { + clientConns.connsMu.RUnlock() + // Mutex on the conn conn.mu.Lock() defer conn.mu.Unlock() @@ -101,6 +107,7 @@ func (clientConns *ClientConns) Write(ctx context.Context, addr net.Addr, messag } return nil } + clientConns.connsMu.RUnlock() // Protect the cache from concurrent writes and establish a connection that // can be dialed @@ -128,7 +135,11 @@ func (clientConns *ClientConns) Write(ctx context.Context, addr net.Addr, messag if err != nil { return err } + + clientConns.connsMu.RLock() if conn.conn != nil { + clientConns.connsMu.RUnlock() + // Mutex on the conn conn.mu.Lock() defer conn.mu.Unlock() @@ -142,10 +153,14 @@ func (clientConns *ClientConns) Write(ctx context.Context, addr net.Addr, messag } return nil } + clientConns.connsMu.RUnlock() // Double-check the connection, because while waiting to acquire the write lock // another goroutine may have already dialed the remote server + clientConns.connsMu.RLock() if conn.conn != nil { + clientConns.connsMu.RUnlock() + // Mutex on the conn conn.mu.Lock() defer conn.mu.Unlock() @@ -159,6 +174,7 @@ func (clientConns *ClientConns) Write(ctx context.Context, addr net.Addr, messag } return nil } + clientConns.connsMu.RUnlock() // A new connection needs to be dialed, so we lock the connection to prevent // multiple dials against the same remote server @@ -166,7 +182,9 @@ func (clientConns *ClientConns) Write(ctx context.Context, addr net.Addr, messag defer conn.mu.Unlock() // Dial + clientConns.connsMu.Lock() conn.conn, err = net.DialTimeout("tcp", addr.String(), clientConns.options.Timeout) + clientConns.connsMu.Unlock() if err != nil { return err } @@ -297,7 +315,7 @@ func (client *Client) sendMessageOnTheWire(ctx context.Context, to net.Addr, mes go func() { begin := time.Now() delay := time.Duration(1000) - for i := 0; i < 60; i++ { + for i := 0; i < client.conns.options.MaxRetries; i++ { // Dial client.conns.options.Logger.Warnf("retrying write to tcp connection to %v with delay of %.4f second(s)", to.String(), time.Now().Sub(begin).Seconds()) err := client.conns.Write(ctx, to, message) diff --git a/tcp/server.go b/tcp/server.go index a404644..a9e94e4 100644 --- a/tcp/server.go +++ b/tcp/server.go @@ -28,6 +28,13 @@ type Server struct { } func NewServer(options ServerOptions, messages protocol.MessageSender) *Server { + if options.Logger == nil { + panic("pre-condition violation: logger is nil") + } + if options.Timeout == 0 { + options.Timeout = time.Minute + } + return &Server{ options: options, messages: messages, diff --git a/testutil/addr.go b/testutil/addr.go index 275c43f..923877c 100644 --- a/testutil/addr.go +++ b/testutil/addr.go @@ -82,7 +82,7 @@ func (address SimpleTCPPeerAddress) IsNewer(peerAddress protocol.PeerAddress) bo if !ok { return false } - return peerAddr.Nonce > address.Nonce + return address.Nonce > peerAddr.Nonce } func Remove(addrs protocol.PeerAddresses, i int) protocol.PeerAddresses {