Skip to content

Commit

Permalink
Will put message into buffer if client timeout
Browse files Browse the repository at this point in the history
This should help prevent client message lost when client conn
accidentally drop off.

Signed-off-by: Yilun <[email protected]>
  • Loading branch information
yilunzhang committed Mar 26, 2020
1 parent a473ad1 commit a7ee196
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 31 deletions.
3 changes: 0 additions & 3 deletions api/websocket/messagebuffer/messagebuffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ func NewMessageBuffer() *MessageBuffer {

// AddMessage adds a message to message buffer
func (messageBuffer *MessageBuffer) AddMessage(clientID []byte, msg *pb.Relay) {
if msg.MaxHoldingSeconds == 0 {
return
}
clientIDStr := hex.EncodeToString(clientID)
messageBuffer.Lock()
defer messageBuffer.Unlock()
Expand Down
47 changes: 47 additions & 0 deletions api/websocket/server/delayedchan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package server

import (
"time"
)

type DelayedChan struct {
buffer chan *delayedValue
delay time.Duration
}

type delayedValue struct {
value interface{}
releaseTime time.Time
}

func NewDelayedChan(size int, delay time.Duration) *DelayedChan {
buffer := make(chan *delayedValue, size)
return &DelayedChan{
buffer: buffer,
delay: delay,
}
}

func (dc *DelayedChan) Push(v interface{}) bool {
dv := &delayedValue{
value: v,
releaseTime: time.Now().Add(dc.delay),
}
select {
case dc.buffer <- dv:
return true
default:
return false
}
}

func (dc *DelayedChan) Pop() (interface{}, bool) {
dv, ok := <-dc.buffer
if !ok {
return nil, false
}
if dv.releaseTime.After(time.Now()) {
time.Sleep(time.Until(dv.releaseTime))
}
return dv.value, true
}
38 changes: 36 additions & 2 deletions api/websocket/server/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"time"

"github.com/gogo/protobuf/proto"
"github.com/nknorg/nkn/pb"
Expand Down Expand Up @@ -65,7 +66,7 @@ func (ws *WsServer) sendOutboundRelayMessage(srcAddrStrPtr *string, msg *pb.Outb
func (ws *WsServer) sendInboundMessage(clientID string, inboundMsg *pb.InboundMessage) bool {
clients := ws.SessionList.GetSessionsById(clientID)
if clients == nil {
log.Infof("Client Not Online: %s", clientID)
log.Debugf("Client Not Online: %s", clientID)
return false
}

Expand Down Expand Up @@ -124,11 +125,44 @@ func (ws *WsServer) sendInboundRelayMessage(relayMessage *pb.Relay) {
sigChainLen: int(relayMessage.SigChainLen),
})
}
} else {
if time.Duration(relayMessage.MaxHoldingSeconds) > pongTimeout/time.Second {
ok := ws.messageDeliveredCache.Push(relayMessage)
if !ok {
log.Warningf("MessageDeliveredCache full, discarding messages.")
}
}
} else if relayMessage.MaxHoldingSeconds > 0 {
ws.messageBuffer.AddMessage(clientID, relayMessage)
}
}

func (ws *WsServer) startCheckingLostMessages() {
for {
v, ok := ws.messageDeliveredCache.Pop()
if !ok {
break
}
if relayMessage, ok := v.(*pb.Relay); ok {
clientID := relayMessage.DestId
clients := ws.SessionList.GetSessionsById(hex.EncodeToString(clientID))
if len(clients) > 0 {
threshold := time.Now().Add(-pongTimeout)
success := false
for _, client := range clients {
if client.GetLastReadTime().After(threshold) {
success = true
break
}
}
if success {
continue
}
}
ws.messageBuffer.AddMessage(clientID, relayMessage)
}
}
}

func (ws *WsServer) handleReceipt(receipt *pb.Receipt) error {
v, ok := ws.sigChainCache.Get(receipt.PrevSignature)
if !ok {
Expand Down
45 changes: 26 additions & 19 deletions api/websocket/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ const (
pingInterval = 8 * time.Second
pongTimeout = 10 * time.Second // should be greater than pingInterval
maxMessageSize = config.MaxClientMessageSize
messageDeliveredCacheSize = 65536
)

type Handler struct {
Expand All @@ -50,29 +51,31 @@ type Handler struct {

type WsServer struct {
sync.RWMutex
Upgrader websocket.Upgrader
listener net.Listener
tlsListener net.Listener
server *http.Server
tlsServer *http.Server
SessionList *session.SessionList
ActionMap map[string]Handler
TxHashMap map[string]string //key: txHash value:sessionid
localNode *node.LocalNode
wallet vault.Wallet
messageBuffer *messagebuffer.MessageBuffer
sigChainCache Cache
Upgrader websocket.Upgrader
listener net.Listener
tlsListener net.Listener
server *http.Server
tlsServer *http.Server
SessionList *session.SessionList
ActionMap map[string]Handler
TxHashMap map[string]string //key: txHash value:sessionid
localNode *node.LocalNode
wallet vault.Wallet
messageBuffer *messagebuffer.MessageBuffer
messageDeliveredCache *DelayedChan
sigChainCache Cache
}

func InitWsServer(localNode *node.LocalNode, wallet vault.Wallet) *WsServer {
ws := &WsServer{
Upgrader: websocket.Upgrader{},
SessionList: session.NewSessionList(),
TxHashMap: make(map[string]string),
localNode: localNode,
wallet: wallet,
messageBuffer: messagebuffer.NewMessageBuffer(),
sigChainCache: NewGoCache(sigChainCacheExpiration, sigChainCacheCleanupInterval),
Upgrader: websocket.Upgrader{},
SessionList: session.NewSessionList(),
TxHashMap: make(map[string]string),
localNode: localNode,
wallet: wallet,
messageBuffer: messagebuffer.NewMessageBuffer(),
messageDeliveredCache: NewDelayedChan(messageDeliveredCacheSize, pongTimeout),
sigChainCache: NewGoCache(sigChainCacheExpiration, sigChainCacheCleanupInterval),
}
return ws
}
Expand Down Expand Up @@ -108,6 +111,8 @@ func (ws *WsServer) Start() error {
ws.tlsServer = &http.Server{Handler: http.HandlerFunc(ws.websocketHandler)}
go ws.tlsServer.Serve(ws.tlsListener)

go ws.startCheckingLostMessages()

return nil
}

Expand Down Expand Up @@ -259,6 +264,7 @@ func (ws *WsServer) websocketHandler(w http.ResponseWriter, r *http.Request) {
wsConn.SetReadDeadline(time.Now().Add(pongTimeout))
wsConn.SetPongHandler(func(string) error {
wsConn.SetReadDeadline(time.Now().Add(pongTimeout))
sess.UpdateLastReadTime()
return nil
})

Expand Down Expand Up @@ -289,6 +295,7 @@ func (ws *WsServer) websocketHandler(w http.ResponseWriter, r *http.Request) {
}

wsConn.SetReadDeadline(time.Now().Add(pongTimeout))
sess.UpdateLastReadTime()

err = ws.OnDataHandle(sess, messageType, bysMsg, r)
if err != nil {
Expand Down
29 changes: 22 additions & 7 deletions api/websocket/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ const (
)

type Session struct {
ws *websocket.Conn
sessionID string

sync.RWMutex
sessionID string
clientChordID []byte
clientPubKey []byte
clientAddrStr *string
isTlsClient bool
lastReadTime time.Time

wsLock sync.Mutex
ws *websocket.Conn
}

func (s *Session) GetSessionId() string {
Expand All @@ -31,8 +33,9 @@ func (s *Session) GetSessionId() string {
func newSession(wsConn *websocket.Conn) (session *Session, err error) {
sessionID := uuid.NewUUID().String()
session = &Session{
ws: wsConn,
sessionID: sessionID,
ws: wsConn,
sessionID: sessionID,
lastReadTime: time.Now(),
}
return session, nil
}
Expand All @@ -48,8 +51,8 @@ func (s *Session) close() {
}

func (s *Session) Send(msgType int, data []byte) error {
s.RLock()
defer s.RUnlock()
s.wsLock.Lock()
defer s.wsLock.Unlock()
if s.ws == nil {
return errors.New("Websocket is null")
}
Expand Down Expand Up @@ -126,3 +129,15 @@ func (s *Session) IsTlsClient() bool {
defer s.RUnlock()
return s.isTlsClient
}

func (s *Session) GetLastReadTime() time.Time {
s.RLock()
defer s.RUnlock()
return s.lastReadTime
}

func (s *Session) UpdateLastReadTime() {
s.Lock()
s.lastReadTime = time.Now()
s.Unlock()
}

0 comments on commit a7ee196

Please sign in to comment.