diff --git a/api/websocket/session/session.go b/api/websocket/session/session.go index e7cbaf33d..37fd6e88d 100644 --- a/api/websocket/session/session.go +++ b/api/websocket/session/session.go @@ -14,9 +14,10 @@ const ( ) type Session struct { - sync.Mutex - ws *websocket.Conn - sSessionId string + ws *websocket.Conn + sessionID string + + sync.RWMutex clientChordID []byte clientPubKey []byte clientAddrStr *string @@ -24,14 +25,14 @@ type Session struct { } func (s *Session) GetSessionId() string { - return s.sSessionId + return s.sessionID } func newSession(wsConn *websocket.Conn) (session *Session, err error) { - sSessionId := uuid.NewUUID().String() + sessionID := uuid.NewUUID().String() session = &Session{ - ws: wsConn, - sSessionId: sSessionId, + ws: wsConn, + sessionID: sessionID, } return session, nil } @@ -43,12 +44,12 @@ func (s *Session) close() { s.ws.Close() s.ws = nil } - s.sSessionId = "" + s.sessionID = "" } func (s *Session) Send(msgType int, data []byte) error { - s.Lock() - defer s.Unlock() + s.RLock() + defer s.RUnlock() if s.ws == nil { return errors.New("Websocket is null") } @@ -71,7 +72,7 @@ func (s *Session) Ping() error { func (s *Session) SetSessionId(sessionId string) { s.Lock() defer s.Unlock() - s.sSessionId = sessionId + s.sessionID = sessionId } func (s *Session) SetClient(chordID, pubKey []byte, addrStr *string, isTls bool) { @@ -83,31 +84,45 @@ func (s *Session) SetClient(chordID, pubKey []byte, addrStr *string, isTls bool) s.isTlsClient = isTls } -func (s *Session) IsClient() bool { +func (s *Session) isClient() bool { return s.clientChordID != nil && s.clientPubKey != nil && s.clientAddrStr != nil } +func (s *Session) IsClient() bool { + s.RLock() + defer s.RUnlock() + return s.isClient() +} + func (s *Session) GetID() []byte { - if !s.IsClient() { + s.RLock() + defer s.RUnlock() + if !s.isClient() { return nil } return s.clientChordID } func (s *Session) GetPubKey() []byte { - if !s.IsClient() { + s.RLock() + defer s.RUnlock() + if !s.isClient() { return nil } return s.clientPubKey } func (s *Session) GetAddrStr() *string { - if !s.IsClient() { + s.RLock() + defer s.RUnlock() + if !s.isClient() { return nil } return s.clientAddrStr } func (s *Session) IsTlsClient() bool { + s.RLock() + defer s.RUnlock() return s.isTlsClient }