Skip to content

Commit

Permalink
Use periodical ping/pong message to detect client timeout rapidly
Browse files Browse the repository at this point in the history
Signed-off-by: Yilun <[email protected]>
  • Loading branch information
yilunzhang committed Mar 26, 2020
1 parent 6db4962 commit 9ac08a7
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 86 deletions.
117 changes: 58 additions & 59 deletions api/websocket/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/tls"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net"
Expand Down Expand Up @@ -34,9 +35,12 @@ import (
)

const (
TlsPort uint16 = 443
sigChainCacheExpiration = config.ConsensusTimeout
sigChainCacheCleanupInterval = time.Second
TlsPort = 443
sigChainCacheExpiration = config.ConsensusTimeout
sigChainCacheCleanupInterval = time.Second
pingInterval = 8 * time.Second
pongTimeout = 10 * time.Second // should be greater than pingInterval
maxMessageSize = config.MaxClientMessageSize
)

type Handler struct {
Expand Down Expand Up @@ -98,8 +102,6 @@ func (ws *WsServer) Start() error {

event.Queue.Subscribe(event.SendInboundMessageToClient, ws.sendInboundRelayMessageToClient)

go ws.checkSessionsTimeout()

ws.server = &http.Server{Handler: http.HandlerFunc(ws.websocketHandler)}
go ws.server.Serve(ws.listener)

Expand Down Expand Up @@ -230,60 +232,67 @@ func (ws *WsServer) Restart() {
}()
}

func (ws *WsServer) checkSessionsTimeout() {
ticker := time.NewTicker(time.Second * 10)
defer ticker.Stop()
for {
select {
case <-ticker.C:
var closeList []*session.Session
ws.SessionList.ForEachSession(func(s *session.Session) {
if s.SessionTimeoverCheck() {
resp := common.ResponsePack(common.SESSION_EXPIRED)
ws.respondToSession(s, resp)
closeList = append(closeList, s)
}
})
for _, s := range closeList {
ws.SessionList.CloseSession(s)
}
}
}

}

//websocketHandler
func (ws *WsServer) websocketHandler(w http.ResponseWriter, r *http.Request) {
wsConn, err := ws.Upgrader.Upgrade(w, r, nil)

if err != nil {
log.Error("websocket Upgrader: ", err)
return
}
defer wsConn.Close()
nsSession, err := ws.SessionList.NewSession(wsConn)

sess, err := ws.SessionList.NewSession(wsConn)
if err != nil {
log.Error("websocket NewSession:", err)
return
}

defer func() {
ws.deleteTxHashs(nsSession.GetSessionId())
ws.SessionList.CloseSession(nsSession)
ws.deleteTxHashs(sess.GetSessionId())
ws.SessionList.CloseSession(sess)
if err := recover(); err != nil {
log.Error("websocket recover:", err)
}
}()

wsConn.SetReadLimit(maxMessageSize)
wsConn.SetReadDeadline(time.Now().Add(pongTimeout))
wsConn.SetPongHandler(func(string) error {
wsConn.SetReadDeadline(time.Now().Add(pongTimeout))
return nil
})

done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(pingInterval)
defer ticker.Stop()
var err error
for {
select {
case <-ticker.C:
err = sess.Ping()
if err != nil {
return
}
case <-done:
return
}
}
}()

for {
messageType, bysMsg, err := wsConn.ReadMessage()
if err != nil {
log.Debugf("websocket read message error: %v", err)
break
}

if ws.OnDataHandle(nsSession, messageType, bysMsg, r) {
nsSession.UpdateActiveTime()
wsConn.SetReadDeadline(time.Now().Add(pongTimeout))

err = ws.OnDataHandle(sess, messageType, bysMsg, r)
if err != nil {
log.Error(err)
}
}
}
Expand All @@ -301,13 +310,12 @@ func (ws *WsServer) IsValidMsg(reqMsg map[string]interface{}) bool {
return true
}

func (ws *WsServer) OnDataHandle(curSession *session.Session, messageType int, bysMsg []byte, r *http.Request) bool {
func (ws *WsServer) OnDataHandle(curSession *session.Session, messageType int, bysMsg []byte, r *http.Request) error {
if messageType == websocket.BinaryMessage {
msg := &pb.ClientMessage{}
err := proto.Unmarshal(bysMsg, msg)
if err != nil {
log.Error("Parse client message error:", err)
return false
return fmt.Errorf("Parse client message error: %v", err)
}

var r io.Reader = bytes.NewReader(msg.Message)
Expand All @@ -316,78 +324,69 @@ func (ws *WsServer) OnDataHandle(curSession *session.Session, messageType int, b
case pb.COMPRESSION_ZLIB:
r, err = zlib.NewReader(r)
if err != nil {
log.Errorf("Create zlib reader error: %v", err)
return false
return fmt.Errorf("Create zlib reader error: %v", err)
}
defer r.(io.ReadCloser).Close()
default:
log.Errorf("Unsupported message compression type %v", msg.CompressionType)
return false
return fmt.Errorf("Unsupported message compression type %v", msg.CompressionType)
}

b, err := ioutil.ReadAll(io.LimitReader(r, config.MaxClientMessageSize+1))
if err != nil {
log.Errorf("ReadAll from reader error: %v", err)
return false
return fmt.Errorf("ReadAll from reader error: %v", err)
}
if len(b) > config.MaxClientMessageSize {
log.Errorf("Max client message size reached.")
return false
return fmt.Errorf("Max client message size reached.")
}

switch msg.MessageType {
case pb.OUTBOUND_MESSAGE:
outboundMsg := &pb.OutboundMessage{}
err = proto.Unmarshal(b, outboundMsg)
if err != nil {
log.Errorf("Unmarshal outbound message error: %v", err)
return false
return fmt.Errorf("Unmarshal outbound message error: %v", err)
}
ws.sendOutboundRelayMessage(curSession.GetAddrStr(), outboundMsg)
case pb.RECEIPT:
receipt := &pb.Receipt{}
err = proto.Unmarshal(b, receipt)
if err != nil {
log.Errorf("Unmarshal receipt error: %v", err)
return false
return fmt.Errorf("Unmarshal receipt error: %v", err)
}
err = ws.handleReceipt(receipt)
if err != nil {
log.Errorf("Handle receipt error: %v", err)
return false
return fmt.Errorf("Handle receipt error: %v", err)
}
default:
log.Errorf("unsupported client message type %v", msg.MessageType)
return false
return fmt.Errorf("unsupported client message type %v", msg.MessageType)
}

return true
return nil
}

var req = make(map[string]interface{})

if err := json.Unmarshal(bysMsg, &req); err != nil {
resp := common.ResponsePack(common.ILLEGAL_DATAFORMAT)
ws.respondToSession(curSession, resp)
log.Error("websocket OnDataHandle:", err)
return false
return fmt.Errorf("websocket OnDataHandle: %v", err)
}
actionName, ok := req["Action"].(string)
if !ok {
resp := common.ResponsePack(common.INVALID_METHOD)
ws.respondToSession(curSession, resp)
return false
return nil
}
action, ok := ws.ActionMap[actionName]
if !ok {
resp := common.ResponsePack(common.INVALID_METHOD)
ws.respondToSession(curSession, resp)
return false
return nil
}
if !ws.IsValidMsg(req) {
resp := common.ResponsePack(common.INVALID_PARAMS)
ws.respondToSession(curSession, resp)
return true
return nil
}
if height, ok := req["Height"].(float64); ok {
req["Height"] = strconv.FormatInt(int64(height), 10)
Expand All @@ -408,7 +407,7 @@ func (ws *WsServer) OnDataHandle(curSession *session.Session, messageType int, b
}
ws.respondToSession(curSession, resp)

return true
return nil
}

func (ws *WsServer) SetTxHashMap(txhash string, sessionid string) {
Expand Down
42 changes: 15 additions & 27 deletions api/websocket/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,56 +9,51 @@ import (
"github.com/pborman/uuid"
)

const (
writeTimeout = 10 * time.Second
)

type Session struct {
sync.Mutex
mConnection *websocket.Conn
nLastActive int64
ws *websocket.Conn
sSessionId string
clientChordID []byte
clientPubKey []byte
clientAddrStr *string
isTlsClient bool
}

const sessionTimeOut int64 = 120

func (s *Session) GetSessionId() string {
return s.sSessionId
}

func newSession(wsConn *websocket.Conn) (session *Session, err error) {
sSessionId := uuid.NewUUID().String()
session = &Session{
mConnection: wsConn,
nLastActive: time.Now().Unix(),
sSessionId: sSessionId,
ws: wsConn,
sSessionId: sSessionId,
}
return session, nil
}

func (s *Session) close() {
s.Lock()
defer s.Unlock()
if s.mConnection != nil {
s.mConnection.Close()
s.mConnection = nil
if s.ws != nil {
s.ws.Close()
s.ws = nil
}
s.sSessionId = ""
}

func (s *Session) UpdateActiveTime() {
s.Lock()
defer s.Unlock()
s.nLastActive = time.Now().Unix()
}

func (s *Session) Send(msgType int, data []byte) error {
s.Lock()
defer s.Unlock()
if s.mConnection == nil {
if s.ws == nil {
return errors.New("Websocket is null")
}
return s.mConnection.WriteMessage(msgType, data)
s.ws.SetWriteDeadline(time.Now().Add(writeTimeout))
return s.ws.WriteMessage(msgType, data)
}

func (s *Session) SendText(data []byte) error {
Expand All @@ -69,15 +64,8 @@ func (s *Session) SendBinary(data []byte) error {
return s.Send(websocket.BinaryMessage, data)
}

func (s *Session) SessionTimeoverCheck() bool {
if s.IsClient() {
return false
}
nCurTime := time.Now().Unix()
if nCurTime-s.nLastActive > sessionTimeOut { //sec
return true
}
return false
func (s *Session) Ping() error {
return s.Send(websocket.PingMessage, nil)
}

func (s *Session) SetSessionId(sessionId string) {
Expand Down

0 comments on commit 9ac08a7

Please sign in to comment.